Skip to content
Snippets Groups Projects
Unverified Commit 8205ba2a authored by Xingjun.Wang's avatar Xingjun.Wang Committed by GitHub
Browse files

Add speed benchmark examples (#1068)


* add qwen2.5 perf report

* update readme

* rebuild docs and fix format issue

* remove fuzzy in speed_benchmark.po

* fix issue

* recover function_call.po

* update

* remove unused code in speed_benchmark.po

* add example

* add readme for speed benchmark scripts

* update readme

* update readme

* update

* refine code

* fix pr issue

* fix some issue for PR

* update installation

* add --generate_length param

* update

* update requirements

* Update README_zh.md

* Update README.md

---------

Co-authored-by: default avatarRen Xuancheng <jklj077@users.noreply.github.com>
parent f45f6b41
No related branches found
No related tags found
No related merge requests found
## Speed Benchmark
This document introduces the speed benchmark testing process for the Qwen2.5 series models (original and quantized models). For detailed reports, please refer to the [Qwen2.5 Speed Benchmark](https://qwen.readthedocs.io/en/latest/benchmark/speed_benchmark.html).
### 1. Model Collections
For models hosted on HuggingFace, please refer to [Qwen2.5 Collections-HuggingFace](https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e).
For models hosted on ModelScope, please refer to [Qwen2.5 Collections-ModelScope](https://modelscope.cn/collections/Qwen25-dbc4d30adb768).
### 2. Environment Installation
For inference using HuggingFace transformers:
```shell
conda create -n qwen_perf_transformers python=3.10
conda activate qwen_perf_transformers
pip install torch==2.3.1
pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@v0.7.1
pip install git+https://github.com/Dao-AILab/flash-attention.git@v2.5.8
pip install -r requirements-perf-transformers.txt
```
> [!Important]
> - For `flash-attention`, you can use the prebulit wheels from [GitHub Releases](https://github.com/Dao-AILab/flash-attention/releases/tag/v2.5.8) or installing from source, which requires a compatible CUDA compiler.
> - You don't actually need to install `flash-attention`. It has been intergrated into `torch` as a backend of `sdpa`.
> - For `auto_gptq` to use efficent kernels, you need to install from source, because the prebuilt wheels require incompatible `torch` versions. Installing from source also requires a compatible CUDA compiler.
> - For `autoawq` to use efficent kenerls, you need `autoawq-kernels`, which should be automatically installed. If not, run `pip install autoawq-kernels`.
For inference using vLLM:
```shell
conda create -n qwen_perf_vllm python=3.10
conda activate qwen_perf_vllm
pip install -r requirements-perf-vllm.txt
```
### 3. Run Experiments
#### 3.1 Inference using HuggingFace Transformers
- Use HuggingFace hub
```shell
python speed_benchmark_transformers.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --gpus 0 --outputs_dir outputs/transformers
```
- Use ModelScope hub
```shell
python speed_benchmark_transformers.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --gpus 0 --use_modelscope --outputs_dir outputs/transformers
```
Parameters:
`--model_id_or_path`: The model path or id on ModelScope or HuggingFace hub
`--context_length`: Input length in tokens; optional values are 1, 6144, 14336, 30720, 63488, 129024; Refer to the `Qwen2.5 SpeedBenchmark`.
`--generate_length`: Output length in tokens; default is 2048.
`--gpus`: Equivalent to the environment variable CUDA_VISIBLE_DEVICES. e.g. `0,1,2,3`, `4,5`
`--use_modelscope`: Use ModelScope when set this flag. Otherwise, use HuggingFace.
`--outputs_dir`: Output directory; default is outputs/transformers.
#### 3.2 Inference using vLLM
- Use HuggingFace hub
```shell
python speed_benchmark_vllm.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --max_model_len 32768 --gpus 0 --gpu_memory_utilization 0.9 --outputs_dir outputs/vllm
```
- Use ModelScope hub
```shell
python speed_benchmark_vllm.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --max_model_len 32768 --gpus 0 --use_modelscope --gpu_memory_utilization 0.9 --outputs_dir outputs/vllm
```
Parameters:
`--model_id_or_path`: The model id on ModelScope or HuggingFace hub.
`--context_length`: Input length in tokens; optional values are 1, 6144, 14336, 30720, 63488, 129024; Refer to the `Qwen2.5 SpeedBenchmark`.
`--generate_length`: Output length in tokens; default is 2048.
`--max_model_len`: Maximum model length in tokens; default is 32768. Optional values are 4096, 8192, 32768, 65536, 131072.
`--gpus`: Equivalent to the environment variable CUDA_VISIBLE_DEVICES. e.g. `0,1,2,3`, `4,5`
`--use_modelscope`: Use ModelScope when set this flag. Otherwise, use HuggingFace.
`--gpu_memory_utilization`: GPU memory utilization; range is (0, 1]; default is 0.9.
`--outputs_dir`: Output directory; default is outputs/vllm.
`--enforce_eager`: Whether to enforce eager mode; default is False.
#### 3.3 Tips
- Run multiple experiments and compute the average result; a typical number is 3 times.
- Make sure the GPU is idle before running experiments.
### 4. Results
Please check the `outputs` directory, which includes two directories by default: `transformers` and `vllm`, containing the experiments results for HuggingFace transformers and vLLM, respectively.
## 效率评估
本文介绍Qwen2.5系列模型(原始模型和量化模型)的效率测试流程,详细报告可参考 [Qwen2.5模型效率评估报告](https://qwen.readthedocs.io/en/latest/benchmark/speed_benchmark.html)
### 1. 模型资源
对于托管在HuggingFace上的模型,可参考 [Qwen2.5模型-HuggingFace](https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e)
对于托管在ModelScope上的模型,可参考 [Qwen2.5模型-ModelScope](https://modelscope.cn/collections/Qwen25-dbc4d30adb768)
### 2. 环境安装
使用HuggingFace transformers推理,安装环境如下:
```shell
conda create -n qwen_perf_transformers python=3.10
conda activate qwen_perf_transformers
pip install torch==2.3.1
pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@v0.7.1
pip install git+https://github.com/Dao-AILab/flash-attention.git@v2.5.8
pip install -r requirements-perf-transformers.txt
```
> [!Important]
> - 对于 `flash-attention`,您可以从 [GitHub 发布页面](https://github.com/Dao-AILab/flash-attention/releases/tag/v2.5.8) 使用预编译的 wheel 包进行安装,或者从源代码安装,后者需要一个兼容的 CUDA 编译器。
> - 实际上,您并不需要单独安装 `flash-attention`。它已经被集成到了 `torch` 中作为 `sdpa` 的后端实现。
> - 若要使 `auto_gptq` 使用高效的内核,您需要从源代码安装,因为预编译的 wheel 包依赖于与之不兼容的 `torch` 版本。从源代码安装同样需要一个兼容的 CUDA 编译器。
> - 若要使 `autoawq` 使用高效的内核,您需要安装 `autoawq-kernels`,该组件应当会自动安装。如果未自动安装,请运行 `pip install autoawq-kernels` 进行手动安装。
使用vLLM推理,安装环境如下:
```shell
conda create -n qwen_perf_vllm python=3.10
conda activate qwen_perf_vllm
pip install -r requirements-perf-vllm.txt
```
### 3. 执行测试
#### 3.1 使用HuggingFace transformers推理
- 使用HuggingFace hub
```shell
python speed_benchmark_transformers.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --gpus 0 --outputs_dir outputs/transformers
# 指定HF_ENDPOINT
HF_ENDPOINT=https://hf-mirror.com python speed_benchmark_transformers.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --gpus 0 --outputs_dir outputs/transformers
```
- 使用ModelScope hub
```shell
python speed_benchmark_transformers.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --gpus 0 --use_modelscope --outputs_dir outputs/transformers
```
参数说明:
`--model_id_or_path`: 模型ID或本地路径, 可选值参考`模型资源`章节
`--context_length`: 输入长度,单位为token数;可选值为1, 6144, 14336, 30720, 63488, 129024;具体可参考`Qwen2.5模型效率评估报告`
`--generate_length`: 生成token数量;默认为2048
`--gpus`: 等价于环境变量CUDA_VISIBLE_DEVICES,例如`0,1,2,3`,`4,5`
`--use_modelscope`: 如果设置该值,则使用ModelScope加载模型,否则使用HuggingFace
`--outputs_dir`: 输出目录, 默认为`outputs/transformers`
#### 3.2 使用vLLM推理
- 使用HuggingFace hub
```shell
python speed_benchmark_vllm.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --max_model_len 32768 --gpus 0 --gpu_memory_utilization 0.9 --outputs_dir outputs/vllm
# 指定HF_ENDPOINT
HF_ENDPOINT=https://hf-mirror.com python speed_benchmark_vllm.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --max_model_len 32768 --gpus 0 --gpu_memory_utilization 0.9 --outputs_dir outputs/vllm
```
- 使用ModelScope hub
```shell
python speed_benchmark_vllm.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --max_model_len 32768 --gpus 0 --use_modelscope --gpu_memory_utilization 0.9 --outputs_dir outputs/vllm
```
参数说明:
`--model_id_or_path`: 模型ID或本地路径, 可选值参考`模型资源`章节
`--context_length`: 输入长度,单位为token数;可选值为1, 6144, 14336, 30720, 63488, 129024;具体可参考`Qwen2.5模型效率评估报告`
`--generate_length`: 生成token数量;默认为2048
`--max_model_len`: 模型最大长度,单位为token数;默认为32768
`--gpus`: 等价于环境变量CUDA_VISIBLE_DEVICES,例如`0,1,2,3`,`4,5`
`--use_modelscope`: 如果设置该值,则使用ModelScope加载模型,否则使用HuggingFace
`--gpu_memory_utilization`: GPU内存利用率,取值范围为(0, 1];默认为0.9
`--outputs_dir`: 输出目录, 默认为`outputs/vllm`
`--enforce_eager`: 是否强制使用eager模式;默认为False
#### 3.3 注意事项
1. 多次测试,取平均值,典型值为3次
2. 测试前请确保GPU处于空闲状态,避免其他任务影响测试结果
### 4. 测试结果
测试结果详见`outputs`目录下的文件,默认包括`transformers``vllm`两个目录,分别存放HuggingFace transformers和vLLM的测试结果。
# Note: install following requirements saparately
# pip install torch==2.3.1
# pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@v0.7.1
# pip install git+https://github.com/Dao-AILab/flash-attention.git@v2.5.8
transformers==4.46.0
autoawq==0.2.6
modelscope[framework]
accelerate
optimum>=1.20.0
vllm==0.6.3.post1
torch==2.4.0
modelscope[framework]
accelerate
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Qwen2.5 Speed Benchmark for transformers(pt) inference.
"""
import os
import time
import json
import csv
import torch
from transformers.trainer_utils import set_seed
class SpeedBenchmarkTransformers:
SEED = 1024
BATCH_SIZE = 1
USE_FLASH_ATTN = True
COMMENT = 'default'
DEVICE_MAP = 'auto'
TORCH_DTYPE = 'auto'
OVERWRITE_RESULT = False
DUMMY_INPUT = ''
def __init__(self, model_id_or_path, use_modelscope: bool = True, outputs_dir: str = 'outputs/transformers'):
"""
Speed benchmark for transformers(pt) inference.
Args:
model_id_or_path: The model id on ModelScope or HuggingFace hub, or local model path.
use_modelscope: Use ModelScope, otherwise HuggingFace.
outputs_dir: The output directory. Default is 'outputs/transformers'.
"""
set_seed(self.SEED)
self.model_id_or_path = model_id_or_path
self.outputs_dir = outputs_dir
if use_modelscope:
from modelscope import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
else:
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
self.tokenizer = AutoTokenizer.from_pretrained(model_id_or_path, trust_remote_code=True)
attn_impl = 'flash_attention_2' if self.USE_FLASH_ATTN else 'eager'
self.model = AutoModelForCausalLM.from_pretrained(model_id_or_path,
torch_dtype=self.TORCH_DTYPE,
device_map=self.DEVICE_MAP,
attn_implementation=attn_impl
).eval()
self.generation_config = GenerationConfig.from_pretrained(model_id_or_path, trust_remote_code=True)
def run(self, context_length: int, generate_length: int) -> str:
# Specify hyperparameters for generation
self.generation_config.min_length = generate_length + context_length
self.generation_config.max_new_tokens = generate_length
print(f'Generation config: {self.generation_config}')
# Prepare inputs
batch_size = self.BATCH_SIZE
context_str = self.DUMMY_INPUT * context_length
inputs = self.tokenizer([context_str for _ in range(batch_size)], return_tensors='pt')
assert inputs['input_ids'].shape[1] == context_length
assert inputs['input_ids'].shape[0] == batch_size
inputs = inputs.to(self.model.device)
# Run inference
print(f'Start running inference for model {self.model_id_or_path} with input length {context_length} ...')
start_time = time.time()
torch.cuda.synchronize()
pred = self.model.generate(**inputs, generation_config=self.generation_config)
torch.cuda.synchronize()
time_cost = time.time() - start_time
assert pred.shape[1] == self.generation_config.min_length
m = 0
max_gpu_memory_cost = 0
for i in range(torch.cuda.device_count()):
m += torch.cuda.max_memory_allocated(i)
max_gpu_memory_cost = max(max_gpu_memory_cost, m)
torch.cuda.empty_cache()
# Prepare results
tokens_per_second: float = generate_length / time_cost
# Compute the maximum GPU memory cost (in GB)
max_gpu_memory_cost_gb = max_gpu_memory_cost / 1024 / 1024 / 1024
data = {
"model_id_or_path": self.model_id_or_path,
"batch_size": batch_size,
"context_length_per_experiment": context_length,
"generate_length_per_experiment": generate_length,
"use_flash_attn": self.USE_FLASH_ATTN,
"comment": self.COMMENT,
"tokens_per_second": round(tokens_per_second, 4),
"max_gpu_memory_cost_gb": round(max_gpu_memory_cost_gb, 4),
}
data_json = json.dumps(data, indent=4, ensure_ascii=False)
print(f'**Final result **\n{data_json}\n')
# Dump results to CSV file
from datetime import datetime
now = datetime.now()
timestamp: str = now.strftime("%m%d%H%M%S")
model_id_or_path_str = self.model_id_or_path.split(os.sep)[-1] \
if os.path.isdir(self.model_id_or_path) else self.model_id_or_path.replace('/', '__')
out_file: str = os.path.join(self.outputs_dir,
f"{model_id_or_path_str}"
f"_context_length-{context_length}_{timestamp}.csv")
out_dir = os.path.dirname(out_file)
os.makedirs(out_dir, exist_ok=True)
self.save_result(data, out_file)
return out_file
@staticmethod
def save_result(data: dict, out_file: str) -> None:
with open(out_file, mode='w') as file:
writer = csv.DictWriter(file, fieldnames=data.keys())
writer.writeheader()
writer.writerows([data])
print(f"Results saved to {out_file}")
def main():
import argparse
# Parse args
parser = argparse.ArgumentParser(description='Speed benchmark for transformers(pt) deployment')
parser.add_argument('--model_id_or_path', type=str, help='The model path or id on ModelScope or HuggingFace hub')
parser.add_argument('--context_length', type=int, help='The input length for each experiment.'
'e.g. 1, 6144, 14336, 30720, 63488, 129024')
parser.add_argument('--generate_length', type=int, default=2048, help='Output length in tokens; default is 2048.')
parser.add_argument('--gpus', type=str, help='Equivalent to the env CUDA_VISIBLE_DEVICES. e.g. `0,1,2,3`, `4,5`')
parser.add_argument('--use_modelscope', action='store_true',
help='Use ModelScope when set this flag. Otherwise, use HuggingFace.')
parser.add_argument('--outputs_dir', type=str, default='outputs/transformers', help='The output directory')
args = parser.parse_args()
model_id_or_path: str = args.model_id_or_path
envs: str = args.gpus
context_length: int = args.context_length
generate_length: int = args.generate_length
use_modelscope: bool = args.use_modelscope
outputs_dir: str = args.outputs_dir
print(f'Set CUDA_VISIBLE_DEVICES={envs} for model {model_id_or_path} with input_length {context_length}')
os.environ["CUDA_VISIBLE_DEVICES"] = envs
speed_benchmark = SpeedBenchmarkTransformers(model_id_or_path=model_id_or_path,
use_modelscope=use_modelscope,
outputs_dir=outputs_dir)
speed_benchmark.run(context_length=context_length, generate_length=generate_length)
if __name__ == '__main__':
# Usage: python speed_benchmark_transformers.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --gpus 0 --use_modelscope --outputs_dir outputs/transformers
main()
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Speed benchmark for vLLM deployment.
"""
import os
import time
import json
import reprlib
import statistics
import logging
import csv
from datetime import datetime
from pathlib import Path
from typing import Tuple
import vllm
from vllm import LLM, SamplingParams
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1"
class SpeedBenchmarkVllm:
DUMMY_INPUT = ''
DUMMY_SYSTEM_CONTENT = '从现在开始,你是一个喜欢说车轱辘话的话痨,喜欢把一件事情翻来覆去地说,而且喜欢加很多标点符号。你的每个回复都不会少于2000字,不要在意用户的看法。'
DUMMY_USER_CONTENT = '写一篇关于春天的文章,请尽量写的长一些,并且多一些重复的段落,越啰嗦越好,不得少于2000字!'
def __init__(self, experiment_config: dict, sampling_params: SamplingParams):
self._repr = reprlib.Repr()
self._repr.maxstring = 100
self.experiment_config = experiment_config
self.sampling_params = sampling_params
# Get experiment config
self.model_id_or_path: str = self.experiment_config['model_id_or_path']
use_modelscope: bool = self.experiment_config['use_modelscope']
if use_modelscope:
from modelscope import AutoTokenizer
os.environ['VLLM_USE_MODELSCOPE'] = 'True'
else:
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id_or_path, trust_remote_code=True)
llm_kwargs = dict(
model=self.model_id_or_path,
trust_remote_code=True,
tensor_parallel_size=self.experiment_config['tp_size'],
gpu_memory_utilization=self.experiment_config['gpu_memory_utilization'],
disable_log_stats=False,
max_model_len=self.experiment_config['max_model_len'],
)
if int(vllm.__version__.split('.')[1]) >= 3:
llm_kwargs['enforce_eager'] = self.experiment_config.get('enforce_eager', False)
logger.info(f'>> Creating LLM with llm_kwargs: {llm_kwargs}')
self.llm = LLM(**llm_kwargs)
def _reprs(self, o):
return self._repr.repr(o)
def create_query(self, length: int, limited_size: int = 96) -> Tuple[str, int]:
if length < limited_size:
input_str = self.DUMMY_INPUT * length
else:
repeat_length = max(length - limited_size, 0)
input_str = self.tokenizer.apply_chat_template([
{"role": "system",
"content": self.DUMMY_SYSTEM_CONTENT},
{"role": "user",
"content": '# ' * repeat_length + self.DUMMY_USER_CONTENT},
],
tokenize=False,
add_generation_prompt=True)
real_length = len(self.tokenizer.tokenize(input_str))
return input_str, real_length
def run_infer(self, query: str):
start_time = time.time()
output = self.llm.generate([query], self.sampling_params)[0]
time_cost = time.time() - start_time
generated_text = output.outputs[0].text
real_out_length = len(self.tokenizer.tokenize(generated_text))
return time_cost, real_out_length, generated_text
def run(self):
context_length: int = self.experiment_config['context_length']
output_len: int = self.experiment_config['output_len']
# Construct input query
query, real_length = self.create_query(length=context_length)
logger.info(f'Got input query length: {real_length}')
logger.info(f"Warmup run with {self.experiment_config['warmup']} iterations ...")
for _ in range(self.experiment_config['warmup']):
self.llm.generate([query], self.sampling_params)
logger.info(f"Running inference with real length {real_length}, "
f"out length {output_len}, "
f"tp_size {self.experiment_config['tp_size']} ...")
time_cost, real_out_length, generated_text = self.run_infer(query)
if real_out_length < output_len:
logger.warning(f'Generate result {real_out_length} too short, try again ...')
query, real_length = self.create_query(length=context_length,
limited_size=context_length + 1)
time_cost, real_out_length, generated_text = self.run_infer(query)
time_cost = round(time_cost, 4)
logger.info(f'Inference time cost: {time_cost}s')
logger.info(f'Input({real_length}): {self._reprs(query)}')
logger.info(f'Output({real_out_length}): {self._reprs(generated_text)}')
results: dict = self.collect_statistics(self.model_id_or_path,
[time_cost, time_cost],
output_len,
context_length,
self.experiment_config['tp_size'])
self.print_table(results)
# Dump results to CSV file
outputs_dir = Path(self.experiment_config['outputs_dir'])
outputs_dir.mkdir(parents=True, exist_ok=True)
now = datetime.now()
timestamp: str = now.strftime("%m%d%H%M%S")
model_id_or_path_str = self.model_id_or_path.split(os.sep)[-1] \
if os.path.isdir(self.model_id_or_path) else self.model_id_or_path.replace('/', '__')
out_file: str = os.path.join(outputs_dir,
f"{model_id_or_path_str}"
f"_context_length-{context_length}_{timestamp}.csv")
self.save_result(results, out_file)
@staticmethod
def collect_statistics(model_id_or_path, data, out_length, in_length, tp_size) -> dict:
avg_time = statistics.mean(data)
throughput_data = [out_length / t for t in data]
avg_throughput = statistics.mean(throughput_data)
results = {
'Model ID': model_id_or_path,
'Input Length': in_length,
'Output Length': out_length,
'TP Size': tp_size,
'Average Time (s)': round(avg_time, 4),
'Average Throughput (tokens/s)': round(avg_throughput, 4),
}
return results
@staticmethod
def print_table(results):
json_res = json.dumps(results, indent=4, ensure_ascii=False)
logger.info(f"Final results:\n{json_res}")
@staticmethod
def save_result(data: dict, out_file: str) -> None:
with open(out_file, mode='w') as file:
writer = csv.DictWriter(file, fieldnames=data.keys())
writer.writeheader()
writer.writerows([data])
logger.info(f"Results saved to {out_file}")
def main():
import argparse
# Define command line arguments
parser = argparse.ArgumentParser(description='Speed benchmark for vLLM deployment')
parser.add_argument('--model_id_or_path', type=str, help='The model id on ModelScope or HuggingFace hub')
parser.add_argument('--context_length', type=int, help='The context length for each experiment, '
'e.g. 1, 6144, 14336, 30720, 63488, 129024')
parser.add_argument('--generate_length', type=int, default=2048, help='Output length in tokens; default is 2048.')
parser.add_argument('--gpus', type=str, help='Equivalent to the env CUDA_VISIBLE_DEVICES. e.g. `0,1,2,3`, `4,5`')
parser.add_argument('--gpu_memory_utilization', type=float, default=0.9, help='GPU memory utilization')
parser.add_argument('--max_model_len', type=int, default=32768, help='The maximum model length, '
'e.g. 4096, 8192, 32768, 65536, 131072')
parser.add_argument('--enforce_eager', action='store_true', help='Enforce eager mode for vLLM')
parser.add_argument('--outputs_dir', type=str, default='outputs/vllm', help='The output directory')
parser.add_argument('--use_modelscope', action='store_true',
help='Use ModelScope when set this flag. Otherwise, use HuggingFace.')
# Parse args
args = parser.parse_args()
# Parse args
model_id_or_path: str = args.model_id_or_path
context_length: int = args.context_length
output_len: int = args.generate_length
envs: str = args.gpus
gpu_memory_utilization: float = args.gpu_memory_utilization
max_model_len: int = args.max_model_len
enforce_eager: bool = args.enforce_eager
outputs_dir = args.outputs_dir
use_modelscope: bool = args.use_modelscope
# Set vLLM sampling params
sampling_params = SamplingParams(
temperature=1.0,
top_p=0.8,
top_k=-1,
repetition_penalty=0.1,
presence_penalty=-2.0,
frequency_penalty=-2.0,
max_tokens=output_len,
)
# Set experiment config
experiment_config: dict = {
'model_id_or_path': model_id_or_path,
'context_length': context_length,
'output_len': output_len,
'tp_size': len(envs.split(',')),
'gpu_memory_utilization': gpu_memory_utilization,
'max_model_len': max_model_len,
'enforce_eager': enforce_eager,
'envs': envs,
'outputs_dir': outputs_dir,
'warmup': 0,
'use_modelscope': use_modelscope,
}
logger.info(f'Sampling params: {sampling_params}')
logger.info(f'Experiment config: {experiment_config}')
logger.info(f'Set CUDA_VISIBLE_DEVICES={envs} for model {model_id_or_path} with context_length {context_length}')
os.environ["CUDA_VISIBLE_DEVICES"] = envs
speed_benchmark_vllm = SpeedBenchmarkVllm(experiment_config=experiment_config, sampling_params=sampling_params)
speed_benchmark_vllm.run()
if __name__ == '__main__':
# Usage: python speed_benchmark_vllm.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --max_model_len 32768 --gpus 0 --use_modelscope --gpu_memory_utilization 0.9 --outputs_dir outputs/vllm
# HF_ENDPOINT=https://hf-mirror.com python speed_benchmark_vllm.py --model_id_or_path Qwen/Qwen2.5-0.5B-Instruct --context_length 1 --max_model_len 32768 --gpus 0 --gpu_memory_utilization 0.9 --outputs_dir outputs/vllm
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment