@@ -497,3 +497,77 @@ def forward(self, in_idx):
497497 x = self .final_norm (x )
498498 logits = self .out_head (x .to (self .cfg ["dtype" ]))
499499 return logits
500+
501+
502+ def assign (left , right , tensor_name = "unknown" ):
503+ if left .shape != right .shape :
504+ raise ValueError (f"Shape mismatch in tensor '{ tensor_name } '. Left: { left .shape } , Right: { right .shape } " )
505+
506+ if isinstance (right , torch .Tensor ):
507+ return torch .nn .Parameter (right .clone ().detach ())
508+ else :
509+ return torch .nn .Parameter (torch .tensor (right ))
510+
511+
512+ def load_weights_into_llama (model , param_config , params ):
513+ model .tok_emb .weight = assign (model .tok_emb .weight , params ["model.embed_tokens.weight" ], "model.embed_tokens.weight" )
514+
515+ for l in range (param_config ["n_layers" ]):
516+
517+ # Load attention weights
518+ model .trf_blocks [l ].att .W_query .weight = assign (
519+ model .trf_blocks [l ].att .W_query .weight ,
520+ params [f"model.layers.{ l } .self_attn.q_proj.weight" ],
521+ f"model.layers.{ l } .self_attn.q_proj.weight"
522+ )
523+ model .trf_blocks [l ].att .W_key .weight = assign (
524+ model .trf_blocks [l ].att .W_key .weight ,
525+ params [f"model.layers.{ l } .self_attn.k_proj.weight" ],
526+ f"model.layers.{ l } .self_attn.k_proj.weight"
527+ )
528+ model .trf_blocks [l ].att .W_value .weight = assign (
529+ model .trf_blocks [l ].att .W_value .weight ,
530+ params [f"model.layers.{ l } .self_attn.v_proj.weight" ],
531+ f"model.layers.{ l } .self_attn.v_proj.weight"
532+ )
533+ model .trf_blocks [l ].att .out_proj .weight = assign (
534+ model .trf_blocks [l ].att .out_proj .weight ,
535+ params [f"model.layers.{ l } .self_attn.o_proj.weight" ],
536+ f"model.layers.{ l } .self_attn.o_proj.weight"
537+ )
538+ model .trf_blocks [l ].norm1 .weight = assign (
539+ model .trf_blocks [l ].norm1 .weight ,
540+ params [f"model.layers.{ l } .input_layernorm.weight" ],
541+ f"model.layers.{ l } .input_layernorm.weight"
542+ )
543+
544+ # Load FeedForward weights
545+ model .trf_blocks [l ].ff .fc1 .weight = assign (
546+ model .trf_blocks [l ].ff .fc1 .weight ,
547+ params [f"model.layers.{ l } .mlp.gate_proj.weight" ],
548+ f"model.layers.{ l } .mlp.gate_proj.weight"
549+ )
550+ model .trf_blocks [l ].ff .fc2 .weight = assign (
551+ model .trf_blocks [l ].ff .fc2 .weight ,
552+ params [f"model.layers.{ l } .mlp.up_proj.weight" ],
553+ f"model.layers.{ l } .mlp.up_proj.weight"
554+ )
555+ model .trf_blocks [l ].ff .fc3 .weight = assign (
556+ model .trf_blocks [l ].ff .fc3 .weight ,
557+ params [f"model.layers.{ l } .mlp.down_proj.weight" ],
558+ f"model.layers.{ l } .mlp.down_proj.weight"
559+ )
560+ model .trf_blocks [l ].norm2 .weight = assign (
561+ model .trf_blocks [l ].norm2 .weight ,
562+ params [f"model.layers.{ l } .post_attention_layernorm.weight" ],
563+ f"model.layers.{ l } .post_attention_layernorm.weight"
564+ )
565+
566+ # Load output layer weights
567+ model .final_norm .weight = assign (model .final_norm .weight , params ["model.norm.weight" ], "model.norm.weight" )
568+
569+ if "lm_head.weight" in params .keys ():
570+ model .out_head .weight = assign (model .out_head .weight , params ["lm_head.weight" ], "lm_head.weight" )
571+ else :
572+ model .out_head .weight = assign (model .out_head .weight , params ["model.embed_tokens.weight" ], "model.embed_tokens.weight" )
573+ print ("Model uses weight tying." )
0 commit comments