Skip to content

Commit 6eca242

Browse files
authored
feat: introduce ScopedModuleOffloading in KD to reduce memory usage (#774)
Signed-off-by: Alexandros Koumparoulis <[email protected]>
1 parent 9995e4a commit 6eca242

File tree

3 files changed

+110
-7
lines changed

3 files changed

+110
-7
lines changed

nemo_automodel/components/training/utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import gc
1516
import math
1617
from typing import Iterable
1718

@@ -299,3 +300,28 @@ def scale_grads_and_clip_grad_norm(
299300
pp_axis_name=pp_axis_name,
300301
foreach=foreach,
301302
)
303+
304+
305+
def move_to_device(model, device):
306+
# FSDP modules do not move buffers to the device automatically
307+
for v in model.buffers():
308+
v.data = v.data.to(device)
309+
model.to(device)
310+
gc.collect()
311+
torch.cuda.empty_cache()
312+
313+
314+
class ScopedModuleOffloading:
315+
def __init__(self, model, enabled=False):
316+
self.model = model
317+
self.enabled = enabled
318+
319+
def __enter__(self):
320+
if self.enabled:
321+
move_to_device(self.model, "cuda")
322+
return self
323+
324+
def __exit__(self, exc_type, exc_val, exc_tb):
325+
if self.enabled:
326+
move_to_device(self.model, "cpu")
327+
return False # Re-raise exceptions by default

nemo_automodel/recipes/llm/kd.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
from nemo_automodel.components.loggers.metric_logger import MetricsSample
5252
from nemo_automodel.components.loss.linear_ce import FusedLinearCrossEntropy
5353
from nemo_automodel.components.training.rng import ScopedRNG
54-
from nemo_automodel.components.training.utils import count_tail_padding
54+
from nemo_automodel.components.training.utils import ScopedModuleOffloading, count_tail_padding
5555
from nemo_automodel.recipes.llm.train_ft import (
5656
TrainFinetuneRecipeForNextTokenPrediction,
5757
calculate_loss,
@@ -134,12 +134,14 @@ def setup(self): # noqa: C901 – same complexity as parent
134134
if self.pp_enabled:
135135
raise ValueError("Pipeline parallelism support will be added in the future for knowledge distillation")
136136

137+
self._offload_teacher_model = self.cfg.get("offload_teacher_model", False)
137138
# teacher specific
139+
teacher_device = self.dist_env.device if not self._offload_teacher_model else "cpu"
138140
self.teacher_model = _build_teacher_model(
139141
self.cfg.get("teacher_model", None),
140142
self.cfg.get("seed", 42),
141143
self.cfg.get("packed_sequence.packed_sequence_size", 0) > 0,
142-
self.dist_env.device,
144+
teacher_device,
143145
self.model_wrapper,
144146
self.device_mesh,
145147
)
@@ -173,6 +175,14 @@ def _forward_backward_step(
173175
model = self.model_parts[0]
174176
sync_ctx = get_sync_ctx(model, idx == num_batches - 1) if is_train else nullcontext()
175177
with train_ctx(), sync_ctx:
178+
# No grad for teacher forward
179+
with (
180+
ScopedModuleOffloading(self.teacher_model, enabled=self._offload_teacher_model),
181+
torch.inference_mode(),
182+
):
183+
teacher_logits = self.teacher_model(**batch)
184+
teacher_logits = getattr(teacher_logits, "logits", teacher_logits).detach().clone()
185+
176186
# Student forward
177187
student_keep_last = isinstance(self.loss_fn, FusedLinearCrossEntropy)
178188
if student_keep_last:
@@ -191,11 +201,6 @@ def _forward_backward_step(
191201
hidden_states=student_out.hidden_states[-1] if "hidden_states" in student_out else None,
192202
num_label_tokens=num_label_tokens,
193203
)
194-
# No grad for teacher forward
195-
with torch.no_grad():
196-
teacher_logits = self.teacher_model(**batch)
197-
teacher_logits = getattr(teacher_logits, "logits", teacher_logits).detach()
198-
199204
# Reminder: kd_loss is normalized by num_label_tokens,
200205
# which typically is larger than the number of labels in this batch,
201206
# because it contains the total number of labels for all batches contained

tests/unit_tests/training/test_train_utils.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@
1717
import pytest
1818
import torch
1919

20+
import pytest
21+
import torch
22+
import torch.nn as nn
23+
24+
from nemo_automodel.components.training.utils import move_to_device, ScopedModuleOffloading
2025
from nemo_automodel.components.training.utils import clip_grad_norm, count_tail_padding
2126

2227

@@ -127,3 +132,70 @@ def test_clip_grad_norm_returns_zero_when_max_grad_norm_is_none():
127132
)
128133

129134
assert grad_norm == 0
135+
136+
137+
class _TinyModule(nn.Module):
138+
def __init__(self):
139+
super().__init__()
140+
self.linear = nn.Linear(4, 2, bias=False)
141+
self.register_buffer("scale", torch.ones(1))
142+
143+
144+
def _all_tensors_on_device(module: nn.Module, device_type: str) -> bool:
145+
for p in module.parameters():
146+
if p.device.type != device_type:
147+
return False
148+
for b in module.buffers():
149+
if b.device.type != device_type:
150+
return False
151+
return True
152+
153+
154+
def test_move_to_device_cpu():
155+
model = _TinyModule()
156+
# Ensure starts on CPU
157+
assert _all_tensors_on_device(model, "cpu")
158+
159+
# Move to CPU (idempotent)
160+
move_to_device(model, "cpu")
161+
assert _all_tensors_on_device(model, "cpu")
162+
163+
164+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
165+
def test_move_to_device_cuda():
166+
model = _TinyModule()
167+
# Move to CUDA
168+
move_to_device(model, "cuda")
169+
assert _all_tensors_on_device(model, "cuda")
170+
171+
# Move back to CPU to leave environment clean
172+
move_to_device(model, "cpu")
173+
assert _all_tensors_on_device(model, "cpu")
174+
175+
176+
def test_scoped_offloading_disabled_noop_and_reraises():
177+
model = _TinyModule()
178+
assert _all_tensors_on_device(model, "cpu")
179+
180+
with pytest.raises(ValueError):
181+
with ScopedModuleOffloading(model, enabled=False):
182+
# Should not move devices and should re-raise exceptions
183+
assert _all_tensors_on_device(model, "cpu")
184+
raise ValueError("boom")
185+
186+
# After context, still on CPU
187+
assert _all_tensors_on_device(model, "cpu")
188+
189+
190+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
191+
def test_scoped_offloading_enabled_moves_and_reraises():
192+
model = _TinyModule()
193+
assert _all_tensors_on_device(model, "cpu")
194+
195+
# Enter moves to CUDA, exit moves back to CPU and re-raises exceptions
196+
with pytest.raises(RuntimeError):
197+
with ScopedModuleOffloading(model, enabled=True):
198+
assert _all_tensors_on_device(model, "cuda")
199+
raise RuntimeError("fail inside context")
200+
201+
assert _all_tensors_on_device(model, "cpu")

0 commit comments

Comments
 (0)