Skip to content

Commit 760bde1

Browse files
committed
add docstring
1 parent 642e1d6 commit 760bde1

File tree

12 files changed

+1061
-6
lines changed

12 files changed

+1061
-6
lines changed

recommendation/dlrm_v3/accuracy.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ def get_args() -> argparse.Namespace:
4242

4343

4444
def main() -> None:
45+
"""
46+
Main function to calculate accuracy metrics from loadgen output.
47+
48+
Reads the mlperf_log_accuracy.json file, parses the results, and computes
49+
accuracy metrics using the MetricsLogger. Each result entry contains
50+
predictions, labels, and weights packed as float32 numpy arrays.
51+
"""
4552
args = get_args()
4653
logger.warning("Parsing loadgen accuracy log...")
4754
with open(args.path, "r") as f:

recommendation/dlrm_v3/checkpoint.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@
1313
# limitations under the License.
1414

1515
# pyre-strict
16+
"""
17+
Checkpoint utilities for saving and loading DLRMv3 model checkpoints.
18+
19+
This module provides functions for saving and loading distributed model checkpoints,
20+
including both sparse (embedding) and dense (non-embedding) components.
21+
"""
1622

1723
import gc
1824
import os
@@ -29,6 +35,17 @@
2935

3036

3137
class SparseState(Stateful):
38+
"""
39+
Stateful wrapper for sparse (embedding) tensors in a model.
40+
41+
This class implements the Stateful interface for distributed checkpointing,
42+
allowing sparse tensors to be saved and loaded separately from dense tensors.
43+
44+
Args:
45+
model: The PyTorch model containing sparse tensors.
46+
sparse_tensor_keys: Set of keys identifying sparse tensors in the model's state dict.
47+
"""
48+
3249
def __init__(self, model: torch.nn.Module, sparse_tensor_keys: Set[str]) -> None:
3350
self.model = model
3451
self.sparse_tensor_keys = sparse_tensor_keys
@@ -79,6 +96,20 @@ def save_dmp_checkpoint(
7996
batch_idx: int,
8097
path: str = "",
8198
) -> None:
99+
"""
100+
Save a distributed model checkpoint including sparse and dense components.
101+
102+
Saves the model's sparse tensors using distributed checkpointing and dense
103+
tensors, optimizer state, and metrics using standard PyTorch serialization.
104+
105+
Args:
106+
model: The model to checkpoint.
107+
optimizer: The optimizer whose state should be saved.
108+
metric_logger: The metrics logger containing training/eval metrics.
109+
rank: The current process rank in distributed training.
110+
batch_idx: The current batch index (used for checkpoint naming).
111+
path: Base path for saving the checkpoint. If empty, no checkpoint is saved.
112+
"""
82113
if path == "":
83114
return
84115
now = datetime.now()
@@ -161,6 +192,18 @@ def load_nonsparse_checkpoint(
161192
metric_logger: Optional[MetricsLogger] = None,
162193
path: str = "",
163194
) -> None:
195+
"""
196+
Load non-sparse (dense) components from a checkpoint.
197+
198+
Loads dense model parameters, and optionally optimizer state and metrics.
199+
200+
Args:
201+
model: The model to load dense parameters into.
202+
device: The device to load tensors onto.
203+
optimizer: Optional optimizer to restore state for.
204+
metric_logger: Optional metrics logger to restore state for.
205+
path: Base path of the checkpoint. If empty, no loading is performed.
206+
"""
164207
if path == "":
165208
return
166209
non_sparse_ckpt = f"{path}/non_sparse.ckpt"
@@ -193,6 +236,19 @@ def load_dmp_checkpoint(
193236
device: torch.device,
194237
path: str = "",
195238
) -> None:
239+
"""
240+
Load a complete distributed model checkpoint (both sparse and dense components).
241+
242+
This is a convenience function that calls both load_sparse_checkpoint and
243+
load_nonsparse_checkpoint.
244+
245+
Args:
246+
model: The model to load the checkpoint into.
247+
optimizer: The optimizer to restore state for.
248+
metric_logger: The metrics logger to restore state for.
249+
device: The device to load tensors onto.
250+
path: Base path of the checkpoint. If empty, no loading is performed.
251+
"""
196252
load_sparse_checkpoint(model=model, path=path)
197253
load_nonsparse_checkpoint(
198254
model=model,

recommendation/dlrm_v3/configs.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
# limitations under the License.
1414

1515
# pyre-strict
16+
"""
17+
Configuration module for DLRMv3 model.
18+
19+
This module provides configuration functions for the HSTU model architecture and embedding table configurations.
20+
"""
1621
from typing import Dict
1722

1823
from generative_recommenders.modules.dlrm_hstu import DlrmHSTUConfig
@@ -27,6 +32,19 @@
2732

2833

2934
def get_hstu_configs(dataset: str = "debug") -> DlrmHSTUConfig:
35+
"""
36+
Create and return HSTU model configuration.
37+
38+
Builds a complete DlrmHSTUConfig with default hyperparameters for the HSTU
39+
architecture including attention settings, embedding dimensions, dropout rates,
40+
and feature name mappings.
41+
42+
Args:
43+
dataset: Dataset identifier (currently unused, reserved for dataset-specific configs).
44+
45+
Returns:
46+
DlrmHSTUConfig: Complete configuration object for the HSTU model.
47+
"""
3048
hstu_config = DlrmHSTUConfig(
3149
hstu_num_heads=4,
3250
hstu_attn_linear_dim=128,
@@ -97,6 +115,18 @@ def get_hstu_configs(dataset: str = "debug") -> DlrmHSTUConfig:
97115

98116

99117
def get_embedding_table_config(dataset: str = "debug") -> Dict[str, EmbeddingConfig]:
118+
"""
119+
Create and return embedding table configurations.
120+
121+
Defines the embedding table configurations for item IDs, category IDs, and user IDs
122+
with their respective dimensions and data types.
123+
124+
Args:
125+
dataset: Dataset identifier (currently unused, reserved for dataset-specific configs).
126+
127+
Returns:
128+
Dict mapping table names to their EmbeddingConfig objects.
129+
"""
100130
return {
101131
"item_id": EmbeddingConfig(
102132
num_embeddings=HASH_SIZE,

recommendation/dlrm_v3/data_producer.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@
1313
# limitations under the License.
1414

1515
# pyre-strict
16+
"""
17+
Data producer module for DLRMv3 inference.
18+
19+
This module provides classes for producing and managing query data during inference,
20+
supporting both single-threaded and multi-threaded data production modes.
21+
"""
1622

1723
import logging
1824
import threading
@@ -28,7 +34,16 @@
2834

2935

3036
class QueryItem:
31-
"""An item that we queue for processing by the thread pool."""
37+
"""
38+
Container for a query item to be processed by the inference thread pool.
39+
40+
Attributes:
41+
query_ids: List of unique identifiers for the queries in this batch.
42+
samples: The sample data containing features for the queries.
43+
start: Time when the query was first received.
44+
dt_queue: Time spent in the queue before processing.
45+
dt_batching: Time spent on batching the data.
46+
"""
3247

3348
def __init__(
3449
self,
@@ -46,13 +61,33 @@ def __init__(
4661

4762

4863
class SingleThreadDataProducer:
64+
"""
65+
Single-threaded data producer for synchronous query processing.
66+
67+
This producer processes queries on the main thread without any parallelism,
68+
suitable for debugging or low-throughput scenarios.
69+
70+
Args:
71+
ds: The dataset to fetch samples from.
72+
run_one_item: Callback function to process a single QueryItem.
73+
"""
74+
4975
def __init__(self, ds: Dataset, run_one_item) -> None: # pyre-ignore [2]
5076
self.ds = ds
5177
self.run_one_item = run_one_item # pyre-ignore [4]
5278

5379
def enqueue(
5480
self, query_ids: List[int], content_ids: List[int], t0: float, dt_queue: float
5581
) -> None:
82+
"""
83+
Enqueue queries for immediate synchronous processing.
84+
85+
Args:
86+
query_ids: List of unique query identifiers.
87+
content_ids: List of content/sample identifiers to fetch.
88+
t0: Timestamp when the query batch was created.
89+
dt_queue: Time spent waiting in the queue.
90+
"""
5691
with torch.profiler.record_function("data batching"):
5792
t0_batching: float = time.time()
5893
samples: Union[Samples, List[Samples]] = self.ds.get_samples(content_ids)
@@ -81,10 +116,23 @@ def enqueue(
81116
self.run_one_item(query)
82117

83118
def finish(self) -> None:
119+
"""Finalize the producer. No-op for single-threaded mode."""
84120
pass
85121

86122

87123
class MultiThreadDataProducer:
124+
"""
125+
Multi-threaded data producer for parallel query processing.
126+
127+
Uses a thread pool to fetch and batch data in parallel with model inference,
128+
improving throughput for high-load scenarios.
129+
130+
Args:
131+
ds: The dataset to fetch samples from.
132+
threads: Number of worker threads to use.
133+
run_one_item: Callback function to process a single QueryItem.
134+
"""
135+
88136
def __init__(
89137
self,
90138
ds: Dataset,
@@ -108,6 +156,14 @@ def __init__(
108156
def handle_tasks(
109157
self, tasks_queue: Queue[Optional[Tuple[List[int], List[int], float, float]]]
110158
) -> None:
159+
"""
160+
Worker thread main loop to process tasks from the queue.
161+
162+
Each worker maintains its own CUDA stream for parallel execution.
163+
164+
Args:
165+
tasks_queue: Queue containing task tuples or None for termination.
166+
"""
111167
stream = torch.cuda.Stream()
112168
while True:
113169
query_and_content_ids = tasks_queue.get()
@@ -147,10 +203,24 @@ def handle_tasks(
147203
def enqueue(
148204
self, query_ids: List[int], content_ids: List[int], t0: float, dt_queue: float
149205
) -> None:
206+
"""
207+
Enqueue queries for asynchronous processing by worker threads.
208+
209+
Args:
210+
query_ids: List of unique query identifiers.
211+
content_ids: List of content/sample identifiers to fetch.
212+
t0: Timestamp when the query batch was created.
213+
dt_queue: Time spent waiting in the queue.
214+
"""
150215
with torch.profiler.record_function("data batching"):
151216
self.tasks.put((query_ids, content_ids, t0, dt_queue))
152217

153218
def finish(self) -> None:
219+
"""
220+
Signal all worker threads to terminate and wait for completion.
221+
222+
Sends None to each worker to trigger graceful shutdown.
223+
"""
154224
for _ in self.workers:
155225
self.tasks.put(None)
156226
for worker in self.workers:

0 commit comments

Comments
 (0)