Skip to content

Commit 3a3d14e

Browse files
committed
Apply isort and black reformatting
Signed-off-by: clumsy <[email protected]>
1 parent 8c5f102 commit 3a3d14e

File tree

2 files changed

+24
-68
lines changed

2 files changed

+24
-68
lines changed

nemo/lightning/pytorch/callbacks/inference_reporter.py

Lines changed: 18 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,10 @@
1818
from megatron.core import parallel_state
1919
from megatron.core.inference.contexts import StaticInferenceContext
2020
from megatron.core.inference.inference_request import InferenceRequest, Status
21-
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import (
22-
GPTInferenceWrapper,
23-
)
24-
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
25-
InferenceWrapperConfig,
26-
)
21+
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper
22+
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig
2723
from megatron.core.inference.sampling_params import SamplingParams
28-
from megatron.core.inference.text_generation_controllers.text_generation_controller import (
29-
TextGenerationController,
30-
)
24+
from megatron.core.inference.text_generation_controllers.text_generation_controller import TextGenerationController
3125

3226

3327
class InferenceReporter(L.Callback):
@@ -61,9 +55,7 @@ def __init__(
6155
self.checkpoint_name = checkpoint_name
6256
self.dataset_name = dataset_name
6357
self.output_dir = os.path.join(output_dir, f"{checkpoint_name}-{dataset_name}")
64-
self.inference_batch_times_seqlen_threshold = (
65-
inference_batch_times_seqlen_threshold
66-
)
58+
self.inference_batch_times_seqlen_threshold = inference_batch_times_seqlen_threshold
6759
self.inference_params_dtype = inference_params_dtype
6860
self.inference_max_seq_length = inference_max_seq_length
6961
self.max_batch_size = max_batch_size
@@ -72,9 +64,7 @@ def __init__(
7264
self.sample_idx = 0
7365
self.text_generation_controller: TextGenerationController | None = None
7466

75-
def setup(
76-
self, trainer: L.Trainer, pl_module: L.LightningModule, stage: str
77-
) -> None:
67+
def setup(self, trainer: L.Trainer, pl_module: L.LightningModule, stage: str) -> None:
7868
pl_module.tokenizer.detokenize = pl_module.tokenizer.ids_to_text
7969

8070
# Add noop methods to avoid exceptions - we don't need text processing
@@ -93,9 +83,7 @@ def on_validation_batch_end(
9383
L.seed_everything(self.random_seed)
9484

9585
prompt_tokens = self._get_prompt_tokens(batch)
96-
generated_tokens, prompt_logprobs, prompt_logits = self._run_inference(
97-
pl_module, prompt_tokens
98-
)
86+
generated_tokens, prompt_logprobs, prompt_logits = self._run_inference(pl_module, prompt_tokens)
9987

10088
input_text = pl_module.tokenizer.detokenize(prompt_tokens)
10189
generated_text = pl_module.tokenizer.detokenize(generated_tokens)
@@ -113,21 +101,13 @@ def on_validation_batch_end(
113101
self.sample_idx += 1
114102

115103
def _get_prompt_tokens(self, batch: Any) -> list[int]:
116-
assert len(batch["tokens"]) == 1, (
117-
"Only one sample at a time generation supported at the moment"
118-
)
104+
assert len(batch["tokens"]) == 1, "Only one sample at a time generation supported at the moment"
119105
tokens = batch["tokens"][0]
120106

121107
# Add the label token (last token from original sequence) to prompt_tokens
122-
if (
123-
torch.distributed.get_rank() == 0
124-
and "labels" in batch
125-
and len(batch["labels"]) > 0
126-
):
108+
if torch.distributed.get_rank() == 0 and "labels" in batch and len(batch["labels"]) > 0:
127109
last_label = batch["labels"][0][-1].item()
128-
tokens = torch.cat(
129-
[tokens, torch.tensor([last_label], device=tokens.device)]
130-
)
110+
tokens = torch.cat([tokens, torch.tensor([last_label], device=tokens.device)])
131111

132112
device = "cuda" if torch.cuda.is_available() else "cpu"
133113
seq_len = torch.tensor(
@@ -145,9 +125,7 @@ def _get_prompt_tokens(self, batch: Any) -> list[int]:
145125
torch.distributed.broadcast(tokens, src=0)
146126
return tokens.cpu().tolist()
147127

148-
def _get_inference_engine(
149-
self, pl_module: L.LightningModule
150-
) -> TextGenerationController:
128+
def _get_inference_engine(self, pl_module: L.LightningModule) -> TextGenerationController:
151129
if self.text_generation_controller is not None:
152130
return self.text_generation_controller
153131

@@ -162,9 +140,7 @@ def _get_inference_engine(
162140
)
163141

164142
inference_context = StaticInferenceContext.from_config(inference_wrapper_config)
165-
inference_wrapped_model = GPTInferenceWrapper(
166-
pl_module.module, inference_wrapper_config, inference_context
167-
)
143+
inference_wrapped_model = GPTInferenceWrapper(pl_module.module, inference_wrapper_config, inference_context)
168144

169145
self.text_generation_controller = TextGenerationController(
170146
inference_wrapped_model=inference_wrapped_model,
@@ -184,9 +160,9 @@ def _run_inference(
184160
status=Status.ACTIVE_BUT_NOT_GENERATING_TOKENS,
185161
)
186162

187-
results = self._get_inference_engine(
188-
pl_module
189-
).generate_all_output_tokens_static_batch({request_id: inference_request})
163+
results = self._get_inference_engine(pl_module).generate_all_output_tokens_static_batch(
164+
{request_id: inference_request}
165+
)
190166

191167
result = results[request_id]
192168
generated_tokens = result.generated_tokens.tolist()
@@ -213,9 +189,7 @@ def _log_artifacts(
213189
):
214190
return
215191

216-
artifact_path = (
217-
f"inference/validation/step_{trainer.global_step}/batch_{batch_idx}"
218-
)
192+
artifact_path = f"inference/validation/step_{trainer.global_step}/batch_{batch_idx}"
219193
data_map = {
220194
"token_ids": generated_tokens,
221195
"prompt_logprobs": prompt_logprobs,
@@ -232,21 +206,15 @@ def _log_artifacts(
232206
)
233207

234208
ctx = (
235-
tempfile.TemporaryDirectory()
236-
if has_logger
237-
else nullcontext(os.path.join(self.output_dir, artifact_path))
209+
tempfile.TemporaryDirectory() if has_logger else nullcontext(os.path.join(self.output_dir, artifact_path))
238210
)
239211
with ctx as base_dir:
240212
for dir_name, data in data_map.items():
241213
if data:
242214
dir_path = os.path.join(base_dir, dir_name)
243215
os.makedirs(dir_path, exist_ok=True)
244-
file_path = os.path.join(
245-
dir_path, f"{dir_name}_{self.sample_idx}.json"
246-
)
216+
file_path = os.path.join(dir_path, f"{dir_name}_{self.sample_idx}.json")
247217
with open(file_path, "w") as f:
248218
json.dump(data, f)
249219
if has_logger:
250-
trainer.logger.experiment.log_artifact(
251-
file_path, f"{artifact_path}/{dir_name}"
252-
)
220+
trainer.logger.experiment.log_artifact(file_path, f"{artifact_path}/{dir_name}")

tests/lightning/pytorch/callbacks/test_inference_reporter.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,7 @@ def test_run_inference(callback, mock_pl_module):
9797
mock_result.logits = None
9898

9999
mock_controller = Mock()
100-
mock_controller.generate_all_output_tokens_static_batch.return_value = {
101-
"0": mock_result
102-
}
100+
mock_controller.generate_all_output_tokens_static_batch.return_value = {"0": mock_result}
103101
callback.text_generation_controller = mock_controller
104102

105103
tokens, logprobs, logits = callback._run_inference(mock_pl_module, [1, 2, 3])
@@ -115,9 +113,7 @@ def test_run_inference(callback, mock_pl_module):
115113

116114

117115
@patch(f"{inference_reporter.__name__}.parallel_state")
118-
def test_log_artifacts_skips_non_primary_ranks(
119-
mock_parallel_state, callback, mock_trainer
120-
):
116+
def test_log_artifacts_skips_non_primary_ranks(mock_parallel_state, callback, mock_trainer):
121117
mock_parallel_state.get_tensor_model_parallel_rank.return_value = 1
122118
mock_parallel_state.get_data_parallel_rank.return_value = 0
123119

@@ -127,9 +123,7 @@ def test_log_artifacts_skips_non_primary_ranks(
127123

128124

129125
@patch(f"{inference_reporter.__name__}.parallel_state")
130-
def test_log_artifacts_logs_on_primary_rank(
131-
mock_parallel_state, callback, mock_trainer
132-
):
126+
def test_log_artifacts_logs_on_primary_rank(mock_parallel_state, callback, mock_trainer):
133127
mock_parallel_state.get_tensor_model_parallel_rank.return_value = 0
134128
mock_parallel_state.get_data_parallel_rank.return_value = 0
135129
callback.sample_idx = 5
@@ -154,9 +148,7 @@ def test_log_artifacts_logs_on_primary_rank(
154148

155149

156150
@patch(f"{inference_reporter.__name__}.parallel_state")
157-
def test_log_artifacts_saves_to_disk_without_logger(
158-
mock_parallel_state, callback, mock_trainer, tmp_path
159-
):
151+
def test_log_artifacts_saves_to_disk_without_logger(mock_parallel_state, callback, mock_trainer, tmp_path):
160152
mock_parallel_state.get_tensor_model_parallel_rank.return_value = 0
161153
mock_parallel_state.get_data_parallel_rank.return_value = 0
162154
mock_trainer.logger = None
@@ -199,18 +191,14 @@ def test_on_validation_batch_end_integration(
199191
mock_result.logits = None
200192

201193
mock_controller = Mock()
202-
mock_controller.generate_all_output_tokens_static_batch.return_value = {
203-
"0": mock_result
204-
}
194+
mock_controller.generate_all_output_tokens_static_batch.return_value = {"0": mock_result}
205195
callback.text_generation_controller = mock_controller
206196

207197
batch = {"tokens": [torch.tensor([1, 2, 3])], "labels": [torch.tensor([4, 5, 6])]}
208198

209199
with patch("torch.cuda.is_available", return_value=False):
210200
with patch("lightning.seed_everything"):
211-
callback.on_validation_batch_end(
212-
mock_trainer, mock_pl_module, None, batch, 0, 0
213-
)
201+
callback.on_validation_batch_end(mock_trainer, mock_pl_module, None, batch, 0, 0)
214202

215203
assert callback.sample_idx == 1
216204
assert mock_trainer.logger.experiment.log_artifact.call_count > 0

0 commit comments

Comments
 (0)