1313# limitations under the License.
1414
1515# pyre-strict
16+ """
17+ Data producer module for DLRMv3 inference.
18+
19+ This module provides classes for producing and managing query data during inference,
20+ supporting both single-threaded and multi-threaded data production modes.
21+ """
1622
1723import logging
1824import threading
2834
2935
3036class QueryItem :
31- """An item that we queue for processing by the thread pool."""
37+ """
38+ Container for a query item to be processed by the inference thread pool.
39+
40+ Attributes:
41+ query_ids: List of unique identifiers for the queries in this batch.
42+ samples: The sample data containing features for the queries.
43+ start: Time when the query was first received.
44+ dt_queue: Time spent in the queue before processing.
45+ dt_batching: Time spent on batching the data.
46+ """
3247
3348 def __init__ (
3449 self ,
@@ -46,13 +61,33 @@ def __init__(
4661
4762
4863class SingleThreadDataProducer :
64+ """
65+ Single-threaded data producer for synchronous query processing.
66+
67+ This producer processes queries on the main thread without any parallelism,
68+ suitable for debugging or low-throughput scenarios.
69+
70+ Args:
71+ ds: The dataset to fetch samples from.
72+ run_one_item: Callback function to process a single QueryItem.
73+ """
74+
4975 def __init__ (self , ds : Dataset , run_one_item ) -> None : # pyre-ignore [2]
5076 self .ds = ds
5177 self .run_one_item = run_one_item # pyre-ignore [4]
5278
5379 def enqueue (
5480 self , query_ids : List [int ], content_ids : List [int ], t0 : float , dt_queue : float
5581 ) -> None :
82+ """
83+ Enqueue queries for immediate synchronous processing.
84+
85+ Args:
86+ query_ids: List of unique query identifiers.
87+ content_ids: List of content/sample identifiers to fetch.
88+ t0: Timestamp when the query batch was created.
89+ dt_queue: Time spent waiting in the queue.
90+ """
5691 with torch .profiler .record_function ("data batching" ):
5792 t0_batching : float = time .time ()
5893 samples : Union [Samples , List [Samples ]] = self .ds .get_samples (content_ids )
@@ -81,10 +116,23 @@ def enqueue(
81116 self .run_one_item (query )
82117
83118 def finish (self ) -> None :
119+ """Finalize the producer. No-op for single-threaded mode."""
84120 pass
85121
86122
87123class MultiThreadDataProducer :
124+ """
125+ Multi-threaded data producer for parallel query processing.
126+
127+ Uses a thread pool to fetch and batch data in parallel with model inference,
128+ improving throughput for high-load scenarios.
129+
130+ Args:
131+ ds: The dataset to fetch samples from.
132+ threads: Number of worker threads to use.
133+ run_one_item: Callback function to process a single QueryItem.
134+ """
135+
88136 def __init__ (
89137 self ,
90138 ds : Dataset ,
@@ -108,6 +156,14 @@ def __init__(
108156 def handle_tasks (
109157 self , tasks_queue : Queue [Optional [Tuple [List [int ], List [int ], float , float ]]]
110158 ) -> None :
159+ """
160+ Worker thread main loop to process tasks from the queue.
161+
162+ Each worker maintains its own CUDA stream for parallel execution.
163+
164+ Args:
165+ tasks_queue: Queue containing task tuples or None for termination.
166+ """
111167 stream = torch .cuda .Stream ()
112168 while True :
113169 query_and_content_ids = tasks_queue .get ()
@@ -147,10 +203,24 @@ def handle_tasks(
147203 def enqueue (
148204 self , query_ids : List [int ], content_ids : List [int ], t0 : float , dt_queue : float
149205 ) -> None :
206+ """
207+ Enqueue queries for asynchronous processing by worker threads.
208+
209+ Args:
210+ query_ids: List of unique query identifiers.
211+ content_ids: List of content/sample identifiers to fetch.
212+ t0: Timestamp when the query batch was created.
213+ dt_queue: Time spent waiting in the queue.
214+ """
150215 with torch .profiler .record_function ("data batching" ):
151216 self .tasks .put ((query_ids , content_ids , t0 , dt_queue ))
152217
153218 def finish (self ) -> None :
219+ """
220+ Signal all worker threads to terminate and wait for completion.
221+
222+ Sends None to each worker to trigger graceful shutdown.
223+ """
154224 for _ in self .workers :
155225 self .tasks .put (None )
156226 for worker in self .workers :
0 commit comments