From ec9a82fc3f307a8d1f2e1890b86df002b4923ec4 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Wed, 7 Jan 2026 15:31:11 +0800 Subject: [PATCH 01/56] init --- .../en/api/models/glm_image_transformer2d.md | 18 + docs/source/en/api/pipelines/glm_image.md | 31 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/embeddings.py | 31 + src/diffusers/models/transformers/__init__.py | 1 + .../transformers/transformer_glm_image.py | 753 +++++++++++++++++ src/diffusers/pipelines/glm_image/__init__.py | 47 ++ .../pipelines/glm_image/pipeline_glm_image.py | 786 ++++++++++++++++++ .../pipelines/glm_image/pipeline_output.py | 21 + 9 files changed, 1690 insertions(+) create mode 100644 docs/source/en/api/models/glm_image_transformer2d.md create mode 100644 docs/source/en/api/pipelines/glm_image.md create mode 100644 src/diffusers/models/transformers/transformer_glm_image.py create mode 100644 src/diffusers/pipelines/glm_image/__init__.py create mode 100644 src/diffusers/pipelines/glm_image/pipeline_glm_image.py create mode 100644 src/diffusers/pipelines/glm_image/pipeline_output.py diff --git a/docs/source/en/api/models/glm_image_transformer2d.md b/docs/source/en/api/models/glm_image_transformer2d.md new file mode 100644 index 000000000000..d31557c9f7ff --- /dev/null +++ b/docs/source/en/api/models/glm_image_transformer2d.md @@ -0,0 +1,18 @@ + + +# GlmImageDecoderTransformer2DModel + +A Diffusion Transformer model for 2D data from [GlmImageDecoderTransformer2DModel]() + +## GlmImageDecoderTransformer2DModel + +[[autodoc]] GlmImageDecoderTransformer2DModel diff --git a/docs/source/en/api/pipelines/glm_image.md b/docs/source/en/api/pipelines/glm_image.md new file mode 100644 index 000000000000..24b5e14a1ae7 --- /dev/null +++ b/docs/source/en/api/pipelines/glm_image.md @@ -0,0 +1,31 @@ + + +# GLM-Image + +> [!TIP] +> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. + +This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/zai-org). The original weights can be found under [hf.co/zai-org](https://huggingface.co/zai-org). + +## GlmImageDecoderPipeline + +[[autodoc]] GlmImageDecoderPipeline + - all + - __call__ + +## GlmImageDecoderPipelineOutput + +[[autodoc]] pipelines.cogview4.pipeline_output.GlmImageDecoderPipelineOutput diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index c4664f00cad2..29bf6016deff 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -96,6 +96,7 @@ _import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] _import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"] + _import_structure["transformers.transformer_glm_image"] = ["GlmImageDecoderTransformer2DModel"] _import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"] _import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"] _import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"] @@ -203,6 +204,7 @@ EasyAnimateTransformer3DModel, Flux2Transformer2DModel, FluxTransformer2DModel, + GlmImageDecoderTransformer2DModel, HiDreamImageTransformer2DModel, HunyuanDiT2DModel, HunyuanImageTransformer2DModel, diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 37fc412adcc3..14fd9e2de974 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1658,6 +1658,37 @@ def forward( return conditioning +class GlmImageDecoderCombinedTimestepSizeEmbeddings(nn.Module): + def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256): + super().__init__() + + self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.condition_proj = Timesteps(num_channels=condition_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim) + self.condition_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") + + def forward( + self, + timestep: torch.Tensor, + target_size: torch.Tensor, + crop_coords: torch.Tensor, + hidden_dtype: torch.dtype, + ) -> torch.Tensor: + timesteps_proj = self.time_proj(timestep) + + crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.size(0), -1) + target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.size(0), -1) + + # (B, 2 * condition_dim) + condition_proj = torch.cat([crop_coords_proj, target_size_proj], dim=1) + + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim) + condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim) + + conditioning = timesteps_emb + condition_emb + return conditioning + + class HunyuanDiTAttentionPool(nn.Module): # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6 diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 40b5d4a0dfc9..ea051624f2dd 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -27,6 +27,7 @@ from .transformer_easyanimate import EasyAnimateTransformer3DModel from .transformer_flux import FluxTransformer2DModel from .transformer_flux2 import Flux2Transformer2DModel + from .transformer_glm_image import GlmImageDecoderTransformer2DModel from .transformer_hidream_image import HiDreamImageTransformer2DModel from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel diff --git a/src/diffusers/models/transformers/transformer_glm_image.py b/src/diffusers/models/transformers/transformer_glm_image.py new file mode 100644 index 000000000000..54a1c7cbdb49 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_glm_image.py @@ -0,0 +1,753 @@ +# Copyright 2025 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ...utils import logging +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import FeedForward +from ..attention_processor import Attention +from ..cache_utils import CacheMixin +from ..embeddings import GlmImageDecoderCombinedTimestepSizeEmbeddings +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import LayerNorm, RMSNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class GlmImageDecoderImageProjector(nn.Module): + def __init__( + self, + in_channels: int = 16, + hidden_size: int = 2560, + patch_size: int = 2, + ): + super().__init__() + self.patch_size = patch_size + + self.proj = nn.Linear(in_channels * patch_size**2, hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, channel, height, width = hidden_states.shape + post_patch_height = height // self.patch_size + post_patch_width = width // self.patch_size + + hidden_states = hidden_states.reshape( + batch_size, channel, post_patch_height, self.patch_size, post_patch_width, self.patch_size + ) + hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2) + hidden_states = self.proj(hidden_states) + + return hidden_states + + +class GlmImageDecoderAdaLayerNormZero(nn.Module): + def __init__(self, embedding_dim: int, dim: int) -> None: + super().__init__() + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.norm_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True) + + def forward( + self, hidden_states: torch.Tensor, glyph_hidden_states: torch.Tensor, temb: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + dtype = hidden_states.dtype + norm_hidden_states = self.norm(hidden_states).to(dtype=dtype) + norm_glyph_hidden_states = self.norm_context(glyph_hidden_states).to(dtype=dtype) + + emb = self.linear(temb) + ( + shift_msa, + c_shift_msa, + scale_msa, + c_scale_msa, + gate_msa, + c_gate_msa, + shift_mlp, + c_shift_mlp, + scale_mlp, + c_scale_mlp, + gate_mlp, + c_gate_mlp, + ) = emb.chunk(12, dim=1) + + hidden_states = norm_hidden_states * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + glyph_hidden_states = norm_glyph_hidden_states * (1 + c_scale_msa.unsqueeze(1)) + c_shift_msa.unsqueeze(1) + + return ( + hidden_states, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + glyph_hidden_states, + c_gate_msa, + c_shift_mlp, + c_scale_mlp, + c_gate_mlp, + ) + + +class GlmImageDecoderAttenProcessorState(Enum): + ImageGen = "ImageGen" + ImageEditWriteKV = "ImageEditWriteKV" + ImageEditReadKV = "ImageEditReadKV" + ImageEditDontReadKV = "ImageEditNoReadKV" + + +class GlmImageDecoderAttnProcessor: + """ + Processor for implementing scaled dot-product attention for the GlmImageDecoder model. It applies a rotary + embedding on query and key vectors, but does not include spatial normalization. + + The processor supports passing an attention mask for text tokens. The attention mask should have shape (batch_size, + text_seq_length) where 1 indicates a non-padded token and 0 indicates a padded token. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "GlmImageDecoderAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0." + ) + self.processor_state = GlmImageDecoderAttenProcessorState.ImageGen + self.k_cache = None + self.v_cache = None + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + dtype = encoder_hidden_states.dtype + + batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape + batch_size, image_seq_length, embed_dim = hidden_states.shape + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + # 1. QKV projections + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + # 2. QK normalization + if attn.norm_q is not None: + query = attn.norm_q(query).to(dtype=dtype) + if attn.norm_k is not None: + key = attn.norm_k(key).to(dtype=dtype) + + # 3. Rotational positional embeddings applied to latent stream + if image_rotary_emb is not None: + from ..embeddings import apply_rotary_emb + + query[:, :, text_seq_length:, :] = apply_rotary_emb( + query[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2 + ) + key[:, :, text_seq_length:, :] = apply_rotary_emb( + key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2 + ) + + if self.processor_state == GlmImageDecoderAttenProcessorState.ImageEditWriteKV: + self.k_cache = key if self.k_cache is None else torch.cat([self.k_cache, key], dim=2) + self.v_cache = value if self.v_cache is None else torch.cat([self.v_cache, value], dim=2) + elif self.processor_state == GlmImageDecoderAttenProcessorState.ImageEditReadKV: + key = torch.cat([self.k_cache, key], dim=2) if self.k_cache is not None else key + value = torch.cat([self.v_cache, value], dim=2) if self.v_cache is not None else value + + # 4. Attention + if attention_mask is not None: + text_attn_mask = attention_mask + assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)" + text_attn_mask = text_attn_mask.float().to(query.device) + mix_attn_mask = torch.ones((batch_size, text_seq_length + image_seq_length), device=query.device) + mix_attn_mask[:, :text_seq_length] = text_attn_mask + mix_attn_mask = mix_attn_mask.unsqueeze(2) + attn_mask_matrix = mix_attn_mask @ mix_attn_mask.transpose(1, 2) + attention_mask = (attn_mask_matrix > 0).unsqueeze(1).to(query.dtype) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + # 5. Output projection + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + ) + return hidden_states, encoder_hidden_states + + +@maybe_allow_in_graph +class GlmImageDecoderTransformerBlock(nn.Module): + def __init__( + self, + dim: int = 2560, + num_attention_heads: int = 64, + attention_head_dim: int = 40, + time_embed_dim: int = 512, + ) -> None: + super().__init__() + + # 1. Attention + self.norm1 = GlmImageDecoderAdaLayerNormZero(time_embed_dim, dim) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + out_dim=dim, + bias=True, + qk_norm="layer_norm", + elementwise_affine=False, + eps=1e-5, + processor=GlmImageDecoderAttnProcessor(), + ) + + # 2. Feedforward + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + glyph_hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[ + Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]] + ] = None, + attention_mask: Optional[Dict[str, torch.Tensor]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 1. Timestep conditioning + ( + norm_hidden_states, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + norm_glyph_hidden_states, + c_gate_msa, + c_shift_mlp, + c_scale_mlp, + c_gate_mlp, + ) = self.norm1(hidden_states, glyph_hidden_states, temb) + + # 2. Attention + if attention_kwargs is None: + attention_kwargs = {} + + attn_hidden_states, attn_glyph_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_glyph_hidden_states, + image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, + **attention_kwargs, + ) + hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1) + glyph_hidden_states = glyph_hidden_states + attn_glyph_hidden_states * c_gate_msa.unsqueeze(1) + + # 3. Feedforward + norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + norm_glyph_hidden_states = self.norm2_context(glyph_hidden_states) * ( + 1 + c_scale_mlp.unsqueeze(1) + ) + c_shift_mlp.unsqueeze(1) + + ff_output = self.ff(norm_hidden_states) + ff_output_context = self.ff(norm_glyph_hidden_states) + hidden_states = hidden_states + ff_output * gate_mlp.unsqueeze(1) + glyph_hidden_states = glyph_hidden_states + ff_output_context * c_gate_mlp.unsqueeze(1) + + return hidden_states, glyph_hidden_states + + +class GlmImageDecoderRotaryPosEmbed(nn.Module): + def __init__(self, dim: int, patch_size: int, theta: float = 10000.0) -> None: + super().__init__() + + self.dim = dim + self.patch_size = patch_size + self.theta = theta + + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size, num_channels, height, width = hidden_states.shape + height, width = height // self.patch_size, width // self.patch_size + + dim_h, dim_w = self.dim // 2, self.dim // 2 + h_inv_freq = 1.0 / ( + self.theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h) + ) + w_inv_freq = 1.0 / ( + self.theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w) + ) + h_seq = torch.arange(height) + w_seq = torch.arange(width) + freqs_h = torch.outer(h_seq, h_inv_freq) + freqs_w = torch.outer(w_seq, w_inv_freq) + + # Create position matrices for height and width + # [height, 1, dim//4] and [1, width, dim//4] + freqs_h = freqs_h.unsqueeze(1) + freqs_w = freqs_w.unsqueeze(0) + # Broadcast freqs_h and freqs_w to [height, width, dim//4] + freqs_h = freqs_h.expand(height, width, -1) + freqs_w = freqs_w.expand(height, width, -1) + + # Concatenate along last dimension to get [height, width, dim//2] + freqs = torch.cat([freqs_h, freqs_w], dim=-1) + freqs = torch.cat([freqs, freqs], dim=-1) # [height, width, dim] + freqs = freqs.reshape(height * width, -1) + return (freqs.cos(), freqs.sin()) + + +class GlmImageDecoderAdaLayerNormContinuous(nn.Module): + """ + GlmImageDecoder-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. Matches Megatron: **no activation** + before the Linear on conditioning embedding. + """ + + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + elementwise_affine: bool = True, + eps: float = 1e-5, + bias: bool = True, + norm_type: str = "layer_norm", + ): + super().__init__() + self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) + if norm_type == "layer_norm": + self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) + elif norm_type == "rms_norm": + self.norm = RMSNorm(embedding_dim, eps, elementwise_affine) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: + # *** NO SiLU here *** + emb = self.linear(conditioning_embedding.to(x.dtype)) + scale, shift = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +class GlmImageDecoderTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin): + r""" + Args: + patch_size (`int`, defaults to `2`): + The size of the patches to use in the patch embedding layer. + in_channels (`int`, defaults to `16`): + The number of channels in the input. + num_layers (`int`, defaults to `30`): + The number of layers of Transformer blocks to use. + attention_head_dim (`int`, defaults to `40`): + The number of channels in each head. + num_attention_heads (`int`, defaults to `64`): + The number of heads to use for multi-head attention. + out_channels (`int`, defaults to `16`): + The number of channels in the output. + text_embed_dim (`int`, defaults to `4096`): + Input dimension of text embeddings from the text encoder. + time_embed_dim (`int`, defaults to `512`): + Output dimension of timestep embeddings. + condition_dim (`int`, defaults to `256`): + The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size, + crop_coords). + pos_embed_max_size (`int`, defaults to `128`): + The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added + to input patched latents, where `H` and `W` are the latent height and width respectively. A value of 128 + means that the maximum supported height and width for image generation is `128 * vae_scale_factor * + patch_size => 128 * 8 * 2 => 2048`. + sample_size (`int`, defaults to `128`): + The base resolution of input latents. If height/width is not provided during generation, this value is used + to determine the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024` + """ + + _supports_gradient_checkpointing = True + _no_split_modules = [ + "GlmImageDecoderTransformerBlock", + "GlmImageDecoderImageProjector", + "GlmImageDecoderImageProjector", + ] + _skip_layerwise_casting_patterns = ["patch_embed", "norm", "proj_out"] + + @register_to_config + def __init__( + self, + patch_size: int = 2, + in_channels: int = 16, + out_channels: int = 16, + num_layers: int = 30, + attention_head_dim: int = 40, + num_attention_heads: int = 64, + text_embed_dim: int = 4096, + glyph_embed_dim: int = 1472, + time_embed_dim: int = 512, + condition_dim: int = 256, + pos_embed_max_size: int = 128, + sample_size: int = 128, + prior_vq_quantizer_codebook_size: int = 16384, + ): + super().__init__() + + # GlmImageDecoder uses 2 additional SDXL-like conditions - target_size, crop_coords + # Each of these are sincos embeddings of shape 2 * condition_dim + pooled_projection_dim = 2 * 2 * condition_dim + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels + + # 1. RoPE + self.rope = GlmImageDecoderRotaryPosEmbed(attention_head_dim, patch_size, theta=10000.0) + + # 2. Patch & Text-timestep embedding + self.image_projector = GlmImageDecoderImageProjector(in_channels, inner_dim, patch_size) + # 这次没有,未来可能有text_projector + # self.text_projector = FeedForward(text_embed_dim, inner_dim, activation_fn="gelu") + self.glyph_projector = FeedForward(glyph_embed_dim, inner_dim, inner_dim=inner_dim, activation_fn="gelu") + + self.prior_token_embedding = nn.Embedding(prior_vq_quantizer_codebook_size, inner_dim) + self.prior_projector = FeedForward(inner_dim, inner_dim, inner_dim=inner_dim, activation_fn="linear-silu") + + self.time_condition_embed = GlmImageDecoderCombinedTimestepSizeEmbeddings( + embedding_dim=time_embed_dim, + condition_dim=condition_dim, + pooled_projection_dim=pooled_projection_dim, + timesteps_dim=time_embed_dim, + ) + + # 3. Transformer blocks + self.transformer_blocks = nn.ModuleList( + [ + GlmImageDecoderTransformerBlock(inner_dim, num_attention_heads, attention_head_dim, time_embed_dim) + for _ in range(num_layers) + ] + ) + + # 4. Output projection + self.norm_out = GlmImageDecoderAdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + glyph_hidden_states: torch.Tensor, + prior_token_id: torch.Tensor, + prior_token_drop: torch.Tensor, + timestep: torch.LongTensor, + original_size: torch.Tensor, + target_size: torch.Tensor, + crop_coords: torch.Tensor, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[ + Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]] + ] = None, + ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]: + batch_size, num_channels, height, width = hidden_states.shape + + # 1. RoPE + if image_rotary_emb is None: + image_rotary_emb = self.rope(hidden_states) + + # 2. Patch & Timestep embeddings + p = self.config.patch_size + post_patch_height = height // p + post_patch_width = width // p + + hidden_states = self.image_projector(hidden_states) + glyph_hidden_states = self.glyph_projector(glyph_hidden_states) + prior_embedding = self.prior_token_embedding(prior_token_id) + prior_embedding[prior_token_drop] *= 0.0 + prior_hidden_states = self.prior_projector(prior_embedding) + + hidden_states = hidden_states + prior_hidden_states + + temb = self.time_condition_embed(timestep, target_size, crop_coords, hidden_states.dtype) + temb = F.silu(temb) + + # 3. Transformer blocks + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states, glyph_hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + glyph_hidden_states, + temb, + image_rotary_emb, + attention_mask, + attention_kwargs, + ) + else: + hidden_states, glyph_hidden_states = block( + hidden_states, + glyph_hidden_states, + temb, + image_rotary_emb, + attention_mask, + attention_kwargs, + ) + + # 4. Output norm & projection + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + # 5. Unpatchify + hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p) + output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) + + def set_attention_processors_state(self, state: GlmImageDecoderAttenProcessorState): + for block in self.transformer_blocks: + block.attn1.processor.processor_state = state + + def clear_attention_processors_cache(self): + for block in self.transformer_blocks: + block.attn1.processor.k_cache = None + block.attn1.processor.v_cache = None + + def repeat_attention_processors_cache(self, repeats: int): + for block in self.transformer_blocks: + if block.attn1.processor.k_cache is None or block.attn1.processor.v_cache is None: + continue + block.attn1.processor.k_cache = torch.repeat_interleave(block.attn1.processor.k_cache, repeats, dim=2) + block.attn1.processor.v_cache = torch.repeat_interleave(block.attn1.processor.v_cache, repeats, dim=2) + + +if __name__ == "__main__": + + def swap_scale_shift(weight, dim): + """ + Swap the scale and shift components in the weight tensor. + + Args: + weight (torch.Tensor): The original weight tensor. + dim (int): The dimension along which to split. + + Returns: + torch.Tensor: The modified weight tensor with scale and shift swapped. + """ + shift, scale = weight.chunk(2, dim=dim) + new_weight = torch.cat([scale, shift], dim=dim) + return new_weight + + def convert_megatron_transformer_checkpoint_to_diffusers( + ckpt_path: str, + num_layers: int, + num_heads: int, + hidden_size: int, + ): + """ + Convert a Megatron Transformer checkpoint to Diffusers format. + + Args: + ckpt_path (str): Path to the Megatron Transformer checkpoint. + num_layers (int): Number of Transformer layers. + num_heads (int): Number of attention heads. + hidden_size (int): Hidden size of the Transformer. + + Returns: + dict: The converted state dictionary compatible with Diffusers. + """ + ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) + mega = ckpt["model"] + used_keys = set() + + def get_mega(key): + used_keys.add(key) + return mega[key] + + new_state_dict = {} + + # Patch Embedding + new_state_dict["image_projector.proj.weight"] = get_mega("encoder_expand_linear.weight").reshape( + hidden_size, 64 + ) + new_state_dict["image_projector.proj.bias"] = get_mega("encoder_expand_linear.bias") + + new_state_dict["glyph_projector.net.0.proj.weight"] = get_mega("glyph_projector.linear_fc1.weight") + new_state_dict["glyph_projector.net.0.proj.bias"] = get_mega("glyph_projector.linear_fc1.bias") + new_state_dict["glyph_projector.net.2.weight"] = get_mega("glyph_projector.linear_fc2.weight") + new_state_dict["glyph_projector.net.2.bias"] = get_mega("glyph_projector.linear_fc2.bias") + + new_state_dict["prior_token_embedding.weight"] = get_mega("xomni_token_id_embedding.weight") + new_state_dict["prior_projector.net.0.proj.weight"] = get_mega("prior_condition_embedding.0.weight") + new_state_dict["prior_projector.net.0.proj.bias"] = get_mega("prior_condition_embedding.0.bias") + new_state_dict["prior_projector.net.2.weight"] = get_mega("prior_condition_embedding.2.weight") + new_state_dict["prior_projector.net.2.bias"] = get_mega("prior_condition_embedding.2.bias") + + # Time Condition Embedding + new_state_dict["time_condition_embed.timestep_embedder.linear_1.weight"] = get_mega( + "time_embedding.time_embed.0.weight" + ) + new_state_dict["time_condition_embed.timestep_embedder.linear_1.bias"] = get_mega( + "time_embedding.time_embed.0.bias" + ) + new_state_dict["time_condition_embed.timestep_embedder.linear_2.weight"] = get_mega( + "time_embedding.time_embed.2.weight" + ) + new_state_dict["time_condition_embed.timestep_embedder.linear_2.bias"] = get_mega( + "time_embedding.time_embed.2.bias" + ) + + new_state_dict["time_condition_embed.condition_embedder.linear_1.weight"] = get_mega( + "label_embedding.label_embed.0.weight" + ) + new_state_dict["time_condition_embed.condition_embedder.linear_1.bias"] = get_mega( + "label_embedding.label_embed.0.bias" + ) + new_state_dict["time_condition_embed.condition_embedder.linear_2.weight"] = get_mega( + "label_embedding.label_embed.2.weight" + ) + new_state_dict["time_condition_embed.condition_embedder.linear_2.bias"] = get_mega( + "label_embedding.label_embed.2.bias" + ) + + # Convert each Transformer layer + from tqdm import tqdm + + for i in tqdm(range(num_layers), desc="Converting layers (Megatron->Diffusers)"): + block_prefix = f"transformer_blocks.{i}." + + # AdaLayerNorm + new_state_dict[block_prefix + "norm1.linear.weight"] = get_mega(f"decoder.layers.{i}.adaln.weight") + new_state_dict[block_prefix + "norm1.linear.bias"] = get_mega(f"decoder.layers.{i}.adaln.bias") + qkv_weight = get_mega(f"decoder.layers.{i}.self_attention.linear_qkv.weight") + qkv_bias = get_mega(f"decoder.layers.{i}.self_attention.linear_qkv.bias") + + # Reshape to match SAT logic + qkv_weight = qkv_weight.view(num_heads, 3, hidden_size // num_heads, hidden_size) + qkv_weight = qkv_weight.permute(1, 0, 2, 3).reshape(3 * hidden_size, hidden_size) + + qkv_bias = qkv_bias.view(num_heads, 3, hidden_size // num_heads) + qkv_bias = qkv_bias.permute(1, 0, 2).reshape(3 * hidden_size) + + # Assign to Diffusers keys + q, k, v = torch.chunk(qkv_weight, 3, dim=0) + qb, kb, vb = torch.chunk(qkv_bias, 3, dim=0) + + new_state_dict[block_prefix + "attn1.to_q.weight"] = q + new_state_dict[block_prefix + "attn1.to_q.bias"] = qb + new_state_dict[block_prefix + "attn1.to_k.weight"] = k + new_state_dict[block_prefix + "attn1.to_k.bias"] = kb + new_state_dict[block_prefix + "attn1.to_v.weight"] = v + new_state_dict[block_prefix + "attn1.to_v.bias"] = vb + + # Attention Output + new_state_dict[block_prefix + "attn1.to_out.0.weight"] = get_mega( + f"decoder.layers.{i}.self_attention.linear_proj.weight" + ) + new_state_dict[block_prefix + "attn1.to_out.0.bias"] = get_mega( + f"decoder.layers.{i}.self_attention.linear_proj.bias" + ) + + # MLP + new_state_dict[block_prefix + "ff.net.0.proj.weight"] = get_mega( + f"decoder.layers.{i}.mlp.linear_fc1.weight" + ) + new_state_dict[block_prefix + "ff.net.0.proj.bias"] = get_mega(f"decoder.layers.{i}.mlp.linear_fc1.bias") + new_state_dict[block_prefix + "ff.net.2.weight"] = get_mega(f"decoder.layers.{i}.mlp.linear_fc2.weight") + new_state_dict[block_prefix + "ff.net.2.bias"] = get_mega(f"decoder.layers.{i}.mlp.linear_fc2.bias") + + # Final Layers + new_state_dict["norm_out.linear.weight"] = swap_scale_shift(get_mega("adaln_final.weight"), dim=0) + new_state_dict["norm_out.linear.bias"] = swap_scale_shift(get_mega("adaln_final.bias"), dim=0) + new_state_dict["proj_out.weight"] = get_mega("output_projector.weight") + new_state_dict["proj_out.bias"] = get_mega("output_projector.bias") + + # Check for unused keys + all_keys = set(mega.keys()) + unused_keys = all_keys - used_keys + if unused_keys: + print(f"\n[WARNING] The following {len(unused_keys)} keys in mega were NOT used:") + for key in sorted(unused_keys): + print(f" - {key}") + raise ValueError( + f"Found {len(unused_keys)} unused keys in Megatron checkpoint. Please update the conversion script to handle these keys." + ) + else: + print(f"\n[INFO] All {len(all_keys)} keys in mega were successfully used.") + + return new_state_dict + + transformer = GlmImageDecoderTransformer2DModel( + patch_size=2, + in_channels=16, + num_layers=30, + attention_head_dim=128, + num_attention_heads=32, + out_channels=16, + text_embed_dim=4096, + time_embed_dim=512, + glyph_embed_dim=1472, + condition_dim=256, + pos_embed_max_size=128, + ).to(torch.bfloat16) + converted_transformer_state_dict = convert_megatron_transformer_checkpoint_to_diffusers( + ckpt_path="/workspace/ckpt/tjy/Glm-train-dev/examples/cogview/ckpts/merge/1+6_0.5+0.5/iter_0000000/mp_rank_00/model_optim_rng.pt", + num_layers=30, + num_heads=32, + hidden_size=4096, + ) + transformer.load_state_dict(converted_transformer_state_dict) + transformer.cuda() + + latent = torch.load("/workspace/ckpt/tjy/glm-train-dev/examples/cogview/latent.pt").to(torch.bfloat16) + latent = rearrange(latent, "(b h w) (c p q) -> b c (h p) (w q)", b=8, h=72, w=54, p=2, q=2) + glyph_hidden_states = torch.load( + "/workspace/ckpt/tjy/glm-train-dev/examples/cogview/glyph_condition_embedding.pt" + ).to(torch.bfloat16) + glyph_hidden_states = rearrange(glyph_hidden_states, "(b n) c -> b n c", b=8, n=2) + prior_token_id = torch.load("/workspace/ckpt/tjy/glm-train-dev/examples/cogview/xomni_token_id.pt") + prior_token_drop = torch.load("/workspace/ckpt/tjy/glm-train-dev/examples/cogview/xomni_drop.pt") + prior_token_id = rearrange(prior_token_id, "(b n) -> b n", b=8) + prior_token_drop = rearrange(prior_token_drop, "(b n)-> b n", b=8) + + with torch.no_grad(): + output = transformer( + hidden_states=latent, + glyph_hidden_states=glyph_hidden_states, + prior_token_id=prior_token_id, + prior_token_drop=prior_token_drop, + timestep=torch.tensor([999.0] * 8).cuda(), + original_size=torch.tensor([[144, 108]] * 8).cuda(), + target_size=torch.tensor([[144, 108]] * 8).cuda(), + crop_coords=torch.tensor([[0, 0]] * 8).cuda(), + ) diff --git a/src/diffusers/pipelines/glm_image/__init__.py b/src/diffusers/pipelines/glm_image/__init__.py new file mode 100644 index 000000000000..24c20c4ef5ac --- /dev/null +++ b/src/diffusers/pipelines/glm_image/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["GlmImageDecoderPipelineOutput"]} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_glm_image"] = ["GlmImageDecoderPipeline"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_glm_image import GlmImageDecoderPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py new file mode 100644 index 000000000000..825952be4b98 --- /dev/null +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -0,0 +1,786 @@ +# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import re +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +from transformers import AutoTokenizer, T5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor +from ...loaders import CogView4LoraLoaderMixin +from ...models import AutoencoderKL, GlmImageDecoderTransformer2DModel +from ...models.transformers.transformer_glm_image import GlmImageDecoderAttenProcessorState +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import GlmImageDecoderPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import GlmImageDecoderPipeline + + >>> pipe = GlmImageDecoderPipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "A photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + >>> image.save("output.png") + ``` +""" + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + base_shift: float = 0.25, + max_shift: float = 0.75, +) -> float: + m = (image_seq_len / base_seq_len) ** 0.5 + mu = m * max_shift + base_shift + return mu + + +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + + if timesteps is not None and sigmas is not None: + if not accepts_timesteps and not accepts_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep or sigma schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif timesteps is not None and sigmas is None: + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif timesteps is None and sigmas is not None: + if not accepts_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin): + r""" + Pipeline for text-to-image generation using CogView4. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`GlmModel`]): + Frozen text-encoder. CogView4 uses [Glm-4-9b-hf](https://huggingface.co/THUDM/Glm-4-9b-hf). + tokenizer (`PreTrainedTokenizer`): + Tokenizer of class + [PreTrainedTokenizer](https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer#transformers.PreTrainedTokenizer). + transformer ([`CogView4Transformer2DModel`]): + A text conditioned `CogView4Transformer2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: AutoTokenizer, + glyph_encoder: T5EncoderModel, + vae: AutoencoderKL, + transformer: GlmImageDecoderTransformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, glyph_encoder=glyph_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") + and self.transformer is not None + and hasattr(self.transformer.config, "sample_size") + else 128 + ) + + def get_glyph_texts( + self, + prompt, + ): + prompt = prompt[0] if isinstance(prompt, list) else prompt + ocr_texts = ( + re.findall(r"'([^']*)'", prompt) + + re.findall(r"“([^“”]*)”", prompt) + + re.findall(r'"([^"]*)"', prompt) + + re.findall(r"「([^「」]*)」", prompt) + ) + return ocr_texts + + def _get_glyph_embeds( + self, + prompt: Union[str, List[str]] = None, + max_sequence_length: int = 2048, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.glyph_encoder.dtype + + glyph_texts = self.get_glyph_texts(prompt) + input_ids = self.tokenizer( + glyph_texts if len(glyph_texts) > 0 else [""], + max_length=max_sequence_length, + truncation=True, + ).input_ids + input_ids = [ + [self.tokenizer.pad_token_id] * ((len(input_ids) + 1) % 2) + input_ids_ for input_ids_ in input_ids + ] + max_length = max(len(input_ids_) for input_ids_ in input_ids) + attention_mask = torch.tensor( + [[1] * len(input_ids_) + [0] * (max_length - len(input_ids_)) for input_ids_ in input_ids], device=device + ) + input_ids = torch.tensor( + [input_ids_ + [self.tokenizer.pad_token_id] * (max_length - len(input_ids_)) for input_ids_ in input_ids], + device=device, + ) + outputs = self.glyph_encoder(input_ids, attention_mask=attention_mask) + glyph_embeds = outputs.last_hidden_state[attention_mask.bool()].unsqueeze(0) + + return glyph_embeds.to(device=device, dtype=dtype) + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 2048, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_images_per_prompt (`int`, *optional*, defaults to 1): + Number of images that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + max_sequence_length (`int`, defaults to `2048`): + Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_glyph_embeds(prompt, max_sequence_length, device, dtype) + + seq_len = prompt_embeds.size(1) + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_glyph_embeds(negative_prompt, max_sequence_length, device, dtype) + + seq_len = negative_prompt_embeds.size(1) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + if latents is not None: + return latents.to(device) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if ( + height is not None + and height % (self.vae_scale_factor * self.transformer.config.patch_size) != 0 + or width is not None + and width % (self.transformer.config.patch_size) != 0 + ): + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prior_token_id: Optional[torch.LongTensor] = None, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + condition_images: Optional[ + Union[ + torch.Tensor, PIL.Image.Image, np.ndarray, List[torch.Tensor], List[PIL.Image.Image], List[np.ndarray] + ] + ] = None, + condition_images_prior_token_id: Optional[Union[torch.LongTensor, List[torch.LongTensor]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 1.5, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 2048, + ) -> Union[GlmImageDecoderPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. If not provided, it is set to 2048. + width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. If not provided it is set to 2048. + num_inference_steps (`int`, *optional*, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to `1`): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + original_size (`Tuple[int]`, *optional*, defaults to (2048, 2048)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `224`): + Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results. + + Examples: + + Returns: + [`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] or `tuple`: + [`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + assert batch_size == 1, "batch_size must be 1" + + device = self._execution_device + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + self.do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + dtype=self.dtype, + ) + + # 4. process images + if condition_images is not None and not isinstance(condition_images, list): + condition_images = [condition_images] + condition_images_prior_token_id = [condition_images_prior_token_id] + assert condition_images is None or (len(condition_images) == len(condition_images_prior_token_id)), ( + "image and image_prior_token_id must be the same length" + ) + + if condition_images is not None: + preprocessed_condition_images = [] + for img in condition_images: + image_height, image_width = img.size[::-1] if isinstance(img, PIL.Image.Image) else img.shape[:2] + multiple_of = self.vae_scale_factor * self.transformer.config.patch_size + image_height = (image_height // multiple_of) * multiple_of + image_width = (image_width // multiple_of) * multiple_of + img = self.image_processor.preprocess(img, height=image_height, width=image_width) + preprocessed_condition_images.append(img) + height = height or image_height + width = width or image_width + condition_images = preprocessed_condition_images + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 5. Prepare latents and (optional) condition_images kv cache + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size=batch_size * num_images_per_prompt, + num_channels_latents=latent_channels, + height=height, + width=width, + dtype=torch.float32, + device=device, + generator=generator, + latents=latents, + ) + + if condition_images is not None: + self.transformer.set_attention_processors_state(GlmImageDecoderAttenProcessorState.ImageEditWriteKV) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.latent_channels, 1, 1) + .to(self.vae.device, self.vae.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, self.vae.config.latent_channels, 1, 1) + .to(self.vae.device, self.vae.dtype) + ) + empty_glyph_hiddens = torch.zeros_like(prompt_embeds)[:1, :0, ...] + for condition_image, condition_image_prior_token_id in zip( + condition_images, condition_images_prior_token_id + ): + condition_image = condition_image.to(device=device, dtype=self.vae.dtype) + condition_latent = retrieve_latents( + self.vae.encode(condition_image), generator=generator, sample_mode="argmax" + ) + condition_latent = (condition_latent - latents_mean) / latents_std + _ = self.transformer( + hidden_states=condition_latent, + glyph_hidden_states=empty_glyph_hiddens, + prior_token_id=condition_image_prior_token_id, + prior_token_drop=torch.full_like(condition_image_prior_token_id, False, dtype=torch.bool), + timestep=torch.zeros((1,), device=device), + original_size=torch.tensor([condition_image.shape[-2:]], device=device), + target_size=torch.tensor([condition_image.shape[-2:]], device=device), + crop_coords=torch.zeros((1, 2), device=device), + attention_kwargs=attention_kwargs, + ) + + # 6. Prepare additional timestep conditions + original_size = original_size or (height, width) + target_size = (height, width) + original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=device) + target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device) + crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device) + + original_size = original_size.repeat(batch_size * num_images_per_prompt, 1) + target_size = target_size.repeat(batch_size * num_images_per_prompt, 1) + crops_coords_top_left = crops_coords_top_left.repeat(batch_size * num_images_per_prompt, 1) + + # Prepare timesteps + image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // ( + self.transformer.config.patch_size**2 + ) + timesteps = ( + np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps + 1)[:-1] + if timesteps is None + else np.array(timesteps) + ) + timesteps = timesteps.astype(np.int64).astype(np.float32) + sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("base_shift", 0.25), + self.scheduler.config.get("max_shift", 0.75), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu + ) + self._num_timesteps = len(timesteps) + + # 7. Denoising loop + transformer_dtype = self.transformer.dtype + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + prior_token_drop_cond = torch.full_like(prior_token_id, False, dtype=torch.bool) + prior_token_drop_uncond = torch.full_like(prior_token_id, True, dtype=torch.bool) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = latents.to(transformer_dtype) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) - 1 + + if condition_images is not None: + self.transformer.set_attention_processors_state(GlmImageDecoderAttenProcessorState.ImageEditReadKV) + + noise_pred_cond = self.transformer( + hidden_states=latent_model_input, + glyph_hidden_states=prompt_embeds, + prior_token_id=prior_token_id, + prior_token_drop=prior_token_drop_cond, + timestep=timestep, + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0].float() + + # perform guidance + if self.do_classifier_free_guidance: + if condition_images is not None: + self.transformer.set_attention_processors_state( + GlmImageDecoderAttenProcessorState.ImageEditDontReadKV + ) + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + glyph_hidden_states=negative_prompt_embeds, + prior_token_id=prior_token_id, + prior_token_drop=prior_token_drop_uncond, + timestep=timestep, + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0].float() + + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + else: + noise_pred = noise_pred_cond + + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs) + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + self.transformer.set_attention_processors_state(GlmImageDecoderAttenProcessorState.ImageGen) + self.transformer.clear_attention_processors_cache() + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.latent_channels, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, self.vae.config.latent_channels, 1, 1) + .to(latents.device, latents.dtype) + ) + latents = latents * latents_std + latents_mean + condition_images = self.vae.decode(latents, return_dict=False, generator=generator)[0] + else: + condition_images = latents + + condition_images = self.image_processor.postprocess(condition_images, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (condition_images,) + + return GlmImageDecoderPipelineOutput(images=condition_images) diff --git a/src/diffusers/pipelines/glm_image/pipeline_output.py b/src/diffusers/pipelines/glm_image/pipeline_output.py new file mode 100644 index 000000000000..a506e527cb50 --- /dev/null +++ b/src/diffusers/pipelines/glm_image/pipeline_output.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class GlmImageDecoderPipelineOutput(BaseOutput): + """ + Output class for CogView3 pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] From b98decfe5fee8bd1be1f5c546b784a5844c8c282 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Wed, 7 Jan 2026 15:35:56 +0800 Subject: [PATCH 02/56] add --- src/diffusers/__init__.py | 2 ++ src/diffusers/pipelines/auto_pipeline.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 3a50634d82d8..e4095f615e04 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -487,6 +487,7 @@ "FluxKontextPipeline", "FluxPipeline", "FluxPriorReduxPipeline", + "GlmImageDecoderPipeline", "HiDreamImagePipeline", "HunyuanDiTControlNetPipeline", "HunyuanDiTPAGPipeline", @@ -1203,6 +1204,7 @@ FluxKontextPipeline, FluxPipeline, FluxPriorReduxPipeline, + GlmImageDecoderPipeline, HiDreamImagePipeline, HunyuanDiTControlNetPipeline, HunyuanDiTPAGPipeline, diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index c14910250b54..eeadeff0327b 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -24,6 +24,7 @@ from .chroma import ChromaPipeline from .cogview3 import CogView3PlusPipeline from .cogview4 import CogView4ControlPipeline, CogView4Pipeline +from .glm_image import GlmImageDecoderPipeline from .controlnet import ( StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline, @@ -167,6 +168,7 @@ ("chroma", ChromaPipeline), ("cogview3", CogView3PlusPipeline), ("cogview4", CogView4Pipeline), + ("glm_image", GlmImageDecoderPipeline), ("cogview4-control", CogView4ControlPipeline), ("qwenimage", QwenImagePipeline), ("qwenimage-controlnet", QwenImageControlNetPipeline), From 57fd26d8fe0c73241eccb7d12a044081fa3b4749 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Wed, 7 Jan 2026 15:41:55 +0800 Subject: [PATCH 03/56] add 1 --- docs/source/en/_toctree.yml | 4 ++++ src/diffusers/pipelines/auto_pipeline.py | 2 +- src/diffusers/utils/dummy_pt_objects.py | 15 +++++++++++++++ 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index f0cb0164436e..456b52d49498 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -353,6 +353,8 @@ title: Flux2Transformer2DModel - local: api/models/flux_transformer title: FluxTransformer2DModel + - local: api/models/glm_image_transformer2d + title: GlmImageDecoderTransformer2DModel - local: api/models/hidream_image_transformer title: HiDreamImageTransformer2DModel - local: api/models/hunyuan_transformer2d @@ -541,6 +543,8 @@ title: Flux2 - local: api/pipelines/control_flux_inpaint title: FluxControlInpaint + - local: api/pipelines/glm_image + title: GLM-Image - local: api/pipelines/hidream title: HiDream-I1 - local: api/pipelines/hunyuandit diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index eeadeff0327b..28b882dfafdb 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -24,7 +24,6 @@ from .chroma import ChromaPipeline from .cogview3 import CogView3PlusPipeline from .cogview4 import CogView4ControlPipeline, CogView4Pipeline -from .glm_image import GlmImageDecoderPipeline from .controlnet import ( StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline, @@ -53,6 +52,7 @@ FluxKontextPipeline, FluxPipeline, ) +from .glm_image import GlmImageDecoderPipeline from .hunyuandit import HunyuanDiTPipeline from .kandinsky import ( KandinskyCombinedPipeline, diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 62426ffbf65c..b36d7f3ccf06 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -967,6 +967,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class GlmImageDecoderTransformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class HunyuanDiT2DControlNetModel(metaclass=DummyObject): _backends = ["torch"] From bcc9c303f6f4be22db8300a47994018de9b078e2 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Wed, 7 Jan 2026 15:59:47 +0800 Subject: [PATCH 04/56] Update __init__.py --- src/diffusers/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index e4095f615e04..4b85f2662a32 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -223,6 +223,7 @@ "FluxControlNetModel", "FluxMultiControlNetModel", "FluxTransformer2DModel", + "GlmImageDecoderTransformer2DModel", "HiDreamImageTransformer2DModel", "HunyuanDiT2DControlNetModel", "HunyuanDiT2DModel", @@ -970,6 +971,7 @@ FluxControlNetModel, FluxMultiControlNetModel, FluxTransformer2DModel, + GlmImageDecoderTransformer2DModel, HiDreamImageTransformer2DModel, HunyuanDiT2DControlNetModel, HunyuanDiT2DModel, From e13fb7655235416ef1167acc37d6606b03cc4f34 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Wed, 7 Jan 2026 16:55:02 +0800 Subject: [PATCH 05/56] rename --- docs/source/en/_toctree.yml | 2 +- .../en/api/models/glm_image_transformer2d.md | 8 +- docs/source/en/api/pipelines/glm_image.md | 8 +- src/diffusers/__init__.py | 8 +- src/diffusers/models/__init__.py | 4 +- src/diffusers/models/embeddings.py | 2 +- src/diffusers/models/transformers/__init__.py | 2 +- .../transformers/transformer_glm_image.py | 319 +++--------------- src/diffusers/pipelines/auto_pipeline.py | 4 +- src/diffusers/pipelines/glm_image/__init__.py | 6 +- .../pipelines/glm_image/pipeline_glm_image.py | 36 +- .../pipelines/glm_image/pipeline_output.py | 2 +- src/diffusers/utils/dummy_pt_objects.py | 2 +- 13 files changed, 93 insertions(+), 310 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 456b52d49498..8e2b9d6c0436 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -354,7 +354,7 @@ - local: api/models/flux_transformer title: FluxTransformer2DModel - local: api/models/glm_image_transformer2d - title: GlmImageDecoderTransformer2DModel + title: GlmImageTransformer2DModel - local: api/models/hidream_image_transformer title: HiDreamImageTransformer2DModel - local: api/models/hunyuan_transformer2d diff --git a/docs/source/en/api/models/glm_image_transformer2d.md b/docs/source/en/api/models/glm_image_transformer2d.md index d31557c9f7ff..8a8b07456046 100644 --- a/docs/source/en/api/models/glm_image_transformer2d.md +++ b/docs/source/en/api/models/glm_image_transformer2d.md @@ -9,10 +9,10 @@ Unless required by applicable law or agreed to in writing, software distributed an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --> -# GlmImageDecoderTransformer2DModel +# GlmImageTransformer2DModel -A Diffusion Transformer model for 2D data from [GlmImageDecoderTransformer2DModel]() +A Diffusion Transformer model for 2D data from [GlmImageTransformer2DModel]() -## GlmImageDecoderTransformer2DModel +## GlmImageTransformer2DModel -[[autodoc]] GlmImageDecoderTransformer2DModel +[[autodoc]] GlmImageTransformer2DModel diff --git a/docs/source/en/api/pipelines/glm_image.md b/docs/source/en/api/pipelines/glm_image.md index 24b5e14a1ae7..c3787cd77b37 100644 --- a/docs/source/en/api/pipelines/glm_image.md +++ b/docs/source/en/api/pipelines/glm_image.md @@ -20,12 +20,12 @@ This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/zai-org). The original weights can be found under [hf.co/zai-org](https://huggingface.co/zai-org). -## GlmImageDecoderPipeline +## GlmImagePipeline -[[autodoc]] GlmImageDecoderPipeline +[[autodoc]] GlmImagePipeline - all - __call__ -## GlmImageDecoderPipelineOutput +## GlmImagePipelineOutput -[[autodoc]] pipelines.cogview4.pipeline_output.GlmImageDecoderPipelineOutput +[[autodoc]] pipelines.cogview4.pipeline_output.GlmImagePipelineOutput diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 4b85f2662a32..ceb52a74099d 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -223,7 +223,7 @@ "FluxControlNetModel", "FluxMultiControlNetModel", "FluxTransformer2DModel", - "GlmImageDecoderTransformer2DModel", + "GlmImageTransformer2DModel", "HiDreamImageTransformer2DModel", "HunyuanDiT2DControlNetModel", "HunyuanDiT2DModel", @@ -488,7 +488,7 @@ "FluxKontextPipeline", "FluxPipeline", "FluxPriorReduxPipeline", - "GlmImageDecoderPipeline", + "GlmImagePipeline", "HiDreamImagePipeline", "HunyuanDiTControlNetPipeline", "HunyuanDiTPAGPipeline", @@ -971,7 +971,7 @@ FluxControlNetModel, FluxMultiControlNetModel, FluxTransformer2DModel, - GlmImageDecoderTransformer2DModel, + GlmImageTransformer2DModel, HiDreamImageTransformer2DModel, HunyuanDiT2DControlNetModel, HunyuanDiT2DModel, @@ -1206,7 +1206,7 @@ FluxKontextPipeline, FluxPipeline, FluxPriorReduxPipeline, - GlmImageDecoderPipeline, + GlmImagePipeline, HiDreamImagePipeline, HunyuanDiTControlNetPipeline, HunyuanDiTPAGPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 29bf6016deff..3851c1d541bd 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -96,7 +96,7 @@ _import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] _import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"] - _import_structure["transformers.transformer_glm_image"] = ["GlmImageDecoderTransformer2DModel"] + _import_structure["transformers.transformer_glm_image"] = ["GlmImageTransformer2DModel"] _import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"] _import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"] _import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"] @@ -204,7 +204,7 @@ EasyAnimateTransformer3DModel, Flux2Transformer2DModel, FluxTransformer2DModel, - GlmImageDecoderTransformer2DModel, + GlmImageTransformer2DModel, HiDreamImageTransformer2DModel, HunyuanDiT2DModel, HunyuanImageTransformer2DModel, diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 14fd9e2de974..1947e1e53490 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1658,7 +1658,7 @@ def forward( return conditioning -class GlmImageDecoderCombinedTimestepSizeEmbeddings(nn.Module): +class GlmImageCombinedTimestepSizeEmbeddings(nn.Module): def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256): super().__init__() diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index ea051624f2dd..5f389979b55e 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -27,7 +27,7 @@ from .transformer_easyanimate import EasyAnimateTransformer3DModel from .transformer_flux import FluxTransformer2DModel from .transformer_flux2 import Flux2Transformer2DModel - from .transformer_glm_image import GlmImageDecoderTransformer2DModel + from .transformer_glm_image import GlmImageTransformer2DModel from .transformer_hidream_image import HiDreamImageTransformer2DModel from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel diff --git a/src/diffusers/models/transformers/transformer_glm_image.py b/src/diffusers/models/transformers/transformer_glm_image.py index 54a1c7cbdb49..4c296b48aabd 100644 --- a/src/diffusers/models/transformers/transformer_glm_image.py +++ b/src/diffusers/models/transformers/transformer_glm_image.py @@ -27,7 +27,7 @@ from ..attention import FeedForward from ..attention_processor import Attention from ..cache_utils import CacheMixin -from ..embeddings import GlmImageDecoderCombinedTimestepSizeEmbeddings +from ..embeddings import GlmImageCombinedTimestepSizeEmbeddings from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import LayerNorm, RMSNorm @@ -36,7 +36,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class GlmImageDecoderImageProjector(nn.Module): +class GlmImageImageProjector(nn.Module): def __init__( self, in_channels: int = 16, @@ -62,7 +62,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class GlmImageDecoderAdaLayerNormZero(nn.Module): +class GlmImageAdaLayerNormZero(nn.Module): def __init__(self, embedding_dim: int, dim: int) -> None: super().__init__() @@ -71,11 +71,11 @@ def __init__(self, embedding_dim: int, dim: int) -> None: self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True) def forward( - self, hidden_states: torch.Tensor, glyph_hidden_states: torch.Tensor, temb: torch.Tensor + self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: dtype = hidden_states.dtype norm_hidden_states = self.norm(hidden_states).to(dtype=dtype) - norm_glyph_hidden_states = self.norm_context(glyph_hidden_states).to(dtype=dtype) + norm_encoder_hidden_states = self.norm_context(encoder_hidden_states).to(dtype=dtype) emb = self.linear(temb) ( @@ -94,7 +94,7 @@ def forward( ) = emb.chunk(12, dim=1) hidden_states = norm_hidden_states * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) - glyph_hidden_states = norm_glyph_hidden_states * (1 + c_scale_msa.unsqueeze(1)) + c_shift_msa.unsqueeze(1) + encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_msa.unsqueeze(1)) + c_shift_msa.unsqueeze(1) return ( hidden_states, @@ -102,7 +102,7 @@ def forward( shift_mlp, scale_mlp, gate_mlp, - glyph_hidden_states, + encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, @@ -110,17 +110,17 @@ def forward( ) -class GlmImageDecoderAttenProcessorState(Enum): +class GlmImageAttenProcessorState(Enum): ImageGen = "ImageGen" ImageEditWriteKV = "ImageEditWriteKV" ImageEditReadKV = "ImageEditReadKV" ImageEditDontReadKV = "ImageEditNoReadKV" -class GlmImageDecoderAttnProcessor: +class GlmImageAttnProcessor: """ - Processor for implementing scaled dot-product attention for the GlmImageDecoder model. It applies a rotary - embedding on query and key vectors, but does not include spatial normalization. + Processor for implementing scaled dot-product attention for the GlmImage model. It applies a rotary embedding on + query and key vectors, but does not include spatial normalization. The processor supports passing an attention mask for text tokens. The attention mask should have shape (batch_size, text_seq_length) where 1 indicates a non-padded token and 0 indicates a padded token. @@ -128,10 +128,8 @@ class GlmImageDecoderAttnProcessor: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "GlmImageDecoderAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0." - ) - self.processor_state = GlmImageDecoderAttenProcessorState.ImageGen + raise ImportError("GlmImageAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") + self.processor_state = GlmImageAttenProcessorState.ImageGen self.k_cache = None self.v_cache = None @@ -175,10 +173,10 @@ def __call__( key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2 ) - if self.processor_state == GlmImageDecoderAttenProcessorState.ImageEditWriteKV: + if self.processor_state == GlmImageAttenProcessorState.ImageEditWriteKV: self.k_cache = key if self.k_cache is None else torch.cat([self.k_cache, key], dim=2) self.v_cache = value if self.v_cache is None else torch.cat([self.v_cache, value], dim=2) - elif self.processor_state == GlmImageDecoderAttenProcessorState.ImageEditReadKV: + elif self.processor_state == GlmImageAttenProcessorState.ImageEditReadKV: key = torch.cat([self.k_cache, key], dim=2) if self.k_cache is not None else key value = torch.cat([self.v_cache, value], dim=2) if self.v_cache is not None else value @@ -210,7 +208,7 @@ def __call__( @maybe_allow_in_graph -class GlmImageDecoderTransformerBlock(nn.Module): +class GlmImageTransformerBlock(nn.Module): def __init__( self, dim: int = 2560, @@ -221,7 +219,7 @@ def __init__( super().__init__() # 1. Attention - self.norm1 = GlmImageDecoderAdaLayerNormZero(time_embed_dim, dim) + self.norm1 = GlmImageAdaLayerNormZero(time_embed_dim, dim) self.attn1 = Attention( query_dim=dim, heads=num_attention_heads, @@ -231,7 +229,7 @@ def __init__( qk_norm="layer_norm", elementwise_affine=False, eps=1e-5, - processor=GlmImageDecoderAttnProcessor(), + processor=GlmImageAttnProcessor(), ) # 2. Feedforward @@ -242,7 +240,7 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - glyph_hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, image_rotary_emb: Optional[ Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]] @@ -257,42 +255,42 @@ def forward( shift_mlp, scale_mlp, gate_mlp, - norm_glyph_hidden_states, + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp, - ) = self.norm1(hidden_states, glyph_hidden_states, temb) + ) = self.norm1(hidden_states, encoder_hidden_states, temb) # 2. Attention if attention_kwargs is None: attention_kwargs = {} - attn_hidden_states, attn_glyph_hidden_states = self.attn1( + attn_hidden_states, attn_encoder_hidden_states = self.attn1( hidden_states=norm_hidden_states, - encoder_hidden_states=norm_glyph_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, attention_mask=attention_mask, **attention_kwargs, ) hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1) - glyph_hidden_states = glyph_hidden_states + attn_glyph_hidden_states * c_gate_msa.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1) # 3. Feedforward norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) - norm_glyph_hidden_states = self.norm2_context(glyph_hidden_states) * ( + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) * ( 1 + c_scale_mlp.unsqueeze(1) ) + c_shift_mlp.unsqueeze(1) ff_output = self.ff(norm_hidden_states) - ff_output_context = self.ff(norm_glyph_hidden_states) + ff_output_context = self.ff(norm_encoder_hidden_states) hidden_states = hidden_states + ff_output * gate_mlp.unsqueeze(1) - glyph_hidden_states = glyph_hidden_states + ff_output_context * c_gate_mlp.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + ff_output_context * c_gate_mlp.unsqueeze(1) - return hidden_states, glyph_hidden_states + return hidden_states, encoder_hidden_states -class GlmImageDecoderRotaryPosEmbed(nn.Module): +class GlmImageRotaryPosEmbed(nn.Module): def __init__(self, dim: int, patch_size: int, theta: float = 10000.0) -> None: super().__init__() @@ -331,10 +329,10 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens return (freqs.cos(), freqs.sin()) -class GlmImageDecoderAdaLayerNormContinuous(nn.Module): +class GlmImageAdaLayerNormContinuous(nn.Module): """ - GlmImageDecoder-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. Matches Megatron: **no activation** - before the Linear on conditioning embedding. + GlmImage-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. Matches Megatron: **no activation** before the + Linear on conditioning embedding. """ def __init__( @@ -363,7 +361,7 @@ def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torc return x -class GlmImageDecoderTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin): +class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin): r""" Args: patch_size (`int`, defaults to `2`): @@ -397,9 +395,9 @@ class GlmImageDecoderTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixi _supports_gradient_checkpointing = True _no_split_modules = [ - "GlmImageDecoderTransformerBlock", - "GlmImageDecoderImageProjector", - "GlmImageDecoderImageProjector", + "GlmImageTransformerBlock", + "GlmImageImageProjector", + "GlmImageImageProjector", ] _skip_layerwise_casting_patterns = ["patch_embed", "norm", "proj_out"] @@ -412,35 +410,30 @@ def __init__( num_layers: int = 30, attention_head_dim: int = 40, num_attention_heads: int = 64, - text_embed_dim: int = 4096, - glyph_embed_dim: int = 1472, + text_embed_dim: int = 1472, time_embed_dim: int = 512, condition_dim: int = 256, - pos_embed_max_size: int = 128, - sample_size: int = 128, prior_vq_quantizer_codebook_size: int = 16384, ): super().__init__() - # GlmImageDecoder uses 2 additional SDXL-like conditions - target_size, crop_coords + # GlmImage uses 2 additional SDXL-like conditions - target_size, crop_coords # Each of these are sincos embeddings of shape 2 * condition_dim pooled_projection_dim = 2 * 2 * condition_dim inner_dim = num_attention_heads * attention_head_dim out_channels = out_channels # 1. RoPE - self.rope = GlmImageDecoderRotaryPosEmbed(attention_head_dim, patch_size, theta=10000.0) + self.rope = GlmImageRotaryPosEmbed(attention_head_dim, patch_size, theta=10000.0) # 2. Patch & Text-timestep embedding - self.image_projector = GlmImageDecoderImageProjector(in_channels, inner_dim, patch_size) - # 这次没有,未来可能有text_projector - # self.text_projector = FeedForward(text_embed_dim, inner_dim, activation_fn="gelu") - self.glyph_projector = FeedForward(glyph_embed_dim, inner_dim, inner_dim=inner_dim, activation_fn="gelu") + self.image_projector = GlmImageImageProjector(in_channels, inner_dim, patch_size) + self.text_projector = FeedForward(text_embed_dim, inner_dim, inner_dim=inner_dim, activation_fn="gelu") self.prior_token_embedding = nn.Embedding(prior_vq_quantizer_codebook_size, inner_dim) self.prior_projector = FeedForward(inner_dim, inner_dim, inner_dim=inner_dim, activation_fn="linear-silu") - self.time_condition_embed = GlmImageDecoderCombinedTimestepSizeEmbeddings( + self.time_condition_embed = GlmImageCombinedTimestepSizeEmbeddings( embedding_dim=time_embed_dim, condition_dim=condition_dim, pooled_projection_dim=pooled_projection_dim, @@ -450,13 +443,13 @@ def __init__( # 3. Transformer blocks self.transformer_blocks = nn.ModuleList( [ - GlmImageDecoderTransformerBlock(inner_dim, num_attention_heads, attention_head_dim, time_embed_dim) + GlmImageTransformerBlock(inner_dim, num_attention_heads, attention_head_dim, time_embed_dim) for _ in range(num_layers) ] ) # 4. Output projection - self.norm_out = GlmImageDecoderAdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False) + self.norm_out = GlmImageAdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False) self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True) self.gradient_checkpointing = False @@ -464,11 +457,10 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - glyph_hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, prior_token_id: torch.Tensor, prior_token_drop: torch.Tensor, timestep: torch.LongTensor, - original_size: torch.Tensor, target_size: torch.Tensor, crop_coords: torch.Tensor, attention_kwargs: Optional[Dict[str, Any]] = None, @@ -490,7 +482,7 @@ def forward( post_patch_width = width // p hidden_states = self.image_projector(hidden_states) - glyph_hidden_states = self.glyph_projector(glyph_hidden_states) + encoder_hidden_states = self.text_projector(encoder_hidden_states) prior_embedding = self.prior_token_embedding(prior_token_id) prior_embedding[prior_token_drop] *= 0.0 prior_hidden_states = self.prior_projector(prior_embedding) @@ -503,19 +495,19 @@ def forward( # 3. Transformer blocks for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states, glyph_hidden_states = self._gradient_checkpointing_func( + hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( block, hidden_states, - glyph_hidden_states, + encoder_hidden_states, temb, image_rotary_emb, attention_mask, attention_kwargs, ) else: - hidden_states, glyph_hidden_states = block( + hidden_states, encoder_hidden_states = block( hidden_states, - glyph_hidden_states, + encoder_hidden_states, temb, image_rotary_emb, attention_mask, @@ -534,7 +526,7 @@ def forward( return (output,) return Transformer2DModelOutput(sample=output) - def set_attention_processors_state(self, state: GlmImageDecoderAttenProcessorState): + def set_attention_processors_state(self, state: GlmImageAttenProcessorState): for block in self.transformer_blocks: block.attn1.processor.processor_state = state @@ -542,212 +534,3 @@ def clear_attention_processors_cache(self): for block in self.transformer_blocks: block.attn1.processor.k_cache = None block.attn1.processor.v_cache = None - - def repeat_attention_processors_cache(self, repeats: int): - for block in self.transformer_blocks: - if block.attn1.processor.k_cache is None or block.attn1.processor.v_cache is None: - continue - block.attn1.processor.k_cache = torch.repeat_interleave(block.attn1.processor.k_cache, repeats, dim=2) - block.attn1.processor.v_cache = torch.repeat_interleave(block.attn1.processor.v_cache, repeats, dim=2) - - -if __name__ == "__main__": - - def swap_scale_shift(weight, dim): - """ - Swap the scale and shift components in the weight tensor. - - Args: - weight (torch.Tensor): The original weight tensor. - dim (int): The dimension along which to split. - - Returns: - torch.Tensor: The modified weight tensor with scale and shift swapped. - """ - shift, scale = weight.chunk(2, dim=dim) - new_weight = torch.cat([scale, shift], dim=dim) - return new_weight - - def convert_megatron_transformer_checkpoint_to_diffusers( - ckpt_path: str, - num_layers: int, - num_heads: int, - hidden_size: int, - ): - """ - Convert a Megatron Transformer checkpoint to Diffusers format. - - Args: - ckpt_path (str): Path to the Megatron Transformer checkpoint. - num_layers (int): Number of Transformer layers. - num_heads (int): Number of attention heads. - hidden_size (int): Hidden size of the Transformer. - - Returns: - dict: The converted state dictionary compatible with Diffusers. - """ - ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) - mega = ckpt["model"] - used_keys = set() - - def get_mega(key): - used_keys.add(key) - return mega[key] - - new_state_dict = {} - - # Patch Embedding - new_state_dict["image_projector.proj.weight"] = get_mega("encoder_expand_linear.weight").reshape( - hidden_size, 64 - ) - new_state_dict["image_projector.proj.bias"] = get_mega("encoder_expand_linear.bias") - - new_state_dict["glyph_projector.net.0.proj.weight"] = get_mega("glyph_projector.linear_fc1.weight") - new_state_dict["glyph_projector.net.0.proj.bias"] = get_mega("glyph_projector.linear_fc1.bias") - new_state_dict["glyph_projector.net.2.weight"] = get_mega("glyph_projector.linear_fc2.weight") - new_state_dict["glyph_projector.net.2.bias"] = get_mega("glyph_projector.linear_fc2.bias") - - new_state_dict["prior_token_embedding.weight"] = get_mega("xomni_token_id_embedding.weight") - new_state_dict["prior_projector.net.0.proj.weight"] = get_mega("prior_condition_embedding.0.weight") - new_state_dict["prior_projector.net.0.proj.bias"] = get_mega("prior_condition_embedding.0.bias") - new_state_dict["prior_projector.net.2.weight"] = get_mega("prior_condition_embedding.2.weight") - new_state_dict["prior_projector.net.2.bias"] = get_mega("prior_condition_embedding.2.bias") - - # Time Condition Embedding - new_state_dict["time_condition_embed.timestep_embedder.linear_1.weight"] = get_mega( - "time_embedding.time_embed.0.weight" - ) - new_state_dict["time_condition_embed.timestep_embedder.linear_1.bias"] = get_mega( - "time_embedding.time_embed.0.bias" - ) - new_state_dict["time_condition_embed.timestep_embedder.linear_2.weight"] = get_mega( - "time_embedding.time_embed.2.weight" - ) - new_state_dict["time_condition_embed.timestep_embedder.linear_2.bias"] = get_mega( - "time_embedding.time_embed.2.bias" - ) - - new_state_dict["time_condition_embed.condition_embedder.linear_1.weight"] = get_mega( - "label_embedding.label_embed.0.weight" - ) - new_state_dict["time_condition_embed.condition_embedder.linear_1.bias"] = get_mega( - "label_embedding.label_embed.0.bias" - ) - new_state_dict["time_condition_embed.condition_embedder.linear_2.weight"] = get_mega( - "label_embedding.label_embed.2.weight" - ) - new_state_dict["time_condition_embed.condition_embedder.linear_2.bias"] = get_mega( - "label_embedding.label_embed.2.bias" - ) - - # Convert each Transformer layer - from tqdm import tqdm - - for i in tqdm(range(num_layers), desc="Converting layers (Megatron->Diffusers)"): - block_prefix = f"transformer_blocks.{i}." - - # AdaLayerNorm - new_state_dict[block_prefix + "norm1.linear.weight"] = get_mega(f"decoder.layers.{i}.adaln.weight") - new_state_dict[block_prefix + "norm1.linear.bias"] = get_mega(f"decoder.layers.{i}.adaln.bias") - qkv_weight = get_mega(f"decoder.layers.{i}.self_attention.linear_qkv.weight") - qkv_bias = get_mega(f"decoder.layers.{i}.self_attention.linear_qkv.bias") - - # Reshape to match SAT logic - qkv_weight = qkv_weight.view(num_heads, 3, hidden_size // num_heads, hidden_size) - qkv_weight = qkv_weight.permute(1, 0, 2, 3).reshape(3 * hidden_size, hidden_size) - - qkv_bias = qkv_bias.view(num_heads, 3, hidden_size // num_heads) - qkv_bias = qkv_bias.permute(1, 0, 2).reshape(3 * hidden_size) - - # Assign to Diffusers keys - q, k, v = torch.chunk(qkv_weight, 3, dim=0) - qb, kb, vb = torch.chunk(qkv_bias, 3, dim=0) - - new_state_dict[block_prefix + "attn1.to_q.weight"] = q - new_state_dict[block_prefix + "attn1.to_q.bias"] = qb - new_state_dict[block_prefix + "attn1.to_k.weight"] = k - new_state_dict[block_prefix + "attn1.to_k.bias"] = kb - new_state_dict[block_prefix + "attn1.to_v.weight"] = v - new_state_dict[block_prefix + "attn1.to_v.bias"] = vb - - # Attention Output - new_state_dict[block_prefix + "attn1.to_out.0.weight"] = get_mega( - f"decoder.layers.{i}.self_attention.linear_proj.weight" - ) - new_state_dict[block_prefix + "attn1.to_out.0.bias"] = get_mega( - f"decoder.layers.{i}.self_attention.linear_proj.bias" - ) - - # MLP - new_state_dict[block_prefix + "ff.net.0.proj.weight"] = get_mega( - f"decoder.layers.{i}.mlp.linear_fc1.weight" - ) - new_state_dict[block_prefix + "ff.net.0.proj.bias"] = get_mega(f"decoder.layers.{i}.mlp.linear_fc1.bias") - new_state_dict[block_prefix + "ff.net.2.weight"] = get_mega(f"decoder.layers.{i}.mlp.linear_fc2.weight") - new_state_dict[block_prefix + "ff.net.2.bias"] = get_mega(f"decoder.layers.{i}.mlp.linear_fc2.bias") - - # Final Layers - new_state_dict["norm_out.linear.weight"] = swap_scale_shift(get_mega("adaln_final.weight"), dim=0) - new_state_dict["norm_out.linear.bias"] = swap_scale_shift(get_mega("adaln_final.bias"), dim=0) - new_state_dict["proj_out.weight"] = get_mega("output_projector.weight") - new_state_dict["proj_out.bias"] = get_mega("output_projector.bias") - - # Check for unused keys - all_keys = set(mega.keys()) - unused_keys = all_keys - used_keys - if unused_keys: - print(f"\n[WARNING] The following {len(unused_keys)} keys in mega were NOT used:") - for key in sorted(unused_keys): - print(f" - {key}") - raise ValueError( - f"Found {len(unused_keys)} unused keys in Megatron checkpoint. Please update the conversion script to handle these keys." - ) - else: - print(f"\n[INFO] All {len(all_keys)} keys in mega were successfully used.") - - return new_state_dict - - transformer = GlmImageDecoderTransformer2DModel( - patch_size=2, - in_channels=16, - num_layers=30, - attention_head_dim=128, - num_attention_heads=32, - out_channels=16, - text_embed_dim=4096, - time_embed_dim=512, - glyph_embed_dim=1472, - condition_dim=256, - pos_embed_max_size=128, - ).to(torch.bfloat16) - converted_transformer_state_dict = convert_megatron_transformer_checkpoint_to_diffusers( - ckpt_path="/workspace/ckpt/tjy/Glm-train-dev/examples/cogview/ckpts/merge/1+6_0.5+0.5/iter_0000000/mp_rank_00/model_optim_rng.pt", - num_layers=30, - num_heads=32, - hidden_size=4096, - ) - transformer.load_state_dict(converted_transformer_state_dict) - transformer.cuda() - - latent = torch.load("/workspace/ckpt/tjy/glm-train-dev/examples/cogview/latent.pt").to(torch.bfloat16) - latent = rearrange(latent, "(b h w) (c p q) -> b c (h p) (w q)", b=8, h=72, w=54, p=2, q=2) - glyph_hidden_states = torch.load( - "/workspace/ckpt/tjy/glm-train-dev/examples/cogview/glyph_condition_embedding.pt" - ).to(torch.bfloat16) - glyph_hidden_states = rearrange(glyph_hidden_states, "(b n) c -> b n c", b=8, n=2) - prior_token_id = torch.load("/workspace/ckpt/tjy/glm-train-dev/examples/cogview/xomni_token_id.pt") - prior_token_drop = torch.load("/workspace/ckpt/tjy/glm-train-dev/examples/cogview/xomni_drop.pt") - prior_token_id = rearrange(prior_token_id, "(b n) -> b n", b=8) - prior_token_drop = rearrange(prior_token_drop, "(b n)-> b n", b=8) - - with torch.no_grad(): - output = transformer( - hidden_states=latent, - glyph_hidden_states=glyph_hidden_states, - prior_token_id=prior_token_id, - prior_token_drop=prior_token_drop, - timestep=torch.tensor([999.0] * 8).cuda(), - original_size=torch.tensor([[144, 108]] * 8).cuda(), - target_size=torch.tensor([[144, 108]] * 8).cuda(), - crop_coords=torch.tensor([[0, 0]] * 8).cuda(), - ) diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 28b882dfafdb..6f583385de7a 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -52,7 +52,7 @@ FluxKontextPipeline, FluxPipeline, ) -from .glm_image import GlmImageDecoderPipeline +from .glm_image import GlmImagePipeline from .hunyuandit import HunyuanDiTPipeline from .kandinsky import ( KandinskyCombinedPipeline, @@ -168,7 +168,7 @@ ("chroma", ChromaPipeline), ("cogview3", CogView3PlusPipeline), ("cogview4", CogView4Pipeline), - ("glm_image", GlmImageDecoderPipeline), + ("glm_image", GlmImagePipeline), ("cogview4-control", CogView4ControlPipeline), ("qwenimage", QwenImagePipeline), ("qwenimage-controlnet", QwenImageControlNetPipeline), diff --git a/src/diffusers/pipelines/glm_image/__init__.py b/src/diffusers/pipelines/glm_image/__init__.py index 24c20c4ef5ac..9df31b0b1734 100644 --- a/src/diffusers/pipelines/glm_image/__init__.py +++ b/src/diffusers/pipelines/glm_image/__init__.py @@ -12,7 +12,7 @@ _dummy_objects = {} _additional_imports = {} -_import_structure = {"pipeline_output": ["GlmImageDecoderPipelineOutput"]} +_import_structure = {"pipeline_output": ["GlmImagePipelineOutput"]} try: if not (is_transformers_available() and is_torch_available()): @@ -22,7 +22,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure["pipeline_glm_image"] = ["GlmImageDecoderPipeline"] + _import_structure["pipeline_glm_image"] = ["GlmImagePipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): @@ -30,7 +30,7 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: - from .pipeline_glm_image import GlmImageDecoderPipeline + from .pipeline_glm_image import GlmImagePipeline else: import sys diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index 825952be4b98..03ecb868ea67 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -25,13 +25,13 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import VaeImageProcessor from ...loaders import CogView4LoraLoaderMixin -from ...models import AutoencoderKL, GlmImageDecoderTransformer2DModel -from ...models.transformers.transformer_glm_image import GlmImageDecoderAttenProcessorState +from ...models import AutoencoderKL, GlmImageTransformer2DModel +from ...models.transformers.transformer_glm_image import GlmImageAttenProcessorState from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor -from .pipeline_output import GlmImageDecoderPipelineOutput +from .pipeline_output import GlmImagePipelineOutput if is_torch_xla_available(): @@ -47,9 +47,9 @@ Examples: ```python >>> import torch - >>> from diffusers import GlmImageDecoderPipeline + >>> from diffusers import GlmImagePipeline - >>> pipe = GlmImageDecoderPipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16) + >>> pipe = GlmImagePipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") >>> prompt = "A photo of an astronaut riding a horse on mars" @@ -151,7 +151,7 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin): +class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin): r""" Pipeline for text-to-image generation using CogView4. @@ -162,7 +162,7 @@ class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin): vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`GlmModel`]): - Frozen text-encoder. CogView4 uses [Glm-4-9b-hf](https://huggingface.co/THUDM/Glm-4-9b-hf). + Frozen text-encoder. CogView4 uses [GLM-Image](https://huggingface.co/zai-org/GLM-Image). tokenizer (`PreTrainedTokenizer`): Tokenizer of class [PreTrainedTokenizer](https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer#transformers.PreTrainedTokenizer). @@ -179,15 +179,15 @@ class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin): def __init__( self, tokenizer: AutoTokenizer, - glyph_encoder: T5EncoderModel, + text_encoder: T5EncoderModel, vae: AutoencoderKL, - transformer: GlmImageDecoderTransformer2DModel, + transformer: GlmImageTransformer2DModel, scheduler: FlowMatchEulerDiscreteScheduler, ): super().__init__() self.register_modules( - tokenizer=tokenizer, glyph_encoder=glyph_encoder, vae=vae, transformer=transformer, scheduler=scheduler + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) @@ -221,7 +221,7 @@ def _get_glyph_embeds( dtype: Optional[torch.dtype] = None, ): device = device or self._execution_device - dtype = dtype or self.glyph_encoder.dtype + dtype = dtype or self.text_encoder.dtype glyph_texts = self.get_glyph_texts(prompt) input_ids = self.tokenizer( @@ -240,7 +240,7 @@ def _get_glyph_embeds( [input_ids_ + [self.tokenizer.pad_token_id] * (max_length - len(input_ids_)) for input_ids_ in input_ids], device=device, ) - outputs = self.glyph_encoder(input_ids, attention_mask=attention_mask) + outputs = self.text_encoder(input_ids, attention_mask=attention_mask) glyph_embeds = outputs.last_hidden_state[attention_mask.bool()].unsqueeze(0) return glyph_embeds.to(device=device, dtype=dtype) @@ -442,7 +442,7 @@ def __call__( ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 2048, - ) -> Union[GlmImageDecoderPipelineOutput, Tuple]: + ) -> Union[GlmImagePipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -615,7 +615,7 @@ def __call__( ) if condition_images is not None: - self.transformer.set_attention_processors_state(GlmImageDecoderAttenProcessorState.ImageEditWriteKV) + self.transformer.set_attention_processors_state(GlmImageAttenProcessorState.ImageEditWriteKV) latents_mean = ( torch.tensor(self.vae.config.latents_mean) .view(1, self.vae.config.latent_channels, 1, 1) @@ -698,7 +698,7 @@ def __call__( timestep = t.expand(latents.shape[0]) - 1 if condition_images is not None: - self.transformer.set_attention_processors_state(GlmImageDecoderAttenProcessorState.ImageEditReadKV) + self.transformer.set_attention_processors_state(GlmImageAttenProcessorState.ImageEditReadKV) noise_pred_cond = self.transformer( hidden_states=latent_model_input, @@ -717,7 +717,7 @@ def __call__( if self.do_classifier_free_guidance: if condition_images is not None: self.transformer.set_attention_processors_state( - GlmImageDecoderAttenProcessorState.ImageEditDontReadKV + GlmImageAttenProcessorState.ImageEditDontReadKV ) noise_pred_uncond = self.transformer( hidden_states=latent_model_input, @@ -755,7 +755,7 @@ def __call__( xm.mark_step() self._current_timestep = None - self.transformer.set_attention_processors_state(GlmImageDecoderAttenProcessorState.ImageGen) + self.transformer.set_attention_processors_state(GlmImageAttenProcessorState.ImageGen) self.transformer.clear_attention_processors_cache() if not output_type == "latent": @@ -783,4 +783,4 @@ def __call__( if not return_dict: return (condition_images,) - return GlmImageDecoderPipelineOutput(images=condition_images) + return GlmImagePipelineOutput(images=condition_images) diff --git a/src/diffusers/pipelines/glm_image/pipeline_output.py b/src/diffusers/pipelines/glm_image/pipeline_output.py index a506e527cb50..aec5a5454ea8 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_output.py +++ b/src/diffusers/pipelines/glm_image/pipeline_output.py @@ -8,7 +8,7 @@ @dataclass -class GlmImageDecoderPipelineOutput(BaseOutput): +class GlmImagePipelineOutput(BaseOutput): """ Output class for CogView3 pipelines. diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index b36d7f3ccf06..d2355d473711 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -967,7 +967,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class GlmImageDecoderTransformer2DModel(metaclass=DummyObject): +class GlmImageTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): From adcc53206bcb75b045daeed889ded027c639f1e5 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Wed, 7 Jan 2026 17:07:13 +0800 Subject: [PATCH 06/56] 2 --- .../models/transformers/transformer_glm_image.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_glm_image.py b/src/diffusers/models/transformers/transformer_glm_image.py index 4c296b48aabd..6eb9b5e80308 100644 --- a/src/diffusers/models/transformers/transformer_glm_image.py +++ b/src/diffusers/models/transformers/transformer_glm_image.py @@ -18,7 +18,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from einops import rearrange from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin @@ -376,7 +375,7 @@ class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach The number of heads to use for multi-head attention. out_channels (`int`, defaults to `16`): The number of channels in the output. - text_embed_dim (`int`, defaults to `4096`): + text_embed_dim (`int`, defaults to `1472`): Input dimension of text embeddings from the text encoder. time_embed_dim (`int`, defaults to `512`): Output dimension of timestep embeddings. @@ -428,8 +427,7 @@ def __init__( # 2. Patch & Text-timestep embedding self.image_projector = GlmImageImageProjector(in_channels, inner_dim, patch_size) - self.text_projector = FeedForward(text_embed_dim, inner_dim, inner_dim=inner_dim, activation_fn="gelu") - + self.glyph_projector = FeedForward(text_embed_dim, inner_dim, activation_fn="gelu") self.prior_token_embedding = nn.Embedding(prior_vq_quantizer_codebook_size, inner_dim) self.prior_projector = FeedForward(inner_dim, inner_dim, inner_dim=inner_dim, activation_fn="linear-silu") @@ -482,7 +480,7 @@ def forward( post_patch_width = width // p hidden_states = self.image_projector(hidden_states) - encoder_hidden_states = self.text_projector(encoder_hidden_states) + encoder_hidden_states = self.glyph_projector(encoder_hidden_states) prior_embedding = self.prior_token_embedding(prior_token_id) prior_embedding[prior_token_drop] *= 0.0 prior_hidden_states = self.prior_projector(prior_embedding) From ec678a1fb7b44a4dc3005f98652c0204fe6decd7 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Wed, 7 Jan 2026 17:16:39 +0800 Subject: [PATCH 07/56] update --- .../transformers/transformer_glm_image.py | 2 +- .../pipelines/glm_image/pipeline_glm_image.py | 18 +++--------------- 2 files changed, 4 insertions(+), 16 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_glm_image.py b/src/diffusers/models/transformers/transformer_glm_image.py index 6eb9b5e80308..6ff45ac9e8ea 100644 --- a/src/diffusers/models/transformers/transformer_glm_image.py +++ b/src/diffusers/models/transformers/transformer_glm_image.py @@ -427,7 +427,7 @@ def __init__( # 2. Patch & Text-timestep embedding self.image_projector = GlmImageImageProjector(in_channels, inner_dim, patch_size) - self.glyph_projector = FeedForward(text_embed_dim, inner_dim, activation_fn="gelu") + self.glyph_projector = FeedForward(text_embed_dim, inner_dim, inner_dim=inner_dim, activation_fn="gelu") self.prior_token_embedding = nn.Embedding(prior_vq_quantizer_codebook_size, inner_dim) self.prior_projector = FeedForward(inner_dim, inner_dim, inner_dim=inner_dim, activation_fn="linear-silu") diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index 03ecb868ea67..caf8ed0e70b0 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -432,7 +432,6 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, - original_size: Optional[Tuple[int, int]] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), output_type: str = "pil", return_dict: bool = True, @@ -496,11 +495,6 @@ def __call__( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. - original_size (`Tuple[int]`, *optional*, defaults to (2048, 2048)): - If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. - `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as - explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting @@ -637,24 +631,20 @@ def __call__( condition_latent = (condition_latent - latents_mean) / latents_std _ = self.transformer( hidden_states=condition_latent, - glyph_hidden_states=empty_glyph_hiddens, + encoder_hidden_states=empty_glyph_hiddens, prior_token_id=condition_image_prior_token_id, prior_token_drop=torch.full_like(condition_image_prior_token_id, False, dtype=torch.bool), timestep=torch.zeros((1,), device=device), - original_size=torch.tensor([condition_image.shape[-2:]], device=device), target_size=torch.tensor([condition_image.shape[-2:]], device=device), crop_coords=torch.zeros((1, 2), device=device), attention_kwargs=attention_kwargs, ) # 6. Prepare additional timestep conditions - original_size = original_size or (height, width) target_size = (height, width) - original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=device) target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device) crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device) - original_size = original_size.repeat(batch_size * num_images_per_prompt, 1) target_size = target_size.repeat(batch_size * num_images_per_prompt, 1) crops_coords_top_left = crops_coords_top_left.repeat(batch_size * num_images_per_prompt, 1) @@ -702,11 +692,10 @@ def __call__( noise_pred_cond = self.transformer( hidden_states=latent_model_input, - glyph_hidden_states=prompt_embeds, + encoder_hidden_states=prompt_embeds, prior_token_id=prior_token_id, prior_token_drop=prior_token_drop_cond, timestep=timestep, - original_size=original_size, target_size=target_size, crop_coords=crops_coords_top_left, attention_kwargs=attention_kwargs, @@ -721,11 +710,10 @@ def __call__( ) noise_pred_uncond = self.transformer( hidden_states=latent_model_input, - glyph_hidden_states=negative_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, prior_token_id=prior_token_id, prior_token_drop=prior_token_drop_uncond, timestep=timestep, - original_size=original_size, target_size=target_size, crop_coords=crops_coords_top_left, attention_kwargs=attention_kwargs, From 22fe6c9023dfedff76bb30baebf4e953fb6ee091 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Wed, 7 Jan 2026 18:15:01 +0800 Subject: [PATCH 08/56] init with encoder --- .../pipelines/glm_image/pipeline_glm_image.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index caf8ed0e70b0..69062bfcf1c3 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -20,7 +20,7 @@ import numpy as np import PIL import torch -from transformers import AutoTokenizer, T5EncoderModel +from transformers import AutoTokenizer, T5EncoderModel, GlmImageForConditionalGeneration from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import VaeImageProcessor @@ -180,6 +180,7 @@ def __init__( self, tokenizer: AutoTokenizer, text_encoder: T5EncoderModel, + vision_language_encoder: GlmImageForConditionalGeneration, vae: AutoencoderKL, transformer: GlmImageTransformer2DModel, scheduler: FlowMatchEulerDiscreteScheduler, @@ -187,7 +188,12 @@ def __init__( super().__init__() self.register_modules( - tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + tokenizer=tokenizer, + text_encoder=text_encoder, + vision_language_encoder=vision_language_encoder, + vae=vae, + transformer=transformer, + scheduler=scheduler ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) From b3d1b5547b7e39090a425887d90c10782e35b857 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Wed, 7 Jan 2026 18:59:30 +0800 Subject: [PATCH 09/56] merge2pipeline --- .../pipelines/glm_image/pipeline_glm_image.py | 413 +++++++++++------- 1 file changed, 261 insertions(+), 152 deletions(-) diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index 69062bfcf1c3..d755dde11c98 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -15,12 +15,13 @@ import inspect import re +from math import sqrt from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL import torch -from transformers import AutoTokenizer, T5EncoderModel, GlmImageForConditionalGeneration +from transformers import ByT5Tokenizer, GlmImageForConditionalGeneration, GlmImageProcessor, T5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import VaeImageProcessor @@ -49,10 +50,10 @@ >>> import torch >>> from diffusers import GlmImagePipeline - >>> pipe = GlmImagePipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16) + >>> pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") - >>> prompt = "A photo of an astronaut riding a horse on mars" + >>> prompt = "A photo of an astronaut riding a horse on mars36 24" >>> image = pipe(prompt).images[0] >>> image.save("output.png") ``` @@ -81,25 +82,6 @@ def retrieve_timesteps( r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. - - Args: - scheduler (`SchedulerMixin`): - The scheduler to get timesteps from. - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` - must be `None`. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - timesteps (`List[int]`, *optional*): - Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, - `num_inference_steps` and `sigmas` must be `None`. - sigmas (`List[float]`, *optional*): - Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, - `num_inference_steps` and `timesteps` must be `None`. - - Returns: - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the - second element is the number of inference steps. """ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) @@ -153,32 +135,36 @@ def retrieve_latents( class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin): r""" - Pipeline for text-to-image generation using CogView4. + Pipeline for text-to-image generation using GLM-Image. - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + This pipeline integrates both the AR (autoregressive) model for token generation and the DiT (diffusion + transformer) model for image decoding. Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`GlmModel`]): - Frozen text-encoder. CogView4 uses [GLM-Image](https://huggingface.co/zai-org/GLM-Image). + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder for glyph embeddings. tokenizer (`PreTrainedTokenizer`): - Tokenizer of class - [PreTrainedTokenizer](https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer#transformers.PreTrainedTokenizer). - transformer ([`CogView4Transformer2DModel`]): - A text conditioned `CogView4Transformer2DModel` to denoise the encoded image latents. + Tokenizer for the text encoder. + processor (`AutoProcessor`): + Processor for the AR model to handle chat templates and tokenization. + vision_language_encoder ([`GlmImageForConditionalGeneration`]): + The AR model that generates image tokens from text prompts. + transformer ([`GlmImageTransformer2DModel`]): + A text conditioned transformer to denoise the encoded image latents (DiT). scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. """ _optional_components = [] model_cpu_offload_seq = "transformer->vae" - _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _callback_tensor_inputs = ["latents", "prompt_embeds"] def __init__( self, - tokenizer: AutoTokenizer, + tokenizer: ByT5Tokenizer, + processor: GlmImageProcessor, text_encoder: T5EncoderModel, vision_language_encoder: GlmImageForConditionalGeneration, vae: AutoencoderKL, @@ -189,11 +175,12 @@ def __init__( self.register_modules( tokenizer=tokenizer, + processor=processor, text_encoder=text_encoder, vision_language_encoder=vision_language_encoder, vae=vae, transformer=transformer, - scheduler=scheduler + scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) @@ -206,10 +193,197 @@ def __init__( else 128 ) - def get_glyph_texts( + def _parse_and_expand_shape_info(self, prompt: str) -> Tuple[str, int, int, int, int]: + """ + Parse the shape info from prompt and expand it for AR model. + + Args: + prompt: The prompt containing H W shape specification + + Returns: + Tuple of (expanded_prompt, token_h, token_w, prev_token_h, prev_token_w) + """ + match = re.search(r"(\d+)\s+(\d+)", prompt) + if match is None: + raise ValueError(f"Prompt must contain shape info in format 'H W', got: {prompt}") + + token_h, token_w = int(match.group(1)), int(match.group(2)) + ratio = token_h / token_w + prev_token_h = int(sqrt(ratio) * 16) + prev_token_w = int(sqrt(1 / ratio) * 16) + + old_shape = f"{token_h} {token_w}" + new_shape = f"{token_h} {token_w}{prev_token_h} {prev_token_w}" + expanded_prompt = prompt.replace(old_shape, new_shape) + + return expanded_prompt, token_h, token_w, prev_token_h, prev_token_w + + def _build_image_grid_thw( self, - prompt, - ): + token_h: int, + token_w: int, + prev_token_h: int, + prev_token_w: int, + existing_grid: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + """ + Build image grid tensor for AR model. + + For text-to-image: creates grid for large image + small image For image-to-image: appends new image to existing + grid + """ + if existing_grid is None or existing_grid.numel() == 0: + # Text-to-image: large image + small image + return torch.tensor( + [ + [1, token_h, token_w], + [1, prev_token_h, prev_token_w], + ], + device=device, + ) + else: + # Image-to-image: append to existing + return torch.cat([existing_grid, torch.tensor([[1, token_h, token_w]], device=device)], dim=0) + + def _calculate_ar_generation_params( + self, token_h: int, token_w: int, prev_token_h: int, prev_token_w: int, is_text_to_image: bool + ) -> Tuple[int, int]: + """ + Calculate max_new_tokens and large_image_start_offset for AR generation. + """ + large_image_tokens = token_h * token_w + small_image_tokens = prev_token_h * prev_token_w + + if is_text_to_image: + max_new_tokens = small_image_tokens + large_image_tokens + 1 + large_image_start_offset = small_image_tokens + else: + max_new_tokens = large_image_tokens + 1 + large_image_start_offset = 0 + + return max_new_tokens, large_image_start_offset + + def _extract_large_image_tokens( + self, outputs: torch.Tensor, input_length: int, large_image_start_offset: int, large_image_tokens: int + ) -> torch.Tensor: + """ + Extract the large image tokens from AR model output. + """ + generated_tokens = outputs[0][input_length:] + large_image_start = large_image_start_offset + large_image_end = large_image_start + large_image_tokens + return generated_tokens[large_image_start:large_image_end] + + def _upsample_d32_to_d16(self, token_ids: torch.Tensor, token_h: int, token_w: int) -> torch.Tensor: + """ + Upsample token IDs from d32 format to d16 format. + + AR model generates tokens at d32 resolution (each token = 32x32 pixels). DiT expects tokens at d16 resolution + (each token = 16x16 pixels). This function performs 2x nearest-neighbor upsampling. + + Args: + token_ids: Token IDs of shape [N] where N = token_h * token_w + token_h: Height in d32 token units + token_w: Width in d32 token units + + Returns: + Upsampled token IDs of shape [1, N*4] where N*4 = (token_h*2) * (token_w*2) + """ + # Reshape to spatial format: [1, 1, H, W] + token_ids = token_ids.view(1, 1, token_h, token_w) + + # 2x nearest-neighbor upsampling + token_ids = torch.nn.functional.interpolate(token_ids.float(), scale_factor=2, mode="nearest").to( + dtype=torch.long + ) + + # Flatten back to [1, H*W*4] + token_ids = token_ids.view(1, -1) + + return token_ids + + def generate_prior_tokens( + self, + prompt: str, + condition_images: Optional[List[PIL.Image.Image]] = None, + ) -> Tuple[torch.Tensor, int, int]: + """ + Generate prior tokens using the AR (vision_language_encoder) model. + + Args: + prompt: The text prompt with shape info (e.g., "description36 24") + condition_images: Optional list of condition images for i2i + + Returns: + Tuple of (prior_token_ids, pixel_height, pixel_width) + - prior_token_ids: Upsampled to d16 format, shape [1, token_h*token_w*4] + - pixel_height: Image height in pixels + - pixel_width: Image width in pixels + """ + device = self.vision_language_encoder.device + + # Parse and expand shape info + expanded_prompt, token_h, token_w, prev_h, prev_w = self._parse_and_expand_shape_info(prompt) + + # Build messages for processor + content = [] + if condition_images is not None: + for img in condition_images: + content.append({"type": "image", "image": img}) + content.append({"type": "text", "text": expanded_prompt}) + messages = [{"role": "user", "content": content}] + + # Process inputs + inputs = self.processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + + # Determine if text-to-image or image-to-image + existing_grid = inputs.get("image_grid_thw") + is_text_to_image = existing_grid is None or existing_grid.numel() == 0 + + # Build image grid + inputs["image_grid_thw"] = self._build_image_grid_thw( + token_h, + token_w, + prev_h, + prev_w, + existing_grid=existing_grid if not is_text_to_image else None, + device=device, + ) + + # Calculate generation parameters + max_new_tokens, large_image_offset = self._calculate_ar_generation_params( + token_h, token_w, prev_h, prev_w, is_text_to_image + ) + large_image_tokens = token_h * token_w + + # Move inputs to device and generate + inputs = inputs.to(device) + input_length = inputs["input_ids"].shape[-1] + + outputs = self.vision_language_encoder.generate( + **inputs, + max_new_tokens=max_new_tokens, + do_sample=True, + ) + + prior_token_ids_d32 = self._extract_large_image_tokens( + outputs, input_length, large_image_offset, large_image_tokens + ) + prior_token_ids = self._upsample_d32_to_d16(prior_token_ids_d32, token_h, token_w) + + pixel_height = token_h * 32 + pixel_width = token_w * 32 + + return prior_token_ids, pixel_height, pixel_width + + def get_glyph_texts(self, prompt): prompt = prompt[0] if isinstance(prompt, list) else prompt ocr_texts = ( re.findall(r"'([^']*)'", prompt) @@ -254,11 +428,9 @@ def _get_glyph_embeds( def encode_prompt( self, prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, do_classifier_free_guidance: bool = True, num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, max_sequence_length: int = 2048, @@ -269,10 +441,6 @@ def encode_prompt( Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): Whether to use classifier free guidance or not. num_images_per_prompt (`int`, *optional*, defaults to 1): @@ -280,10 +448,6 @@ def encode_prompt( prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. device: (`torch.device`, *optional*): torch device dtype: (`torch.dtype`, *optional*): @@ -306,8 +470,9 @@ def encode_prompt( prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" + negative_prompt_embeds = None + if do_classifier_free_guidance: + negative_prompt = "" negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt if prompt is not None and type(prompt) is not type(negative_prompt): @@ -353,10 +518,8 @@ def check_inputs( prompt, height, width, - negative_prompt, callback_on_step_end_tensor_inputs, prompt_embeds=None, - negative_prompt_embeds=None, ): if ( height is not None @@ -391,9 +554,6 @@ def check_inputs( def guidance_scale(self): return self._guidance_scale - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @@ -418,15 +578,12 @@ def interrupt(self): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prior_token_id: Optional[torch.LongTensor] = None, prompt: Optional[Union[str, List[str]]] = None, - negative_prompt: Optional[Union[str, List[str]]] = None, condition_images: Optional[ Union[ torch.Tensor, PIL.Image.Image, np.ndarray, List[torch.Tensor], List[PIL.Image.Image], List[np.ndarray] ] ] = None, - condition_images_prior_token_id: Optional[Union[torch.LongTensor, List[torch.LongTensor]]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, @@ -437,7 +594,6 @@ def __call__( generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), output_type: str = "pil", return_dict: bool = True, @@ -452,109 +608,48 @@ def __call__( Function invoked when calling the pipeline for generation. Args: - image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): - `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both - numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list - or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a - list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image - latents as `image`, but if passing latents directly it is not encoded again. prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. If not provided, it is set to 2048. - width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. If not provided it is set to 2048. + The prompt or prompts to guide the image generation. Must contain shape info in the format 'H + W' where H and W are token dimensions (d32). Example: "A beautiful sunset36 24" + generates a 1152x768 image. + condition_images: Optional condition images for image-to-image generation. + height (`int`, *optional*): + The height in pixels. If not provided, derived from prompt shape info. + width (`int`, *optional*): + The width in pixels. If not provided, derived from prompt shape info. num_inference_steps (`int`, *optional*, defaults to `50`): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - timesteps (`List[int]`, *optional*): - Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument - in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is - passed will be used. Must be in descending order. - sigmas (`List[float]`, *optional*): - Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in - their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed - will be used. - guidance_scale (`float`, *optional*, defaults to `5.0`): - Guidance scale as defined in [Classifier-Free Diffusion - Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. - of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting - `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to - the text `prompt`, usually at the expense of lower image quality. + The number of denoising steps for DiT. + guidance_scale (`float`, *optional*, defaults to `1.5`): + Guidance scale for classifier-free guidance. num_images_per_prompt (`int`, *optional*, defaults to `1`): The number of images to generate per prompt. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will be generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): - `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position - `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting - `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of - [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + generator (`torch.Generator`, *optional*): + Random generator for reproducibility. output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead - of a plain tuple. - attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. - callback_on_step_end_tensor_inputs (`List`, *optional*): - The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list - will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the - `._callback_tensor_inputs` attribute of your pipeline class. - max_sequence_length (`int`, defaults to `224`): - Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results. + Output format: "pil", "np", or "latent". Examples: Returns: - [`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] or `tuple`: - [`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] if `return_dict` is True, otherwise a - `tuple`. When returning a tuple, the first element is a list with the generated images. + [`GlmImagePipelineOutput`] or `tuple`: Generated images. """ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - # 1. Check inputs. Raise error if not correct + # 1. Check inputs self.check_inputs( prompt, height, width, - negative_prompt, callback_on_step_end_tensor_inputs, prompt_embeds, - negative_prompt_embeds, ) self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs self._current_timestep = None self._interrupt = False - # 2. Default call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -565,29 +660,50 @@ def __call__( device = self._execution_device + ar_condition_images = None + if condition_images is not None: + if not isinstance(condition_images, list): + condition_images = [condition_images] + ar_condition_images = [] + for img in condition_images: + if isinstance(img, PIL.Image.Image): + ar_condition_images.append(img) + elif isinstance(img, torch.Tensor): + img_np = img.cpu().numpy() + if img_np.ndim == 4: + img_np = img_np[0] + if img_np.shape[0] in [1, 3, 4]: + img_np = img_np.transpose(1, 2, 0) + if img_np.max() <= 1.0: + img_np = (img_np * 255).astype(np.uint8) + ar_condition_images.append(PIL.Image.fromarray(img_np)) + else: + ar_condition_images.append(PIL.Image.fromarray(img)) + + prior_token_id, ar_height, ar_width = self.generate_prior_tokens( + prompt=prompt[0] if isinstance(prompt, list) else prompt, + condition_images=ar_condition_images, + ) + + height = height or ar_height + width = width or ar_width + # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, - negative_prompt, self.do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt, prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, max_sequence_length=max_sequence_length, device=device, dtype=self.dtype, ) # 4. process images - if condition_images is not None and not isinstance(condition_images, list): - condition_images = [condition_images] - condition_images_prior_token_id = [condition_images_prior_token_id] - assert condition_images is None or (len(condition_images) == len(condition_images_prior_token_id)), ( - "image and image_prior_token_id must be the same length" - ) - + condition_images_prior_token_id = None if condition_images is not None: preprocessed_condition_images = [] + condition_images_prior_token_id = [] for img in condition_images: image_height, image_width = img.size[::-1] if isinstance(img, PIL.Image.Image) else img.shape[:2] multiple_of = self.vae_scale_factor * self.transformer.config.patch_size @@ -595,11 +711,7 @@ def __call__( image_width = (image_width // multiple_of) * multiple_of img = self.image_processor.preprocess(img, height=image_height, width=image_width) preprocessed_condition_images.append(img) - height = height or image_height - width = width or image_width condition_images = preprocessed_condition_images - height = height or self.default_sample_size * self.vae_scale_factor - width = width or self.default_sample_size * self.vae_scale_factor # 5. Prepare latents and (optional) condition_images kv cache latent_channels = self.transformer.config.in_channels @@ -614,7 +726,7 @@ def __call__( latents=latents, ) - if condition_images is not None: + if condition_images is not None and condition_images_prior_token_id is not None: self.transformer.set_attention_processors_state(GlmImageAttenProcessorState.ImageEditWriteKV) latents_mean = ( torch.tensor(self.vae.config.latents_mean) @@ -690,7 +802,6 @@ def __call__( self._current_timestep = t latent_model_input = latents.to(transformer_dtype) - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]) - 1 if condition_images is not None: @@ -732,7 +843,6 @@ def __call__( latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - # call the callback, if provided if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: @@ -740,7 +850,6 @@ def __call__( callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() From e2b31f8b153311e5bd355b2f3bb8f75fc31da978 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Wed, 7 Jan 2026 22:55:16 +0800 Subject: [PATCH 10/56] Update pipeline_glm_image.py --- .../pipelines/glm_image/pipeline_glm_image.py | 48 +++++++++---------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index d755dde11c98..7286f9ab273a 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -306,14 +306,14 @@ def _upsample_d32_to_d16(self, token_ids: torch.Tensor, token_h: int, token_w: i def generate_prior_tokens( self, prompt: str, - condition_images: Optional[List[PIL.Image.Image]] = None, + image: Optional[List[PIL.Image.Image]] = None, ) -> Tuple[torch.Tensor, int, int]: """ Generate prior tokens using the AR (vision_language_encoder) model. Args: prompt: The text prompt with shape info (e.g., "description36 24") - condition_images: Optional list of condition images for i2i + image: Optional list of condition images for i2i Returns: Tuple of (prior_token_ids, pixel_height, pixel_width) @@ -328,8 +328,8 @@ def generate_prior_tokens( # Build messages for processor content = [] - if condition_images is not None: - for img in condition_images: + if image is not None: + for img in image: content.append({"type": "image", "image": img}) content.append({"type": "text", "text": expanded_prompt}) messages = [{"role": "user", "content": content}] @@ -579,7 +579,7 @@ def interrupt(self): def __call__( self, prompt: Optional[Union[str, List[str]]] = None, - condition_images: Optional[ + image: Optional[ Union[ torch.Tensor, PIL.Image.Image, np.ndarray, List[torch.Tensor], List[PIL.Image.Image], List[np.ndarray] ] @@ -612,7 +612,7 @@ def __call__( The prompt or prompts to guide the image generation. Must contain shape info in the format 'H W' where H and W are token dimensions (d32). Example: "A beautiful sunset36 24" generates a 1152x768 image. - condition_images: Optional condition images for image-to-image generation. + image: Optional condition images for image-to-image generation. height (`int`, *optional*): The height in pixels. If not provided, derived from prompt shape info. width (`int`, *optional*): @@ -661,11 +661,11 @@ def __call__( device = self._execution_device ar_condition_images = None - if condition_images is not None: - if not isinstance(condition_images, list): - condition_images = [condition_images] + if image is not None: + if not isinstance(image, list): + image = [image] ar_condition_images = [] - for img in condition_images: + for img in image: if isinstance(img, PIL.Image.Image): ar_condition_images.append(img) elif isinstance(img, torch.Tensor): @@ -682,7 +682,7 @@ def __call__( prior_token_id, ar_height, ar_width = self.generate_prior_tokens( prompt=prompt[0] if isinstance(prompt, list) else prompt, - condition_images=ar_condition_images, + image=ar_condition_images, ) height = height or ar_height @@ -701,19 +701,19 @@ def __call__( # 4. process images condition_images_prior_token_id = None - if condition_images is not None: + if image is not None: preprocessed_condition_images = [] condition_images_prior_token_id = [] - for img in condition_images: + for img in image: image_height, image_width = img.size[::-1] if isinstance(img, PIL.Image.Image) else img.shape[:2] multiple_of = self.vae_scale_factor * self.transformer.config.patch_size image_height = (image_height // multiple_of) * multiple_of image_width = (image_width // multiple_of) * multiple_of img = self.image_processor.preprocess(img, height=image_height, width=image_width) preprocessed_condition_images.append(img) - condition_images = preprocessed_condition_images + image = preprocessed_condition_images - # 5. Prepare latents and (optional) condition_images kv cache + # 5. Prepare latents and (optional) image kv cache latent_channels = self.transformer.config.in_channels latents = self.prepare_latents( batch_size=batch_size * num_images_per_prompt, @@ -726,7 +726,7 @@ def __call__( latents=latents, ) - if condition_images is not None and condition_images_prior_token_id is not None: + if image is not None and condition_images_prior_token_id is not None: self.transformer.set_attention_processors_state(GlmImageAttenProcessorState.ImageEditWriteKV) latents_mean = ( torch.tensor(self.vae.config.latents_mean) @@ -740,7 +740,7 @@ def __call__( ) empty_glyph_hiddens = torch.zeros_like(prompt_embeds)[:1, :0, ...] for condition_image, condition_image_prior_token_id in zip( - condition_images, condition_images_prior_token_id + image, condition_images_prior_token_id ): condition_image = condition_image.to(device=device, dtype=self.vae.dtype) condition_latent = retrieve_latents( @@ -804,7 +804,7 @@ def __call__( timestep = t.expand(latents.shape[0]) - 1 - if condition_images is not None: + if image is not None: self.transformer.set_attention_processors_state(GlmImageAttenProcessorState.ImageEditReadKV) noise_pred_cond = self.transformer( @@ -821,7 +821,7 @@ def __call__( # perform guidance if self.do_classifier_free_guidance: - if condition_images is not None: + if image is not None: self.transformer.set_attention_processors_state( GlmImageAttenProcessorState.ImageEditDontReadKV ) @@ -874,16 +874,16 @@ def __call__( .to(latents.device, latents.dtype) ) latents = latents * latents_std + latents_mean - condition_images = self.vae.decode(latents, return_dict=False, generator=generator)[0] + image = self.vae.decode(latents, return_dict=False, generator=generator)[0] else: - condition_images = latents + image = latents - condition_images = self.image_processor.postprocess(condition_images, output_type=output_type) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: - return (condition_images,) + return (image,) - return GlmImagePipelineOutput(images=condition_images) + return GlmImagePipelineOutput(images=image) From 1cf277d36d55e3942e34726b5f3e982006d2a615 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Wed, 7 Jan 2026 23:51:40 +0800 Subject: [PATCH 11/56] remove sop --- .../pipelines/glm_image/pipeline_glm_image.py | 84 +++++++++++++------ 1 file changed, 58 insertions(+), 26 deletions(-) diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index 7286f9ab273a..e4dd7de80938 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -290,64 +290,96 @@ def _upsample_d32_to_d16(self, token_ids: torch.Tensor, token_h: int, token_w: i Returns: Upsampled token IDs of shape [1, N*4] where N*4 = (token_h*2) * (token_w*2) """ - # Reshape to spatial format: [1, 1, H, W] token_ids = token_ids.view(1, 1, token_h, token_w) - - # 2x nearest-neighbor upsampling token_ids = torch.nn.functional.interpolate(token_ids.float(), scale_factor=2, mode="nearest").to( dtype=torch.long ) - # Flatten back to [1, H*W*4] token_ids = token_ids.view(1, -1) - return token_ids + def _build_prompt_with_shape( + self, + prompt: str, + height: int, + width: int, + is_text_to_image: bool, + factor: int = 32, + ) -> Tuple[str, int, int, int, int]: + """ + Build prompt with shape info (H W) based on height and width. + + Args: + prompt: The raw text prompt without shape info + height: Target image height in pixels + width: Target image width in pixels + is_text_to_image: Whether this is text-to-image (True) or image-to-image (False) + + Returns: + Tuple of (expanded_prompt, token_h, token_w, prev_token_h, prev_token_w) + """ + token_h = height // factor + token_w = width // factor + ratio = token_h / token_w + prev_token_h = int(sqrt(ratio) * (factor // 2)) + prev_token_w = int(sqrt(1 / ratio) * (factor // 2)) + + if is_text_to_image: + expanded_prompt = f"{prompt}{token_h} {token_w}{prev_token_h} {prev_token_w}" + else: + expanded_prompt = f"{prompt}{token_h} {token_w}" + + return expanded_prompt, token_h, token_w, prev_token_h, prev_token_w + def generate_prior_tokens( self, prompt: str, + height: int, + width: int, image: Optional[List[PIL.Image.Image]] = None, + factor: int = 32, ) -> Tuple[torch.Tensor, int, int]: """ Generate prior tokens using the AR (vision_language_encoder) model. + Automatically builds the prompt with shape info based on height/width. Users only need to provide the raw text + prompt without ... tags. + Args: - prompt: The text prompt with shape info (e.g., "description36 24") - image: Optional list of condition images for i2i + prompt: The raw text prompt (without shape info) + height: Target image height in pixels (must be divisible by 32) + width: Target image width in pixels (must be divisible by 32) + image: Optional list of condition images for image-to-image generation Returns: Tuple of (prior_token_ids, pixel_height, pixel_width) - prior_token_ids: Upsampled to d16 format, shape [1, token_h*token_w*4] - - pixel_height: Image height in pixels - - pixel_width: Image width in pixels + - pixel_height: Image height in pixels (aligned to 32) + - pixel_width: Image width in pixels (aligned to 32) + """ device = self.vision_language_encoder.device - - # Parse and expand shape info - expanded_prompt, token_h, token_w, prev_h, prev_w = self._parse_and_expand_shape_info(prompt) - - # Build messages for processor + height = (height // factor) * factor + width = (width // factor) * factor + is_text_to_image = image is None or len(image) == 0 + expanded_prompt, token_h, token_w, prev_h, prev_w = self._build_prompt_with_shape( + prompt, height, width, is_text_to_image + ) content = [] if image is not None: for img in image: content.append({"type": "image", "image": img}) content.append({"type": "text", "text": expanded_prompt}) messages = [{"role": "user", "content": content}] - - # Process inputs inputs = self.processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, - return_tensors="pt", + return_tensors="pt" ) - # Determine if text-to-image or image-to-image existing_grid = inputs.get("image_grid_thw") - is_text_to_image = existing_grid is None or existing_grid.numel() == 0 - - # Build image grid inputs["image_grid_thw"] = self._build_image_grid_thw( token_h, token_w, @@ -378,8 +410,8 @@ def generate_prior_tokens( ) prior_token_ids = self._upsample_d32_to_d16(prior_token_ids_d32, token_h, token_w) - pixel_height = token_h * 32 - pixel_width = token_w * 32 + pixel_height = token_h * factor + pixel_width = token_w * factor return prior_token_ids, pixel_height, pixel_width @@ -683,6 +715,8 @@ def __call__( prior_token_id, ar_height, ar_width = self.generate_prior_tokens( prompt=prompt[0] if isinstance(prompt, list) else prompt, image=ar_condition_images, + height=height, + width=width, ) height = height or ar_height @@ -739,9 +773,7 @@ def __call__( .to(self.vae.device, self.vae.dtype) ) empty_glyph_hiddens = torch.zeros_like(prompt_embeds)[:1, :0, ...] - for condition_image, condition_image_prior_token_id in zip( - image, condition_images_prior_token_id - ): + for condition_image, condition_image_prior_token_id in zip(image, condition_images_prior_token_id): condition_image = condition_image.to(device=device, dtype=self.vae.dtype) condition_latent = retrieve_latents( self.vae.encode(condition_image), generator=generator, sample_mode="argmax" From 170d0ba1603aeaa385cea483d8c7f1892338c5a4 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Thu, 8 Jan 2026 00:04:24 +0800 Subject: [PATCH 12/56] remove useless func --- .../pipelines/glm_image/pipeline_glm_image.py | 55 +++---------------- 1 file changed, 7 insertions(+), 48 deletions(-) diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index e4dd7de80938..0de4e5db54ec 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -193,31 +193,6 @@ def __init__( else 128 ) - def _parse_and_expand_shape_info(self, prompt: str) -> Tuple[str, int, int, int, int]: - """ - Parse the shape info from prompt and expand it for AR model. - - Args: - prompt: The prompt containing H W shape specification - - Returns: - Tuple of (expanded_prompt, token_h, token_w, prev_token_h, prev_token_w) - """ - match = re.search(r"(\d+)\s+(\d+)", prompt) - if match is None: - raise ValueError(f"Prompt must contain shape info in format 'H W', got: {prompt}") - - token_h, token_w = int(match.group(1)), int(match.group(2)) - ratio = token_h / token_w - prev_token_h = int(sqrt(ratio) * 16) - prev_token_w = int(sqrt(1 / ratio) * 16) - - old_shape = f"{token_h} {token_w}" - new_shape = f"{token_h} {token_w}{prev_token_h} {prev_token_w}" - expanded_prompt = prompt.replace(old_shape, new_shape) - - return expanded_prompt, token_h, token_w, prev_token_h, prev_token_w - def _build_image_grid_thw( self, token_h: int, @@ -227,14 +202,7 @@ def _build_image_grid_thw( existing_grid: Optional[torch.Tensor] = None, device: Optional[torch.device] = None, ) -> torch.Tensor: - """ - Build image grid tensor for AR model. - - For text-to-image: creates grid for large image + small image For image-to-image: appends new image to existing - grid - """ if existing_grid is None or existing_grid.numel() == 0: - # Text-to-image: large image + small image return torch.tensor( [ [1, token_h, token_w], @@ -243,8 +211,7 @@ def _build_image_grid_thw( device=device, ) else: - # Image-to-image: append to existing - return torch.cat([existing_grid, torch.tensor([[1, token_h, token_w]], device=device)], dim=0) + return torch.cat([existing_grid.to(device), torch.tensor([[1, token_h, token_w]], device=device)], dim=0) def _calculate_ar_generation_params( self, token_h: int, token_w: int, prev_token_h: int, prev_token_w: int, is_text_to_image: bool @@ -267,9 +234,6 @@ def _calculate_ar_generation_params( def _extract_large_image_tokens( self, outputs: torch.Tensor, input_length: int, large_image_start_offset: int, large_image_tokens: int ) -> torch.Tensor: - """ - Extract the large image tokens from AR model output. - """ generated_tokens = outputs[0][input_length:] large_image_start = large_image_start_offset large_image_end = large_image_start + large_image_tokens @@ -347,15 +311,16 @@ def generate_prior_tokens( Args: prompt: The raw text prompt (without shape info) - height: Target image height in pixels (must be divisible by 32) - width: Target image width in pixels (must be divisible by 32) + height: Target image height in pixels (must be divisible by factor) + width: Target image width in pixels (must be divisible by factor) image: Optional list of condition images for image-to-image generation + factor: Token size factor (32 for d32 tokens) Returns: Tuple of (prior_token_ids, pixel_height, pixel_width) - prior_token_ids: Upsampled to d16 format, shape [1, token_h*token_w*4] - - pixel_height: Image height in pixels (aligned to 32) - - pixel_width: Image width in pixels (aligned to 32) + - pixel_height: Image height in pixels (aligned to factor) + - pixel_width: Image width in pixels (aligned to factor) """ device = self.vision_language_encoder.device @@ -372,11 +337,7 @@ def generate_prior_tokens( content.append({"type": "text", "text": expanded_prompt}) messages = [{"role": "user", "content": content}] inputs = self.processor.apply_chat_template( - messages, - add_generation_prompt=True, - tokenize=True, - return_dict=True, - return_tensors="pt" + messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ) existing_grid = inputs.get("image_grid_thw") @@ -389,13 +350,11 @@ def generate_prior_tokens( device=device, ) - # Calculate generation parameters max_new_tokens, large_image_offset = self._calculate_ar_generation_params( token_h, token_w, prev_h, prev_w, is_text_to_image ) large_image_tokens = token_h * token_w - # Move inputs to device and generate inputs = inputs.to(device) input_length = inputs["input_ids"].shape[-1] From 144c07565937d17578a9d1a96edc0a1e85ed6210 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Thu, 8 Jan 2026 15:43:52 +0800 Subject: [PATCH 13/56] Update pipeline_glm_image.py --- .../pipelines/glm_image/pipeline_glm_image.py | 50 ++++++++----------- 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index 0de4e5db54ec..4cece9609da7 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -302,7 +302,7 @@ def generate_prior_tokens( width: int, image: Optional[List[PIL.Image.Image]] = None, factor: int = 32, - ) -> Tuple[torch.Tensor, int, int]: + ) -> Tuple[torch.Tensor, torch.Tensor, int, int]: """ Generate prior tokens using the AR (vision_language_encoder) model. @@ -364,6 +364,16 @@ def generate_prior_tokens( do_sample=True, ) + ## TODO: transformers not return image_ids so need run again to get image_ids, need optimize + prior_token_image_ids = None + if image is not None: + prior_token_image_embed = self.vision_language_encoder.get_image_features( + inputs["pixel_values"], existing_grid + ) + prior_token_image_embed = torch.cat(prior_token_image_embed, dim=0) + prior_token_image_ids = self.vision_language_encoder.get_image_tokens( + prior_token_image_embed, existing_grid + ) prior_token_ids_d32 = self._extract_large_image_tokens( outputs, input_length, large_image_offset, large_image_tokens ) @@ -372,7 +382,7 @@ def generate_prior_tokens( pixel_height = token_h * factor pixel_width = token_w * factor - return prior_token_ids, pixel_height, pixel_width + return prior_token_ids, prior_token_image_ids, pixel_height, pixel_width def get_glyph_texts(self, prompt): prompt = prompt[0] if isinstance(prompt, list) else prompt @@ -651,29 +661,9 @@ def __call__( device = self._execution_device - ar_condition_images = None - if image is not None: - if not isinstance(image, list): - image = [image] - ar_condition_images = [] - for img in image: - if isinstance(img, PIL.Image.Image): - ar_condition_images.append(img) - elif isinstance(img, torch.Tensor): - img_np = img.cpu().numpy() - if img_np.ndim == 4: - img_np = img_np[0] - if img_np.shape[0] in [1, 3, 4]: - img_np = img_np.transpose(1, 2, 0) - if img_np.max() <= 1.0: - img_np = (img_np * 255).astype(np.uint8) - ar_condition_images.append(PIL.Image.fromarray(img_np)) - else: - ar_condition_images.append(PIL.Image.fromarray(img)) - - prior_token_id, ar_height, ar_width = self.generate_prior_tokens( + prior_token_id, prior_token_image_ids, ar_height, ar_width = self.generate_prior_tokens( prompt=prompt[0] if isinstance(prompt, list) else prompt, - image=ar_condition_images, + image=image, height=height, width=width, ) @@ -681,7 +671,7 @@ def __call__( height = height or ar_height width = width or ar_width - # 3. Encode input prompt + # 4. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, self.do_classifier_free_guidance, @@ -693,10 +683,8 @@ def __call__( ) # 4. process images - condition_images_prior_token_id = None if image is not None: preprocessed_condition_images = [] - condition_images_prior_token_id = [] for img in image: image_height, image_width = img.size[::-1] if isinstance(img, PIL.Image.Image) else img.shape[:2] multiple_of = self.vae_scale_factor * self.transformer.config.patch_size @@ -704,7 +692,11 @@ def __call__( image_width = (image_width // multiple_of) * multiple_of img = self.image_processor.preprocess(img, height=image_height, width=image_width) preprocessed_condition_images.append(img) + height = height or image_height + width = width or image_width image = preprocessed_condition_images + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor # 5. Prepare latents and (optional) image kv cache latent_channels = self.transformer.config.in_channels @@ -719,7 +711,7 @@ def __call__( latents=latents, ) - if image is not None and condition_images_prior_token_id is not None: + if image is not None: self.transformer.set_attention_processors_state(GlmImageAttenProcessorState.ImageEditWriteKV) latents_mean = ( torch.tensor(self.vae.config.latents_mean) @@ -732,7 +724,7 @@ def __call__( .to(self.vae.device, self.vae.dtype) ) empty_glyph_hiddens = torch.zeros_like(prompt_embeds)[:1, :0, ...] - for condition_image, condition_image_prior_token_id in zip(image, condition_images_prior_token_id): + for condition_image, condition_image_prior_token_id in zip(image, prior_token_image_ids): condition_image = condition_image.to(device=device, dtype=self.vae.dtype) condition_latent = retrieve_latents( self.vae.encode(condition_image), generator=generator, sample_mode="argmax" From 86f5ce4ebe2e840ddb57afee528f8401ee93ebbf Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 8 Jan 2026 03:15:17 +0100 Subject: [PATCH 14/56] up (cherry picked from commit cfe19a31b9bc14c090c7259c09f3532dfafcd059) --- .../transformers/transformer_glm_image.py | 85 +++++++++++++------ .../pipelines/glm_image/pipeline_glm_image.py | 38 ++++----- 2 files changed, 76 insertions(+), 47 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_glm_image.py b/src/diffusers/models/transformers/transformer_glm_image.py index 6ff45ac9e8ea..997c19ded935 100644 --- a/src/diffusers/models/transformers/transformer_glm_image.py +++ b/src/diffusers/models/transformers/transformer_glm_image.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from enum import Enum from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -109,11 +108,50 @@ def forward( ) -class GlmImageAttenProcessorState(Enum): - ImageGen = "ImageGen" - ImageEditWriteKV = "ImageEditWriteKV" - ImageEditReadKV = "ImageEditReadKV" - ImageEditDontReadKV = "ImageEditNoReadKV" +class GlmImageLayerKVCache: + """KV cache for GlmImage model.""" + + def __init__(self): + self.k_cache = None + self.v_cache = None + self.mode: Optional[str] = None # "write", "read", "skip" + + def store(self, k: torch.Tensor, v: torch.Tensor): + if self.k_cache is None: + self.k_cache = k + self.v_cache = v + else: + self.k_cache = torch.cat([self.k_cache, k], dim=2) + self.v_cache = torch.cat([self.v_cache, v], dim=2) + + def get(self): + return self.k_cache, self.v_cache + + def clear(self): + self.k_cache = None + self.v_cache = None + self.mode = None + + +class GlmImageKVCache: + """Container for all layers' KV caches.""" + + def __init__(self, num_layers: int): + self.num_layers = num_layers + self.caches = [GlmImageLayerKVCache() for _ in range(num_layers)] + + def __getitem__(self, layer_idx: int) -> GlmImageLayerKVCache: + return self.caches[layer_idx] + + def set_mode(self, mode: Optional[str]): + if mode is not None and mode not in ["write", "read", "skip"]: + raise ValueError(f"Invalid mode: {mode}, must be one of 'write', 'read', 'skip'") + for cache in self.caches: + cache.mode = mode + + def clear(self): + for cache in self.caches: + cache.clear() class GlmImageAttnProcessor: @@ -128,9 +166,6 @@ class GlmImageAttnProcessor: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("GlmImageAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") - self.processor_state = GlmImageAttenProcessorState.ImageGen - self.k_cache = None - self.v_cache = None def __call__( self, @@ -139,6 +174,7 @@ def __call__( encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + kv_cache: Optional[GlmImageLayerKVCache] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: dtype = encoder_hidden_states.dtype @@ -172,12 +208,15 @@ def __call__( key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2 ) - if self.processor_state == GlmImageAttenProcessorState.ImageEditWriteKV: - self.k_cache = key if self.k_cache is None else torch.cat([self.k_cache, key], dim=2) - self.v_cache = value if self.v_cache is None else torch.cat([self.v_cache, value], dim=2) - elif self.processor_state == GlmImageAttenProcessorState.ImageEditReadKV: - key = torch.cat([self.k_cache, key], dim=2) if self.k_cache is not None else key - value = torch.cat([self.v_cache, value], dim=2) if self.v_cache is not None else value + if kv_cache is not None: + if kv_cache.mode == "write": + kv_cache.store(key, value) + elif kv_cache.mode == "read": + k_cache, v_cache = kv_cache.get() + key = torch.cat([k_cache, key], dim=2) if k_cache is not None else key + value = torch.cat([v_cache, value], dim=2) if v_cache is not None else value + elif kv_cache.mode == "skip": + pass # 4. Attention if attention_mask is not None: @@ -246,6 +285,7 @@ def forward( ] = None, attention_mask: Optional[Dict[str, torch.Tensor]] = None, attention_kwargs: Optional[Dict[str, Any]] = None, + kv_cache: Optional[GlmImageLayerKVCache] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # 1. Timestep conditioning ( @@ -270,6 +310,7 @@ def forward( encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, attention_mask=attention_mask, + kv_cache=kv_cache, **attention_kwargs, ) hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1) @@ -464,6 +505,7 @@ def forward( attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, attention_mask: Optional[torch.Tensor] = None, + kv_caches: Optional[GlmImageKVCache] = None, image_rotary_emb: Optional[ Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]] ] = None, @@ -491,7 +533,7 @@ def forward( temb = F.silu(temb) # 3. Transformer blocks - for block in self.transformer_blocks: + for idx, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( block, @@ -501,6 +543,7 @@ def forward( image_rotary_emb, attention_mask, attention_kwargs, + kv_caches[idx] if kv_caches is not None else None, ) else: hidden_states, encoder_hidden_states = block( @@ -510,6 +553,7 @@ def forward( image_rotary_emb, attention_mask, attention_kwargs, + kv_cache=kv_caches[idx] if kv_caches is not None else None, ) # 4. Output norm & projection @@ -523,12 +567,3 @@ def forward( if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) - - def set_attention_processors_state(self, state: GlmImageAttenProcessorState): - for block in self.transformer_blocks: - block.attn1.processor.processor_state = state - - def clear_attention_processors_cache(self): - for block in self.transformer_blocks: - block.attn1.processor.k_cache = None - block.attn1.processor.v_cache = None diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index 4cece9609da7..f933e24af331 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -27,7 +27,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import CogView4LoraLoaderMixin from ...models import AutoencoderKL, GlmImageTransformer2DModel -from ...models.transformers.transformer_glm_image import GlmImageAttenProcessorState +from ...models.transformers.transformer_glm_image import GlmImageKVCache from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring @@ -302,7 +302,7 @@ def generate_prior_tokens( width: int, image: Optional[List[PIL.Image.Image]] = None, factor: int = 32, - ) -> Tuple[torch.Tensor, torch.Tensor, int, int]: + ) -> Tuple[torch.Tensor, int, int]: """ Generate prior tokens using the AR (vision_language_encoder) model. @@ -364,16 +364,6 @@ def generate_prior_tokens( do_sample=True, ) - ## TODO: transformers not return image_ids so need run again to get image_ids, need optimize - prior_token_image_ids = None - if image is not None: - prior_token_image_embed = self.vision_language_encoder.get_image_features( - inputs["pixel_values"], existing_grid - ) - prior_token_image_embed = torch.cat(prior_token_image_embed, dim=0) - prior_token_image_ids = self.vision_language_encoder.get_image_tokens( - prior_token_image_embed, existing_grid - ) prior_token_ids_d32 = self._extract_large_image_tokens( outputs, input_length, large_image_offset, large_image_tokens ) @@ -382,7 +372,7 @@ def generate_prior_tokens( pixel_height = token_h * factor pixel_width = token_w * factor - return prior_token_ids, prior_token_image_ids, pixel_height, pixel_width + return prior_token_ids, pixel_height, pixel_width def get_glyph_texts(self, prompt): prompt = prompt[0] if isinstance(prompt, list) else prompt @@ -671,7 +661,7 @@ def __call__( height = height or ar_height width = width or ar_width - # 4. Encode input prompt + # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, self.do_classifier_free_guidance, @@ -683,8 +673,10 @@ def __call__( ) # 4. process images + condition_images_prior_token_id = None if image is not None: preprocessed_condition_images = [] + condition_images_prior_token_id = [] for img in image: image_height, image_width = img.size[::-1] if isinstance(img, PIL.Image.Image) else img.shape[:2] multiple_of = self.vae_scale_factor * self.transformer.config.patch_size @@ -711,8 +703,10 @@ def __call__( latents=latents, ) + kv_caches = GlmImageKVCache(num_layers=self.transformer.config.num_layers) + if image is not None: - self.transformer.set_attention_processors_state(GlmImageAttenProcessorState.ImageEditWriteKV) + kv_caches.set_mode("write") latents_mean = ( torch.tensor(self.vae.config.latents_mean) .view(1, self.vae.config.latent_channels, 1, 1) @@ -724,7 +718,7 @@ def __call__( .to(self.vae.device, self.vae.dtype) ) empty_glyph_hiddens = torch.zeros_like(prompt_embeds)[:1, :0, ...] - for condition_image, condition_image_prior_token_id in zip(image, prior_token_image_ids): + for condition_image, condition_image_prior_token_id in zip(image, condition_images_prior_token_id): condition_image = condition_image.to(device=device, dtype=self.vae.dtype) condition_latent = retrieve_latents( self.vae.encode(condition_image), generator=generator, sample_mode="argmax" @@ -739,6 +733,7 @@ def __call__( target_size=torch.tensor([condition_image.shape[-2:]], device=device), crop_coords=torch.zeros((1, 2), device=device), attention_kwargs=attention_kwargs, + kv_caches=kv_caches, ) # 6. Prepare additional timestep conditions @@ -788,7 +783,7 @@ def __call__( timestep = t.expand(latents.shape[0]) - 1 if image is not None: - self.transformer.set_attention_processors_state(GlmImageAttenProcessorState.ImageEditReadKV) + kv_caches.set_mode("read") noise_pred_cond = self.transformer( hidden_states=latent_model_input, @@ -800,14 +795,13 @@ def __call__( crop_coords=crops_coords_top_left, attention_kwargs=attention_kwargs, return_dict=False, + kv_caches=kv_caches, )[0].float() # perform guidance if self.do_classifier_free_guidance: if image is not None: - self.transformer.set_attention_processors_state( - GlmImageAttenProcessorState.ImageEditDontReadKV - ) + kv_caches.set_mode("skip") noise_pred_uncond = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=negative_prompt_embeds, @@ -818,6 +812,7 @@ def __call__( crop_coords=crops_coords_top_left, attention_kwargs=attention_kwargs, return_dict=False, + kv_caches=kv_caches, )[0].float() noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) @@ -841,8 +836,7 @@ def __call__( xm.mark_step() self._current_timestep = None - self.transformer.set_attention_processors_state(GlmImageAttenProcessorState.ImageGen) - self.transformer.clear_attention_processors_cache() + kv_caches.clear() if not output_type == "latent": latents = latents.to(self.vae.dtype) From c65f2249a106f4c4d62dbc591f52a6c9609d3573 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Thu, 8 Jan 2026 16:47:37 +0800 Subject: [PATCH 15/56] review for work only --- .../pipelines/glm_image/pipeline_glm_image.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index f933e24af331..4f65a4943402 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -358,6 +358,16 @@ def generate_prior_tokens( inputs = inputs.to(device) input_length = inputs["input_ids"].shape[-1] + prior_token_image_ids = None + if image is not None: + prior_token_image_embed = self.vision_language_encoder.get_image_features( + inputs["pixel_values"], existing_grid + ) + prior_token_image_embed = torch.cat(prior_token_image_embed, dim=0) + prior_token_image_ids = self.vision_language_encoder.get_image_tokens( + prior_token_image_embed, existing_grid + ) + inputs["image_ids"] = prior_token_image_ids outputs = self.vision_language_encoder.generate( **inputs, max_new_tokens=max_new_tokens, @@ -372,7 +382,7 @@ def generate_prior_tokens( pixel_height = token_h * factor pixel_width = token_w * factor - return prior_token_ids, pixel_height, pixel_width + return prior_token_ids, prior_token_image_ids, pixel_height, pixel_width def get_glyph_texts(self, prompt): prompt = prompt[0] if isinstance(prompt, list) else prompt From e70ebc026d97427df0ca05f8ce395292bab573c2 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Thu, 8 Jan 2026 17:58:59 +0800 Subject: [PATCH 16/56] change place --- src/diffusers/models/embeddings.py | 31 ---------------- .../transformers/transformer_glm_image.py | 36 +++++++++++++++++-- .../pipelines/glm_image/pipeline_glm_image.py | 23 ++---------- 3 files changed, 36 insertions(+), 54 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 1947e1e53490..37fc412adcc3 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1658,37 +1658,6 @@ def forward( return conditioning -class GlmImageCombinedTimestepSizeEmbeddings(nn.Module): - def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256): - super().__init__() - - self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0) - self.condition_proj = Timesteps(num_channels=condition_dim, flip_sin_to_cos=True, downscale_freq_shift=0) - self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim) - self.condition_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") - - def forward( - self, - timestep: torch.Tensor, - target_size: torch.Tensor, - crop_coords: torch.Tensor, - hidden_dtype: torch.dtype, - ) -> torch.Tensor: - timesteps_proj = self.time_proj(timestep) - - crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.size(0), -1) - target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.size(0), -1) - - # (B, 2 * condition_dim) - condition_proj = torch.cat([crop_coords_proj, target_size_proj], dim=1) - - timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim) - condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim) - - conditioning = timesteps_emb + condition_emb - return conditioning - - class HunyuanDiTAttentionPool(nn.Module): # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6 diff --git a/src/diffusers/models/transformers/transformer_glm_image.py b/src/diffusers/models/transformers/transformer_glm_image.py index 997c19ded935..fea2f8c142d4 100644 --- a/src/diffusers/models/transformers/transformer_glm_image.py +++ b/src/diffusers/models/transformers/transformer_glm_image.py @@ -25,7 +25,7 @@ from ..attention import FeedForward from ..attention_processor import Attention from ..cache_utils import CacheMixin -from ..embeddings import GlmImageCombinedTimestepSizeEmbeddings +from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import LayerNorm, RMSNorm @@ -34,6 +34,37 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class GlmImageCombinedTimestepSizeEmbeddings(nn.Module): + def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256): + super().__init__() + + self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.condition_proj = Timesteps(num_channels=condition_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim) + self.condition_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") + + def forward( + self, + timestep: torch.Tensor, + target_size: torch.Tensor, + crop_coords: torch.Tensor, + hidden_dtype: torch.dtype, + ) -> torch.Tensor: + timesteps_proj = self.time_proj(timestep) + + crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.size(0), -1) + target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.size(0), -1) + + # (B, 2 * condition_dim) + condition_proj = torch.cat([crop_coords_proj, target_size_proj], dim=1) + + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim) + condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim) + + conditioning = timesteps_emb + condition_emb + return conditioning + + class GlmImageImageProjector(nn.Module): def __init__( self, @@ -302,8 +333,7 @@ def forward( ) = self.norm1(hidden_states, encoder_hidden_states, temb) # 2. Attention - if attention_kwargs is None: - attention_kwargs = {} + attention_kwargs = attention_kwargs or {} attn_hidden_states, attn_encoder_hidden_states = self.attn1( hidden_states=norm_hidden_states, diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index 4f65a4943402..c9830622ae46 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -239,26 +239,11 @@ def _extract_large_image_tokens( large_image_end = large_image_start + large_image_tokens return generated_tokens[large_image_start:large_image_end] - def _upsample_d32_to_d16(self, token_ids: torch.Tensor, token_h: int, token_w: int) -> torch.Tensor: - """ - Upsample token IDs from d32 format to d16 format. - - AR model generates tokens at d32 resolution (each token = 32x32 pixels). DiT expects tokens at d16 resolution - (each token = 16x16 pixels). This function performs 2x nearest-neighbor upsampling. - - Args: - token_ids: Token IDs of shape [N] where N = token_h * token_w - token_h: Height in d32 token units - token_w: Width in d32 token units - - Returns: - Upsampled token IDs of shape [1, N*4] where N*4 = (token_h*2) * (token_w*2) - """ + def _upsample_token_ids(self, token_ids: torch.Tensor, token_h: int, token_w: int) -> torch.Tensor: token_ids = token_ids.view(1, 1, token_h, token_w) token_ids = torch.nn.functional.interpolate(token_ids.float(), scale_factor=2, mode="nearest").to( dtype=torch.long ) - token_ids = token_ids.view(1, -1) return token_ids @@ -377,7 +362,7 @@ def generate_prior_tokens( prior_token_ids_d32 = self._extract_large_image_tokens( outputs, input_length, large_image_offset, large_image_tokens ) - prior_token_ids = self._upsample_d32_to_d16(prior_token_ids_d32, token_h, token_w) + prior_token_ids = self._upsample_token_ids(prior_token_ids_d32, token_h, token_w) pixel_height = token_h * factor pixel_width = token_w * factor @@ -683,10 +668,8 @@ def __call__( ) # 4. process images - condition_images_prior_token_id = None if image is not None: preprocessed_condition_images = [] - condition_images_prior_token_id = [] for img in image: image_height, image_width = img.size[::-1] if isinstance(img, PIL.Image.Image) else img.shape[:2] multiple_of = self.vae_scale_factor * self.transformer.config.patch_size @@ -728,7 +711,7 @@ def __call__( .to(self.vae.device, self.vae.dtype) ) empty_glyph_hiddens = torch.zeros_like(prompt_embeds)[:1, :0, ...] - for condition_image, condition_image_prior_token_id in zip(image, condition_images_prior_token_id): + for condition_image, condition_image_prior_token_id in zip(image, prior_token_image_ids): condition_image = condition_image.to(device=device, dtype=self.vae.dtype) condition_latent = retrieve_latents( self.vae.encode(condition_image), generator=generator, sample_mode="argmax" From 762f9a303042772d17b9613d89f1775ee06e4d87 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Thu, 8 Jan 2026 18:09:17 +0800 Subject: [PATCH 17/56] Update pipeline_glm_image.py --- src/diffusers/pipelines/glm_image/pipeline_glm_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index c9830622ae46..0ff7b507163b 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -352,7 +352,7 @@ def generate_prior_tokens( prior_token_image_ids = self.vision_language_encoder.get_image_tokens( prior_token_image_embed, existing_grid ) - inputs["image_ids"] = prior_token_image_ids + outputs = self.vision_language_encoder.generate( **inputs, max_new_tokens=max_new_tokens, From 5a0a9fa2e2a4948781ec5ca92e0c576d168947be Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Thu, 8 Jan 2026 18:19:46 +0800 Subject: [PATCH 18/56] update --- .../transformers/transformer_glm_image.py | 3 +- .../pipelines/glm_image/pipeline_glm_image.py | 40 ++++++++++++------- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_glm_image.py b/src/diffusers/models/transformers/transformer_glm_image.py index fea2f8c142d4..b73d658ba608 100644 --- a/src/diffusers/models/transformers/transformer_glm_image.py +++ b/src/diffusers/models/transformers/transformer_glm_image.py @@ -62,6 +62,8 @@ def forward( condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim) conditioning = timesteps_emb + condition_emb + conditioning = F.silu(conditioning) + return conditioning @@ -560,7 +562,6 @@ def forward( hidden_states = hidden_states + prior_hidden_states temb = self.time_condition_embed(timestep, target_size, crop_coords, hidden_states.dtype) - temb = F.silu(temb) # 3. Transformer blocks for idx, block in enumerate(self.transformer_blocks): diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index 0ff7b507163b..aa6689dd4c8b 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -53,7 +53,7 @@ >>> pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") - >>> prompt = "A photo of an astronaut riding a horse on mars36 24" + >>> prompt = "A photo of an astronaut riding a horse on mars" >>> image = pipe(prompt).images[0] >>> image.save("output.png") ``` @@ -71,6 +71,7 @@ def calculate_shift( return mu +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, @@ -158,7 +159,7 @@ class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin): """ _optional_components = [] - model_cpu_offload_seq = "transformer->vae" + model_cpu_offload_seq = "text_encoder->vision_language_encoder->transformer->vae" _callback_tensor_inputs = ["latents", "prompt_embeds"] def __init__( @@ -460,19 +461,6 @@ def encode_prompt( if do_classifier_free_guidance: negative_prompt = "" negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - negative_prompt_embeds = self._get_glyph_embeds(negative_prompt, max_sequence_length, device, dtype) seq_len = negative_prompt_embeds.size(1) @@ -505,6 +493,7 @@ def check_inputs( height, width, callback_on_step_end_tensor_inputs, + do_classifier_free_guidance, prompt_embeds=None, ): if ( @@ -536,6 +525,26 @@ def check_inputs( elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + if do_classifier_free_guidance: + negative_prompt = "" + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + @property def guidance_scale(self): return self._guidance_scale @@ -629,6 +638,7 @@ def __call__( height, width, callback_on_step_end_tensor_inputs, + self.do_classifier_free_guidance, prompt_embeds, ) self._guidance_scale = guidance_scale From 2ae574aee86a4bf976da296d6b510dd4305e2ea8 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Thu, 8 Jan 2026 18:22:58 +0800 Subject: [PATCH 19/56] Update transformer_glm_image.py --- src/diffusers/models/transformers/transformer_glm_image.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/models/transformers/transformer_glm_image.py b/src/diffusers/models/transformers/transformer_glm_image.py index b73d658ba608..8f3c748dc340 100644 --- a/src/diffusers/models/transformers/transformer_glm_image.py +++ b/src/diffusers/models/transformers/transformer_glm_image.py @@ -593,6 +593,8 @@ def forward( # 5. Unpatchify hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p) + + # Rearrange tensor from (B, H_p, W_p, C, p, p) to (B, C, H_p * p, W_p * p) output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) if not return_dict: From 264f930201d236a1e294412d0afd62f4a903e30c Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Thu, 8 Jan 2026 18:29:06 +0800 Subject: [PATCH 20/56] 1 --- src/diffusers/pipelines/glm_image/pipeline_glm_image.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index aa6689dd4c8b..b1cb5bfd497e 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -526,11 +526,12 @@ def check_inputs( raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if do_classifier_free_guidance: - negative_prompt = "" if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] + + negative_prompt = None # Not used in GLM-Image negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt if prompt is not None and type(prompt) is not type(negative_prompt): @@ -633,6 +634,8 @@ def __call__( callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 1. Check inputs + self._guidance_scale = guidance_scale + self.check_inputs( prompt, height, From e9b2c89b6f9b1cd6f60fe7423dd13694868d74de Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Thu, 8 Jan 2026 18:33:09 +0800 Subject: [PATCH 21/56] no negative_prompt for GLM-Image --- .../pipelines/glm_image/pipeline_glm_image.py | 25 ------------------- 1 file changed, 25 deletions(-) diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index b1cb5bfd497e..77c88851fd27 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -493,7 +493,6 @@ def check_inputs( height, width, callback_on_step_end_tensor_inputs, - do_classifier_free_guidance, prompt_embeds=None, ): if ( @@ -525,27 +524,6 @@ def check_inputs( elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - if do_classifier_free_guidance: - if prompt is not None: - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - negative_prompt = None # Not used in GLM-Image - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - @property def guidance_scale(self): return self._guidance_scale @@ -634,14 +612,11 @@ def __call__( callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 1. Check inputs - self._guidance_scale = guidance_scale - self.check_inputs( prompt, height, width, callback_on_step_end_tensor_inputs, - self.do_classifier_free_guidance, prompt_embeds, ) self._guidance_scale = guidance_scale From e4f6549fca392417b63e0903969781cf64783456 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Thu, 8 Jan 2026 18:38:34 +0800 Subject: [PATCH 22/56] remove CogView4LoraLoaderMixin --- src/diffusers/pipelines/glm_image/pipeline_glm_image.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index 77c88851fd27..f99c36e867c0 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -25,7 +25,6 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import VaeImageProcessor -from ...loaders import CogView4LoraLoaderMixin from ...models import AutoencoderKL, GlmImageTransformer2DModel from ...models.transformers.transformer_glm_image import GlmImageKVCache from ...pipelines.pipeline_utils import DiffusionPipeline @@ -134,7 +133,7 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin): +class GlmImagePipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using GLM-Image. @@ -705,6 +704,10 @@ def __call__( self.vae.encode(condition_image), generator=generator, sample_mode="argmax" ) condition_latent = (condition_latent - latents_mean) / latents_std + + # Do not remove. + # It would be use to run the reference image through a + # forward pass at timestep 0 and keep the KV cache. _ = self.transformer( hidden_states=condition_latent, encoder_hidden_states=empty_glyph_hiddens, From 51f801504627177c18e37ae30f25774c007cffc9 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 8 Jan 2026 17:29:08 +0530 Subject: [PATCH 23/56] refactor attention processor. --- .../transformers/transformer_glm_image.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_glm_image.py b/src/diffusers/models/transformers/transformer_glm_image.py index 8f3c748dc340..bacc0e154f5a 100644 --- a/src/diffusers/models/transformers/transformer_glm_image.py +++ b/src/diffusers/models/transformers/transformer_glm_image.py @@ -24,6 +24,7 @@ from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward from ..attention_processor import Attention +from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput @@ -196,6 +197,9 @@ class GlmImageAttnProcessor: text_seq_length) where 1 indicates a non-padded token and 0 indicates a padded token. """ + _attention_backend = None + _parallel_config = None + def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("GlmImageAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") @@ -262,11 +266,16 @@ def __call__( attn_mask_matrix = mix_attn_mask @ mix_attn_mask.transpose(1, 2) attention_mask = (attn_mask_matrix > 0).unsqueeze(1).to(query.dtype) - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + backend=self._attention_backend, + parallel_config=self._parallel_config, ) - hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) - hidden_states = hidden_states.type_as(query) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) # 5. Output projection hidden_states = attn.to_out[0](hidden_states) From 075b6a952eb8bbb9af67c8d8f882505fd93b11dc Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Thu, 8 Jan 2026 20:32:01 +0800 Subject: [PATCH 24/56] update --- docs/source/en/api/pipelines/glm_image.md | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/docs/source/en/api/pipelines/glm_image.md b/docs/source/en/api/pipelines/glm_image.md index c3787cd77b37..d2b0777fee6b 100644 --- a/docs/source/en/api/pipelines/glm_image.md +++ b/docs/source/en/api/pipelines/glm_image.md @@ -15,10 +15,26 @@ # GLM-Image -> [!TIP] -> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. +## Overview -This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/zai-org). The original weights can be found under [hf.co/zai-org](https://huggingface.co/zai-org). +GLM-Image is an image generation model adopts a hybrid autoregressive + diffusion decoder architecture, effectively pushing the upper bound of visual fidelity and fine-grained details. In general image generation quality, it aligns with industry-standard LDM-based approaches, while demonstrating significant advantages in knowledge-intensive image generation scenarios. + +Model architecture: a hybrid autoregressive + diffusion decoder design、 + ++ Autoregressive generator: a 9B-parameter model initialized from [GLM-4-9B-0414](https://huggingface.co/zai-org/GLM-4-9B-0414), with an expanded vocabulary to incorporate visual tokens. The model first generates a compact encoding of approximately 256 tokens, then expands to 1K–4K tokens, corresponding to 1K–2K high-resolution image outputs. You can check AR model in class `GlmImageForConditionalGeneration` of transformers library. ++ Diffusion Decoder: a 7B-parameter decoder based on a single-stream DiT architecture for latent-space image decoding. It is equipped with a Glyph Encoder text module, significantly improving accurate text rendering within images. + +Post-training with decoupled reinforcement learning: the model introduces a fine-grained, modular feedback strategy using the GRPO algorithm, substantially enhancing both semantic understanding and visual detail quality. + ++ Autoregressive module: provides low-frequency feedback signals focused on aesthetics and semantic alignment, improving instruction following and artistic expressiveness. ++ Decoder module: delivers high-frequency feedback targeting detail fidelity and text accuracy, resulting in highly realistic textures, lighting, and color reproduction, as well as more precise text rendering. + +GLM-Image supports both text-to-image and image-to-image generation within a single model + ++ Text-to-image: generates high-detail images from textual descriptions, with particularly strong performance in information-dense scenarios. ++ Image-to-image: supports a wide range of tasks, including image editing, style transfer, multi-subject consistency, and identity-preserving generation for people and objects. + +This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The codebase can be found [here](https://huggingface.co/zai-org/GLM-Image). ## GlmImagePipeline From e2d4bda5c54b7cb5c11907a17dbfb1e3cb6d0e99 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 8 Jan 2026 18:05:10 +0530 Subject: [PATCH 25/56] fix --- src/diffusers/models/transformers/transformer_glm_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_glm_image.py b/src/diffusers/models/transformers/transformer_glm_image.py index bacc0e154f5a..441124de6ed8 100644 --- a/src/diffusers/models/transformers/transformer_glm_image.py +++ b/src/diffusers/models/transformers/transformer_glm_image.py @@ -274,7 +274,7 @@ def __call__( backend=self._attention_backend, parallel_config=self._parallel_config, ) - hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) hidden_states = hidden_states.to(query.dtype) # 5. Output projection From 854e861c393b620348604bd88c7c8e0aea5a8e25 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Thu, 8 Jan 2026 20:42:46 +0800 Subject: [PATCH 26/56] use staticmethod --- .../pipelines/glm_image/pipeline_glm_image.py | 56 ++++--------------- 1 file changed, 12 insertions(+), 44 deletions(-) diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index f99c36e867c0..5f738e656f3f 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -141,16 +141,16 @@ class GlmImagePipeline(DiffusionPipeline): transformer) model for image decoding. Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`T5EncoderModel`]): - Frozen text-encoder for glyph embeddings. tokenizer (`PreTrainedTokenizer`): Tokenizer for the text encoder. processor (`AutoProcessor`): Processor for the AR model to handle chat templates and tokenization. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder for glyph embeddings. vision_language_encoder ([`GlmImageForConditionalGeneration`]): The AR model that generates image tokens from text prompts. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. transformer ([`GlmImageTransformer2DModel`]): A text conditioned transformer to denoise the encoded image latents (DiT). scheduler ([`SchedulerMixin`]): @@ -193,8 +193,8 @@ def __init__( else 128 ) + @staticmethod def _build_image_grid_thw( - self, token_h: int, token_w: int, prev_token_h: int, @@ -213,12 +213,10 @@ def _build_image_grid_thw( else: return torch.cat([existing_grid.to(device), torch.tensor([[1, token_h, token_w]], device=device)], dim=0) + @staticmethod def _calculate_ar_generation_params( - self, token_h: int, token_w: int, prev_token_h: int, prev_token_w: int, is_text_to_image: bool + token_h: int, token_w: int, prev_token_h: int, prev_token_w: int, is_text_to_image: bool ) -> Tuple[int, int]: - """ - Calculate max_new_tokens and large_image_start_offset for AR generation. - """ large_image_tokens = token_h * token_w small_image_tokens = prev_token_h * prev_token_w @@ -231,15 +229,17 @@ def _calculate_ar_generation_params( return max_new_tokens, large_image_start_offset + @staticmethod def _extract_large_image_tokens( - self, outputs: torch.Tensor, input_length: int, large_image_start_offset: int, large_image_tokens: int + outputs: torch.Tensor, input_length: int, large_image_start_offset: int, large_image_tokens: int ) -> torch.Tensor: generated_tokens = outputs[0][input_length:] large_image_start = large_image_start_offset large_image_end = large_image_start + large_image_tokens return generated_tokens[large_image_start:large_image_end] - def _upsample_token_ids(self, token_ids: torch.Tensor, token_h: int, token_w: int) -> torch.Tensor: + @staticmethod + def _upsample_token_ids(token_ids: torch.Tensor, token_h: int, token_w: int) -> torch.Tensor: token_ids = token_ids.view(1, 1, token_h, token_w) token_ids = torch.nn.functional.interpolate(token_ids.float(), scale_factor=2, mode="nearest").to( dtype=torch.long @@ -247,26 +247,14 @@ def _upsample_token_ids(self, token_ids: torch.Tensor, token_h: int, token_w: in token_ids = token_ids.view(1, -1) return token_ids + @staticmethod def _build_prompt_with_shape( - self, prompt: str, height: int, width: int, is_text_to_image: bool, factor: int = 32, ) -> Tuple[str, int, int, int, int]: - """ - Build prompt with shape info (H W) based on height and width. - - Args: - prompt: The raw text prompt without shape info - height: Target image height in pixels - width: Target image width in pixels - is_text_to_image: Whether this is text-to-image (True) or image-to-image (False) - - Returns: - Tuple of (expanded_prompt, token_h, token_w, prev_token_h, prev_token_w) - """ token_h = height // factor token_w = width // factor ratio = token_h / token_w @@ -288,26 +276,6 @@ def generate_prior_tokens( image: Optional[List[PIL.Image.Image]] = None, factor: int = 32, ) -> Tuple[torch.Tensor, int, int]: - """ - Generate prior tokens using the AR (vision_language_encoder) model. - - Automatically builds the prompt with shape info based on height/width. Users only need to provide the raw text - prompt without ... tags. - - Args: - prompt: The raw text prompt (without shape info) - height: Target image height in pixels (must be divisible by factor) - width: Target image width in pixels (must be divisible by factor) - image: Optional list of condition images for image-to-image generation - factor: Token size factor (32 for d32 tokens) - - Returns: - Tuple of (prior_token_ids, pixel_height, pixel_width) - - prior_token_ids: Upsampled to d16 format, shape [1, token_h*token_w*4] - - pixel_height: Image height in pixels (aligned to factor) - - pixel_width: Image width in pixels (aligned to factor) - - """ device = self.vision_language_encoder.device height = (height // factor) * factor width = (width // factor) * factor From 786221729e4d47400ccc931f49cd588f2c3bbce8 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Thu, 8 Jan 2026 20:52:40 +0800 Subject: [PATCH 27/56] update --- .../pipelines/glm_image/pipeline_glm_image.py | 23 ++++++++----------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index 5f738e656f3f..e27fe4ef130c 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -645,7 +645,7 @@ def __call__( num_channels_latents=latent_channels, height=height, width=width, - dtype=torch.float32, + dtype=prompt_embeds.dtype, device=device, generator=generator, latents=latents, @@ -655,19 +655,14 @@ def __call__( if image is not None: kv_caches.set_mode("write") - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, self.vae.config.latent_channels, 1, 1) - .to(self.vae.device, self.vae.dtype) - ) - latents_std = ( - torch.tensor(self.vae.config.latents_std) - .view(1, self.vae.config.latent_channels, 1, 1) - .to(self.vae.device, self.vae.dtype) - ) - empty_glyph_hiddens = torch.zeros_like(prompt_embeds)[:1, :0, ...] + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.latent_channels, 1, 1) + latents_std = torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.latent_channels, 1, 1) + + latents_mean = latents_mean.to(device=device, dtype=prompt_embeds.dtype) + latents_std = latents_std.to(device=device, dtype=prompt_embeds.dtype) + for condition_image, condition_image_prior_token_id in zip(image, prior_token_image_ids): - condition_image = condition_image.to(device=device, dtype=self.vae.dtype) + condition_image = condition_image.to(device=device, dtype=prompt_embeds.dtype) condition_latent = retrieve_latents( self.vae.encode(condition_image), generator=generator, sample_mode="argmax" ) @@ -678,7 +673,7 @@ def __call__( # forward pass at timestep 0 and keep the KV cache. _ = self.transformer( hidden_states=condition_latent, - encoder_hidden_states=empty_glyph_hiddens, + encoder_hidden_states=torch.zeros_like(prompt_embeds)[:1, :0, ...], prior_token_id=condition_image_prior_token_id, prior_token_drop=torch.full_like(condition_image_prior_token_id, False, dtype=torch.bool), timestep=torch.zeros((1,), device=device), From 1226fcbb81ceac0fe9a3295e8c1373eff8576890 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 8 Jan 2026 18:51:11 +0530 Subject: [PATCH 28/56] up --- .../transformers/transformer_glm_image.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_glm_image.py b/src/diffusers/models/transformers/transformer_glm_image.py index 441124de6ed8..1a53fac8b58c 100644 --- a/src/diffusers/models/transformers/transformer_glm_image.py +++ b/src/diffusers/models/transformers/transformer_glm_image.py @@ -224,9 +224,9 @@ def __call__( key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) - query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) - key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) - value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) # 2. QK normalization if attn.norm_q is not None: @@ -238,11 +238,11 @@ def __call__( if image_rotary_emb is not None: from ..embeddings import apply_rotary_emb - query[:, :, text_seq_length:, :] = apply_rotary_emb( - query[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2 + query[:, text_seq_length:, :, :] = apply_rotary_emb( + query[:, text_seq_length:, :, :], image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2 ) - key[:, :, text_seq_length:, :] = apply_rotary_emb( - key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2 + key[:, text_seq_length:, :, :] = apply_rotary_emb( + key[:, text_seq_length:, :, :], image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2 ) if kv_cache is not None: @@ -271,11 +271,13 @@ def __call__( key, value, attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, backend=self._attention_backend, parallel_config=self._parallel_config, ) - hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) - hidden_states = hidden_states.to(query.dtype) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.type_as(query) # 5. Output projection hidden_states = attn.to_out[0](hidden_states) From 68ebb42ca80619b7717b959c81e8413cda87c7a3 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 8 Jan 2026 19:05:27 +0530 Subject: [PATCH 29/56] up --- src/diffusers/models/transformers/transformer_glm_image.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_glm_image.py b/src/diffusers/models/transformers/transformer_glm_image.py index 1a53fac8b58c..726410dbe1f0 100644 --- a/src/diffusers/models/transformers/transformer_glm_image.py +++ b/src/diffusers/models/transformers/transformer_glm_image.py @@ -250,8 +250,8 @@ def __call__( kv_cache.store(key, value) elif kv_cache.mode == "read": k_cache, v_cache = kv_cache.get() - key = torch.cat([k_cache, key], dim=2) if k_cache is not None else key - value = torch.cat([v_cache, value], dim=2) if v_cache is not None else value + key = torch.cat([k_cache, key], dim=1) if k_cache is not None else key + value = torch.cat([v_cache, value], dim=1) if v_cache is not None else value elif kv_cache.mode == "skip": pass @@ -277,7 +277,7 @@ def __call__( parallel_config=self._parallel_config, ) hidden_states = hidden_states.flatten(2, 3) - hidden_states = hidden_states.type_as(query) + hidden_states = hidden_states.to(query.dtype) # 5. Output projection hidden_states = attn.to_out[0](hidden_states) From 40559ca49396c540cdd9aa1a0806030155d1b75a Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Thu, 8 Jan 2026 22:14:20 +0800 Subject: [PATCH 30/56] update --- docs/source/en/api/pipelines/glm_image.md | 55 +++++++++++++++++++ .../transformers/transformer_glm_image.py | 2 +- 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/glm_image.md b/docs/source/en/api/pipelines/glm_image.md index d2b0777fee6b..a394ba59e607 100644 --- a/docs/source/en/api/pipelines/glm_image.md +++ b/docs/source/en/api/pipelines/glm_image.md @@ -36,6 +36,61 @@ GLM-Image supports both text-to-image and image-to-image generation within a sin This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The codebase can be found [here](https://huggingface.co/zai-org/GLM-Image). +## Usage examples + +### Text to Image Generation + +```python +import torch +from diffusers.pipelines.glm_image import GlmImagePipeline +def main(): + pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image",torch_dtype=torch.bfloat16,device_map="cuda") + prompt = "现代美食杂志风格的甜点制作教程图,主题为覆盆子慕斯蛋糕。整体布局干净明亮,分为四个主要区域:顶部左侧是黑色粗体标题“覆盆子慕斯蛋糕制作指南”,右侧搭配光线柔和的成品蛋糕特写照片,蛋糕呈淡粉色,表面点缀新鲜覆盆子与薄荷叶;左下方为配料清单区域,标题“配料”使用简洁字体,下方列有“面粉 150g”“鸡蛋 3个”“细砂糖 120g”“覆盆子果泥 200g”“明胶片 10g”“淡奶油 300ml”“新鲜覆盆子”等配料,每种配料旁配有简约线图标(如面粉袋、鸡蛋、糖罐等);右下方是四个等大的步骤方框,每个方框内含高清微距实拍图及对应操作说明,从上到下依次为:步骤1展示打蛋器打发白色泡沫(对应说明“打发蛋白至干性发泡”),步骤2展示红白相间的混合物被刮刀翻拌(对应说明“轻柔翻拌果泥与面糊”),步骤3展示粉色液体被倒入圆形模具(对应说明“倒入模具并冷藏4小时”),步骤4展示成品蛋糕表面装饰覆盆子与薄荷叶(对应说明“用覆盆子和薄荷装饰”);底部边缘设浅棕色信息条,左侧图标分别代表“准备时间:30分钟”“烹饪时间:20分钟”“份量:8人份”。整体色调以奶油白、淡粉色为主,背景带轻微纸质纹理,图文排版紧凑有序,信息层级分明。" + image = pipe( + prompt=prompt, + height=89 * 32, + width=45 * 32, + num_inference_steps=30, + guidance_scale=1.5, + generator=torch.Generator(device="cuda").manual_seed(42), + ).images[0] + + image.save("output_t2i.png") + +if __name__ == "__main__": + main() +``` + +### Image to Image Generation + +```python +import torch +from diffusers.pipelines.glm_image import GlmImagePipeline +from PIL import Image + +def main(): + pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image",torch_dtype=torch.bfloat16,device_map="cuda") + image_path = "cond.jpg" + prompt = "Replace the background of the snow forest with an underground station featuring an automatic escalator." + image = Image.open(image_path).convert("RGB") + image = pipe( + prompt=prompt, + image=[image], # can input multiple images for multi-image-to-image generation such as [image, image1] + height=33 * 32, + width=32 * 32, + num_inference_steps=30, + guidance_scale=1.5, + generator=torch.Generator(device="cuda").manual_seed(42), + ).images[0] + + image.save("output_t2i.png") + +if __name__ == "__main__": + main() +``` + ++ Since the AR model used in GLM-Image is configured with `do_sample=True` and a temperature of `0.95` by default, the generated images can vary significantly across runs. We do not recommend setting do_sample=False, as this may lead to incorrect or degenerate outputs from the AR model. + ## GlmImagePipeline [[autodoc]] GlmImagePipeline diff --git a/src/diffusers/models/transformers/transformer_glm_image.py b/src/diffusers/models/transformers/transformer_glm_image.py index 726410dbe1f0..093472ac6608 100644 --- a/src/diffusers/models/transformers/transformer_glm_image.py +++ b/src/diffusers/models/transformers/transformer_glm_image.py @@ -23,8 +23,8 @@ from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward -from ..attention_processor import Attention from ..attention_dispatch import dispatch_attention_fn +from ..attention_processor import Attention from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps from ..modeling_outputs import Transformer2DModelOutput From 19fc76b70a780bca04b62eea125025457d56c52f Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Thu, 8 Jan 2026 23:13:36 +0800 Subject: [PATCH 31/56] Update glm_image.md --- docs/source/en/api/pipelines/glm_image.md | 63 ++++++++++------------- 1 file changed, 28 insertions(+), 35 deletions(-) diff --git a/docs/source/en/api/pipelines/glm_image.md b/docs/source/en/api/pipelines/glm_image.md index a394ba59e607..98e0145b6807 100644 --- a/docs/source/en/api/pipelines/glm_image.md +++ b/docs/source/en/api/pipelines/glm_image.md @@ -43,22 +43,19 @@ This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRz ```python import torch from diffusers.pipelines.glm_image import GlmImagePipeline -def main(): - pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image",torch_dtype=torch.bfloat16,device_map="cuda") - prompt = "现代美食杂志风格的甜点制作教程图,主题为覆盆子慕斯蛋糕。整体布局干净明亮,分为四个主要区域:顶部左侧是黑色粗体标题“覆盆子慕斯蛋糕制作指南”,右侧搭配光线柔和的成品蛋糕特写照片,蛋糕呈淡粉色,表面点缀新鲜覆盆子与薄荷叶;左下方为配料清单区域,标题“配料”使用简洁字体,下方列有“面粉 150g”“鸡蛋 3个”“细砂糖 120g”“覆盆子果泥 200g”“明胶片 10g”“淡奶油 300ml”“新鲜覆盆子”等配料,每种配料旁配有简约线图标(如面粉袋、鸡蛋、糖罐等);右下方是四个等大的步骤方框,每个方框内含高清微距实拍图及对应操作说明,从上到下依次为:步骤1展示打蛋器打发白色泡沫(对应说明“打发蛋白至干性发泡”),步骤2展示红白相间的混合物被刮刀翻拌(对应说明“轻柔翻拌果泥与面糊”),步骤3展示粉色液体被倒入圆形模具(对应说明“倒入模具并冷藏4小时”),步骤4展示成品蛋糕表面装饰覆盆子与薄荷叶(对应说明“用覆盆子和薄荷装饰”);底部边缘设浅棕色信息条,左侧图标分别代表“准备时间:30分钟”“烹饪时间:20分钟”“份量:8人份”。整体色调以奶油白、淡粉色为主,背景带轻微纸质纹理,图文排版紧凑有序,信息层级分明。" - image = pipe( - prompt=prompt, - height=89 * 32, - width=45 * 32, - num_inference_steps=30, - guidance_scale=1.5, - generator=torch.Generator(device="cuda").manual_seed(42), - ).images[0] - - image.save("output_t2i.png") - -if __name__ == "__main__": - main() + +pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image",torch_dtype=torch.bfloat16,device_map="cuda") +prompt = "A beautifully designed modern food magazine style dessert recipe illustration, themed around a raspberry mousse cake. The overall layout is clean and bright, divided into four main areas: the top left features a bold black title 'Raspberry Mousse Cake Recipe Guide', with a soft-lit close-up photo of the finished cake on the right, showcasing a light pink cake adorned with fresh raspberries and mint leaves; the bottom left contains an ingredient list section, titled 'Ingredients' in a simple font, listing 'Flour 150g', 'Eggs 3', 'Sugar 120g', 'Raspberry puree 200g', 'Gelatin sheets 10g', 'Whipping cream 300ml', and 'Fresh raspberries', each accompanied by minimalist line icons (like a flour bag, eggs, sugar jar, etc.); the bottom right displays four equally sized step boxes, each containing high-definition macro photos and corresponding instructions, arranged from top to bottom as follows: Step 1 shows a whisk whipping white foam (with the instruction 'Whip egg whites to stiff peaks'), Step 2 shows a red-and-white mixture being folded with a spatula (with the instruction 'Gently fold in the puree and batter'), Step 3 shows pink liquid being poured into a round mold (with the instruction 'Pour into mold and chill for 4 hours'), Step 4 shows the finished cake decorated with raspberries and mint leaves (with the instruction 'Decorate with raspberries and mint'); a light brown information bar runs along the bottom edge, with icons on the left representing 'Preparation time: 30 minutes', 'Cooking time: 20 minutes', and 'Servings: 8'. The overall color scheme is dominated by creamy white and light pink, with a subtle paper texture in the background, featuring compact and orderly text and image layout with clear information hierarchy." +image = pipe( + prompt=prompt, + height=32 * 32, + width=36 * 32, + num_inference_steps=30, + guidance_scale=1.5, + generator=torch.Generator(device="cuda").manual_seed(42), +).images[0] + +image.save("output_t2i.png") ``` ### Image to Image Generation @@ -68,25 +65,21 @@ import torch from diffusers.pipelines.glm_image import GlmImagePipeline from PIL import Image -def main(): - pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image",torch_dtype=torch.bfloat16,device_map="cuda") - image_path = "cond.jpg" - prompt = "Replace the background of the snow forest with an underground station featuring an automatic escalator." - image = Image.open(image_path).convert("RGB") - image = pipe( - prompt=prompt, - image=[image], # can input multiple images for multi-image-to-image generation such as [image, image1] - height=33 * 32, - width=32 * 32, - num_inference_steps=30, - guidance_scale=1.5, - generator=torch.Generator(device="cuda").manual_seed(42), - ).images[0] - - image.save("output_t2i.png") - -if __name__ == "__main__": - main() +pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image",torch_dtype=torch.bfloat16,device_map="cuda") +image_path = "cond.jpg" +prompt = "Replace the background of the snow forest with an underground station featuring an automatic escalator." +image = Image.open(image_path).convert("RGB") +image = pipe( + prompt=prompt, + image=[image], # can input multiple images for multi-image-to-image generation such as [image, image1] + height=33 * 32, + width=32 * 32, + num_inference_steps=30, + guidance_scale=1.5, + generator=torch.Generator(device="cuda").manual_seed(42), +).images[0] + +image.save("output_i2i.png") ``` + Since the AR model used in GLM-Image is configured with `do_sample=True` and a temperature of `0.95` by default, the generated images can vary significantly across runs. We do not recommend setting do_sample=False, as this may lead to incorrect or degenerate outputs from the AR model. From d2a514631ab490b2f3d7c7b40f10e3c5f75c9413 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Fri, 9 Jan 2026 14:45:49 +0800 Subject: [PATCH 32/56] 1 --- .../models/transformers/transformer_glm_image.py | 1 + .../pipelines/glm_image/pipeline_glm_image.py | 15 +++++---------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_glm_image.py b/src/diffusers/models/transformers/transformer_glm_image.py index 093472ac6608..f4b5b819ede5 100644 --- a/src/diffusers/models/transformers/transformer_glm_image.py +++ b/src/diffusers/models/transformers/transformer_glm_image.py @@ -483,6 +483,7 @@ class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach "GlmImageImageProjector", ] _skip_layerwise_casting_patterns = ["patch_embed", "norm", "proj_out"] + _skip_keys = ["kv_caches"] @register_to_config def __init__( diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index e27fe4ef130c..0568685ca286 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -158,7 +158,7 @@ class GlmImagePipeline(DiffusionPipeline): """ _optional_components = [] - model_cpu_offload_seq = "text_encoder->vision_language_encoder->transformer->vae" + model_cpu_offload_seq = "vision_language_encoder->text_encoder->transformer->vae" _callback_tensor_inputs = ["latents", "prompt_embeds"] def __init__( @@ -332,10 +332,7 @@ def generate_prior_tokens( ) prior_token_ids = self._upsample_token_ids(prior_token_ids_d32, token_h, token_w) - pixel_height = token_h * factor - pixel_width = token_w * factor - - return prior_token_ids, prior_token_image_ids, pixel_height, pixel_width + return prior_token_ids, prior_token_image_ids def get_glyph_texts(self, prompt): prompt = prompt[0] if isinstance(prompt, list) else prompt @@ -597,20 +594,18 @@ def __call__( batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] - assert batch_size == 1, "batch_size must be 1" + if batch_size != 1: + raise ValueError(f"batch_size must be 1 due to AR model limitations, got {batch_size}") device = self._execution_device - prior_token_id, prior_token_image_ids, ar_height, ar_width = self.generate_prior_tokens( + prior_token_id, prior_token_image_ids = self.generate_prior_tokens( prompt=prompt[0] if isinstance(prompt, list) else prompt, image=image, height=height, width=width, ) - height = height or ar_height - width = width or ar_width - # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, From 6cfc83b4abc5b083fef56a18ec4700f48ba3aaba Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Fri, 9 Jan 2026 15:10:20 +0800 Subject: [PATCH 33/56] Update pipeline_glm_image.py --- src/diffusers/pipelines/glm_image/pipeline_glm_image.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index 0568685ca286..1c706208d3de 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -291,8 +291,7 @@ def generate_prior_tokens( messages = [{"role": "user", "content": content}] inputs = self.processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" - ) - + ).to(device) existing_grid = inputs.get("image_grid_thw") inputs["image_grid_thw"] = self._build_image_grid_thw( token_h, @@ -307,8 +306,6 @@ def generate_prior_tokens( token_h, token_w, prev_h, prev_w, is_text_to_image ) large_image_tokens = token_h * token_w - - inputs = inputs.to(device) input_length = inputs["input_ids"].shape[-1] prior_token_image_ids = None @@ -320,7 +317,6 @@ def generate_prior_tokens( prior_token_image_ids = self.vision_language_encoder.get_image_tokens( prior_token_image_embed, existing_grid ) - outputs = self.vision_language_encoder.generate( **inputs, max_new_tokens=max_new_tokens, From d1373a9e4f7cdc5b060efa445f0fb9bff09c93b0 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Mon, 12 Jan 2026 15:55:00 +0800 Subject: [PATCH 34/56] Update transformer_glm_image.py --- .../models/transformers/transformer_glm_image.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_glm_image.py b/src/diffusers/models/transformers/transformer_glm_image.py index f4b5b819ede5..30db3dc82968 100644 --- a/src/diffusers/models/transformers/transformer_glm_image.py +++ b/src/diffusers/models/transformers/transformer_glm_image.py @@ -155,11 +155,13 @@ def store(self, k: torch.Tensor, v: torch.Tensor): self.k_cache = k self.v_cache = v else: - self.k_cache = torch.cat([self.k_cache, k], dim=2) - self.v_cache = torch.cat([self.v_cache, v], dim=2) + self.k_cache = torch.cat([self.k_cache, k], dim=1) + self.v_cache = torch.cat([self.v_cache, v], dim=1) - def get(self): - return self.k_cache, self.v_cache + def get(self, k: torch.Tensor, v: torch.Tensor): + k_cache = torch.cat([self.k_cache, k], dim=1) + v_cache = torch.cat([self.v_cache, v], dim=1) + return k_cache, v_cache def clear(self): self.k_cache = None @@ -249,9 +251,7 @@ def __call__( if kv_cache.mode == "write": kv_cache.store(key, value) elif kv_cache.mode == "read": - k_cache, v_cache = kv_cache.get() - key = torch.cat([k_cache, key], dim=1) if k_cache is not None else key - value = torch.cat([v_cache, value], dim=1) if v_cache is not None else value + key, value = kv_cache.get(key, value) elif kv_cache.mode == "skip": pass From 961fd795b22c40e54771009e8393c14cfaa1090e Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Mon, 12 Jan 2026 20:19:26 +0800 Subject: [PATCH 35/56] using new transformers impl --- .../pipelines/glm_image/pipeline_glm_image.py | 114 ++++++------------ 1 file changed, 35 insertions(+), 79 deletions(-) diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index 1c706208d3de..cafa610834ac 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -15,7 +15,6 @@ import inspect import re -from math import sqrt from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np @@ -194,40 +193,28 @@ def __init__( ) @staticmethod - def _build_image_grid_thw( - token_h: int, - token_w: int, - prev_token_h: int, - prev_token_w: int, - existing_grid: Optional[torch.Tensor] = None, - device: Optional[torch.device] = None, - ) -> torch.Tensor: - if existing_grid is None or existing_grid.numel() == 0: - return torch.tensor( - [ - [1, token_h, token_w], - [1, prev_token_h, prev_token_w], - ], - device=device, - ) - else: - return torch.cat([existing_grid.to(device), torch.tensor([[1, token_h, token_w]], device=device)], dim=0) + def _compute_generation_params( + image_grid_thw, + is_text_to_image: bool, + ): + grid_sizes = [] + grid_hw = [] - @staticmethod - def _calculate_ar_generation_params( - token_h: int, token_w: int, prev_token_h: int, prev_token_w: int, is_text_to_image: bool - ) -> Tuple[int, int]: - large_image_tokens = token_h * token_w - small_image_tokens = prev_token_h * prev_token_w - - if is_text_to_image: - max_new_tokens = small_image_tokens + large_image_tokens + 1 - large_image_start_offset = small_image_tokens - else: - max_new_tokens = large_image_tokens + 1 - large_image_start_offset = 0 + for i in range(image_grid_thw.shape[0]): + t, h, w = image_grid_thw[i].tolist() + grid_sizes.append(int(h * w)) + grid_hw.append((int(h), int(w))) - return max_new_tokens, large_image_start_offset + if not is_text_to_image: + max_new_tokens = grid_sizes[-1] + 1 + large_image_start_offset = 0 + target_grid_h, target_grid_w = grid_hw[-1] + else: + total_tokens = sum(grid_sizes) + max_new_tokens = total_tokens + 1 + large_image_start_offset = sum(grid_sizes[1:]) + target_grid_h, target_grid_w = grid_hw[0] + return max_new_tokens, large_image_start_offset, target_grid_h, target_grid_w @staticmethod def _extract_large_image_tokens( @@ -247,75 +234,44 @@ def _upsample_token_ids(token_ids: torch.Tensor, token_h: int, token_w: int) -> token_ids = token_ids.view(1, -1) return token_ids - @staticmethod - def _build_prompt_with_shape( - prompt: str, - height: int, - width: int, - is_text_to_image: bool, - factor: int = 32, - ) -> Tuple[str, int, int, int, int]: - token_h = height // factor - token_w = width // factor - ratio = token_h / token_w - prev_token_h = int(sqrt(ratio) * (factor // 2)) - prev_token_w = int(sqrt(1 / ratio) * (factor // 2)) - - if is_text_to_image: - expanded_prompt = f"{prompt}{token_h} {token_w}{prev_token_h} {prev_token_w}" - else: - expanded_prompt = f"{prompt}{token_h} {token_w}" - - return expanded_prompt, token_h, token_w, prev_token_h, prev_token_w - def generate_prior_tokens( self, prompt: str, height: int, width: int, image: Optional[List[PIL.Image.Image]] = None, - factor: int = 32, - ) -> Tuple[torch.Tensor, int, int]: + ): device = self.vision_language_encoder.device - height = (height // factor) * factor - width = (width // factor) * factor is_text_to_image = image is None or len(image) == 0 - expanded_prompt, token_h, token_w, prev_h, prev_w = self._build_prompt_with_shape( - prompt, height, width, is_text_to_image - ) content = [] if image is not None: for img in image: content.append({"type": "image", "image": img}) - content.append({"type": "text", "text": expanded_prompt}) + content.append({"type": "text", "text": prompt}) messages = [{"role": "user", "content": content}] inputs = self.processor.apply_chat_template( - messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + messages, + add_generation_prompt=True, + tokenize=True, + target_h=height, + target_w=width, + return_dict=True, + return_tensors="pt", ).to(device) - existing_grid = inputs.get("image_grid_thw") - inputs["image_grid_thw"] = self._build_image_grid_thw( - token_h, - token_w, - prev_h, - prev_w, - existing_grid=existing_grid if not is_text_to_image else None, - device=device, - ) - max_new_tokens, large_image_offset = self._calculate_ar_generation_params( - token_h, token_w, prev_h, prev_w, is_text_to_image + image_grid_thw = inputs.get("image_grid_thw") + max_new_tokens, large_image_offset, token_h, token_w = self._compute_generation_params( + image_grid_thw, is_text_to_image ) - large_image_tokens = token_h * token_w - input_length = inputs["input_ids"].shape[-1] prior_token_image_ids = None if image is not None: prior_token_image_embed = self.vision_language_encoder.get_image_features( - inputs["pixel_values"], existing_grid + inputs["pixel_values"], image_grid_thw[:-1] ) prior_token_image_embed = torch.cat(prior_token_image_embed, dim=0) prior_token_image_ids = self.vision_language_encoder.get_image_tokens( - prior_token_image_embed, existing_grid + prior_token_image_embed, image_grid_thw[:-1] ) outputs = self.vision_language_encoder.generate( **inputs, @@ -324,7 +280,7 @@ def generate_prior_tokens( ) prior_token_ids_d32 = self._extract_large_image_tokens( - outputs, input_length, large_image_offset, large_image_tokens + outputs, inputs["input_ids"].shape[-1], large_image_offset, token_h * token_w ) prior_token_ids = self._upsample_token_ids(prior_token_ids_d32, token_h, token_w) From 6c39bcfff5b075df157e4c73eb5a531dbe6738b8 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Mon, 12 Jan 2026 21:05:10 +0800 Subject: [PATCH 36/56] support --- .../pipelines/glm_image/pipeline_glm_image.py | 51 +++++++++++++++---- 1 file changed, 41 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index cafa610834ac..2b054f105464 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -374,7 +374,7 @@ def encode_prompt( prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = None - if do_classifier_free_guidance: + if do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = "" negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_embeds = self._get_glyph_embeds(negative_prompt, max_sequence_length, device, dtype) @@ -406,10 +406,14 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype def check_inputs( self, prompt, + negative_prompt, height, width, callback_on_step_end_tensor_inputs, prompt_embeds=None, + negative_prompt_embeds=None, + prior_token_ids=None, + prior_image_token_ids=None, ): if ( height is not None @@ -427,7 +431,6 @@ def check_inputs( raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) - if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" @@ -439,6 +442,26 @@ def check_inputs( ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if (prior_token_ids is None and prior_image_token_ids is not None) or ( + prior_token_ids is not None and prior_image_token_ids is None + ): + raise ValueError( + f"Cannot forward only one `prior_token_ids`: {negative_prompt} or `prior_image_token_ids`:" + f" {negative_prompt_embeds} provided. Please make sure both are provided or neither." + ) @property def guidance_scale(self): @@ -469,6 +492,7 @@ def interrupt(self): def __call__( self, prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, image: Optional[ Union[ torch.Tensor, PIL.Image.Image, np.ndarray, List[torch.Tensor], List[PIL.Image.Image], List[np.ndarray] @@ -483,7 +507,10 @@ def __call__( num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prior_token_ids: Optional[torch.FloatTensor] = None, + prior_image_token_ids: Optional[torch.Tensor] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), output_type: str = "pil", return_dict: bool = True, @@ -530,10 +557,14 @@ def __call__( # 1. Check inputs self.check_inputs( prompt, + negative_prompt, height, width, callback_on_step_end_tensor_inputs, prompt_embeds, + negative_prompt_embeds, + prior_token_ids, + prior_image_token_ids, ) self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs @@ -550,13 +581,13 @@ def __call__( raise ValueError(f"batch_size must be 1 due to AR model limitations, got {batch_size}") device = self._execution_device - - prior_token_id, prior_token_image_ids = self.generate_prior_tokens( - prompt=prompt[0] if isinstance(prompt, list) else prompt, - image=image, - height=height, - width=width, - ) + if prior_token_ids is None: + prior_token_id, prior_token_image_ids = self.generate_prior_tokens( + prompt=prompt[0] if isinstance(prompt, list) else prompt, + image=image, + height=height, + width=width, + ) # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( From c77dee98f963a64fdfce828a575a6d2139b96683 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Mon, 12 Jan 2026 22:39:40 +0800 Subject: [PATCH 37/56] resolution change --- src/diffusers/pipelines/glm_image/pipeline_glm_image.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index 2b054f105464..2b85b970bc99 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -251,7 +251,6 @@ def generate_prior_tokens( messages = [{"role": "user", "content": content}] inputs = self.processor.apply_chat_template( messages, - add_generation_prompt=True, tokenize=True, target_h=height, target_w=width, From 95b88f9db03430122f60854f62ccf6fd6d2b1ef0 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 13 Jan 2026 10:30:22 +0530 Subject: [PATCH 38/56] fix-copies --- .../pipelines/glm_image/pipeline_glm_image.py | 39 ++++++++++++------- src/diffusers/utils/dummy_pt_objects.py | 4 +- .../dummy_torch_and_transformers_objects.py | 15 +++++++ 3 files changed, 42 insertions(+), 16 deletions(-) diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index 2b85b970bc99..460310f35e6f 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -81,20 +81,30 @@ def retrieve_timesteps( r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. - """ - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ if timesteps is not None and sigmas is not None: - if not accepts_timesteps and not accepts_sigmas: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" timestep or sigma schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - elif timesteps is not None and sigmas is None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" @@ -103,8 +113,9 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) - elif timesteps is None and sigmas is not None: - if not accepts_sigmas: + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 35feca0e346d..7120ff1f6257 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -982,7 +982,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class HiDreamImageTransformer2DModel(metaclass=DummyObject): +class GlmImageTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): @@ -997,7 +997,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class GlmImageTransformer2DModel(metaclass=DummyObject): +class HiDreamImageTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index da32b7ad8df0..5e1fed304ce7 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1142,6 +1142,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class GlmImagePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class HiDreamImagePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From c252ba1a83373d9dbf01d39eca4ec8df198d420b Mon Sep 17 00:00:00 2001 From: Yuxuan Zhang <2448370773@qq.com> Date: Tue, 13 Jan 2026 13:29:06 +0800 Subject: [PATCH 39/56] Update src/diffusers/pipelines/glm_image/pipeline_glm_image.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/glm_image/pipeline_glm_image.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index 2b85b970bc99..2bc273adf780 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -446,7 +446,14 @@ def check_inputs( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) - + if prompt is not None and prior_token_ids is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prior_token_ids`: {prior_token_ids}. Please make sure to" + " only forward one of the two." + elif prompt is None and prior_token_ids is None: + raise ValueError( + "Provide either `prompt` or `prior_token_ids`. Cannot leave both `prompt` and `prior_token_ids` undefined." + ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( From 404689bb1486f166fb8b00a4809795825c494fd7 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Tue, 13 Jan 2026 13:33:47 +0800 Subject: [PATCH 40/56] Update pipeline_glm_image.py --- src/diffusers/pipelines/glm_image/pipeline_glm_image.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index 9339bcbfcb6d..c59390ec236a 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -461,7 +461,8 @@ def check_inputs( raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prior_token_ids`: {prior_token_ids}. Please make sure to" " only forward one of the two." - elif prompt is None and prior_token_ids is None: + ) + elif prompt is None and prior_token_ids is None: raise ValueError( "Provide either `prompt` or `prior_token_ids`. Cannot leave both `prompt` and `prior_token_ids` undefined." ) From 908030f4c0c98b977aed183062ca6523ca6a93ba Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Tue, 13 Jan 2026 13:44:11 +0800 Subject: [PATCH 41/56] use cogview4 --- .../pipelines/glm_image/pipeline_glm_image.py | 36 ++++++++----------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index c59390ec236a..13502374f1e1 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -69,7 +69,7 @@ def calculate_shift( return mu -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, @@ -101,10 +101,19 @@ def retrieve_timesteps( `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if timesteps is not None and sigmas is not None: - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") - if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps and not accepts_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep or sigma schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif timesteps is not None and sigmas is None: if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" @@ -113,9 +122,8 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) - elif sigmas is not None: - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accept_sigmas: + elif timesteps is None and sigmas is not None: + if not accepts_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." @@ -129,20 +137,6 @@ def retrieve_timesteps( return timesteps, num_inference_steps -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents( - encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" -): - if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": - return encoder_output.latent_dist.sample(generator) - elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": - return encoder_output.latent_dist.mode() - elif hasattr(encoder_output, "latents"): - return encoder_output.latents - else: - raise AttributeError("Could not access latents of provided encoder_output") - - class GlmImagePipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using GLM-Image. From 97f85ffefba073bd522c9a22ee0823ac84cb218c Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Tue, 13 Jan 2026 14:05:30 +0800 Subject: [PATCH 42/56] Update pipeline_glm_image.py --- .../pipelines/glm_image/pipeline_glm_image.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index 13502374f1e1..cfd5ff815d31 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -137,6 +137,20 @@ def retrieve_timesteps( return timesteps, num_inference_steps +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + class GlmImagePipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using GLM-Image. @@ -282,7 +296,7 @@ def generate_prior_tokens( max_new_tokens=max_new_tokens, do_sample=True, ) - + print(outputs) prior_token_ids_d32 = self._extract_large_image_tokens( outputs, inputs["input_ids"].shape[-1], large_image_offset, token_h * token_w ) From 906792fed2b2ca823d21bc399e4f473809f92051 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Tue, 13 Jan 2026 15:46:30 +0800 Subject: [PATCH 43/56] Update pipeline_glm_image.py --- .../pipelines/glm_image/pipeline_glm_image.py | 65 ++++++++++++------- 1 file changed, 41 insertions(+), 24 deletions(-) diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index cfd5ff815d31..d90777c881f6 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -14,6 +14,7 @@ # limitations under the License. import inspect +import math import re from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -151,6 +152,19 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") +def calculate_dimensions(target_area, ratio): + width = math.sqrt(target_area * ratio) + height = width / ratio + + width = width if width % 32 == 0 else (width // 32 + 1) * 32 + height = height if height % 32 == 0 else (height // 32 + 1) * 32 + + width = int(width) + height = int(height) + + return width, height + + class GlmImagePipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using GLM-Image. @@ -279,7 +293,7 @@ def generate_prior_tokens( image_grid_thw = inputs.get("image_grid_thw") max_new_tokens, large_image_offset, token_h, token_w = self._compute_generation_params( - image_grid_thw, is_text_to_image + image_grid_thw=image_grid_thw, is_text_to_image=is_text_to_image ) prior_token_image_ids = None @@ -291,12 +305,15 @@ def generate_prior_tokens( prior_token_image_ids = self.vision_language_encoder.get_image_tokens( prior_token_image_embed, image_grid_thw[:-1] ) + + # For GLM-Image, greedy decoding is not allowed; it may cause repetitive outputs. + # max_new_tokens must be exactly grid_h * grid_w + 1 (the +1 is for EOS). outputs = self.vision_language_encoder.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=True, ) - print(outputs) + prior_token_ids_d32 = self._extract_large_image_tokens( outputs, inputs["input_ids"].shape[-1], large_image_offset, token_h * token_w ) @@ -391,6 +408,7 @@ def encode_prompt( prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + # For GLM-Image, negative_prompt must be "" instead of None negative_prompt_embeds = None if do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = "" @@ -435,12 +453,13 @@ def check_inputs( ): if ( height is not None - and height % (self.vae_scale_factor * self.transformer.config.patch_size) != 0 + and height % (self.vae_scale_factor * self.transformer.config.patch_size * 2) != 0 or width is not None - and width % (self.transformer.config.patch_size) != 0 + and width % (self.transformer.config.patch_size * 2) != 0 ): - logger.warning( - f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + # GLM-Image uses 32× downsampling, so the image dimensions must be multiples of 32. + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 4} but are {height} and {width}." ) if callback_on_step_end_tensor_inputs is not None and not all( @@ -607,6 +626,15 @@ def __call__( raise ValueError(f"batch_size must be 1 due to AR model limitations, got {batch_size}") device = self._execution_device + + # 2. Preprocess image for AR: + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + image_size = image[0].size if isinstance(image, list) else image.size + calculated_width, calculated_height = calculate_dimensions( + 2048 * 2048, image_size[0] * 1.0 / image_size[1] + ) + image = self.image_processor.resize(image, calculated_height, calculated_width) + if prior_token_ids is None: prior_token_id, prior_token_image_ids = self.generate_prior_tokens( prompt=prompt[0] if isinstance(prompt, list) else prompt, @@ -615,7 +643,12 @@ def __call__( width=width, ) - # 3. Encode input prompt + # 4. Preprocess image for DIT + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + image = self.image_processor.preprocess(image, calculated_height, calculated_width) + image = image.unsqueeze(2) + + # 5. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, self.do_classifier_free_guidance, @@ -626,23 +659,7 @@ def __call__( dtype=self.dtype, ) - # 4. process images - if image is not None: - preprocessed_condition_images = [] - for img in image: - image_height, image_width = img.size[::-1] if isinstance(img, PIL.Image.Image) else img.shape[:2] - multiple_of = self.vae_scale_factor * self.transformer.config.patch_size - image_height = (image_height // multiple_of) * multiple_of - image_width = (image_width // multiple_of) * multiple_of - img = self.image_processor.preprocess(img, height=image_height, width=image_width) - preprocessed_condition_images.append(img) - height = height or image_height - width = width or image_width - image = preprocessed_condition_images - height = height or self.default_sample_size * self.vae_scale_factor - width = width or self.default_sample_size * self.vae_scale_factor - - # 5. Prepare latents and (optional) image kv cache + # 4. Prepare latents and (optional) image kv cache latent_channels = self.transformer.config.in_channels latents = self.prepare_latents( batch_size=batch_size * num_images_per_prompt, From 2f4c8d9b607ad2b7eebf550c2f607b13f62fe810 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Tue, 13 Jan 2026 15:51:33 +0800 Subject: [PATCH 44/56] revert --- .../pipelines/glm_image/pipeline_glm_image.py | 39 +++++++------------ 1 file changed, 14 insertions(+), 25 deletions(-) diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index d90777c881f6..0e15485050a1 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -152,19 +152,6 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -def calculate_dimensions(target_area, ratio): - width = math.sqrt(target_area * ratio) - height = width / ratio - - width = width if width % 32 == 0 else (width // 32 + 1) * 32 - height = height if height % 32 == 0 else (height // 32 + 1) * 32 - - width = int(width) - height = int(height) - - return width, height - - class GlmImagePipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using GLM-Image. @@ -627,14 +614,7 @@ def __call__( device = self._execution_device - # 2. Preprocess image for AR: - if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): - image_size = image[0].size if isinstance(image, list) else image.size - calculated_width, calculated_height = calculate_dimensions( - 2048 * 2048, image_size[0] * 1.0 / image_size[1] - ) - image = self.image_processor.resize(image, calculated_height, calculated_width) - + # 2. Preprocess image tokens and prompt tokens if prior_token_ids is None: prior_token_id, prior_token_image_ids = self.generate_prior_tokens( prompt=prompt[0] if isinstance(prompt, list) else prompt, @@ -643,10 +623,19 @@ def __call__( width=width, ) - # 4. Preprocess image for DIT - if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): - image = self.image_processor.preprocess(image, calculated_height, calculated_width) - image = image.unsqueeze(2) + # 3. Preprocess image + if image is not None: + preprocessed_condition_images = [] + for img in image: + image_height, image_width = img.size[::-1] if isinstance(img, PIL.Image.Image) else img.shape[:2] + multiple_of = self.vae_scale_factor * self.transformer.config.patch_size + image_height = (image_height // multiple_of) * multiple_of + image_width = (image_width // multiple_of) * multiple_of + img = self.image_processor.preprocess(img, height=image_height, width=image_width) + preprocessed_condition_images.append(img) + height = height or image_height + width = width or image_width + image = preprocessed_condition_images # 5. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( From 1520728141b66c977d825540626450d0b2533ee0 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Tue, 13 Jan 2026 16:28:10 +0800 Subject: [PATCH 45/56] update --- .../pipelines/glm_image/pipeline_glm_image.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index 0e15485050a1..ea1cb373c522 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -616,7 +616,7 @@ def __call__( # 2. Preprocess image tokens and prompt tokens if prior_token_ids is None: - prior_token_id, prior_token_image_ids = self.generate_prior_tokens( + prior_token_ids, prior_token_image_ids = self.generate_prior_tokens( prompt=prompt[0] if isinstance(prompt, list) else prompt, image=image, height=height, @@ -660,7 +660,6 @@ def __call__( generator=generator, latents=latents, ) - kv_caches = GlmImageKVCache(num_layers=self.transformer.config.num_layers) if image is not None: @@ -727,8 +726,8 @@ def __call__( transformer_dtype = self.transformer.dtype num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - prior_token_drop_cond = torch.full_like(prior_token_id, False, dtype=torch.bool) - prior_token_drop_uncond = torch.full_like(prior_token_id, True, dtype=torch.bool) + prior_token_drop_cond = torch.full_like(prior_token_ids, False, dtype=torch.bool) + prior_token_drop_uncond = torch.full_like(prior_token_ids, True, dtype=torch.bool) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: @@ -745,7 +744,7 @@ def __call__( noise_pred_cond = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, - prior_token_id=prior_token_id, + prior_token_id=prior_token_ids, prior_token_drop=prior_token_drop_cond, timestep=timestep, target_size=target_size, @@ -762,7 +761,7 @@ def __call__( noise_pred_uncond = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=negative_prompt_embeds, - prior_token_id=prior_token_id, + prior_token_id=prior_token_ids, prior_token_drop=prior_token_drop_uncond, timestep=timestep, target_size=target_size, From 9631b682aac256b1df8d2949f6e6a278144f6d96 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Tue, 13 Jan 2026 16:57:57 +0800 Subject: [PATCH 46/56] batch support --- .../models/transformers/transformer_glm_image.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_glm_image.py b/src/diffusers/models/transformers/transformer_glm_image.py index 30db3dc82968..b7b3aa391ce4 100644 --- a/src/diffusers/models/transformers/transformer_glm_image.py +++ b/src/diffusers/models/transformers/transformer_glm_image.py @@ -159,8 +159,15 @@ def store(self, k: torch.Tensor, v: torch.Tensor): self.v_cache = torch.cat([self.v_cache, v], dim=1) def get(self, k: torch.Tensor, v: torch.Tensor): - k_cache = torch.cat([self.k_cache, k], dim=1) - v_cache = torch.cat([self.v_cache, v], dim=1) + if self.k_cache.shape[0] != k.shape[0]: + k_cache_expanded = self.k_cache.expand(k.shape[0], -1, -1, -1) + v_cache_expanded = self.v_cache.expand(v.shape[0], -1, -1, -1) + else: + k_cache_expanded = self.k_cache + v_cache_expanded = self.v_cache + + k_cache = torch.cat([k_cache_expanded, k], dim=1) + v_cache = torch.cat([v_cache_expanded, v], dim=1) return k_cache, v_cache def clear(self): From 008e5ea37f469cd02be9eaee794a4c21a0ff5151 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Tue, 13 Jan 2026 17:14:57 +0800 Subject: [PATCH 47/56] update --- .../pipelines/glm_image/pipeline_glm_image.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index ea1cb373c522..92389b33d10b 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -14,7 +14,6 @@ # limitations under the License. import inspect -import math import re from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -356,6 +355,7 @@ def encode_prompt( do_classifier_free_guidance: bool = True, num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, max_sequence_length: int = 2048, @@ -396,7 +396,6 @@ def encode_prompt( prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For GLM-Image, negative_prompt must be "" instead of None - negative_prompt_embeds = None if do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = "" negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt @@ -429,7 +428,6 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype def check_inputs( self, prompt, - negative_prompt, height, width, callback_on_step_end_tensor_inputs, @@ -466,11 +464,6 @@ def check_inputs( ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) if prompt is not None and prior_token_ids is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prior_token_ids`: {prior_token_ids}. Please make sure to" @@ -491,8 +484,8 @@ def check_inputs( prior_token_ids is not None and prior_image_token_ids is None ): raise ValueError( - f"Cannot forward only one `prior_token_ids`: {negative_prompt} or `prior_image_token_ids`:" - f" {negative_prompt_embeds} provided. Please make sure both are provided or neither." + f"Cannot forward only one `prior_token_ids`: {prior_token_ids} or `prior_image_token_ids`:" + f" {prior_image_token_ids} provided. Please make sure both are provided or neither." ) @property @@ -524,7 +517,6 @@ def interrupt(self): def __call__( self, prompt: Optional[Union[str, List[str]]] = None, - negative_prompt: Optional[Union[str, List[str]]] = None, image: Optional[ Union[ torch.Tensor, PIL.Image.Image, np.ndarray, List[torch.Tensor], List[PIL.Image.Image], List[np.ndarray] @@ -589,7 +581,6 @@ def __call__( # 1. Check inputs self.check_inputs( prompt, - negative_prompt, height, width, callback_on_step_end_tensor_inputs, @@ -643,6 +634,7 @@ def __call__( self.do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt, prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, max_sequence_length=max_sequence_length, device=device, dtype=self.dtype, From 5bb959fba861f264d61a87b1fa3ecda4d96ec204 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 13 Jan 2026 17:29:59 +0530 Subject: [PATCH 48/56] version guard glm image pipeline --- src/diffusers/__init__.py | 8 +++-- src/diffusers/pipelines/__init__.py | 6 ++++ src/diffusers/pipelines/glm_image/__init__.py | 16 ++++++++-- .../dummy_torch_and_transformers_objects.py | 30 +++++++++---------- 4 files changed, 41 insertions(+), 19 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 4051a26fe954..ab34927e70b8 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -23,6 +23,7 @@ is_torchao_available, is_torchsde_available, is_transformers_available, + is_transformers_version, ) @@ -493,7 +494,6 @@ "FluxKontextPipeline", "FluxPipeline", "FluxPriorReduxPipeline", - "GlmImagePipeline", "HiDreamImagePipeline", "HunyuanDiTControlNetPipeline", "HunyuanDiTPAGPipeline", @@ -691,6 +691,8 @@ "ZImagePipeline", ] ) + if is_transformers_version(">=", "4.57.4"): + _import_structure["pipelines"].extend(["GlmImagePipeline"]) try: @@ -1219,7 +1221,6 @@ FluxKontextPipeline, FluxPipeline, FluxPriorReduxPipeline, - GlmImagePipeline, HiDreamImagePipeline, HunyuanDiTControlNetPipeline, HunyuanDiTPAGPipeline, @@ -1415,6 +1416,9 @@ ZImagePipeline, ) + if is_transformers_version(">=", "4.57.4"): + from .pipelines import GlmImagePipeline + try: if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): raise OptionalDependencyNotAvailable() diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index b94319ffcbdc..5e291aef4dc4 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -15,6 +15,7 @@ is_torch_available, is_torch_npu_available, is_transformers_available, + is_transformers_version, ) @@ -434,6 +435,8 @@ "QwenImageLayeredPipeline", ] _import_structure["chronoedit"] = ["ChronoEditPipeline"] + if is_transformers_version(">=", "4.57.4"): + _import_structure["glm_image"] = ["GlmImagePipeline"] try: if not is_onnx_available(): raise OptionalDependencyNotAvailable() @@ -870,6 +873,9 @@ ZImagePipeline, ) + if is_transformers_version(">=", "4.57.4"): + from .glm_image import GlmImagePipeline + try: if not is_onnx_available(): raise OptionalDependencyNotAvailable() diff --git a/src/diffusers/pipelines/glm_image/__init__.py b/src/diffusers/pipelines/glm_image/__init__.py index 9df31b0b1734..8cea72390ade 100644 --- a/src/diffusers/pipelines/glm_image/__init__.py +++ b/src/diffusers/pipelines/glm_image/__init__.py @@ -7,6 +7,7 @@ get_objects_from_module, is_torch_available, is_transformers_available, + is_transformers_version, ) @@ -14,8 +15,19 @@ _additional_imports = {} _import_structure = {"pipeline_output": ["GlmImagePipelineOutput"]} +# Import transformers components so they can be resolved during pipeline loading + +if is_transformers_available() and is_transformers_version(">=", "4.57.4"): + try: + from transformers import GlmImageForConditionalGeneration, GlmImageProcessor + + _additional_imports["GlmImageForConditionalGeneration"] = GlmImageForConditionalGeneration + _additional_imports["GlmImageProcessor"] = GlmImageProcessor + except ImportError: + pass + try: - if not (is_transformers_available() and is_torch_available()): + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.57.4")): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 @@ -25,7 +37,7 @@ _import_structure["pipeline_glm_image"] = ["GlmImagePipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: - if not (is_transformers_available() and is_torch_available()): + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.57.4")): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 5e1fed304ce7..7522cfa6d625 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1142,21 +1142,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class GlmImagePipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - class HiDreamImagePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -4050,3 +4035,18 @@ def from_config(cls, *args, **kwargs): @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) + + +class GlmImagePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) From 512529f0c74e26fbb7e565ba4ba99998ba9bd396 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 13 Jan 2026 17:34:21 +0530 Subject: [PATCH 49/56] validate prompt_embeds and prior_token_ids --- src/diffusers/pipelines/glm_image/pipeline_glm_image.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index 92389b33d10b..c5948ce1bd8e 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -488,6 +488,9 @@ def check_inputs( f" {prior_image_token_ids} provided. Please make sure both are provided or neither." ) + if prior_token_ids is not None and prompt_embeds is None: + raise ValueError("`prompt_embeds` must also be provided with `prior_token_ids`.") + @property def guidance_scale(self): return self._guidance_scale From 707a29adf7704e5dc38620e1d20f6c8e02ab4fe6 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 13 Jan 2026 17:44:37 +0530 Subject: [PATCH 50/56] try docs. --- docs/source/en/api/models/glm_image_transformer2d.md | 2 +- docs/source/en/api/pipelines/glm_image.md | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/en/api/models/glm_image_transformer2d.md b/docs/source/en/api/models/glm_image_transformer2d.md index 8a8b07456046..7a18d1050075 100644 --- a/docs/source/en/api/models/glm_image_transformer2d.md +++ b/docs/source/en/api/models/glm_image_transformer2d.md @@ -11,7 +11,7 @@ specific language governing permissions and limitations under the License. --> # GlmImageTransformer2DModel -A Diffusion Transformer model for 2D data from [GlmImageTransformer2DModel]() +A Diffusion Transformer model for 2D data from [GlmImageTransformer2DModel] (TODO). ## GlmImageTransformer2DModel diff --git a/docs/source/en/api/pipelines/glm_image.md b/docs/source/en/api/pipelines/glm_image.md index 98e0145b6807..a99832787847 100644 --- a/docs/source/en/api/pipelines/glm_image.md +++ b/docs/source/en/api/pipelines/glm_image.md @@ -21,7 +21,7 @@ GLM-Image is an image generation model adopts a hybrid autoregressive + diffusio Model architecture: a hybrid autoregressive + diffusion decoder design、 -+ Autoregressive generator: a 9B-parameter model initialized from [GLM-4-9B-0414](https://huggingface.co/zai-org/GLM-4-9B-0414), with an expanded vocabulary to incorporate visual tokens. The model first generates a compact encoding of approximately 256 tokens, then expands to 1K–4K tokens, corresponding to 1K–2K high-resolution image outputs. You can check AR model in class `GlmImageForConditionalGeneration` of transformers library. ++ Autoregressive generator: a 9B-parameter model initialized from [GLM-4-9B-0414](https://huggingface.co/zai-org/GLM-4-9B-0414), with an expanded vocabulary to incorporate visual tokens. The model first generates a compact encoding of approximately 256 tokens, then expands to 1K–4K tokens, corresponding to 1K–2K high-resolution image outputs. You can check AR model in class `GlmImageForConditionalGeneration` of `transformers` library. + Diffusion Decoder: a 7B-parameter decoder based on a single-stream DiT architecture for latent-space image decoding. It is equipped with a Glyph Encoder text module, significantly improving accurate text rendering within images. Post-training with decoupled reinforcement learning: the model introduces a fine-grained, modular feedback strategy using the GRPO algorithm, substantially enhancing both semantic understanding and visual detail quality. @@ -86,10 +86,10 @@ image.save("output_i2i.png") ## GlmImagePipeline -[[autodoc]] GlmImagePipeline +[[autodoc]] pipelines.glm_image.pipeline_glm_image.GlmImagePipeline - all - __call__ ## GlmImagePipelineOutput -[[autodoc]] pipelines.cogview4.pipeline_output.GlmImagePipelineOutput +[[autodoc]] pipelines.glm_image.pipeline_output.GlmImagePipelineOutput From b8febaab7f0d7f4ffe8d6591fe842b0df98dd91a Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Tue, 13 Jan 2026 21:34:48 +0800 Subject: [PATCH 51/56] 4 --- tests/pipelines/glm_image/__init__.py | 0 tests/pipelines/glm_image/test_glm_image.py | 186 ++++++++++++++++++++ 2 files changed, 186 insertions(+) create mode 100644 tests/pipelines/glm_image/__init__.py create mode 100644 tests/pipelines/glm_image/test_glm_image.py diff --git a/tests/pipelines/glm_image/__init__.py b/tests/pipelines/glm_image/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/glm_image/test_glm_image.py b/tests/pipelines/glm_image/test_glm_image.py new file mode 100644 index 000000000000..b7a42fb4503e --- /dev/null +++ b/tests/pipelines/glm_image/test_glm_image.py @@ -0,0 +1,186 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + + +# Add transformers path for GlmImage models +sys.path.insert(0, "/fsx/sayak/transformers/src") + +import torch +from transformers import ( + AutoTokenizer, + GlmImageConfig, + GlmImageForConditionalGeneration, + GlmImageProcessor, + T5EncoderModel, +) + +from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, GlmImagePipeline, GlmImageTransformer2DModel + +from ...testing_utils import enable_full_determinism +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class GlmImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = GlmImagePipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + test_attention_slicing = False + + def get_dummy_components(self): + torch.manual_seed(0) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + glm_config = GlmImageConfig( + text_config={ + "vocab_size": 168064, + "hidden_size": 32, + "intermediate_size": 32, + "num_hidden_layers": 2, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "max_position_embeddings": 512, + "vision_vocab_size": 128, + "rope_parameters": {"mrope_section": (4, 2, 2)}, + }, + vision_config={ + "depth": 2, + "hidden_size": 32, + "num_heads": 2, + "image_size": 32, + "patch_size": 8, + "intermediate_size": 32, + }, + vq_config={ + "embed_dim": 32, + "num_embeddings": 128, + "latent_channels": 32, + }, + ) + + torch.manual_seed(0) + vision_language_encoder = GlmImageForConditionalGeneration(glm_config) + + # TODO: move to a public checkpoint + processor = GlmImageProcessor.from_pretrained("ZP2Test/GLM-Image", subfolder="processor") + + torch.manual_seed(0) + # For GLM-Image, the relationship between components must satisfy: + # patch_size × vae_scale_factor = 16 (since AR tokens are upsampled 2× from d32) + transformer = GlmImageTransformer2DModel( + patch_size=2, + in_channels=4, + out_channels=4, + num_layers=2, + attention_head_dim=8, + num_attention_heads=2, + text_embed_dim=text_encoder.config.hidden_size, + time_embed_dim=16, + condition_dim=8, + prior_vq_quantizer_codebook_size=128, + ) + + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=(4, 8, 16, 16), + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + norm_num_groups=4, + sample_size=128, + latents_mean=[0.0] * 4, + latents_std=[1.0] * 4, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + components = { + "tokenizer": tokenizer, + "processor": processor, + "text_encoder": text_encoder, + "vision_language_encoder": vision_language_encoder, + "vae": vae, + "transformer": transformer, + "scheduler": scheduler, + } + + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + # Use 128×128 to get more tokens + # - AR d32 tokens: 128/32 = 4×4 = 16 tokens + # - After 2x upsample: 8×8 = 64 tokens + # - VAE latent: 128/2 = 64x64 + # - Transformer patches: 64/8 = 8x8 = 64 patches + height, width = 128, 128 + + inputs = { + "prompt": "A photo of a cat", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 1.5, + "height": height, + "width": width, + "max_sequence_length": 16, + "output_type": "pt", + } + + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images[0] + + self.assertEqual(image.shape, (3, 128, 128)) + # TODO: slices + + @unittest.skip("Not supported.") + def test_inference_batch_single_identical(self): + # GLM-Image has batch_size=1 constraint due to AR model + # Skip this test or modify it + pass From 16333086bd899d46c8511ea11bde2067cb5ac423 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 13 Jan 2026 14:03:03 +0000 Subject: [PATCH 52/56] up --- tests/pipelines/glm_image/test_glm_image.py | 71 +++++++++++++++------ 1 file changed, 52 insertions(+), 19 deletions(-) diff --git a/tests/pipelines/glm_image/test_glm_image.py b/tests/pipelines/glm_image/test_glm_image.py index b7a42fb4503e..74ac28e8b571 100644 --- a/tests/pipelines/glm_image/test_glm_image.py +++ b/tests/pipelines/glm_image/test_glm_image.py @@ -15,10 +15,7 @@ import sys import unittest - -# Add transformers path for GlmImage models -sys.path.insert(0, "/fsx/sayak/transformers/src") - +import numpy as np import torch from transformers import ( AutoTokenizer, @@ -30,7 +27,7 @@ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, GlmImagePipeline, GlmImageTransformer2DModel -from ...testing_utils import enable_full_determinism +from ...testing_utils import enable_full_determinism, require_transformers_version_greater, require_torch_accelerator from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import PipelineTesterMixin @@ -38,6 +35,8 @@ enable_full_determinism() +@require_transformers_version_greater("4.57.4") +@require_torch_accelerator class GlmImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = GlmImagePipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} @@ -56,6 +55,7 @@ class GlmImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): ) test_xformers_attention = False test_attention_slicing = False + supports_dduf = False def get_dummy_components(self): torch.manual_seed(0) @@ -82,11 +82,7 @@ def get_dummy_components(self): "patch_size": 8, "intermediate_size": 32, }, - vq_config={ - "embed_dim": 32, - "num_embeddings": 128, - "latent_channels": 32, - }, + vq_config={"embed_dim": 32, "num_embeddings": 128, "latent_channels": 32}, ) torch.manual_seed(0) @@ -145,12 +141,7 @@ def get_dummy_inputs(self, device, seed=0): else: generator = torch.Generator(device=device).manual_seed(seed) - # Use 128×128 to get more tokens - # - AR d32 tokens: 128/32 = 4×4 = 16 tokens - # - After 2x upsample: 8×8 = 64 tokens - # - VAE latent: 128/2 = 64x64 - # - Transformer patches: 64/8 = 8x8 = 64 patches - height, width = 128, 128 + height, width = 32, 32 inputs = { "prompt": "A photo of a cat", @@ -175,12 +166,54 @@ def test_inference(self): inputs = self.get_dummy_inputs(device) image = pipe(**inputs).images[0] + generated_slice = image.flatten() + generated_slice = np.concatenate([generated_slice[:8], generated_slice[-8:]]) + + # fmt: off + expected_slice = np.array( + [ + 0.5796329, 0.5005878, 0.45881274, 0.45331675, 0.43688118, 0.4899527, 0.54017603, 0.50983673, 0.3387968, 0.38074082, 0.29942477, 0.33733928, 0.3672544, 0.38462338, 0.40991822, 0.46641728 + ] + ) + # fmt: on - self.assertEqual(image.shape, (3, 128, 128)) - # TODO: slices + self.assertEqual(image.shape, (3, 32, 32)) + self.assertTrue(np.allclose(expected_slice, generated_slice, atol=1e-4, rtol=1e-4)) @unittest.skip("Not supported.") def test_inference_batch_single_identical(self): # GLM-Image has batch_size=1 constraint due to AR model - # Skip this test or modify it + pass + + @unittest.skip("Not supported.") + def test_inference_batch_consisten(self): + # GLM-Image has batch_size=1 constraint due to AR model + pass + + @unittest.skip("Needs to be revisited.") + def test_encode_prompt_works_in_isolation(self): + pass + + @unittest.skip("Needs to be revisited.") + def test_pipeline_level_group_offloading_inference(self): + pass + + @unittest.skip("Follow set of tests are relaxed because this pipeline doesn't guarantee same outputs for the same inputs in consecutive runs.") + def test_dict_tuple_outputs_equivalent(self): + pass + + @unittest.skip("Skipped") + def test_float16_inference(self): + pass + + @unittest.skip("Skipped") + def test_float16_inference(self): + pass + + @unittest.skip("Skipped") + def test_save_load_float16(self): + pass + + @unittest.skip("Skipped") + def test_save_load_local(self): pass From c87389bdbf1913cf5b96286c2ca99ec7ccde7df9 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 13 Jan 2026 19:34:36 +0530 Subject: [PATCH 53/56] up --- tests/pipelines/glm_image/test_glm_image.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/tests/pipelines/glm_image/test_glm_image.py b/tests/pipelines/glm_image/test_glm_image.py index 74ac28e8b571..69a0dbaa77f7 100644 --- a/tests/pipelines/glm_image/test_glm_image.py +++ b/tests/pipelines/glm_image/test_glm_image.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys import unittest import numpy as np @@ -27,7 +26,7 @@ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, GlmImagePipeline, GlmImageTransformer2DModel -from ...testing_utils import enable_full_determinism, require_transformers_version_greater, require_torch_accelerator +from ...testing_utils import enable_full_determinism, require_torch_accelerator, require_transformers_version_greater from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import PipelineTesterMixin @@ -168,7 +167,7 @@ def test_inference(self): image = pipe(**inputs).images[0] generated_slice = image.flatten() generated_slice = np.concatenate([generated_slice[:8], generated_slice[-8:]]) - + # fmt: off expected_slice = np.array( [ @@ -197,13 +196,11 @@ def test_encode_prompt_works_in_isolation(self): @unittest.skip("Needs to be revisited.") def test_pipeline_level_group_offloading_inference(self): pass - - @unittest.skip("Follow set of tests are relaxed because this pipeline doesn't guarantee same outputs for the same inputs in consecutive runs.") - def test_dict_tuple_outputs_equivalent(self): - pass - @unittest.skip("Skipped") - def test_float16_inference(self): + @unittest.skip( + "Follow set of tests are relaxed because this pipeline doesn't guarantee same outputs for the same inputs in consecutive runs." + ) + def test_dict_tuple_outputs_equivalent(self): pass @unittest.skip("Skipped") From 1cd6d2d76afae052d5ef9ee123b123039e83cf65 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 13 Jan 2026 20:44:37 +0530 Subject: [PATCH 54/56] skip properly --- tests/pipelines/glm_image/test_glm_image.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/pipelines/glm_image/test_glm_image.py b/tests/pipelines/glm_image/test_glm_image.py index 69a0dbaa77f7..f9c62ada2fbd 100644 --- a/tests/pipelines/glm_image/test_glm_image.py +++ b/tests/pipelines/glm_image/test_glm_image.py @@ -16,21 +16,20 @@ import numpy as np import torch -from transformers import ( - AutoTokenizer, - GlmImageConfig, - GlmImageForConditionalGeneration, - GlmImageProcessor, - T5EncoderModel, -) +from transformers import AutoTokenizer, T5EncoderModel from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, GlmImagePipeline, GlmImageTransformer2DModel +from diffusers.utils import is_transformers_version from ...testing_utils import enable_full_determinism, require_torch_accelerator, require_transformers_version_greater from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import PipelineTesterMixin +if is_transformers_version(">=", "4.57.4"): + from transformers import GlmImageConfig, GlmImageForConditionalGeneration, GlmImageProcessor + + enable_full_determinism() @@ -38,7 +37,7 @@ @require_torch_accelerator class GlmImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = GlmImagePipeline - params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "negative_prompt"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS image_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS @@ -185,7 +184,7 @@ def test_inference_batch_single_identical(self): pass @unittest.skip("Not supported.") - def test_inference_batch_consisten(self): + def test_inference_batch_consistent(self): # GLM-Image has batch_size=1 constraint due to AR model pass From 8f62cac31ef31a1b2234493d11a78d7753cbab15 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 13 Jan 2026 21:13:51 +0530 Subject: [PATCH 55/56] fix tests --- .../pipelines/glm_image/pipeline_glm_image.py | 7 ++++--- tests/pipelines/glm_image/test_glm_image.py | 13 +++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index c5948ce1bd8e..42a1816b1c8c 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -259,8 +259,9 @@ def generate_prior_tokens( height: int, width: int, image: Optional[List[PIL.Image.Image]] = None, + device: Optional[torch.device] = None, ): - device = self.vision_language_encoder.device + device = device or self._execution_device is_text_to_image = image is None or len(image) == 0 content = [] if image is not None: @@ -615,6 +616,7 @@ def __call__( image=image, height=height, width=width, + device=device, ) # 3. Preprocess image @@ -803,11 +805,10 @@ def __call__( ) latents = latents * latents_std + latents_mean image = self.vae.decode(latents, return_dict=False, generator=generator)[0] + image = self.image_processor.postprocess(image, output_type=output_type) else: image = latents - image = self.image_processor.postprocess(image, output_type=output_type) - # Offload all models self.maybe_free_model_hooks() diff --git a/tests/pipelines/glm_image/test_glm_image.py b/tests/pipelines/glm_image/test_glm_image.py index f9c62ada2fbd..cebb71d602bb 100644 --- a/tests/pipelines/glm_image/test_glm_image.py +++ b/tests/pipelines/glm_image/test_glm_image.py @@ -188,6 +188,11 @@ def test_inference_batch_consistent(self): # GLM-Image has batch_size=1 constraint due to AR model pass + @unittest.skip("Not supported.") + def test_num_images_per_prompt(self): + # GLM-Image has batch_size=1 constraint due to AR model + pass + @unittest.skip("Needs to be revisited.") def test_encode_prompt_works_in_isolation(self): pass @@ -202,6 +207,14 @@ def test_pipeline_level_group_offloading_inference(self): def test_dict_tuple_outputs_equivalent(self): pass + @unittest.skip("Skipped") + def test_cpu_offload_forward_pass_twice(self): + pass + + @unittest.skip("Skipped") + def test_sequential_offload_forward_pass_twice(self): + pass + @unittest.skip("Skipped") def test_float16_inference(self): pass From 1636b438099946348373fddb0adcf46c4113d956 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 13 Jan 2026 21:20:24 +0530 Subject: [PATCH 56/56] up --- src/diffusers/__init__.py | 7 ++--- src/diffusers/pipelines/__init__.py | 8 ++--- src/diffusers/pipelines/glm_image/__init__.py | 4 +-- .../pipelines/glm_image/pipeline_glm_image.py | 9 ++++-- .../dummy_torch_and_transformers_objects.py | 30 +++++++++---------- 5 files changed, 29 insertions(+), 29 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index ab34927e70b8..8f3368b96329 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -494,6 +494,7 @@ "FluxKontextPipeline", "FluxPipeline", "FluxPriorReduxPipeline", + "GlmImagePipeline", "HiDreamImagePipeline", "HunyuanDiTControlNetPipeline", "HunyuanDiTPAGPipeline", @@ -691,8 +692,6 @@ "ZImagePipeline", ] ) - if is_transformers_version(">=", "4.57.4"): - _import_structure["pipelines"].extend(["GlmImagePipeline"]) try: @@ -1221,6 +1220,7 @@ FluxKontextPipeline, FluxPipeline, FluxPriorReduxPipeline, + GlmImagePipeline, HiDreamImagePipeline, HunyuanDiTControlNetPipeline, HunyuanDiTPAGPipeline, @@ -1416,9 +1416,6 @@ ZImagePipeline, ) - if is_transformers_version(">=", "4.57.4"): - from .pipelines import GlmImagePipeline - try: if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): raise OptionalDependencyNotAvailable() diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 5e291aef4dc4..9bc7dc4bf99a 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -435,8 +435,8 @@ "QwenImageLayeredPipeline", ] _import_structure["chronoedit"] = ["ChronoEditPipeline"] - if is_transformers_version(">=", "4.57.4"): - _import_structure["glm_image"] = ["GlmImagePipeline"] + _import_structure["glm_image"] = ["GlmImagePipeline"] + try: if not is_onnx_available(): raise OptionalDependencyNotAvailable() @@ -679,6 +679,7 @@ ReduxImageEncoder, ) from .flux2 import Flux2Pipeline + from .glm_image import GlmImagePipeline from .hidream_image import HiDreamImagePipeline from .hunyuan_image import HunyuanImagePipeline, HunyuanImageRefinerPipeline from .hunyuan_video import ( @@ -873,9 +874,6 @@ ZImagePipeline, ) - if is_transformers_version(">=", "4.57.4"): - from .glm_image import GlmImagePipeline - try: if not is_onnx_available(): raise OptionalDependencyNotAvailable() diff --git a/src/diffusers/pipelines/glm_image/__init__.py b/src/diffusers/pipelines/glm_image/__init__.py index 8cea72390ade..140b9cc760cc 100644 --- a/src/diffusers/pipelines/glm_image/__init__.py +++ b/src/diffusers/pipelines/glm_image/__init__.py @@ -27,7 +27,7 @@ pass try: - if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.57.4")): + if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 @@ -37,7 +37,7 @@ _import_structure["pipeline_glm_image"] = ["GlmImagePipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: - if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.57.4")): + if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index 42a1816b1c8c..bfb3966a69d7 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -20,7 +20,7 @@ import numpy as np import PIL import torch -from transformers import ByT5Tokenizer, GlmImageForConditionalGeneration, GlmImageProcessor, T5EncoderModel +from transformers import ByT5Tokenizer, T5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import VaeImageProcessor @@ -28,11 +28,16 @@ from ...models.transformers.transformer_glm_image import GlmImageKVCache from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils import is_torch_xla_available, is_transformers_version, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from .pipeline_output import GlmImagePipelineOutput +# Because it's not released in stable as of 13/01/2026. So this is just a proxy. +if is_transformers_version(">=", "4.57.4"): + from transformers import GlmImageForConditionalGeneration, GlmImageProcessor + + if is_torch_xla_available(): import torch_xla.core.xla_model as xm diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 7522cfa6d625..5e1fed304ce7 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1142,6 +1142,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class GlmImagePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class HiDreamImagePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -4035,18 +4050,3 @@ def from_config(cls, *args, **kwargs): @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) - - -class GlmImagePipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"])