Skip to content

Commit a8c5286

Browse files
Andrew Gupytorchmergebot
authored andcommitted
[FSDP][6/N] Check valid param freezing for ModuleWrapPolicy (pytorch#104427)
This PR adds improved error/warning messaging when auto wrapping with `ModuleWrapPolicy` in the presence of frozen parameters. - For `use_orig_params=False`, FSDP requires uniform `requires_grad` for each FSDP instance. This PR adds a `ValueError` at wrapping time with a message that mentions the violating module and the frozen/non-frozen parameter names. - For `use_orig_params=True`, FSDP allows non-uniform `requires_grad` for each FSDP instance. However, it will result in higher-than-expected gradient memory usage. This PR adds a `UserWarning` at wrapping time with a message that mentions the violating module, how much extra gradient memory will be used (in units of numel), and the frozen/non-frozen parameter names. - There is a possibility that this warning will be spammy/verbose, but my current thinking is that it is okay for now unless users complain. <details> <summary> Why DFS via named_children() vs. Using named_modules()</summary> ``` LoraModel( (embed_tokens): Embedding(100, 32) (layers): ModuleList( (0-3): 4 x LoraDecoder( (attn): LoraAttention( (q_proj): Linear(in_features=32, out_features=32, bias=False) (lora_A): Linear(in_features=32, out_features=8, bias=False) (lora_B): Linear(in_features=8, out_features=32, bias=False) (k_proj): Linear(in_features=32, out_features=32, bias=False) (v_proj): Linear(in_features=32, out_features=32, bias=False) (o_proj): Linear(in_features=32, out_features=32, bias=False) ) (mlp): LoraMLP( (proj1): Linear(in_features=32, out_features=128, bias=False) (proj2): Linear(in_features=128, out_features=32, bias=False) ) (inp_layernorm): LayerNorm((32,), eps=1e-05, elementwise_affine=True) (post_attn_layernorm): LayerNorm((32,), eps=1e-05, elementwise_affine=True) ) ) (norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True) ) ``` Reverse topological order with stack-based DFS via `named_children()`: ``` [ 'embed_tokens', 'layers.0.attn.q_proj', 'layers.0.attn.lora_A', 'layers.0.attn.lora_B', 'layers.0.attn.k_proj', 'layers.0.attn.v_proj', 'layers.0.attn.o_proj', 'layers.0.attn', 'layers.0.mlp.proj1', 'layers.0.mlp.proj2', 'layers.0.mlp', 'layers.0.inp_layernorm', 'layers.0.post_attn_layernorm', 'layers.0', 'layers.1.attn.q_proj', 'layers.1.attn.lora_A', 'layers.1.attn.lora_B', 'layers.1.attn.k_proj', 'layers.1.attn.v_proj', 'layers.1.attn.o_proj', 'layers.1.attn', 'layers.1.mlp.proj1', 'layers.1.mlp.proj2', 'layers.1.mlp', 'layers.1.inp_layernorm', 'layers.1.post_attn_layernorm', 'layers.1', 'layers.2.attn.q_proj', 'layers.2.attn.lora_A', 'layers.2.attn.lora_B', 'layers.2.attn.k_proj', 'layers.2.attn.v_proj', 'layers.2.attn.o_proj', 'layers.2.attn', 'layers.2.mlp.proj1', 'layers.2.mlp.proj2', 'layers.2.mlp', 'layers.2.inp_layernorm', 'layers.2.post_attn_layernorm', 'layers.2', 'layers.3.attn.q_proj', 'layers.3.attn.lora_A', 'layers.3.attn.lora_B', 'layers.3.attn.k_proj', 'layers.3.attn.v_proj', 'layers.3.attn.o_proj', 'layers.3.attn', 'layers.3.mlp.proj1', 'layers.3.mlp.proj2', 'layers.3.mlp', 'layers.3.inp_layernorm', 'layers.3.post_attn_layernorm', 'layers.3', 'layers', 'norm', '' ] ``` Reverse topological order with `named_modules()`: ``` [ 'norm', 'layers.3.post_attn_layernorm', 'layers.3.inp_layernorm', 'layers.3.mlp.proj2', 'layers.3.mlp.proj1', 'layers.3.mlp', 'layers.3.attn.o_proj', 'layers.3.attn.v_proj', 'layers.3.attn.k_proj', 'layers.3.attn.lora_B', 'layers.3.attn.lora_A', 'layers.3.attn.q_proj', 'layers.3.attn', 'layers.3', 'layers.2.post_attn_layernorm', 'layers.2.inp_layernorm', 'layers.2.mlp.proj2', 'layers.2.mlp.proj1', 'layers.2.mlp', 'layers.2.attn.o_proj', 'layers.2.attn.v_proj', 'layers.2.attn.k_proj', 'layers.2.attn.lora_B', 'layers.2.attn.lora_A', 'layers.2.attn.q_proj', 'layers.2.attn', 'layers.2', 'layers.1.post_attn_layernorm', 'layers.1.inp_layernorm', 'layers.1.mlp.proj2', 'layers.1.mlp.proj1', 'layers.1.mlp', 'layers.1.attn.o_proj', 'layers.1.attn.v_proj', 'layers.1.attn.k_proj', 'layers.1.attn.lora_B', 'layers.1.attn.lora_A', 'layers.1.attn.q_proj', 'layers.1.attn', 'layers.1', 'layers.0.post_attn_layernorm', 'layers.0.inp_layernorm', 'layers.0.mlp.proj2', 'layers.0.mlp.proj1', 'layers.0.mlp', 'layers.0.attn.o_proj', 'layers.0.attn.v_proj', 'layers.0.attn.k_proj', 'layers.0.attn.lora_B', 'layers.0.attn.lora_A', 'layers.0.attn.q_proj', 'layers.0.attn', 'layers.0', 'layers', 'embed_tokens', '' ] ``` With the stack-based DFS via `named_children()`, reversing the topological order gives us each level in the module tree in the registered order, wheres with `named_modules()`, reversing the topological order gives us each level in reverse. Both are valid orders, but we prefer the former since it allows us to error/warn on the _first-registered_ module that violates the frozen/non-frozen condition. </details> Pull Request resolved: pytorch#104427 Approved by: https://github.com/ezyang
1 parent aec8418 commit a8c5286

File tree

2 files changed

+309
-1
lines changed

2 files changed

+309
-1
lines changed

test/distributed/fsdp/test_wrap.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
import torch.nn as nn
1212
import torch.nn.functional as F
13+
from torch.distributed.fsdp._wrap_utils import _validate_frozen_params
1314
from torch.distributed.fsdp.fully_sharded_data_parallel import (
1415
BackwardPrefetch,
1516
CPUOffload,
@@ -57,6 +58,56 @@ def __init__(self):
5758
self.sync_bn = nn.SyncBatchNorm(10)
5859

5960

61+
class LoraModel(nn.Module):
62+
"""This is a toy LoRA decoder model."""
63+
64+
def __init__(self):
65+
super().__init__()
66+
self.embed_tokens = nn.Embedding(100, 32)
67+
self.layers = nn.ModuleList([LoraDecoder() for _ in range(4)])
68+
self.norm = nn.LayerNorm(32)
69+
self.embed_tokens.weight.requires_grad_(False)
70+
self.norm.weight.requires_grad_(False)
71+
self.norm.bias.requires_grad_(False)
72+
73+
74+
class LoraDecoder(nn.Module):
75+
def __init__(self):
76+
super().__init__()
77+
self.attn = LoraAttention()
78+
self.mlp = LoraMLP()
79+
self.inp_layernorm = nn.LayerNorm(32)
80+
self.post_attn_layernorm = nn.LayerNorm(32)
81+
self.inp_layernorm.weight.requires_grad_(False)
82+
self.inp_layernorm.bias.requires_grad_(False)
83+
self.post_attn_layernorm.weight.requires_grad_(False)
84+
self.post_attn_layernorm.bias.requires_grad_(False)
85+
86+
87+
class LoraAttention(nn.Module):
88+
def __init__(self):
89+
super().__init__()
90+
self.q_proj = nn.Linear(32, 32, bias=False)
91+
self.lora_A = nn.Linear(32, 8, bias=False)
92+
self.lora_B = nn.Linear(8, 32, bias=False)
93+
self.k_proj = nn.Linear(32, 32, bias=False)
94+
self.v_proj = nn.Linear(32, 32, bias=False)
95+
self.o_proj = nn.Linear(32, 32, bias=False)
96+
self.q_proj.weight.requires_grad_(False)
97+
self.k_proj.weight.requires_grad_(False)
98+
self.v_proj.weight.requires_grad_(False)
99+
self.o_proj.weight.requires_grad_(False)
100+
101+
102+
class LoraMLP(nn.Module):
103+
def __init__(self):
104+
super().__init__()
105+
self.proj1 = nn.Linear(32, 128, bias=False)
106+
self.proj2 = nn.Linear(128, 32, bias=False)
107+
self.proj1.weight.requires_grad_(False)
108+
self.proj2.weight.requires_grad_(False)
109+
110+
60111
class WrapMethod(Enum):
61112
FSDP_CTOR = auto()
62113
# FSDP_CTOR is the supported way forward, but keep WRAP_API in case we miss
@@ -650,6 +701,116 @@ def test_auto_wrap_with_ignored_modules(self, wrap_method: WrapMethod):
650701
self.assertTrue(isinstance(model.module[2][0], nn.Linear))
651702
self.assertTrue(isinstance(model.module[2][1], nn.Linear))
652703

704+
@unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs")
705+
def test_frozen_params(self):
706+
"""
707+
Tests that mixing frozen/non-frozen parameters in an FSDP instance
708+
raises for ``use_orig_params=False`` and warns for ``True``.
709+
"""
710+
for use_orig_params in [True, False]:
711+
self._test_frozen_params(use_orig_params)
712+
713+
def _test_frozen_params(self, use_orig_params: bool):
714+
model = LoraModel().cuda()
715+
policy = ModuleWrapPolicy({LoraAttention, LoraMLP, LoraDecoder})
716+
msg = "layers.0.attn has both parameters with requires_grad=True and False. "
717+
if use_orig_params:
718+
msg += "We do not recommend wrapping such modules"
719+
ctx = self.assertWarnsRegex(UserWarning, msg)
720+
else:
721+
msg += "FSDP does not support wrapping such modules when use_orig_params=False."
722+
ctx = self.assertRaisesRegex(ValueError, msg)
723+
with ctx:
724+
FSDP(
725+
model,
726+
process_group=self.process_group,
727+
auto_wrap_policy=policy,
728+
use_orig_params=use_orig_params,
729+
)
730+
731+
732+
class TestWrapUtils(TestCase):
733+
def test_validate_frozen_params(self):
734+
"""Tests the method ``_validate_frozen_params()``."""
735+
for use_orig_params in [True, False]:
736+
self._test_validate_frozen_params(use_orig_params)
737+
738+
def _test_validate_frozen_params(self, use_orig_params: bool):
739+
model = LoraModel()
740+
# Wrap only LoRA modules
741+
modules_to_wrap = {
742+
module
743+
for module_name, module in model.named_modules()
744+
if "lora_A" in module_name or "lora_B" in module_name
745+
}
746+
_validate_frozen_params(model, modules_to_wrap, set(), use_orig_params)
747+
# Additionally wrap attention
748+
for module in model.modules():
749+
if isinstance(module, LoraAttention):
750+
modules_to_wrap.add(module)
751+
_validate_frozen_params(model, modules_to_wrap, set(), use_orig_params)
752+
# Additionally wrap decoders
753+
for module in model.modules():
754+
if isinstance(module, LoraDecoder):
755+
modules_to_wrap.add(module)
756+
_validate_frozen_params(model, modules_to_wrap, set(), use_orig_params)
757+
# Do not wrap the LoRA-A modules (meaning mixed frozen/non-frozen)
758+
for module_name, module in model.named_modules():
759+
if "lora_A" in module_name:
760+
modules_to_wrap.remove(module)
761+
regex = "layers.0.attn has both parameters with requires_grad=True and False."
762+
if use_orig_params:
763+
# Wrapping the attention manages all parameters except those from
764+
# the LoRA-B module, which is separately wrapped and all nonfrozen
765+
lorab_numel = sum(
766+
p.numel() for p in model.layers[0].attn.lora_B.parameters()
767+
)
768+
attn_frozen_param_numel = sum(
769+
p.numel()
770+
for p in model.layers[0].attn.parameters()
771+
if not p.requires_grad
772+
)
773+
attn_nonfrozen_param_numel = (
774+
sum(
775+
p.numel()
776+
for p in model.layers[0].attn.parameters()
777+
if p.requires_grad
778+
)
779+
- lorab_numel
780+
)
781+
attn_total_param_numel = (
782+
attn_frozen_param_numel + attn_nonfrozen_param_numel
783+
)
784+
regex += (
785+
" We do not recommend wrapping such modules since the "
786+
r"gradient memory usage will be higher than expected \("
787+
f"{attn_total_param_numel} numel instead of {attn_nonfrozen_param_numel} numel "
788+
r"before sharding via reduce-scatter\). "
789+
)
790+
else:
791+
regex += " FSDP does not support wrapping such modules when use_orig_params=False. "
792+
regex += "If possible, wrap the frozen parameters with FSDP separately.\n"
793+
regex += (
794+
"The following parameters have requires_grad=True:\n"
795+
r"\['layers.0.attn.lora_A.weight'\]\n"
796+
"The following parameters have requires_grad=False:\n"
797+
r"\['layers.0.attn.q_proj.weight', 'layers.0.attn.k_proj.weight', "
798+
r"'layers.0.attn.v_proj.weight', 'layers.0.attn.o_proj.weight'\]"
799+
)
800+
if use_orig_params:
801+
ctx = self.assertWarnsRegex(UserWarning, regex)
802+
else:
803+
ctx = self.assertRaisesRegex(ValueError, regex)
804+
with ctx:
805+
_validate_frozen_params(model, modules_to_wrap, set(), use_orig_params)
806+
# Now ignore those LoRA-A modules' parameters
807+
ignored_params = set()
808+
for module_name, module in model.named_modules():
809+
if "lora_A" in module_name:
810+
for param in module.parameters():
811+
ignored_params.add(param)
812+
_validate_frozen_params(model, modules_to_wrap, ignored_params, use_orig_params)
813+
653814

654815
instantiate_parametrized_tests(TestFSDPWrap)
655816
instantiate_parametrized_tests(TestAutoWrap)

torch/distributed/fsdp/_wrap_utils.py

Lines changed: 148 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import collections
12
import functools
23
import inspect
34
import warnings
45
from functools import partial
5-
from typing import Any, Callable, Dict, Set, Type, Union
6+
from typing import Any, Callable, Dict, List, Set, Tuple, Type, Union
67

78
import torch.nn as nn
89
from torch.distributed.fsdp._common_utils import (
@@ -64,6 +65,13 @@ def _auto_wrap(
6465
root_module, mixed_precision._module_classes_to_ignore
6566
)
6667
_warn_on_overridden_mixed_precision(overridden_module_classes)
68+
use_orig_params = fsdp_kwargs.get("use_orig_params", False)
69+
_validate_frozen_params(
70+
root_module,
71+
set(target_module_to_kwargs.keys()),
72+
ignored_params,
73+
use_orig_params,
74+
)
6775
wrap_fn = _construct_wrap_fn(root_module, target_module_to_kwargs, fsdp_fn)
6876
_post_order_apply(root_module, wrap_fn)
6977
return
@@ -121,3 +129,142 @@ def _warn_on_overridden_mixed_precision(
121129
"These modules will be wrapped as separate FSDP instacnes with mixed "
122130
"precision disabled."
123131
)
132+
133+
134+
def _validate_frozen_params(
135+
root_module: nn.Module,
136+
modules_to_wrap: Set[nn.Module],
137+
ignored_params: Set[nn.Parameter],
138+
use_orig_params: bool,
139+
):
140+
"""
141+
This checks that, given ``modules_to_wrap``, each module would manage
142+
parameters that are uniformly frozen or non-frozen. This uniformity
143+
requirement is strict for ``use_orig_params=False`` (hard error) and highly
144+
recommended for ``use_orig_params=True`` (user warning).
145+
"""
146+
post_order_named_modules = _get_post_order_named_modules(root_module)
147+
visited_modules: Set[nn.Module] = set()
148+
for module_name, module in post_order_named_modules:
149+
if module in modules_to_wrap:
150+
param_to_fqn = _get_managed_param_to_fqn(
151+
module, ignored_params, visited_modules, module_name
152+
)
153+
frozen_param_fqns: List[str] = []
154+
frozen_param_numel = 0
155+
nonfrozen_param_fqns: List[str] = []
156+
nonfrozen_param_numel = 0
157+
for param, fqn in param_to_fqn.items():
158+
if param.requires_grad:
159+
nonfrozen_param_fqns.append(fqn)
160+
nonfrozen_param_numel += param.numel()
161+
else:
162+
frozen_param_fqns.append(fqn)
163+
frozen_param_numel += param.numel()
164+
if len(frozen_param_fqns) > 0 and len(nonfrozen_param_fqns) > 0:
165+
msg = f"{module_name} has both parameters with requires_grad=True and False."
166+
if use_orig_params:
167+
total_param_numel = frozen_param_numel + nonfrozen_param_numel
168+
msg += (
169+
" We do not recommend wrapping such modules since "
170+
"the gradient memory usage will be higher than expected "
171+
f"({total_param_numel} numel instead of {nonfrozen_param_numel} numel "
172+
"before sharding via reduce-scatter). "
173+
)
174+
else:
175+
msg += " FSDP does not support wrapping such modules when use_orig_params=False. "
176+
msg += "If possible, wrap the frozen parameters with FSDP separately.\n"
177+
msg += (
178+
f"The following parameters have requires_grad=True:\n{nonfrozen_param_fqns}\n"
179+
f"The following parameters have requires_grad=False:\n{frozen_param_fqns}"
180+
)
181+
if use_orig_params:
182+
warnings.warn(msg)
183+
else:
184+
raise ValueError(msg)
185+
186+
187+
def _get_post_order_named_modules(
188+
root_module: nn.Module,
189+
) -> List[Tuple[str, nn.Module]]:
190+
"""
191+
This returns the named modules following a post-order traversal, which is a
192+
valid reverse topological sort. We achieve this using the reverse of a
193+
stack-based DFS order instead of reversing ``root_module.named_modules()``
194+
since the former gives the modules in registration order at each level in
195+
the module tree (as opposed to the reverse), which allows us to error/warn
196+
on the first registered module that violates the condition.
197+
198+
For example, consider the following module structure:
199+
M(
200+
S1(),
201+
S2(
202+
SS1(),
203+
SS2(),
204+
),
205+
S3(),
206+
)
207+
The reverse DFS order is [S1, SS1, SS2, S2, S3, M], while the reverse
208+
``named_modules()`` order is [S3, SS2, SS1, S2, S1, M].
209+
"""
210+
visited_modules = {root_module}
211+
stack = [("", root_module)]
212+
# Append and reverse at the end for linear-time algorithm
213+
reverse_post_order_named_modules: List[Tuple[str, nn.Module]] = []
214+
while stack:
215+
module_name, module = stack.pop()
216+
reverse_post_order_named_modules.append((module_name, module))
217+
for child_module_name, child_module in module.named_children():
218+
if child_module is None: # only for overrides of `named_children()`
219+
continue
220+
if child_module not in visited_modules:
221+
visited_modules.add(child_module)
222+
if module_name != "":
223+
child_module_name = module_name + "." + child_module_name
224+
stack.append((child_module_name, child_module))
225+
post_order_named_modules = list(reversed(reverse_post_order_named_modules))
226+
return post_order_named_modules
227+
228+
229+
def _get_managed_param_to_fqn(
230+
module_to_wrap: nn.Module,
231+
ignored_params: Set[nn.Parameter],
232+
visited_modules: Set[nn.Module],
233+
root_prefix: str,
234+
) -> Dict[nn.Parameter, str]:
235+
"""
236+
This returns a dict that maps managed parameter to its FQN for the given
237+
``module_to_wrap``. The dict's keys are exactly the parameters that would
238+
be managed by the module, where this is achieved by calling this function
239+
on the modules to wrap in reverse topological order, destructively updating
240+
``visited_modules``, and not traversing into those modules. The FQNs are
241+
prefixed from the root (via ``root_prefix``) to be more informative.
242+
243+
NOTE: This function is meant to be called pre-wrapping and iteratively in
244+
reverse topological order to cover the full module tree. This differs from
245+
the ``_get_param_to_fqn()`` function meant to be called post-wrapping and
246+
on the full module tree in one shot. Given those differences, we do not try
247+
to unify the two.
248+
"""
249+
param_to_fqn: Dict[nn.Parameter, str] = {}
250+
# Run BFS (or any tree traversal works)
251+
queue = collections.deque([(module_to_wrap, root_prefix)])
252+
visited_modules.add(module_to_wrap)
253+
while queue:
254+
module, prefix = queue.popleft()
255+
for param_name, param in module.named_parameters(recurse=False):
256+
if param not in ignored_params:
257+
fqn = param_name if prefix == "" else prefix + "." + param_name
258+
param_to_fqn[param] = fqn
259+
for child_module_name, child_module in module.named_children():
260+
if child_module is None: # only for overrides of `named_children()`
261+
continue
262+
if child_module not in visited_modules:
263+
visited_modules.add(child_module)
264+
child_prefix = (
265+
child_module_name
266+
if prefix == ""
267+
else prefix + "." + child_module_name
268+
)
269+
queue.append((child_module, child_prefix))
270+
return param_to_fqn

0 commit comments

Comments
 (0)