Skip to content

Conversation

@akoumpa
Copy link
Contributor

@akoumpa akoumpa commented Oct 28, 2025

HF

from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B",
)

NeMo AutoModel

from nemo_automodel import NeMoAutoModelForCausalLM
from torch.distributed.device_mesh import init_device_mesh

mesh = init_device_mesh("cuda", mesh_shape=(1,1,1,1,2), mesh_dim_names=("pp","dp_replicate","dp_shard","cp","tp"))

model = NeMoAutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B",
    device_mesh=mesh,
    distributed={"tp_size": 2, "cp_size": 1, "pp_size": 1, "dp_size": 2, "backend": "nccl"},
)

This PR extends the Auto API to include the device_mesh and distributed options. The goal is to provide a drop-in class that supports models with distributed processing.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Oct 28, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@akoumpa akoumpa changed the title Refactor: AutoModel entry point refactor: AutoModel entry point Oct 28, 2025
@akoumpa akoumpa force-pushed the akoumparouli/refactor_auto_entrypoint branch from 56564c8 to 31f29b8 Compare October 30, 2025 05:51
Signed-off-by: Alexandros Koumparoulis <[email protected]>
Signed-off-by: Alexandros Koumparoulis <[email protected]>
@akoumpa akoumpa force-pushed the akoumparouli/refactor_auto_entrypoint branch from 32c3357 to c375b42 Compare November 6, 2025 06:10
Signed-off-by: Alexandros Koumparoulis <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants