diff --git a/language/llama2-70b/CONTRIBUTING.md b/language/llama2-70b/CONTRIBUTING.md new file mode 100644 index 0000000000..2863937713 --- /dev/null +++ b/language/llama2-70b/CONTRIBUTING.md @@ -0,0 +1,9 @@ +# Contributing + +## Unit Tests + +To run unit tests for the LLaMA 2 70B implementation, install development dependencies + +```bash +pip install -r requirements-dev.txt +``` diff --git a/language/llama2-70b/README.md b/language/llama2-70b/README.md index 0f604d4f3d..359eb1baba 100644 --- a/language/llama2-70b/README.md +++ b/language/llama2-70b/README.md @@ -25,7 +25,7 @@ conda activate llama2-70b # Install packages conda install pybind11==2.10.4 -c conda-forge -y python -m pip install torch==2.2.0.dev20231006+cpu --index-url https://download.pytorch.org/whl/nightly/cpu -pip install transformers==4.31.0 nltk==3.8.1 evaluate==0.4.0 absl-py==1.4.0 rouge-score==0.1.2 sentencepiece==0.1.99 accelerate==0.21.0 +pip install transformers==4.31.0 nltk==3.8.1 evaluate==0.4.0 absl-py==1.4.0 rouge-score==0.1.2 sentencepiece==0.1.99 accelerate==1.11.0 httpx==0.28.1 more_itertools==10.8.0 export CUR_DIR=${PWD} cd /loadgen @@ -187,6 +187,23 @@ python3 -u main.py --scenario Offline \ --device cuda:0 2>&1 | tee offline_performance_log.log ``` +For models hosted over an OpenAI-compatible LLM API endpoint (eg. via VLLM, TensorRT-LLM): + +``` +python3 -u main.py --scenario Offline \ + --vllm \ + --api-model-name ${MODEL_NAME} \ + --api-server ${API_BASE} \ + --model-path ${CHECKPOINT_PATH} \ + --user-conf user.conf \ + --total-sample-count 24576 \ + --dataset-path ${DATASET_PATH} \ + --output-log-dir offline-logs +``` + +- `` is the base URL of the OpenAI-compatible endpoint eg. `http://server1:8000/` +- **Multinode** multiple LLM API endpoints can be provided by specifying `--api-server` multiple times. + ### Server ``` python -u main.py --scenario Server \ @@ -199,7 +216,7 @@ python -u main.py --scenario Server \ --output-log-dir server-logs ``` -The ServerSUT was not tested for GPU runs. +The ServerSUT was not tested for GPU or LLM API runs. ## Run Accuracy Benchmarks @@ -210,6 +227,7 @@ OUTPUT_LOG_DIR=offline-accuracy-logs mkdir -p "run_outputs" # The script will dump all the outputs to 'run_outputs'. +# for normal runs: python -u main.py --scenario Offline \ --model-path ${CHECKPOINT_PATH} \ --accuracy \ @@ -241,6 +259,20 @@ python consolidate_results.py --dataset-path ${DATASET_PATH} --model-dir ${CHECK For the GPU run - The above steps have been automated in `run_accuracy.sh`. You can also modify this script to use `--device cpu` to adapt it to a CPU-only run. +For models hosted over an OpenAI-compatible LLM API endpoint, +replace the `python -m main.py` command normal run instructions with: +```sh +python3 -u main.py --scenario Offline \ + --vllm \ + --api-model-name ${MODEL_NAME} \ + --api-server ${API_BASE} \ + --model-path ${CHECKPOINT_PATH} \ + --user-conf user.conf \ + --total-sample-count 24576 \ + --dataset-path ${DATASET_PATH} \ + --output-log-dir ${OUTPUT_LOG_DIR} \ + --accuracy +``` ### Server ``` diff --git a/language/llama2-70b/SUT_API.py b/language/llama2-70b/SUT_API.py index 0b1ebd0c98..fab3ded0c4 100644 --- a/language/llama2-70b/SUT_API.py +++ b/language/llama2-70b/SUT_API.py @@ -1,32 +1,23 @@ -import os -import time -import numpy as np import array -import torch -from torch.nn.functional import pad -from torch.utils.data import DataLoader -from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM -from transformers.generation.streamers import BaseStreamer - +import asyncio import json +import logging import pickle -import time -import threading -import tqdm import queue - -import logging -from typing import TYPE_CHECKING, Optional, List +import threading +import time +import traceback from pathlib import Path +import httpx +import mlperf_loadgen as lg import more_itertools as mit -from concurrent.futures.thread import ThreadPoolExecutor - +import numpy as np import requests -from urllib3.exceptions import InsecureRequestWarning - -import mlperf_loadgen as lg +import torch from dataset import Dataset +from transformers import AutoTokenizer +from transformers.generation.streamers import BaseStreamer logging.basicConfig(level=logging.INFO) log = logging.getLogger("Llama-70B-SUT") @@ -109,7 +100,7 @@ def __init__( self.device = device self.api_servers = [] if api_server: - self.api_servers.append(api_server) + self.api_servers.extend(api_server) self.api_model_name = api_model_name self.device = device @@ -176,144 +167,161 @@ def stop(self): for worker in self.worker_threads: worker.join() - def query_api_vllm(self, inputs, idx): + async def query_batch( + self, http: httpx.AsyncClient, prompt: list[list[int]], api_server: str + ) -> list[str]: + """Query LLM API server to get output tokens for given input prompt batch. + + Args: + http: httpx AsyncClient for making HTTP requests. + prompt: Batch of Input prompt tokens to be sent to the API server. + api_server: URL of the API server to which the request is to be sent. + Returns + Batch of output completion text from the LLM API server. + """ headers = { "Content-Type": "application/json", } json_data = { "model": self.api_model_name, - "prompt": inputs, + "prompt": prompt, "min_tokens": 1, "max_tokens": 1024, + "n": 1, + "temperature": 1.0, + "top_p": 0.001, } - response_code = 0 - print(f"Server path {self.api_servers[idx]}/v1/completions") - while response_code != 200: - try: - response = requests.post( - f"{self.api_servers[idx]}/v1/completions", - headers=headers, - json=json_data, - verify=False, - ) - response_code = response.status_code - except Exception as e: - print(e) - print("connection failure") - break - return [resp["text"] for resp in json.loads(response.text)["choices"]] + try: + print( + f"query_batch: Sending prompts to API server: n_prompts={len(prompt)} api_server={api_server}" + ) + response = await http.post( + f"{api_server}/v1/completions", + headers=headers, + json=json_data, + ) + completions = [c["text"] for c in json.loads(response.text)["choices"]] + print( + f"query_batch: Received completions from API server: n_prompts={len(prompt)} api_server={api_server}" + ) + except Exception as e: + # log exception trace: necessary as mlperf swallows exceptions + print("[ERROR] Exception occurred while querying API server:") + traceback.print_exception(e) + + completions = [] + return completions + + async def query_servers( + self, http: httpx.AsyncClient, prompts: list[list[int]] + ) -> list[str]: + """Query LLM API servers to get output tokens for given input prompt tokens. + + Args: + http: httpx AsyncClient for making HTTP requests. + prompts: List of input prompt tokens to be sent to the API servers. + Returns: + List of output tokens for each prompt in the given prompts. + """ + # subdivide full prompts load among servers for even distribution + print( + "query_servers: Distributing prompts over API servers: " + f"n_prompts={len(prompts)} n_servers={len(self.api_servers)}" + ) + + promises = [] + for api_server, server_prompts in zip( + self.api_servers, mit.divide(len(self.api_servers), prompts) + ): + promises.append(self.query_batch(http, list(server_prompts), api_server)) + + outputs = [o for outs in await asyncio.gather(*promises) for o in outs] - def api_action_handler(self, chunk, server_idx): - output = self.query_api_vllm(chunk, server_idx) - return output + return outputs def process_queries(self): """Processor of the queued queries. User may choose to add batching logic""" - while True: - qitem = self.query_queue.get() - if qitem is None: - break - - query_ids = [q.index for q in qitem] - - fname = "q" + "_".join([str(i) for i in query_ids]) - fname = f"run_outputs/{fname}.pkl" - _p = Path(fname) - if self.use_cached_outputs and _p.exists(): - # Read cache - with _p.open(mode="rb") as f: - d = pickle.load(f) - processed_output = d["outputs"] - tik1 = None - tik2 = None - tik3 = None - tok = None - else: - # Construct / collate batch - max_seq_len = 1024 - - tik1 = time.time() - - # OpenAI-API servers don't require padding and can take input tokens - # directly, so we build our input_ids_tensor as a jagged list - input_ids_tensor = [] - for q in qitem: - # input_ids_tensor.append(self.data_object.input_ids[q.index].tolist()) - input_ids_tensor += self.data_object.input_ids[q.index].tolist( - ) - - # NOTE(mgoin): I don't think this has to be a torch tensor - # input_ids_tensor = torch.cat(input_ids_tensor) - - # print(input_ids_tensor) - - assert len(input_ids_tensor) <= self.batch_size - - tik2 = time.time() - - # NOTE(mgoin): I don't think threading is necessary since we are submitting all queries in one request - # The API server should take care of mini-batches and - # scheduling - if self.api_servers: - """ - decoded = self.tokenizer.batch_decode(input_ids_tensor) - cleaned = [entry.replace('','').replace('','') for entry in decoded] - cleaned_chunks = [list(c) for c in mit.divide(len(self.api_servers), cleaned)] - """ - cleaned_chunks = [input_ids_tensor] - with ThreadPoolExecutor( - max_workers=len(self.api_servers) - ) as executor: - # needs to be tested - output_chunks = list( - executor.map( - self.api_action_handler, - cleaned_chunks, - range(len(self.api_servers)), + async def process(): + # init common http client to take advantage of connection pooling + async with httpx.AsyncClient( + verify=False, + # 1hr timeout + timeout=httpx.Timeout(3600), + ) as http: + while True: + qitem = self.query_queue.get() + if qitem is None: + break + + query_ids = [q.index for q in qitem] + + fname = "q" + "_".join([str(i) for i in query_ids]) + fname = f"run_outputs/{fname}.pkl" + _p = Path(fname) + if self.use_cached_outputs and _p.exists(): + # Read cache + with _p.open(mode="rb") as f: + d = pickle.load(f) + processed_output = d["outputs"] + tik1 = None + tik2 = None + tik3 = None + tok = None + else: + tik1 = time.time() + + # OpenAI-API servers don't require padding and can take input tokens + # directly, so we build our input_ids_tensor as a jagged list + input_ids_tensor = [] + for q in qitem: + input_ids_tensor += self.data_object.input_ids[ + q.index + ].tolist() + # collect prompt tokens for selected query ids + tik2 = time.time() + + # NOTE(mgoin): I don't think threading is necessary since we are submitting all queries in one request + # The API server should take care of mini-batches and scheduling + if len(self.api_servers) > 0: + outputs = await self.query_servers(http, input_ids_tensor) + else: + print( + "Error: Specify at least one API to which the request is to be sent!" ) - ) - output = [] - for row in output_chunks: - output += row - else: - print( - "Error: Specify at least one API to which the request is to be sent!" - ) - exit(1) - - tik3 = time.time() - - processed_output = self.tokenizer(output)["input_ids"] - # for i in range(len(qitem)): - for i in range(len(processed_output)): - # NOTE(mgoin): Not optimal to make numpy arrays just to - # serialize - unpadded = np.array(processed_output[i]) - n_tokens = unpadded.shape[0] - response_array = array.array("B", unpadded.tobytes()) - bi = response_array.buffer_info() - response = [ - lg.QuerySampleResponse( - qitem[i].id, - bi[0], - bi[1], - n_tokens)] - lg.QuerySamplesComplete(response) - - tok = time.time() - - with self.sample_counter_lock: - self.sample_counter += len(qitem) - print(f"Samples run: {self.sample_counter}") - if tik1: - print(f"\tBatchMaker time: {tik2 - tik1}") - print(f"\tInference time: {tik3 - tik2}") - print(f"\tPostprocess time: {tok - tik3}") - print(f"\t==== Total time: {tok - tik1}") - else: - print(f"\tLoaded from cache: {_p}") + exit(1) + + tik3 = time.time() + + processed_output = self.tokenizer(outputs)["input_ids"] + # for i in range(len(qitem)): + for i in range(len(processed_output)): + # NOTE(mgoin): Not optimal to make numpy arrays just to + # serialize + unpadded = np.array(processed_output[i]) + n_tokens = unpadded.shape[0] + response_array = array.array("B", unpadded.tobytes()) + bi = response_array.buffer_info() + response = [ + lg.QuerySampleResponse(qitem[i].id, bi[0], bi[1], n_tokens) + ] + lg.QuerySamplesComplete(response) + + tok = time.time() + + with self.sample_counter_lock: + self.sample_counter += len(qitem) + print(f"Samples run: {self.sample_counter}") + if tik1: + print(f"\tBatchMaker time: {tik2 - tik1}") + print(f"\tInference time: {tik3 - tik2}") + print(f"\tPostprocess time: {tok - tik3}") + print(f"\t==== Total time: {tok - tik1}") + else: + print(f"\tLoaded from cache: {_p}") + + asyncio.run(process()) def get_sut(self): self.sut = lg.ConstructSUT(self.issue_queries, self.flush_queries) @@ -334,7 +342,7 @@ def issue_queries(self, query_samples): print(f"IssueQuery started with {len(query_samples)} samples") while len(query_samples) > 0: self.query_queue.put(query_samples[: self.batch_size]) - query_samples = query_samples[self.batch_size:] + query_samples = query_samples[self.batch_size :] print(f"IssueQuery done") def flush_queries(self): @@ -360,12 +368,13 @@ def __init__( super().__init__( model_path=model_path, - api_server=None, - api_model_name=None, + api_server=api_server, + api_model_name=api_model_name, dtype=dtype, device=device, total_sample_count=total_sample_count, dataset_path=dataset_path, + batch_size=batch_size, workers=workers, ) @@ -384,8 +393,7 @@ def start(self): self.worker_threads[j] = worker # Create first token response thread - self.ft_response_thread = threading.Thread( - target=self.process_first_tokens) + self.ft_response_thread = threading.Thread(target=self.process_first_tokens) self.ft_response_thread.start() def process_first_tokens(self): @@ -435,14 +443,12 @@ def stream_api_vllm(self, input, response_ids, idx): for line in resp.iter_lines(): if line: decoded = line.decode() - if decoded.startswith( - "data") and "[DONE]" not in decoded: + if decoded.startswith("data") and "[DONE]" not in decoded: inter = json.loads(decoded[6:])["choices"][0][ "logprobs" ] if "top_logprobs" in inter: - token_s = list( - inter["top_logprobs"][0].keys())[0] + token_s = list(inter["top_logprobs"][0].keys())[0] token = self.llama_vocab[token_s] if first: self.first_token_queue.put( @@ -468,9 +474,7 @@ def async_process_query(self, input_ids_tensor, qitem_id, idx): print("WARNING: caught low token count") print(input_ids_tensor) print(output_tokens) - response_array = array.array( - "B", np.array( - output_tokens, np.int32).tobytes()) + response_array = array.array("B", np.array(output_tokens, np.int32).tobytes()) bi = response_array.buffer_info() response = [lg.QuerySampleResponse(qitem_id, bi[0], bi[1], n_tokens)] lg.QuerySamplesComplete(response) @@ -520,9 +524,7 @@ def process_queries(self): "B", np.array(output_tokens, np.int32).tobytes() ) bi = response_array.buffer_info() - response = [ - lg.QuerySampleResponse( - qitem.id, bi[0], bi[1], n_tokens)] + response = [lg.QuerySampleResponse(qitem.id, bi[0], bi[1], n_tokens)] lg.QuerySamplesComplete(response) def issue_queries(self, query_samples): diff --git a/language/llama2-70b/build.sh b/language/llama2-70b/build.sh index 87afb992fa..2382dfcb17 100644 --- a/language/llama2-70b/build.sh +++ b/language/llama2-70b/build.sh @@ -2,7 +2,7 @@ set -e conda install pybind11==2.10.4 -c conda-forge -y conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch-nightly -c nvidia -python -m pip install transformers==4.31.0 nltk==3.8.1 evaluate==0.4.0 absl-py==1.4.0 rouge-score==0.1.2 sentencepiece==0.1.99 accelerate==0.21.0 +python -m pip install transformers==4.31.0 nltk==3.8.1 evaluate==0.4.0 absl-py==1.4.0 rouge-score==0.1.2 sentencepiece==0.1.99 accelerate==1.11.0 httpx==0.28.0 more_itertools==10.8.0 cd ../../loadgen && python3 -m pip install . diff --git a/language/llama2-70b/main.py b/language/llama2-70b/main.py index 84ccf849ad..b7c9bc2866 100644 --- a/language/llama2-70b/main.py +++ b/language/llama2-70b/main.py @@ -50,10 +50,7 @@ def get_args(): help="Model name", ) parser.add_argument("--dataset-path", type=str, default=None, help="") - parser.add_argument( - "--accuracy", - action="store_true", - help="Run accuracy mode") + parser.add_argument("--accuracy", action="store_true", help="Run accuracy mode") parser.add_argument( "--dtype", type=str, @@ -107,7 +104,11 @@ def get_args(): default=1, help="Number of workers to process queries", ) - parser.add_argument("--vllm", action="store_true", help="vllm mode") + parser.add_argument( + "--vllm", + action="store_true", + help="Access model via OpenAI-compatiable API: API mode (aka vllm mode)", + ) parser.add_argument( "--api-model-name", type=str, @@ -117,8 +118,9 @@ def get_args(): parser.add_argument( "--api-server", type=str, - default=None, - help="Specify an api endpoint call to use api mode", + action="append", + default=[], + help="Specify an api endpoints to use in OpenAI-compatiable API mode (aka vllm mode)", ) parser.add_argument( "--lg-model-name", @@ -141,13 +143,12 @@ def main(): args = get_args() if args.vllm: - resp = verify_model_name( - args.api_model_name, - args.api_server + "/v1/models") - if resp["error"]: - print(f"\n\n\033[91mError:\033[0m", end=" ") - print(resp["error"]) - sys.exit(1) + for server in args.api_server: + resp = verify_model_name(args.api_model_name, server + "/v1/models") + if resp["error"]: + print(f"\n\n\033[91mError:\033[0m", end=" ") + print(resp["error"]) + sys.exit(1) settings = lg.TestSettings() settings.scenario = scenario_map[args.scenario.lower()] @@ -203,12 +204,7 @@ def main(): sut.start() lgSUT = lg.ConstructSUT(sut.issue_queries, sut.flush_queries) log.info("Starting Benchmark run") - lg.StartTestWithLogSettings( - lgSUT, - sut.qsl, - settings, - log_settings, - args.audit_conf) + lg.StartTestWithLogSettings(lgSUT, sut.qsl, settings, log_settings, args.audit_conf) # Stop sut after completion sut.stop() diff --git a/language/llama2-70b/requirements-dev.txt b/language/llama2-70b/requirements-dev.txt new file mode 100644 index 0000000000..eec27128f8 --- /dev/null +++ b/language/llama2-70b/requirements-dev.txt @@ -0,0 +1,5 @@ +iniconfig==2.3.0 +pluggy==1.6.0 +pygments==2.19.2 +pytest==8.4.2 +pytest-asyncio==1.2.0 diff --git a/language/llama2-70b/test_SUT_API.py b/language/llama2-70b/test_SUT_API.py new file mode 100644 index 0000000000..9ab26b1335 --- /dev/null +++ b/language/llama2-70b/test_SUT_API.py @@ -0,0 +1,67 @@ +# +# NTUHPC +# MLPerf +# SUT_API unit tests +# + +from typing import Iterable +import httpx +import pytest +from SUT_API import SUT + +from unittest.mock import AsyncMock, Mock, patch + + +@pytest.fixture +def sut() -> Iterable[SUT]: + # Create a SUT instance with dummy parameters for testing + with ( + patch("SUT_API.Dataset") as dataset, + patch("SUT_API.AutoTokenizer") as tokenizer, + ): + yield SUT( + model_path="model_path", + dtype="float16", + batch_size=1, + dataset_path="dataset_path", + total_sample_count=10, + device="cpu", + api_server=["http://server1", "http://server2"], + api_model_name="dummy_model_name", + workers=1, + ) + + +@pytest.mark.asyncio +async def test_query_batch(sut: SUT): + for api_server in sut.api_servers: + # mock the HTTP post method to return a dummy response from LLM API server + http = Mock(spec=httpx.AsyncClient) + http.post = AsyncMock() + http.post.return_value.text = ( + '{"choices": [{"text": "Output 1"}, {"text": "Output 2"}]}' + ) + + prompt = [[1, 2]] + outputs = await sut.query_batch(http, prompt, api_server) + + # only the first output is returned + assert outputs == ["Output 1", "Output 2"] + http.post.assert_called_once() + + +@pytest.mark.asyncio +async def test_query_api_servers(sut: SUT): + test_inputs = [[1], [2], [3], [4]] + + with patch.object(sut, "query_batch") as mock_query_api_vllm: + # compute expected batch size Distributing data over api_servers + n_batch = len(test_inputs) // len(sut.api_servers) + + mock_query_api_vllm.return_value = ["Outputs"] * n_batch + + outputs = await sut.query_servers(Mock(spec=httpx.AsyncClient), test_inputs) + + assert outputs == ["Outputs"] * len(test_inputs) + # expect 2 calls, one batch for each api_server + assert mock_query_api_vllm.call_count == 2