Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions src/transformers/integrations/eetq.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..core_model_loading import ConversionOps
from ..quantizers.quantizers_utils import get_module_from_name
from ..utils import is_accelerate_available, is_eetq_available, logging


Expand All @@ -24,6 +26,30 @@

logger = logging.get_logger(__name__)

class EetqQuantize(ConversionOps):
def __init__(self, hf_quantizer):
self.hf_quantizer = hf_quantizer

def convert(self, input_dict: torch.Tensor, model: Optional[torch.nn.Module] = None, **kwargs) -> dict[str, torch.Tensor]:
target_key, value = tuple(input_dict.items())[0]
value = value[0] if isinstance(value, list) else value

from eetq import EetqLinear, quantize_and_preprocess_weights

module, tensor_name = get_module_from_name(model, param_name)
new_value, weight_scale = quantize_and_preprocess_weights(param_value)

# Samity check
if isinstance(module, EetqLinear):
if self.pre_quantized or tensor_name == "bias":
if tensor_name == "weight" and param_value.dtype != torch.int8:
raise ValueError("Expect quantized weights but got an unquantized weight")
else:
if tensor_name == "weight_scale":
raise ValueError("Expect unquantized weights but got a quantized weight_scale")

return {target_key:new_value,
f"{target_key}_scales": weight_scale}

def _replace_with_eetq_linear(
model,
Expand Down Expand Up @@ -56,8 +82,7 @@ def _replace_with_eetq_linear(
model._modules[name] = eetq.EetqLinear(
in_features, out_features, module.bias is not None, module.weight.device
)
if pre_quantized:
model._modules[name].register_scale(module.weight.device)
model._modules[name].register_scale(module.weight.device)
has_been_replaced = True

# Force requires grad to False to avoid unexpected errors
Expand Down
65 changes: 65 additions & 0 deletions src/transformers/integrations/fbgemm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.

from ..activations import ACT2FN
from ..core_model_loading import ConversionOps
from ..quantizers.quantizers_utils import get_module_from_name
from ..utils import is_accelerate_available, is_fbgemm_gpu_available, is_torch_available, logging


Expand All @@ -29,6 +31,69 @@
logger = logging.get_logger(__name__)


class FbgemmFp8Quantize(ConversionOps):
def __init__(self, hf_quantizer):
self.hf_quantizer = hf_quantizer

def convert(self, input_dict: torch.Tensor, model: Optional[torch.nn.Module] = None, **kwargs) -> dict[str, torch.Tensor]:
target_key, value = tuple(input_dict.items())[0]
value = value[0] if isinstance(value, list) else value

from ..integrations import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts
module, tensor_name = get_module_from_name(model, target_key)
# Sanity checks
if isinstance(module, FbgemmFp8Linear):
if self.pre_quantized or tensor_name == "bias":
if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn:
raise ValueError("Expect quantized weights but got an unquantized weight")
else:
if tensor_name == "weight_scale":
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
if isinstance(module, FbgemmFp8Llama4TextExperts):
if not (self.pre_quantized or tensor_name == "bias"):
if tensor_name == "gate_up_proj_scale" or tensor_name == "down_proj_scale":
raise ValueError("Expect unquantized weights but got a quantized weight_scale")

if isinstance(module, FbgemmFp8Llama4TextExperts):
if tensor_name == "gate_up_proj":
# Process each expert separately
# Transpose the second and third dimension
transposed_param = param_value.transpose(1, 2)

# Reshape to 2D for quantization
original_shape = transposed_param.shape
flattened_param = transposed_param.reshape(-1, original_shape[-1])

# Quantize using per row instead of per column
new_value_flat, weight_scale_flat = torch.ops.fbgemm.quantize_fp8_per_row(flattened_param)

# Reshape back to original dimensions
new_value = new_value_flat.reshape(original_shape)
new_value = new_value.transpose(1, 2)
weight_scale = weight_scale_flat.reshape(original_shape[0], 1, original_shape[1])
elif tensor_name == "down_proj":
# Process each expert separately
# Transpose the weights for proper quantization
transposed_param = param_value.transpose(1, 2)

# Reshape to 2D for quantization
original_shape = transposed_param.shape
flattened_param = transposed_param.reshape(-1, original_shape[-1])

# Quantize using per column
new_value_flat, weight_scale_flat = torch.ops.fbgemm.quantize_fp8_per_row(flattened_param)

# Reshape back to original dimensions
new_value = new_value_flat.reshape(original_shape)
new_value = new_value.transpose(1, 2)
weight_scale = weight_scale_flat.reshape(original_shape[0], original_shape[1], 1)
else:
new_value, weight_scale = torch.ops.fbgemm.quantize_fp8_per_row(param_value)
weight_scale = weight_scale.view(weight_scale.shape[0], 1)

return {target_key: new_value,
f"{target_key}_scale": weight_scale}

class FbgemmFp8Linear(torch.nn.Linear):
def __init__(self, in_features, out_features, bias, weight_dtype=torch.float32):
super().__init__(in_features, out_features, bias)
Expand Down
68 changes: 68 additions & 0 deletions src/transformers/integrations/fp_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,75 @@
from fp_quant import FPQuantDtype

from transformers.utils.quantization_config import FPQuantConfig
from ..quantizers.quantizers_utils import get_module_from_name
from ..core_model_loading import ConversionOps

class FpQuantQuantize(ConversionOps):
def __init__(self, hf_quantizer):
self.hf_quantizer = hf_quantizer

def convert(self, input_dict: torch.Tensor, model: Optional[torch.nn.Module] = None, **kwargs) -> dict[str, torch.Tensor]:
target_key, value = tuple(input_dict.items())[0]
value = value[0] if isinstance(value, list) else value

module, _ = get_module_from_name(model, target_key)

# TODO: check if we need this or not, commented for now
# if target_device == "cpu" and param_name.endswith("weight"):
# # Works agains hard-coded missing key dispatch to CPU
# return

# The module holds either:
# * `weight` when `store_master_weights=True`
# * `qweight` and `scales` when `store_master_weights=False` and `pseudoquantization=False`
# * `dqweight` when `store_master_weights=False` and `pseudoquantization=True`

if target_key.endswith(".qweight"):
# Loading a real quantized checkpoint without master weights
qweight = torch.nn.Parameter(
param_value,
requires_grad=False,
)
module.weight = None
module.dqweight = None

weight_key = target_keys.rsplit(".", 1)[0] + ".weight"
dqweight_key = target_keys.rsplit(".", 1)[0] + ".dqweight"
return {target_key: qweight,
weight_key: None,
dqweight_key: None
}

if param_name.endswith(".dqweight"):
# Loading a pseudo-quantized checkpoint without master weights
dqweight = torch.nn.Parameter(param_value.to(target_device))

weight_key = target_keys.rsplit(".", 1)[0] + ".weight"
dqweight_key = target_keys.rsplit(".", 1)[0] + ".dqweight"
scales_key = target_keys.rsplit(".", 1)[0] + ".scales"

return {
target_key:dqweight,
weight_key:None,
dqweight_key:None,
scales_key:None
}

# Loading master weights or an unquantized checkpoint
weight = torch.nn.Parameter(param_value.to(target_device))
module.weight = weight
# Let pre-forward handle the quantization and set None where necessary
module.pre_forward()

prefix_target_key = target_keys.rsplit(".", 1)[0]

return {target_key: weight,
prefix_target_key + ".act_global_scale": module.act_global_scale,
prefix_target_key + "backward_hadamard_matrix": module.backward_hadamard_matrix
prefix_target_key + "forward_hadamard_matrix": module.forward_hadamard_matrix
prefix_target_key + "qweight": module.qweight
prefix_target_key + "scales": module.scales
}

def adapt_fp_quant_config(config: FPQuantConfig):
if config.forward_dtype == "mxfp4":
Expand Down
34 changes: 34 additions & 0 deletions src/transformers/integrations/higgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
if is_hadamard_available():
from fast_hadamard_transform import hadamard_transform

from ..core_model_loading import ConversionOps


def pad_to_block(tensor, dims, had_block_size, value=0):
pad_dims = [0 for _ in range(2 * len(tensor.shape))]
Expand Down Expand Up @@ -439,6 +441,38 @@ def get_higgs_grid(p: int, n: int) -> "torch.Tensor":
raise NotImplementedError(f"Unsupported p={p}, n={n}")


class HiggsQuantize(ConversionOps):
def __init__(self, hf_quantizer):
self.hf_quantizer = hf_quantizer

def convert(self, input_dict: torch.Tensor, model: Optional[torch.nn.Module] = None, **kwargs) -> dict[str, torch.Tensor]:
target_key, value = tuple(input_dict.items())[0]
value = value[0] if isinstance(value, list) else value

flute_dict = quantize_with_higgs(
value,
self.hf_quantizer.quantization_config.bits,
self.hf_quantizer.quantization_config.p,
self.hf_quantizer.quantization_config.group_size,
self.hf_quantizer.quantization_config.hadamard_size,
)
del value

quantized_dict = {}
module, _ = get_module_from_name(model, target_key)
module_name = target_key.rsplit(".", 1)[0]
for key, value in flute_dict.items():
if key in module._parameters:
quantized_dict[module_name + "." + key] = torch.nn.Parameter(value, requires_grad=False)
elif key in module._buffers:
quantized_dict[module_name + "." + key] = torch.nn.Buffer(value)
elif key == "tune_metadata":
module.tune_metadata = value
self.hf_quantizer.quantization_config.tune_metadata[module_name] = value.to_dict()
else:
raise ValueError(f"Unexpected key {key} in module {module}")
return quantized_dict

def quantize_with_higgs(weight, bits: int = 4, p: int = 2, group_size: int = 256, hadamard_size: int = 1024):
assert len(weight.shape) == 2, "Only 2D weights are supported for now"

Expand Down
82 changes: 82 additions & 0 deletions src/transformers/integrations/hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,89 @@

logger = logging.get_logger(__name__)

from ..quantizers.quantizers_utils import get_module_from_name


def HqqQuantize(ConversionOps):
def __init__(self, hf_quantizer):
self.hf_quantizer = hf_quantizer

def convert(self, input_dict: torch.Tensor, model: Optional[torch.nn.Module] = None, missing_keys=None **kwargs) -> dict[str, torch.Tensor]:
target_key, value = tuple(input_dict.items())[0]
value = value[0] if isinstance(value, list) else value


module, tensor_name = get_module_from_name(model, param_name)
module_name = param_name.rsplit(".", 1)[0]
parent_module, node = get_module_from_name(model, module_name)

quant_config = model.config.quantization_config["quant_config"]
skip_modules = model.config.quantization_config["skip_modules"]

# In this case we do not quantize this layer (it's explicitly skipped) -> simply load param
if any(skip_module in module.name for skip_module in skip_modules):
return {target_key: value}

# We need this hack as the model is not pre-prepared as an empty skeleton on meta device
if self.pre_quantized:
# Save them for later
if not hasattr(self.hf_quantizer, "hqq_params"):
self.hf_quantizer.hqq_params = defaultdict(dict)
self.hf_quantizer.hqq_params[module_name].update({tensor_name: value})
hqq_params = self.hf_quantizer.hqq_params[module_name]

# If they are all present and saved, make it a HQQLinear layer! (we cannot do it param after param because
# hqq does not support it...)
if all(k in hqq_params for k in self.hf_quantizer.hqq_keys) and ("bias" in hqq_params or module.bias is None):
hqq_layer = HQQLinear(
linear_layer=None,
quant_config=None,
compute_dtype=self.hf_quantizer.dtype,
device=value.device,
del_orig=False,
)
hqq_layer.load_state_dict(hqq_params)

if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)
if self.using_multi_gpu:
hqq_layer = self._patch_layer_for_multigpu(hqq_layer)

setattr(parent_module, node, hqq_layer)
del self.hqq_params[module_name], module
return {}
return {}

# Load param in the module (without caring about device or dtype, it will be changed later)
module.load_state_dict({tensor_name: param_value}, strict=False, assign=True)

# If both the weight and bias have already been loaded, time to quantize!
module_is_ready = module.weight.device.type != "meta" and (
module.bias is None or module.bias.device.type != "meta"
)

if module_is_ready:
module_tag = ".".join(module.name.split(".")[-2:])
if "weight_quant_params" in quant_config:
module_quant_config = quant_config
elif module_tag in quant_config:
module_quant_config = quant_config[module_tag]

hqq_layer = HQQLinear(
module,
quant_config=module_quant_config,
compute_dtype=self.dtype,
device=target_device,
del_orig=True,
)

if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)

if self.using_multi_gpu:
hqq_layer = self._patch_layer_for_multigpu(hqq_layer)

setattr(parent_module, node, hqq_layer)
# Name all modules inside the model
def autoname_modules(model):
for name, module in model.named_modules():
Expand Down
16 changes: 16 additions & 0 deletions src/transformers/integrations/quanto.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ..core_model_loading import ConversionOps
from ..quantizers.quantizers_utils import get_module_from_name
from ..utils import is_optimum_quanto_available, is_torch_available, logging


Expand All @@ -20,6 +22,20 @@

logger = logging.get_logger(__name__)

class QuantoQuantize(ConversionOps):
def __init__(self, hf_quantizer):
self.hf_quantizer = hf_quantizer

def convert(self, input_dict: torch.Tensor, model: Optional[torch.nn.Module] = None, **kwargs) -> dict[str, torch.Tensor]:
target_key, value = tuple(input_dict.items())[0]
value = value[0] if isinstance(value, list) else value

from ..modeling_utils import _load_parameter_into_model
_load_parameter_into_model(model, target_key, param_value)
module, _ = get_module_from_name(model, param_name)
module.freeze()
module.weight.requires_grad = False
return {target_key: module.weight}

def replace_with_quanto_layers(
model,
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/quantizers/quantizer_eetq.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,7 @@ def is_serializable(self, safe_serialization=None):
@property
def is_trainable(self) -> bool:
return True

def get_quantize_ops(self):
from ..integrations.eetq import EetqQuantize
return EetqQuantize(self)
4 changes: 4 additions & 0 deletions src/transformers/quantizers/quantizer_fbgemm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,10 @@ def update_tp_plan(self, config):

return config

def get_quantize_ops(self):
from ..integrations.fbgemm_fp8 import FbgemmFp8Quantize
return FbgemmFp8Quantize(self)

def is_serializable(self, safe_serialization=None):
return True

Expand Down
4 changes: 4 additions & 0 deletions src/transformers/quantizers/quantizer_fp_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,7 @@ def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **
return True
else:
return False

def get_quantize_ops(self):
from ..integrations.fp_quant import FpQuantQuantize
return FpQuantQuantize(self)
Loading