8686SYM_INT8_RTN = ggml_tensor_qtype ["sym_int8_rtn" ]
8787ASYM_INT4_RTN = ggml_tensor_qtype ["asym_int4_rtn" ]
8888WOQ_INT4 = ggml_tensor_qtype ["woq_int4" ]
89+ TORCH_FP8E5 = ggml_tensor_qtype ["torch_fp8_e5m2" ]
90+ TORCH_FP8E4 = ggml_tensor_qtype ["torch_fp8_e4m3" ]
8991RTN_DTYPE = {
9092 SYM_INT4_RTN : torch .uint8 ,
9193 ASYM_INT4_RTN : torch .uint8 ,
@@ -106,39 +108,44 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
106108 imatrix : torch .Tensor = None ,
107109 in_features : int = None ,
108110 enable_scale_search : bool = False ):
109- QK = ggml .ggml_qk_size (qtype )
110- block_size_in_bytes = ggml .ggml_type_size (qtype )
111-
112- invalidInputError (tensor .dtype == torch .float ,
113- "Input tensor must be float32" )
114- src = tensor .data .data_ptr ()
115- src = ctypes .cast (src , ctypes .POINTER (ctypes .c_float ))
116- n = tensor .numel () # all elements
117- k = tensor .shape [- 1 ]
118- invalidInputError (k % QK == 0 ,
119- f"Last dim of input tensor must be multiple of { QK } " )
120-
121- dst_size = (n // QK ) * block_size_in_bytes
122- if qtype in [SYM_INT8_RTN , SYM_INT4_RTN , ASYM_INT4_RTN ]:
123- dst_tensor = torch .empty (dst_size , dtype = RTN_DTYPE [qtype ],
124- device = device )
125- dst_tensor = dst_tensor .reshape (tensor .shape [0 ], tensor .shape [- 1 ] // QK )
126- if qtype == ASYM_INT4_RTN :
127- scale = torch .empty ((n // k ) * 2 , dtype = torch .float32 ,
128- device = device )
129- else :
130- scale = torch .empty (n // k , dtype = torch .float32 ,
131- device = device )
132- elif qtype == NF4 :
133- # Deepspeed zero3 requires unified dtype,
134- # thus here uses bfloat16 consistent to other layers
135- # dst_size above is computed based on uint8, and for bfloat16,
136- # buffer size should be half
137- dst_tensor = torch .empty (dst_size // 2 , dtype = torch .bfloat16 ,
138- device = device )
111+ if qtype in [TORCH_FP8E5 , TORCH_FP8E4 ]:
112+ fp8_dtype = torch .float8_e5m2 if qtype == TORCH_FP8E5 else torch .float8_e4m3fn
113+ dst_tensor = torch .empty (tensor .shape , device = device , dtype = fp8_dtype )
114+ scale = torch .zeros (1 , device = device , dtype = torch .float32 )
139115 else :
140- dst_tensor = torch .empty (dst_size , dtype = torch .uint8 ,
141- device = device )
116+ QK = ggml .ggml_qk_size (qtype )
117+ block_size_in_bytes = ggml .ggml_type_size (qtype )
118+
119+ invalidInputError (tensor .dtype == torch .float ,
120+ "Input tensor must be float32" )
121+ src = tensor .data .data_ptr ()
122+ src = ctypes .cast (src , ctypes .POINTER (ctypes .c_float ))
123+ n = tensor .numel () # all elements
124+ k = tensor .shape [- 1 ]
125+ invalidInputError (k % QK == 0 ,
126+ f"Last dim of input tensor must be multiple of { QK } " )
127+
128+ dst_size = (n // QK ) * block_size_in_bytes
129+ if qtype in [SYM_INT8_RTN , SYM_INT4_RTN , ASYM_INT4_RTN ]:
130+ dst_tensor = torch .empty (dst_size , dtype = RTN_DTYPE [qtype ],
131+ device = device )
132+ dst_tensor = dst_tensor .reshape (tensor .shape [0 ], tensor .shape [- 1 ] // QK )
133+ if qtype == ASYM_INT4_RTN :
134+ scale = torch .empty ((n // k ) * 2 , dtype = torch .float32 ,
135+ device = device )
136+ else :
137+ scale = torch .empty (n // k , dtype = torch .float32 ,
138+ device = device )
139+ elif qtype == NF4 :
140+ # Deepspeed zero3 requires unified dtype,
141+ # thus here uses bfloat16 consistent to other layers
142+ # dst_size above is computed based on uint8, and for bfloat16,
143+ # buffer size should be half
144+ dst_tensor = torch .empty (dst_size // 2 , dtype = torch .bfloat16 ,
145+ device = device )
146+ else :
147+ dst_tensor = torch .empty (dst_size , dtype = torch .uint8 ,
148+ device = device )
142149
143150 if not convert_shape_only and device != 'meta' :
144151 dst = ctypes .c_void_p (dst_tensor .data .data_ptr ())
@@ -158,6 +165,17 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
158165 enable_scale_search ,
159166 imatrix )
160167 return dst_tensor , scale .type (torch .float16 )
168+ elif qtype in [TORCH_FP8E5 , TORCH_FP8E4 ]:
169+ import xe_linear
170+ tensor_device = tensor .device
171+ tensor_xpu = tensor .to ("xpu" )
172+ dst_tensor = dst_tensor .to ("xpu" )
173+ scale = scale .to ("xpu" )
174+
175+ xe_linear .dynamic_scaled_fp8_quant (dst_tensor , tensor_xpu , scale )
176+
177+ # scale = scale.to(tensor_device)
178+ dst_tensor = dst_tensor .to (tensor_device )
161179 else :
162180 ggml .ggml_quantize_tensor (src , dst , qtype , n , k , hist , enable_scale_search )
163181 else :
@@ -171,6 +189,8 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
171189 hist , imatrix )
172190 if qtype in [SYM_INT8_RTN , SYM_INT4_RTN , ASYM_INT4_RTN ]:
173191 return dst_tensor , scale .type (torch .float16 )
192+ elif qtype in [TORCH_FP8E5 , TORCH_FP8E4 ]:
193+ return dst_tensor , scale
174194 else :
175195 return dst_tensor
176196
@@ -179,7 +199,7 @@ def ggml_q_format_convet_cpu2xpu(tensor: torch.Tensor, num_elem: int, qtype: int
179199 if qtype == NF4 :
180200 invalidInputError (tensor .dtype == torch .bfloat16 ,
181201 "NF4 Input tensor must be bfloat16" )
182- else :
202+ elif qtype not in [ TORCH_FP8E5 , TORCH_FP8E4 ] :
183203 invalidInputError (tensor .dtype == torch .uint8 ,
184204 "Input tensor except NF4 must be uint8" )
185205
@@ -208,7 +228,7 @@ def ggml_q_format_convet_xpu2cpu(tensor: torch.Tensor, num_elem: int, qtype: int
208228 if qtype == NF4 :
209229 invalidInputError (tensor .dtype == torch .bfloat16 ,
210230 "NF4 Input tensor must be bfloat16" )
211- else :
231+ elif qtype not in [ TORCH_FP8E5 , TORCH_FP8E4 ] :
212232 invalidInputError (tensor .dtype == torch .uint8 ,
213233 "Input tensor must be uint8" )
214234
@@ -319,7 +339,8 @@ def __new__(cls,
319339 qtype = None ,
320340 imatrix = None ,
321341 in_features = None ,
322- enable_scale_search = False ):
342+ enable_scale_search = False ,
343+ torch_fp8_scale = None ):
323344 if data is None :
324345 data = torch .empty (0 )
325346
@@ -332,6 +353,7 @@ def __new__(cls,
332353 self .imatrix = imatrix
333354 self .in_features = in_features
334355 self .enable_scale_search = enable_scale_search
356+ self .torch_fp8_scale = torch_fp8_scale
335357 return self
336358
337359 def ggml_mse (self , w , ggml_qtype , device ):
@@ -391,7 +413,11 @@ def quantize(self, device=None):
391413 imatrix = self .imatrix ,
392414 in_features = self .in_features ,
393415 enable_scale_search = self .enable_scale_search )
394- self .data = w_quantized
416+ if self .qtype in [TORCH_FP8E5 , TORCH_FP8E4 ]:
417+ self .data = w_quantized [0 ]
418+ self .torch_fp8_scale = w_quantized [1 ]
419+ else :
420+ self .data = w_quantized
395421 self .quantized = True
396422 self ._shape = w .shape
397423 return self
@@ -414,6 +440,8 @@ def to(self: T, tensor: Tensor, non_blocking: bool=...) -> T:
414440
415441 def to (self , * args , ** kwargs ):
416442 device , dtype , non_blocking , convert_to_format = torch ._C ._nn ._parse_to (* args , ** kwargs )
443+ if self .qtype in [TORCH_FP8E5 , TORCH_FP8E4 ]:
444+ dtype = None
417445 if (device is not None and device .type == "cpu" and self .data .device .type == "cpu" ):
418446 return self .quantize (device .type )
419447 elif device is not None and device .type == "meta" and self .data .device .type == "meta" :
@@ -424,24 +452,28 @@ def to(self, *args, **kwargs):
424452 self .data = ggml_q_format_convet_cpu2xpu (self .data ,
425453 reduce (mul , self ._shape , 1 ),
426454 self .qtype )
455+ fp8_scale = None if self .torch_fp8_scale is None else self .torch_fp8_scale .to (device )
427456 new_param = FP4Params (super ().to (device = device ,
428457 dtype = dtype ,
429458 non_blocking = non_blocking ),
430459 requires_grad = self .requires_grad ,
431460 quantized = self .quantized ,
432461 _shape = self ._shape ,
433462 qtype = self .qtype ,
434- enable_scale_search = self .enable_scale_search )
463+ enable_scale_search = self .enable_scale_search ,
464+ torch_fp8_scale = fp8_scale )
435465 return new_param
436466 elif (device is not None and device .type == "cpu" and self .data .device .type == "xpu" ):
467+ fp8_scale = None if self .torch_fp8_scale is None else self .torch_fp8_scale .to (device )
437468 new_param = FP4Params (super ().to (device = device ,
438469 dtype = dtype ,
439470 non_blocking = non_blocking ),
440471 requires_grad = self .requires_grad ,
441472 quantized = self .quantized ,
442473 _shape = self ._shape ,
443474 qtype = self .qtype ,
444- enable_scale_search = self .enable_scale_search )
475+ enable_scale_search = self .enable_scale_search ,
476+ torch_fp8_scale = fp8_scale )
445477 ggml_xpu = new_param .data
446478 new_param .data = ggml_q_format_convet_xpu2cpu (ggml_xpu ,
447479 reduce (mul , new_param ._shape , 1 ),
@@ -614,6 +646,7 @@ def forward(self, x: torch.Tensor):
614646 # Due to inconsistent training status in some models like Baichuan-7b-Chat,
615647 # we should check both self.training and torch.is_inference_mode_enabled().
616648 is_training = self .training and not torch .is_inference_mode_enabled ()
649+
617650 if is_training :
618651 # below logic is only for training
619652 autocast_dtype = get_autocast_dtype (x .device .type )
@@ -643,6 +676,8 @@ def forward(self, x: torch.Tensor):
643676
644677 if self .weight .device .type == "xpu" :
645678 if is_training and x_2d .requires_grad :
679+ invalidInputError (self .weight .qtype not in [TORCH_FP8E5 , TORCH_FP8E4 ],
680+ "TORCH_FP8 training is not supported." )
646681 result = MatMulLowBit .apply (x_2d , self .weight , self .out_len )
647682 else :
648683 do_empty_cache = self .low_memory_mode and x_2d .shape [0 ] >= 1024
@@ -654,7 +689,11 @@ def forward(self, x: torch.Tensor):
654689 else :
655690 w = self .weight .data
656691
657- if use_batch_forward (x_2d , self .weight .qtype , self .out_len ) and \
692+ if self .weight .qtype in [TORCH_FP8E5 , TORCH_FP8E4 ]:
693+ import xe_linear
694+ result = xe_linear .run_linear_fp8 (x_2d , w , self .bias ,
695+ self .weight .torch_fp8_scale )
696+ elif use_batch_forward (x_2d , self .weight .qtype , self .out_len ) and \
658697 (x_2d .dtype == torch .half or self .conver_to_half ):
659698 import xe_batch
660699 result = xe_batch .batch_forward (x_2d , w , self .qtype )
@@ -682,13 +721,13 @@ def forward(self, x: torch.Tensor):
682721 else :
683722 invalidInputError (False , "mp_group is not None, but no supported backend found" )
684723
685- if self .bias is not None :
724+ if self .bias is not None and self . weight . qtype not in [ TORCH_FP8E5 , TORCH_FP8E4 ] :
686725 result += self .bias
687726 else :
688727 # CPU logic
689728 # todo may need to set a different number on different platforms
690- invalidInputError (self .qtype != NF3 and self . qtype != NF4 and self . qtype != FP8E4
691- and self . qtype != FP4 and self . qtype != FP8E5 ,
729+ invalidInputError (self .qtype not in [ NF3 , NF4 , FP8E4 , FP4 , FP8E5 ,
730+ TORCH_FP8E5 , TORCH_FP8E4 ] ,
692731 "NF3, NF4, FP4 and FP8 quantization are currently not"
693732 " supported on CPU" )
694733 if self .training and x .requires_grad :
0 commit comments