@@ -77,15 +77,15 @@ def __init__(
7777 max_position_embeddings : int ,
7878 rope_theta : float = 10000 ,
7979 rope_scaling : dict [str , Any ] | None = None ,
80+ head_dim : int | None = None ,
8081 layer_id : int = 0 ,
8182 dtype : jnp .dtype = jnp .float16 ,
8283 rngs : nnx .Rngs = None ,
8384 ):
8485 self .hidden_size = hidden_size
8586 self .num_heads = num_heads
86- head_size = hidden_size // num_heads
87- self .head_size = head_size
88- self .scaling = head_size ** - 0.5
87+ self .head_dim = head_dim or hidden_size // num_heads
88+ self .scaling = head_dim ** - 0.5
8989
9090 self .q_proj = LinearBase (
9191 input_size = hidden_size ,
@@ -112,7 +112,7 @@ def __init__(
112112 params_dtype = dtype ,
113113 )
114114 self .c_proj = LinearBase (
115- input_size = num_heads * head_size ,
115+ input_size = num_heads * head_dim ,
116116 output_size = hidden_size ,
117117 use_bias = False ,
118118 kernel_axes = ("tensor" , None ),
@@ -122,17 +122,17 @@ def __init__(
122122
123123 # Use torch version of RotaryEmbedding directly
124124 self .rotary_emb = RotaryEmbedding (
125- head_size = head_size ,
126- rotary_dim = head_size ,
125+ head_size = head_dim ,
126+ rotary_dim = head_dim ,
127127 max_position_embeddings = max_position_embeddings ,
128128 base = rope_theta ,
129129 is_neox_style = True ,
130130 dtype = dtype ,
131131 )
132- self .scaling = head_size ** - 0.5
132+ self .scaling = head_dim ** - 0.5
133133 self .attn = RadixAttention (
134134 num_heads = num_heads ,
135- head_dim = head_size ,
135+ head_dim = head_dim ,
136136 scaling = self .scaling ,
137137 num_kv_heads = num_heads ,
138138 layer_id = layer_id ,
@@ -150,9 +150,9 @@ def __call__(
150150 k , _ = self .k_proj (hidden_states )
151151 v , _ = self .v_proj (hidden_states )
152152
153- q = q .reshape (- 1 , self .num_heads , self .head_size )
154- k = k .reshape (- 1 , self .num_heads , self .head_size )
155- v = v .reshape (- 1 , self .num_heads , self .head_size )
153+ q = q .reshape (- 1 , self .num_heads , self .head_dim )
154+ k = k .reshape (- 1 , self .num_heads , self .head_dim )
155+ v = v .reshape (- 1 , self .num_heads , self .head_dim )
156156
157157 q , k = self .rotary_emb (positions , q , k )
158158 attn_output , kv_fused = self .attn (q , k , v , forward_batch , token_to_kv_pool )
@@ -169,7 +169,7 @@ def __init__(
169169 rngs : nnx .Rngs = None ,
170170 ):
171171 self .layer_id = layer_id
172-
172+ head_dim = getattr ( config , "head_dim_padded" , None )
173173 self .ln_1 = RMSNorm (
174174 config .hidden_size ,
175175 epsilon = config .layer_norm_epsilon ,
@@ -186,6 +186,7 @@ def __init__(
186186 config .max_position_embeddings ,
187187 rope_theta = rope_theta ,
188188 rope_scaling = rope_scaling ,
189+ head_dim = head_dim ,
189190 layer_id = layer_id ,
190191 dtype = dtype ,
191192 rngs = rngs ,
0 commit comments