1818from megatron .core import parallel_state
1919from megatron .core .inference .contexts import StaticInferenceContext
2020from megatron .core .inference .inference_request import InferenceRequest , Status
21- from megatron .core .inference .model_inference_wrappers .gpt .gpt_inference_wrapper import (
22- GPTInferenceWrapper ,
23- )
24- from megatron .core .inference .model_inference_wrappers .inference_wrapper_config import (
25- InferenceWrapperConfig ,
26- )
21+ from megatron .core .inference .model_inference_wrappers .gpt .gpt_inference_wrapper import GPTInferenceWrapper
22+ from megatron .core .inference .model_inference_wrappers .inference_wrapper_config import InferenceWrapperConfig
2723from megatron .core .inference .sampling_params import SamplingParams
28- from megatron .core .inference .text_generation_controllers .text_generation_controller import (
29- TextGenerationController ,
30- )
24+ from megatron .core .inference .text_generation_controllers .text_generation_controller import TextGenerationController
3125
3226
3327class InferenceReporter (L .Callback ):
@@ -61,9 +55,7 @@ def __init__(
6155 self .checkpoint_name = checkpoint_name
6256 self .dataset_name = dataset_name
6357 self .output_dir = os .path .join (output_dir , f"{ checkpoint_name } -{ dataset_name } " )
64- self .inference_batch_times_seqlen_threshold = (
65- inference_batch_times_seqlen_threshold
66- )
58+ self .inference_batch_times_seqlen_threshold = inference_batch_times_seqlen_threshold
6759 self .inference_params_dtype = inference_params_dtype
6860 self .inference_max_seq_length = inference_max_seq_length
6961 self .max_batch_size = max_batch_size
@@ -72,9 +64,7 @@ def __init__(
7264 self .sample_idx = 0
7365 self .text_generation_controller : TextGenerationController | None = None
7466
75- def setup (
76- self , trainer : L .Trainer , pl_module : L .LightningModule , stage : str
77- ) -> None :
67+ def setup (self , trainer : L .Trainer , pl_module : L .LightningModule , stage : str ) -> None :
7868 pl_module .tokenizer .detokenize = pl_module .tokenizer .ids_to_text
7969
8070 # Add noop methods to avoid exceptions - we don't need text processing
@@ -93,9 +83,7 @@ def on_validation_batch_end(
9383 L .seed_everything (self .random_seed )
9484
9585 prompt_tokens = self ._get_prompt_tokens (batch )
96- generated_tokens , prompt_logprobs , prompt_logits = self ._run_inference (
97- pl_module , prompt_tokens
98- )
86+ generated_tokens , prompt_logprobs , prompt_logits = self ._run_inference (pl_module , prompt_tokens )
9987
10088 input_text = pl_module .tokenizer .detokenize (prompt_tokens )
10189 generated_text = pl_module .tokenizer .detokenize (generated_tokens )
@@ -113,21 +101,13 @@ def on_validation_batch_end(
113101 self .sample_idx += 1
114102
115103 def _get_prompt_tokens (self , batch : Any ) -> list [int ]:
116- assert len (batch ["tokens" ]) == 1 , (
117- "Only one sample at a time generation supported at the moment"
118- )
104+ assert len (batch ["tokens" ]) == 1 , "Only one sample at a time generation supported at the moment"
119105 tokens = batch ["tokens" ][0 ]
120106
121107 # Add the label token (last token from original sequence) to prompt_tokens
122- if (
123- torch .distributed .get_rank () == 0
124- and "labels" in batch
125- and len (batch ["labels" ]) > 0
126- ):
108+ if torch .distributed .get_rank () == 0 and "labels" in batch and len (batch ["labels" ]) > 0 :
127109 last_label = batch ["labels" ][0 ][- 1 ].item ()
128- tokens = torch .cat (
129- [tokens , torch .tensor ([last_label ], device = tokens .device )]
130- )
110+ tokens = torch .cat ([tokens , torch .tensor ([last_label ], device = tokens .device )])
131111
132112 device = "cuda" if torch .cuda .is_available () else "cpu"
133113 seq_len = torch .tensor (
@@ -145,9 +125,7 @@ def _get_prompt_tokens(self, batch: Any) -> list[int]:
145125 torch .distributed .broadcast (tokens , src = 0 )
146126 return tokens .cpu ().tolist ()
147127
148- def _get_inference_engine (
149- self , pl_module : L .LightningModule
150- ) -> TextGenerationController :
128+ def _get_inference_engine (self , pl_module : L .LightningModule ) -> TextGenerationController :
151129 if self .text_generation_controller is not None :
152130 return self .text_generation_controller
153131
@@ -162,9 +140,7 @@ def _get_inference_engine(
162140 )
163141
164142 inference_context = StaticInferenceContext .from_config (inference_wrapper_config )
165- inference_wrapped_model = GPTInferenceWrapper (
166- pl_module .module , inference_wrapper_config , inference_context
167- )
143+ inference_wrapped_model = GPTInferenceWrapper (pl_module .module , inference_wrapper_config , inference_context )
168144
169145 self .text_generation_controller = TextGenerationController (
170146 inference_wrapped_model = inference_wrapped_model ,
@@ -184,9 +160,9 @@ def _run_inference(
184160 status = Status .ACTIVE_BUT_NOT_GENERATING_TOKENS ,
185161 )
186162
187- results = self ._get_inference_engine (
188- pl_module
189- ). generate_all_output_tokens_static_batch ({ request_id : inference_request })
163+ results = self ._get_inference_engine (pl_module ). generate_all_output_tokens_static_batch (
164+ { request_id : inference_request }
165+ )
190166
191167 result = results [request_id ]
192168 generated_tokens = result .generated_tokens .tolist ()
@@ -213,9 +189,7 @@ def _log_artifacts(
213189 ):
214190 return
215191
216- artifact_path = (
217- f"inference/validation/step_{ trainer .global_step } /batch_{ batch_idx } "
218- )
192+ artifact_path = f"inference/validation/step_{ trainer .global_step } /batch_{ batch_idx } "
219193 data_map = {
220194 "token_ids" : generated_tokens ,
221195 "prompt_logprobs" : prompt_logprobs ,
@@ -232,21 +206,15 @@ def _log_artifacts(
232206 )
233207
234208 ctx = (
235- tempfile .TemporaryDirectory ()
236- if has_logger
237- else nullcontext (os .path .join (self .output_dir , artifact_path ))
209+ tempfile .TemporaryDirectory () if has_logger else nullcontext (os .path .join (self .output_dir , artifact_path ))
238210 )
239211 with ctx as base_dir :
240212 for dir_name , data in data_map .items ():
241213 if data :
242214 dir_path = os .path .join (base_dir , dir_name )
243215 os .makedirs (dir_path , exist_ok = True )
244- file_path = os .path .join (
245- dir_path , f"{ dir_name } _{ self .sample_idx } .json"
246- )
216+ file_path = os .path .join (dir_path , f"{ dir_name } _{ self .sample_idx } .json" )
247217 with open (file_path , "w" ) as f :
248218 json .dump (data , f )
249219 if has_logger :
250- trainer .logger .experiment .log_artifact (
251- file_path , f"{ artifact_path } /{ dir_name } "
252- )
220+ trainer .logger .experiment .log_artifact (file_path , f"{ artifact_path } /{ dir_name } " )
0 commit comments