diff --git a/language/llama3.1-8b/SUT_VLLM.py b/language/llama3.1-8b/SUT_VLLM.py index 94ee14abdd..914ce4fea6 100644 --- a/language/llama3.1-8b/SUT_VLLM.py +++ b/language/llama3.1-8b/SUT_VLLM.py @@ -1,3 +1,19 @@ +# Copyright 2025 The MLPerf Authors. All Rights Reserved. +# Copyright 2025 Arm Limited and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + import asyncio import os import time @@ -49,9 +65,6 @@ def __init__( self.dtype = dtype self.tensor_parallel_size = tensor_parallel_size - if not torch.cuda.is_available(): - assert False, "torch gpu is not available, exiting..." - self.dataset_path = dataset_path self.data_object = Dataset( self.model_path, @@ -117,10 +130,11 @@ def process_queries(self): # self.data_object.input[q.index] for q in qitem] # for in_text in input_text_tensor: # log.info(f"Input: {in_text}") + token_prompts = [TokensPrompt(prompt_token_ids=ids) for ids in input_ids_tensor] tik2 = time.time() outputs = self.model.generate( - prompt_token_ids=input_ids_tensor, sampling_params=self.sampling_params + prompts=token_prompts, sampling_params=self.sampling_params ) pred_output_tokens = [] for output in outputs: @@ -162,6 +176,8 @@ def load_model(self): self.model_path, dtype=self.dtype, tensor_parallel_size=self.tensor_parallel_size, + max_model_len=4096, + max_num_batched_tokens=4096, ) log.info("Loaded model") diff --git a/language/llama3.1-8b/requirements.txt b/language/llama3.1-8b/requirements.txt index a62f68e7ba..4d468be845 100644 --- a/language/llama3.1-8b/requirements.txt +++ b/language/llama3.1-8b/requirements.txt @@ -1,9 +1,9 @@ -transformers==4.46.2 +transformers==4.57.1 nltk==3.8.1 evaluate==0.4.0 absl-py==1.4.0 rouge-score==0.1.2 sentencepiece==0.2.0 accelerate==0.21.0 -vllm==0.6.3 +vllm==0.11.0 pybind11==2.10.4