File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed
Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -89,7 +89,8 @@ def test(
8989
9090 # flex attention version
9191 # TODO(jansel): turn the above kernel into a flex attention kernel
92- flex_out = flex_attention (q , k , v )
92+ flex_compiled = torch .compile (flex_attention , fullgraph = True )
93+ flex_out = flex_compiled (q , k , v )
9394 torch .testing .assert_close (flex_out , ref_out , atol = 1e-2 , rtol = 1e-2 )
9495
9596 # sdpa version
@@ -106,7 +107,7 @@ def test(
106107 spda_sec = do_bench (
107108 lambda : torch .nn .functional .scaled_dot_product_attention (q , k , v )
108109 )
109- flex_sec = do_bench (lambda : flex_attention (q , k , v ))
110+ flex_sec = do_bench (lambda : flex_compiled (q , k , v ))
110111 helion_sec = do_bench (lambda : attention (q , k , v ))
111112 print (
112113 f"Helion time: { helion_sec :.4f} ms, flex time: { flex_sec :.4f} , torch time: { spda_sec :.4f} "
You can’t perform that action at this time.
0 commit comments