Skip to content

Commit 8394776

Browse files
authored
[gaudi] Perf optimization (#3256)
Signed-off-by: Wang, Yi A <[email protected]>
1 parent 79183d1 commit 8394776

24 files changed

+229
-66
lines changed

backends/gaudi/server/text_generation_server/layers/attention/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
attention,
1212
paged_attention,
1313
paged_attention_mla,
14+
set_block_mapping,
1415
)
1516

1617

@@ -22,6 +23,7 @@
2223
"get_kv_scales",
2324
"paged_attention",
2425
"paged_attention_mla",
26+
"set_block_mapping",
2527
"SUPPORTS_WINDOWING",
2628
"KVCache",
2729
"KVCompressCache",

backends/gaudi/server/text_generation_server/layers/attention/hpu.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from vllm_hpu_extension.utils import ModuleFusedSDPA
99
import os
1010
from text_generation_server.models.globals import BLOCK_SIZE
11+
import math
1112

1213
SUPPORTS_WINDOWING = False
1314

@@ -106,6 +107,21 @@ def attention(
106107
return attn_output
107108

108109

110+
def set_block_mapping(hpu_attention_meta: HPUPagedAttentionMetadata, batch_size):
111+
block_mapping = torch.nn.functional.one_hot(
112+
hpu_attention_meta.block_groups, num_classes=batch_size
113+
)
114+
dtype = hpu_attention_meta.block_usage.dtype
115+
device = hpu_attention_meta.block_usage.device
116+
mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)
117+
mask = mask >= hpu_attention_meta.block_usage.unsqueeze(-1)
118+
attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)
119+
hpu_attention_meta = hpu_attention_meta._replace(
120+
attn_bias=attn_bias, block_mapping=block_mapping.to(dtype)
121+
)
122+
return hpu_attention_meta
123+
124+
109125
def paged_attention(
110126
query: torch.Tensor,
111127
kv_cache: KVCache,
@@ -176,4 +192,10 @@ def paged_attention_mla(
176192
return output.view(batch_size, head_num, -1)
177193

178194

179-
__all__ = ["SUPPORTS_WINDOWING", "attention", "paged_attention", "paged_attention_mla"]
195+
__all__ = [
196+
"SUPPORTS_WINDOWING",
197+
"attention",
198+
"paged_attention",
199+
"paged_attention_mla",
200+
"set_block_mapping",
201+
]

backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from text_generation_server.layers.attention import (
2929
paged_attention,
3030
attention,
31+
set_block_mapping,
3132
Seqlen,
3233
HPUPagedAttentionMetadata,
3334
)
@@ -415,6 +416,10 @@ def forward(
415416
seqlen: torch.Tensor,
416417
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
417418
) -> torch.Tensor:
419+
if hpu_attention_meta is not None:
420+
hpu_attention_meta = set_block_mapping(
421+
hpu_attention_meta, input_ids.shape[0]
422+
)
418423
hidden_states = self.embed_tokens(input_ids)
419424

420425
# Get rotary cos and sin for this forward

backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from text_generation_server.layers.attention import (
2727
paged_attention,
2828
attention,
29+
set_block_mapping,
2930
Seqlen,
3031
HPUPagedAttentionMetadata,
3132
)
@@ -678,6 +679,10 @@ def forward(
678679
seqlen: Seqlen,
679680
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
680681
) -> torch.Tensor:
682+
if hpu_attention_meta is not None:
683+
hpu_attention_meta = set_block_mapping(
684+
hpu_attention_meta, input_ids.shape[0]
685+
)
681686
hidden_states = self.embed_tokens(input_ids)
682687

683688
# Get rotary cos and sin for this forward

backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
Seqlen,
3434
attention,
3535
paged_attention,
36+
set_block_mapping,
3637
HPUPagedAttentionMetadata,
3738
)
3839
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
@@ -569,6 +570,10 @@ def forward(
569570
seqlen: Seqlen,
570571
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
571572
) -> torch.Tensor:
573+
if hpu_attention_meta is not None:
574+
hpu_attention_meta = set_block_mapping(
575+
hpu_attention_meta, input_ids.shape[0]
576+
)
572577
hidden_states = self.embed_tokens(input_ids)
573578

574579
# Get rotary cos and sin for this forward

backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
Seqlen,
3535
attention,
3636
paged_attention_mla,
37+
set_block_mapping,
3738
HPUPagedAttentionMetadata,
3839
)
3940
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
@@ -645,6 +646,10 @@ def forward(
645646
seqlen: Seqlen,
646647
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
647648
) -> torch.Tensor:
649+
if hpu_attention_meta is not None:
650+
hpu_attention_meta = set_block_mapping(
651+
hpu_attention_meta, input_ids.shape[0]
652+
)
648653
hidden_states = self.embed_tokens(input_ids)
649654

650655
# Get rotary cos and sin for this forward

backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from text_generation_server.layers.attention import (
2929
paged_attention,
3030
attention,
31+
set_block_mapping,
3132
Seqlen,
3233
HPUPagedAttentionMetadata,
3334
)
@@ -466,6 +467,10 @@ def forward(
466467
adapter_data: Optional[torch.Tensor],
467468
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
468469
) -> torch.Tensor:
470+
if hpu_attention_meta is not None:
471+
hpu_attention_meta = set_block_mapping(
472+
hpu_attention_meta, inputs_embeds.shape[0]
473+
)
469474
hidden_states = inputs_embeds
470475

471476
# Get rotary cos and sin for this forward

backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from text_generation_server.layers.attention import (
2929
paged_attention,
3030
attention,
31+
set_block_mapping,
3132
Seqlen,
3233
HPUPagedAttentionMetadata,
3334
)
@@ -388,6 +389,10 @@ def forward(
388389
adapter_data: Optional[torch.Tensor],
389390
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
390391
) -> torch.Tensor:
392+
if hpu_attention_meta is not None:
393+
hpu_attention_meta = set_block_mapping(
394+
hpu_attention_meta, inputs_embeds.shape[0]
395+
)
391396
hidden_states = inputs_embeds
392397

393398
# Get rotary cos and sin for this forward

backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from text_generation_server.layers.attention import (
2828
paged_attention,
2929
attention,
30+
set_block_mapping,
3031
Seqlen,
3132
HPUPagedAttentionMetadata,
3233
)
@@ -383,6 +384,10 @@ def forward(
383384
seqlen: Seqlen,
384385
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
385386
) -> torch.Tensor:
387+
if hpu_attention_meta is not None:
388+
hpu_attention_meta = set_block_mapping(
389+
hpu_attention_meta, inputs_embeds.shape[0]
390+
)
386391
hidden_states = inputs_embeds
387392

388393
residual = None

backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from text_generation_server.layers.attention import (
2929
paged_attention,
3030
attention,
31+
set_block_mapping,
3132
Seqlen,
3233
HPUPagedAttentionMetadata,
3334
)
@@ -324,6 +325,10 @@ def forward(
324325
seqlen: Seqlen,
325326
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
326327
) -> torch.Tensor:
328+
if hpu_attention_meta is not None:
329+
hpu_attention_meta = set_block_mapping(
330+
hpu_attention_meta, input_ids.shape[0]
331+
)
327332
hidden_states = self.wte(input_ids)
328333

329334
# Get rotary cos and sin for this forward

0 commit comments

Comments
 (0)