Skip to content

Commit 95d7e20

Browse files
authored
feat: manually set tensor-parallel-size (#92)
1 parent 341f9f6 commit 95d7e20

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

run-vllm.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,16 @@ def __call__(
3333
m = params["_max_tokens"]
3434
kwargs["max_num_batched_tokens"] = m
3535
kwargs["max_model_len"] = min(m, model_max_tokens or m, model_seq_length or m)
36-
tensor_parallel_size = math.gcd(
37-
torch.cuda.device_count(),
38-
math.gcd(
39-
getattr(config, "num_attention_heads", 720720),
40-
getattr(config, "num_key_value_heads", 720720),
41-
),
42-
)
36+
if kwargs["tensor_parallel_size"] > 0:
37+
tensor_parallel_size = kwargs["tensor_parallel_size"]
38+
else:
39+
tensor_parallel_size = math.gcd(
40+
torch.cuda.device_count(),
41+
math.gcd(
42+
getattr(config, "num_attention_heads", 720720),
43+
getattr(config, "num_key_value_heads", 720720),
44+
),
45+
)
4346
self.llm = LLM(
4447
model=model,
4548
tensor_parallel_size=tensor_parallel_size,
@@ -146,6 +149,9 @@ def __call__(
146149
parser.add_argument("--max-tokens", type=int, default=0, help="Maximum number of tokens.")
147150
parser.add_argument("--prefix", type=str, default="", help="Prefix for the prompt.")
148151
parser.add_argument("--dtype", type=str, default="", help="Data type.")
152+
parser.add_argument(
153+
"--tensor-parallel-size", type=int, default=-1, help="Tensor Parallel Size."
154+
)
149155
parser.add_argument("--quantization", type=str, default=None, help="Quantization method.")
150156
args = parser.parse_args()
151157
kwargs = {}
@@ -162,6 +168,9 @@ def __call__(
162168
for t in chat_templates:
163169
if args.model in t["models"]:
164170
kwargs["chat_template"] = t["chat_template"]
171+
172+
kwargs["tensor_parallel_size"] = args.tensor_parallel_size
173+
165174
pfgen.run_tasks(
166175
args.mode,
167176
Callback(),

0 commit comments

Comments
 (0)