-
Notifications
You must be signed in to change notification settings - Fork 438
Description
🐛 Describe the bug
The return_dict parameter is a standard transformers parameter that controls output format (ModelOutput object vs tuple). PEFT passes this parameter to the base model, but Liger Kernel's model implementations pass all **kwargs to LigerForCausalLMLoss, which eventually reaches liger_fused_linear_cross_entropy() that doesn't accept this parameter.
Full Stack Trace
Traceback (most recent call last):
File "/opt/conda/envs/ptca/lib/python3.10/site-packages/transformers/trainer.py", line 3884, in compute_loss
outputs = model(**inputs)
File "/opt/conda/envs/ptca/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2179, in forward
loss = self.module(*inputs, **kwargs)
File "/opt/conda/envs/ptca/lib/python3.10/site-packages/peft/peft_model.py", line 1850, in forward
return self.base_model(
File "/opt/conda/envs/ptca/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 222, in forward
return self.model.forward(*args, **kwargs)
File "/opt/conda/envs/ptca/lib/python3.10/site-packages/liger_kernel/transformers/model/qwen3.py", line 95, in lce_forward
loss = LigerForCausalLMLoss(
File "/opt/conda/envs/ptca/lib/python3.10/site-packages/liger_kernel/transformers/model/loss_utils.py", line 58, in LigerForCausalLMLoss
loss = fixed_fused_linear_cross_entropy(
File "/opt/conda/envs/ptca/lib/python3.10/site-packages/liger_kernel/transformers/model/loss_utils.py", line 20, in fixed_fused_linear_cross_entropy
loss = F.liger_fused_linear_cross_entropy(
TypeError: liger_fused_linear_cross_entropy() got an unexpected keyword argument 'return_dict'
Reproduce
This issue occurs when:
Using PEFT (LoRA/QLoRA/etc.) with any Liger Kernel supported model
Training with transformers Trainer
The trainer's compute_loss calls model(**inputs) where inputs contains return_dict
Versions
Operating System: Linux-5.15.0-1074-azure-x86_64-with-glibc2.35
Python version: 3.10.18
Liger Kernel version: 0.6.3
PyTorch version: 2.7.1+cu126
CUDA version: 12.6
HIP(ROCm) version: Not available
Triton version: 3.3.1
Transformers version: 4.55.0
XPU version: XPU Not Available