@@ -92,7 +92,7 @@ def __init__(
9292 self .scaling = self .head_dim ** - 0.5
9393 self .rope_theta = config .rope_theta
9494
95- self .qkv = QKVParallelLinear (
95+ self .qkv_proj = QKVParallelLinear (
9696 hidden_size = self .hidden_size ,
9797 head_size = self .head_dim ,
9898 total_num_heads = self .num_attention_heads ,
@@ -129,7 +129,7 @@ def __init__(
129129 def forward (
130130 self , hidden_states : torch .Tensor , positions : torch .Tensor
131131 ) -> torch .Tensor :
132- qkv , _ = self .qkv (hidden_states )
132+ qkv , _ = self .qkv_proj (hidden_states )
133133 q , k , v = qkv .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
134134 q , k = self .rotary_emb (positions , q , k )
135135 v = v .contiguous ()
@@ -606,9 +606,9 @@ def _load_weights_other(
606606 def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]) -> set [str ]:
607607 stacked_params_mapping = [
608608 # (param_name, shard_name, shard_id)
609- (".qkv " , ".q_proj" , "q" ),
610- (".qkv " , ".k_proj" , "k" ),
611- (".qkv " , ".v_proj" , "v" ),
609+ (".qkv_proj " , ".q_proj" , "q" ),
610+ (".qkv_proj " , ".k_proj" , "k" ),
611+ (".qkv_proj " , ".v_proj" , "v" ),
612612 ]
613613
614614 tp_rank = get_tensor_model_parallel_rank ()
0 commit comments