Skip to content

Training requires upgrading transformers from 4.49.0 due to dtype mismatch error, causing potential performance concerns #9

@YcZhangSing

Description

@YcZhangSing

Excellent work!
However, I'm encountering an environment-related issue. I would greatly appreciate your help in resolving it.

I set up the environment according to your configuration in setup.sh:

Package                           Version       Editable project location
--------------------------------- ------------- -----------------------------------------------
accelerate                        1.4.0
aiohappyeyeballs                  2.6.1
aiohttp                           3.11.18
aiohttp-cors                      0.8.1
aiosignal                         1.3.2
airportsdata                      20250224
annotated-types                   0.7.0
antlr4-python3-runtime            4.13.2
anyio                             4.9.0
astor                             0.8.1
async-timeout                     5.0.1
attrs                             25.3.0
av                                14.3.0
bitsandbytes                      0.45.5
black                             25.1.0
blake3                            1.0.4
cachetools                        5.5.2
certifi                           2025.4.26
charset-normalizer                3.4.2
click                             8.1.8
cloudpickle                       3.1.1
colorful                          0.5.6
compressed-tensors                0.9.1
datasets                          3.6.0
deepspeed                         0.15.4
depyf                             0.18.0
dill                              0.3.8
diskcache                         5.6.3
distlib                           0.3.9
distro                            1.9.0
docker-pycreds                    0.4.0
einops                            0.8.1
exceptiongroup                    1.2.2
fastapi                           0.115.12
filelock                          3.18.0
flake8                            7.2.0
flash_attn                        2.7.4.post1
frozenlist                        1.6.0
fsspec                            2025.3.0
gguf                              0.10.0
gitdb                             4.0.12
GitPython                         3.1.44
google-api-core                   2.24.2
google-auth                       2.40.1
googleapis-common-protos          1.70.0
grpcio                            1.71.0
h11                               0.16.0
hf_transfer                       0.1.9
hf-xet                            1.1.0
hjson                             3.1.0
httpcore                          1.0.9
httptools                         0.6.4
httpx                             0.28.1
huggingface-hub                   0.31.1
idna                              3.10
importlib_metadata                8.7.0
iniconfig                         2.1.0
inquirerpy                        0.3.4
interegular                       0.3.3
isort                             6.0.1
Jinja2                            3.1.6
jiter                             0.9.0
jsonschema                        4.23.0
jsonschema-specifications         2025.4.1
lark                              1.2.2
latex2sympy2_extended             1.10.1
liger_kernel                      0.5.2
lm-format-enforcer                0.10.11
markdown-it-py                    3.0.0
MarkupSafe                        3.0.2
math-verify                       0.7.0
mccabe                            0.7.0
mdurl                             0.1.2
mistral_common                    1.5.4
mpmath                            1.3.0
msgpack                           1.1.0
msgspec                           0.19.0
multidict                         6.4.3
multiprocess                      0.70.16
mypy_extensions                   1.1.0
nest-asyncio                      1.6.0
networkx                          3.4.2
ninja                             1.11.1.4
numpy                             1.26.4
nvidia-cublas-cu12                12.4.5.8
nvidia-cuda-cupti-cu12            12.4.127
nvidia-cuda-nvrtc-cu12            12.4.127
nvidia-cuda-runtime-cu12          12.4.127
nvidia-cudnn-cu12                 9.1.0.70
nvidia-cufft-cu12                 11.2.1.3
nvidia-cufile-cu12                1.11.1.6
nvidia-curand-cu12                10.3.5.147
nvidia-cusolver-cu12              11.6.1.9
nvidia-cusparse-cu12              12.3.1.170
nvidia-cusparselt-cu12            0.6.3
nvidia-ml-py                      12.575.51
nvidia-nccl-cu12                  2.21.5
nvidia-nvjitlink-cu12             12.4.127
nvidia-nvtx-cu12                  12.4.127
open-r1                           0.1.0.dev0    /data1_hdd/yaxiong/zyc_temp/UI-R1/src/ui_r1/src
openai                            1.77.0
opencensus                        0.11.4
opencensus-context                0.1.3
opencv-python-headless            4.11.0.86
outlines                          0.1.11
outlines_core                     0.1.26
packaging                         25.0
pandas                            2.2.3
parameterized                     0.9.0
partial-json-parser               0.2.1.1.post5
pathspec                          0.12.1
pfzy                              0.3.4
pillow                            11.2.1
pip                               25.1
platformdirs                      4.3.8
pluggy                            1.5.0
prometheus_client                 0.21.1
prometheus-fastapi-instrumentator 7.1.0
prompt_toolkit                    3.0.51
propcache                         0.3.1
proto-plus                        1.26.1
protobuf                          5.29.4
psutil                            7.0.0
py-cpuinfo                        9.0.0
py-spy                            0.4.0
pyarrow                           20.0.0
pyasn1                            0.6.1
pyasn1_modules                    0.4.2
pycodestyle                       2.13.0
pycountry                         24.6.1
pydantic                          2.11.4
pydantic_core                     2.33.2
pyflakes                          3.3.2
Pygments                          2.19.1
pytest                            8.3.5
python-dateutil                   2.9.0.post0
python-dotenv                     1.1.0
pytz                              2025.2
PyYAML                            6.0.2
pyzmq                             26.4.0
qwen-vl-utils                     0.0.11
ray                               2.46.0
referencing                       0.36.2
regex                             2024.11.6
requests                          2.32.3
rich                              14.0.0
rpds-py                           0.24.0
rsa                               4.9.1
safetensors                       0.5.3
sentencepiece                     0.2.0
sentry-sdk                        2.27.0
setproctitle                      1.3.6
setuptools                        78.1.1
six                               1.17.0
smart-open                        7.1.0
smmap                             5.0.2
sniffio                           1.3.1
starlette                         0.46.2
sympy                             1.13.1
tensorboardX                      2.6.2.2
tiktoken                          0.9.0
tokenizers                        0.21.1
tomli                             2.2.1
torch                             2.5.1
torchaudio                        2.5.1
torchvision                       0.20.1
tqdm                              4.67.1
transformers                      4.49.0
triton                            3.1.0
trl                               0.16.0
typing_extensions                 4.13.2
typing-inspection                 0.4.0
tzdata                            2025.2
urllib3                           2.4.0
uvicorn                           0.34.2
uvloop                            0.21.0
virtualenv                        20.31.1
vllm                              0.7.2
wandb                             0.18.3
watchfiles                        1.0.5
wcwidth                           0.2.13
websockets                        15.0.1
wheel                             0.45.1
wrapt                             1.17.2
xformers                          0.0.28.post3
xgrammar                          0.1.19
xxhash                            3.5.0
yarl                              1.20.0
zipp                              3.21.0

Run train.sh with this environment will report errors:

[rank1]: AssertionError: Input and cos/sin must have the same dtype, got torch.float32 and torch.bfloat16

I noticed that is a bug from transformers=4.49.0(modelscope/ms-swift#3156)
So I upgraded transformers to 4.50.3. At this time, train.sh can run normally, but there will be an environment ERR

open-r1 0.1.0.dev0 requires transformers==4.49.0, but you have transformers 4.50.3 which is incompatible.

So, will using transformers=4.50.3 this way have an impact on model performance? Because I can train normally using this configuration, but after training 8 epochs, the accuracy of the model is only about 20%.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions