Skip to content

Commit ca57a6c

Browse files
author
linjianma
committed
dlrmv3
1 parent 8999c4d commit ca57a6c

19 files changed

+4262
-0
lines changed

recommendation/dlrm_v3/README.md

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# MLCommons (MLPerf) DLRMv3 Inference Benchmarks
2+
3+
## Install generative-recommenders
4+
5+
```
6+
cd generative_recommenders/
7+
pip install -e .
8+
```
9+
10+
## Build loadgen
11+
12+
```
13+
cd generative_recommenders/generative_recommenders/dlrm_v3/inference/thirdparty/loadgen/
14+
CFLAGS="-std=c++14 -O3" python -m pip install .
15+
```
16+
17+
## Generate synthetic dataset
18+
19+
```
20+
cd generative_recommenders/dlrm_v3/
21+
python streaming_synthetic_data.py
22+
```
23+
24+
## Inference benchmark
25+
26+
```
27+
cd generative_recommenders/generative_recommenders/dlrm_v3/inference/
28+
WORLD_SIZE=8 python main.py --dataset streaming-100b
29+
```
30+
31+
The config file is listed in `dlrm_v3/inference/gin/streaming_100b.gin`.
32+
`WORLD_SIZE` is the number of GPUs used in the inference benchmark.
33+
34+
To load checkpoint from training, modify `run.model_path` inside the inference
35+
gin config file. (We will relase the checkpoint soon.)
36+
37+
To achieve the best performance, tune `run.target_qps` and `run.batch_size` in
38+
the config file.
39+
40+
## Accuracy test
41+
42+
Set `run.compute_eval` will run the accuracy test and dump prediction outputs in
43+
`mlperf_log_accuracy.json`. To check the accuracy, run
44+
45+
```
46+
python accuracy.py -- --path path/to/mlperf_log_accuracy.json
47+
```
48+
49+
## Run unit tests
50+
51+
```
52+
python tests/inference_test.py
53+
```

recommendation/dlrm_v3/accuracy.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# pyre-strict
16+
"""
17+
Tool to calculate accuracy for loadgen accuracy output found in mlperf_log_accuracy.json
18+
"""
19+
20+
import argparse
21+
import json
22+
import logging
23+
24+
import numpy as np
25+
import torch
26+
from generative_recommenders.dlrm_v3.configs import get_hstu_configs
27+
from generative_recommenders.dlrm_v3.utils import MetricsLogger
28+
29+
logger: logging.Logger = logging.getLogger("main")
30+
31+
32+
def get_args() -> argparse.Namespace:
33+
"""Parse commandline."""
34+
parser = argparse.ArgumentParser()
35+
parser.add_argument(
36+
"--path",
37+
required=True,
38+
help="path to mlperf_log_accuracy.json",
39+
)
40+
args = parser.parse_args()
41+
return args
42+
43+
44+
def main() -> None:
45+
args = get_args()
46+
logger.warning("Parsing loadgen accuracy log...")
47+
with open(args.path, "r") as f:
48+
results = json.load(f)
49+
hstu_config = get_hstu_configs(dataset="sampled-streaming-100b")
50+
metrics = MetricsLogger(
51+
multitask_configs=hstu_config.multitask_configs,
52+
batch_size=1,
53+
window_size=3000,
54+
device=torch.device("cpu"),
55+
rank=0,
56+
)
57+
logger.warning(f"results have {len(results)} entries")
58+
for result in results:
59+
data = np.frombuffer(bytes.fromhex(result["data"]), np.float32)
60+
num_candidates = data[-1].astype(int)
61+
assert len(data) == 1 + num_candidates * 3
62+
mt_target_preds = torch.from_numpy(data[0:num_candidates])
63+
mt_target_labels = torch.from_numpy(data[num_candidates : num_candidates * 2])
64+
mt_target_weights = torch.from_numpy(
65+
data[num_candidates * 2 : num_candidates * 3]
66+
)
67+
num_candidates = torch.tensor([num_candidates])
68+
metrics.update(
69+
predictions=mt_target_preds.view(1, -1),
70+
labels=mt_target_labels.view(1, -1),
71+
weights=mt_target_weights.view(1, -1),
72+
num_candidates=num_candidates,
73+
)
74+
for k, v in metrics.compute().items():
75+
logger.warning(f"{k}: {v}")
76+
77+
78+
if __name__ == "__main__":
79+
main()
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# pyre-strict
16+
17+
import gc
18+
import os
19+
from datetime import datetime
20+
from typing import Any, Dict, Optional, Set
21+
22+
import gin
23+
24+
import torch
25+
from generative_recommenders.dlrm_v3.utils import MetricsLogger
26+
from torch.distributed.checkpoint.stateful import Stateful
27+
from torch.optim.optimizer import Optimizer
28+
from torchrec.distributed.types import ShardedTensor
29+
30+
31+
class SparseState(Stateful):
32+
def __init__(self, model: torch.nn.Module, sparse_tensor_keys: Set[str]) -> None:
33+
self.model = model
34+
self.sparse_tensor_keys = sparse_tensor_keys
35+
36+
def state_dict(self) -> Dict[str, torch.Tensor]:
37+
out_dict: Dict[str, torch.Tensor] = {}
38+
is_sharded_tensor: Optional[bool] = None
39+
for k, v in self.model.state_dict().items():
40+
if k in self.sparse_tensor_keys:
41+
if is_sharded_tensor is None:
42+
is_sharded_tensor = isinstance(v, ShardedTensor)
43+
assert is_sharded_tensor == isinstance(v, ShardedTensor)
44+
out_dict[k] = v
45+
return out_dict
46+
47+
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
48+
incompatible_keys = self.model.load_state_dict(state_dict, strict=False)
49+
assert not incompatible_keys.unexpected_keys
50+
51+
52+
def is_sparse_key(k: str, v: torch.Tensor) -> bool:
53+
return isinstance(v, ShardedTensor) or "embedding_collection" in k
54+
55+
56+
def load_dense_state_dict(model: torch.nn.Module, state_dict: Dict[str, Any]) -> None:
57+
own_state = model.state_dict()
58+
own_state_dense_keys = {k for k, v in own_state.items() if not is_sparse_key(k, v)}
59+
state_dict_dense_keys = {
60+
k for k, v in state_dict.items() if not is_sparse_key(k, v)
61+
}
62+
assert (
63+
own_state_dense_keys == state_dict_dense_keys
64+
), f"expects {own_state_dense_keys} but gets {state_dict_dense_keys}"
65+
for name in state_dict_dense_keys:
66+
param = state_dict[name]
67+
if isinstance(param, torch.nn.Parameter):
68+
# backwards compatibility for serialized parameters
69+
param = param.data
70+
own_state[name].copy_(param)
71+
72+
73+
@gin.configurable
74+
def save_dmp_checkpoint(
75+
model: torch.nn.Module,
76+
optimizer: Optimizer,
77+
metric_logger: MetricsLogger,
78+
rank: int,
79+
batch_idx: int,
80+
path: str = "",
81+
) -> None:
82+
if path == "":
83+
return
84+
now = datetime.now()
85+
formatted_datetime = now.strftime("%Y_%m_%d_%H_%M_%S")
86+
path = f"{path}/{batch_idx}"
87+
if not os.path.exists(path) and rank == 0:
88+
os.makedirs(path)
89+
sparse_path = f"{path}/sparse/"
90+
if not os.path.exists(sparse_path) and rank == 0:
91+
os.makedirs(sparse_path)
92+
non_sparse_ckpt = f"{path}/non_sparse.ckpt"
93+
94+
sparse_tensor_keys = {
95+
k for k, v in model.state_dict().items() if isinstance(v, ShardedTensor)
96+
}
97+
if rank == 0:
98+
dense_state_dict = {
99+
k: v
100+
for k, v in model.state_dict().items()
101+
if not isinstance(v, ShardedTensor)
102+
}
103+
class_metric_state_dict = {
104+
"train": [m.state_dict() for m in metric_logger.class_metrics["train"]],
105+
"eval": [m.state_dict() for m in metric_logger.class_metrics["eval"]],
106+
}
107+
regression_metric_state_dict = {
108+
"train": [
109+
m.state_dict() for m in metric_logger.regression_metrics["train"]
110+
],
111+
"eval": [m.state_dict() for m in metric_logger.regression_metrics["eval"]],
112+
}
113+
torch.save(
114+
{
115+
"dense_dict": dense_state_dict,
116+
"optimizer_dict": optimizer.state_dict(),
117+
"class_metrics": class_metric_state_dict,
118+
"reg_metrics": regression_metric_state_dict,
119+
"global_step": metric_logger.global_step,
120+
"sparse_tensor_keys": sparse_tensor_keys,
121+
},
122+
non_sparse_ckpt,
123+
)
124+
torch.distributed.barrier()
125+
sparse_dict = {"sparse_dict": SparseState(model, sparse_tensor_keys)}
126+
torch.distributed.checkpoint.save(
127+
sparse_dict,
128+
storage_writer=torch.distributed.checkpoint.FileSystemWriter(sparse_path),
129+
)
130+
torch.distributed.barrier()
131+
print("checkpoint successfully saved")
132+
133+
134+
@gin.configurable
135+
def load_sparse_checkpoint(
136+
model: torch.nn.Module,
137+
path: str = "",
138+
) -> None:
139+
if path == "":
140+
return
141+
sparse_path = f"{path}/sparse/"
142+
143+
sparse_tensor_keys = {
144+
k for k, v in model.state_dict().items() if is_sparse_key(k, v)
145+
}
146+
sparse_dict = {"sparse_dict": SparseState(model, sparse_tensor_keys)}
147+
gc.collect()
148+
torch.distributed.checkpoint.load(
149+
sparse_dict,
150+
storage_reader=torch.distributed.checkpoint.FileSystemReader(sparse_path),
151+
)
152+
gc.collect()
153+
print("sparse checkpoint successfully loaded")
154+
155+
156+
@gin.configurable
157+
def load_nonsparse_checkpoint(
158+
model: torch.nn.Module,
159+
device: torch.device,
160+
optimizer: Optional[Optimizer] = None,
161+
metric_logger: Optional[MetricsLogger] = None,
162+
path: str = "",
163+
) -> None:
164+
if path == "":
165+
return
166+
non_sparse_ckpt = f"{path}/non_sparse.ckpt"
167+
168+
non_sparse_state_dict = torch.load(non_sparse_ckpt, map_location=device)
169+
load_dense_state_dict(model, non_sparse_state_dict["dense_dict"])
170+
print("dense checkpoint successfully loaded")
171+
if optimizer is not None:
172+
optimizer.load_state_dict(non_sparse_state_dict["optimizer_dict"])
173+
print("optimizer checkpoint successfully loaded")
174+
if metric_logger is not None:
175+
metric_logger.global_step = non_sparse_state_dict["global_step"]
176+
class_metric_state_dict = non_sparse_state_dict["class_metrics"]
177+
regression_metric_state_dict = non_sparse_state_dict["reg_metrics"]
178+
for i, m in enumerate(metric_logger.class_metrics["train"]):
179+
m.load_state_dict(class_metric_state_dict["train"][i])
180+
for i, m in enumerate(metric_logger.class_metrics["eval"]):
181+
m.load_state_dict(class_metric_state_dict["eval"][i])
182+
for i, m in enumerate(metric_logger.regression_metrics["train"]):
183+
m.load_state_dict(regression_metric_state_dict["train"][i])
184+
for i, m in enumerate(metric_logger.regression_metrics["eval"]):
185+
m.load_state_dict(regression_metric_state_dict["eval"][i])
186+
187+
188+
@gin.configurable
189+
def load_dmp_checkpoint(
190+
model: torch.nn.Module,
191+
optimizer: Optimizer,
192+
metric_logger: MetricsLogger,
193+
device: torch.device,
194+
path: str = "",
195+
) -> None:
196+
load_sparse_checkpoint(model=model, path=path)
197+
load_nonsparse_checkpoint(
198+
model=model,
199+
optimizer=optimizer,
200+
metric_logger=metric_logger,
201+
path=path,
202+
device=device,
203+
)

0 commit comments

Comments
 (0)