diff --git a/test/bench/bench_vllm.py b/test/bench/bench_vllm.py new file mode 100644 index 00000000..4338a3d2 --- /dev/null +++ b/test/bench/bench_vllm.py @@ -0,0 +1,238 @@ +import argparse +import itertools +import time +import random + +import torch +import json + + +from vllm.engine.llm_engine import LLMEngine +from vllm.engine.arg_utils import EngineArgs +from vllm.inputs import TokensPrompt +from vllm.sampling_params import SamplingParams + +# import os +# import logging +# os.environ["VLLM_LOGGING_LEVEL"] = "ERROR" +# logging.getLogger("vllm").setLevel(logging.ERROR) + + +def parse_list(value: str): + """Parse parse_list argument: can be a single int or a list of ints. + + Examples: + "1" -> 1 + "[1,2,4]" -> [1, 2, 4] + "1,2,4" -> [1, 2, 4] + """ + value = value.strip() + # Try to parse as JSON list first + if value.startswith("[") and value.endswith("]"): + try: + result = json.loads(value) + if isinstance(result, list): + return [int(x) for x in result] + return int(result) + except (json.JSONDecodeError, ValueError): + pass + + # Try to parse as comma-separated values + if "," in value: + try: + return [int(x.strip()) for x in value.split(",")] + except ValueError: + pass + + # Try to parse as a single integer + try: + return [int(value)] + except ValueError: + raise argparse.ArgumentTypeError( + f"batch-size must be an int or list[int], got: {value}" + ) + + +def run_one_case( + engine: LLMEngine, + batch_size: int, + input_len: int, + output_len: int, + vocab_size: int, +): + # ------------------------------------------------------------ + # 1. Random input token IDs + # ------------------------------------------------------------ + input_ids_list = [ + [random.randint(0, vocab_size - 1) for _ in range(input_len)] + for _ in range(batch_size) + ] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=1.0, + top_p=0.9, + top_k=50, + ) + + request_ids = [] + for i, input_ids in enumerate(input_ids_list): + rid = f"req_{i}" + engine.add_request( + request_id=rid, + prompt=TokensPrompt(prompt_token_ids=input_ids), + params=sampling_params, + ) + request_ids.append(rid) + + # ------------------------------------------------------------ + # 2. Run until first decode token appears for all requests (prefill timing) + # ------------------------------------------------------------ + t0 = time.perf_counter() + pre_decode = 0 # some decode tokens can be mixed with prefill batch + pending = set(f"req_{i}" for i in range(batch_size)) + while pending: + outputs = engine.step() + for out in outputs: + if len(out.outputs[0].token_ids) > 0: + if out.request_id in pending: + pending.remove(out.request_id) + else: + pre_decode += 1 + torch.cuda.synchronize() + t1 = time.perf_counter() + + prefill_time = t1 - t0 + prefill_tokens = batch_size * input_len + + # ------------------------------------------------------------ + # 3. Decode until all requests finish + # ------------------------------------------------------------ + decode_start = time.perf_counter() + + while engine.has_unfinished_requests(): + outputs = engine.step() + + torch.cuda.synchronize() + decode_end = time.perf_counter() + + decode_time = decode_end - decode_start + decode_tokens = ( + batch_size * (output_len - 1) - pre_decode + ) # exclude prefill-mixed tokens + + return { + "batch_size": batch_size, + "input_len": input_len, + "output_len": output_len, + "prefill_tput": prefill_tokens / prefill_time, + "decode_tput": decode_tokens / decode_time, + } + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, required=True) + parser.add_argument( + "--batch-size", + type=parse_list, + default=[1], + help=( + "number of prompts in a batch (int or list, e.g. '1', '1,2,4', '[1,2,4]')" + ), + ) + parser.add_argument( + "--input-len", + type=parse_list, + default=[256], + help="input sequence length(s)", + ) + parser.add_argument( + "--output-len", + type=parse_list, + default=[256], + help="output sequence length(s)", + ) + parser.add_argument( + "--tensor-parallel-size", + "--tp", + type=int, + default=1, + help="total rank for tensor parallel", + ) + parser.add_argument("--dtype", type=str, default="auto") + parser.add_argument("--num-iters-warmup", type=int, default=2) + args = parser.parse_args() + + # ------------------------------------------------------------ + # Engine init (TP supported here) + # ------------------------------------------------------------ + engine_args = EngineArgs( + model=args.model, + dtype=args.dtype, + tensor_parallel_size=args.tensor_parallel_size, + trust_remote_code=True, + distributed_executor_backend="mp", + enable_chunked_prefill=False, + ) + + engine = LLMEngine.from_engine_args(engine_args) + vocab_size = engine.model_config.get_vocab_size() + + # ------------------------------------------------------------ + # Sweep all combinations + # ------------------------------------------------------------ + print("\n=== Running benchmark ===") + results = [] + + try: + for bs, il, ol in itertools.product( + args.batch_size, args.input_len, args.output_len + ): + # Warmup + for _ in range(args.num_iters_warmup): + run_one_case( + engine, + batch_size=bs, + input_len=il, + output_len=ol, + vocab_size=vocab_size, + ) + + res = run_one_case( + engine, + batch_size=bs, + input_len=il, + output_len=ol, + vocab_size=vocab_size, + ) + results.append(res) + + print( + f"[TP={args.tensor_parallel_size} | " + f"bs={bs} in={il} out={ol}] " + f"prefill={res['prefill_tput']} tok/s | " + f"decode={res['decode_tput']} tok/s" + ) + except Exception as e: + print(f"Error Occured: {e}") + # ------------------------------------------------------------ + # Summary + # ------------------------------------------------------------ + print("\n=== Summary ===") + print("bs in_len out_len prefill_tok/s decode_tok/s") + for r in results: + print( + f"{r['batch_size']:3d} " + f"{r['input_len']:7d} " + f"{r['output_len']:8d} " + f"{r['prefill_tput']:14.2f} " + f"{r['decode_tput']:14.2f}" + ) + + +if __name__ == "__main__": + torch.manual_seed(0) + random.seed(0) + main()