diff --git a/run-vllm.py b/run-vllm.py index 888ffc8..a53f7a3 100644 --- a/run-vllm.py +++ b/run-vllm.py @@ -33,8 +33,8 @@ def __call__( m = params["_max_tokens"] kwargs["max_num_batched_tokens"] = m kwargs["max_model_len"] = min(m, model_max_tokens or m, model_seq_length or m) - if kwargs["tensor_parallel_size"] > 0: - tensor_parallel_size = kwargs["tensor_parallel_size"] + if params["tensor_parallel_size"] > 0: + tensor_parallel_size = params["tensor_parallel_size"] else: tensor_parallel_size = math.gcd( torch.cuda.device_count(),