Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions python/sglang/multimodal_gen/configs/pipeline_configs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,25 +173,40 @@ def postprocess_image(self, image):
# Wan2.2 TI2V parameters
boundary_ratio: float | None = None

# i2i, i2v
multi_image_input: bool = False

# Compilation
# enable_torch_compile: bool = False

# calculate the adjust size for condition image
# width: original condition image width
# height: original condition image height
def calculate_condition_image_size(self, image, width, height) -> tuple[int, int]:
vae_scale_factor = self.vae_config.arch_config.spatial_compression_ratio
height, width = get_default_height_width(image, vae_scale_factor, height, width)
def calculate_condition_image_size(self, images) -> tuple[int, int]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The return type hint for calculate_condition_image_size is incorrect. The function returns two lists of integers (width and height), but the type hint is tuple[int, int]. It should be tuple[list[int], list[int]] to match the actual return value.

Suggested change
def calculate_condition_image_size(self, images) -> tuple[int, int]:
def calculate_condition_image_size(self, images) -> tuple[list[int], list[int]]:

width, height = [], []
for image in images:
vae_scale_factor = self.vae_config.arch_config.spatial_compression_ratio
h, w = get_default_height_width(image, vae_scale_factor)
height.append(h)
width.append(w)
return width, height

def support_multi_image_input(self) -> bool:
return self.multi_image_input

## For timestep preparation stage

def prepare_sigmas(self, sigmas, num_inference_steps):
return sigmas

## For ImageVAEEncodingStage
def resize_condition_image(self, image, target_width, target_height):
return image.resize((target_width, target_height), PIL.Image.Resampling.LANCZOS)
def resize_condition_image(self, images, target_width, target_height):
new_images = []
for image, width, height in zip(images, target_width, target_height):
new_images.append(
image.resize((width, height), PIL.Image.Resampling.LANCZOS)
)
return new_images

def prepare_image_processor_kwargs(self, batch):
return {}
Expand Down Expand Up @@ -262,6 +277,9 @@ def get_decode_scale_and_shift(self, device, dtype, vae):
shift_factor = getattr(vae, "shift_factor", None)
return scaling_factor, shift_factor

def preprocess_image(self, image, image_processor):
return image

# called after latents are prepared
def maybe_pack_latents(self, latents, batch_size, batch):
return latents
Expand Down
33 changes: 19 additions & 14 deletions python/sglang/multimodal_gen/configs/pipeline_configs/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,21 +452,26 @@ def get_pos_prompt_embeds(self, batch):
def get_neg_prompt_embeds(self, batch):
return batch.negative_prompt_embeds[0]

def calculate_condition_image_size(
self, image, width, height
) -> Optional[tuple[int, int]]:
def calculate_condition_image_size(self, images) -> Optional[tuple[int, int]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The return type hint for calculate_condition_image_size is incorrect. The function now returns two lists of integers, but the hint is Optional[tuple[int, int]]. It should be updated to tuple[list[int], list[int]]. Since the function no longer returns None, the Optional is also unnecessary.

Suggested change
def calculate_condition_image_size(self, images) -> Optional[tuple[int, int]]:
def calculate_condition_image_size(self, images) -> tuple[list[int], list[int]]:

width, height = [], []
target_area: int = 1024 * 1024
if width is not None and height is not None:
if width * height > target_area:
scale = math.sqrt(target_area / (width * height))
width = int(width * scale)
height = int(height * scale)
return width, height

return None

def resize_condition_image(self, image, target_width, target_height):
return image.resize((target_width, target_height), PIL.Image.Resampling.LANCZOS)
for image in images:
w, h = image.size
if w * h > target_area:
scale = math.sqrt(target_area / (w * h))
w = int(w * scale)
h = int(h * scale)
width.append(w)
height.append(h)
return width, height

def resize_condition_image(self, images, target_width, target_height):
new_images = []
for image, width, height in zip(images, target_width, target_height):
new_images.append(
image.resize((width, height), PIL.Image.Resampling.LANCZOS)
)
return new_images
Comment on lines +468 to +474
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This resize_condition_image method is an exact copy of the implementation in the base class PipelineConfig. This override is redundant and can be removed to avoid code duplication and improve maintainability.


def postprocess_image_latent(self, latent_condition, batch):
batch_size = batch.batch_size
Expand Down
132 changes: 110 additions & 22 deletions python/sglang/multimodal_gen/configs/pipeline_configs/qwen_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from sglang.multimodal_gen.runtime.models.vision_utils import resize
from sglang.multimodal_gen.utils import calculate_dimensions

CONDITION_IMAGE_SIZE = 384 * 384
VAE_IMAGE_SIZE = 1024 * 1024


def _extract_masked_hidden(hidden_states: torch.Tensor, mask: torch.Tensor):
bool_mask = mask.bool()
Expand Down Expand Up @@ -220,32 +223,31 @@ class QwenImageEditPipelineConfig(QwenImagePipelineConfig):

task_type: ModelTaskType = ModelTaskType.I2I

def _get_condition_image_sizes(self, batch) -> list[tuple[int, int]]:
image_size = batch.original_condition_image_size[0]
edit_width, edit_height, _ = calculate_dimensions(
VAE_IMAGE_SIZE, image_size[0] / image_size[1]
)
return [(edit_width, edit_height)]

def _prepare_edit_cond_kwargs(
self, batch, prompt_embeds, rotary_emb, device, dtype
):
batch_size = batch.latents.shape[0]
assert batch_size == 1
height = batch.height
width = batch.width
image_size = batch.original_condition_image_size
edit_width, edit_height, _ = calculate_dimensions(
1024 * 1024, image_size[0] / image_size[1]
)
condition_image_sizes = self._get_condition_image_sizes(batch)
vae_scale_factor = self.get_vae_scale_factor()

img_shapes = [
[
(
1,
height // vae_scale_factor // 2,
width // vae_scale_factor // 2,
),
(
1,
edit_height // vae_scale_factor // 2,
edit_width // vae_scale_factor // 2,
),
],
(1, height // vae_scale_factor // 2, width // vae_scale_factor // 2),
*[
(1, h // vae_scale_factor // 2, w // vae_scale_factor // 2)
for w, h in condition_image_sizes
],
]
] * batch_size
txt_seq_lens = [prompt_embeds[0].shape[1]]
(img_cos, img_sin), (txt_cos, txt_sin) = QwenImagePipelineConfig.get_freqs_cis(
Expand Down Expand Up @@ -273,8 +275,11 @@ def _prepare_edit_cond_kwargs(
"freqs_cis": ((img_cos, img_sin), (txt_cos, txt_sin)),
}

def resize_condition_image(self, image, target_width, target_height):
return resize(image, target_height, target_width, resize_mode="default")
def resize_condition_image(self, images, target_width, target_height):
new_images = []
for img, width, height in zip(images, target_width, target_height):
new_images.append(resize(img, height, width, resize_mode="default"))
return new_images

def postprocess_image_latent(self, latent_condition, batch):
batch_size = batch.batch_size
Expand Down Expand Up @@ -313,13 +318,96 @@ def prepare_neg_cond_kwargs(self, batch, device, rotary_emb, dtype):
batch, batch.negative_prompt_embeds, rotary_emb, device, dtype
)

def calculate_condition_image_size(self, image, width, height) -> tuple[int, int]:
calculated_width, calculated_height, _ = calculate_dimensions(
1024 * 1024, width / height
)
return calculated_width, calculated_height
def calculate_condition_image_size(self, images) -> tuple[int, int]:
width, height = [], []
for image in images:
image_width, image_height = image.size
calculated_width, calculated_height, _ = calculate_dimensions(
VAE_IMAGE_SIZE, image_width / image_height
)
width.append(calculated_width)
height.append(calculated_height)
return width, height

def slice_noise_pred(self, noise, latents):
# remove noise over input image
noise = noise[:, : latents.size(1)]
return noise

def preprocess_image(self, image, image_processor):
if not isinstance(image, list):
image = [image]
condition_images = []
for img in image:
image_width, image_height = img.size
condition_width, condition_height, _ = calculate_dimensions(
CONDITION_IMAGE_SIZE, image_width / image_height
)
condition_images.append(
image_processor.resize(img, condition_height, condition_width)
)

return condition_images


class QwenImageEditPlusPipelineConfig(QwenImageEditPipelineConfig):
task_type: ModelTaskType = ModelTaskType.I2I

def _get_condition_image_sizes(self, batch) -> list[tuple[int, int]]:
image = batch.condition_image
if not isinstance(image, list):
image = [image]

condition_image_sizes = []
for img in image:
image_width, image_height = img.size
edit_width, edit_height, _ = calculate_dimensions(
VAE_IMAGE_SIZE, image_width / image_height
)
condition_image_sizes.append((edit_width, edit_height))

return condition_image_sizes

def prepare_image_processor_kwargs(self, batch) -> dict:
prompt = batch.prompt
prompt_list = [prompt] if isinstance(prompt, str) else prompt
image_list = batch.condition_image

prompt_template_encode = (
"<|im_start|>system\nDescribe the key features of the input image "
"(color, shape, size, texture, objects, background), then explain how "
"the user's text instruction should alter or modify the image. Generate "
"a new image that meets the user's requirements while maintaining "
"consistency with the original input where appropriate.<|im_end|>\n"
"<|im_start|>user\n{}<|im_end|>\n"
"<|im_start|>assistant\n"
)
img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"
if isinstance(image_list, list):
base_img_prompt = ""
for i, img in enumerate(image_list):
base_img_prompt += img_prompt_template.format(i + 1)
txt = [prompt_template_encode.format(base_img_prompt + p) for p in prompt_list]
return dict(text=txt, padding=True)
Comment on lines +371 to +391
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a potential NameError in this method. If batch.condition_image (assigned to image_list) is not a list (e.g., a single image object), the if isinstance(image_list, list): block will be skipped. This means base_img_prompt will not be initialized, causing a NameError on the line txt = [prompt_template_encode.format(base_img_prompt + p) for p in prompt_list]. To fix this, you should initialize base_img_prompt before the if block and also handle the case where image_list is a single image.

    def prepare_image_processor_kwargs(self, batch) -> dict:
        prompt = batch.prompt
        prompt_list = [prompt] if isinstance(prompt, str) else prompt
        image_list = batch.condition_image

        prompt_template_encode = (
            "<|im_start|>system\nDescribe the key features of the input image "
            "(color, shape, size, texture, objects, background), then explain how "
            "the user's text instruction should alter or modify the image. Generate "
            "a new image that meets the user's requirements while maintaining "
            "consistency with the original input where appropriate.<|im_end|>\n"
            "<|im_start|>user\n{}\n<|im_end|>\n"
            "<|im_start|>assistant\n"
        )
        img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"
        base_img_prompt = ""
        if image_list:
            if not isinstance(image_list, list):
                image_list = [image_list]
            for i, img in enumerate(image_list):
                base_img_prompt += img_prompt_template.format(i + 1)
        txt = [prompt_template_encode.format(base_img_prompt + p) for p in prompt_list]
        return dict(text=txt, padding=True)


def resize_condition_image(self, images, target_width, target_height):
if not isinstance(images, list):
images = [images]
new_images = []
for img, width, height in zip(images, target_width, target_height):
new_images.append(resize(img, height, width, resize_mode="default"))
return new_images

def calculate_condition_image_size(self, images) -> tuple[int, int]:
calculated_widths = []
calculated_heights = []

for img in images:
image_width, image_height = img.size
calculated_width, calculated_height, _ = calculate_dimensions(
VAE_IMAGE_SIZE, image_width / image_height
)
calculated_widths.append(calculated_width)
calculated_heights.append(calculated_height)

return calculated_widths, calculated_heights
Comment on lines +401 to +413
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This method has two issues:

  1. The return type hint is tuple[int, int], but it returns two lists of integers. It should be tuple[list[int], list[int]].
  2. This method is a near-identical copy of the implementation in its parent class QwenImageEditPipelineConfig. This creates unnecessary code duplication.

Since the logic is the same, this method can be removed from QwenImageEditPlusPipelineConfig to inherit the implementation from the parent class, which would fix both issues and improve maintainability.

8 changes: 8 additions & 0 deletions python/sglang/multimodal_gen/configs/sample/qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,11 @@ class QwenImageSamplingParams(SamplingParams):
# Denoising stage
guidance_scale: float = 4.0
num_inference_steps: int = 50


@dataclass
class QwenImageEditPlusSamplingParams(QwenImageSamplingParams):
# Denoising stage
guidance_scale: float = 1.0
# "true_cfg_scale": 4.0 TODO(yhyang201): check if this is correct
num_inference_steps: int = 40
7 changes: 7 additions & 0 deletions python/sglang/multimodal_gen/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from sglang.multimodal_gen.configs.pipeline_configs.flux import Flux2PipelineConfig
from sglang.multimodal_gen.configs.pipeline_configs.qwen_image import (
QwenImageEditPipelineConfig,
QwenImageEditPlusPipelineConfig,
QwenImagePipelineConfig,
)
from sglang.multimodal_gen.configs.pipeline_configs.wan import (
Expand Down Expand Up @@ -428,5 +429,11 @@ def _register_configs():
pipeline_config_cls=QwenImageEditPipelineConfig,
)

register_configs(
model_name="qwen-image-edit-2509",
sampling_param_cls=QwenImageSamplingParams,
pipeline_config_cls=QwenImageEditPlusPipelineConfig,
)


_register_configs()
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def _build_req_from_sampling(s: SamplingParams) -> Req:
request_id=s.request_id,
data_type=s.data_type,
prompt=s.prompt,
negative_prompt=s.negative_prompt,
image_path=s.image_path,
height=s.height,
width=s.width,
Expand Down Expand Up @@ -160,12 +161,18 @@ async def edits(
if not images or len(images) == 0:
raise HTTPException(status_code=422, detail="Field 'image' is required")

# Save first input image; additional images or mask are not yet used by the pipeline
# Save all input images; additional images beyond the first are saved for potential future use
uploads_dir = os.path.join("outputs", "uploads")
os.makedirs(uploads_dir, exist_ok=True)
first_image = images[0]
input_path = os.path.join(uploads_dir, f"{request_id}_{first_image.filename}")
await _save_upload_to_path(first_image, input_path)
if images is not None and not isinstance(images, list):
images = [images]

input_paths = []
for idx, img in enumerate(images):
filename = img.filename or f"image_{idx}"
input_path = os.path.join(uploads_dir, f"{request_id}_{idx}_{filename}")
await _save_upload_to_path(img, input_path)
input_paths.append(input_path)

sampling = _build_sampling_params_from_request(
request_id=request_id,
Expand All @@ -174,7 +181,7 @@ async def edits(
size=size,
output_format=output_format,
background=background,
image_path=input_path,
image_path=input_paths,
)
batch = _build_req_from_sampling(sampling)

Expand All @@ -186,6 +193,8 @@ async def edits(
"id": request_id,
"created_at": int(time.time()),
"file_path": save_file_path,
"input_image_paths": input_paths, # Store all input image paths
"num_input_images": len(input_paths),
},
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,4 +189,8 @@ def create_pipeline_stages(self, server_args: ServerArgs):
)


EntryClass = [QwenImagePipeline, QwenImageEditPipeline]
class QwenImageEditPlusPipeline(QwenImageEditPipeline):
pipeline_name = "QwenImageEditPlusPipeline"


EntryClass = [QwenImagePipeline, QwenImageEditPipeline, QwenImageEditPlusPipeline]
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ class Req:
image_embeds: list[torch.Tensor] = field(default_factory=list)

original_condition_image_size: tuple[int, int] = None
condition_image: torch.Tensor | PIL.Image.Image | None = None
condition_image: torch.Tensor | PIL.Image.Image | list[PIL.Image.Image] | None = (
None
)
pixel_values: torch.Tensor | PIL.Image.Image | None = None
preprocessed_image: torch.Tensor | None = None

Expand Down
Loading
Loading