File tree Expand file tree Collapse file tree 2 files changed +5
-3
lines changed
Expand file tree Collapse file tree 2 files changed +5
-3
lines changed Original file line number Diff line number Diff line change 1010 name = 'truegrad' ,
1111 license = 'BSD' ,
1212 description = 'PyTorch interface for TrueGrad-AdamW' ,
13- version = '4.0.2 ' ,
13+ version = '4.0.3 ' ,
1414 long_description = README ,
1515 url = 'https://github.com/clashluke/truegrad' ,
1616 packages = setuptools .find_packages (),
2626 'Programming Language :: Python :: 3.7' ,
2727 'Programming Language :: Python :: 3.8' ,
2828 'Programming Language :: Python :: 3.9' ,
29+ 'Programming Language :: Python :: 3.10' ,
2930 'Topic :: Software Development :: Libraries' ,
3031 'Topic :: Software Development :: Libraries :: Python Modules' ,
3132 'Intended Audience :: Developers' ,
Original file line number Diff line number Diff line change @@ -196,9 +196,10 @@ def step(self, closure=None):
196196 update = p .double () - o .double ()
197197 p .set_ (o )
198198 scale = group ["lr" ]
199+ sign_update = torch .sign (update )
199200 if group ["graft_to_self" ]:
200- scale = scale * torch .norm (update )
201- p .add_ (torch . sign ( update ) , alpha = scale )
201+ scale = scale * update .norm () / sign_update . norm (). clamp ( min = group [ "eps" ] )
202+ p .add_ (sign_update , alpha = scale )
202203
203204 return loss
204205
You can’t perform that action at this time.
0 commit comments