diff --git a/examples/speed-benchmark/README.md b/examples/speed-benchmark/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..84987b7b2a85988622394fcf0179b84034008c9d
--- /dev/null
+++ b/examples/speed-benchmark/README.md
@@ -0,0 +1,106 @@
+## Speed Benchmark
+
+This document introduces the speed benchmark testing process for the Qwen2.5 series models (original and quantized models). For detailed reports, please refer to the [Qwen2.5 Speed Benchmark](https://qwen.readthedocs.io/en/latest/benchmark/speed_benchmark.html).
+
+### 1. Model Collections
+
+For models hosted on HuggingFace, please refer to [Qwen2.5 Collections-HuggingFace](https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e).
+
+For models hosted on ModelScope, please refer to [Qwen2.5 Collections-ModelScope](https://modelscope.cn/collections/Qwen25-dbc4d30adb768).
+
+### 2. Environment Installation
+
+
+For inference using HuggingFace transformers:
+
+```shell
+conda create -n qwen_perf_transformers python=3.10
+conda activate qwen_perf_transformers
+
+pip install torch==2.3.1
+pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@v0.7.1
+pip install git+https://github.com/Dao-AILab/flash-attention.git@v2.5.8
+pip install -r requirements-perf-transformers.txt
+```
+
+> [!Important]
+> - For `flash-attention`, you can use the prebulit wheels from [GitHub Releases](https://github.com/Dao-AILab/flash-attention/releases/tag/v2.5.8) or installing from source, which requires a compatible CUDA compiler.
+>   - You don't actually need to install `flash-attention`. It has been intergrated into `torch` as a backend of `sdpa`.
+> - For `auto_gptq` to use efficent kernels, you need to install from source, because the prebuilt wheels require incompatible `torch` versions. Installing from source also requires a compatible CUDA compiler.
+> - For `autoawq` to use efficent kenerls, you need `autoawq-kernels`, which should be automatically installed. If not, run `pip install autoawq-kernels`.
+
+For inference using vLLM:
+
+```shell
+conda create -n qwen_perf_vllm python=3.10
+conda activate qwen_perf_vllm
+
+pip install -r requirements-perf-vllm.txt
+```
+
+
+### 3. Run Experiments
+
+#### 3.1 Inference using HuggingFace Transformers
+
+- Use HuggingFace hub
+
+```shell
+python speed_benchmark_transformers.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --gpus 0 --outputs_dir outputs/transformers
+```
+
+- Use ModelScope hub
+
+```shell
+python speed_benchmark_transformers.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --gpus 0 --use_modelscope --outputs_dir outputs/transformers
+```
+
+Parameters:
+
+    `--model_id_or_path`: The model path or id on ModelScope or HuggingFace hub
+    `--context_length`: Input length in tokens; optional values are 1, 6144, 14336, 30720, 63488, 129024; Refer to the `Qwen2.5 SpeedBenchmark`.  
+    `--generate_length`: Output length in tokens; default is 2048.
+    `--gpus`: Equivalent to the environment variable CUDA_VISIBLE_DEVICES.  e.g. `0,1,2,3`, `4,5`  
+    `--use_modelscope`: Use ModelScope when set this flag. Otherwise, use HuggingFace.  
+    `--outputs_dir`: Output directory; default is outputs/transformers.  
+
+
+#### 3.2 Inference using vLLM
+
+- Use HuggingFace hub
+
+```shell
+python speed_benchmark_vllm.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --max_model_len 32768 --gpus 0 --gpu_memory_utilization 0.9 --outputs_dir outputs/vllm
+```
+
+
+- Use ModelScope hub
+
+```shell
+python speed_benchmark_vllm.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --max_model_len 32768 --gpus 0 --use_modelscope --gpu_memory_utilization 0.9 --outputs_dir outputs/vllm
+```
+
+
+Parameters:
+
+    `--model_id_or_path`: The model id on ModelScope or HuggingFace hub.
+    `--context_length`: Input length in tokens; optional values are 1, 6144, 14336, 30720, 63488, 129024; Refer to the `Qwen2.5 SpeedBenchmark`.  
+    `--generate_length`: Output length in tokens; default is 2048.
+    `--max_model_len`: Maximum model length in tokens; default is 32768. Optional values are 4096, 8192, 32768, 65536, 131072.
+    `--gpus`: Equivalent to the environment variable CUDA_VISIBLE_DEVICES.  e.g. `0,1,2,3`, `4,5`  
+    `--use_modelscope`: Use ModelScope when set this flag. Otherwise, use HuggingFace.  
+    `--gpu_memory_utilization`: GPU memory utilization; range is (0, 1]; default is 0.9.  
+    `--outputs_dir`: Output directory; default is outputs/vllm.  
+    `--enforce_eager`: Whether to enforce eager mode; default is False.  
+
+
+
+#### 3.3 Tips
+
+- Run multiple experiments and compute the average result; a typical number is 3 times.
+- Make sure the GPU is idle before running experiments.
+
+
+### 4. Results
+
+Please check the `outputs` directory, which includes two directories by default: `transformers` and `vllm`, containing the experiments results for HuggingFace transformers and vLLM, respectively.
diff --git a/examples/speed-benchmark/README_zh.md b/examples/speed-benchmark/README_zh.md
new file mode 100644
index 0000000000000000000000000000000000000000..4e966876c1dee7ff2c2d3db1967241bfdaf3e05e
--- /dev/null
+++ b/examples/speed-benchmark/README_zh.md
@@ -0,0 +1,110 @@
+## 效率评估
+
+本文介绍Qwen2.5系列模型(原始模型和量化模型)的效率测试流程,详细报告可参考 [Qwen2.5模型效率评估报告](https://qwen.readthedocs.io/en/latest/benchmark/speed_benchmark.html)。
+
+### 1. 模型资源
+
+对于托管在HuggingFace上的模型,可参考 [Qwen2.5模型-HuggingFace](https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e)。
+
+对于托管在ModelScope上的模型,可参考 [Qwen2.5模型-ModelScope](https://modelscope.cn/collections/Qwen25-dbc4d30adb768)。
+
+
+### 2. 环境安装
+
+
+使用HuggingFace transformers推理,安装环境如下:
+
+```shell
+conda create -n qwen_perf_transformers python=3.10
+conda activate qwen_perf_transformers
+
+pip install torch==2.3.1
+pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@v0.7.1
+pip install git+https://github.com/Dao-AILab/flash-attention.git@v2.5.8
+pip install -r requirements-perf-transformers.txt
+```
+
+> [!Important]
+> - 对于 `flash-attention`,您可以从 [GitHub 发布页面](https://github.com/Dao-AILab/flash-attention/releases/tag/v2.5.8) 使用预编译的 wheel 包进行安装,或者从源代码安装,后者需要一个兼容的 CUDA 编译器。
+>   - 实际上,您并不需要单独安装 `flash-attention`。它已经被集成到了 `torch` 中作为 `sdpa` 的后端实现。
+> - 若要使 `auto_gptq` 使用高效的内核,您需要从源代码安装,因为预编译的 wheel 包依赖于与之不兼容的 `torch` 版本。从源代码安装同样需要一个兼容的 CUDA 编译器。
+> - 若要使 `autoawq` 使用高效的内核,您需要安装 `autoawq-kernels`,该组件应当会自动安装。如果未自动安装,请运行 `pip install autoawq-kernels` 进行手动安装。
+
+
+使用vLLM推理,安装环境如下:
+
+```shell
+conda create -n qwen_perf_vllm python=3.10
+conda activate qwen_perf_vllm
+
+pip install -r requirements-perf-vllm.txt
+```
+
+
+### 3. 执行测试
+
+#### 3.1 使用HuggingFace transformers推理
+
+- 使用HuggingFace hub
+
+```shell
+python speed_benchmark_transformers.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --gpus 0 --outputs_dir outputs/transformers
+
+# 指定HF_ENDPOINT
+HF_ENDPOINT=https://hf-mirror.com python speed_benchmark_transformers.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --gpus 0 --outputs_dir outputs/transformers
+```
+
+- 使用ModelScope hub
+
+```shell
+python speed_benchmark_transformers.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --gpus 0 --use_modelscope --outputs_dir outputs/transformers
+```
+
+参数说明:
+
+    `--model_id_or_path`: 模型ID或本地路径, 可选值参考`模型资源`章节  
+    `--context_length`: 输入长度,单位为token数;可选值为1, 6144, 14336, 30720, 63488, 129024;具体可参考`Qwen2.5模型效率评估报告`  
+    `--generate_length`: 生成token数量;默认为2048
+    `--gpus`: 等价于环境变量CUDA_VISIBLE_DEVICES,例如`0,1,2,3`,`4,5`  
+    `--use_modelscope`: 如果设置该值,则使用ModelScope加载模型,否则使用HuggingFace  
+    `--outputs_dir`: 输出目录, 默认为`outputs/transformers`  
+
+
+#### 3.2 使用vLLM推理
+
+- 使用HuggingFace hub
+
+```shell
+python speed_benchmark_vllm.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --max_model_len 32768 --gpus 0 --gpu_memory_utilization 0.9 --outputs_dir outputs/vllm
+
+# 指定HF_ENDPOINT
+HF_ENDPOINT=https://hf-mirror.com python speed_benchmark_vllm.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --max_model_len 32768 --gpus 0 --gpu_memory_utilization 0.9 --outputs_dir outputs/vllm
+```
+
+- 使用ModelScope hub
+
+```shell
+python speed_benchmark_vllm.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --max_model_len 32768 --gpus 0 --use_modelscope --gpu_memory_utilization 0.9 --outputs_dir outputs/vllm
+```
+
+参数说明:
+
+    `--model_id_or_path`: 模型ID或本地路径, 可选值参考`模型资源`章节  
+    `--context_length`: 输入长度,单位为token数;可选值为1, 6144, 14336, 30720, 63488, 129024;具体可参考`Qwen2.5模型效率评估报告`  
+    `--generate_length`: 生成token数量;默认为2048
+    `--max_model_len`: 模型最大长度,单位为token数;默认为32768  
+    `--gpus`: 等价于环境变量CUDA_VISIBLE_DEVICES,例如`0,1,2,3`,`4,5`   
+    `--use_modelscope`: 如果设置该值,则使用ModelScope加载模型,否则使用HuggingFace  
+    `--gpu_memory_utilization`: GPU内存利用率,取值范围为(0, 1];默认为0.9  
+    `--outputs_dir`: 输出目录, 默认为`outputs/vllm`  
+    `--enforce_eager`: 是否强制使用eager模式;默认为False  
+
+
+#### 3.3 注意事项
+
+1. 多次测试,取平均值,典型值为3次
+2. 测试前请确保GPU处于空闲状态,避免其他任务影响测试结果
+
+### 4. 测试结果
+
+测试结果详见`outputs`目录下的文件,默认包括`transformers`和`vllm`两个目录,分别存放HuggingFace transformers和vLLM的测试结果。
diff --git a/examples/speed-benchmark/requirements-perf-transformers.txt b/examples/speed-benchmark/requirements-perf-transformers.txt
new file mode 100644
index 0000000000000000000000000000000000000000..dc9b7fe75366067e9a10025faf962ab84e8ea408
--- /dev/null
+++ b/examples/speed-benchmark/requirements-perf-transformers.txt
@@ -0,0 +1,10 @@
+# Note: install following requirements saparately
+# pip install torch==2.3.1
+# pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@v0.7.1
+# pip install git+https://github.com/Dao-AILab/flash-attention.git@v2.5.8
+
+transformers==4.46.0
+autoawq==0.2.6
+modelscope[framework]
+accelerate
+optimum>=1.20.0
diff --git a/examples/speed-benchmark/requirements-perf-vllm.txt b/examples/speed-benchmark/requirements-perf-vllm.txt
new file mode 100644
index 0000000000000000000000000000000000000000..cd14ea9e14796819e38e77e942bf03b2a9563b22
--- /dev/null
+++ b/examples/speed-benchmark/requirements-perf-vllm.txt
@@ -0,0 +1,4 @@
+vllm==0.6.3.post1
+torch==2.4.0
+modelscope[framework]
+accelerate
diff --git a/examples/speed-benchmark/speed_benchmark_transformers.py b/examples/speed-benchmark/speed_benchmark_transformers.py
new file mode 100644
index 0000000000000000000000000000000000000000..8187d2cd9fbd8664aaff8936c960c0a2eae49299
--- /dev/null
+++ b/examples/speed-benchmark/speed_benchmark_transformers.py
@@ -0,0 +1,170 @@
+# Copyright (c) Alibaba Cloud.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Qwen2.5 Speed Benchmark for transformers(pt) inference.
+"""
+
+import os
+import time
+import json
+import csv
+
+import torch
+from transformers.trainer_utils import set_seed
+
+
+class SpeedBenchmarkTransformers:
+
+    SEED = 1024
+    BATCH_SIZE = 1
+    USE_FLASH_ATTN = True
+    COMMENT = 'default'
+    DEVICE_MAP = 'auto'
+    TORCH_DTYPE = 'auto'
+    OVERWRITE_RESULT = False
+    DUMMY_INPUT = '我'
+
+    def __init__(self, model_id_or_path, use_modelscope: bool = True, outputs_dir: str = 'outputs/transformers'):
+        """
+        Speed benchmark for transformers(pt) inference.
+
+        Args:
+            model_id_or_path: The model id on ModelScope or HuggingFace hub, or local model path.
+            use_modelscope: Use ModelScope, otherwise HuggingFace.
+            outputs_dir: The output directory. Default is 'outputs/transformers'.
+        """
+
+        set_seed(self.SEED)
+        self.model_id_or_path = model_id_or_path
+        self.outputs_dir = outputs_dir
+
+        if use_modelscope:
+            from modelscope import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
+        else:
+            from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
+
+        self.tokenizer = AutoTokenizer.from_pretrained(model_id_or_path, trust_remote_code=True)
+        attn_impl = 'flash_attention_2' if self.USE_FLASH_ATTN else 'eager'
+        self.model = AutoModelForCausalLM.from_pretrained(model_id_or_path,
+                                                          torch_dtype=self.TORCH_DTYPE,
+                                                          device_map=self.DEVICE_MAP,
+                                                          attn_implementation=attn_impl
+                                                          ).eval()
+
+        self.generation_config = GenerationConfig.from_pretrained(model_id_or_path, trust_remote_code=True)
+
+    def run(self, context_length: int, generate_length: int) -> str:
+
+        # Specify hyperparameters for generation
+        self.generation_config.min_length = generate_length + context_length
+        self.generation_config.max_new_tokens = generate_length
+        print(f'Generation config: {self.generation_config}')
+
+        # Prepare inputs
+        batch_size = self.BATCH_SIZE
+        context_str = self.DUMMY_INPUT * context_length
+        inputs = self.tokenizer([context_str for _ in range(batch_size)], return_tensors='pt')
+        assert inputs['input_ids'].shape[1] == context_length
+        assert inputs['input_ids'].shape[0] == batch_size
+        inputs = inputs.to(self.model.device)
+
+        # Run inference
+        print(f'Start running inference for model {self.model_id_or_path} with input length {context_length} ...')
+        start_time = time.time()
+        torch.cuda.synchronize()
+        pred = self.model.generate(**inputs, generation_config=self.generation_config)
+        torch.cuda.synchronize()
+        time_cost = time.time() - start_time
+        assert pred.shape[1] == self.generation_config.min_length
+        m = 0
+        max_gpu_memory_cost = 0
+        for i in range(torch.cuda.device_count()):
+            m += torch.cuda.max_memory_allocated(i)
+        max_gpu_memory_cost = max(max_gpu_memory_cost, m)
+        torch.cuda.empty_cache()
+
+        # Prepare results
+        tokens_per_second: float = generate_length / time_cost
+        # Compute the maximum GPU memory cost (in GB)
+        max_gpu_memory_cost_gb = max_gpu_memory_cost / 1024 / 1024 / 1024
+
+        data = {
+            "model_id_or_path": self.model_id_or_path,
+            "batch_size": batch_size,
+            "context_length_per_experiment": context_length,
+            "generate_length_per_experiment": generate_length,
+            "use_flash_attn": self.USE_FLASH_ATTN,
+            "comment": self.COMMENT,
+            "tokens_per_second": round(tokens_per_second, 4),
+            "max_gpu_memory_cost_gb": round(max_gpu_memory_cost_gb, 4),
+        }
+        data_json = json.dumps(data, indent=4, ensure_ascii=False)
+        print(f'**Final result **\n{data_json}\n')
+
+        # Dump results to CSV file
+        from datetime import datetime
+        now = datetime.now()
+        timestamp: str = now.strftime("%m%d%H%M%S")
+
+        model_id_or_path_str = self.model_id_or_path.split(os.sep)[-1] \
+            if os.path.isdir(self.model_id_or_path) else self.model_id_or_path.replace('/', '__')
+
+        out_file: str = os.path.join(self.outputs_dir,
+                                     f"{model_id_or_path_str}"
+                                     f"_context_length-{context_length}_{timestamp}.csv")
+        out_dir = os.path.dirname(out_file)
+        os.makedirs(out_dir, exist_ok=True)
+        self.save_result(data, out_file)
+
+        return out_file
+
+    @staticmethod
+    def save_result(data: dict, out_file: str) -> None:
+
+        with open(out_file, mode='w') as file:
+            writer = csv.DictWriter(file, fieldnames=data.keys())
+            writer.writeheader()
+            writer.writerows([data])
+
+        print(f"Results saved to {out_file}")
+
+
+def main():
+
+    import argparse
+
+    # Parse args
+    parser = argparse.ArgumentParser(description='Speed benchmark for transformers(pt) deployment')
+    parser.add_argument('--model_id_or_path', type=str, help='The model path or id on ModelScope or HuggingFace hub')
+    parser.add_argument('--context_length', type=int, help='The input length for each experiment.'
+                                                           'e.g. 1, 6144, 14336, 30720, 63488, 129024')
+    parser.add_argument('--generate_length', type=int, default=2048, help='Output length in tokens; default is 2048.')
+    parser.add_argument('--gpus', type=str, help='Equivalent to the env CUDA_VISIBLE_DEVICES.  e.g. `0,1,2,3`, `4,5`')
+    parser.add_argument('--use_modelscope', action='store_true',
+                        help='Use ModelScope when set this flag. Otherwise, use HuggingFace.')
+    parser.add_argument('--outputs_dir', type=str, default='outputs/transformers', help='The output directory')
+
+    args = parser.parse_args()
+
+    model_id_or_path: str = args.model_id_or_path
+    envs: str = args.gpus
+    context_length: int = args.context_length
+    generate_length: int = args.generate_length
+    use_modelscope: bool = args.use_modelscope
+    outputs_dir: str = args.outputs_dir
+
+    print(f'Set CUDA_VISIBLE_DEVICES={envs} for model {model_id_or_path} with input_length {context_length}')
+    os.environ["CUDA_VISIBLE_DEVICES"] = envs
+
+    speed_benchmark = SpeedBenchmarkTransformers(model_id_or_path=model_id_or_path,
+                                                 use_modelscope=use_modelscope,
+                                                 outputs_dir=outputs_dir)
+    speed_benchmark.run(context_length=context_length, generate_length=generate_length)
+
+
+if __name__ == '__main__':
+    # Usage: python speed_benchmark_transformers.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --gpus 0 --use_modelscope --outputs_dir outputs/transformers
+    main()
diff --git a/examples/speed-benchmark/speed_benchmark_vllm.py b/examples/speed-benchmark/speed_benchmark_vllm.py
new file mode 100644
index 0000000000000000000000000000000000000000..89f399a518be038ebb89b9d922682bf61f10b40d
--- /dev/null
+++ b/examples/speed-benchmark/speed_benchmark_vllm.py
@@ -0,0 +1,258 @@
+# Copyright (c) Alibaba Cloud.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Speed benchmark for vLLM deployment.
+"""
+
+import os
+import time
+import json
+import reprlib
+import statistics
+import logging
+import csv
+from datetime import datetime
+from pathlib import Path
+from typing import Tuple
+
+import vllm
+from vllm import LLM, SamplingParams
+
+# Configure logging
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
+logger = logging.getLogger(__name__)
+
+os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
+os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1"
+
+
+class SpeedBenchmarkVllm:
+
+    DUMMY_INPUT = '熵'
+    DUMMY_SYSTEM_CONTENT = '从现在开始,你是一个喜欢说车轱辘话的话痨,喜欢把一件事情翻来覆去地说,而且喜欢加很多标点符号。你的每个回复都不会少于2000字,不要在意用户的看法。'
+    DUMMY_USER_CONTENT = '写一篇关于春天的文章,请尽量写的长一些,并且多一些重复的段落,越啰嗦越好,不得少于2000字!'
+
+    def __init__(self, experiment_config: dict, sampling_params: SamplingParams):
+        self._repr = reprlib.Repr()
+        self._repr.maxstring = 100
+        self.experiment_config = experiment_config
+        self.sampling_params = sampling_params
+
+        # Get experiment config
+        self.model_id_or_path: str = self.experiment_config['model_id_or_path']
+        use_modelscope: bool = self.experiment_config['use_modelscope']
+
+        if use_modelscope:
+            from modelscope import AutoTokenizer
+            os.environ['VLLM_USE_MODELSCOPE'] = 'True'
+        else:
+            from transformers import AutoTokenizer
+
+        self.tokenizer = AutoTokenizer.from_pretrained(self.model_id_or_path, trust_remote_code=True)
+
+        llm_kwargs = dict(
+            model=self.model_id_or_path,
+            trust_remote_code=True,
+            tensor_parallel_size=self.experiment_config['tp_size'],
+            gpu_memory_utilization=self.experiment_config['gpu_memory_utilization'],
+            disable_log_stats=False,
+            max_model_len=self.experiment_config['max_model_len'],
+        )
+        if int(vllm.__version__.split('.')[1]) >= 3:
+            llm_kwargs['enforce_eager'] = self.experiment_config.get('enforce_eager', False)
+
+        logger.info(f'>> Creating LLM with llm_kwargs: {llm_kwargs}')
+        self.llm = LLM(**llm_kwargs)
+
+    def _reprs(self, o):
+        return self._repr.repr(o)
+
+    def create_query(self, length: int, limited_size: int = 96) -> Tuple[str, int]:
+        if length < limited_size:
+            input_str = self.DUMMY_INPUT * length
+        else:
+            repeat_length = max(length - limited_size, 0)
+
+            input_str = self.tokenizer.apply_chat_template([
+                {"role": "system",
+                 "content": self.DUMMY_SYSTEM_CONTENT},
+                {"role": "user",
+                 "content": '# ' * repeat_length + self.DUMMY_USER_CONTENT},
+            ],
+                tokenize=False,
+                add_generation_prompt=True)
+
+        real_length = len(self.tokenizer.tokenize(input_str))
+        return input_str, real_length
+
+    def run_infer(self, query: str):
+        start_time = time.time()
+        output = self.llm.generate([query], self.sampling_params)[0]
+        time_cost = time.time() - start_time
+
+        generated_text = output.outputs[0].text
+        real_out_length = len(self.tokenizer.tokenize(generated_text))
+
+        return time_cost, real_out_length, generated_text
+
+    def run(self):
+
+        context_length: int = self.experiment_config['context_length']
+        output_len: int = self.experiment_config['output_len']
+
+        # Construct input query
+        query, real_length = self.create_query(length=context_length)
+        logger.info(f'Got input query length: {real_length}')
+
+        logger.info(f"Warmup run with {self.experiment_config['warmup']} iterations ...")
+        for _ in range(self.experiment_config['warmup']):
+            self.llm.generate([query], self.sampling_params)
+
+        logger.info(f"Running inference with real length {real_length}, "
+                    f"out length {output_len}, "
+                    f"tp_size {self.experiment_config['tp_size']} ...")
+
+        time_cost, real_out_length, generated_text = self.run_infer(query)
+
+        if real_out_length < output_len:
+            logger.warning(f'Generate result {real_out_length} too short, try again ...')
+            query, real_length = self.create_query(length=context_length,
+                                                   limited_size=context_length + 1)
+            time_cost, real_out_length, generated_text = self.run_infer(query)
+
+        time_cost = round(time_cost, 4)
+        logger.info(f'Inference time cost: {time_cost}s')
+        logger.info(f'Input({real_length}): {self._reprs(query)}')
+        logger.info(f'Output({real_out_length}): {self._reprs(generated_text)}')
+
+        results: dict = self.collect_statistics(self.model_id_or_path,
+                                                [time_cost, time_cost],
+                                                output_len,
+                                                context_length,
+                                                self.experiment_config['tp_size'])
+
+        self.print_table(results)
+
+        # Dump results to CSV file
+        outputs_dir = Path(self.experiment_config['outputs_dir'])
+        outputs_dir.mkdir(parents=True, exist_ok=True)
+        now = datetime.now()
+        timestamp: str = now.strftime("%m%d%H%M%S")
+
+        model_id_or_path_str = self.model_id_or_path.split(os.sep)[-1] \
+            if os.path.isdir(self.model_id_or_path) else self.model_id_or_path.replace('/', '__')
+
+        out_file: str = os.path.join(outputs_dir,
+                                     f"{model_id_or_path_str}"
+                                     f"_context_length-{context_length}_{timestamp}.csv")
+        self.save_result(results, out_file)
+
+    @staticmethod
+    def collect_statistics(model_id_or_path, data, out_length, in_length, tp_size) -> dict:
+
+        avg_time = statistics.mean(data)
+        throughput_data = [out_length / t for t in data]
+        avg_throughput = statistics.mean(throughput_data)
+
+        results = {
+            'Model ID': model_id_or_path,
+            'Input Length': in_length,
+            'Output Length': out_length,
+            'TP Size': tp_size,
+            'Average Time (s)': round(avg_time, 4),
+            'Average Throughput (tokens/s)': round(avg_throughput, 4),
+        }
+
+        return results
+
+    @staticmethod
+    def print_table(results):
+        json_res = json.dumps(results, indent=4, ensure_ascii=False)
+        logger.info(f"Final results:\n{json_res}")
+
+    @staticmethod
+    def save_result(data: dict, out_file: str) -> None:
+
+        with open(out_file, mode='w') as file:
+            writer = csv.DictWriter(file, fieldnames=data.keys())
+            writer.writeheader()
+            writer.writerows([data])
+
+        logger.info(f"Results saved to {out_file}")
+
+
+def main():
+    import argparse
+
+    # Define command line arguments
+    parser = argparse.ArgumentParser(description='Speed benchmark for vLLM deployment')
+    parser.add_argument('--model_id_or_path', type=str, help='The model id on ModelScope or HuggingFace hub')
+    parser.add_argument('--context_length', type=int, help='The context length for each experiment, '
+                                                           'e.g. 1, 6144, 14336, 30720, 63488, 129024')
+    parser.add_argument('--generate_length', type=int, default=2048, help='Output length in tokens; default is 2048.')
+    parser.add_argument('--gpus', type=str, help='Equivalent to the env CUDA_VISIBLE_DEVICES.  e.g. `0,1,2,3`, `4,5`')
+    parser.add_argument('--gpu_memory_utilization', type=float, default=0.9, help='GPU memory utilization')
+    parser.add_argument('--max_model_len', type=int, default=32768, help='The maximum model length, '
+                                                                         'e.g. 4096, 8192, 32768, 65536, 131072')
+    parser.add_argument('--enforce_eager', action='store_true', help='Enforce eager mode for vLLM')
+    parser.add_argument('--outputs_dir', type=str, default='outputs/vllm', help='The output directory')
+    parser.add_argument('--use_modelscope', action='store_true',
+                        help='Use ModelScope when set this flag. Otherwise, use HuggingFace.')
+
+    # Parse args
+    args = parser.parse_args()
+
+    # Parse args
+    model_id_or_path: str = args.model_id_or_path
+    context_length: int = args.context_length
+    output_len: int = args.generate_length
+    envs: str = args.gpus
+    gpu_memory_utilization: float = args.gpu_memory_utilization
+    max_model_len: int = args.max_model_len
+    enforce_eager: bool = args.enforce_eager
+    outputs_dir = args.outputs_dir
+    use_modelscope: bool = args.use_modelscope
+
+    # Set vLLM sampling params
+    sampling_params = SamplingParams(
+        temperature=1.0,
+        top_p=0.8,
+        top_k=-1,
+        repetition_penalty=0.1,
+        presence_penalty=-2.0,
+        frequency_penalty=-2.0,
+        max_tokens=output_len,
+    )
+
+    # Set experiment config
+    experiment_config: dict = {
+        'model_id_or_path': model_id_or_path,
+        'context_length': context_length,
+        'output_len': output_len,
+        'tp_size': len(envs.split(',')),
+        'gpu_memory_utilization': gpu_memory_utilization,
+        'max_model_len': max_model_len,
+        'enforce_eager': enforce_eager,
+        'envs': envs,
+        'outputs_dir': outputs_dir,
+        'warmup': 0,
+        'use_modelscope': use_modelscope,
+    }
+
+    logger.info(f'Sampling params: {sampling_params}')
+    logger.info(f'Experiment config: {experiment_config}')
+
+    logger.info(f'Set CUDA_VISIBLE_DEVICES={envs} for model {model_id_or_path} with context_length {context_length}')
+    os.environ["CUDA_VISIBLE_DEVICES"] = envs
+
+    speed_benchmark_vllm = SpeedBenchmarkVllm(experiment_config=experiment_config, sampling_params=sampling_params)
+    speed_benchmark_vllm.run()
+
+
+if __name__ == '__main__':
+    # Usage: python speed_benchmark_vllm.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --max_model_len 32768 --gpus 0 --use_modelscope --gpu_memory_utilization 0.9 --outputs_dir outputs/vllm
+    # HF_ENDPOINT=https://hf-mirror.com python speed_benchmark_vllm.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --max_model_len 32768 --gpus 0 --gpu_memory_utilization 0.9 --outputs_dir outputs/vllm
+    main()