Skip to content

[Qwen3]: TypeError: liger_fused_linear_cross_entropy() got an unexpected keyword argument 'return_dict' #925

@yeshsurya

Description

@yeshsurya

🐛 Describe the bug

The return_dict parameter is a standard transformers parameter that controls output format (ModelOutput object vs tuple). PEFT passes this parameter to the base model, but Liger Kernel's model implementations pass all **kwargs to LigerForCausalLMLoss, which eventually reaches liger_fused_linear_cross_entropy() that doesn't accept this parameter.

Full Stack Trace
Traceback (most recent call last):
File "/opt/conda/envs/ptca/lib/python3.10/site-packages/transformers/trainer.py", line 3884, in compute_loss
outputs = model(**inputs)
File "/opt/conda/envs/ptca/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2179, in forward
loss = self.module(*inputs, **kwargs)
File "/opt/conda/envs/ptca/lib/python3.10/site-packages/peft/peft_model.py", line 1850, in forward
return self.base_model(
File "/opt/conda/envs/ptca/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 222, in forward
return self.model.forward(*args, **kwargs)
File "/opt/conda/envs/ptca/lib/python3.10/site-packages/liger_kernel/transformers/model/qwen3.py", line 95, in lce_forward
loss = LigerForCausalLMLoss(
File "/opt/conda/envs/ptca/lib/python3.10/site-packages/liger_kernel/transformers/model/loss_utils.py", line 58, in LigerForCausalLMLoss
loss = fixed_fused_linear_cross_entropy(
File "/opt/conda/envs/ptca/lib/python3.10/site-packages/liger_kernel/transformers/model/loss_utils.py", line 20, in fixed_fused_linear_cross_entropy
loss = F.liger_fused_linear_cross_entropy(
TypeError: liger_fused_linear_cross_entropy() got an unexpected keyword argument 'return_dict'

Reproduce

This issue occurs when:

Using PEFT (LoRA/QLoRA/etc.) with any Liger Kernel supported model
Training with transformers Trainer
The trainer's compute_loss calls model(**inputs) where inputs contains return_dict

Versions

Operating System: Linux-5.15.0-1074-azure-x86_64-with-glibc2.35
Python version: 3.10.18
Liger Kernel version: 0.6.3
PyTorch version: 2.7.1+cu126
CUDA version: 12.6
HIP(ROCm) version: Not available
Triton version: 3.3.1
Transformers version: 4.55.0
XPU version: XPU Not Available

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions