Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions models/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@

import numpy as np
import torch
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank, get_model_parallel_world_size
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank,
get_model_parallel_world_size,
)


def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> List[int]:
Expand All @@ -37,6 +40,7 @@ def maybe_reshard_state_dict(
moe_num_experts: Optional[int] = None,
map_location: Union[str, torch.device] = "cpu",
mmap: bool = True,
is_scale: bool = False,
) -> Dict[str, torch.Tensor]:
if str(map_location) == "cpu":
torch.set_default_tensor_type(torch.BFloat16Tensor)
Expand All @@ -45,7 +49,10 @@ def maybe_reshard_state_dict(

ckpt_paths = np.array(sorted(ckpt_paths))

new_mp_size, new_mp_rank = get_model_parallel_world_size(), get_model_parallel_rank()
new_mp_size, new_mp_rank = (
get_model_parallel_world_size(),
get_model_parallel_rank(),
)
old_mp_size = len(ckpt_paths)
old_mp_ranks = map_mp_rank(old_mp_size, new_mp_size, new_mp_rank)

Expand All @@ -65,6 +72,7 @@ def maybe_reshard_state_dict(
size=max(new_mp_size // old_mp_size, 1),
rank=new_mp_rank % max(new_mp_size // old_mp_size, 1),
repeat_qk_qv=max(new_mp_size // n_kv_heads, 1),
is_scale=is_scale,
)


Expand Down Expand Up @@ -102,6 +110,7 @@ def reshard_mp(
size: int,
rank: int,
repeat_qk_qv: int = 1,
is_scale: bool = False,
) -> Dict[str, torch.Tensor]:
"""
Reshard a list of state dicts into a single state dict given a change in MP size.
Expand All @@ -116,7 +125,10 @@ def concat_or_chunk(tensors: List[torch.Tensor], dim: int) -> torch.Tensor:

def process_key(key: str) -> torch.Tensor:
if row_regex.search(key):
return concat_or_chunk([s[key] for s in state_dicts], dim=-1)
if is_scale and "shared_" in key:
return concat_or_chunk([s[key] for s in state_dicts], dim=0)
else:
return concat_or_chunk([s[key] for s in state_dicts], dim=-1)
elif column_regex.search(key):
if "w13" in key or "fc1_weight" in key:
dims = state_dicts[0][key].size()
Expand All @@ -133,7 +145,10 @@ def process_key(key: str) -> torch.Tensor:
elif key == "output.bias" or key == "fc.weight":
return concat_or_chunk([s[key] for s in state_dicts], dim=0)
elif "w_" in key:
return concat_or_chunk([s[key] for s in state_dicts], dim=-2)
if is_scale and ("shared_" in key or "swiglu_" in key):
return concat_or_chunk([s[key] for s in state_dicts], dim=-1)
else:
return concat_or_chunk([s[key] for s in state_dicts], dim=-2)
else:
return concat_or_chunk([s[key] for s in state_dicts], dim=0)
else:
Expand Down
1 change: 1 addition & 0 deletions models/llama4/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class QuantizationArgs(BaseModel):
scheme: Optional[QuantizationScheme] = None
group_size: Optional[int] = None
spinquant: bool = False
int4_weight: bool = False


class LoRAArgs(BaseModel):
Expand Down
33 changes: 30 additions & 3 deletions models/llama4/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from typing import Any, Dict, List

import torch

from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from torch import nn
Expand All @@ -19,13 +21,38 @@ def __init__(
dim: int,
hidden_dim: int,
do_reduce: bool = True,
int4_weight: bool = False,
):
super().__init__()
self.do_reduce = do_reduce

self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x)
self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
self.w1 = ColumnParallelLinear(
dim // 2 if int4_weight else dim,
hidden_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.w2 = RowParallelLinear(
hidden_dim // 2 if int4_weight else hidden_dim,
dim,
bias=False,
input_is_parallel=True,
init_method=lambda x: x,
)
self.w3 = ColumnParallelLinear(
dim // 2 if int4_weight else dim,
hidden_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
if int4_weight:
# Since we pack 2*int4 into 1*int8 and leverage float8_e4m3fn to bypass gradient check in nn.Parameter, we use torch.float8_e4m3fn here
dtype = torch.float8_e4m3fn
self.w1.weight.data = self.w1.weight.data.to(dtype)
self.w2.weight.data = self.w2.weight.data.to(dtype)
self.w3.weight.data = self.w3.weight.data.to(dtype)
self._register_load_state_dict_pre_hook(self.load_hook)

def load_hook(
Expand Down
26 changes: 23 additions & 3 deletions models/llama4/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from ..checkpoint import maybe_reshard_state_dict
from ..datatypes import GenerationResult, QuantizationMode
from .args import ModelArgs
from .args import ModelArgs, QuantizationArgs
from .chat_format import ChatFormat, RawContent, RawMessage
from .datatypes import LLMInput, MaskedEmbedding, TransformerInput
from .model import Transformer
Expand Down Expand Up @@ -82,19 +82,39 @@ def build(

state_dict = maybe_reshard_state_dict(
ckpt_paths,
n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads,
n_kv_heads=(model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads),
moe_num_experts=model_args.moe_args.num_experts,
)
print("Loaded checkpoint")
if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
from .quantization.loader import convert_to_quantized_model

scale_state_dict = None
if quantization_mode == QuantizationMode.int4_mixed:
scale_ckpt_paths = sorted(Path(ckpt_dir).glob("*.pt"))

if len(scale_ckpt_paths) > 0:
print(f"Loading a scale checkpoint (shards={len(scale_ckpt_paths)}, current-mp-size={world_size})")
scale_state_dict = maybe_reshard_state_dict(
scale_ckpt_paths,
n_kv_heads=(model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads),
moe_num_experts=model_args.moe_args.num_experts,
is_scale=True,
)
model_args.quantization_args = QuantizationArgs()
model_args.quantization_args.int4_weight = True
print("Loaded scale checkpoint")
torch.set_default_tensor_type(torch.BFloat16Tensor)
model = Transformer(model_args)
print("Loading state dict...")
model.load_state_dict(state_dict, strict=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you move the model.load_state_dict() to convert_to_quantized_model() then you can do the following:

  • change the structure of the Transformer from the outside in this code path (whatever you are doing with Experts)
  • move all this scale ckpt paths complexity into quantization land

nobody reading generation.py should know about quantization unless they want to dig into it.

print("Done...")
model = convert_to_quantized_model(model, ckpt_dir, quantization_mode)
model = convert_to_quantized_model(
model,
ckpt_dir,
quantization_mode=quantization_mode,
scale_state_dict=scale_state_dict,
)
else:
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
Expand Down
3 changes: 3 additions & 0 deletions models/llama4/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
RowParallelLinear,
VocabParallelEmbedding,
)
from models.quantize_impls import load_int4
from torch import nn

from .args import ModelArgs
Expand Down Expand Up @@ -269,6 +270,7 @@ def __init__(self, layer_id: int, args: ModelArgs):
ffn_dim_multiplier=args.ffn_dim_multiplier,
multiple_of=args.multiple_of,
moe_args=args.moe_args,
int4_weight=(args.quantization_args.int4_weight if args.quantization_args is not None else False),
)
else:
hidden_dim = int(4 * args.dim)
Expand All @@ -280,6 +282,7 @@ def __init__(self, layer_id: int, args: ModelArgs):
self.feed_forward = FeedForward(
dim=args.dim,
hidden_dim=hidden_dim,
int4_weight=(args.quantization_args.int4_weight if args.quantization_args is not None else False),
)
self.layer_id = layer_id
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
Expand Down
30 changes: 23 additions & 7 deletions models/llama4/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import fairscale.nn.model_parallel.initialize as fs_init
import torch
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from torch import Tensor, nn
from torch import nn, Tensor
from torch.nn import functional as F

from .args import MoEArgs
Expand All @@ -25,18 +25,24 @@ def __init__(
num_local_experts: int,
dim: int,
hidden_dim: int,
int4_weight: bool = False,
) -> None:
super().__init__()

self.int4_weight = int4_weight
dtype = torch.get_default_dtype()
if int4_weight:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this feels like complexity that truly doesn't belong at this layer. can we please keep it outside into quantization code somehow?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't want llama-models to become torchao or vllm or whatever really. it is not a full fledged all powerful inference engine.

# Since we pack 2*int4 into 1*int8 and leverage float8_e4m3fn to bypass gradient check in nn.Parameter, we use torch.float8_e4m3fn here
dtype = torch.float8_e4m3fn

self.num_local_experts = num_local_experts
self.dim = dim
divide_factor = fs_init.get_model_parallel_world_size()

self.w1: nn.Parameter = nn.Parameter(
torch.empty(
num_local_experts,
dim,
dim // 2 if int4_weight else dim,
divide_exact(hidden_dim, divide_factor),
dtype=dtype,
)
Expand All @@ -45,7 +51,11 @@ def __init__(
self.w2: nn.Parameter = nn.Parameter(
torch.empty(
num_local_experts,
divide_exact(hidden_dim, divide_factor),
(
divide_exact(hidden_dim, divide_factor) // 2
if int4_weight
else divide_exact(hidden_dim, divide_factor)
),
dim,
dtype=dtype,
)
Expand All @@ -54,7 +64,7 @@ def __init__(
self.w3: nn.Parameter = nn.Parameter(
torch.empty(
num_local_experts,
dim,
dim // 2 if int4_weight else dim,
divide_exact(hidden_dim, divide_factor),
dtype=dtype,
)
Expand All @@ -76,9 +86,13 @@ def load_hook(
if prefix + "moe_w_in_eD_F" in state_dict:
e = self.num_local_experts
D = self.dim
state_dict[prefix + "w1"] = state_dict.pop(prefix + "moe_w_in_eD_F").view(e, D, -1)
state_dict[prefix + "w1"] = state_dict.pop(prefix + "moe_w_in_eD_F").view(
e, D // 2 if self.int4_weight else D, -1
)
state_dict[prefix + "w2"] = state_dict.pop(prefix + "moe_w_out_eF_D").view(e, -1, D)
state_dict[prefix + "w3"] = state_dict.pop(prefix + "moe_w_swiglu_eD_F").view(e, D, -1)
state_dict[prefix + "w3"] = state_dict.pop(prefix + "moe_w_swiglu_eD_F").view(
e, D // 2 if self.int4_weight else D, -1
)

def forward(
self,
Expand Down Expand Up @@ -125,6 +139,7 @@ def __init__(
ffn_dim_multiplier: float,
multiple_of: int,
moe_args: MoEArgs,
int4_weight: bool = False,
) -> None:
super().__init__()

Expand All @@ -150,10 +165,11 @@ def __init__(
num_local_experts,
dim,
hidden_dim,
int4_weight=int4_weight,
)

self.router_DE: nn.Parameter = nn.Parameter(torch.empty(dim, moe_args.num_experts, dtype=dtype))
self.shared_expert = FeedForward(dim, hidden_dim, do_reduce=False)
self.shared_expert = FeedForward(dim, hidden_dim, do_reduce=False, int4_weight=int4_weight)

self._register_load_state_dict_pre_hook(self.load_hook)

Expand Down
29 changes: 19 additions & 10 deletions models/llama4/quantization/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import torch
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
from torch import Tensor, nn
from torch import nn, Tensor
from torch.nn import functional as F

from ...datatypes import QuantizationMode
Expand Down Expand Up @@ -49,6 +49,7 @@ def convert_to_quantized_model(
quantization_mode: Optional[str] = None,
fp8_activation_scale_ub: Optional[float] = 1200.0,
use_rich_progress: bool = True,
scale_state_dict: Optional[dict] = None,
) -> Transformer:
from ...quantize_impls import (
Fp8ScaledWeights,
Expand All @@ -75,13 +76,12 @@ def should_quantize_block(block: nn.Module) -> bool:
use_rich_progress = use_rich_progress and rank == 0
progress, log_status, update_status = logging_callbacks(use_rich_progress, rank, model, should_quantize_block)
if quantization_mode == QuantizationMode.int4_mixed:
int4_scales_path = os.path.join(checkpoint_dir, f"int4_scales_{rank}.pt")
if os.path.isfile(int4_scales_path):
log_status(f"Rank {rank}: Loading int4 scales")
int4_scales = torch.load(int4_scales_path, weights_only=True)
if scale_state_dict is not None:

def apply_quantization(key, weight):
scale = int4_scales[key]
scale = scale_state_dict[key]
if "experts" in key:
scale = scale.squeeze(1)
return load_int4(
weight,
scale,
Expand All @@ -92,7 +92,7 @@ def apply_quantization(key, weight):
log_status(f"Rank {rank}: Quantizing int4 weights from bf16")

def apply_quantization(_, weight):
return quantize_int4(weight, fp8_activation_scale_ub, output_device=torch.device("cuda"))
return quantize_int4(weight, output_device=torch.device("cuda"))

else:
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt")
Expand Down Expand Up @@ -130,26 +130,35 @@ def apply_quantization(_, weight):
prefix = f"layers.{block.layer_id}.feed_forward"
moe = block.feed_forward
moe.experts.batched_swiglu = experts_batched_swiglu_wrapper.__get__(moe.experts)

state_dict_key_map = {
"w1": "moe_w_in_eD_F",
"w2": "moe_w_out_eF_D",
"w3": "moe_w_swiglu_eD_F",
}
for key in ("w1", "w3", "w2"):
param = getattr(moe.experts, key)
update_status(f"Rank {rank} - Layer {block.layer_id} - MoE {key}")
setattr(
moe.experts,
key,
apply_quantization(
f"{prefix}.experts.{key}",
f"{prefix}.experts.{state_dict_key_map[key]}",
param.transpose(1, 2).contiguous(),
),
)

if quantization_mode == QuantizationMode.int4_mixed:
# Quantize shared experts
moe.shared_expert.forward = swiglu_wrapper_no_reduce.__get__(moe.shared_expert)
state_dict_key_map = {
"w1": "w_in_shared_FD.weight",
"w2": "w_out_shared_DF.weight",
"w3": "w_swiglu_FD.weight",
}
for key in ("w1", "w3", "w2"):
param = getattr(moe.shared_expert, key)
update_status(f"Rank {rank} - Layer {block.layer_id} - MoE shared expert {key}")
param.weight = apply_quantization(f"{prefix}.shared_expert.{key}", param.weight)
param.weight = apply_quantization(f"{prefix}.{state_dict_key_map[key]}", param.weight)

processed_blocks += 1
update_status(message=None, completed=processed_blocks)
Expand Down
Loading