Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 238 additions & 0 deletions test/bench/bench_vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
import argparse
import itertools
import time
import random

import torch
import json


from vllm.engine.llm_engine import LLMEngine
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import TokensPrompt
from vllm.sampling_params import SamplingParams

# import os
# import logging
# os.environ["VLLM_LOGGING_LEVEL"] = "ERROR"
# logging.getLogger("vllm").setLevel(logging.ERROR)


def parse_list(value: str):
"""Parse parse_list argument: can be a single int or a list of ints.

Examples:
"1" -> 1
"[1,2,4]" -> [1, 2, 4]
"1,2,4" -> [1, 2, 4]
"""
value = value.strip()
# Try to parse as JSON list first
if value.startswith("[") and value.endswith("]"):
try:
result = json.loads(value)
if isinstance(result, list):
return [int(x) for x in result]
return int(result)
except (json.JSONDecodeError, ValueError):
pass

# Try to parse as comma-separated values
if "," in value:
try:
return [int(x.strip()) for x in value.split(",")]
except ValueError:
pass

# Try to parse as a single integer
try:
return [int(value)]
except ValueError:
raise argparse.ArgumentTypeError(
f"batch-size must be an int or list[int], got: {value}"
)


def run_one_case(
engine: LLMEngine,
batch_size: int,
input_len: int,
output_len: int,
vocab_size: int,
):
# ------------------------------------------------------------
# 1. Random input token IDs
# ------------------------------------------------------------
input_ids_list = [
[random.randint(0, vocab_size - 1) for _ in range(input_len)]
for _ in range(batch_size)
]

sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=1.0,
top_p=0.9,
top_k=50,
)

request_ids = []
for i, input_ids in enumerate(input_ids_list):
rid = f"req_{i}"
engine.add_request(
request_id=rid,
prompt=TokensPrompt(prompt_token_ids=input_ids),
params=sampling_params,
)
request_ids.append(rid)

# ------------------------------------------------------------
# 2. Run until first decode token appears for all requests (prefill timing)
# ------------------------------------------------------------
t0 = time.perf_counter()
pre_decode = 0 # some decode tokens can be mixed with prefill batch
pending = set(f"req_{i}" for i in range(batch_size))
while pending:
outputs = engine.step()
for out in outputs:
if len(out.outputs[0].token_ids) > 0:
if out.request_id in pending:
pending.remove(out.request_id)
else:
pre_decode += 1
torch.cuda.synchronize()
t1 = time.perf_counter()

prefill_time = t1 - t0
prefill_tokens = batch_size * input_len

# ------------------------------------------------------------
# 3. Decode until all requests finish
# ------------------------------------------------------------
decode_start = time.perf_counter()

while engine.has_unfinished_requests():
outputs = engine.step()

torch.cuda.synchronize()
decode_end = time.perf_counter()

decode_time = decode_end - decode_start
decode_tokens = (
batch_size * (output_len - 1) - pre_decode
) # exclude prefill-mixed tokens

return {
"batch_size": batch_size,
"input_len": input_len,
"output_len": output_len,
"prefill_tput": prefill_tokens / prefill_time,
"decode_tput": decode_tokens / decode_time,
}


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, required=True)
parser.add_argument(
"--batch-size",
type=parse_list,
default=[1],
help=(
"number of prompts in a batch (int or list, e.g. '1', '1,2,4', '[1,2,4]')"
),
)
parser.add_argument(
"--input-len",
type=parse_list,
default=[256],
help="input sequence length(s)",
)
parser.add_argument(
"--output-len",
type=parse_list,
default=[256],
help="output sequence length(s)",
)
parser.add_argument(
"--tensor-parallel-size",
"--tp",
type=int,
default=1,
help="total rank for tensor parallel",
)
parser.add_argument("--dtype", type=str, default="auto")
parser.add_argument("--num-iters-warmup", type=int, default=2)
args = parser.parse_args()

# ------------------------------------------------------------
# Engine init (TP supported here)
# ------------------------------------------------------------
engine_args = EngineArgs(
model=args.model,
dtype=args.dtype,
tensor_parallel_size=args.tensor_parallel_size,
trust_remote_code=True,
distributed_executor_backend="mp",
enable_chunked_prefill=False,
)

engine = LLMEngine.from_engine_args(engine_args)
vocab_size = engine.model_config.get_vocab_size()

# ------------------------------------------------------------
# Sweep all combinations
# ------------------------------------------------------------
print("\n=== Running benchmark ===")
results = []

try:
for bs, il, ol in itertools.product(
args.batch_size, args.input_len, args.output_len
):
# Warmup
for _ in range(args.num_iters_warmup):
run_one_case(
engine,
batch_size=bs,
input_len=il,
output_len=ol,
vocab_size=vocab_size,
)

res = run_one_case(
engine,
batch_size=bs,
input_len=il,
output_len=ol,
vocab_size=vocab_size,
)
results.append(res)

print(
f"[TP={args.tensor_parallel_size} | "
f"bs={bs} in={il} out={ol}] "
f"prefill={res['prefill_tput']} tok/s | "
f"decode={res['decode_tput']} tok/s"
)
except Exception as e:
print(f"Error Occured: {e}")
# ------------------------------------------------------------
# Summary
# ------------------------------------------------------------
print("\n=== Summary ===")
print("bs in_len out_len prefill_tok/s decode_tok/s")
for r in results:
print(
f"{r['batch_size']:3d} "
f"{r['input_len']:7d} "
f"{r['output_len']:8d} "
f"{r['prefill_tput']:14.2f} "
f"{r['decode_tput']:14.2f}"
)


if __name__ == "__main__":
torch.manual_seed(0)
random.seed(0)
main()