@@ -36,9 +36,7 @@ def __init__(self, inv_freq, scaling_factor, max_position_embeddings):
3636 self ._sin_k_cached = None
3737 self .scaling_factor = scaling_factor
3838 self .dynamic_args = None
39- self ._update_cos_sin_cache (
40- torch .float32 , inv_freq .device , max_position_embeddings
41- )
39+ self .max_position_embeddings = max_position_embeddings
4240
4341 def forward (
4442 self ,
@@ -270,7 +268,9 @@ def _update_cos_sin_cache(self, dtype, device, seqlen):
270268 self ._sin_cached = torch .sin (freqs ).to (dtype )
271269
272270 def get_cos_sin (self , position_ids : torch .Tensor ):
273-
271+ self ._update_cos_sin_cache (
272+ torch .float32 , position_ids .device , seqlen = self .max_position_embeddings
273+ )
274274 cos = torch .index_select (self ._cos_cached , 0 , position_ids )
275275 sin = torch .index_select (self ._sin_cached , 0 , position_ids )
276276
@@ -298,9 +298,6 @@ def __init__(
298298 self ._cos_k_cached = None
299299 self ._sin_k_cached = None
300300 self .dynamic_args = None
301- self ._update_cos_sin_cache (
302- torch .float32 , short_inv_freq .device , max_position_embeddings
303- )
304301
305302 def _update_cos_sin_cache (self , dtype , device , seqlen ):
306303 # Reset the tables if the sequence length has changed,
@@ -354,9 +351,6 @@ def __init__(
354351 self ._cos_k_cached = None
355352 self ._sin_k_cached = None
356353 self .dynamic_args = None
357- self ._update_cos_sin_cache (
358- torch .float32 , short_inv_freq .device , max_position_embeddings
359- )
360354
361355 def _update_cos_sin_cache (self , dtype , device , seqlen ):
362356 if (
@@ -598,6 +592,9 @@ def get_cos_sin(
598592 position_ids : torch .Tensor ,
599593 ):
600594 slen = position_ids .shape [0 ]
595+ self ._update_cos_sin_cache (
596+ torch .float32 , position_ids .device , seqlen = self .max_position_embeddings
597+ )
601598
602599 cos = self ._cos_cached [position_ids ].gather (1 , self ._sections [:slen ])
603600 sin = self ._sin_cached [position_ids ].gather (1 , self ._sections [:slen ])
0 commit comments