Skip to content

Commit e032156

Browse files
authored
Support torch_fp8 (#13196)
* support torch_fp8
1 parent 3accc31 commit e032156

File tree

2 files changed

+84
-42
lines changed

2 files changed

+84
-42
lines changed

python/llm/src/ipex_llm/ggml/quantize.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@
5454
"sym_int8_rtn": 32,
5555
"asym_int4_rtn": 33,
5656
"woq_int4": 34,
57+
"torch_fp8_e5m2": 35,
58+
"torch_fp8": 35,
59+
"torch_fp8_e4m3": 36
5760
}
5861

5962
# mixed precison from llama.cpp

python/llm/src/ipex_llm/transformers/low_bit_linear.py

Lines changed: 81 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@
8686
SYM_INT8_RTN = ggml_tensor_qtype["sym_int8_rtn"]
8787
ASYM_INT4_RTN = ggml_tensor_qtype["asym_int4_rtn"]
8888
WOQ_INT4 = ggml_tensor_qtype["woq_int4"]
89+
TORCH_FP8E5 = ggml_tensor_qtype["torch_fp8_e5m2"]
90+
TORCH_FP8E4 = ggml_tensor_qtype["torch_fp8_e4m3"]
8991
RTN_DTYPE = {
9092
SYM_INT4_RTN: torch.uint8,
9193
ASYM_INT4_RTN: torch.uint8,
@@ -106,39 +108,44 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
106108
imatrix: torch.Tensor=None,
107109
in_features: int=None,
108110
enable_scale_search: bool=False):
109-
QK = ggml.ggml_qk_size(qtype)
110-
block_size_in_bytes = ggml.ggml_type_size(qtype)
111-
112-
invalidInputError(tensor.dtype == torch.float,
113-
"Input tensor must be float32")
114-
src = tensor.data.data_ptr()
115-
src = ctypes.cast(src, ctypes.POINTER(ctypes.c_float))
116-
n = tensor.numel() # all elements
117-
k = tensor.shape[-1]
118-
invalidInputError(k % QK == 0,
119-
f"Last dim of input tensor must be multiple of {QK}")
120-
121-
dst_size = (n // QK) * block_size_in_bytes
122-
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN, ASYM_INT4_RTN]:
123-
dst_tensor = torch.empty(dst_size, dtype=RTN_DTYPE[qtype],
124-
device=device)
125-
dst_tensor = dst_tensor.reshape(tensor.shape[0], tensor.shape[-1] // QK)
126-
if qtype == ASYM_INT4_RTN:
127-
scale = torch.empty((n // k) * 2, dtype=torch.float32,
128-
device=device)
129-
else:
130-
scale = torch.empty(n // k, dtype=torch.float32,
131-
device=device)
132-
elif qtype == NF4:
133-
# Deepspeed zero3 requires unified dtype,
134-
# thus here uses bfloat16 consistent to other layers
135-
# dst_size above is computed based on uint8, and for bfloat16,
136-
# buffer size should be half
137-
dst_tensor = torch.empty(dst_size // 2, dtype=torch.bfloat16,
138-
device=device)
111+
if qtype in [TORCH_FP8E5, TORCH_FP8E4]:
112+
fp8_dtype = torch.float8_e5m2 if qtype == TORCH_FP8E5 else torch.float8_e4m3fn
113+
dst_tensor = torch.empty(tensor.shape, device=device, dtype=fp8_dtype)
114+
scale = torch.zeros(1, device=device, dtype=torch.float32)
139115
else:
140-
dst_tensor = torch.empty(dst_size, dtype=torch.uint8,
141-
device=device)
116+
QK = ggml.ggml_qk_size(qtype)
117+
block_size_in_bytes = ggml.ggml_type_size(qtype)
118+
119+
invalidInputError(tensor.dtype == torch.float,
120+
"Input tensor must be float32")
121+
src = tensor.data.data_ptr()
122+
src = ctypes.cast(src, ctypes.POINTER(ctypes.c_float))
123+
n = tensor.numel() # all elements
124+
k = tensor.shape[-1]
125+
invalidInputError(k % QK == 0,
126+
f"Last dim of input tensor must be multiple of {QK}")
127+
128+
dst_size = (n // QK) * block_size_in_bytes
129+
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN, ASYM_INT4_RTN]:
130+
dst_tensor = torch.empty(dst_size, dtype=RTN_DTYPE[qtype],
131+
device=device)
132+
dst_tensor = dst_tensor.reshape(tensor.shape[0], tensor.shape[-1] // QK)
133+
if qtype == ASYM_INT4_RTN:
134+
scale = torch.empty((n // k) * 2, dtype=torch.float32,
135+
device=device)
136+
else:
137+
scale = torch.empty(n // k, dtype=torch.float32,
138+
device=device)
139+
elif qtype == NF4:
140+
# Deepspeed zero3 requires unified dtype,
141+
# thus here uses bfloat16 consistent to other layers
142+
# dst_size above is computed based on uint8, and for bfloat16,
143+
# buffer size should be half
144+
dst_tensor = torch.empty(dst_size // 2, dtype=torch.bfloat16,
145+
device=device)
146+
else:
147+
dst_tensor = torch.empty(dst_size, dtype=torch.uint8,
148+
device=device)
142149

143150
if not convert_shape_only and device != 'meta':
144151
dst = ctypes.c_void_p(dst_tensor.data.data_ptr())
@@ -158,6 +165,17 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
158165
enable_scale_search,
159166
imatrix)
160167
return dst_tensor, scale.type(torch.float16)
168+
elif qtype in [TORCH_FP8E5, TORCH_FP8E4]:
169+
import xe_linear
170+
tensor_device = tensor.device
171+
tensor_xpu = tensor.to("xpu")
172+
dst_tensor = dst_tensor.to("xpu")
173+
scale = scale.to("xpu")
174+
175+
xe_linear.dynamic_scaled_fp8_quant(dst_tensor, tensor_xpu, scale)
176+
177+
# scale = scale.to(tensor_device)
178+
dst_tensor = dst_tensor.to(tensor_device)
161179
else:
162180
ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist, enable_scale_search)
163181
else:
@@ -171,6 +189,8 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
171189
hist, imatrix)
172190
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN, ASYM_INT4_RTN]:
173191
return dst_tensor, scale.type(torch.float16)
192+
elif qtype in [TORCH_FP8E5, TORCH_FP8E4]:
193+
return dst_tensor, scale
174194
else:
175195
return dst_tensor
176196

@@ -179,7 +199,7 @@ def ggml_q_format_convet_cpu2xpu(tensor: torch.Tensor, num_elem: int, qtype: int
179199
if qtype == NF4:
180200
invalidInputError(tensor.dtype == torch.bfloat16,
181201
"NF4 Input tensor must be bfloat16")
182-
else:
202+
elif qtype not in [TORCH_FP8E5, TORCH_FP8E4]:
183203
invalidInputError(tensor.dtype == torch.uint8,
184204
"Input tensor except NF4 must be uint8")
185205

@@ -208,7 +228,7 @@ def ggml_q_format_convet_xpu2cpu(tensor: torch.Tensor, num_elem: int, qtype: int
208228
if qtype == NF4:
209229
invalidInputError(tensor.dtype == torch.bfloat16,
210230
"NF4 Input tensor must be bfloat16")
211-
else:
231+
elif qtype not in [TORCH_FP8E5, TORCH_FP8E4]:
212232
invalidInputError(tensor.dtype == torch.uint8,
213233
"Input tensor must be uint8")
214234

@@ -319,7 +339,8 @@ def __new__(cls,
319339
qtype=None,
320340
imatrix=None,
321341
in_features=None,
322-
enable_scale_search=False):
342+
enable_scale_search=False,
343+
torch_fp8_scale=None):
323344
if data is None:
324345
data = torch.empty(0)
325346

@@ -332,6 +353,7 @@ def __new__(cls,
332353
self.imatrix = imatrix
333354
self.in_features = in_features
334355
self.enable_scale_search = enable_scale_search
356+
self.torch_fp8_scale = torch_fp8_scale
335357
return self
336358

337359
def ggml_mse(self, w, ggml_qtype, device):
@@ -391,7 +413,11 @@ def quantize(self, device=None):
391413
imatrix=self.imatrix,
392414
in_features=self.in_features,
393415
enable_scale_search=self.enable_scale_search)
394-
self.data = w_quantized
416+
if self.qtype in [TORCH_FP8E5, TORCH_FP8E4]:
417+
self.data = w_quantized[0]
418+
self.torch_fp8_scale = w_quantized[1]
419+
else:
420+
self.data = w_quantized
395421
self.quantized = True
396422
self._shape = w.shape
397423
return self
@@ -414,6 +440,8 @@ def to(self: T, tensor: Tensor, non_blocking: bool=...) -> T:
414440

415441
def to(self, *args, **kwargs):
416442
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
443+
if self.qtype in [TORCH_FP8E5, TORCH_FP8E4]:
444+
dtype = None
417445
if (device is not None and device.type == "cpu" and self.data.device.type == "cpu"):
418446
return self.quantize(device.type)
419447
elif device is not None and device.type == "meta" and self.data.device.type == "meta":
@@ -424,24 +452,28 @@ def to(self, *args, **kwargs):
424452
self.data = ggml_q_format_convet_cpu2xpu(self.data,
425453
reduce(mul, self._shape, 1),
426454
self.qtype)
455+
fp8_scale = None if self.torch_fp8_scale is None else self.torch_fp8_scale.to(device)
427456
new_param = FP4Params(super().to(device=device,
428457
dtype=dtype,
429458
non_blocking=non_blocking),
430459
requires_grad=self.requires_grad,
431460
quantized=self.quantized,
432461
_shape=self._shape,
433462
qtype=self.qtype,
434-
enable_scale_search=self.enable_scale_search)
463+
enable_scale_search=self.enable_scale_search,
464+
torch_fp8_scale=fp8_scale)
435465
return new_param
436466
elif (device is not None and device.type == "cpu" and self.data.device.type == "xpu"):
467+
fp8_scale = None if self.torch_fp8_scale is None else self.torch_fp8_scale.to(device)
437468
new_param = FP4Params(super().to(device=device,
438469
dtype=dtype,
439470
non_blocking=non_blocking),
440471
requires_grad=self.requires_grad,
441472
quantized=self.quantized,
442473
_shape=self._shape,
443474
qtype=self.qtype,
444-
enable_scale_search=self.enable_scale_search)
475+
enable_scale_search=self.enable_scale_search,
476+
torch_fp8_scale=fp8_scale)
445477
ggml_xpu = new_param.data
446478
new_param.data = ggml_q_format_convet_xpu2cpu(ggml_xpu,
447479
reduce(mul, new_param._shape, 1),
@@ -614,6 +646,7 @@ def forward(self, x: torch.Tensor):
614646
# Due to inconsistent training status in some models like Baichuan-7b-Chat,
615647
# we should check both self.training and torch.is_inference_mode_enabled().
616648
is_training = self.training and not torch.is_inference_mode_enabled()
649+
617650
if is_training:
618651
# below logic is only for training
619652
autocast_dtype = get_autocast_dtype(x.device.type)
@@ -643,6 +676,8 @@ def forward(self, x: torch.Tensor):
643676

644677
if self.weight.device.type == "xpu":
645678
if is_training and x_2d.requires_grad:
679+
invalidInputError(self.weight.qtype not in [TORCH_FP8E5, TORCH_FP8E4],
680+
"TORCH_FP8 training is not supported.")
646681
result = MatMulLowBit.apply(x_2d, self.weight, self.out_len)
647682
else:
648683
do_empty_cache = self.low_memory_mode and x_2d.shape[0] >= 1024
@@ -654,7 +689,11 @@ def forward(self, x: torch.Tensor):
654689
else:
655690
w = self.weight.data
656691

657-
if use_batch_forward(x_2d, self.weight.qtype, self.out_len) and \
692+
if self.weight.qtype in [TORCH_FP8E5, TORCH_FP8E4]:
693+
import xe_linear
694+
result = xe_linear.run_linear_fp8(x_2d, w, self.bias,
695+
self.weight.torch_fp8_scale)
696+
elif use_batch_forward(x_2d, self.weight.qtype, self.out_len) and \
658697
(x_2d.dtype == torch.half or self.conver_to_half):
659698
import xe_batch
660699
result = xe_batch.batch_forward(x_2d, w, self.qtype)
@@ -682,13 +721,13 @@ def forward(self, x: torch.Tensor):
682721
else:
683722
invalidInputError(False, "mp_group is not None, but no supported backend found")
684723

685-
if self.bias is not None:
724+
if self.bias is not None and self.weight.qtype not in [TORCH_FP8E5, TORCH_FP8E4]:
686725
result += self.bias
687726
else:
688727
# CPU logic
689728
# todo may need to set a different number on different platforms
690-
invalidInputError(self.qtype != NF3 and self.qtype != NF4 and self.qtype != FP8E4
691-
and self.qtype != FP4 and self.qtype != FP8E5,
729+
invalidInputError(self.qtype not in [NF3, NF4, FP8E4, FP4, FP8E5,
730+
TORCH_FP8E5, TORCH_FP8E4],
692731
"NF3, NF4, FP4 and FP8 quantization are currently not"
693732
" supported on CPU")
694733
if self.training and x.requires_grad:

0 commit comments

Comments
 (0)