2020
2121from modules import shared , devices , sd_models , errors , scripts , sd_hijack
2222import modules .textual_inversion .textual_inversion as textual_inversion
23+ import modules .models .sd3 .mmdit
2324
2425from lora_logger import logger
2526
@@ -166,12 +167,26 @@ def load_network(name, network_on_disk):
166167
167168 keys_failed_to_match = {}
168169 is_sd2 = 'model_transformer_resblocks' in shared .sd_model .network_layer_mapping
170+ if hasattr (shared .sd_model , 'diffusers_weight_map' ):
171+ diffusers_weight_map = shared .sd_model .diffusers_weight_map
172+ elif hasattr (shared .sd_model , 'diffusers_weight_mapping' ):
173+ diffusers_weight_map = {}
174+ for k , v in shared .sd_model .diffusers_weight_mapping ():
175+ diffusers_weight_map [k ] = v
176+ shared .sd_model .diffusers_weight_map = diffusers_weight_map
177+ else :
178+ diffusers_weight_map = None
169179
170180 matched_networks = {}
171181 bundle_embeddings = {}
172182
173183 for key_network , weight in sd .items ():
174- key_network_without_network_parts , _ , network_part = key_network .partition ("." )
184+
185+ if diffusers_weight_map :
186+ key_network_without_network_parts , network_name , network_weight = key_network .rsplit ("." , 2 )
187+ network_part = network_name + '.' + network_weight
188+ else :
189+ key_network_without_network_parts , _ , network_part = key_network .partition ("." )
175190
176191 if key_network_without_network_parts == "bundle_emb" :
177192 emb_name , vec_name = network_part .split ("." , 1 )
@@ -183,7 +198,11 @@ def load_network(name, network_on_disk):
183198 emb_dict [vec_name ] = weight
184199 bundle_embeddings [emb_name ] = emb_dict
185200
186- key = convert_diffusers_name_to_compvis (key_network_without_network_parts , is_sd2 )
201+ if diffusers_weight_map :
202+ key = diffusers_weight_map .get (key_network_without_network_parts , key_network_without_network_parts )
203+ else :
204+ key = convert_diffusers_name_to_compvis (key_network_without_network_parts , is_sd2 )
205+
187206 sd_module = shared .sd_model .network_layer_mapping .get (key , None )
188207
189208 if sd_module is None :
@@ -347,6 +366,28 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
347366 purge_networks_from_memory ()
348367
349368
369+ def allowed_layer_without_weight (layer ):
370+ if isinstance (layer , torch .nn .LayerNorm ) and not layer .elementwise_affine :
371+ return True
372+
373+ return False
374+
375+
376+ def store_weights_backup (weight ):
377+ if weight is None :
378+ return None
379+
380+ return weight .to (devices .cpu , copy = True )
381+
382+
383+ def restore_weights_backup (obj , field , weight ):
384+ if weight is None :
385+ setattr (obj , field , None )
386+ return
387+
388+ getattr (obj , field ).copy_ (weight )
389+
390+
350391def network_restore_weights_from_backup (self : Union [torch .nn .Conv2d , torch .nn .Linear , torch .nn .GroupNorm , torch .nn .LayerNorm , torch .nn .MultiheadAttention ]):
351392 weights_backup = getattr (self , "network_weights_backup" , None )
352393 bias_backup = getattr (self , "network_bias_backup" , None )
@@ -356,21 +397,15 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li
356397
357398 if weights_backup is not None :
358399 if isinstance (self , torch .nn .MultiheadAttention ):
359- self . in_proj_weight . copy_ ( weights_backup [0 ])
360- self .out_proj . weight . copy_ ( weights_backup [1 ])
400+ restore_weights_backup ( self , ' in_proj_weight' , weights_backup [0 ])
401+ restore_weights_backup ( self .out_proj , ' weight' , weights_backup [1 ])
361402 else :
362- self . weight . copy_ ( weights_backup )
403+ restore_weights_backup ( self , ' weight' , weights_backup )
363404
364- if bias_backup is not None :
365- if isinstance (self , torch .nn .MultiheadAttention ):
366- self .out_proj .bias .copy_ (bias_backup )
367- else :
368- self .bias .copy_ (bias_backup )
405+ if isinstance (self , torch .nn .MultiheadAttention ):
406+ restore_weights_backup (self .out_proj , 'bias' , bias_backup )
369407 else :
370- if isinstance (self , torch .nn .MultiheadAttention ):
371- self .out_proj .bias = None
372- else :
373- self .bias = None
408+ restore_weights_backup (self , 'bias' , bias_backup )
374409
375410
376411def network_apply_weights (self : Union [torch .nn .Conv2d , torch .nn .Linear , torch .nn .GroupNorm , torch .nn .LayerNorm , torch .nn .MultiheadAttention ]):
@@ -389,37 +424,38 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
389424
390425 weights_backup = getattr (self , "network_weights_backup" , None )
391426 if weights_backup is None and wanted_names != ():
392- if current_names != ():
393- raise RuntimeError (" no backup weights found and current weights are not unchanged" )
427+ if current_names != () and not allowed_layer_without_weight ( self ) :
428+ raise RuntimeError (f" { network_layer_name } - no backup weights found and current weights are not unchanged" )
394429
395430 if isinstance (self , torch .nn .MultiheadAttention ):
396- weights_backup = (self .in_proj_weight . to ( devices . cpu , copy = True ), self .out_proj .weight . to ( devices . cpu , copy = True ))
431+ weights_backup = (store_weights_backup ( self .in_proj_weight ), store_weights_backup ( self .out_proj .weight ))
397432 else :
398- weights_backup = self .weight . to ( devices . cpu , copy = True )
433+ weights_backup = store_weights_backup ( self .weight )
399434
400435 self .network_weights_backup = weights_backup
401436
402437 bias_backup = getattr (self , "network_bias_backup" , None )
403438 if bias_backup is None and wanted_names != ():
404439 if isinstance (self , torch .nn .MultiheadAttention ) and self .out_proj .bias is not None :
405- bias_backup = self .out_proj .bias . to ( devices . cpu , copy = True )
440+ bias_backup = store_weights_backup ( self .out_proj .bias )
406441 elif getattr (self , 'bias' , None ) is not None :
407- bias_backup = self .bias . to ( devices . cpu , copy = True )
442+ bias_backup = store_weights_backup ( self .bias )
408443 else :
409444 bias_backup = None
410445
411446 # Unlike weight which always has value, some modules don't have bias.
412447 # Only report if bias is not None and current bias are not unchanged.
413448 if bias_backup is not None and current_names != ():
414449 raise RuntimeError ("no backup bias found and current bias are not unchanged" )
450+
415451 self .network_bias_backup = bias_backup
416452
417453 if current_names != wanted_names :
418454 network_restore_weights_from_backup (self )
419455
420456 for net in loaded_networks :
421457 module = net .modules .get (network_layer_name , None )
422- if module is not None and hasattr (self , 'weight' ):
458+ if module is not None and hasattr (self , 'weight' ) and not isinstance ( module , modules . models . sd3 . mmdit . QkvLinear ) :
423459 try :
424460 with torch .no_grad ():
425461 if getattr (self , 'fp16_weight' , None ) is None :
@@ -479,6 +515,24 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
479515
480516 continue
481517
518+ if isinstance (self , modules .models .sd3 .mmdit .QkvLinear ) and module_q and module_k and module_v :
519+ try :
520+ with torch .no_grad ():
521+ # Send "real" orig_weight into MHA's lora module
522+ qw , kw , vw = self .weight .chunk (3 , 0 )
523+ updown_q , _ = module_q .calc_updown (qw )
524+ updown_k , _ = module_k .calc_updown (kw )
525+ updown_v , _ = module_v .calc_updown (vw )
526+ del qw , kw , vw
527+ updown_qkv = torch .vstack ([updown_q , updown_k , updown_v ])
528+ self .weight += updown_qkv
529+
530+ except RuntimeError as e :
531+ logging .debug (f"Network { net .name } layer { network_layer_name } : { e } " )
532+ extra_network_lora .errors [net .name ] = extra_network_lora .errors .get (net .name , 0 ) + 1
533+
534+ continue
535+
482536 if module is None :
483537 continue
484538
0 commit comments