Skip to content

Commit 25fdc5f

Browse files
authored
[gaudi] Move the _update_cos_sin_cache into get_cos_sin (#3254)
Signed-off-by: yuanwu <[email protected]>
1 parent 613b8dd commit 25fdc5f

File tree

1 file changed

+7
-10
lines changed
  • backends/gaudi/server/text_generation_server/layers

1 file changed

+7
-10
lines changed

backends/gaudi/server/text_generation_server/layers/rotary.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,7 @@ def __init__(self, inv_freq, scaling_factor, max_position_embeddings):
3636
self._sin_k_cached = None
3737
self.scaling_factor = scaling_factor
3838
self.dynamic_args = None
39-
self._update_cos_sin_cache(
40-
torch.float32, inv_freq.device, max_position_embeddings
41-
)
39+
self.max_position_embeddings = max_position_embeddings
4240

4341
def forward(
4442
self,
@@ -270,7 +268,9 @@ def _update_cos_sin_cache(self, dtype, device, seqlen):
270268
self._sin_cached = torch.sin(freqs).to(dtype)
271269

272270
def get_cos_sin(self, position_ids: torch.Tensor):
273-
271+
self._update_cos_sin_cache(
272+
torch.float32, position_ids.device, seqlen=self.max_position_embeddings
273+
)
274274
cos = torch.index_select(self._cos_cached, 0, position_ids)
275275
sin = torch.index_select(self._sin_cached, 0, position_ids)
276276

@@ -298,9 +298,6 @@ def __init__(
298298
self._cos_k_cached = None
299299
self._sin_k_cached = None
300300
self.dynamic_args = None
301-
self._update_cos_sin_cache(
302-
torch.float32, short_inv_freq.device, max_position_embeddings
303-
)
304301

305302
def _update_cos_sin_cache(self, dtype, device, seqlen):
306303
# Reset the tables if the sequence length has changed,
@@ -354,9 +351,6 @@ def __init__(
354351
self._cos_k_cached = None
355352
self._sin_k_cached = None
356353
self.dynamic_args = None
357-
self._update_cos_sin_cache(
358-
torch.float32, short_inv_freq.device, max_position_embeddings
359-
)
360354

361355
def _update_cos_sin_cache(self, dtype, device, seqlen):
362356
if (
@@ -598,6 +592,9 @@ def get_cos_sin(
598592
position_ids: torch.Tensor,
599593
):
600594
slen = position_ids.shape[0]
595+
self._update_cos_sin_cache(
596+
torch.float32, position_ids.device, seqlen=self.max_position_embeddings
597+
)
601598

602599
cos = self._cos_cached[position_ids].gather(1, self._sections[:slen])
603600
sin = self._sin_cached[position_ids].gather(1, self._sections[:slen])

0 commit comments

Comments
 (0)