Skip to content

Commit d9b50a8

Browse files
committed
fix(optim/decay): inplace op
1 parent c39bcc9 commit d9b50a8

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

truegrad/optim.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def __call__(self, mod: torch.optim.Optimizer):
2323
idx = 0
2424
for group in mod.param_groups:
2525
for p in group["params"]:
26-
p.data.add(functools.reduce(lambda x, f: f(mod, x, idx), self.operands, p),
27-
alpha=-group["lr"] * group["weight_decay"])
26+
p.data.add_(functools.reduce(lambda x, f: f(mod, x, idx), self.operands, p),
27+
alpha=-group["lr"] * group["weight_decay"])
2828
idx += 1
2929

3030

@@ -354,9 +354,8 @@ class TGAdamW(TrueGrad):
354354

355355
def __init__(self, params, lr: float = 1e-3,
356356
betas: Union[Tuple[float, float], Tuple[float, float, float]] = (0.9, 0.999, 0.999),
357-
eps: float = 1e-12, weight_decay: float = 1e-2, graft: bool = True,
358-
default_to_baseline: bool = None, enforce_baseline: bool = False,
359-
weight_decay_cls: Optional[WeightDecayChain] = None):
357+
eps: float = 1e-12, weight_decay: float = 1e-2, graft: bool = True, default_to_baseline: bool = None,
358+
enforce_baseline: bool = False, weight_decay_cls: Optional[WeightDecayChain] = None):
360359
if default_to_baseline is None:
361360
default_to_baseline = False
362361
super().__init__(params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, graft=graft,
@@ -428,9 +427,8 @@ class TGRMSProp(TrueGrad):
428427
base_statistics: List[str] = ["exp_avg_sq"]
429428

430429
def __init__(self, params, lr: float = 1e-3, betas: Union[float, Tuple[float], Tuple[float, float]] = (0.9,),
431-
eps: float = 1e-12, weight_decay: float = 1e-2, graft: bool = True,
432-
default_to_baseline: bool = False, enforce_baseline: bool = False,
433-
weight_decay_cls: Optional[WeightDecayChain] = None):
430+
eps: float = 1e-12, weight_decay: float = 1e-2, graft: bool = True, default_to_baseline: bool = False,
431+
enforce_baseline: bool = False, weight_decay_cls: Optional[WeightDecayChain] = None):
434432
super().__init__(params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, graft=graft,
435433
default_to_baseline=default_to_baseline, enforce_baseline=enforce_baseline,
436434
weight_decay_cls=weight_decay_cls)

0 commit comments

Comments
 (0)