Skip to content

Commit d7c458b

Browse files
committed
fix(optim): self-graft correctly
1 parent 9a43be0 commit d7c458b

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
name='truegrad',
1111
license='BSD',
1212
description='PyTorch interface for TrueGrad-AdamW',
13-
version='4.0.2',
13+
version='4.0.3',
1414
long_description=README,
1515
url='https://github.com/clashluke/truegrad',
1616
packages=setuptools.find_packages(),
@@ -26,6 +26,7 @@
2626
'Programming Language :: Python :: 3.7',
2727
'Programming Language :: Python :: 3.8',
2828
'Programming Language :: Python :: 3.9',
29+
'Programming Language :: Python :: 3.10',
2930
'Topic :: Software Development :: Libraries',
3031
'Topic :: Software Development :: Libraries :: Python Modules',
3132
'Intended Audience :: Developers',

truegrad/optim.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,10 @@ def step(self, closure=None):
196196
update = p.double() - o.double()
197197
p.set_(o)
198198
scale = group["lr"]
199+
sign_update = torch.sign(update)
199200
if group["graft_to_self"]:
200-
scale = scale * torch.norm(update)
201-
p.add_(torch.sign(update), alpha=scale)
201+
scale = scale * update.norm() / sign_update.norm().clamp(min=group["eps"])
202+
p.add_(sign_update, alpha=scale)
202203

203204
return loss
204205

0 commit comments

Comments
 (0)