Skip to content

Commit 613b8dd

Browse files
authored
[gaudi] Vlm rebase and issue fix in benchmark test (#3263)
Signed-off-by: Wang, Yi A <[email protected]>
1 parent 8394776 commit 613b8dd

File tree

13 files changed

+1084
-666
lines changed

13 files changed

+1084
-666
lines changed

backends/gaudi/server/text_generation_server/models/__init__.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,6 @@
8383
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
8484
FlashGPTNeoXForCausalLM,
8585
)
86-
from text_generation_server.models.pali_gemma import (
87-
PaliGemmaBatch,
88-
)
8986
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
9087
PaliGemmaForConditionalGeneration,
9188
)
@@ -153,7 +150,6 @@
153150
)
154151

155152
VLM_BATCH_TYPES = {
156-
PaliGemmaBatch,
157153
FlashVlmCausalLMBatch,
158154
FlashMllamaCausalLMBatch,
159155
}
@@ -635,6 +631,7 @@ def get_model(
635631
default_dtype=torch.bfloat16,
636632
trust_remote_code=trust_remote_code,
637633
lora_adapter_ids=lora_adapter_ids,
634+
support_chunking=False,
638635
)
639636
elif model_type == BAICHUAN:
640637
return FlashCausalLM(
@@ -784,6 +781,8 @@ def get_model(
784781
kv_cache_dtype=kv_cache_dtype,
785782
trust_remote_code=trust_remote_code,
786783
lora_adapter_ids=lora_adapter_ids,
784+
# TODO: Fix bug in rust image_text_replacement implementation
785+
support_chunking=False,
787786
)
788787
elif model_type == QWEN2_5_VL:
789788
return FlashVlmCausalLM(
@@ -799,6 +798,8 @@ def get_model(
799798
lora_adapter_ids=lora_adapter_ids,
800799
config_class=Qwen2_5_VLConfig,
801800
processor_class=Qwen2_5_VLProcessor,
801+
# TODO: Fix bug in rust image_text_replacement implementation
802+
support_chunking=False,
802803
)
803804
elif model_type == QWEN3:
804805
return FlashCausalLM(
@@ -824,6 +825,7 @@ def get_model(
824825
default_dtype=torch.bfloat16,
825826
trust_remote_code=trust_remote_code,
826827
lora_adapter_ids=lora_adapter_ids,
828+
support_chunking=False,
827829
)
828830
elif model_type == IDEFICS2:
829831
return FlashVlmCausalLM(
@@ -868,7 +870,6 @@ def get_model(
868870
default_dtype=torch.bfloat16,
869871
trust_remote_code=trust_remote_code,
870872
lora_adapter_ids=lora_adapter_ids,
871-
batch_class=PaliGemmaBatch,
872873
)
873874
elif model_type == LLAVA_NEXT:
874875
return FlashVlmCausalLM(

backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py

Lines changed: 45 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1356,55 +1356,36 @@ def get_image_features(
13561356
hidden_state = self.vision_model(pixel_values)
13571357
return hidden_state
13581358

1359-
def forward(
1359+
def get_vision_embeds(
13601360
self,
1361-
input_ids: torch.LongTensor = None,
1362-
pixel_values: torch.FloatTensor = None,
1363-
pixel_attention_mask=None,
1364-
position_ids: Optional[torch.LongTensor] = None,
1365-
cu_seqlen_prefill: Optional[torch.Tensor] = None,
1366-
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = None,
1367-
slots: torch.Tensor = None,
1368-
seqlen: Seqlen = None,
1369-
hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
1370-
inputs_embeds: Optional[torch.FloatTensor] = None,
1371-
vision_feature_layer: Optional[Union[int, List[int]]] = None,
1372-
vision_feature_select_strategy: Optional[str] = None,
1373-
image_sizes: torch.Tensor = None,
1374-
lm_head_indices: Optional[torch.Tensor] = None,
1375-
adapter_data: Optional[torch.Tensor] = None,
1376-
**lm_kwargs,
1377-
) -> Tuple[torch.Tensor, torch.Tensor]:
1378-
1379-
def _get_padding_mask(input_ids, pad_token_id=0):
1380-
return (input_ids != pad_token_id).long()
1361+
pixel_values: torch.FloatTensor,
1362+
pixel_attention_mask: Optional[torch.FloatTensor] = None,
1363+
image_sizes: Optional[torch.Tensor] = None,
1364+
image_grid_thw: Optional[torch.LongTensor] = None,
1365+
):
1366+
image_features = self.get_image_features(
1367+
pixel_values=pixel_values,
1368+
vision_feature_layer=self.config.vision_config.vision_feature_layer,
1369+
vision_feature_select_strategy=self.config.vision_config.vision_feature_select_strategy,
1370+
image_sizes=image_sizes,
1371+
)
1372+
vision_flat = image_features.view(-1, image_features.size(-1))
1373+
image_features = self.multi_modal_projector(vision_flat)
1374+
return image_features
13811375

1382-
attention_mask = _get_padding_mask(input_ids)
1383-
attention_mask = attention_mask.view(seqlen.input_lengths.shape[0], -1)
1376+
def get_inputs_embeds(
1377+
self,
1378+
input_ids: torch.Tensor,
1379+
vision_embeds: torch.Tensor = None,
1380+
pixel_values: torch.FloatTensor = None,
1381+
image_sizes: Optional[torch.LongTensor] = None,
1382+
):
13841383
inputs_embeds = self.text_model.model.embed_tokens(input_ids)
1385-
vision_feature_layer = (
1386-
vision_feature_layer
1387-
if vision_feature_layer is not None
1388-
else self.config.vision_config.vision_feature_layer
1389-
)
1390-
vision_feature_select_strategy = (
1391-
vision_feature_select_strategy
1392-
if vision_feature_select_strategy is not None
1393-
else self.config.vision_config.vision_feature_select_strategy
1394-
)
1395-
1396-
if pixel_values is not None:
1397-
image_features = self.get_image_features(
1398-
pixel_values=pixel_values,
1399-
vision_feature_layer=vision_feature_layer,
1400-
vision_feature_select_strategy=vision_feature_select_strategy,
1401-
image_sizes=image_sizes,
1402-
)
1403-
original_inputs_embeds_shape = inputs_embeds.shape
1404-
1405-
vision_flat = image_features.view(-1, image_features.size(-1))
1406-
projected_vision_flat = self.multi_modal_projector(vision_flat)
14071384

1385+
if vision_embeds is not None:
1386+
# When we generate, we don't want to replace the potential image_token_id that we generated by images
1387+
# that simply don't exist
1388+
original_inputs_embeds_shape = inputs_embeds.shape
14081389
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
14091390
-1
14101391
)
@@ -1414,19 +1395,33 @@ def _get_padding_mask(input_ids, pad_token_id=0):
14141395
final_mask_1d = final_mask[..., 0].reshape(-1)
14151396
num_tokens_to_fill = final_mask_1d.sum()
14161397

1417-
if num_tokens_to_fill != projected_vision_flat.size(0):
1398+
if num_tokens_to_fill != vision_embeds.size(0):
14181399
raise ValueError(
14191400
f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, "
1420-
f"but multi_modal_projector returned {projected_vision_flat.size(0)}"
1401+
f"but multi_modal_projector returned {vision_embeds.size(0)}"
14211402
)
14221403

14231404
expanded_mask = final_mask_1d.unsqueeze(-1).expand(
14241405
-1, inputs_embeds.size(-1)
14251406
)
1426-
inputs_embeds = inputs_embeds.masked_scatter(
1427-
expanded_mask, projected_vision_flat
1428-
)
1407+
inputs_embeds = inputs_embeds.masked_scatter(expanded_mask, vision_embeds)
14291408
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape)
1409+
return inputs_embeds
1410+
1411+
def forward(
1412+
self,
1413+
inputs_embeds: torch.Tensor,
1414+
position_ids: Optional[torch.LongTensor] = None,
1415+
cu_seqlen_prefill: Optional[torch.Tensor] = None,
1416+
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = None,
1417+
slots: torch.Tensor = None,
1418+
seqlen: Seqlen = None,
1419+
hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
1420+
lm_head_indices: Optional[torch.Tensor] = None,
1421+
attention_mask: Optional[torch.Tensor] = None,
1422+
adapter_data: Optional[torch.Tensor] = None,
1423+
**lm_kwargs,
1424+
) -> Tuple[torch.Tensor, torch.Tensor]:
14301425

14311426
logits, speculative_logits = self.text_model(
14321427
inputs_embeds,

backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py

Lines changed: 104 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -163,111 +163,124 @@ def _merge_input_ids_with_image_features(
163163
)
164164
return inputs_embeds
165165

166-
def forward(
166+
def get_vision_embeds(
167167
self,
168-
input_ids: torch.Tensor,
169-
position_ids: torch.Tensor,
170-
cu_seqlen_prefill: Optional[torch.Tensor],
171-
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
172-
slots: torch.Tensor,
173-
seqlen: Seqlen,
174-
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
175-
lm_head_indices: Optional[torch.Tensor] = None,
176-
pixel_values: torch.FloatTensor = None,
177-
# Unused for this model
178-
pixel_attention_mask=None,
179-
image_sizes: Optional[torch.LongTensor] = None,
180-
adapter_data: Optional[torch.Tensor] = None,
168+
pixel_values: torch.FloatTensor,
169+
pixel_attention_mask: Optional[torch.FloatTensor] = None,
170+
image_sizes: Optional[torch.Tensor] = None,
181171
image_grid_thw: Optional[torch.LongTensor] = None,
182172
):
183-
inputs_embeds = self.text_model.embed_tokens(input_ids)
184-
if pixel_values is not None and len(pixel_values) > 0:
185-
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
186-
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
187-
# 1. Extract the input embeddings
188-
189-
# 2. Merge text and images
190-
num_images, num_patches, channels, height, width = pixel_values.shape
191-
pixel_values = pixel_values.view(
192-
num_images * num_patches, channels, height, width
193-
)
194-
image_features = self.vision_tower(pixel_values)
173+
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
174+
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
175+
# 1. Extract the input embeddings
176+
177+
# 2. Merge text and images
178+
num_images, num_patches, channels, height, width = pixel_values.shape
179+
pixel_values = pixel_values.view(
180+
num_images * num_patches, channels, height, width
181+
)
182+
image_features = self.vision_tower(pixel_values)
195183

196-
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
197-
# Already done within the clip model
198-
selected_image_feature = image_features.last_hidden_state
184+
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
185+
# Already done within the clip model
186+
selected_image_feature = image_features.last_hidden_state
199187

200-
if self.config.vision_feature_select_strategy == "default":
201-
selected_image_feature = selected_image_feature[:, 1:]
202-
elif self.config.vision_feature_select_strategy == "full":
203-
selected_image_feature = selected_image_feature
204-
else:
205-
raise RuntimeError(
206-
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
207-
)
188+
if self.config.vision_feature_select_strategy == "default":
189+
selected_image_feature = selected_image_feature[:, 1:]
190+
elif self.config.vision_feature_select_strategy == "full":
191+
selected_image_feature = selected_image_feature
192+
else:
193+
raise RuntimeError(
194+
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
195+
)
208196

209-
image_features = self.multi_modal_projector(selected_image_feature)
197+
image_features = self.multi_modal_projector(selected_image_feature)
210198

211-
# split up image_features for each of the individual images
212-
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
213-
# if we assume each image has 5 image features (base image + 4 patches)
214-
split_sizes = [num_patches] * num_images
215-
image_features = torch.split(image_features, split_sizes, dim=0)
199+
# split up image_features for each of the individual images
200+
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
201+
# if we assume each image has 5 image features (base image + 4 patches)
202+
split_sizes = [num_patches] * num_images
203+
image_features = torch.split(image_features, split_sizes, dim=0)
216204

217-
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
218-
height = width = (
219-
self.config.vision_config.image_size
220-
// self.config.vision_config.patch_size
221-
)
205+
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
206+
height = width = (
207+
self.config.vision_config.image_size // self.config.vision_config.patch_size
208+
)
222209

223-
new_image_features = []
224-
for image_idx, image_feature in enumerate(image_features):
225-
if image_feature.shape[0] > 1:
226-
base_image_feature = image_feature[0]
227-
image_feature = image_feature[1:]
228-
229-
if height * width != base_image_feature.shape[0]:
230-
raise ValueError(
231-
"The number of patches is not consistent with the image size."
232-
)
233-
234-
# Dimensions are intentionally swapped to be bug-compatible with
235-
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
236-
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
237-
image_sizes[image_idx],
238-
self.config.image_grid_pinpoints,
239-
self.config.vision_config.image_size,
240-
)
241-
image_feature = image_feature.view(
242-
num_patch_height, num_patch_width, height, width, -1
210+
new_image_features = []
211+
for image_idx, image_feature in enumerate(image_features):
212+
if image_feature.shape[0] > 1:
213+
base_image_feature = image_feature[0]
214+
image_feature = image_feature[1:]
215+
216+
if height * width != base_image_feature.shape[0]:
217+
raise ValueError(
218+
"The number of patches is not consistent with the image size."
243219
)
244-
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
245-
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
246-
image_feature = unpad_image(image_feature, image_sizes[image_idx])
247-
image_feature = torch.cat(
248-
(
249-
image_feature,
250-
self.image_newline[:, None, None].expand(
251-
*image_feature.shape[:-1], 1
252-
),
220+
221+
# Dimensions are intentionally swapped to be bug-compatible with
222+
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
223+
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
224+
image_sizes[image_idx],
225+
self.config.image_grid_pinpoints,
226+
self.config.vision_config.image_size,
227+
)
228+
image_feature = image_feature.view(
229+
num_patch_height, num_patch_width, height, width, -1
230+
)
231+
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
232+
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
233+
image_feature = unpad_image(image_feature, image_sizes[image_idx])
234+
image_feature = torch.cat(
235+
(
236+
image_feature,
237+
self.image_newline[:, None, None].expand(
238+
*image_feature.shape[:-1], 1
253239
),
254-
dim=-1,
255-
)
256-
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
257-
image_feature = torch.cat(
258-
(base_image_feature, image_feature), dim=0
259-
)
260-
else:
261-
image_feature = image_feature[0]
262-
image_feature = torch.cat(
263-
(image_feature, self.image_newline[None]), dim=0
264-
)
265-
new_image_features.append(image_feature)
266-
image_features = torch.stack(new_image_features, dim=0)
240+
),
241+
dim=-1,
242+
)
243+
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
244+
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
245+
else:
246+
image_feature = image_feature[0]
247+
image_feature = torch.cat(
248+
(image_feature, self.image_newline[None]), dim=0
249+
)
250+
new_image_features.append(image_feature)
251+
image_features = torch.stack(new_image_features, dim=0)
252+
return image_features.view(-1, image_features.shape[-1])
253+
254+
def get_inputs_embeds(
255+
self,
256+
input_ids: torch.Tensor,
257+
vision_embeds: torch.Tensor = None,
258+
pixel_values: torch.FloatTensor = None,
259+
image_sizes: Optional[torch.LongTensor] = None,
260+
):
261+
inputs_embeds = self.text_model.embed_tokens(input_ids)
267262

263+
if vision_embeds is not None:
264+
# When we generate, we don't want to replace the potential image_token_id that we generated by images
265+
# that simply don't exist
268266
inputs_embeds = self._merge_input_ids_with_image_features(
269-
input_ids, inputs_embeds, image_features
267+
input_ids, inputs_embeds, vision_embeds
270268
)
269+
return inputs_embeds
270+
271+
def forward(
272+
self,
273+
inputs_embeds: torch.Tensor,
274+
position_ids: torch.Tensor,
275+
cu_seqlen_prefill: Optional[torch.Tensor],
276+
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
277+
slots: torch.Tensor,
278+
seqlen: Seqlen,
279+
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
280+
lm_head_indices: Optional[torch.Tensor] = None,
281+
attention_mask: Optional[torch.BoolTensor] = None,
282+
adapter_data: Optional[torch.Tensor] = None,
283+
):
271284

272285
hidden_states = self.text_model.model(
273286
inputs_embeds=inputs_embeds,

0 commit comments

Comments
 (0)