@@ -163,111 +163,124 @@ def _merge_input_ids_with_image_features(
163163 )
164164 return inputs_embeds
165165
166- def forward (
166+ def get_vision_embeds (
167167 self ,
168- input_ids : torch .Tensor ,
169- position_ids : torch .Tensor ,
170- cu_seqlen_prefill : Optional [torch .Tensor ],
171- kv_cache : List [Tuple [torch .Tensor , torch .Tensor ]],
172- slots : torch .Tensor ,
173- seqlen : Seqlen ,
174- hpu_attention_meta : Optional [HPUPagedAttentionMetadata ],
175- lm_head_indices : Optional [torch .Tensor ] = None ,
176- pixel_values : torch .FloatTensor = None ,
177- # Unused for this model
178- pixel_attention_mask = None ,
179- image_sizes : Optional [torch .LongTensor ] = None ,
180- adapter_data : Optional [torch .Tensor ] = None ,
168+ pixel_values : torch .FloatTensor ,
169+ pixel_attention_mask : Optional [torch .FloatTensor ] = None ,
170+ image_sizes : Optional [torch .Tensor ] = None ,
181171 image_grid_thw : Optional [torch .LongTensor ] = None ,
182172 ):
183- inputs_embeds = self .text_model .embed_tokens (input_ids )
184- if pixel_values is not None and len (pixel_values ) > 0 :
185- # num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
186- # assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
187- # 1. Extract the input embeddings
188-
189- # 2. Merge text and images
190- num_images , num_patches , channels , height , width = pixel_values .shape
191- pixel_values = pixel_values .view (
192- num_images * num_patches , channels , height , width
193- )
194- image_features = self .vision_tower (pixel_values )
173+ # num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
174+ # assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
175+ # 1. Extract the input embeddings
176+
177+ # 2. Merge text and images
178+ num_images , num_patches , channels , height , width = pixel_values .shape
179+ pixel_values = pixel_values .view (
180+ num_images * num_patches , channels , height , width
181+ )
182+ image_features = self .vision_tower (pixel_values )
195183
196- # selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
197- # Already done within the clip model
198- selected_image_feature = image_features .last_hidden_state
184+ # selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
185+ # Already done within the clip model
186+ selected_image_feature = image_features .last_hidden_state
199187
200- if self .config .vision_feature_select_strategy == "default" :
201- selected_image_feature = selected_image_feature [:, 1 :]
202- elif self .config .vision_feature_select_strategy == "full" :
203- selected_image_feature = selected_image_feature
204- else :
205- raise RuntimeError (
206- f"Strategy `{ self .config .vision_feature_select_strategy } ` is not supported/valid."
207- )
188+ if self .config .vision_feature_select_strategy == "default" :
189+ selected_image_feature = selected_image_feature [:, 1 :]
190+ elif self .config .vision_feature_select_strategy == "full" :
191+ selected_image_feature = selected_image_feature
192+ else :
193+ raise RuntimeError (
194+ f"Strategy `{ self .config .vision_feature_select_strategy } ` is not supported/valid."
195+ )
208196
209- image_features = self .multi_modal_projector (selected_image_feature )
197+ image_features = self .multi_modal_projector (selected_image_feature )
210198
211- # split up image_features for each of the individual images
212- # hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
213- # if we assume each image has 5 image features (base image + 4 patches)
214- split_sizes = [num_patches ] * num_images
215- image_features = torch .split (image_features , split_sizes , dim = 0 )
199+ # split up image_features for each of the individual images
200+ # hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
201+ # if we assume each image has 5 image features (base image + 4 patches)
202+ split_sizes = [num_patches ] * num_images
203+ image_features = torch .split (image_features , split_sizes , dim = 0 )
216204
217- # NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
218- height = width = (
219- self .config .vision_config .image_size
220- // self .config .vision_config .patch_size
221- )
205+ # NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
206+ height = width = (
207+ self .config .vision_config .image_size // self .config .vision_config .patch_size
208+ )
222209
223- new_image_features = []
224- for image_idx , image_feature in enumerate (image_features ):
225- if image_feature .shape [0 ] > 1 :
226- base_image_feature = image_feature [0 ]
227- image_feature = image_feature [1 :]
228-
229- if height * width != base_image_feature .shape [0 ]:
230- raise ValueError (
231- "The number of patches is not consistent with the image size."
232- )
233-
234- # Dimensions are intentionally swapped to be bug-compatible with
235- # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
236- num_patch_width , num_patch_height = get_anyres_image_grid_shape (
237- image_sizes [image_idx ],
238- self .config .image_grid_pinpoints ,
239- self .config .vision_config .image_size ,
240- )
241- image_feature = image_feature .view (
242- num_patch_height , num_patch_width , height , width , - 1
210+ new_image_features = []
211+ for image_idx , image_feature in enumerate (image_features ):
212+ if image_feature .shape [0 ] > 1 :
213+ base_image_feature = image_feature [0 ]
214+ image_feature = image_feature [1 :]
215+
216+ if height * width != base_image_feature .shape [0 ]:
217+ raise ValueError (
218+ "The number of patches is not consistent with the image size."
243219 )
244- image_feature = image_feature .permute (4 , 0 , 2 , 1 , 3 ).contiguous ()
245- image_feature = image_feature .flatten (1 , 2 ).flatten (2 , 3 )
246- image_feature = unpad_image (image_feature , image_sizes [image_idx ])
247- image_feature = torch .cat (
248- (
249- image_feature ,
250- self .image_newline [:, None , None ].expand (
251- * image_feature .shape [:- 1 ], 1
252- ),
220+
221+ # Dimensions are intentionally swapped to be bug-compatible with
222+ # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
223+ num_patch_width , num_patch_height = get_anyres_image_grid_shape (
224+ image_sizes [image_idx ],
225+ self .config .image_grid_pinpoints ,
226+ self .config .vision_config .image_size ,
227+ )
228+ image_feature = image_feature .view (
229+ num_patch_height , num_patch_width , height , width , - 1
230+ )
231+ image_feature = image_feature .permute (4 , 0 , 2 , 1 , 3 ).contiguous ()
232+ image_feature = image_feature .flatten (1 , 2 ).flatten (2 , 3 )
233+ image_feature = unpad_image (image_feature , image_sizes [image_idx ])
234+ image_feature = torch .cat (
235+ (
236+ image_feature ,
237+ self .image_newline [:, None , None ].expand (
238+ * image_feature .shape [:- 1 ], 1
253239 ),
254- dim = - 1 ,
255- )
256- image_feature = image_feature .flatten (1 , 2 ).transpose (0 , 1 )
257- image_feature = torch .cat (
258- (base_image_feature , image_feature ), dim = 0
259- )
260- else :
261- image_feature = image_feature [0 ]
262- image_feature = torch .cat (
263- (image_feature , self .image_newline [None ]), dim = 0
264- )
265- new_image_features .append (image_feature )
266- image_features = torch .stack (new_image_features , dim = 0 )
240+ ),
241+ dim = - 1 ,
242+ )
243+ image_feature = image_feature .flatten (1 , 2 ).transpose (0 , 1 )
244+ image_feature = torch .cat ((base_image_feature , image_feature ), dim = 0 )
245+ else :
246+ image_feature = image_feature [0 ]
247+ image_feature = torch .cat (
248+ (image_feature , self .image_newline [None ]), dim = 0
249+ )
250+ new_image_features .append (image_feature )
251+ image_features = torch .stack (new_image_features , dim = 0 )
252+ return image_features .view (- 1 , image_features .shape [- 1 ])
253+
254+ def get_inputs_embeds (
255+ self ,
256+ input_ids : torch .Tensor ,
257+ vision_embeds : torch .Tensor = None ,
258+ pixel_values : torch .FloatTensor = None ,
259+ image_sizes : Optional [torch .LongTensor ] = None ,
260+ ):
261+ inputs_embeds = self .text_model .embed_tokens (input_ids )
267262
263+ if vision_embeds is not None :
264+ # When we generate, we don't want to replace the potential image_token_id that we generated by images
265+ # that simply don't exist
268266 inputs_embeds = self ._merge_input_ids_with_image_features (
269- input_ids , inputs_embeds , image_features
267+ input_ids , inputs_embeds , vision_embeds
270268 )
269+ return inputs_embeds
270+
271+ def forward (
272+ self ,
273+ inputs_embeds : torch .Tensor ,
274+ position_ids : torch .Tensor ,
275+ cu_seqlen_prefill : Optional [torch .Tensor ],
276+ kv_cache : List [Tuple [torch .Tensor , torch .Tensor ]],
277+ slots : torch .Tensor ,
278+ seqlen : Seqlen ,
279+ hpu_attention_meta : Optional [HPUPagedAttentionMetadata ],
280+ lm_head_indices : Optional [torch .Tensor ] = None ,
281+ attention_mask : Optional [torch .BoolTensor ] = None ,
282+ adapter_data : Optional [torch .Tensor ] = None ,
283+ ):
271284
272285 hidden_states = self .text_model .model (
273286 inputs_embeds = inputs_embeds ,
0 commit comments