Skip to content

Commit 3a87c2d

Browse files
author
linjianma
committed
fix a bug
1 parent 8f8fe82 commit 3a87c2d

File tree

4 files changed

+30
-24
lines changed

4 files changed

+30
-24
lines changed

recommendation/dlrm_v3/accuracy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323

2424
import numpy as np
2525
import torch
26-
from generative_recommenders.dlrm_v3.configs import get_hstu_configs
27-
from generative_recommenders.dlrm_v3.utils import MetricsLogger
26+
from configs import get_hstu_configs
27+
from utils import MetricsLogger
2828

2929
logger: logging.Logger = logging.getLogger("main")
3030

recommendation/dlrm_v3/main.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919

2020
import argparse
2121
import array
22-
import json
2322
import logging
2423
import random
24+
import threading
2525

2626
logging.basicConfig(level=logging.INFO)
2727
import math
@@ -53,7 +53,6 @@
5353
SUPPORTED_DATASETS,
5454
)
5555

56-
5756
logger: logging.Logger = logging.getLogger("main")
5857

5958
torch.multiprocessing.set_start_method("spawn", force=True)
@@ -99,16 +98,16 @@ def get_args(): # pyre-ignore [3]
9998
"--find-peak-performance", default=False, help="Whether to find peak performance in the benchmark"
10099
)
101100
parser.add_argument(
102-
"--dataset-path-prefix", default="", help="Prefix to the dataset path. Example: /home/username/"
101+
"--dataset-path-prefix", default=f"/home/{os.getlogin()}/", help="Prefix to the dataset path. Example: /home/username/"
103102
)
104103
parser.add_argument(
105-
"--warmup-ratio", default=0.1, help="The ratio of the dataset used to warmup SUT"
104+
"--warmup-ratio", default=0.3, help="The ratio of the dataset used to warmup SUT"
106105
)
107106
parser.add_argument(
108107
"--num-queries", default=500000, help="Number of queries to run in the benchmark"
109108
)
110109
parser.add_argument(
111-
"--target-qps", default=1500, help="Benchmark target QPS. Needs to be tuned for different implementations to balance latency and throughput"
110+
"--target-qps", default=1000, help="Benchmark target QPS. Needs to be tuned for different implementations to balance latency and throughput"
112111
)
113112
parser.add_argument(
114113
"--numpy-rand-seed", default=123, help="Numpy random seed"
@@ -332,6 +331,7 @@ def __init__(
332331
get_num_queries(input_queries, self.total_requests) // self.total_requests
333332
)
334333
self.repeat: int = 0
334+
self._lock = threading.Lock()
335335

336336
def get_num_requests(self, warmup_ratio: float) -> List[int]:
337337
return [
@@ -359,6 +359,7 @@ def init_sut(self) -> None:
359359
self.ts = self.start_ts
360360
self.ds.set_ts(self.start_ts)
361361
self.cnt = 0
362+
self.repeat = 0
362363

363364
def load_query_samples(self, query_ids: List[Optional[int]]) -> None:
364365
length = len(query_ids)
@@ -382,25 +383,27 @@ def unload_query_samples(self, sample_list: List[int]) -> None:
382383
def get_samples(self, id_list: List[int]) -> Samples:
383384
batch_size: int = len(id_list)
384385
ts_idx: int = 0
385-
while self.num_requests_cumsum[ts_idx] <= self.cnt:
386-
ts_idx += 1
387-
offset: int = 0 if ts_idx == 0 else self.num_requests_cumsum[ts_idx - 1]
386+
with self._lock:
387+
current_cnt: int = self.cnt
388+
while self.num_requests_cumsum[ts_idx] <= current_cnt:
389+
ts_idx += 1
390+
offset: int = 0 if ts_idx == 0 else self.num_requests_cumsum[ts_idx - 1]
391+
self.repeat += 1
392+
if self.repeat == self.num_repeats:
393+
self.repeat = 0
394+
self.cnt += batch_size
388395
output: Samples = self.ds.get_samples_with_ts(
389-
self.run_order[ts_idx][self.cnt - offset : self.cnt + batch_size - offset],
396+
self.run_order[ts_idx][current_cnt - offset : current_cnt + batch_size - offset],
390397
ts_idx + self.start_ts,
391398
)
392-
self.repeat += 1
393-
if self.repeat == self.num_repeats:
394-
self.repeat = 0
395-
self.cnt += batch_size
396399
return output
397400

398401
def get_item_count(self) -> int:
399402
return self.total_requests
400403

401404

402405
def run(
403-
dataset: str = "debug",
406+
dataset: str = "sampled-streaming-100b",
404407
model_path: str = "",
405408
scenario_name: str = "Server",
406409
batchsize: int = 16,

recommendation/dlrm_v3/model_family.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -306,10 +306,10 @@ def load(self, model_path: str) -> None:
306306
processes.append(p)
307307

308308
def distributed_setup(self, rank: int, world_size: int, model_path: str) -> None:
309-
nprocs_per_rank = 16
310-
start_core: int = nprocs_per_rank * rank
311-
cores: set[int] = set([start_core + i for i in range(nprocs_per_rank)])
312-
os.sched_setaffinity(0, cores)
309+
# nprocs_per_rank = 16
310+
# start_core: int = nprocs_per_rank * rank + 128
311+
# cores: set[int] = set([start_core + i for i in range(nprocs_per_rank)])
312+
# os.sched_setaffinity(0, cores)
313313
set_is_inference(is_inference=not self.compute_eval)
314314
model = get_hstu_model(
315315
table_config=self.table_config,
@@ -366,6 +366,9 @@ def distributed_setup(self, rank: int, world_size: int, model_path: str) -> None
366366
max_num_candidates=max_num_candidates,
367367
num_candidates=num_candidates,
368368
)
369+
# mt_target_preds = torch.empty(1, 2048 * 20).to(device="cpu")
370+
# mt_target_labels = None
371+
# mt_target_weights = None
369372
assert mt_target_preds is not None
370373
mt_target_preds = mt_target_preds.detach().to(device="cpu")
371374
if mt_target_labels is not None:

recommendation/dlrm_v3/user.conf

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Please set these fields depending on the performance of your system to
22
# override default LoadGen settings.
3-
*.SingleStream.target_latency = 150
4-
*.MultiStream.target_latency = 150
5-
*.Server.target_latency = 150
6-
*.Server.min_duration = 20000
3+
*.SingleStream.target_latency = 100
4+
*.MultiStream.target_latency = 100
5+
*.Server.target_latency = 100
6+
# *.Server.min_duration = 20000

0 commit comments

Comments
 (0)