88import torch .nn .functional as F
99
1010from torch .nn import CrossEntropyLoss
11- from torch .nn .attention .flex_attention import create_block_mask
1211from torch .nn .attention .flex_attention import flex_attention
1312from transformers .modeling_outputs import CausalLMOutputWithPast
1413from transformers .models .llama .modeling_llama import _CONFIG_FOR_DOC
@@ -256,8 +255,6 @@ def lce_forward(
256255
257256
258257# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/flex_attention.py#L12
259-
260-
261258def flex_attention_forward (
262259 module : torch .nn .Module ,
263260 query : torch .Tensor ,
@@ -279,26 +276,24 @@ def causal_mod(score, b, h, q_idx, kv_idx):
279276 score = score + causal_mask [b ][0 ][q_idx ][kv_idx ]
280277 return score
281278
282- # We only got `attention_mask` tensors, so we recreate `causal_mask` function as specific llama causal attention
283- # TODO: Consider other customized `attention_mask` in the future, e.g., shared prefix
284- def causal_mask_fn (b , h , q_idx , kv_idx ):
285- return q_idx >= kv_idx
279+ # def causal_mask_fn(b, h, q_idx, kv_idx):
280+ # return q_idx >= kv_idx
286281
287- # To construct block attention mask that leverages sparsity.
288- sparse_causal_mask = create_block_mask (causal_mask_fn , None , None , query .shape [- 2 ], query .shape [- 2 ], device = "cuda" )
282+ # TODO: Construct block attention mask that leverages sparsity
283+ # sparse_causal_mask = create_block_mask(
284+ # causal_mask_fn, B=None, H=None, Q_LEN=query.shape[-2], KV_LEN=key.shape[-2], device=query.device, BLOCK_SIZE=1
285+ # )
289286
290287 attn_output , attention_weights = flex_attention (
291288 query ,
292289 key ,
293290 value ,
294291 score_mod = causal_mod ,
295- block_mask = sparse_causal_mask ,
292+ # block_mask=sparse_causal_mask,
296293 enable_gqa = True ,
297294 scale = scaling ,
298- # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
299- # For simplification, we thus always return it as no additional computations are introduced.
300295 return_lse = True ,
301- kernel_options = { # different harware might need different configs
296+ kernel_options = {
302297 "BLOCK_M" : 32 ,
303298 "BLOCK_N" : 32 ,
304299 "BLOCK_M1" : 16 ,
@@ -307,7 +302,7 @@ def causal_mask_fn(b, h, q_idx, kv_idx):
307302 "BLOCK_N2" : 16 ,
308303 },
309304 )
310- # lse is returned in float32
305+
311306 attention_weights = attention_weights .to (value .dtype )
312307 attn_output = attn_output .transpose (1 , 2 ).contiguous ()
313308
0 commit comments