Skip to content

Conversation

@jiawenliu64
Copy link
Contributor

Generate INT4 MP8 checkpoint:

torchrun --nproc-per-node=8 -m models.llama4.scripts.quantize --ckpt_dir ../checkpoints/Llama-4-Scout-17B-16E-Instruct --output_dir ../checkpoints/Llama-4-Scout-17B-16E-Instruct-INT4 --quantization_mode int4_mixed --world_size 8

Verify generated INT4 MP8 checkpoint with int4_mixed on single GPU (output):

PYTHONPATH=$(git rev-parse --show-toplevel) torchrun --nproc_per_node=1 -m models.llama4.scripts.chat_completion ../checkpoints/Llama-4-Scout-17B-16E-Instruct-INT4 --world_size 1 --quantization-mode int4_mixed

Generate FP8 MP8 checkpoint:

torchrun --nproc-per-node=8 -m models.llama4.scripts.quantize --ckpt_dir ../checkpoints/Llama-4-Scout-17B-16E-Instruct --output_dir ../checkpoints/Llama-4-Scout-17B-16E-Instruct-FP8 --quantization_mode fp8_mixed --world_size 8

Verify generated FP8 MP8 checkpoint with fp8_mixed (output):

PYTHONPATH=$(git rev-parse --show-toplevel) torchrun --nproc_per_node=8 -m models.llama4.scripts.chat_completion ../checkpoints/Llama-4-Scout-17B-16E-Instruct-FP8 --world_size 8 --quantization-mode fp8_mixed

Verify BF16 MP8 checkpoint (output):

PYTHONPATH=$(git rev-parse --show-toplevel) torchrun --nproc_per_node=8 -m models.llama4.scripts.chat_completion ../checkpoints/Llama-4-Scout-17B-16E-Instruct --world_size 8

Verify BF16 MP8 checkpoint with fp8_mixed (output):

PYTHONPATH=$(git rev-parse --show-toplevel) torchrun --nproc_per_node=8 -m models.llama4.scripts.chat_completion ../checkpoints/Llama-4-Scout-17B-16E-Instruct --world_size 8 --quantization-mode fp8_mixed

Verify BF16 MP8 checkpoint with int4_mixed on single GPU (output):

PYTHONPATH=$(git rev-parse --show-toplevel) torchrun --nproc_per_node=1 -m models.llama4.scripts.chat_completion ../checkpoints/Llama-4-Scout-17B-16E-Instruct --world_size 1 --quantization-mode int4_mixed

Generate INT4 MP8 checkpoint:
```
torchrun --nproc-per-node=8 -m models.llama4.scripts.quantize --ckpt_dir ../checkpoints/Llama-4-Scout-17B-16E-Instruct --output_dir ../checkpoints/Llama-4-Scout-17B-16E-Instruct-INT4 --quantization_mode int4_mixed --world_size 8
```
Verify generated INT4 MP8 checkpoint with int4_mixed on single GPU (output):
```
PYTHONPATH=$(git rev-parse --show-toplevel) torchrun --nproc_per_node=1 -m models.llama4.scripts.chat_completion ../checkpoints/Llama-4-Scout-17B-16E-Instruct-INT4 --world_size 1 --quantization-mode int4_mixed
```
Generate FP8 MP8 checkpoint:
```
torchrun --nproc-per-node=8 -m models.llama4.scripts.quantize --ckpt_dir ../checkpoints/Llama-4-Scout-17B-16E-Instruct --output_dir ../checkpoints/Llama-4-Scout-17B-16E-Instruct-FP8 --quantization_mode fp8_mixed --world_size 8
```
Verify generated FP8 MP8 checkpoint with fp8_mixed (output):
```
PYTHONPATH=$(git rev-parse --show-toplevel) torchrun --nproc_per_node=8 -m models.llama4.scripts.chat_completion ../checkpoints/Llama-4-Scout-17B-16E-Instruct-FP8 --world_size 8 --quantization-mode fp8_mixed
```

Verify BF16 MP8 checkpoint (output):
```
PYTHONPATH=$(git rev-parse --show-toplevel) torchrun --nproc_per_node=8 -m models.llama4.scripts.chat_completion ../checkpoints/Llama-4-Scout-17B-16E-Instruct --world_size 8
```
Verify BF16 MP8 checkpoint with fp8_mixed (output):
```
PYTHONPATH=$(git rev-parse --show-toplevel) torchrun --nproc_per_node=8 -m models.llama4.scripts.chat_completion ../checkpoints/Llama-4-Scout-17B-16E-Instruct --world_size 8 --quantization-mode fp8_mixed
```
Verify BF16 MP8 checkpoint with int4_mixed on single GPU (output):
```
PYTHONPATH=$(git rev-parse --show-toplevel) torchrun --nproc_per_node=1 -m models.llama4.scripts.chat_completion ../checkpoints/Llama-4-Scout-17B-16E-Instruct --world_size 1 --quantization-mode int4_mixed
```
@jiawenliu64 jiawenliu64 requested a review from jianyuh April 24, 2025 06:13
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 24, 2025
@jiawenliu64 jiawenliu64 changed the title Enable loading precompiled INT4 weights in Llama4 Enable loading pre-quantized INT4 weights in Llama4 Apr 24, 2025

self.int4_weight = int4_weight
dtype = torch.get_default_dtype()
if int4_weight:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this feels like complexity that truly doesn't belong at this layer. can we please keep it outside into quantization code somehow?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't want llama-models to become torchao or vllm or whatever really. it is not a full fledged all powerful inference engine.

torch.set_default_tensor_type(torch.BFloat16Tensor)
model = Transformer(model_args)
print("Loading state dict...")
model.load_state_dict(state_dict, strict=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you move the model.load_state_dict() to convert_to_quantized_model() then you can do the following:

  • change the structure of the Transformer from the outside in this code path (whatever you are doing with Experts)
  • move all this scale ckpt paths complexity into quantization land

nobody reading generation.py should know about quantization unless they want to dig into it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants