-
Notifications
You must be signed in to change notification settings - Fork 3.6k
[do not merge] diffusion: support mutli image input #14236
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -452,21 +452,26 @@ def get_pos_prompt_embeds(self, batch): | |||||
| def get_neg_prompt_embeds(self, batch): | ||||||
| return batch.negative_prompt_embeds[0] | ||||||
|
|
||||||
| def calculate_condition_image_size( | ||||||
| self, image, width, height | ||||||
| ) -> Optional[tuple[int, int]]: | ||||||
| def calculate_condition_image_size(self, images) -> Optional[tuple[int, int]]: | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The return type hint for
Suggested change
|
||||||
| width, height = [], [] | ||||||
| target_area: int = 1024 * 1024 | ||||||
| if width is not None and height is not None: | ||||||
| if width * height > target_area: | ||||||
| scale = math.sqrt(target_area / (width * height)) | ||||||
| width = int(width * scale) | ||||||
| height = int(height * scale) | ||||||
| return width, height | ||||||
|
|
||||||
| return None | ||||||
|
|
||||||
| def resize_condition_image(self, image, target_width, target_height): | ||||||
| return image.resize((target_width, target_height), PIL.Image.Resampling.LANCZOS) | ||||||
| for image in images: | ||||||
| w, h = image.size | ||||||
| if w * h > target_area: | ||||||
| scale = math.sqrt(target_area / (w * h)) | ||||||
| w = int(w * scale) | ||||||
| h = int(h * scale) | ||||||
| width.append(w) | ||||||
| height.append(h) | ||||||
| return width, height | ||||||
|
|
||||||
| def resize_condition_image(self, images, target_width, target_height): | ||||||
| new_images = [] | ||||||
| for image, width, height in zip(images, target_width, target_height): | ||||||
| new_images.append( | ||||||
| image.resize((width, height), PIL.Image.Resampling.LANCZOS) | ||||||
| ) | ||||||
| return new_images | ||||||
|
Comment on lines
+468
to
+474
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
|
|
||||||
| def postprocess_image_latent(self, latent_condition, batch): | ||||||
| batch_size = batch.batch_size | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,9 @@ | |
| from sglang.multimodal_gen.runtime.models.vision_utils import resize | ||
| from sglang.multimodal_gen.utils import calculate_dimensions | ||
|
|
||
| CONDITION_IMAGE_SIZE = 384 * 384 | ||
| VAE_IMAGE_SIZE = 1024 * 1024 | ||
|
|
||
|
|
||
| def _extract_masked_hidden(hidden_states: torch.Tensor, mask: torch.Tensor): | ||
| bool_mask = mask.bool() | ||
|
|
@@ -220,32 +223,31 @@ class QwenImageEditPipelineConfig(QwenImagePipelineConfig): | |
|
|
||
| task_type: ModelTaskType = ModelTaskType.I2I | ||
|
|
||
| def _get_condition_image_sizes(self, batch) -> list[tuple[int, int]]: | ||
| image_size = batch.original_condition_image_size[0] | ||
| edit_width, edit_height, _ = calculate_dimensions( | ||
| VAE_IMAGE_SIZE, image_size[0] / image_size[1] | ||
| ) | ||
| return [(edit_width, edit_height)] | ||
|
|
||
| def _prepare_edit_cond_kwargs( | ||
| self, batch, prompt_embeds, rotary_emb, device, dtype | ||
| ): | ||
| batch_size = batch.latents.shape[0] | ||
| assert batch_size == 1 | ||
| height = batch.height | ||
| width = batch.width | ||
| image_size = batch.original_condition_image_size | ||
| edit_width, edit_height, _ = calculate_dimensions( | ||
| 1024 * 1024, image_size[0] / image_size[1] | ||
| ) | ||
| condition_image_sizes = self._get_condition_image_sizes(batch) | ||
| vae_scale_factor = self.get_vae_scale_factor() | ||
|
|
||
| img_shapes = [ | ||
| [ | ||
| ( | ||
| 1, | ||
| height // vae_scale_factor // 2, | ||
| width // vae_scale_factor // 2, | ||
| ), | ||
| ( | ||
| 1, | ||
| edit_height // vae_scale_factor // 2, | ||
| edit_width // vae_scale_factor // 2, | ||
| ), | ||
| ], | ||
| (1, height // vae_scale_factor // 2, width // vae_scale_factor // 2), | ||
| *[ | ||
| (1, h // vae_scale_factor // 2, w // vae_scale_factor // 2) | ||
| for w, h in condition_image_sizes | ||
| ], | ||
| ] | ||
| ] * batch_size | ||
| txt_seq_lens = [prompt_embeds[0].shape[1]] | ||
| (img_cos, img_sin), (txt_cos, txt_sin) = QwenImagePipelineConfig.get_freqs_cis( | ||
|
|
@@ -273,8 +275,11 @@ def _prepare_edit_cond_kwargs( | |
| "freqs_cis": ((img_cos, img_sin), (txt_cos, txt_sin)), | ||
| } | ||
|
|
||
| def resize_condition_image(self, image, target_width, target_height): | ||
| return resize(image, target_height, target_width, resize_mode="default") | ||
| def resize_condition_image(self, images, target_width, target_height): | ||
| new_images = [] | ||
| for img, width, height in zip(images, target_width, target_height): | ||
| new_images.append(resize(img, height, width, resize_mode="default")) | ||
| return new_images | ||
|
|
||
| def postprocess_image_latent(self, latent_condition, batch): | ||
| batch_size = batch.batch_size | ||
|
|
@@ -313,13 +318,96 @@ def prepare_neg_cond_kwargs(self, batch, device, rotary_emb, dtype): | |
| batch, batch.negative_prompt_embeds, rotary_emb, device, dtype | ||
| ) | ||
|
|
||
| def calculate_condition_image_size(self, image, width, height) -> tuple[int, int]: | ||
| calculated_width, calculated_height, _ = calculate_dimensions( | ||
| 1024 * 1024, width / height | ||
| ) | ||
| return calculated_width, calculated_height | ||
| def calculate_condition_image_size(self, images) -> tuple[int, int]: | ||
yhyang201 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| width, height = [], [] | ||
| for image in images: | ||
| image_width, image_height = image.size | ||
| calculated_width, calculated_height, _ = calculate_dimensions( | ||
| VAE_IMAGE_SIZE, image_width / image_height | ||
| ) | ||
| width.append(calculated_width) | ||
| height.append(calculated_height) | ||
| return width, height | ||
|
|
||
| def slice_noise_pred(self, noise, latents): | ||
| # remove noise over input image | ||
| noise = noise[:, : latents.size(1)] | ||
| return noise | ||
|
|
||
| def preprocess_image(self, image, image_processor): | ||
| if not isinstance(image, list): | ||
| image = [image] | ||
| condition_images = [] | ||
| for img in image: | ||
| image_width, image_height = img.size | ||
| condition_width, condition_height, _ = calculate_dimensions( | ||
| CONDITION_IMAGE_SIZE, image_width / image_height | ||
| ) | ||
| condition_images.append( | ||
| image_processor.resize(img, condition_height, condition_width) | ||
| ) | ||
|
|
||
| return condition_images | ||
|
|
||
|
|
||
| class QwenImageEditPlusPipelineConfig(QwenImageEditPipelineConfig): | ||
| task_type: ModelTaskType = ModelTaskType.I2I | ||
|
|
||
| def _get_condition_image_sizes(self, batch) -> list[tuple[int, int]]: | ||
| image = batch.condition_image | ||
| if not isinstance(image, list): | ||
| image = [image] | ||
|
|
||
| condition_image_sizes = [] | ||
| for img in image: | ||
| image_width, image_height = img.size | ||
| edit_width, edit_height, _ = calculate_dimensions( | ||
| VAE_IMAGE_SIZE, image_width / image_height | ||
| ) | ||
| condition_image_sizes.append((edit_width, edit_height)) | ||
|
|
||
| return condition_image_sizes | ||
|
|
||
| def prepare_image_processor_kwargs(self, batch) -> dict: | ||
| prompt = batch.prompt | ||
| prompt_list = [prompt] if isinstance(prompt, str) else prompt | ||
| image_list = batch.condition_image | ||
|
|
||
| prompt_template_encode = ( | ||
| "<|im_start|>system\nDescribe the key features of the input image " | ||
| "(color, shape, size, texture, objects, background), then explain how " | ||
| "the user's text instruction should alter or modify the image. Generate " | ||
| "a new image that meets the user's requirements while maintaining " | ||
| "consistency with the original input where appropriate.<|im_end|>\n" | ||
| "<|im_start|>user\n{}<|im_end|>\n" | ||
| "<|im_start|>assistant\n" | ||
| ) | ||
| img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>" | ||
| if isinstance(image_list, list): | ||
| base_img_prompt = "" | ||
| for i, img in enumerate(image_list): | ||
| base_img_prompt += img_prompt_template.format(i + 1) | ||
| txt = [prompt_template_encode.format(base_img_prompt + p) for p in prompt_list] | ||
| return dict(text=txt, padding=True) | ||
|
Comment on lines
+371
to
+391
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a potential def prepare_image_processor_kwargs(self, batch) -> dict:
prompt = batch.prompt
prompt_list = [prompt] if isinstance(prompt, str) else prompt
image_list = batch.condition_image
prompt_template_encode = (
"<|im_start|>system\nDescribe the key features of the input image "
"(color, shape, size, texture, objects, background), then explain how "
"the user's text instruction should alter or modify the image. Generate "
"a new image that meets the user's requirements while maintaining "
"consistency with the original input where appropriate.<|im_end|>\n"
"<|im_start|>user\n{}\n<|im_end|>\n"
"<|im_start|>assistant\n"
)
img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"
base_img_prompt = ""
if image_list:
if not isinstance(image_list, list):
image_list = [image_list]
for i, img in enumerate(image_list):
base_img_prompt += img_prompt_template.format(i + 1)
txt = [prompt_template_encode.format(base_img_prompt + p) for p in prompt_list]
return dict(text=txt, padding=True) |
||
|
|
||
| def resize_condition_image(self, images, target_width, target_height): | ||
| if not isinstance(images, list): | ||
| images = [images] | ||
| new_images = [] | ||
| for img, width, height in zip(images, target_width, target_height): | ||
| new_images.append(resize(img, height, width, resize_mode="default")) | ||
| return new_images | ||
|
|
||
| def calculate_condition_image_size(self, images) -> tuple[int, int]: | ||
| calculated_widths = [] | ||
| calculated_heights = [] | ||
|
|
||
| for img in images: | ||
| image_width, image_height = img.size | ||
| calculated_width, calculated_height, _ = calculate_dimensions( | ||
| VAE_IMAGE_SIZE, image_width / image_height | ||
| ) | ||
| calculated_widths.append(calculated_width) | ||
| calculated_heights.append(calculated_height) | ||
|
|
||
| return calculated_widths, calculated_heights | ||
|
Comment on lines
+401
to
+413
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This method has two issues:
Since the logic is the same, this method can be removed from |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return type hint for
calculate_condition_image_sizeis incorrect. The function returns two lists of integers (widthandheight), but the type hint istuple[int, int]. It should betuple[list[int], list[int]]to match the actual return value.