From 9a61f6044239c7a57f503a9c29b8ef4a15b6f94c Mon Sep 17 00:00:00 2001 From: yhyang201 Date: Mon, 1 Dec 2025 17:09:07 +0000 Subject: [PATCH] support mutli image input Co-authored-by: Zyann --- .../configs/pipeline_configs/base.py | 28 +++- .../configs/pipeline_configs/flux.py | 33 +++-- .../configs/pipeline_configs/qwen_image.py | 132 +++++++++++++++--- .../configs/sample/qwenimage.py | 8 ++ python/sglang/multimodal_gen/registry.py | 7 + .../runtime/entrypoints/openai/image_api.py | 19 ++- .../runtime/pipelines/qwen_image.py | 6 +- .../runtime/pipelines_core/schedule_batch.py | 4 +- .../pipelines_core/stages/image_encoding.py | 94 ++++++++----- 9 files changed, 247 insertions(+), 84 deletions(-) diff --git a/python/sglang/multimodal_gen/configs/pipeline_configs/base.py b/python/sglang/multimodal_gen/configs/pipeline_configs/base.py index f0c9c85d384..95b1c6baa8e 100644 --- a/python/sglang/multimodal_gen/configs/pipeline_configs/base.py +++ b/python/sglang/multimodal_gen/configs/pipeline_configs/base.py @@ -173,25 +173,40 @@ def postprocess_image(self, image): # Wan2.2 TI2V parameters boundary_ratio: float | None = None + # i2i, i2v + multi_image_input: bool = False + # Compilation # enable_torch_compile: bool = False # calculate the adjust size for condition image # width: original condition image width # height: original condition image height - def calculate_condition_image_size(self, image, width, height) -> tuple[int, int]: - vae_scale_factor = self.vae_config.arch_config.spatial_compression_ratio - height, width = get_default_height_width(image, vae_scale_factor, height, width) + def calculate_condition_image_size(self, images) -> tuple[int, int]: + width, height = [], [] + for image in images: + vae_scale_factor = self.vae_config.arch_config.spatial_compression_ratio + h, w = get_default_height_width(image, vae_scale_factor) + height.append(h) + width.append(w) return width, height + def support_multi_image_input(self) -> bool: + return self.multi_image_input + ## For timestep preparation stage def prepare_sigmas(self, sigmas, num_inference_steps): return sigmas ## For ImageVAEEncodingStage - def resize_condition_image(self, image, target_width, target_height): - return image.resize((target_width, target_height), PIL.Image.Resampling.LANCZOS) + def resize_condition_image(self, images, target_width, target_height): + new_images = [] + for image, width, height in zip(images, target_width, target_height): + new_images.append( + image.resize((width, height), PIL.Image.Resampling.LANCZOS) + ) + return new_images def prepare_image_processor_kwargs(self, batch): return {} @@ -262,6 +277,9 @@ def get_decode_scale_and_shift(self, device, dtype, vae): shift_factor = getattr(vae, "shift_factor", None) return scaling_factor, shift_factor + def preprocess_image(self, image, image_processor): + return image + # called after latents are prepared def maybe_pack_latents(self, latents, batch_size, batch): return latents diff --git a/python/sglang/multimodal_gen/configs/pipeline_configs/flux.py b/python/sglang/multimodal_gen/configs/pipeline_configs/flux.py index 402cbf686d5..c380f304961 100644 --- a/python/sglang/multimodal_gen/configs/pipeline_configs/flux.py +++ b/python/sglang/multimodal_gen/configs/pipeline_configs/flux.py @@ -452,21 +452,26 @@ def get_pos_prompt_embeds(self, batch): def get_neg_prompt_embeds(self, batch): return batch.negative_prompt_embeds[0] - def calculate_condition_image_size( - self, image, width, height - ) -> Optional[tuple[int, int]]: + def calculate_condition_image_size(self, images) -> Optional[tuple[int, int]]: + width, height = [], [] target_area: int = 1024 * 1024 - if width is not None and height is not None: - if width * height > target_area: - scale = math.sqrt(target_area / (width * height)) - width = int(width * scale) - height = int(height * scale) - return width, height - - return None - - def resize_condition_image(self, image, target_width, target_height): - return image.resize((target_width, target_height), PIL.Image.Resampling.LANCZOS) + for image in images: + w, h = image.size + if w * h > target_area: + scale = math.sqrt(target_area / (w * h)) + w = int(w * scale) + h = int(h * scale) + width.append(w) + height.append(h) + return width, height + + def resize_condition_image(self, images, target_width, target_height): + new_images = [] + for image, width, height in zip(images, target_width, target_height): + new_images.append( + image.resize((width, height), PIL.Image.Resampling.LANCZOS) + ) + return new_images def postprocess_image_latent(self, latent_condition, batch): batch_size = batch.batch_size diff --git a/python/sglang/multimodal_gen/configs/pipeline_configs/qwen_image.py b/python/sglang/multimodal_gen/configs/pipeline_configs/qwen_image.py index f6f589a7e39..82796ae3e09 100644 --- a/python/sglang/multimodal_gen/configs/pipeline_configs/qwen_image.py +++ b/python/sglang/multimodal_gen/configs/pipeline_configs/qwen_image.py @@ -17,6 +17,9 @@ from sglang.multimodal_gen.runtime.models.vision_utils import resize from sglang.multimodal_gen.utils import calculate_dimensions +CONDITION_IMAGE_SIZE = 384 * 384 +VAE_IMAGE_SIZE = 1024 * 1024 + def _extract_masked_hidden(hidden_states: torch.Tensor, mask: torch.Tensor): bool_mask = mask.bool() @@ -220,6 +223,13 @@ class QwenImageEditPipelineConfig(QwenImagePipelineConfig): task_type: ModelTaskType = ModelTaskType.I2I + def _get_condition_image_sizes(self, batch) -> list[tuple[int, int]]: + image_size = batch.original_condition_image_size[0] + edit_width, edit_height, _ = calculate_dimensions( + VAE_IMAGE_SIZE, image_size[0] / image_size[1] + ) + return [(edit_width, edit_height)] + def _prepare_edit_cond_kwargs( self, batch, prompt_embeds, rotary_emb, device, dtype ): @@ -227,25 +237,17 @@ def _prepare_edit_cond_kwargs( assert batch_size == 1 height = batch.height width = batch.width - image_size = batch.original_condition_image_size - edit_width, edit_height, _ = calculate_dimensions( - 1024 * 1024, image_size[0] / image_size[1] - ) + condition_image_sizes = self._get_condition_image_sizes(batch) vae_scale_factor = self.get_vae_scale_factor() img_shapes = [ [ - ( - 1, - height // vae_scale_factor // 2, - width // vae_scale_factor // 2, - ), - ( - 1, - edit_height // vae_scale_factor // 2, - edit_width // vae_scale_factor // 2, - ), - ], + (1, height // vae_scale_factor // 2, width // vae_scale_factor // 2), + *[ + (1, h // vae_scale_factor // 2, w // vae_scale_factor // 2) + for w, h in condition_image_sizes + ], + ] ] * batch_size txt_seq_lens = [prompt_embeds[0].shape[1]] (img_cos, img_sin), (txt_cos, txt_sin) = QwenImagePipelineConfig.get_freqs_cis( @@ -273,8 +275,11 @@ def _prepare_edit_cond_kwargs( "freqs_cis": ((img_cos, img_sin), (txt_cos, txt_sin)), } - def resize_condition_image(self, image, target_width, target_height): - return resize(image, target_height, target_width, resize_mode="default") + def resize_condition_image(self, images, target_width, target_height): + new_images = [] + for img, width, height in zip(images, target_width, target_height): + new_images.append(resize(img, height, width, resize_mode="default")) + return new_images def postprocess_image_latent(self, latent_condition, batch): batch_size = batch.batch_size @@ -313,13 +318,96 @@ def prepare_neg_cond_kwargs(self, batch, device, rotary_emb, dtype): batch, batch.negative_prompt_embeds, rotary_emb, device, dtype ) - def calculate_condition_image_size(self, image, width, height) -> tuple[int, int]: - calculated_width, calculated_height, _ = calculate_dimensions( - 1024 * 1024, width / height - ) - return calculated_width, calculated_height + def calculate_condition_image_size(self, images) -> tuple[int, int]: + width, height = [], [] + for image in images: + image_width, image_height = image.size + calculated_width, calculated_height, _ = calculate_dimensions( + VAE_IMAGE_SIZE, image_width / image_height + ) + width.append(calculated_width) + height.append(calculated_height) + return width, height def slice_noise_pred(self, noise, latents): # remove noise over input image noise = noise[:, : latents.size(1)] return noise + + def preprocess_image(self, image, image_processor): + if not isinstance(image, list): + image = [image] + condition_images = [] + for img in image: + image_width, image_height = img.size + condition_width, condition_height, _ = calculate_dimensions( + CONDITION_IMAGE_SIZE, image_width / image_height + ) + condition_images.append( + image_processor.resize(img, condition_height, condition_width) + ) + + return condition_images + + +class QwenImageEditPlusPipelineConfig(QwenImageEditPipelineConfig): + task_type: ModelTaskType = ModelTaskType.I2I + + def _get_condition_image_sizes(self, batch) -> list[tuple[int, int]]: + image = batch.condition_image + if not isinstance(image, list): + image = [image] + + condition_image_sizes = [] + for img in image: + image_width, image_height = img.size + edit_width, edit_height, _ = calculate_dimensions( + VAE_IMAGE_SIZE, image_width / image_height + ) + condition_image_sizes.append((edit_width, edit_height)) + + return condition_image_sizes + + def prepare_image_processor_kwargs(self, batch) -> dict: + prompt = batch.prompt + prompt_list = [prompt] if isinstance(prompt, str) else prompt + image_list = batch.condition_image + + prompt_template_encode = ( + "<|im_start|>system\nDescribe the key features of the input image " + "(color, shape, size, texture, objects, background), then explain how " + "the user's text instruction should alter or modify the image. Generate " + "a new image that meets the user's requirements while maintaining " + "consistency with the original input where appropriate.<|im_end|>\n" + "<|im_start|>user\n{}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>" + if isinstance(image_list, list): + base_img_prompt = "" + for i, img in enumerate(image_list): + base_img_prompt += img_prompt_template.format(i + 1) + txt = [prompt_template_encode.format(base_img_prompt + p) for p in prompt_list] + return dict(text=txt, padding=True) + + def resize_condition_image(self, images, target_width, target_height): + if not isinstance(images, list): + images = [images] + new_images = [] + for img, width, height in zip(images, target_width, target_height): + new_images.append(resize(img, height, width, resize_mode="default")) + return new_images + + def calculate_condition_image_size(self, images) -> tuple[int, int]: + calculated_widths = [] + calculated_heights = [] + + for img in images: + image_width, image_height = img.size + calculated_width, calculated_height, _ = calculate_dimensions( + VAE_IMAGE_SIZE, image_width / image_height + ) + calculated_widths.append(calculated_width) + calculated_heights.append(calculated_height) + + return calculated_widths, calculated_heights diff --git a/python/sglang/multimodal_gen/configs/sample/qwenimage.py b/python/sglang/multimodal_gen/configs/sample/qwenimage.py index c3270395559..574f34b55e3 100644 --- a/python/sglang/multimodal_gen/configs/sample/qwenimage.py +++ b/python/sglang/multimodal_gen/configs/sample/qwenimage.py @@ -16,3 +16,11 @@ class QwenImageSamplingParams(SamplingParams): # Denoising stage guidance_scale: float = 4.0 num_inference_steps: int = 50 + + +@dataclass +class QwenImageEditPlusSamplingParams(QwenImageSamplingParams): + # Denoising stage + guidance_scale: float = 1.0 + # "true_cfg_scale": 4.0 TODO(yhyang201): check if this is correct + num_inference_steps: int = 40 diff --git a/python/sglang/multimodal_gen/registry.py b/python/sglang/multimodal_gen/registry.py index 97e6c7c61a0..c809ff95579 100644 --- a/python/sglang/multimodal_gen/registry.py +++ b/python/sglang/multimodal_gen/registry.py @@ -30,6 +30,7 @@ from sglang.multimodal_gen.configs.pipeline_configs.flux import Flux2PipelineConfig from sglang.multimodal_gen.configs.pipeline_configs.qwen_image import ( QwenImageEditPipelineConfig, + QwenImageEditPlusPipelineConfig, QwenImagePipelineConfig, ) from sglang.multimodal_gen.configs.pipeline_configs.wan import ( @@ -428,5 +429,11 @@ def _register_configs(): pipeline_config_cls=QwenImageEditPipelineConfig, ) + register_configs( + model_name="qwen-image-edit-2509", + sampling_param_cls=QwenImageSamplingParams, + pipeline_config_cls=QwenImageEditPlusPipelineConfig, + ) + _register_configs() diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/openai/image_api.py b/python/sglang/multimodal_gen/runtime/entrypoints/openai/image_api.py index 3a0411cd6c4..edc1a8ba849 100644 --- a/python/sglang/multimodal_gen/runtime/entrypoints/openai/image_api.py +++ b/python/sglang/multimodal_gen/runtime/entrypoints/openai/image_api.py @@ -79,6 +79,7 @@ def _build_req_from_sampling(s: SamplingParams) -> Req: request_id=s.request_id, data_type=s.data_type, prompt=s.prompt, + negative_prompt=s.negative_prompt, image_path=s.image_path, height=s.height, width=s.width, @@ -160,12 +161,18 @@ async def edits( if not images or len(images) == 0: raise HTTPException(status_code=422, detail="Field 'image' is required") - # Save first input image; additional images or mask are not yet used by the pipeline + # Save all input images; additional images beyond the first are saved for potential future use uploads_dir = os.path.join("outputs", "uploads") os.makedirs(uploads_dir, exist_ok=True) - first_image = images[0] - input_path = os.path.join(uploads_dir, f"{request_id}_{first_image.filename}") - await _save_upload_to_path(first_image, input_path) + if images is not None and not isinstance(images, list): + images = [images] + + input_paths = [] + for idx, img in enumerate(images): + filename = img.filename or f"image_{idx}" + input_path = os.path.join(uploads_dir, f"{request_id}_{idx}_{filename}") + await _save_upload_to_path(img, input_path) + input_paths.append(input_path) sampling = _build_sampling_params_from_request( request_id=request_id, @@ -174,7 +181,7 @@ async def edits( size=size, output_format=output_format, background=background, - image_path=input_path, + image_path=input_paths, ) batch = _build_req_from_sampling(sampling) @@ -186,6 +193,8 @@ async def edits( "id": request_id, "created_at": int(time.time()), "file_path": save_file_path, + "input_image_paths": input_paths, # Store all input image paths + "num_input_images": len(input_paths), }, ) diff --git a/python/sglang/multimodal_gen/runtime/pipelines/qwen_image.py b/python/sglang/multimodal_gen/runtime/pipelines/qwen_image.py index c87b7274013..688ec25118b 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines/qwen_image.py +++ b/python/sglang/multimodal_gen/runtime/pipelines/qwen_image.py @@ -189,4 +189,8 @@ def create_pipeline_stages(self, server_args: ServerArgs): ) -EntryClass = [QwenImagePipeline, QwenImageEditPipeline] +class QwenImageEditPlusPipeline(QwenImageEditPipeline): + pipeline_name = "QwenImageEditPlusPipeline" + + +EntryClass = [QwenImagePipeline, QwenImageEditPipeline, QwenImageEditPlusPipeline] diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/schedule_batch.py b/python/sglang/multimodal_gen/runtime/pipelines_core/schedule_batch.py index 8e2ca11f946..7ff375c1838 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/schedule_batch.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/schedule_batch.py @@ -57,7 +57,9 @@ class Req: image_embeds: list[torch.Tensor] = field(default_factory=list) original_condition_image_size: tuple[int, int] = None - condition_image: torch.Tensor | PIL.Image.Image | None = None + condition_image: torch.Tensor | PIL.Image.Image | list[PIL.Image.Image] | None = ( + None + ) pixel_values: torch.Tensor | PIL.Image.Image | None = None preprocessed_image: torch.Tensor | None = None diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/image_encoding.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/image_encoding.py index 1ecbd11b562..0484d8795c2 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/image_encoding.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/image_encoding.py @@ -108,6 +108,10 @@ def forward( server_args.pipeline_config.prepare_image_processor_kwargs(batch) ) + image = server_args.pipeline_config.preprocess_image( + image, self.vae_image_processor + ) + image_inputs = self.image_processor( images=image, return_tensors="pt", **image_processor_kwargs ).to(cuda_device) @@ -124,12 +128,9 @@ def forward( elif self.text_encoder: # if a text encoder is provided, e.g. Qwen-Image-Edit # 1. neg prompt embeds - if batch.prompt: - prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" - txt = prompt_template_encode.format(batch.negative_prompt) - neg_image_processor_kwargs = dict(text=[txt], padding=True) - else: - neg_image_processor_kwargs = {} + neg_image_processor_kwargs = ( + server_args.pipeline_config.prepare_image_processor_kwargs(batch) + ) neg_image_inputs = self.image_processor( images=image, return_tensors="pt", **neg_image_processor_kwargs @@ -150,6 +151,8 @@ def forward( image_grid_thw=neg_image_inputs.image_grid_thw, output_hidden_states=True, ) + + # abstract this batch.prompt_embeds.append( self.encoding_qwen_image_edit(outputs, image_inputs) ) @@ -191,37 +194,14 @@ def __init__(self, vae: ParallelTiledVAE, **kwargs) -> None: super().__init__() self.vae: ParallelTiledVAE = vae - def forward( + def _process_single_image( self, + image: PIL.Image.Image, + num_frames: int, batch: Req, server_args: ServerArgs, - ) -> Req: - """ - Encode pixel representations into latent space. - - Args: - batch: The current batch information. - server_args: The inference arguments. - - Returns: - The batch with encoded outputs. - """ - - if batch.condition_image is None: - return batch - - assert batch.condition_image is not None and isinstance( - batch.condition_image, PIL.Image.Image - ) - assert batch.height is not None and isinstance(batch.height, int) - assert batch.width is not None and isinstance(batch.width, int) - assert batch.num_frames is not None and isinstance(batch.num_frames, int) - - num_frames = batch.num_frames - - self.vae = self.vae.to(get_local_torch_device()) + ) -> torch.Tensor: - image = batch.condition_image image = self.preprocess( image, ).to(get_local_torch_device(), dtype=torch.float32) @@ -304,14 +284,56 @@ def forward( latent_condition -= shift_factor latent_condition = latent_condition * scaling_factor - batch.image_latent = server_args.pipeline_config.postprocess_image_latent( + image_latent = server_args.pipeline_config.postprocess_image_latent( latent_condition, batch ) - self.maybe_free_model_hooks() + return image_latent - self.vae.to("cpu") + def forward( + self, + batch: Req, + server_args: ServerArgs, + ) -> Req: + """ + Encode pixel representations into latent space. + + Args: + batch: The current batch information. + server_args: The inference arguments. + + Returns: + The batch with encoded outputs. + """ + + if batch.condition_image is None: + return batch + + assert isinstance(batch.condition_image, list) + assert len(batch.condition_image) > 0 + assert all(isinstance(img, PIL.Image.Image) for img in batch.condition_image) + + assert batch.height is not None and isinstance(batch.height, int) + assert batch.width is not None and isinstance(batch.width, int) + assert batch.num_frames is not None and isinstance(batch.num_frames, int) + + num_frames = batch.num_frames + + self.vae = self.vae.to(get_local_torch_device()) + + image_latents = [] + for image in batch.condition_image: + latent = self._process_single_image( + image=image, num_frames=num_frames, batch=batch, server_args=server_args + ) + image_latents.append(latent) + if server_args.pipeline_config.support_multi_image_input(): + batch.image_latent = image_latents + else: + batch.image_latent = image_latents[0] + self.maybe_free_model_hooks() + self.vae.to("cpu") return batch def retrieve_latents(