diff --git a/examples/speed-benchmark/README.md b/examples/speed-benchmark/README.md new file mode 100644 index 0000000000000000000000000000000000000000..84987b7b2a85988622394fcf0179b84034008c9d --- /dev/null +++ b/examples/speed-benchmark/README.md @@ -0,0 +1,106 @@ +## Speed Benchmark + +This document introduces the speed benchmark testing process for the Qwen2.5 series models (original and quantized models). For detailed reports, please refer to the [Qwen2.5 Speed Benchmark](https://qwen.readthedocs.io/en/latest/benchmark/speed_benchmark.html). + +### 1. Model Collections + +For models hosted on HuggingFace, please refer to [Qwen2.5 Collections-HuggingFace](https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e). + +For models hosted on ModelScope, please refer to [Qwen2.5 Collections-ModelScope](https://modelscope.cn/collections/Qwen25-dbc4d30adb768). + +### 2. Environment Installation + + +For inference using HuggingFace transformers: + +```shell +conda create -n qwen_perf_transformers python=3.10 +conda activate qwen_perf_transformers + +pip install torch==2.3.1 +pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@v0.7.1 +pip install git+https://github.com/Dao-AILab/flash-attention.git@v2.5.8 +pip install -r requirements-perf-transformers.txt +``` + +> [!Important] +> - For `flash-attention`, you can use the prebulit wheels from [GitHub Releases](https://github.com/Dao-AILab/flash-attention/releases/tag/v2.5.8) or installing from source, which requires a compatible CUDA compiler. +> - You don't actually need to install `flash-attention`. It has been intergrated into `torch` as a backend of `sdpa`. +> - For `auto_gptq` to use efficent kernels, you need to install from source, because the prebuilt wheels require incompatible `torch` versions. Installing from source also requires a compatible CUDA compiler. +> - For `autoawq` to use efficent kenerls, you need `autoawq-kernels`, which should be automatically installed. If not, run `pip install autoawq-kernels`. + +For inference using vLLM: + +```shell +conda create -n qwen_perf_vllm python=3.10 +conda activate qwen_perf_vllm + +pip install -r requirements-perf-vllm.txt +``` + + +### 3. Run Experiments + +#### 3.1 Inference using HuggingFace Transformers + +- Use HuggingFace hub + +```shell +python speed_benchmark_transformers.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --gpus 0 --outputs_dir outputs/transformers +``` + +- Use ModelScope hub + +```shell +python speed_benchmark_transformers.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --gpus 0 --use_modelscope --outputs_dir outputs/transformers +``` + +Parameters: + + `--model_id_or_path`: The model path or id on ModelScope or HuggingFace hub + `--context_length`: Input length in tokens; optional values are 1, 6144, 14336, 30720, 63488, 129024; Refer to the `Qwen2.5 SpeedBenchmark`. + `--generate_length`: Output length in tokens; default is 2048. + `--gpus`: Equivalent to the environment variable CUDA_VISIBLE_DEVICES. e.g. `0,1,2,3`, `4,5` + `--use_modelscope`: Use ModelScope when set this flag. Otherwise, use HuggingFace. + `--outputs_dir`: Output directory; default is outputs/transformers. + + +#### 3.2 Inference using vLLM + +- Use HuggingFace hub + +```shell +python speed_benchmark_vllm.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --max_model_len 32768 --gpus 0 --gpu_memory_utilization 0.9 --outputs_dir outputs/vllm +``` + + +- Use ModelScope hub + +```shell +python speed_benchmark_vllm.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --max_model_len 32768 --gpus 0 --use_modelscope --gpu_memory_utilization 0.9 --outputs_dir outputs/vllm +``` + + +Parameters: + + `--model_id_or_path`: The model id on ModelScope or HuggingFace hub. + `--context_length`: Input length in tokens; optional values are 1, 6144, 14336, 30720, 63488, 129024; Refer to the `Qwen2.5 SpeedBenchmark`. + `--generate_length`: Output length in tokens; default is 2048. + `--max_model_len`: Maximum model length in tokens; default is 32768. Optional values are 4096, 8192, 32768, 65536, 131072. + `--gpus`: Equivalent to the environment variable CUDA_VISIBLE_DEVICES. e.g. `0,1,2,3`, `4,5` + `--use_modelscope`: Use ModelScope when set this flag. Otherwise, use HuggingFace. + `--gpu_memory_utilization`: GPU memory utilization; range is (0, 1]; default is 0.9. + `--outputs_dir`: Output directory; default is outputs/vllm. + `--enforce_eager`: Whether to enforce eager mode; default is False. + + + +#### 3.3 Tips + +- Run multiple experiments and compute the average result; a typical number is 3 times. +- Make sure the GPU is idle before running experiments. + + +### 4. Results + +Please check the `outputs` directory, which includes two directories by default: `transformers` and `vllm`, containing the experiments results for HuggingFace transformers and vLLM, respectively. diff --git a/examples/speed-benchmark/README_zh.md b/examples/speed-benchmark/README_zh.md new file mode 100644 index 0000000000000000000000000000000000000000..4e966876c1dee7ff2c2d3db1967241bfdaf3e05e --- /dev/null +++ b/examples/speed-benchmark/README_zh.md @@ -0,0 +1,110 @@ +## 效率评估 + +本文介ç»Qwen2.5系列模型(原始模型和é‡åŒ–模型)的效率测试æµç¨‹ï¼Œè¯¦ç»†æŠ¥å‘Šå¯å‚考 [Qwen2.5模型效率评估报告](https://qwen.readthedocs.io/en/latest/benchmark/speed_benchmark.html)。 + +### 1. æ¨¡åž‹èµ„æº + +对于托管在HuggingFace上的模型,å¯å‚考 [Qwen2.5模型-HuggingFace](https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e)。 + +对于托管在ModelScope上的模型,å¯å‚考 [Qwen2.5模型-ModelScope](https://modelscope.cn/collections/Qwen25-dbc4d30adb768)。 + + +### 2. 环境安装 + + +使用HuggingFace transformers推ç†ï¼Œå®‰è£…环境如下: + +```shell +conda create -n qwen_perf_transformers python=3.10 +conda activate qwen_perf_transformers + +pip install torch==2.3.1 +pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@v0.7.1 +pip install git+https://github.com/Dao-AILab/flash-attention.git@v2.5.8 +pip install -r requirements-perf-transformers.txt +``` + +> [!Important] +> - 对于 `flash-attention`,您å¯ä»¥ä»Ž [GitHub å‘布页é¢](https://github.com/Dao-AILab/flash-attention/releases/tag/v2.5.8) 使用预编译的 wheel 包进行安装,或者从æºä»£ç 安装,åŽè€…需è¦ä¸€ä¸ªå…¼å®¹çš„ CUDA 编译器。 +> - 实际上,您并ä¸éœ€è¦å•独安装 `flash-attention`。它已ç»è¢«é›†æˆåˆ°äº† `torch` ä¸ä½œä¸º `sdpa` çš„åŽç«¯å®žçŽ°ã€‚ +> - è‹¥è¦ä½¿ `auto_gptq` ä½¿ç”¨é«˜æ•ˆçš„å†…æ ¸ï¼Œæ‚¨éœ€è¦ä»Žæºä»£ç å®‰è£…ï¼Œå› ä¸ºé¢„ç¼–è¯‘çš„ wheel 包ä¾èµ–于与之ä¸å…¼å®¹çš„ `torch` 版本。从æºä»£ç å®‰è£…åŒæ ·éœ€è¦ä¸€ä¸ªå…¼å®¹çš„ CUDA 编译器。 +> - è‹¥è¦ä½¿ `autoawq` ä½¿ç”¨é«˜æ•ˆçš„å†…æ ¸ï¼Œæ‚¨éœ€è¦å®‰è£… `autoawq-kernels`,该组件应当会自动安装。如果未自动安装,请è¿è¡Œ `pip install autoawq-kernels` 进行手动安装。 + + +使用vLLM推ç†ï¼Œå®‰è£…环境如下: + +```shell +conda create -n qwen_perf_vllm python=3.10 +conda activate qwen_perf_vllm + +pip install -r requirements-perf-vllm.txt +``` + + +### 3. 执行测试 + +#### 3.1 使用HuggingFace transformersæŽ¨ç† + +- 使用HuggingFace hub + +```shell +python speed_benchmark_transformers.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --gpus 0 --outputs_dir outputs/transformers + +# 指定HF_ENDPOINT +HF_ENDPOINT=https://hf-mirror.com python speed_benchmark_transformers.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --gpus 0 --outputs_dir outputs/transformers +``` + +- 使用ModelScope hub + +```shell +python speed_benchmark_transformers.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --gpus 0 --use_modelscope --outputs_dir outputs/transformers +``` + +傿•°è¯´æ˜Žï¼š + + `--model_id_or_path`: 模型ID或本地路径, å¯é€‰å€¼å‚考`模型资æº`ç« èŠ‚ + `--context_length`: 输入长度,å•ä½ä¸ºtokenæ•°ï¼›å¯é€‰å€¼ä¸º1, 6144, 14336, 30720, 63488, 129024;具体å¯å‚考`Qwen2.5模型效率评估报告` + `--generate_length`: 生æˆtokenæ•°é‡ï¼›é»˜è®¤ä¸º2048 + `--gpus`: ç‰ä»·äºŽçŽ¯å¢ƒå˜é‡CUDA_VISIBLE_DEVICES,例如`0,1,2,3`,`4,5` + `--use_modelscope`: 如果设置该值,则使用ModelScopeåŠ è½½æ¨¡åž‹ï¼Œå¦åˆ™ä½¿ç”¨HuggingFace + `--outputs_dir`: 输出目录, 默认为`outputs/transformers` + + +#### 3.2 使用vLLMæŽ¨ç† + +- 使用HuggingFace hub + +```shell +python speed_benchmark_vllm.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --max_model_len 32768 --gpus 0 --gpu_memory_utilization 0.9 --outputs_dir outputs/vllm + +# 指定HF_ENDPOINT +HF_ENDPOINT=https://hf-mirror.com python speed_benchmark_vllm.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --max_model_len 32768 --gpus 0 --gpu_memory_utilization 0.9 --outputs_dir outputs/vllm +``` + +- 使用ModelScope hub + +```shell +python speed_benchmark_vllm.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --max_model_len 32768 --gpus 0 --use_modelscope --gpu_memory_utilization 0.9 --outputs_dir outputs/vllm +``` + +傿•°è¯´æ˜Žï¼š + + `--model_id_or_path`: 模型ID或本地路径, å¯é€‰å€¼å‚考`模型资æº`ç« èŠ‚ + `--context_length`: 输入长度,å•ä½ä¸ºtokenæ•°ï¼›å¯é€‰å€¼ä¸º1, 6144, 14336, 30720, 63488, 129024;具体å¯å‚考`Qwen2.5模型效率评估报告` + `--generate_length`: 生æˆtokenæ•°é‡ï¼›é»˜è®¤ä¸º2048 + `--max_model_len`: 模型最大长度,å•ä½ä¸ºtoken数;默认为32768 + `--gpus`: ç‰ä»·äºŽçŽ¯å¢ƒå˜é‡CUDA_VISIBLE_DEVICES,例如`0,1,2,3`,`4,5` + `--use_modelscope`: 如果设置该值,则使用ModelScopeåŠ è½½æ¨¡åž‹ï¼Œå¦åˆ™ä½¿ç”¨HuggingFace + `--gpu_memory_utilization`: GPU内å˜åˆ©ç”¨çŽ‡ï¼Œå–值范围为(0, 1];默认为0.9 + `--outputs_dir`: 输出目录, 默认为`outputs/vllm` + `--enforce_eager`: 是å¦å¼ºåˆ¶ä½¿ç”¨eager模å¼ï¼›é»˜è®¤ä¸ºFalse + + +#### 3.3 注æ„事项 + +1. 多次测试,å–å¹³å‡å€¼ï¼Œå…¸åž‹å€¼ä¸º3次 +2. 测试å‰è¯·ç¡®ä¿GPU处于空闲状æ€ï¼Œé¿å…å…¶ä»–ä»»åŠ¡å½±å“æµ‹è¯•结果 + +### 4. 测试结果 + +测试结果详è§`outputs`目录下的文件,默认包括`transformers`å’Œ`vllm`ä¸¤ä¸ªç›®å½•ï¼Œåˆ†åˆ«å˜æ”¾HuggingFace transformerså’ŒvLLM的测试结果。 diff --git a/examples/speed-benchmark/requirements-perf-transformers.txt b/examples/speed-benchmark/requirements-perf-transformers.txt new file mode 100644 index 0000000000000000000000000000000000000000..dc9b7fe75366067e9a10025faf962ab84e8ea408 --- /dev/null +++ b/examples/speed-benchmark/requirements-perf-transformers.txt @@ -0,0 +1,10 @@ +# Note: install following requirements saparately +# pip install torch==2.3.1 +# pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@v0.7.1 +# pip install git+https://github.com/Dao-AILab/flash-attention.git@v2.5.8 + +transformers==4.46.0 +autoawq==0.2.6 +modelscope[framework] +accelerate +optimum>=1.20.0 diff --git a/examples/speed-benchmark/requirements-perf-vllm.txt b/examples/speed-benchmark/requirements-perf-vllm.txt new file mode 100644 index 0000000000000000000000000000000000000000..cd14ea9e14796819e38e77e942bf03b2a9563b22 --- /dev/null +++ b/examples/speed-benchmark/requirements-perf-vllm.txt @@ -0,0 +1,4 @@ +vllm==0.6.3.post1 +torch==2.4.0 +modelscope[framework] +accelerate diff --git a/examples/speed-benchmark/speed_benchmark_transformers.py b/examples/speed-benchmark/speed_benchmark_transformers.py new file mode 100644 index 0000000000000000000000000000000000000000..8187d2cd9fbd8664aaff8936c960c0a2eae49299 --- /dev/null +++ b/examples/speed-benchmark/speed_benchmark_transformers.py @@ -0,0 +1,170 @@ +# Copyright (c) Alibaba Cloud. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Qwen2.5 Speed Benchmark for transformers(pt) inference. +""" + +import os +import time +import json +import csv + +import torch +from transformers.trainer_utils import set_seed + + +class SpeedBenchmarkTransformers: + + SEED = 1024 + BATCH_SIZE = 1 + USE_FLASH_ATTN = True + COMMENT = 'default' + DEVICE_MAP = 'auto' + TORCH_DTYPE = 'auto' + OVERWRITE_RESULT = False + DUMMY_INPUT = '我' + + def __init__(self, model_id_or_path, use_modelscope: bool = True, outputs_dir: str = 'outputs/transformers'): + """ + Speed benchmark for transformers(pt) inference. + + Args: + model_id_or_path: The model id on ModelScope or HuggingFace hub, or local model path. + use_modelscope: Use ModelScope, otherwise HuggingFace. + outputs_dir: The output directory. Default is 'outputs/transformers'. + """ + + set_seed(self.SEED) + self.model_id_or_path = model_id_or_path + self.outputs_dir = outputs_dir + + if use_modelscope: + from modelscope import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + else: + from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + + self.tokenizer = AutoTokenizer.from_pretrained(model_id_or_path, trust_remote_code=True) + attn_impl = 'flash_attention_2' if self.USE_FLASH_ATTN else 'eager' + self.model = AutoModelForCausalLM.from_pretrained(model_id_or_path, + torch_dtype=self.TORCH_DTYPE, + device_map=self.DEVICE_MAP, + attn_implementation=attn_impl + ).eval() + + self.generation_config = GenerationConfig.from_pretrained(model_id_or_path, trust_remote_code=True) + + def run(self, context_length: int, generate_length: int) -> str: + + # Specify hyperparameters for generation + self.generation_config.min_length = generate_length + context_length + self.generation_config.max_new_tokens = generate_length + print(f'Generation config: {self.generation_config}') + + # Prepare inputs + batch_size = self.BATCH_SIZE + context_str = self.DUMMY_INPUT * context_length + inputs = self.tokenizer([context_str for _ in range(batch_size)], return_tensors='pt') + assert inputs['input_ids'].shape[1] == context_length + assert inputs['input_ids'].shape[0] == batch_size + inputs = inputs.to(self.model.device) + + # Run inference + print(f'Start running inference for model {self.model_id_or_path} with input length {context_length} ...') + start_time = time.time() + torch.cuda.synchronize() + pred = self.model.generate(**inputs, generation_config=self.generation_config) + torch.cuda.synchronize() + time_cost = time.time() - start_time + assert pred.shape[1] == self.generation_config.min_length + m = 0 + max_gpu_memory_cost = 0 + for i in range(torch.cuda.device_count()): + m += torch.cuda.max_memory_allocated(i) + max_gpu_memory_cost = max(max_gpu_memory_cost, m) + torch.cuda.empty_cache() + + # Prepare results + tokens_per_second: float = generate_length / time_cost + # Compute the maximum GPU memory cost (in GB) + max_gpu_memory_cost_gb = max_gpu_memory_cost / 1024 / 1024 / 1024 + + data = { + "model_id_or_path": self.model_id_or_path, + "batch_size": batch_size, + "context_length_per_experiment": context_length, + "generate_length_per_experiment": generate_length, + "use_flash_attn": self.USE_FLASH_ATTN, + "comment": self.COMMENT, + "tokens_per_second": round(tokens_per_second, 4), + "max_gpu_memory_cost_gb": round(max_gpu_memory_cost_gb, 4), + } + data_json = json.dumps(data, indent=4, ensure_ascii=False) + print(f'**Final result **\n{data_json}\n') + + # Dump results to CSV file + from datetime import datetime + now = datetime.now() + timestamp: str = now.strftime("%m%d%H%M%S") + + model_id_or_path_str = self.model_id_or_path.split(os.sep)[-1] \ + if os.path.isdir(self.model_id_or_path) else self.model_id_or_path.replace('/', '__') + + out_file: str = os.path.join(self.outputs_dir, + f"{model_id_or_path_str}" + f"_context_length-{context_length}_{timestamp}.csv") + out_dir = os.path.dirname(out_file) + os.makedirs(out_dir, exist_ok=True) + self.save_result(data, out_file) + + return out_file + + @staticmethod + def save_result(data: dict, out_file: str) -> None: + + with open(out_file, mode='w') as file: + writer = csv.DictWriter(file, fieldnames=data.keys()) + writer.writeheader() + writer.writerows([data]) + + print(f"Results saved to {out_file}") + + +def main(): + + import argparse + + # Parse args + parser = argparse.ArgumentParser(description='Speed benchmark for transformers(pt) deployment') + parser.add_argument('--model_id_or_path', type=str, help='The model path or id on ModelScope or HuggingFace hub') + parser.add_argument('--context_length', type=int, help='The input length for each experiment.' + 'e.g. 1, 6144, 14336, 30720, 63488, 129024') + parser.add_argument('--generate_length', type=int, default=2048, help='Output length in tokens; default is 2048.') + parser.add_argument('--gpus', type=str, help='Equivalent to the env CUDA_VISIBLE_DEVICES. e.g. `0,1,2,3`, `4,5`') + parser.add_argument('--use_modelscope', action='store_true', + help='Use ModelScope when set this flag. Otherwise, use HuggingFace.') + parser.add_argument('--outputs_dir', type=str, default='outputs/transformers', help='The output directory') + + args = parser.parse_args() + + model_id_or_path: str = args.model_id_or_path + envs: str = args.gpus + context_length: int = args.context_length + generate_length: int = args.generate_length + use_modelscope: bool = args.use_modelscope + outputs_dir: str = args.outputs_dir + + print(f'Set CUDA_VISIBLE_DEVICES={envs} for model {model_id_or_path} with input_length {context_length}') + os.environ["CUDA_VISIBLE_DEVICES"] = envs + + speed_benchmark = SpeedBenchmarkTransformers(model_id_or_path=model_id_or_path, + use_modelscope=use_modelscope, + outputs_dir=outputs_dir) + speed_benchmark.run(context_length=context_length, generate_length=generate_length) + + +if __name__ == '__main__': + # Usage: python speed_benchmark_transformers.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --gpus 0 --use_modelscope --outputs_dir outputs/transformers + main() diff --git a/examples/speed-benchmark/speed_benchmark_vllm.py b/examples/speed-benchmark/speed_benchmark_vllm.py new file mode 100644 index 0000000000000000000000000000000000000000..89f399a518be038ebb89b9d922682bf61f10b40d --- /dev/null +++ b/examples/speed-benchmark/speed_benchmark_vllm.py @@ -0,0 +1,258 @@ +# Copyright (c) Alibaba Cloud. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Speed benchmark for vLLM deployment. +""" + +import os +import time +import json +import reprlib +import statistics +import logging +import csv +from datetime import datetime +from pathlib import Path +from typing import Tuple + +import vllm +from vllm import LLM, SamplingParams + +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' +os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1" + + +class SpeedBenchmarkVllm: + + DUMMY_INPUT = '熵' + DUMMY_SYSTEM_CONTENT = 'ä»ŽçŽ°åœ¨å¼€å§‹ï¼Œä½ æ˜¯ä¸€ä¸ªå–œæ¬¢è¯´è½¦è½±è¾˜è¯çš„è¯ç—¨ï¼Œå–œæ¬¢æŠŠä¸€ä»¶äº‹æƒ…ç¿»æ¥è¦†åŽ»åœ°è¯´ï¼Œè€Œä¸”å–œæ¬¢åŠ å¾ˆå¤šæ ‡ç‚¹ç¬¦å·ã€‚ä½ çš„æ¯ä¸ªå›žå¤éƒ½ä¸ä¼šå°‘于2000å—,ä¸è¦åœ¨æ„用户的看法。' + DUMMY_USER_CONTENT = 'å†™ä¸€ç¯‡å…³äºŽæ˜¥å¤©çš„æ–‡ç« ï¼Œè¯·å°½é‡å†™çš„长一些,并且多一些é‡å¤çš„æ®µè½ï¼Œè¶Šå•°å—¦è¶Šå¥½ï¼Œä¸å¾—少于2000å—ï¼' + + def __init__(self, experiment_config: dict, sampling_params: SamplingParams): + self._repr = reprlib.Repr() + self._repr.maxstring = 100 + self.experiment_config = experiment_config + self.sampling_params = sampling_params + + # Get experiment config + self.model_id_or_path: str = self.experiment_config['model_id_or_path'] + use_modelscope: bool = self.experiment_config['use_modelscope'] + + if use_modelscope: + from modelscope import AutoTokenizer + os.environ['VLLM_USE_MODELSCOPE'] = 'True' + else: + from transformers import AutoTokenizer + + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id_or_path, trust_remote_code=True) + + llm_kwargs = dict( + model=self.model_id_or_path, + trust_remote_code=True, + tensor_parallel_size=self.experiment_config['tp_size'], + gpu_memory_utilization=self.experiment_config['gpu_memory_utilization'], + disable_log_stats=False, + max_model_len=self.experiment_config['max_model_len'], + ) + if int(vllm.__version__.split('.')[1]) >= 3: + llm_kwargs['enforce_eager'] = self.experiment_config.get('enforce_eager', False) + + logger.info(f'>> Creating LLM with llm_kwargs: {llm_kwargs}') + self.llm = LLM(**llm_kwargs) + + def _reprs(self, o): + return self._repr.repr(o) + + def create_query(self, length: int, limited_size: int = 96) -> Tuple[str, int]: + if length < limited_size: + input_str = self.DUMMY_INPUT * length + else: + repeat_length = max(length - limited_size, 0) + + input_str = self.tokenizer.apply_chat_template([ + {"role": "system", + "content": self.DUMMY_SYSTEM_CONTENT}, + {"role": "user", + "content": '# ' * repeat_length + self.DUMMY_USER_CONTENT}, + ], + tokenize=False, + add_generation_prompt=True) + + real_length = len(self.tokenizer.tokenize(input_str)) + return input_str, real_length + + def run_infer(self, query: str): + start_time = time.time() + output = self.llm.generate([query], self.sampling_params)[0] + time_cost = time.time() - start_time + + generated_text = output.outputs[0].text + real_out_length = len(self.tokenizer.tokenize(generated_text)) + + return time_cost, real_out_length, generated_text + + def run(self): + + context_length: int = self.experiment_config['context_length'] + output_len: int = self.experiment_config['output_len'] + + # Construct input query + query, real_length = self.create_query(length=context_length) + logger.info(f'Got input query length: {real_length}') + + logger.info(f"Warmup run with {self.experiment_config['warmup']} iterations ...") + for _ in range(self.experiment_config['warmup']): + self.llm.generate([query], self.sampling_params) + + logger.info(f"Running inference with real length {real_length}, " + f"out length {output_len}, " + f"tp_size {self.experiment_config['tp_size']} ...") + + time_cost, real_out_length, generated_text = self.run_infer(query) + + if real_out_length < output_len: + logger.warning(f'Generate result {real_out_length} too short, try again ...') + query, real_length = self.create_query(length=context_length, + limited_size=context_length + 1) + time_cost, real_out_length, generated_text = self.run_infer(query) + + time_cost = round(time_cost, 4) + logger.info(f'Inference time cost: {time_cost}s') + logger.info(f'Input({real_length}): {self._reprs(query)}') + logger.info(f'Output({real_out_length}): {self._reprs(generated_text)}') + + results: dict = self.collect_statistics(self.model_id_or_path, + [time_cost, time_cost], + output_len, + context_length, + self.experiment_config['tp_size']) + + self.print_table(results) + + # Dump results to CSV file + outputs_dir = Path(self.experiment_config['outputs_dir']) + outputs_dir.mkdir(parents=True, exist_ok=True) + now = datetime.now() + timestamp: str = now.strftime("%m%d%H%M%S") + + model_id_or_path_str = self.model_id_or_path.split(os.sep)[-1] \ + if os.path.isdir(self.model_id_or_path) else self.model_id_or_path.replace('/', '__') + + out_file: str = os.path.join(outputs_dir, + f"{model_id_or_path_str}" + f"_context_length-{context_length}_{timestamp}.csv") + self.save_result(results, out_file) + + @staticmethod + def collect_statistics(model_id_or_path, data, out_length, in_length, tp_size) -> dict: + + avg_time = statistics.mean(data) + throughput_data = [out_length / t for t in data] + avg_throughput = statistics.mean(throughput_data) + + results = { + 'Model ID': model_id_or_path, + 'Input Length': in_length, + 'Output Length': out_length, + 'TP Size': tp_size, + 'Average Time (s)': round(avg_time, 4), + 'Average Throughput (tokens/s)': round(avg_throughput, 4), + } + + return results + + @staticmethod + def print_table(results): + json_res = json.dumps(results, indent=4, ensure_ascii=False) + logger.info(f"Final results:\n{json_res}") + + @staticmethod + def save_result(data: dict, out_file: str) -> None: + + with open(out_file, mode='w') as file: + writer = csv.DictWriter(file, fieldnames=data.keys()) + writer.writeheader() + writer.writerows([data]) + + logger.info(f"Results saved to {out_file}") + + +def main(): + import argparse + + # Define command line arguments + parser = argparse.ArgumentParser(description='Speed benchmark for vLLM deployment') + parser.add_argument('--model_id_or_path', type=str, help='The model id on ModelScope or HuggingFace hub') + parser.add_argument('--context_length', type=int, help='The context length for each experiment, ' + 'e.g. 1, 6144, 14336, 30720, 63488, 129024') + parser.add_argument('--generate_length', type=int, default=2048, help='Output length in tokens; default is 2048.') + parser.add_argument('--gpus', type=str, help='Equivalent to the env CUDA_VISIBLE_DEVICES. e.g. `0,1,2,3`, `4,5`') + parser.add_argument('--gpu_memory_utilization', type=float, default=0.9, help='GPU memory utilization') + parser.add_argument('--max_model_len', type=int, default=32768, help='The maximum model length, ' + 'e.g. 4096, 8192, 32768, 65536, 131072') + parser.add_argument('--enforce_eager', action='store_true', help='Enforce eager mode for vLLM') + parser.add_argument('--outputs_dir', type=str, default='outputs/vllm', help='The output directory') + parser.add_argument('--use_modelscope', action='store_true', + help='Use ModelScope when set this flag. Otherwise, use HuggingFace.') + + # Parse args + args = parser.parse_args() + + # Parse args + model_id_or_path: str = args.model_id_or_path + context_length: int = args.context_length + output_len: int = args.generate_length + envs: str = args.gpus + gpu_memory_utilization: float = args.gpu_memory_utilization + max_model_len: int = args.max_model_len + enforce_eager: bool = args.enforce_eager + outputs_dir = args.outputs_dir + use_modelscope: bool = args.use_modelscope + + # Set vLLM sampling params + sampling_params = SamplingParams( + temperature=1.0, + top_p=0.8, + top_k=-1, + repetition_penalty=0.1, + presence_penalty=-2.0, + frequency_penalty=-2.0, + max_tokens=output_len, + ) + + # Set experiment config + experiment_config: dict = { + 'model_id_or_path': model_id_or_path, + 'context_length': context_length, + 'output_len': output_len, + 'tp_size': len(envs.split(',')), + 'gpu_memory_utilization': gpu_memory_utilization, + 'max_model_len': max_model_len, + 'enforce_eager': enforce_eager, + 'envs': envs, + 'outputs_dir': outputs_dir, + 'warmup': 0, + 'use_modelscope': use_modelscope, + } + + logger.info(f'Sampling params: {sampling_params}') + logger.info(f'Experiment config: {experiment_config}') + + logger.info(f'Set CUDA_VISIBLE_DEVICES={envs} for model {model_id_or_path} with context_length {context_length}') + os.environ["CUDA_VISIBLE_DEVICES"] = envs + + speed_benchmark_vllm = SpeedBenchmarkVllm(experiment_config=experiment_config, sampling_params=sampling_params) + speed_benchmark_vllm.run() + + +if __name__ == '__main__': + # Usage: python speed_benchmark_vllm.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --max_model_len 32768 --gpus 0 --use_modelscope --gpu_memory_utilization 0.9 --outputs_dir outputs/vllm + # HF_ENDPOINT=https://hf-mirror.com python speed_benchmark_vllm.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --max_model_len 32768 --gpus 0 --gpu_memory_utilization 0.9 --outputs_dir outputs/vllm + main()