@@ -97,7 +97,7 @@ def __init__(
9797 num_heads : int ,
9898 num_kv_heads : int ,
9999 rope_theta : float = 10000 ,
100- rope_scaling : dict [str , Any ] | None = None ,
100+ rope_parameters : dict [str , Any ] | None = None ,
101101 max_position_embeddings : int = 8192 ,
102102 cache_config : CacheConfig | None = None ,
103103 quant_config : QuantizationConfig | None = None ,
@@ -150,7 +150,7 @@ def __init__(
150150 rotary_dim = self .head_dim ,
151151 max_position = self .max_position_embeddings ,
152152 base = self .rope_theta ,
153- rope_scaling = rope_scaling ,
153+ rope_parameters = rope_parameters ,
154154 is_neox_style = True ,
155155 )
156156 self .attn = Attention (
@@ -199,14 +199,8 @@ def __init__(
199199 self .config = config
200200 self .layer_idx = layer_idx
201201
202- rope_theta = getattr (config , "rope_theta" , 10000 )
203- rope_scaling = getattr (config , "rope_scaling" , None )
204- if rope_scaling is not None and getattr (
205- config , "original_max_position_embeddings" , None
206- ):
207- rope_scaling ["original_max_position_embeddings" ] = (
208- config .original_max_position_embeddings
209- )
202+ if ompe := getattr (config , "original_max_position_embeddings" , None ):
203+ config .rope_parameters ["original_max_position_embeddings" ] = ompe
210204 max_position_embeddings = getattr (config , "max_position_embeddings" , 8192 )
211205
212206 self .self_attn = Lfm2Attention (
@@ -215,8 +209,8 @@ def __init__(
215209 hidden_size = config .hidden_size ,
216210 num_heads = config .num_attention_heads ,
217211 num_kv_heads = config .num_key_value_heads ,
218- rope_theta = rope_theta ,
219- rope_scaling = rope_scaling ,
212+ rope_theta = config . rope_parameters [ " rope_theta" ] ,
213+ rope_parameters = config . rope_parameters ,
220214 max_position_embeddings = max_position_embeddings ,
221215 cache_config = cache_config ,
222216 quant_config = quant_config ,
0 commit comments