Skip to content

Commit 8f8fe82

Browse files
author
linjianma
committed
remove gin
1 parent 71cf628 commit 8f8fe82

File tree

7 files changed

+128
-132
lines changed

7 files changed

+128
-132
lines changed

recommendation/dlrm_v3/README.md

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,68 @@
1-
# MLCommons (MLPerf) DLRMv3 Inference Benchmarks
1+
# MLPerf Inference reference implementation for DLRMv3
22

3-
## Install generative-recommenders
3+
## Install dependencies and build loadgen
44

55
```
6-
cd generative_recommenders/
7-
pip install -e .
6+
sh setup.sh
87
```
98

10-
## Build loadgen
9+
## Dataset download
1110

12-
```
13-
cd generative_recommenders/generative_recommenders/dlrm_v3/inference/thirdparty/loadgen/
14-
CFLAGS="-std=c++14 -O3" python -m pip install .
15-
```
16-
17-
## Generate synthetic dataset
18-
19-
```
20-
cd generative_recommenders/dlrm_v3/
21-
python streaming_synthetic_data.py
22-
```
11+
TODO: pending MLPerf system setup
2312

2413
## Inference benchmark
2514

2615
```
27-
cd generative_recommenders/generative_recommenders/dlrm_v3/inference/
28-
WORLD_SIZE=8 python main.py --dataset streaming-100b
16+
WORLD_SIZE=8 python main.py --dataset sampled-streaming-100b
2917
```
3018

31-
The config file is listed in `dlrm_v3/inference/gin/streaming_100b.gin`.
3219
`WORLD_SIZE` is the number of GPUs used in the inference benchmark.
3320

34-
To load checkpoint from training, modify `run.model_path` inside the inference
35-
gin config file. (We will relase the checkpoint soon.)
36-
37-
To achieve the best performance, tune `run.target_qps` and `run.batch_size` in
38-
the config file.
21+
```
22+
usage: main.py [-h] [--dataset {streaming-100b,sampled-streaming-100b}] [--model-path MODEL_PATH] [--scenario-name {SingleStream,MultiStream,Server,Offline}] [--batchsize BATCHSIZE]
23+
[--output-trace OUTPUT_TRACE] [--data-producer-threads DATA_PRODUCER_THREADS] [--compute-eval COMPUTE_EVAL] [--find-peak-performance FIND_PEAK_PERFORMANCE]
24+
[--dataset-path-prefix DATASET_PATH_PREFIX] [--warmup-ratio WARMUP_RATIO] [--num-queries NUM_QUERIES] [--target-qps TARGET_QPS] [--numpy-rand-seed NUMPY_RAND_SEED]
25+
[--sparse-quant SPARSE_QUANT] [--dataset-percentage DATASET_PERCENTAGE]
26+
27+
options:
28+
-h, --help show this help message and exit
29+
--dataset {streaming-100b,sampled-streaming-100b}
30+
name of the dataset
31+
--model-path MODEL_PATH
32+
path to the model checkpoint. Example: /home/username/ckpts/streaming_100b/89/
33+
--scenario-name {SingleStream,MultiStream,Server,Offline}
34+
inference benchmark scenario
35+
--batchsize BATCHSIZE
36+
batch size used in the benchmark
37+
--output-trace OUTPUT_TRACE
38+
Whether to output trace
39+
--data-producer-threads DATA_PRODUCER_THREADS
40+
Number of threads used in data producer
41+
--compute-eval COMPUTE_EVAL
42+
If true, will run AccuracyOnly mode and outputs both predictions and labels for accuracy calcuations
43+
--find-peak-performance FIND_PEAK_PERFORMANCE
44+
Whether to find peak performance in the benchmark
45+
--dataset-path-prefix DATASET_PATH_PREFIX
46+
Prefix to the dataset path. Example: /home/username/
47+
--warmup-ratio WARMUP_RATIO
48+
The ratio of the dataset used to warmup SUT
49+
--num-queries NUM_QUERIES
50+
Number of queries to run in the benchmark
51+
--target-qps TARGET_QPS
52+
Benchmark target QPS. Needs to be tuned for different implementations to balance latency and throughput
53+
--numpy-rand-seed NUMPY_RAND_SEED
54+
Numpy random seed
55+
--sparse-quant SPARSE_QUANT
56+
Whether to quantize sparse arch
57+
--dataset-percentage DATASET_PERCENTAGE
58+
Percentage of the dataset to run in the benchmark
59+
```
3960

4061
## Accuracy test
4162

4263
Set `run.compute_eval` will run the accuracy test and dump prediction outputs in
4364
`mlperf_log_accuracy.json`. To check the accuracy, run
4465

4566
```
46-
python accuracy.py -- --path path/to/mlperf_log_accuracy.json
47-
```
48-
49-
## Run unit tests
50-
51-
```
52-
python tests/inference_test.py
67+
python accuracy.py --path path/to/mlperf_log_accuracy.json
5368
```

recommendation/dlrm_v3/gin/streaming_100b.gin

Lines changed: 0 additions & 17 deletions
This file was deleted.

recommendation/dlrm_v3/main.py

Lines changed: 65 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@
3030
import time
3131
from typing import Any, Dict, List, Optional, Union
3232

33-
import gin
34-
3533
# pyre-ignore [21]
3634
import mlperf_loadgen as lg # @manual
3735
import numpy as np
@@ -64,10 +62,6 @@
6462

6563
USER_CONF = f"{os.path.dirname(__file__)}/user.conf"
6664

67-
SUPPORTED_CONFIGS = {
68-
"sampled-streaming-100b": "streaming_100b.gin",
69-
}
70-
7165

7266
SCENARIO_MAP = { # pyre-ignore [5]
7367
"SingleStream": lg.TestScenario.SingleStream,
@@ -81,7 +75,49 @@ def get_args(): # pyre-ignore [3]
8175
"""Parse commandline."""
8276
parser = argparse.ArgumentParser()
8377
parser.add_argument(
84-
"--dataset", default="debug", choices=SUPPORTED_DATASETS, help="dataset"
78+
"--dataset", default="sampled-streaming-100b", choices=SUPPORTED_DATASETS, help="name of the dataset"
79+
)
80+
parser.add_argument(
81+
"--model-path", default="", help="path to the model checkpoint. Example: /home/username/ckpts/streaming_100b/89/"
82+
)
83+
parser.add_argument(
84+
"--scenario-name", default="Server", choices={"SingleStream", "MultiStream", "Server", "Offline"}, help="inference benchmark scenario"
85+
)
86+
parser.add_argument(
87+
"--batchsize", default=20, help="batch size used in the benchmark"
88+
)
89+
parser.add_argument(
90+
"--output-trace", default=False, help="Whether to output trace"
91+
)
92+
parser.add_argument(
93+
"--data-producer-threads", default=16, help="Number of threads used in data producer"
94+
)
95+
parser.add_argument(
96+
"--compute-eval", default=False, help="If true, will run AccuracyOnly mode and outputs both predictions and labels for accuracy calcuations"
97+
)
98+
parser.add_argument(
99+
"--find-peak-performance", default=False, help="Whether to find peak performance in the benchmark"
100+
)
101+
parser.add_argument(
102+
"--dataset-path-prefix", default="", help="Prefix to the dataset path. Example: /home/username/"
103+
)
104+
parser.add_argument(
105+
"--warmup-ratio", default=0.1, help="The ratio of the dataset used to warmup SUT"
106+
)
107+
parser.add_argument(
108+
"--num-queries", default=500000, help="Number of queries to run in the benchmark"
109+
)
110+
parser.add_argument(
111+
"--target-qps", default=1500, help="Benchmark target QPS. Needs to be tuned for different implementations to balance latency and throughput"
112+
)
113+
parser.add_argument(
114+
"--numpy-rand-seed", default=123, help="Numpy random seed"
115+
)
116+
parser.add_argument(
117+
"--sparse-quant", default=False, help="Whether to quantize sparse arch"
118+
)
119+
parser.add_argument(
120+
"--dataset-percentage", default=0.001, help="Percentage of the dataset to run in the benchmark"
85121
)
86122
args, unknown_args = parser.parse_known_args()
87123
logger.warning(f"unknown_args: {unknown_args}")
@@ -363,33 +399,24 @@ def get_item_count(self) -> int:
363399
return self.total_requests
364400

365401

366-
@gin.configurable
367402
def run(
368403
dataset: str = "debug",
369404
model_path: str = "",
370405
scenario_name: str = "Server",
371406
batchsize: int = 16,
372-
out_dir: str = "",
373407
output_trace: bool = False,
374408
data_producer_threads: int = 4,
375409
compute_eval: bool = False,
376410
find_peak_performance: bool = False,
377-
new_path_prefix: str = "",
378-
train_split_percentage: float = 0.75,
411+
dataset_path_prefix: str = "",
379412
warmup_ratio: float = 0.1,
380-
# below will override mlperf rules compliant settings - don't use for official submission
381-
duration: Optional[int] = None,
382413
target_qps: Optional[int] = None,
383-
max_latency: Optional[float] = None,
384414
num_queries: Optional[int] = None,
385-
samples_per_query_multistream: int = 8,
386-
max_num_samples: int = -1,
387415
numpy_rand_seed: int = 123,
388-
dev_mode: bool = False,
389416
sparse_quant: bool = False,
390417
dataset_percentage: float = 1.0,
391418
) -> None:
392-
set_dev_mode(dev_mode)
419+
set_dev_mode(False)
393420
if scenario_name not in SCENARIO_MAP:
394421
raise NotImplementedError("valid scanarios:" + str(list(SCENARIO_MAP.keys())))
395422
scenario = SCENARIO_MAP[scenario_name]
@@ -408,7 +435,7 @@ def run(
408435
compute_eval=compute_eval,
409436
)
410437
is_streaming: bool = "streaming" in dataset
411-
dataset, kwargs = get_dataset(dataset, new_path_prefix)
438+
dataset, kwargs = get_dataset(dataset, dataset_path_prefix)
412439

413440
ds: Dataset = dataset(
414441
hstu_config=hstu_config,
@@ -430,11 +457,6 @@ def run(
430457
logger.error("{} not found".format(user_conf))
431458
sys.exit(1)
432459

433-
if out_dir:
434-
output_dir = os.path.abspath(out_dir)
435-
os.makedirs(output_dir, exist_ok=True)
436-
os.chdir(output_dir)
437-
438460
# warmup
439461
warmup_ids = list(range(batchsize))
440462
ds.load_query_samples(warmup_ids)
@@ -453,7 +475,7 @@ def run(
453475
if not is_streaming
454476
else ds.get_item_count()
455477
)
456-
train_size: int = round(train_split_percentage * count) if not is_streaming else 0
478+
train_size: int = 0
457479

458480
settings = lg.TestSettings()
459481
settings.FromConfig(user_conf, model_path, scenario_name)
@@ -489,21 +511,10 @@ def flush_queries() -> None:
489511
if find_peak_performance:
490512
settings.mode = lg.TestMode.FindPeakPerformance
491513

492-
if duration:
493-
settings.min_duration_ms = duration
494-
settings.max_duration_ms = duration
495-
496514
if target_qps:
497515
settings.server_target_qps = float(target_qps)
498516
settings.offline_expected_qps = float(target_qps)
499517

500-
if samples_per_query_multistream:
501-
settings.multi_stream_samples_per_query = samples_per_query_multistream
502-
503-
if max_latency:
504-
settings.server_target_latency_ns = int(max_latency * NANO_SEC)
505-
settings.multi_stream_expected_latency_ns = int(max_latency * NANO_SEC)
506-
507518
# inference benchmark warmup
508519
if is_streaming:
509520
ds.init_sut()
@@ -549,7 +560,7 @@ def flush_queries() -> None:
549560
sut = lg.ConstructSUT(issue_queries, flush_queries)
550561
qsl = lg.ConstructQSL(
551562
count,
552-
min(count, max_num_samples) if max_num_samples > 0 else count,
563+
count,
553564
load_query_samples,
554565
ds.unload_query_samples,
555566
)
@@ -572,18 +583,28 @@ def flush_queries() -> None:
572583
if int(os.environ.get("WORLD_SIZE", 1)) > 1:
573584
model_family.predict(None)
574585

575-
if out_dir:
576-
with open("results.json", "w") as f:
577-
json.dump(final_results, f, sort_keys=True, indent=4)
578-
579586

580587
def main() -> None:
581588
set_verbose_level(1)
582589
args = get_args()
583590
logger.info(args)
584-
gin_path = f"{os.path.dirname(__file__)}/gin/{SUPPORTED_CONFIGS[args.dataset]}"
585-
gin.parse_config_file(gin_path)
586-
run(dataset=args.dataset)
591+
run(
592+
dataset=args.dataset,
593+
model_path=args.model_path,
594+
scenario_name=args.scenario_name,
595+
batchsize=args.batchsize,
596+
output_trace=args.output_trace,
597+
data_producer_threads=args.data_producer_threads,
598+
compute_eval=args.compute_eval,
599+
find_peak_performance=args.find_peak_performance,
600+
dataset_path_prefix=args.dataset_path_prefix,
601+
warmup_ratio=args.warmup_ratio,
602+
target_qps=args.target_qps,
603+
num_queries=args.num_queries,
604+
numpy_rand_seed=args.numpy_rand_seed,
605+
sparse_quant=args.sparse_quant,
606+
dataset_percentage=args.dataset_percentage,
607+
)
587608

588609

589610
if __name__ == "__main__":
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
torch==2.8.0
2+
fbgemm_gpu==1.3.0
3+
torchrec==1.3.0
4+
gin_config==0.5.0
5+
pandas==2.3.2
6+
tensorboard==2.20.0
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#!/usr/bin/env bash
2+
3+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 WORLD_SIZE=8 \
4+
python main.py --dataset sampled-streaming-100b 2>&1 | tee /home/$USER/dlrmv3-inference-benchmark.log

recommendation/dlrm_v3/setup.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#!/usr/bin/env bash
2+
3+
conda create --name dlrmv3 python=3.13
4+
conda activate dlrmv3
5+
pip install -r requirements.txt
6+
git_dir=$(git rev-parse --show-toplevel)
7+
pip install $git_dir/loadgen

recommendation/dlrm_v3/utils.py

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -274,14 +274,6 @@ def reset(self, mode: str = "train"):
274274

275275
# the datasets we support
276276
SUPPORTED_DATASETS = [
277-
"debug",
278-
"movielens-1m",
279-
"movielens-20m",
280-
"movielens-13b",
281-
"movielens-18b",
282-
"kuairand-1k",
283-
"streaming-400m",
284-
"streaming-200b",
285277
"streaming-100b",
286278
"sampled-streaming-100b",
287279
]
@@ -290,38 +282,6 @@ def reset(self, mode: str = "train"):
290282
@gin.configurable
291283
def get_dataset(name: str, new_path_prefix: str = ""):
292284
assert name in SUPPORTED_DATASETS, f"dataset {name} not supported"
293-
if name == "debug":
294-
return DLRMv3RandomDataset, {}
295-
if name == "streaming-400m":
296-
return (
297-
DLRMv3SyntheticStreamingDataset,
298-
{
299-
"ratings_file_prefix": os.path.join(
300-
new_path_prefix, "data/streaming-400m/"
301-
),
302-
"train_ts": 8,
303-
"total_ts": 10,
304-
"num_files": 3,
305-
"num_users": 150_000,
306-
"num_items": 1_500_000,
307-
"num_categories": 128,
308-
},
309-
)
310-
if name == "streaming-200b":
311-
return (
312-
DLRMv3SyntheticStreamingDataset,
313-
{
314-
"ratings_file_prefix": os.path.join(
315-
new_path_prefix, "data/streaming-200b/"
316-
),
317-
"train_ts": 90,
318-
"total_ts": 100,
319-
"num_files": 100,
320-
"num_users": 10_000_000,
321-
"num_items": 1_000_000_000,
322-
"num_categories": 128,
323-
},
324-
)
325285
if name == "streaming-100b":
326286
return (
327287
DLRMv3SyntheticStreamingDataset,

0 commit comments

Comments
 (0)