diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 16c5acf346d..e1deb089cd3 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1230,6 +1230,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "4a2e2abae11ca2b86d570fc5b44be4d5eb5e72cc8f22dd136a94b37da83ab665": # ref: https://huggingface.co/KORMo-Team/KORMo-tokenizer res = "kormo" + if chkhsh == "9d70134b369a70e5735009b6de918f7581b5211f7c074d1f89f753aea8248af1": + # ref: ./Youtu-VL + res = "utu-vl" if res is None: logger.warning("\n") @@ -7133,6 +7136,7 @@ def prepare_tensors(self): "DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM", "KimiVLForConditionalGeneration", + "UTUVLForCausalLM", ) class DeepseekV2Model(TextModel): model_arch = gguf.MODEL_ARCH.DEEPSEEK2 @@ -7211,11 +7215,26 @@ def set_gguf_parameters(self): self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) self.gguf_writer.add_value_length_mla(hparams["v_head_dim"]) - self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) - self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) - self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) - self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"]) - self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) + if (moe_intermediate_size := hparams.get("moe_intermediate_size")) is not None: + self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) + else: + self.gguf_writer.add_expert_feed_forward_length(hparams.get("intermediate_size", 0)) + + if (n_routed_experts := hparams.get("n_routed_experts")) is not None: + self.gguf_writer.add_expert_count(n_routed_experts) + + if (n_shared_experts := hparams.get("n_shared_experts")) is not None: + self.gguf_writer.add_expert_shared_count(n_shared_experts) + else: + self.gguf_writer.add_expert_shared_count(0) + + if (routed_scaling_factor := hparams.get("routed_scaling_factor")) is not None: + self.gguf_writer.add_expert_weights_scale(routed_scaling_factor) + else: + self.gguf_writer.add_expert_weights_scale(1.0) + + if (norm_topk_prob := hparams.get("norm_topk_prob")) is not None and norm_topk_prob: + self.gguf_writer.add_expert_weights_norm(norm_topk_prob) self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) @@ -7231,10 +7250,17 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter # skip vision tensors and remove "language_model." for Kimi-VL if "vision_tower" in name or "multi_modal_projector" in name: return [] - + if name.startswith("siglip2.") or name.startswith("merger."): + return [] if name.startswith("language_model."): name = name.replace("language_model.", "") + # skip lm_head.weight if tie_word_embeddings is True + if self.hparams.get("tie_word_embeddings", False): + if name == "lm_head.weight" or name == "model.lm_head.weight": + logger.info("Skipping tied output layer 'lm_head.weight' (will use token_embd.weight)") + return [] + # rename e_score_correction_bias tensors if name.endswith("e_score_correction_bias"): name = name.replace("e_score_correction_bias", "e_score_correction.bias") @@ -7246,7 +7272,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [] # process the experts separately - if name.find("mlp.experts") != -1: + if name.find("mlp.experts") != -1 and self.hparams.get("n_routed_experts") is not None: n_experts = self.hparams["n_routed_experts"] assert bid is not None @@ -7309,7 +7335,6 @@ def prepare_tensors(self): if len(experts) > 0: raise ValueError(f"Unprocessed experts: {experts}") - @ModelBase.register("MiniMaxM2ForCausalLM") class MiniMaxM2Model(TextModel): model_arch = gguf.MODEL_ARCH.MINIMAXM2 @@ -10466,7 +10491,53 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [] - +@ModelBase.register("UtuVLForConditionalGeneration", "UTUVLForCausalLM") +class UtuVLVisionModel(MmprojModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.hparams_vision is not None + self.hparams_vision["image_size"] = self.hparams_vision.get("image_size", 560) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.UTUVL) + self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-6)) + + # Handle activation function + hidden_act = str(self.hparams.get("hidden_act", "gelu_pytorch_tanh")).lower() + if hidden_act in ("gelu", "gelu_pytorch_tanh", "gelu_fast", "gelu_new", "gelu_accurate"): + self.gguf_writer.add_vision_use_gelu(True) + elif hidden_act == "silu": + self.gguf_writer.add_vision_use_silu(True) + else: + raise ValueError(f"Unsupported activation function for UTUVL: {hidden_act}") + + self.gguf_writer.add_vision_spatial_merge_size(self.hparams.get("spatial_merge_size", 2)) + + window_size = self.hparams.get("window_size") + if window_size is not None: + self.gguf_writer.add_vision_window_size(window_size) + fullatt_block_indexes = self.hparams.get("fullatt_block_indexes") + assert fullatt_block_indexes is not None, "fullatt_block_indexes is required for utuvl" + self.gguf_writer.add_vision_wa_layers(layers=fullatt_block_indexes) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + # Skip language model tensors + skip_prefixes = ('lm_head.', 'model.layers.', 'model.embed_tokens.', 'model.norm.') + if name.startswith(skip_prefixes): + return [] + + # Try to map the tensor using TensorNameMap (handles vision encoder and projector) + try: + new_name = self.map_tensor_name(name) + return [(new_name, data_torch)] + except ValueError: + # If mapping fails, log warning and skip + logger.warning(f"Cannot map tensor: {name}") + return [] ###### CONVERSION LOGIC ###### diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 4378378309f..78a2a0168ce 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -145,6 +145,7 @@ class TOKENIZER_TYPE(IntEnum): {"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", }, {"name": "minimax-m2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", }, {"name": "kormo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/KORMo-Team/KORMo-tokenizer", }, + {"name": "utu-vl", "tokt": TOKENIZER_TYPE.BPE, "repo": "./Youtu-VL", }, ] # some models are known to be broken upstream, so we will skip them as exceptions diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 41d3bd4faf2..4c1741a12d2 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -294,7 +294,9 @@ class ClipVision: USE_GELU = "clip.use_gelu" USE_SILU = "clip.use_silu" N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl + WA_LAYERS = "clip.vision.wa_layers" # used by utuvl IS_DEEPSTACK_LAYERS = "clip.vision.is_deepstack_layers" + WINDOW_SIZE = "clip.vision.window_size" class Attention: HEAD_COUNT = "clip.vision.attention.head_count" @@ -3432,6 +3434,7 @@ class VisionProjectorType: JANUS_PRO = "janus_pro" LFM2A = "lfm2a" # audio GLM4V = "glm4v" + UTUVL = "utuvl" # Items here are (block size, type size) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 6a4a504f8dc..937550bb53f 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -1130,10 +1130,16 @@ def add_vision_projector_scale_factor(self, value: int) -> None: def add_vision_n_wa_pattern(self, value: int) -> None: self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value) + + def add_vision_wa_layers(self, layers: Sequence[int]) -> None: + self.add_array(Keys.ClipVision.WA_LAYERS, layers) def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None: self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers) + def add_vision_window_size(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.WINDOW_SIZE, value) + # audio models def add_audio_projection_dim(self, value: int) -> None: diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 276720fcde9..c9eb4476341 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1218,6 +1218,7 @@ class TensorNameMap: MODEL_TENSOR.V_MMPROJ: ( "multi_modal_projector.linear_{bid}", "visual.merger.mlp.{bid}", # qwen2vl + "merger.mlp.{bid}", ), MODEL_TENSOR.V_MMPROJ_FC: ( @@ -1255,6 +1256,7 @@ class TensorNameMap: "visual.patch_embed.proj", # qwen2vl "vision_tower.patch_embed.proj", # kimi-vl "model.vision.patch_embedding.proj", # cogvlm + "siglip2.vision_model.embeddings.patch_embedding", ), MODEL_TENSOR.V_ENC_EMBD_NORM: ( @@ -1288,6 +1290,7 @@ class TensorNameMap: "vision_encoder.transformer.layers.{bid}.attention.wq", # pixtral "visual.blocks.{bid}.attn.q", # qwen2vl, generated "vision_tower.encoder.blocks.{bid}.wq", # kimi-vl, generated + "siglip2.vision_model.encoder.layers.{bid}.self_attn.q_proj", # utuvl ), MODEL_TENSOR.V_ENC_ATTN_Q_NORM: ( @@ -1305,6 +1308,7 @@ class TensorNameMap: "vision_encoder.transformer.layers.{bid}.attention.wk", # pixtral "visual.blocks.{bid}.attn.k", # qwen2vl, generated "vision_tower.encoder.blocks.{bid}.wk", # kimi-vl, generated + "siglip2.vision_model.encoder.layers.{bid}.self_attn.k_proj", ), MODEL_TENSOR.V_ENC_ATTN_K_NORM: ( @@ -1322,6 +1326,7 @@ class TensorNameMap: "vision_encoder.transformer.layers.{bid}.attention.wv", # pixtral "visual.blocks.{bid}.attn.v", # qwen2vl, generated "vision_tower.encoder.blocks.{bid}.wv", # kimi-vl, generated + "siglip2.vision_model.encoder.layers.{bid}.self_attn.v_proj", ), MODEL_TENSOR.V_ENC_INPUT_NORM: ( @@ -1336,6 +1341,7 @@ class TensorNameMap: "visual.blocks.{bid}.norm1", # qwen2vl "vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1) "model.vision.transformer.layers.{bid}.input_layernorm", # cogvlm + "siglip2.vision_model.encoder.layers.{bid}.layer_norm1", ), MODEL_TENSOR.V_ENC_ATTN_O: ( @@ -1351,6 +1357,7 @@ class TensorNameMap: "visual.blocks.{bid}.attn.proj", # qwen2vl "vision_tower.encoder.blocks.{bid}.wo", # kimi-vl "model.vision.transformer.layers.{bid}.attention.dense", # cogvlm + "siglip2.vision_model.encoder.layers.{bid}.self_attn.out_proj", # utuvl ), MODEL_TENSOR.V_ENC_POST_ATTN_NORM: ( @@ -1365,6 +1372,7 @@ class TensorNameMap: "visual.blocks.{bid}.norm2", # qwen2vl "vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1) "model.vision.transformer.layers.{bid}.post_attention_layernorm", # cogvlm + "siglip2.vision_model.encoder.layers.{bid}.layer_norm2", ), MODEL_TENSOR.V_ENC_FFN_UP: ( @@ -1380,6 +1388,7 @@ class TensorNameMap: "visual.blocks.{bid}.mlp.linear_fc1", # qwen3vl "vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1) "model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm + "siglip2.vision_model.encoder.layers.{bid}.mlp.fc1", ), MODEL_TENSOR.V_ENC_FFN_GATE: ( @@ -1401,6 +1410,7 @@ class TensorNameMap: "visual.blocks.{bid}.mlp.linear_fc2", # qwen3vl "vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1) "model.vision.transformer.layers.{bid}.mlp.fc2", # cogvlm + "siglip2.vision_model.encoder.layers.{bid}.mlp.fc2", ), MODEL_TENSOR.V_LAYER_SCALE_1: ( @@ -1427,6 +1437,7 @@ class TensorNameMap: "visual.merger.ln_q", # qwen2vl "vision_tower.encoder.final_layernorm", # kimi-vl "visual.post_layernorm", # glm4v + "siglip2.vision_model.post_layernorm", ), MODEL_TENSOR.V_MM_POST_NORM: ( @@ -1443,6 +1454,7 @@ class TensorNameMap: "multi_modal_projector.pre_norm", "pre_mm_projector_norm", "model.vision.linear_proj.norm1", # cogvlm + "merger.ln_q", ), MODEL_TENSOR.V_MM_SOFT_EMB_NORM: ( diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 0d5bcc64fe5..1bb98bed09f 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4699,7 +4699,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -4762,7 +4766,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index cd4092ca077..b2e148d32f4 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -314,6 +314,12 @@ struct llm_tokenizer_bpe : llm_tokenizer { "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_UTU_VL: + regex_exprs = { + "[가-힣ㄱ-ㆎ]+|[!…“”‘’—:;,、-〿︰-﹏]+|[ㄅ-ㄯ]+|[一-龥぀-ゟ゠-ヿ]+", + "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER: regex_exprs = { "[\r\n]", @@ -1860,6 +1866,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "deepseek-v3") { pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM; clean_spaces = false; + } else if ( + tokenizer_pre == "utu-vl") { + pre_type = LLAMA_VOCAB_PRE_TYPE_UTU_VL; + clean_spaces = false; + ignore_merges = true; } else if ( tokenizer_pre == "falcon") { pre_type = LLAMA_VOCAB_PRE_TYPE_FALCON; diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 55f8f3923c9..19ae099a3fb 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -51,6 +51,7 @@ enum llama_vocab_pre_type { LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40, LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2 = 41, LLAMA_VOCAB_PRE_TYPE_AFMOE = 42, + LLAMA_VOCAB_PRE_TYPE_UTU_VL = 43, }; struct LLM_KV; diff --git a/src/unicode.cpp b/src/unicode.cpp index bb44edfaddf..b47dcbe6198 100644 --- a/src/unicode.cpp +++ b/src/unicode.cpp @@ -964,6 +964,11 @@ std::vector unicode_regex_split(const std::string & text, const std { "\\p{P}", unicode_cpt_flags::PUNCTUATION }, { "\\p{M}", unicode_cpt_flags::ACCENT_MARK }, { "\\p{S}", unicode_cpt_flags::SYMBOL }, + { "\\p{Lu}", unicode_cpt_flags::LETTER }, // Uppercase letter + { "\\p{Ll}", unicode_cpt_flags::LETTER }, // Lowercase letter + { "\\p{Lt}", unicode_cpt_flags::LETTER }, // Titlecase letter + { "\\p{Lm}", unicode_cpt_flags::LETTER }, // Modifier letter + { "\\p{Lo}", unicode_cpt_flags::LETTER }, // Other letter }; static const std::map k_ucat_cpt = { @@ -1074,22 +1079,26 @@ std::vector unicode_regex_split(const std::string & text, const std continue; } - if (regex_expr[i + 0] == '\\' && i + 4 < regex_expr.size() && + // Match \p{...} Unicode properties of varying lengths + if (regex_expr[i + 0] == '\\' && i + 3 < regex_expr.size() && regex_expr[i + 1] == 'p' && - regex_expr[i + 2] == '{' && - regex_expr[i + 4] == '}') { - const std::string pat = regex_expr.substr(i, 5); - if (k_ucat_enum.find(pat) != k_ucat_enum.end()) { - if (!inside) { - regex_expr_collapsed += '['; + regex_expr[i + 2] == '{') { + // Find the closing brace + size_t closing_brace = regex_expr.find('}', i + 3); + if (closing_brace != std::string::npos && closing_brace <= i + 10) { // reasonable limit + const std::string pat = regex_expr.substr(i, closing_brace - i + 1); + if (k_ucat_enum.find(pat) != k_ucat_enum.end()) { + if (!inside) { + regex_expr_collapsed += '['; + } + regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat)); + regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat)); + if (!inside) { + regex_expr_collapsed += ']'; + } + i = closing_brace; + continue; } - regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat)); - regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat)); - if (!inside) { - regex_expr_collapsed += ']'; - } - i += 4; - continue; } } diff --git a/tools/mtmd/CMakeLists.txt b/tools/mtmd/CMakeLists.txt index 317d5f19fd9..0e862994d67 100644 --- a/tools/mtmd/CMakeLists.txt +++ b/tools/mtmd/CMakeLists.txt @@ -27,6 +27,7 @@ add_library(mtmd models/qwen3vl.cpp models/siglip.cpp models/whisper-enc.cpp + models/utuvl.cpp ) set_target_properties(mtmd PROPERTIES diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index a0939865e3f..90f53f0cdb8 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -49,6 +49,7 @@ #define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" #define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution" #define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern" +#define KEY_WIN_ATTN_LAYERS "clip.vision.wa_layers" #define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size" #define KEY_MINICPMV_VERSION "clip.minicpmv_version" #define KEY_MINICPMV_QUERY_NUM "clip.minicpmv_query_num" @@ -187,6 +188,7 @@ enum projector_type { PROJECTOR_TYPE_JANUS_PRO, PROJECTOR_TYPE_LFM2A, PROJECTOR_TYPE_GLM4V, + PROJECTOR_TYPE_UTUVL, PROJECTOR_TYPE_UNKNOWN, }; @@ -216,6 +218,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_JANUS_PRO, "janus_pro"}, { PROJECTOR_TYPE_LFM2A, "lfm2a"}, { PROJECTOR_TYPE_GLM4V, "glm4v"}, + { PROJECTOR_TYPE_UTUVL, "utuvl"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index b4c31cdde6b..3b17e5a8e82 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -61,6 +61,7 @@ struct clip_hparams { std::unordered_set vision_feature_layer; int32_t attn_window_size = 0; int32_t n_wa_pattern = 0; + std::unordered_set wa_layers; // window attention full layers // audio int32_t n_mel_bins = 0; // whisper preprocessor diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 3ba0823defb..637180f07ac 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -845,6 +845,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { builder = std::make_unique(ctx, img); } break; + case PROJECTOR_TYPE_UTUVL: + { + builder = std::make_unique(ctx, img); + } break; default: GGML_ABORT("missing cgraph builder"); } @@ -1158,6 +1162,23 @@ struct clip_model_loader { LOG_WRN("%s: more info: https://github.com/ggml-org/llama.cpp/issues/16842\n\n", __func__); } } break; + case PROJECTOR_TYPE_UTUVL: + { + hparams.n_merge = 2; + get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false); + get_u32(KEY_ATTN_WINDOW_SIZE, hparams.attn_window_size, true); + std::vector wa_layers_vec; + get_arr_int(KEY_WIN_ATTN_LAYERS, wa_layers_vec, true); + for (auto & layer : wa_layers_vec) { + hparams.wa_layers.insert(layer); + } + hparams.set_limit_image_tokens(1, 62500); + hparams.set_warmup_n_tokens(16*16); // avoid OOM on warmup + const int warn_min_pixels = 1 * hparams.n_merge * hparams.n_merge * hparams.patch_size * hparams.patch_size; + if (hparams.image_min_pixels < warn_min_pixels) { + LOG_WRN("%s: Youtu-VL models require at minimum 1 image tokens to function correctly on grounding tasks\n", __func__); + } + } break; case PROJECTOR_TYPE_GLM4V: { hparams.rope_theta = 10000.0f; @@ -1225,7 +1246,14 @@ struct clip_model_loader { LOG_INF("%s: has_llava_proj: %d\n", __func__, hparams.has_llava_projector); LOG_INF("%s: minicpmv_version: %d\n", __func__, hparams.minicpmv_version); LOG_INF("%s: n_merge: %d\n", __func__, hparams.n_merge); - LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern); + LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern); + if (!hparams.wa_layers.empty()) { + LOG_INF("%s: wa_layers: ", __func__); + for (auto & layer : hparams.wa_layers) { + LOG_INF("%d ", layer); + } + LOG_INF("\n"); + } if (hparams.image_min_pixels > 0) { LOG_INF("%s: image_min_pixels: %d%s\n", __func__, hparams.image_min_pixels, hparams.custom_image_min_tokens > 0 ? " (custom value)" : ""); } @@ -1493,6 +1521,14 @@ struct clip_model_loader { model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); } break; + case PROJECTOR_TYPE_UTUVL: + { + model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM); // merger.ln_q (RMS norm) + model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); // merger.mlp.0 + model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias")); + model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); // merger.mlp.2 + model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); + } break; case PROJECTOR_TYPE_GLM4V: { model.projection = get_tensor(TN_MM_PROJECTOR); @@ -2684,6 +2720,57 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str // res_imgs->data[0] = *res; res_imgs->entries.push_back(std::move(img_f32)); } break; + case PROJECTOR_TYPE_UTUVL: + { + const int patch_size = params.patch_size; // typically 16 + const int merge_size = params.n_merge; // typically 2 + const int align_size = patch_size * merge_size; // 32 + + const int max_num_patches = params.image_max_pixels > 0 ? + params.image_max_pixels / (patch_size * patch_size) : 256; + + // Linear search for optimal scale to fit within max_num_patches + float scale = 1.0f; + int target_height = original_size.height; + int target_width = original_size.width; + + auto get_scaled_image_size = [align_size](float scale, int size) -> int { + float scaled_size = size * scale; + // Round up to nearest multiple of align_size + int aligned = static_cast(std::ceil(scaled_size / align_size)) * align_size; + // Ensure at least one patch + return std::max(align_size, aligned); + }; + + // Linear search with 0.02 step size + while (scale > 0.0f) { + target_height = get_scaled_image_size(scale, original_size.height); + target_width = get_scaled_image_size(scale, original_size.width); + + int num_patches_h = target_height / patch_size; + int num_patches_w = target_width / patch_size; + int num_patches = num_patches_h * num_patches_w; + + if (num_patches > max_num_patches) { + scale -= 0.02f; + } else { + break; + } + } + + clip_image_size new_size = {target_width, target_height}; + + // Resize the image + clip_image_u8 resized; + img_tool::resize(*img, resized, new_size, img_tool::RESIZE_ALGO_BILINEAR, false); + + // Normalize to float32 + clip_image_f32_ptr img_f32(clip_image_f32_init()); + normalize_image_u8_to_f32(resized, *img_f32, params.image_mean, params.image_std); + + // Add to results + res_imgs->entries.push_back(std::move(img_f32)); + } break; case PROJECTOR_TYPE_IDEFICS3: { @@ -2916,6 +3003,7 @@ int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: case PROJECTOR_TYPE_GLM4V: + case PROJECTOR_TYPE_UTUVL: return (img->nx / params.patch_size) / 2; default: break; @@ -2931,6 +3019,7 @@ int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: case PROJECTOR_TYPE_GLM4V: + case PROJECTOR_TYPE_UTUVL: return (img->ny / params.patch_size) / 2; default: break; @@ -2991,6 +3080,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: case PROJECTOR_TYPE_GLM4V: + case PROJECTOR_TYPE_UTUVL: { // dynamic size (2 conv, so double patch size) int x_patch = img->nx / (params.patch_size * 2); @@ -3117,7 +3207,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima const int pos_w = image_size_width / patch_size; const int pos_h = image_size_height / patch_size; - const bool use_window_attn = hparams.n_wa_pattern > 0; // for qwen2.5vl auto get_inp_tensor = [&gf](const char * name) { ggml_tensor * inp = ggml_graph_get_tensor(gf, name); @@ -3266,9 +3355,11 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima set_input_i32("positions", positions); } break; case PROJECTOR_TYPE_QWEN25VL: + case PROJECTOR_TYPE_UTUVL: { // pw * ph = number of tokens output by ViT after apply patch merger // ipw * ipw = number of vision token been processed inside ViT + const bool use_window_attn = ctx->model.proj_type == PROJECTOR_TYPE_QWEN25VL ? hparams.n_wa_pattern > 0 : !hparams.wa_layers.empty(); const int merge_ratio = 2; const int pw = image_size_width / patch_size / merge_ratio; const int ph = image_size_height / patch_size / merge_ratio; @@ -3279,7 +3370,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima std::vector inv_idx(ph * pw); if (use_window_attn) { - const int attn_window_size = 112; + const int attn_window_size = hparams.attn_window_size > 0 ? hparams.attn_window_size : 112; const int grid_window = attn_window_size / patch_size / merge_ratio; int dst = 0; // [num_vision_tokens, num_vision_tokens] attention mask tensor @@ -3516,6 +3607,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_JANUS_PRO: + case PROJECTOR_TYPE_UTUVL: return ctx->model.mm_1_b->ne[0]; case PROJECTOR_TYPE_QWEN3VL: // main path + deepstack paths diff --git a/tools/mtmd/models/models.h b/tools/mtmd/models/models.h index 8d6d4ef67be..8360c72d050 100644 --- a/tools/mtmd/models/models.h +++ b/tools/mtmd/models/models.h @@ -22,6 +22,11 @@ struct clip_graph_qwen3vl : clip_graph { ggml_cgraph * build() override; }; +struct clip_graph_utuvl : clip_graph { + clip_graph_utuvl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; +}; + struct clip_graph_minicpmv : clip_graph { clip_graph_minicpmv(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} ggml_cgraph * build() override; diff --git a/tools/mtmd/models/utuvl.cpp b/tools/mtmd/models/utuvl.cpp new file mode 100644 index 00000000000..aa3c9857dc9 --- /dev/null +++ b/tools/mtmd/models/utuvl.cpp @@ -0,0 +1,179 @@ +#include "models.h" + +ggml_cgraph * clip_graph_utuvl::build() { + GGML_ASSERT(model.class_embedding == nullptr); + const int batch_size = 1; + const bool use_window_attn = !hparams.wa_layers.empty(); + const int n_pos = n_patches; + const int num_position_ids = n_pos * 4; + const int m = 2; + const int Wp = n_patches_x; + const int Hp = n_patches_y; + const int Hm = Hp / m; + const int Wm = Wp / m; + norm_type norm_t = NORM_TYPE_NORMAL; + + int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; + + ggml_tensor * inp = build_inp_raw(); + + // change conv3d to linear + // reshape and permute to get patches, permute from (patch_size, m, Wm, patch_size, m, Hm, C) to (C, patch_size, patch_size, m, m, Wm, Hm) + { + inp = ggml_reshape_4d( + ctx0, inp, + Wm * m * patch_size, m * patch_size, Hm, 3); + inp = ggml_permute(ctx0, inp, 1, 2, 3, 0); + inp = ggml_cont_4d( + ctx0, inp, + m * patch_size * 3, Wm, m * patch_size, Hm); + + inp = ggml_permute(ctx0, inp, 0, 2, 1, 3); + inp = ggml_cont_4d( + ctx0, inp, + m * patch_size * 3, patch_size, m, Hm * Wm); + + inp = ggml_permute(ctx0, inp, 1, 0, 2, 3); + inp = ggml_cont_4d( + ctx0, inp, + patch_size, 3, patch_size, Hm * Wm * m * m); + + inp = ggml_permute(ctx0, inp, 2, 0, 1, 3); + inp = ggml_cont_3d( + ctx0, inp, + 3*patch_size* patch_size, Hm * Wm * m * m, 1); + } + inp = ggml_mul_mat(ctx0, model.patch_embeddings_0, inp); + + if (model.patch_bias) { + inp = ggml_add(ctx0, inp, model.patch_bias); + } + + inp = ggml_reshape_2d(ctx0, inp, n_embd, n_patches); + + ggml_tensor * inpL = inp; + ggml_tensor * window_mask = nullptr; + ggml_tensor * window_idx = nullptr; + ggml_tensor * inv_window_idx = nullptr; + + ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids); + ggml_set_name(positions, "positions"); + ggml_set_input(positions); + + // pre-layernorm + if (model.pre_ln_w) { + inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1); + } + if (use_window_attn) { + inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / 4); + ggml_set_name(inv_window_idx, "inv_window_idx"); + ggml_set_input(inv_window_idx); + // mask for window attention + window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_pos, n_pos); + ggml_set_name(window_mask, "window_mask"); + ggml_set_input(window_mask); + + // if flash attn is used, we need to pad the mask and cast to f16 + if (flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) { + window_mask = ggml_cast(ctx0, window_mask, GGML_TYPE_F16); + } + + // inpL shape: [n_embd, n_patches_x * n_patches_y, batch_size] + GGML_ASSERT(batch_size == 1); + inpL = ggml_reshape_2d(ctx0, inpL, n_embd * 4, n_patches_x * n_patches_y * batch_size / 4); + inpL = ggml_get_rows(ctx0, inpL, inv_window_idx); + inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_patches_x * n_patches_y, batch_size); + } + + // loop over layers + for (int il = 0; il < n_layer; il++) { + const auto & layer = model.layers[il]; + const bool full_attn = use_window_attn ? hparams.wa_layers.count(il) > 0 : true; + + ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states + + // layernorm1 + cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il); + // self-attention + { + ggml_tensor * Qcur = ggml_add(ctx0, + ggml_mul_mat(ctx0, layer.q_w, cur), layer.q_b); + ggml_tensor * Kcur = ggml_add(ctx0, + ggml_mul_mat(ctx0, layer.k_w, cur), layer.k_b); + ggml_tensor * Vcur = ggml_add(ctx0, + ggml_mul_mat(ctx0, layer.v_w, cur), layer.v_b); + + Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_patches); + Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_patches); + Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_patches); + + Qcur = ggml_rope_multi( + ctx0, Qcur, positions, nullptr, + d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); + Kcur = ggml_rope_multi( + ctx0, Kcur, positions, nullptr, + d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); + + ggml_tensor * attn_mask = full_attn ? nullptr : window_mask; + + cur = build_attn(layer.o_w, layer.o_b, + Qcur, Kcur, Vcur, attn_mask, kq_scale, il); + } + // re-add the layer input, e.g., residual + cur = ggml_add(ctx0, cur, inpL); + + inpL = cur; // inpL = residual, cur = hidden_states + + // layernorm2 + cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il); + + // ffn + cur = build_ffn(cur, + layer.ff_up_w, layer.ff_up_b, + nullptr, nullptr, + layer.ff_down_w, layer.ff_down_b, + hparams.ffn_op, il); + + // residual 2 + cur = ggml_add(ctx0, inpL, cur); + + inpL = cur; + } + + ggml_tensor * embeddings = inpL; + if (use_window_attn) { + const int spatial_merge_unit = 4; + window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / spatial_merge_unit); + ggml_set_name(window_idx, "window_idx"); + ggml_set_input(window_idx); + GGML_ASSERT(batch_size == 1); + embeddings = ggml_reshape_2d(ctx0, embeddings, n_embd * spatial_merge_unit, n_patches / spatial_merge_unit); + embeddings = ggml_get_rows(ctx0, embeddings, window_idx); + embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd, n_patches, batch_size); + cb(embeddings, "window_order_restored", -1); + } + + // post-layernorm (part of Siglip2VisionTransformer, applied after encoder) + if (model.post_ln_w) { + embeddings = build_norm(embeddings, model.post_ln_w, model.post_ln_b, norm_t, eps, n_layer); + } + + // Now apply merger (VLPatchMerger): + // 1. Apply RMS norm (ln_q in VLPatchMerger) + embeddings = build_norm(embeddings, model.mm_input_norm_w, nullptr, NORM_TYPE_RMS, 1e-6, -1); + cb(embeddings, "merger_normed", -1); + + // 2. First reshape for spatial merge (merge 2x2 patches) + embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, n_pos / 4, batch_size); + cb(embeddings, "merger_reshaped", -1); + + embeddings = build_ffn(embeddings, + model.mm_0_w, model.mm_0_b, + nullptr, nullptr, + model.mm_1_w, model.mm_1_b, + FFN_GELU, + -1); + ggml_build_forward_expand(gf, embeddings); + + return gf; +} diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index b9c4fa90980..3f8bf53454a 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -283,7 +283,7 @@ struct mtmd_context { // https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md img_end = "[IMG_END]"; - } else if (proj == PROJECTOR_TYPE_QWEN2VL || proj == PROJECTOR_TYPE_QWEN25VL || proj == PROJECTOR_TYPE_QWEN3VL) { + } else if (proj == PROJECTOR_TYPE_QWEN2VL || proj == PROJECTOR_TYPE_QWEN25VL || proj == PROJECTOR_TYPE_QWEN3VL || proj == PROJECTOR_TYPE_UTUVL) { // <|vision_start|> ... (image embeddings) ... <|vision_end|> img_beg = "<|vision_start|>"; img_end = "<|vision_end|>"; diff --git a/vendor/minja/minja.hpp b/vendor/minja/minja.hpp index 873ece8c180..905fc47f8a6 100644 --- a/vendor/minja/minja.hpp +++ b/vendor/minja/minja.hpp @@ -1446,7 +1446,7 @@ struct ArgumentsExpression { static std::string strip(const std::string & s, const std::string & chars = "", bool left = true, bool right = true) { auto charset = chars.empty() ? " \t\n\r" : chars; auto start = left ? s.find_first_not_of(charset) : 0; - if (start == std::string::npos) return ""; + if (start == std::string::npos) return ""; auto end = right ? s.find_last_not_of(charset) : s.size() - 1; return s.substr(start, end - start + 1); } @@ -1464,6 +1464,20 @@ static std::vector split(const std::string & s, const std::string & return result; } +static std::vector rsplit(const std::string & s, const std::string & sep) { + std::vector result; + size_t end = s.length(); + size_t pos = s.rfind(sep); + while (pos != std::string::npos) { + result.insert(result.begin(), s.substr(pos + sep.length(), end - pos - sep.length())); + end = pos; + if (pos == 0) break; + pos = s.rfind(sep, pos - 1); + } + result.insert(result.begin(), s.substr(0, end)); + return result; +} + static std::string capitalize(const std::string & s) { if (s.empty()) return s; auto result = s; @@ -1573,6 +1587,15 @@ class MethodCallExpr : public Expression { result.push_back(Value(part)); } return result; + } else if (method->get_name() == "rsplit") { + vargs.expectArgs("rsplit method", {1, 1}, {0, 0}); + auto sep = vargs.args[0].get(); + auto parts = rsplit(str, sep); + Value result = Value::array(); + for (const auto& part : parts) { + result.push_back(Value(part)); + } + return result; } else if (method->get_name() == "capitalize") { vargs.expectArgs("capitalize method", {0, 0}, {0, 0}); return Value(capitalize(str));