diff --git a/nemo_automodel/components/_peft/lora.py b/nemo_automodel/components/_peft/lora.py index a3daf85de..2adb5c2a1 100644 --- a/nemo_automodel/components/_peft/lora.py +++ b/nemo_automodel/components/_peft/lora.py @@ -340,16 +340,11 @@ def apply_lora_to_linear_modules( for w in model.parameters(): w.requires_grad_(False) - is_causal_lm = False - try: - if hasattr(model, "config") and "CausalLM" in model.config.architectures[0]: - # for example, LlamaForCausalLM - is_causal_lm = True - except AttributeError: - is_causal_lm = False - matcher = ModuleMatcher( - peft_config.target_modules, peft_config.exclude_modules, peft_config.match_all_linear, is_causal_lm + peft_config.target_modules, + peft_config.exclude_modules, + peft_config.match_all_linear, + model, ) num_modules_matched = 0 for name, module in list(model.named_modules()): diff --git a/nemo_automodel/components/_peft/module_matcher.py b/nemo_automodel/components/_peft/module_matcher.py index b2974263b..fd7c971bb 100644 --- a/nemo_automodel/components/_peft/module_matcher.py +++ b/nemo_automodel/components/_peft/module_matcher.py @@ -37,6 +37,28 @@ def wildcard_match(pattern, key): return match is not None +def _get_model_embedding_ptrs(model: nn.Module) -> list[str]: + ptrs = [] + for name, module in model.named_modules(): + if not isinstance(module, nn.Embedding): + continue + ptrs.append(module.weight.data_ptr()) + return ptrs + + +def _get_tied_target_modules(model: nn.Module) -> list[str]: + if model is None: + return [] + tied_target_modules = [] + embedding_ptrs = set(_get_model_embedding_ptrs(model)) + for name, module in model.named_modules(): + if not isinstance(module, nn.Linear): + continue + if module.weight.data_ptr() in embedding_ptrs: + tied_target_modules.append(name) + return tied_target_modules + + @dataclass class ModuleMatcher: """ @@ -55,13 +77,13 @@ class ModuleMatcher: on the first two layers. exclude_modules (List[str], optional): A list of module names to exclude from applying LoRA to. match_all_linear (bool, optional): Whether to match all linear layers. - is_causal_lm (bool, optional): Whether the model is a causal language model. + model (nn.Module, optional): The model to match modules on. """ target_modules: List[str] = field(default_factory=lambda: ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"]) exclude_modules: List[str] = field(default_factory=list) match_all_linear: bool = field(default=False) - is_causal_lm: bool = field(default=False) + model: nn.Module = field(default=None) def __post_init__(self): """ @@ -77,6 +99,7 @@ def __post_init__(self): and (not isinstance(self.exclude_modules, list) or len(self.exclude_modules) == 0) ): raise ValueError("Expected match_all_linear to be true or target_modules/exclude_modules to be non-empty") + self.tied_target_modules = _get_tied_target_modules(self.model) # --------------------------------------------------------------------- # # Public API # @@ -85,12 +108,10 @@ def match(self, m: nn.Module, name: str = None, prefix: str = None): """ Return (pattern, full_name) if the module matches; otherwise None. """ - full_name = f"{prefix}.{name}" if prefix else name - - if self.is_causal_lm: - if "lm_head" in full_name: - return False + if m in self.tied_target_modules: + return False + full_name = f"{prefix}.{name}" if prefix else name # 1. matching by layer type takes absolute precedence if self.match_all_linear and isinstance(m, nn.Linear): return True