@@ -33,13 +33,16 @@ def __call__(
3333 m = params ["_max_tokens" ]
3434 kwargs ["max_num_batched_tokens" ] = m
3535 kwargs ["max_model_len" ] = min (m , model_max_tokens or m , model_seq_length or m )
36- tensor_parallel_size = math .gcd (
37- torch .cuda .device_count (),
38- math .gcd (
39- getattr (config , "num_attention_heads" , 720720 ),
40- getattr (config , "num_key_value_heads" , 720720 ),
41- ),
42- )
36+ if kwargs ["tensor_parallel_size" ] > 0 :
37+ tensor_parallel_size = kwargs ["tensor_parallel_size" ]
38+ else :
39+ tensor_parallel_size = math .gcd (
40+ torch .cuda .device_count (),
41+ math .gcd (
42+ getattr (config , "num_attention_heads" , 720720 ),
43+ getattr (config , "num_key_value_heads" , 720720 ),
44+ ),
45+ )
4346 self .llm = LLM (
4447 model = model ,
4548 tensor_parallel_size = tensor_parallel_size ,
@@ -146,6 +149,9 @@ def __call__(
146149 parser .add_argument ("--max-tokens" , type = int , default = 0 , help = "Maximum number of tokens." )
147150 parser .add_argument ("--prefix" , type = str , default = "" , help = "Prefix for the prompt." )
148151 parser .add_argument ("--dtype" , type = str , default = "" , help = "Data type." )
152+ parser .add_argument (
153+ "--tensor-parallel-size" , type = int , default = - 1 , help = "Tensor Parallel Size."
154+ )
149155 parser .add_argument ("--quantization" , type = str , default = None , help = "Quantization method." )
150156 args = parser .parse_args ()
151157 kwargs = {}
@@ -162,6 +168,9 @@ def __call__(
162168 for t in chat_templates :
163169 if args .model in t ["models" ]:
164170 kwargs ["chat_template" ] = t ["chat_template" ]
171+
172+ kwargs ["tensor_parallel_size" ] = args .tensor_parallel_size
173+
165174 pfgen .run_tasks (
166175 args .mode ,
167176 Callback (),
0 commit comments