@@ -23,8 +23,8 @@ def __call__(self, mod: torch.optim.Optimizer):
2323 idx = 0
2424 for group in mod .param_groups :
2525 for p in group ["params" ]:
26- p .data .add (functools .reduce (lambda x , f : f (mod , x , idx ), self .operands , p ),
27- alpha = - group ["lr" ] * group ["weight_decay" ])
26+ p .data .add_ (functools .reduce (lambda x , f : f (mod , x , idx ), self .operands , p ),
27+ alpha = - group ["lr" ] * group ["weight_decay" ])
2828 idx += 1
2929
3030
@@ -354,9 +354,8 @@ class TGAdamW(TrueGrad):
354354
355355 def __init__ (self , params , lr : float = 1e-3 ,
356356 betas : Union [Tuple [float , float ], Tuple [float , float , float ]] = (0.9 , 0.999 , 0.999 ),
357- eps : float = 1e-12 , weight_decay : float = 1e-2 , graft : bool = True ,
358- default_to_baseline : bool = None , enforce_baseline : bool = False ,
359- weight_decay_cls : Optional [WeightDecayChain ] = None ):
357+ eps : float = 1e-12 , weight_decay : float = 1e-2 , graft : bool = True , default_to_baseline : bool = None ,
358+ enforce_baseline : bool = False , weight_decay_cls : Optional [WeightDecayChain ] = None ):
360359 if default_to_baseline is None :
361360 default_to_baseline = False
362361 super ().__init__ (params , lr = lr , betas = betas , eps = eps , weight_decay = weight_decay , graft = graft ,
@@ -428,9 +427,8 @@ class TGRMSProp(TrueGrad):
428427 base_statistics : List [str ] = ["exp_avg_sq" ]
429428
430429 def __init__ (self , params , lr : float = 1e-3 , betas : Union [float , Tuple [float ], Tuple [float , float ]] = (0.9 ,),
431- eps : float = 1e-12 , weight_decay : float = 1e-2 , graft : bool = True ,
432- default_to_baseline : bool = False , enforce_baseline : bool = False ,
433- weight_decay_cls : Optional [WeightDecayChain ] = None ):
430+ eps : float = 1e-12 , weight_decay : float = 1e-2 , graft : bool = True , default_to_baseline : bool = False ,
431+ enforce_baseline : bool = False , weight_decay_cls : Optional [WeightDecayChain ] = None ):
434432 super ().__init__ (params , lr = lr , betas = betas , eps = eps , weight_decay = weight_decay , graft = graft ,
435433 default_to_baseline = default_to_baseline , enforce_baseline = enforce_baseline ,
436434 weight_decay_cls = weight_decay_cls )
0 commit comments