Skip to content

Commit 0a14e38

Browse files
committed
feat: add backpack backend
1 parent 90ed2d4 commit 0a14e38

File tree

5 files changed

+104
-61
lines changed

5 files changed

+104
-61
lines changed

README.md

Lines changed: 67 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,62 @@ python3 -m pip install truegrad
1212

1313
## Examples
1414

15+
### BackPack
16+
17+
The preferred method to integrate TrueGrad is using [BackPack](https://github.com/f-dangel/backpack). BackPack is a
18+
third-party library that automatically computes the sum of gradient squares and works for most models by implementing
19+
custom backward rules for many `torch.nn.Module`'s.
20+
21+
```PYTHON
22+
import backpack
23+
import torch
24+
from torch.nn import CrossEntropyLoss
25+
from truegrad.optim import TGAdamW
26+
from torchvision.models import alexnet
27+
28+
model = alexnet()
29+
optim = TGAdamW(model.parameters(), lr=1e-7, weight_decay=0)
30+
31+
# backpack can't handle inplace ops like nn.ReLU(inplace=True) and `x += y`
32+
for mod in model.modules():
33+
if hasattr(mod, "inplace"):
34+
mod.inplace = False
35+
36+
# backpack relies on module-level pytorch hooks
37+
model = backpack.extend(model)
38+
lossfunc = backpack.extend(CrossEntropyLoss())
39+
40+
# constant input/output to overfit
41+
inp = torch.randn((2, 3, 224, 224))
42+
tgt = torch.randint(0, 1000, (2,))
43+
44+
# standard training loop
45+
i = 0
46+
while True:
47+
# "SumGradSquared" computes the sum of the squared gradient
48+
with backpack.backpack(backpack.extensions.SumGradSquared()):
49+
loss = lossfunc(model(inp), tgt)
50+
loss.backward()
51+
optim.step()
52+
i += 1
53+
if i % 5 == 0:
54+
print(i, loss.item())
55+
```
56+
57+
If you're using custom modules with self-defined parameters, this method will not work. Additionally, note that, if
58+
your model has any layer called `.output` or you're using PyTorch >= 1.13, you will need to install
59+
[BackPack-HF](https://github.com/ClashLuke/backpack-hf) via
60+
`python3 -m pip install git+https://github.com/ClashLuke/backpack-hf`.
61+
1562
### Patch Custom Models
1663

17-
The easiest way to integrate TrueGrad into existing models is to patch them using `truegrad.utils.patch_model()`.
64+
Another option to integrate TrueGrad into existing models is to patch them using `truegrad.utils.patch_model()`.
1865
`patch_model()` will go through all`torch.nn.Module`'s in PyTorch model and convert their `torch.nn.Parameter`'s to
1966
`truegrad.nn.TrueGradParameter`'s. A `TrueGradParameter` acts largely the same as a `torch.nn.Parameter`, but adds
2067
required operations into the model's backward pass.\
21-
Patching an existing
68+
Importantly, be aware that this does not work for fused functions, such as `torch.nn.LayerNorm`
69+
and `torch.nn.MultiheadAttention`. However, unfused functions which directly access a parameter, such as multiplication
70+
and work well. Therefore, torch.nn.Linear and HuggingFace's attention work as expected.
2271

2372
```PYTHON
2473
import transformers
@@ -40,11 +89,13 @@ for sample in ["Hello", "World", "!"]:
4089

4190
### nn
4291

43-
Patching existing PyTorch computation graphs on the fly might add unnecessary memory and computation. That's why a
44-
pre-patched alternative of `torch.nn` with hand-crafted gradients exists alongside the `truegrad.utils` module. Compared
45-
to `truegrad.utils.patch_model()`, `truegrad.nn` offers higher speeds and lower memory usage, although it might require
46-
code alterations and doesn't support all models. You cannot (currently) use `truegrad.nn` with `truegrad.utils`, as both
47-
use different ways to arrive at the same value.
92+
Patching existing PyTorch computation graphs on the fly might add unnecessary memory and computation or even fail
93+
unexpectedly. That's why a pre-patched alternative of `torch.nn` with hand-crafted gradients exists alongside the
94+
`truegrad.utils` module. Compared to `truegrad.utils.patch_model()`, `truegrad.nn` offers higher speeds and lower
95+
memory usage, although it might require code alterations and doesn't support all models. You cannot (currently) use
96+
`truegrad.nn` with `truegrad.utils`, as both use different ways to arrive at the same value. However, you can
97+
combine `torch.nn.Modules` and `truegrad.nn.Modules` and use the truegrad information only where it is available (
98+
see [Partial TrueGrad](#Partial-TrueGrad)).
4899

49100
```PYTHON
50101
import torch
@@ -58,7 +109,7 @@ model = torch.nn.Sequential(nn.Linear(1, 10),
58109
nn.Linear(10, 1))
59110
optim = TGAdamW(model.parameters()) # truegrad.optim.TGAdamW instead of torch.optim.AdamW
60111

61-
# training loop as normal
112+
# standard training loop
62113
while True:
63114
input = torch.randn((16, 1))
64115
model(input).mean().backward()
@@ -67,10 +118,11 @@ while True:
67118

68119
### Partial TrueGrad
69120

70-
Unfortunately, it's not always sensible to apply TrueGrad, as some backward passes are too slow to do them twice.
71-
Therefore, it can be an option to use TGAdamW only on specific subsections of the model. To do so, you can either check
72-
which parameters are of type `truegrad.nn.TrueGradParameter` when using `truegrad.utils.patch_model()` or which
73-
parameters belong to a module listed in `truegrad.nn.modules`.
121+
Unfortunately, it's not always sensible to apply TrueGrad, as some backward passes are too slow, and sometimes it's
122+
impossible to avoid a fused function.
123+
Therefore, it can be an option to use TGAdamW only on specific subsections of the model. To do so, you can
124+
specify `default_to_adam=True` to TGAdamW. Adding this option allows TGAdamW to fall back to AdamW if there is
125+
no `sum_grad_squared` attribute available.
74126
For example, the code from [#nn](#nn) could be extended in the following way:
75127

76128
```PYTHON
@@ -83,26 +135,11 @@ model = torch.nn.Sequential(nn.Linear(1, 10), # Weights coming from truegrad.nn
83135
torch.nn.ReLU(),
84136
torch.nn.Linear(10, 1)) # Weights coming torch.nn
85137

86-
truegrad_parameters = []
87-
normal_parameters = []
88-
89-
90-
def get_parameters(mod: torch.nn.Module):
91-
if isinstance(mod, nn.modules):
92-
truegrad_parameters.extend(list(mod.parameters(recurse=False)))
93-
else:
94-
# you could do truegrad.utils.patch_model(mod, recurse=False) here!
95-
normal_parameters.extend(list(mod.parameters(recurse=False)))
96-
97-
98-
model = model.apply(get_parameters)
99-
100-
optim0 = TGAdamW(truegrad_parameters)
101-
optim1 = torch.optim.AdamW(normal_parameters)
138+
optim = TGAdamW(model.parameters(), default_to_adam=True)
102139

140+
# standard training loop
103141
while True:
104142
input = torch.randn((16, 1))
105143
model(input).mean().backward()
106-
optim0.step() # update both parameter sets separately
107-
optim1.step()
144+
optim.step()
108145
```

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
name='truegrad',
1111
license='BSD',
1212
description='PyTorch interface for TrueGrad-AdamW',
13-
version='0.0.9',
13+
version='0.1.0',
1414
long_description=README,
1515
url='https://github.com/clashluke/truegrad',
1616
packages=setuptools.find_packages(),

truegrad/functional.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ def backward(ctx, dy: torch.Tensor):
1818
diff = inp.ndim - weight.ndim
1919
summed = list(range(diff)) + [i for i, dim in enumerate(weight.shape, diff) if dim == 1]
2020
weight_grad = dy * inp
21-
weight.square_grad = weight_grad.square()
21+
weight.sum_grad_squared = weight_grad.square()
2222
if summed:
2323
weight_grad = weight_grad.sum(summed)
24-
weight.square_grad = weight.square_grad.sum(summed)
25-
weight.square_grad = weight.square_grad.reshape(weight.size()) * dy.size(0)
24+
weight.sum_grad_squared = weight.sum_grad_squared.sum(summed)
25+
weight.sum_grad_squared = weight.sum_grad_squared.reshape(weight.size()) * dy.size(0)
2626
return dy * weight, weight_grad.reshape(weight.size())
2727

2828

@@ -42,11 +42,11 @@ def backward(ctx, dy: torch.Tensor):
4242
return None, None
4343
weight, = ctx.saved_tensors
4444
weight_grad = dy
45-
weight.square_grad = dy.square()
45+
weight.sum_grad_squared = dy.square()
4646
if ctx.summed:
4747
weight_grad = weight_grad.sum(ctx.summed)
48-
weight.square_grad = weight.square_grad.sum(ctx.summed)
49-
weight.square_grad = weight.square_grad.reshape(weight.size()) * dy.size(0)
48+
weight.sum_grad_squared = weight.sum_grad_squared.sum(ctx.summed)
49+
weight.sum_grad_squared = weight.sum_grad_squared.reshape(weight.size()) * dy.size(0)
5050
return dy, weight_grad.reshape(weight.size())
5151

5252

@@ -67,7 +67,7 @@ def backward(ctx, dy: torch.Tensor) -> Tuple[None, torch.Tensor, torch.Tensor]:
6767
lhs, rhs = inputs.split(',')
6868

6969
d_wgt = torch.einsum(f'{lhs},{output}->{rhs}', inp, dy)
70-
wgt.square_grad = torch.einsum(f'{lhs},{output}->{rhs}', inp.square(), dy.square() * inp.size(0))
70+
wgt.sum_grad_squared = torch.einsum(f'{lhs},{output}->{rhs}', inp.square(), dy.square() * inp.size(0))
7171
d_inp = torch.einsum(f"{rhs},{output}->{lhs}", wgt, dy)
7272
return None, d_inp, d_wgt
7373

@@ -85,7 +85,7 @@ def backward(ctx, dy: torch.Tensor) -> Tuple[None, torch.Tensor]:
8585
return None, None
8686
inp, wgt = ctx.saved_tensors
8787
wgt_grad = torch.zeros_like(wgt)
88-
wgt.square_grad = wgt_grad.scatter_add(0, inp, dy.square())
88+
wgt.sum_grad_squared = wgt_grad.scatter_add(0, inp, dy.square())
8989
wgt_grad.scatter_add_(0, inp, dy)
9090
return None, wgt_grad
9191

@@ -103,8 +103,8 @@ def backward(ctx, dy: torch.Tensor) -> Tuple[None, torch.Tensor]:
103103
if not ctx.saved_tensors:
104104
return None
105105
wgt, = ctx.saved_tensors
106-
if hasattr(wgt, "square_grad"):
107-
wgt.square_grad = wgt.square_grad.reshape(ctx.original_shape)
106+
if hasattr(wgt, "sum_grad_squared"):
107+
wgt.sum_grad_squared = wgt.sum_grad_squared.reshape(ctx.original_shape)
108108
return dy.reshape(ctx.original_shape)
109109

110110

@@ -121,8 +121,8 @@ def backward(ctx, dy: torch.Tensor) -> Tuple[None, torch.Tensor]:
121121
if not ctx.saved_tensors:
122122
return None
123123
wgt, = ctx.saved_tensors
124-
if hasattr(wgt, "square_grad") and ctx.summed:
125-
wgt.square_grad = wgt.square_grad.sum(ctx.summed)
124+
if hasattr(wgt, "sum_grad_squared") and ctx.summed:
125+
wgt.sum_grad_squared = wgt.sum_grad_squared.sum(ctx.summed)
126126
return dy.sum(ctx.summed)
127127

128128

truegrad/nn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,10 @@ def _square(x: Union[torch.Tensor, None]):
159159
for p, a in zip(list(ctx.args) + list(ctx.kwargs.values()), list(args) + list(kwargs.values())):
160160
if not isinstance(p, torch.nn.Parameter):
161161
continue
162-
if hasattr(p, "square_grad") and p.square_grad is not None:
163-
p.square_grad = p.square_grad + a.grad
162+
if hasattr(p, "sum_grad_squared") and p.sum_grad_squared is not None:
163+
p.sum_grad_squared = p.sum_grad_squared + a.grad
164164
else:
165-
p.square_grad = a.grad
165+
p.sum_grad_squared = a.grad
166166
return None, None, None, None
167167

168168

truegrad/optim.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@ def __init__(self, params, lr: float = 1e-3,
99
eps: float = 1e-12,
1010
weight_decay: float = 1e-2,
1111
graft: bool = True,
12-
decay_to_init: bool = False):
12+
decay_to_init: bool = False,
13+
default_to_adam: bool = False):
1314
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, graft=graft,
14-
decay_to_init=decay_to_init)
15+
decay_to_init=decay_to_init, default_to_adam=default_to_adam)
1516
super(TGAdamW, self).__init__(params, defaults)
1617

1718
@torch.no_grad()
@@ -31,18 +32,19 @@ def step(self, closure=None):
3132
for p in group['params']:
3233
if p.grad is None:
3334
continue
34-
if not hasattr(p, "square_grad") or p.square_grad is None:
35-
raise ValueError(f"Parameter of shape {list(p.size())} doesn't have `square_grad` attribute. "
36-
f"Make sure to use truegrad.utils.patch_model() or truegrad.nn for all optimized "
37-
f"parameters.")
35+
do_adam = not hasattr(p, "sum_grad_squared") or p.sum_grad_squared is None
36+
if not group["default_to_adam"] and do_adam:
37+
raise ValueError(f"Parameter of shape {list(p.size())} doesn't have `sum_grad_squared` attribute. "
38+
f"Make sure to use backpack.")
3839

3940
state = self.state[p]
4041

4142
if len(state) == 0:
4243
state['step'] = torch.tensor(0.)
4344
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
44-
state['exp_avg_true_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
45-
if group["graft"]:
45+
if not do_adam:
46+
state['exp_avg_true_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
47+
if do_adam or group["graft"]:
4648
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
4749
if group["decay_to_init"]:
4850
state["init"] = torch.clone(p.detach())
@@ -61,22 +63,26 @@ def step(self, closure=None):
6163
else:
6264
p.mul_(1 - decay)
6365

64-
# Decay the first and second moment running average coefficient
6566
exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
66-
exp_avg_true_sq.mul_(beta3).add_(p.square_grad, alpha=1 - beta3)
67-
p.square_grad = None
6867

6968
step = step_t.item()
70-
71-
denom = (exp_avg_true_sq / (1 - beta3 ** step)).sqrt().add_(group['eps'])
72-
update = exp_avg / denom
7369
alpha = -group['lr'] / (1 - beta1 ** step)
7470

75-
if group["graft"]:
71+
if not do_adam:
72+
exp_avg_true_sq.mul_(beta3).add_(p.sum_grad_squared, alpha=1 - beta3)
73+
p.sum_grad_squared = None
74+
denom = (exp_avg_true_sq / (1 - beta3 ** step)).sqrt().add_(group['eps'])
75+
update = exp_avg / denom
76+
77+
if group["graft"] or do_adam:
7678
exp_avg_sq = state['exp_avg_sq']
7779
exp_avg_sq.mul_(beta2).add_(p.grad.square(), alpha=1 - beta2)
7880
adam_update = exp_avg / (exp_avg_sq / (1 - beta2 ** step)).sqrt().add_(group['eps'])
81+
82+
if group["graft"] and not do_adam:
7983
alpha = alpha * adam_update.norm() / update.norm()
84+
elif do_adam:
85+
update = adam_update
8086

8187
p.add_(update, alpha=alpha)
8288
return loss

0 commit comments

Comments
 (0)