diff --git a/models/checkpoint.py b/models/checkpoint.py index b579865c0..106d798e2 100644 --- a/models/checkpoint.py +++ b/models/checkpoint.py @@ -12,7 +12,10 @@ import numpy as np import torch -from fairscale.nn.model_parallel.initialize import get_model_parallel_rank, get_model_parallel_world_size +from fairscale.nn.model_parallel.initialize import ( + get_model_parallel_rank, + get_model_parallel_world_size, +) def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> List[int]: @@ -37,6 +40,7 @@ def maybe_reshard_state_dict( moe_num_experts: Optional[int] = None, map_location: Union[str, torch.device] = "cpu", mmap: bool = True, + is_scale: bool = False, ) -> Dict[str, torch.Tensor]: if str(map_location) == "cpu": torch.set_default_tensor_type(torch.BFloat16Tensor) @@ -45,7 +49,10 @@ def maybe_reshard_state_dict( ckpt_paths = np.array(sorted(ckpt_paths)) - new_mp_size, new_mp_rank = get_model_parallel_world_size(), get_model_parallel_rank() + new_mp_size, new_mp_rank = ( + get_model_parallel_world_size(), + get_model_parallel_rank(), + ) old_mp_size = len(ckpt_paths) old_mp_ranks = map_mp_rank(old_mp_size, new_mp_size, new_mp_rank) @@ -65,6 +72,7 @@ def maybe_reshard_state_dict( size=max(new_mp_size // old_mp_size, 1), rank=new_mp_rank % max(new_mp_size // old_mp_size, 1), repeat_qk_qv=max(new_mp_size // n_kv_heads, 1), + is_scale=is_scale, ) @@ -102,6 +110,7 @@ def reshard_mp( size: int, rank: int, repeat_qk_qv: int = 1, + is_scale: bool = False, ) -> Dict[str, torch.Tensor]: """ Reshard a list of state dicts into a single state dict given a change in MP size. @@ -116,7 +125,10 @@ def concat_or_chunk(tensors: List[torch.Tensor], dim: int) -> torch.Tensor: def process_key(key: str) -> torch.Tensor: if row_regex.search(key): - return concat_or_chunk([s[key] for s in state_dicts], dim=-1) + if is_scale and "shared_" in key: + return concat_or_chunk([s[key] for s in state_dicts], dim=0) + else: + return concat_or_chunk([s[key] for s in state_dicts], dim=-1) elif column_regex.search(key): if "w13" in key or "fc1_weight" in key: dims = state_dicts[0][key].size() @@ -133,7 +145,10 @@ def process_key(key: str) -> torch.Tensor: elif key == "output.bias" or key == "fc.weight": return concat_or_chunk([s[key] for s in state_dicts], dim=0) elif "w_" in key: - return concat_or_chunk([s[key] for s in state_dicts], dim=-2) + if is_scale and ("shared_" in key or "swiglu_" in key): + return concat_or_chunk([s[key] for s in state_dicts], dim=-1) + else: + return concat_or_chunk([s[key] for s in state_dicts], dim=-2) else: return concat_or_chunk([s[key] for s in state_dicts], dim=0) else: diff --git a/models/llama4/args.py b/models/llama4/args.py index d1685d2e6..3920eeebe 100644 --- a/models/llama4/args.py +++ b/models/llama4/args.py @@ -19,6 +19,7 @@ class QuantizationArgs(BaseModel): scheme: Optional[QuantizationScheme] = None group_size: Optional[int] = None spinquant: bool = False + int4_weight: bool = False class LoRAArgs(BaseModel): diff --git a/models/llama4/ffn.py b/models/llama4/ffn.py index 03aa0abea..950d10596 100644 --- a/models/llama4/ffn.py +++ b/models/llama4/ffn.py @@ -7,6 +7,8 @@ from typing import Any, Dict, List +import torch + from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region from torch import nn @@ -19,13 +21,38 @@ def __init__( dim: int, hidden_dim: int, do_reduce: bool = True, + int4_weight: bool = False, ): super().__init__() self.do_reduce = do_reduce - self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x) - self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x) - self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x) + self.w1 = ColumnParallelLinear( + dim // 2 if int4_weight else dim, + hidden_dim, + bias=False, + gather_output=False, + init_method=lambda x: x, + ) + self.w2 = RowParallelLinear( + hidden_dim // 2 if int4_weight else hidden_dim, + dim, + bias=False, + input_is_parallel=True, + init_method=lambda x: x, + ) + self.w3 = ColumnParallelLinear( + dim // 2 if int4_weight else dim, + hidden_dim, + bias=False, + gather_output=False, + init_method=lambda x: x, + ) + if int4_weight: + # Since we pack 2*int4 into 1*int8 and leverage float8_e4m3fn to bypass gradient check in nn.Parameter, we use torch.float8_e4m3fn here + dtype = torch.float8_e4m3fn + self.w1.weight.data = self.w1.weight.data.to(dtype) + self.w2.weight.data = self.w2.weight.data.to(dtype) + self.w3.weight.data = self.w3.weight.data.to(dtype) self._register_load_state_dict_pre_hook(self.load_hook) def load_hook( diff --git a/models/llama4/generation.py b/models/llama4/generation.py index 67dfad460..b6f83ec2f 100644 --- a/models/llama4/generation.py +++ b/models/llama4/generation.py @@ -24,7 +24,7 @@ from ..checkpoint import maybe_reshard_state_dict from ..datatypes import GenerationResult, QuantizationMode -from .args import ModelArgs +from .args import ModelArgs, QuantizationArgs from .chat_format import ChatFormat, RawContent, RawMessage from .datatypes import LLMInput, MaskedEmbedding, TransformerInput from .model import Transformer @@ -82,19 +82,39 @@ def build( state_dict = maybe_reshard_state_dict( ckpt_paths, - n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads, + n_kv_heads=(model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads), moe_num_experts=model_args.moe_args.num_experts, ) print("Loaded checkpoint") if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed: from .quantization.loader import convert_to_quantized_model + scale_state_dict = None + if quantization_mode == QuantizationMode.int4_mixed: + scale_ckpt_paths = sorted(Path(ckpt_dir).glob("*.pt")) + + if len(scale_ckpt_paths) > 0: + print(f"Loading a scale checkpoint (shards={len(scale_ckpt_paths)}, current-mp-size={world_size})") + scale_state_dict = maybe_reshard_state_dict( + scale_ckpt_paths, + n_kv_heads=(model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads), + moe_num_experts=model_args.moe_args.num_experts, + is_scale=True, + ) + model_args.quantization_args = QuantizationArgs() + model_args.quantization_args.int4_weight = True + print("Loaded scale checkpoint") torch.set_default_tensor_type(torch.BFloat16Tensor) model = Transformer(model_args) print("Loading state dict...") model.load_state_dict(state_dict, strict=False) print("Done...") - model = convert_to_quantized_model(model, ckpt_dir, quantization_mode) + model = convert_to_quantized_model( + model, + ckpt_dir, + quantization_mode=quantization_mode, + scale_state_dict=scale_state_dict, + ) else: if torch.cuda.is_bf16_supported(): torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) diff --git a/models/llama4/model.py b/models/llama4/model.py index 75f281d5c..40d0c8d7a 100644 --- a/models/llama4/model.py +++ b/models/llama4/model.py @@ -16,6 +16,7 @@ RowParallelLinear, VocabParallelEmbedding, ) +from models.quantize_impls import load_int4 from torch import nn from .args import ModelArgs @@ -269,6 +270,7 @@ def __init__(self, layer_id: int, args: ModelArgs): ffn_dim_multiplier=args.ffn_dim_multiplier, multiple_of=args.multiple_of, moe_args=args.moe_args, + int4_weight=(args.quantization_args.int4_weight if args.quantization_args is not None else False), ) else: hidden_dim = int(4 * args.dim) @@ -280,6 +282,7 @@ def __init__(self, layer_id: int, args: ModelArgs): self.feed_forward = FeedForward( dim=args.dim, hidden_dim=hidden_dim, + int4_weight=(args.quantization_args.int4_weight if args.quantization_args is not None else False), ) self.layer_id = layer_id self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) diff --git a/models/llama4/moe.py b/models/llama4/moe.py index d66afdb56..59595fa09 100644 --- a/models/llama4/moe.py +++ b/models/llama4/moe.py @@ -12,7 +12,7 @@ import fairscale.nn.model_parallel.initialize as fs_init import torch from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region -from torch import Tensor, nn +from torch import nn, Tensor from torch.nn import functional as F from .args import MoEArgs @@ -25,10 +25,16 @@ def __init__( num_local_experts: int, dim: int, hidden_dim: int, + int4_weight: bool = False, ) -> None: super().__init__() + self.int4_weight = int4_weight dtype = torch.get_default_dtype() + if int4_weight: + # Since we pack 2*int4 into 1*int8 and leverage float8_e4m3fn to bypass gradient check in nn.Parameter, we use torch.float8_e4m3fn here + dtype = torch.float8_e4m3fn + self.num_local_experts = num_local_experts self.dim = dim divide_factor = fs_init.get_model_parallel_world_size() @@ -36,7 +42,7 @@ def __init__( self.w1: nn.Parameter = nn.Parameter( torch.empty( num_local_experts, - dim, + dim // 2 if int4_weight else dim, divide_exact(hidden_dim, divide_factor), dtype=dtype, ) @@ -45,7 +51,11 @@ def __init__( self.w2: nn.Parameter = nn.Parameter( torch.empty( num_local_experts, - divide_exact(hidden_dim, divide_factor), + ( + divide_exact(hidden_dim, divide_factor) // 2 + if int4_weight + else divide_exact(hidden_dim, divide_factor) + ), dim, dtype=dtype, ) @@ -54,7 +64,7 @@ def __init__( self.w3: nn.Parameter = nn.Parameter( torch.empty( num_local_experts, - dim, + dim // 2 if int4_weight else dim, divide_exact(hidden_dim, divide_factor), dtype=dtype, ) @@ -76,9 +86,13 @@ def load_hook( if prefix + "moe_w_in_eD_F" in state_dict: e = self.num_local_experts D = self.dim - state_dict[prefix + "w1"] = state_dict.pop(prefix + "moe_w_in_eD_F").view(e, D, -1) + state_dict[prefix + "w1"] = state_dict.pop(prefix + "moe_w_in_eD_F").view( + e, D // 2 if self.int4_weight else D, -1 + ) state_dict[prefix + "w2"] = state_dict.pop(prefix + "moe_w_out_eF_D").view(e, -1, D) - state_dict[prefix + "w3"] = state_dict.pop(prefix + "moe_w_swiglu_eD_F").view(e, D, -1) + state_dict[prefix + "w3"] = state_dict.pop(prefix + "moe_w_swiglu_eD_F").view( + e, D // 2 if self.int4_weight else D, -1 + ) def forward( self, @@ -125,6 +139,7 @@ def __init__( ffn_dim_multiplier: float, multiple_of: int, moe_args: MoEArgs, + int4_weight: bool = False, ) -> None: super().__init__() @@ -150,10 +165,11 @@ def __init__( num_local_experts, dim, hidden_dim, + int4_weight=int4_weight, ) self.router_DE: nn.Parameter = nn.Parameter(torch.empty(dim, moe_args.num_experts, dtype=dtype)) - self.shared_expert = FeedForward(dim, hidden_dim, do_reduce=False) + self.shared_expert = FeedForward(dim, hidden_dim, do_reduce=False, int4_weight=int4_weight) self._register_load_state_dict_pre_hook(self.load_hook) diff --git a/models/llama4/quantization/loader.py b/models/llama4/quantization/loader.py index 2bafc38c1..6ba88518d 100644 --- a/models/llama4/quantization/loader.py +++ b/models/llama4/quantization/loader.py @@ -11,7 +11,7 @@ import torch from fairscale.nn.model_parallel.initialize import get_model_parallel_rank -from torch import Tensor, nn +from torch import nn, Tensor from torch.nn import functional as F from ...datatypes import QuantizationMode @@ -49,6 +49,7 @@ def convert_to_quantized_model( quantization_mode: Optional[str] = None, fp8_activation_scale_ub: Optional[float] = 1200.0, use_rich_progress: bool = True, + scale_state_dict: Optional[dict] = None, ) -> Transformer: from ...quantize_impls import ( Fp8ScaledWeights, @@ -75,13 +76,12 @@ def should_quantize_block(block: nn.Module) -> bool: use_rich_progress = use_rich_progress and rank == 0 progress, log_status, update_status = logging_callbacks(use_rich_progress, rank, model, should_quantize_block) if quantization_mode == QuantizationMode.int4_mixed: - int4_scales_path = os.path.join(checkpoint_dir, f"int4_scales_{rank}.pt") - if os.path.isfile(int4_scales_path): - log_status(f"Rank {rank}: Loading int4 scales") - int4_scales = torch.load(int4_scales_path, weights_only=True) + if scale_state_dict is not None: def apply_quantization(key, weight): - scale = int4_scales[key] + scale = scale_state_dict[key] + if "experts" in key: + scale = scale.squeeze(1) return load_int4( weight, scale, @@ -92,7 +92,7 @@ def apply_quantization(key, weight): log_status(f"Rank {rank}: Quantizing int4 weights from bf16") def apply_quantization(_, weight): - return quantize_int4(weight, fp8_activation_scale_ub, output_device=torch.device("cuda")) + return quantize_int4(weight, output_device=torch.device("cuda")) else: fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt") @@ -130,7 +130,11 @@ def apply_quantization(_, weight): prefix = f"layers.{block.layer_id}.feed_forward" moe = block.feed_forward moe.experts.batched_swiglu = experts_batched_swiglu_wrapper.__get__(moe.experts) - + state_dict_key_map = { + "w1": "moe_w_in_eD_F", + "w2": "moe_w_out_eF_D", + "w3": "moe_w_swiglu_eD_F", + } for key in ("w1", "w3", "w2"): param = getattr(moe.experts, key) update_status(f"Rank {rank} - Layer {block.layer_id} - MoE {key}") @@ -138,7 +142,7 @@ def apply_quantization(_, weight): moe.experts, key, apply_quantization( - f"{prefix}.experts.{key}", + f"{prefix}.experts.{state_dict_key_map[key]}", param.transpose(1, 2).contiguous(), ), ) @@ -146,10 +150,15 @@ def apply_quantization(_, weight): if quantization_mode == QuantizationMode.int4_mixed: # Quantize shared experts moe.shared_expert.forward = swiglu_wrapper_no_reduce.__get__(moe.shared_expert) + state_dict_key_map = { + "w1": "w_in_shared_FD.weight", + "w2": "w_out_shared_DF.weight", + "w3": "w_swiglu_FD.weight", + } for key in ("w1", "w3", "w2"): param = getattr(moe.shared_expert, key) update_status(f"Rank {rank} - Layer {block.layer_id} - MoE shared expert {key}") - param.weight = apply_quantization(f"{prefix}.shared_expert.{key}", param.weight) + param.weight = apply_quantization(f"{prefix}.{state_dict_key_map[key]}", param.weight) processed_blocks += 1 update_status(message=None, completed=processed_blocks) diff --git a/models/llama4/scripts/quantize.py b/models/llama4/scripts/quantize.py index 9e58e861f..f2ba52afb 100644 --- a/models/llama4/scripts/quantize.py +++ b/models/llama4/scripts/quantize.py @@ -18,6 +18,7 @@ from models.llama4.generation import QuantizationMode from models.llama4.model import MoE, Transformer, TransformerBlock from models.quantize_impls import int4_row_quantize, pack_int4 +from torch import nn try: import fbgemm_gpu.experimental.gen_ai # noqa: F401 @@ -112,16 +113,24 @@ def ffn_quantize( model.load_state_dict(checkpoint, strict=False) print("Done...") + def should_quantize_block(block: nn.Module) -> bool: + if not isinstance(block, TransformerBlock): + return False + + is_moe = isinstance(block.feed_forward, MoE) + if quantization_mode == QuantizationMode.fp8_mixed: + # skip quantization on first and last layers + return is_moe and not (block.layer_id == 0 or block.layer_id == (model.n_layers - 1)) + + return is_moe + fp8_scales = {} int4_scales = {} old_keys = set(checkpoint.keys()) new_state_dict = checkpoint for _, block in model.named_modules(): if isinstance(block, TransformerBlock): - if block.layer_id == 0 or block.layer_id == (model.n_layers - 1): - continue - - if not isinstance(block.feed_forward, MoE): + if not should_quantize_block(block): continue # IMPORTANT NOTE: @@ -147,12 +156,12 @@ def ffn_quantize( weight = param.transpose(1, 2).contiguous() if weight.ndim >= 3: - wq, scale = zip(*[int4_row_quantize(i.cuda()) for i in weight]) - wq = torch.stack([pack_int4(i.cuda()) for i in wq], dim=0) + wq, scale = zip(*[int4_row_quantize(i) for i in weight]) + wq = torch.stack([pack_int4(i) for i in wq], dim=0) w_scale = torch.stack(scale, dim=0) else: - wq, w_scale = int4_row_quantize(weight.cuda()) - wq = pack_int4(wq.cuda()) + wq, w_scale = int4_row_quantize(weight) + wq = pack_int4(wq) state_dict_key_map = { "w1": "moe_w_in_eD_F", @@ -167,8 +176,29 @@ def ffn_quantize( new_state_dict[f"{prefix}.experts.{state_dict_key_map[key]}"] = torch.nn.Parameter( wq.view(torch.float8_e4m3fn) ) - int4_scales[f"{prefix}.experts.{key}"] = w_scale - print(f"Quantized {prefix}.experts.{state_dict_key_map[key]} {wq.shape=} {w_scale.shape=}") + int4_scales[f"{prefix}.experts.{state_dict_key_map[key]}"] = w_scale + print(f"Quantized {prefix}.experts.{key} {wq.shape=} {w_scale.shape=}") + + param = getattr(moe.shared_expert, key) + weight = param.weight + if weight.ndim >= 3: + wq, scale = zip(*[int4_row_quantize(i) for i in weight]) + wq = torch.stack([pack_int4(i) for i in wq], dim=0) + w_scale = torch.stack(scale, dim=0) + else: + wq, w_scale = int4_row_quantize(weight) + wq = pack_int4(wq) + + state_dict_key_map = { + "w1": "w_in_shared_FD.weight", + "w2": "w_out_shared_DF.weight", + "w3": "w_swiglu_FD.weight", + } + new_state_dict[f"{prefix}.{state_dict_key_map[key]}"] = torch.nn.Parameter( + wq.view(torch.float8_e4m3fn) + ) + int4_scales[f"{prefix}.{state_dict_key_map[key]}"] = w_scale + print(f"Quantized {prefix}.{key} {wq.shape=} {w_scale.shape=}") else: for key in ("w1", "w3", "w2"): @@ -186,7 +216,7 @@ def ffn_quantize( wq = wq.transpose(1, 2).reshape(*new_shape).contiguous() new_state_dict[f"{prefix}.experts.{state_dict_key_map[key]}"] = torch.nn.Parameter(wq) - fp8_scales[f"{prefix}.experts.{key}"] = w_scale + fp8_scales[f"{prefix}.experts.{state_dict_key_map[key]}"] = w_scale print(f"Quantized {prefix}.experts.{state_dict_key_map[key]} {wq.shape=} {w_scale.shape=}") new_keys = set(new_state_dict.keys()) diff --git a/models/quantize_impls.py b/models/quantize_impls.py index 4127425ce..e0352c79f 100644 --- a/models/quantize_impls.py +++ b/models/quantize_impls.py @@ -250,7 +250,7 @@ def load_int4( w_scale (Tensor): [n, k/2] input INT4 scale. """ return Int4Weights( - weight=w.to(torch.int8).to(device=output_device), + weight=w.view(torch.int8).to(device=output_device), scale=scale.to(device=output_device), shape=w.shape, )