Skip to content

Commit 867709c

Browse files
committed
fix bug
1 parent 1600974 commit 867709c

File tree

10 files changed

+97
-68
lines changed

10 files changed

+97
-68
lines changed

convert_hf_to_gguf.py

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,8 +1173,6 @@ def get_vocab_base_pre(self, tokenizer) -> str:
11731173
if chkhsh == "877081d19cf6996e2c4ff0e1236341e9b7bde288f5311a56a937f0afbbb3aeb5":
11741174
# ref: https://huggingface.co/deepseek-ai/DeepSeek-V3
11751175
res = "deepseek-v3"
1176-
if chkhsh == "9d70134b369a70e5735009b6de918f7581b5211f7c074d1f89f753aea8248af1":
1177-
res = "utu-vl"
11781176
if chkhsh == "b3f499bb4255f8ca19fccd664443283318f2fd2414d5e0b040fbdd0cc195d6c5":
11791177
# ref: https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
11801178
res = "deepseek-r1-qwen"
@@ -1232,6 +1230,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
12321230
if chkhsh == "4a2e2abae11ca2b86d570fc5b44be4d5eb5e72cc8f22dd136a94b37da83ab665":
12331231
# ref: https://huggingface.co/KORMo-Team/KORMo-tokenizer
12341232
res = "kormo"
1233+
if chkhsh == "9d70134b369a70e5735009b6de918f7581b5211f7c074d1f89f753aea8248af1":
1234+
# ref: ./Youtu-VL
1235+
res = "utu-vl"
12351236

12361237
if res is None:
12371238
logger.warning("\n")
@@ -3808,15 +3809,10 @@ def set_gguf_parameters(self):
38083809
else:
38093810
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN25VL)
38103811
self.gguf_writer.add_vision_use_silu(True)
3811-
# find n_wa_pattern (window attention pattern)
3812+
# save window attention layers (full attention block indexes)
38123813
fullatt_block_indexes = hparams.get("fullatt_block_indexes")
38133814
assert fullatt_block_indexes is not None, "fullatt_block_indexes is required for qwen2_5_vl"
3814-
n_wa_pattern = fullatt_block_indexes[0] + 1
3815-
# validate n_wa_pattern
3816-
for i in range(1, len(fullatt_block_indexes)):
3817-
if fullatt_block_indexes[i] - fullatt_block_indexes[i - 1] != n_wa_pattern:
3818-
raise ValueError(f"Invalid fullatt_block_indexes: {fullatt_block_indexes}")
3819-
self.gguf_writer.add_vision_n_wa_pattern(n_wa_pattern)
3815+
self.gguf_writer.add_vision_wa_layers(fullatt_block_indexes)
38203816
else:
38213817
raise ValueError(f"Unknown QwenVL model type: {self.global_config['model_type']}")
38223818
# default values below are taken from HF tranformers code
@@ -7214,26 +7210,26 @@ def set_gguf_parameters(self):
72147210
self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
72157211
self.gguf_writer.add_value_length_mla(hparams["v_head_dim"])
72167212

7217-
if hparams.get("moe_intermediate_size") is not None:
7218-
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
7213+
if (moe_intermediate_size := hparams.get("moe_intermediate_size")) is not None:
7214+
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
72197215
else:
72207216
self.gguf_writer.add_expert_feed_forward_length(hparams.get("intermediate_size", 0))
72217217

7222-
if hparams.get("n_routed_experts") is not None:
7223-
self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
7218+
if (n_routed_experts := hparams.get("n_routed_experts")) is not None:
7219+
self.gguf_writer.add_expert_count(n_routed_experts)
72247220

7225-
if hparams.get("n_shared_experts") is not None:
7226-
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
7221+
if (n_shared_experts := hparams.get("n_shared_experts")) is not None:
7222+
self.gguf_writer.add_expert_shared_count(n_shared_experts)
72277223
else:
72287224
self.gguf_writer.add_expert_shared_count(0)
72297225

7230-
if hparams.get("routed_scaling_factor") is not None:
7231-
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
7226+
if (routed_scaling_factor := hparams.get("routed_scaling_factor")) is not None:
7227+
self.gguf_writer.add_expert_weights_scale(routed_scaling_factor)
72327228
else:
72337229
self.gguf_writer.add_expert_weights_scale(1.0)
72347230

7235-
if hparams.get("norm_topk_prob") is not None and hparams["norm_topk_prob"]:
7236-
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
7231+
if (norm_topk_prob := hparams.get("norm_topk_prob")) is not None and norm_topk_prob:
7232+
self.gguf_writer.add_expert_weights_norm(norm_topk_prob)
72377233

72387234
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
72397235

@@ -7244,7 +7240,6 @@ def set_gguf_parameters(self):
72447240
self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * rope_mscale_all)
72457241

72467242
_experts: list[dict[str, Tensor]] | None = None
7247-
_token_embd: Tensor | None = None
72487243

72497244
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
72507245
# skip vision tensors and remove "language_model." for Kimi-VL
@@ -7257,11 +7252,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
72577252

72587253
# skip lm_head.weight if tie_word_embeddings is True
72597254
if self.hparams.get("tie_word_embeddings", False):
7260-
# Save token_embd for potential duplication as output if tie_word_embeddings is True
7261-
if name == "model.embed_tokens.weight":
7262-
self._token_embd = data_torch
72637255
if name == "lm_head.weight" or name == "model.lm_head.weight":
7264-
logger.info("Skipping tied output layer 'lm_head.weight' - will duplicate from token_embd.weight")
7256+
logger.info("Skipping tied output layer 'lm_head.weight' (will use token_embd.weight)")
72657257
return []
72667258

72677259
# rename e_score_correction_bias tensors
@@ -7337,10 +7329,6 @@ def prepare_tensors(self):
73377329
experts = [k for d in self._experts for k in d.keys()]
73387330
if len(experts) > 0:
73397331
raise ValueError(f"Unprocessed experts: {experts}")
7340-
if self._token_embd is not None:
7341-
logger.info("Model has tie_word_embeddings=True but no lm_head.weight found - adding output.weight from token_embd.weight")
7342-
output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT)
7343-
self.gguf_writer.add_tensor(output_name, self._token_embd.numpy())
73447332

73457333
@ModelBase.register("MiniMaxM2ForCausalLM")
73467334
class MiniMaxM2Model(TextModel):
@@ -10521,7 +10509,14 @@ def set_gguf_parameters(self):
1052110509
raise ValueError(f"Unsupported activation function for UTUVL: {hidden_act}")
1052210510

1052310511
self.gguf_writer.add_vision_spatial_merge_size(self.hparams.get("spatial_merge_size", 2))
10524-
10512+
10513+
window_size = self.hparams.get("window_size")
10514+
if window_size is not None:
10515+
self.gguf_writer.add_vision_window_size(window_size)
10516+
fullatt_block_indexes = self.hparams.get("fullatt_block_indexes")
10517+
assert fullatt_block_indexes is not None, "fullatt_block_indexes is required for utuvl"
10518+
self.gguf_writer.add_vision_wa_layers(layers=fullatt_block_indexes)
10519+
1052510520
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
1052610521
del bid # unused
1052710522

convert_hf_to_gguf_update.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ class TOKENIZER_TYPE(IntEnum):
145145
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
146146
{"name": "minimax-m2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", },
147147
{"name": "kormo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/KORMo-Team/KORMo-tokenizer", },
148+
{"name": "utu-vl", "tokt": TOKENIZER_TYPE.BPE, "repo": "./Youtu-VL", },
148149
]
149150

150151
# some models are known to be broken upstream, so we will skip them as exceptions

gguf-py/gguf/constants.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,8 +293,9 @@ class ClipVision:
293293
SPATIAL_MERGE_SIZE = "clip.vision.spatial_merge_size"
294294
USE_GELU = "clip.use_gelu"
295295
USE_SILU = "clip.use_silu"
296-
N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl
296+
WA_LAYERS = "clip.vision.wa_layers" # used by qwen2.5vl and utuvl
297297
IS_DEEPSTACK_LAYERS = "clip.vision.is_deepstack_layers"
298+
WINDOW_SIZE = "clip.vision.window_size"
298299

299300
class Attention:
300301
HEAD_COUNT = "clip.vision.attention.head_count"

gguf-py/gguf/gguf_writer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,12 +1128,15 @@ def add_vision_use_silu(self, value: bool) -> None:
11281128
def add_vision_projector_scale_factor(self, value: int) -> None:
11291129
self.add_uint32(Keys.ClipVision.Projector.SCALE_FACTOR, value)
11301130

1131-
def add_vision_n_wa_pattern(self, value: int) -> None:
1132-
self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value)
1131+
def add_vision_wa_layers(self, layers: Sequence[int]) -> None:
1132+
self.add_array(Keys.ClipVision.WA_LAYERS, layers)
11331133

11341134
def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None:
11351135
self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers)
11361136

1137+
def add_vision_window_size(self, value: int) -> None:
1138+
self.add_uint32(Keys.ClipVision.WINDOW_SIZE, value)
1139+
11371140
# audio models
11381141

11391142
def add_audio_projection_dim(self, value: int) -> None:

src/llama-model.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4699,7 +4699,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
46994699

47004700
// output
47014701
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
4702-
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
4702+
// try to load output.weight, if not found, use token_embd (tied embeddings)
4703+
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
4704+
if (!output) {
4705+
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
4706+
}
47034707

47044708
for (int i = 0; i < n_layer; ++i) {
47054709
auto & layer = layers[i];
@@ -4762,7 +4766,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
47624766

47634767
// output
47644768
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
4765-
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
4769+
// try to load output.weight, if not found, use token_embd (tied embeddings)
4770+
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
4771+
if (!output) {
4772+
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
4773+
}
47664774

47674775
for (int i = 0; i < n_layer; ++i) {
47684776
auto & layer = layers[i];

tools/mtmd/clip-impl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
4949
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
5050
#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
51-
#define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern"
51+
#define KEY_WIN_ATTN_LAYERS "clip.vision.wa_layers"
5252
#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size"
5353
#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
5454
#define KEY_MINICPMV_QUERY_NUM "clip.minicpmv_query_num"

tools/mtmd/clip-model.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ struct clip_hparams {
6060
int32_t image_crop_resolution;
6161
std::unordered_set<int32_t> vision_feature_layer;
6262
int32_t attn_window_size = 0;
63-
int32_t n_wa_pattern = 0;
63+
std::unordered_set<int32_t> wa_layers; // window attention full layers
6464

6565
// audio
6666
int32_t n_mel_bins = 0; // whisper preprocessor

tools/mtmd/clip.cpp

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1151,7 +1151,14 @@ struct clip_model_loader {
11511151
{
11521152
hparams.n_merge = 2; // default value for Qwen 2 and 2.5
11531153
get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false);
1154-
get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern, model.proj_type == PROJECTOR_TYPE_QWEN25VL); // only 2.5 requires it
1154+
// load window attention layers (only 2.5 requires it)
1155+
if (model.proj_type == PROJECTOR_TYPE_QWEN25VL) {
1156+
std::vector<int> wa_layers_vec;
1157+
get_arr_int(KEY_WIN_ATTN_LAYERS, wa_layers_vec, true);
1158+
for (auto & layer : wa_layers_vec) {
1159+
hparams.wa_layers.insert(layer);
1160+
}
1161+
}
11551162
// ref: https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct/blob/main/preprocessor_config.json
11561163
hparams.set_limit_image_tokens(8, 4096);
11571164
hparams.set_warmup_n_tokens(46*46); // avoid OOM on warmup
@@ -1166,6 +1173,12 @@ struct clip_model_loader {
11661173
{
11671174
hparams.n_merge = 2;
11681175
get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false);
1176+
get_u32(KEY_ATTN_WINDOW_SIZE, hparams.attn_window_size, true);
1177+
std::vector<int> wa_layers_vec;
1178+
get_arr_int(KEY_WIN_ATTN_LAYERS, wa_layers_vec, true);
1179+
for (auto & layer : wa_layers_vec) {
1180+
hparams.wa_layers.insert(layer);
1181+
}
11691182
hparams.set_limit_image_tokens(8, 4096);
11701183
hparams.set_warmup_n_tokens(46*46); // avoid OOM on warmup
11711184
const int warn_min_pixels = 1024 * hparams.n_merge * hparams.n_merge * hparams.patch_size * hparams.patch_size;
@@ -1240,7 +1253,13 @@ struct clip_model_loader {
12401253
LOG_INF("%s: has_llava_proj: %d\n", __func__, hparams.has_llava_projector);
12411254
LOG_INF("%s: minicpmv_version: %d\n", __func__, hparams.minicpmv_version);
12421255
LOG_INF("%s: n_merge: %d\n", __func__, hparams.n_merge);
1243-
LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern);
1256+
if (!hparams.wa_layers.empty()) {
1257+
LOG_INF("%s: wa_layers: ", __func__);
1258+
for (auto & layer : hparams.wa_layers) {
1259+
LOG_INF("%d ", layer);
1260+
}
1261+
LOG_INF("\n");
1262+
}
12441263
if (hparams.image_min_pixels > 0) {
12451264
LOG_INF("%s: image_min_pixels: %d%s\n", __func__, hparams.image_min_pixels, hparams.custom_image_min_tokens > 0 ? " (custom value)" : "");
12461265
}
@@ -3346,7 +3365,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
33463365
{
33473366
// pw * ph = number of tokens output by ViT after apply patch merger
33483367
// ipw * ipw = number of vision token been processed inside ViT
3349-
const bool use_window_attn = hparams.n_wa_pattern > 0; // for qwen2.5vl
3368+
const bool use_window_attn = !hparams.wa_layers.empty(); // for qwen2.5vl
33503369
const int merge_ratio = 2;
33513370
const int pw = image_size_width / patch_size / merge_ratio;
33523371
const int ph = image_size_height / patch_size / merge_ratio;
@@ -3357,7 +3376,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
33573376
std::vector<int> inv_idx(ph * pw);
33583377

33593378
if (use_window_attn) {
3360-
const int attn_window_size = ctx->model.proj_type == PROJECTOR_TYPE_QWEN25VL ? 112 : patch_size * 2 * 8;
3379+
const int attn_window_size = hparams.attn_window_size > 0 ? hparams.attn_window_size : 112;
33613380
const int grid_window = attn_window_size / patch_size / merge_ratio;
33623381
int dst = 0;
33633382
// [num_vision_tokens, num_vision_tokens] attention mask tensor

tools/mtmd/models/qwen2vl.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@ ggml_cgraph * clip_graph_qwen2vl::build() {
55
GGML_ASSERT(model.class_embedding == nullptr);
66

77
const int batch_size = 1;
8-
const bool use_window_attn = hparams.n_wa_pattern > 0;
9-
const int n_wa_pattern = hparams.n_wa_pattern;
8+
const bool use_window_attn = !hparams.wa_layers.empty();
109
const int n_pos = n_patches;
1110
const int num_position_ids = n_pos * 4; // m-rope requires 4 dim per position
1211

@@ -79,7 +78,7 @@ ggml_cgraph * clip_graph_qwen2vl::build() {
7978
// loop over layers
8079
for (int il = 0; il < n_layer; il++) {
8180
const auto & layer = model.layers[il];
82-
const bool full_attn = use_window_attn ? (il + 1) % n_wa_pattern == 0 : true;
81+
const bool full_attn = use_window_attn ? hparams.wa_layers.count(il) > 0 : true;
8382

8483
ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states
8584

tools/mtmd/models/utuvl.cpp

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
ggml_cgraph * clip_graph_utuvl::build() {
44
GGML_ASSERT(model.class_embedding == nullptr);
55
const int batch_size = 1;
6-
const bool use_window_attn = true;
6+
const bool use_window_attn = !hparams.wa_layers.empty();
77
const int n_pos = n_patches;
88
const int num_position_ids = n_pos * 4;
99
const int m = 2;
@@ -17,29 +17,32 @@ ggml_cgraph * clip_graph_utuvl::build() {
1717

1818
ggml_tensor * inp = build_inp_raw();
1919

20-
inp = ggml_reshape_4d(
21-
ctx0, inp,
22-
Wm * m * patch_size, m * patch_size, Hm, 3);
23-
inp = ggml_permute(ctx0, inp, 1, 2, 3, 0);
24-
inp = ggml_cont_4d(
25-
ctx0, inp,
26-
m * patch_size * 3, Wm, m * patch_size, Hm);
27-
28-
inp = ggml_permute(ctx0, inp, 0, 2, 1, 3);
29-
inp = ggml_cont_4d(
30-
ctx0, inp,
31-
m * patch_size * 3, patch_size, m, Hm * Wm);
32-
33-
inp = ggml_permute(ctx0, inp, 1, 0, 2, 3);
34-
inp = ggml_cont_4d(
35-
ctx0, inp,
36-
patch_size, 3, patch_size, Hm * Wm * m * m);
37-
38-
inp = ggml_permute(ctx0, inp, 2, 0, 1, 3);
39-
inp = ggml_cont_3d(
40-
ctx0, inp,
41-
3*patch_size* patch_size, Hm * Wm * m * m, 1);
42-
20+
// change conv3d to linear
21+
// reshape and permute to get patches, permute from (patch_size, m, Wm, patch_size, m, Hm, C) to (C, patch_size, patch_size, m, m, Wm, Hm)
22+
{
23+
inp = ggml_reshape_4d(
24+
ctx0, inp,
25+
Wm * m * patch_size, m * patch_size, Hm, 3);
26+
inp = ggml_permute(ctx0, inp, 1, 2, 3, 0);
27+
inp = ggml_cont_4d(
28+
ctx0, inp,
29+
m * patch_size * 3, Wm, m * patch_size, Hm);
30+
31+
inp = ggml_permute(ctx0, inp, 0, 2, 1, 3);
32+
inp = ggml_cont_4d(
33+
ctx0, inp,
34+
m * patch_size * 3, patch_size, m, Hm * Wm);
35+
36+
inp = ggml_permute(ctx0, inp, 1, 0, 2, 3);
37+
inp = ggml_cont_4d(
38+
ctx0, inp,
39+
patch_size, 3, patch_size, Hm * Wm * m * m);
40+
41+
inp = ggml_permute(ctx0, inp, 2, 0, 1, 3);
42+
inp = ggml_cont_3d(
43+
ctx0, inp,
44+
3*patch_size* patch_size, Hm * Wm * m * m, 1);
45+
}
4346
inp = ggml_mul_mat(ctx0, model.patch_embeddings_0, inp);
4447

4548
if (model.patch_bias) {
@@ -85,7 +88,7 @@ ggml_cgraph * clip_graph_utuvl::build() {
8588
// loop over layers
8689
for (int il = 0; il < n_layer; il++) {
8790
const auto & layer = model.layers[il];
88-
const bool full_attn = (il + 1) % 8 == 0 || il == n_layer - 1;
91+
const bool full_attn = use_window_attn ? hparams.wa_layers.count(il) > 0 : true;
8992

9093
ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states
9194

0 commit comments

Comments
 (0)