Skip to content

Commit 4279d61

Browse files
committed
feat(functional): improve grad accum, fix einsum backwd, allow full patching
1 parent bb868cc commit 4279d61

File tree

7 files changed

+428
-286
lines changed

7 files changed

+428
-286
lines changed

README.md

Lines changed: 111 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@ python3 -m pip install truegrad
1414

1515
TrueGrad supports various backends, each with their own tradeoffs:
1616

17-
| Name | Advantages | Disadvantages |
18-
|----------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|
19-
| [truegrad.nn](#nn) | * What you see is what you get - Modules not in truegrad.nn and truegrad.nn.functional are not supported<br/>* Custom forward/backward for some fused functions<br/>* Optimized backward passes | * Limited applicability - custom modules can't be used<br/>* Requires code modification |
20-
| [truegrad.utils.patch_torch](#patch-torch) | * Uses truegrad.nn under the hood<br/>* Works for many (off-the-shelf!) torch models<br/>* No code modification necessary | * Uncertainty if model is compatible |
21-
| [backpack](#backpack) | * Highest stability<br/>* Loud warnings and errors<br/>* Battle-tested<br/>* Simple to extend further | * High memory usage<br/>* High compute usage<br/>* Sparse support for torch operations |
22-
| [truegrad.utils.patch_model](#patch-custom-models) | * Best compatibility | * Fails silently on fused functions<br/>* More costly than truegrad.nn |
17+
| Name | Advantages | Disadvantages |
18+
|----------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------|
19+
| [truegrad.nn](#nn) | * What you see is what you get - Modules not in truegrad.nn and truegrad.nn.functional are not supported<br/>* Custom forward/backward for some fused functions<br/>* Optimized backward passes | * Limited applicability - custom modules can't be used<br/>* Requires code modification |
20+
| [truegrad.utils.patch_torch](#patch-torch) | * Uses truegrad.nn under the hood<br/>* Works for many (off-the-shelf!) torch models<br/>* No code modification necessary | * Uncertainty if model is compatible |
21+
| [backpack](#backpack) | * Highest stability<br/>* Loud warnings and errors<br/>* Battle-tested<br/>* Simple to extend further | * High memory usage<br/>* High compute usage<br/>* Sparse support for torch operations |
22+
| [truegrad.utils.patch_model](#patch-custom-models) | * Works with custom models | * Fails silently on fused functions<br/>* ~50% to 100% slower than truegrad.nn |
23+
| [patch_torch + patch_model](#Full Patching) | * Best compatibility<br/>* Reduced overheads compared to `patch_model` (by falling back to faster pre-patched `patch_torch` where available) | * Fails silently on fused functions outside of torch.nn<br/> * Slower than truegrad.nn when truegrad.nn would've been enough |
2324

2425
Below, you'll find examples for each of these backends, as well as a [general strategy](#partial-truegrad) allowing
2526
partial application of TrueGrad.
@@ -47,6 +48,7 @@ while True:
4748
input = torch.randn((16, 1))
4849
model(input).mean().backward()
4950
optim.step()
51+
optim.zero_grad()
5052
```
5153

5254
### Patch Torch
@@ -77,11 +79,46 @@ while True:
7779
loss = torch.nn.functional.cross_entropy(model(inp), tgt)
7880
loss.backward()
7981
optim.step()
82+
optim.zero_grad()
8083
i += 1
8184
if i % 5 == 0:
8285
print(i, loss.item())
8386
```
8487

88+
Similarly, most huggingface transformers work out of the box:
89+
90+
```PYTHON
91+
import torch
92+
import transformers
93+
from torch.nn import functional as F
94+
95+
from truegrad.optim import TGAdamW
96+
from truegrad.utils import patch_torch
97+
98+
patch_torch() # only added line to get truegrad statistics for TGAdamW
99+
100+
model = transformers.BertModel.from_pretrained("google/bert_uncased_L-2_H-128_A-2") # any existing model
101+
tokenizer = transformers.BertTokenizer.from_pretrained("google/bert_uncased_L-2_H-128_A-2")
102+
103+
optim = TGAdamW(model.parameters())
104+
105+
# constant input to overfit
106+
input = tokenizer(["Hello World!"], return_tensors="pt")
107+
108+
# training loop as normal
109+
while True:
110+
out = model(**input)
111+
loss = F.l1_loss(out[0], torch.ones_like(out[0]))
112+
loss.backward()
113+
optim.step()
114+
optim.zero_grad()
115+
print(loss.item())
116+
```
117+
118+
Note that this works even though transformers have custom modules, which could cause issues. The key factor is that all
119+
parameters come from `torch.nn.Module`'s, which are patched by `patch_torch()`. Therefore, truegrad handles all
120+
parameter usages. Therefore, any composition of `torch.nn.Module`'s makes for a truegrad-compatible model.
121+
85122
### BackPack
86123

87124
The most stable although also memory hungry method to compute TrueGrad statistics is to use
@@ -119,6 +156,7 @@ while True:
119156
loss = lossfunc(model(inp), tgt)
120157
loss.backward()
121158
optim.step()
159+
optim.zero_grad()
122160
i += 1
123161
if i % 5 == 0:
124162
print(i, loss.item())
@@ -141,21 +179,78 @@ and `torch.nn.MultiheadAttention`. However, unfused functions which directly acc
141179
work well. Therefore, torch.nn.Linear and HuggingFace's attention work as expected.
142180

143181
```PYTHON
144-
import transformers
145-
from truegrad.utils import patch_model
182+
import torch
146183
from truegrad.optim import TGAdamW
184+
from truegrad.utils import patch_model
185+
from torchvision.models import alexnet
147186

148-
model = transformers.BertModel.from_pretrained("google/bert_uncased_L-2_H-128_A-2") # any existing model
149-
tokenizer = transformers.BertTokenizer.from_pretrained("google/bert_uncased_L-2_H-128_A-2")
187+
model = alexnet() # patch_model can't handle fused ops like VGG's and ResNet's BatchNorm
188+
optim = TGAdamW(model.parameters())
189+
190+
# replace inplace ops like nn.ReLU(inplace=True) where possible
191+
for mod in model.modules():
192+
if hasattr(mod, "inplace"):
193+
mod.inplace = False
150194

151195
patch_model(model) # replace torch.nn.Parameter with truegrad.nn.Parameter
152-
optim = TGAdamW(model.parameters()) # truegrad.optim.TGAdamW instead of torch.optim.AdamW
153196

154-
# training loop as normal
155-
for sample in ["Hello", "World", "!"]:
156-
out = model(**tokenizer([sample], return_tensors="pt"))
157-
out[0].mean().backward()
197+
# constant input/output to overfit
198+
inp = torch.randn((2, 3, 224, 224))
199+
tgt = torch.randint(0, 1000, (2,))
200+
201+
# standard training loop
202+
i = 0
203+
while True:
204+
# "SumGradSquared" computes the sum of the squared gradient
205+
loss = torch.nn.functional.cross_entropy(model(inp), tgt)
206+
loss.backward()
158207
optim.step()
208+
optim.zero_grad()
209+
i += 1
210+
if i % 5 == 0:
211+
print(i, loss.item())
212+
```
213+
214+
### Full Patching
215+
216+
One way of avoiding [truegrad.utils.patch_model](#patch-custom-models)'s downsides when working with off-the-shelf
217+
models containing custom parameters, such as [lucidrains' ViT's](https://github.com/lucidrains/vit-pytorch/) is to also
218+
`patch_torch`. This takes care of many fused functions, such as LayerNorm, while still allowing full flexibility in
219+
model design.
220+
221+
```PYTHON
222+
import torch
223+
from vit_pytorch.levit import LeViT
224+
from truegrad.utils import patch_torch, patch_model
225+
from truegrad.optim import TGAdamW
226+
227+
patch_torch() # before model instantiation
228+
229+
levit = LeViT(
230+
image_size=224,
231+
num_classes=1000,
232+
stages=3, # number of stages
233+
dim=(256, 384, 512), # dimensions at each stage
234+
depth=4, # transformer of depth 4 at each stage
235+
heads=(4, 6, 8), # heads at each stage
236+
mlp_mult=2,
237+
dropout=0.1
238+
)
239+
240+
opt = TGAdamW(levit.parameters())
241+
242+
patch_model(levit) # replace torch.nn.Parameter with truegrad.nn.TrueGradParameter
243+
244+
# constant input to overfit
245+
img = torch.randn(1, 3, 224, 224)
246+
247+
# standard training loop
248+
while True:
249+
loss = levit(img).square().mean()
250+
loss.backward()
251+
opt.step()
252+
opt.zero_grad()
253+
print(loss.item())
159254
```
160255

161256
### Partial TrueGrad
@@ -186,6 +281,7 @@ while True:
186281
loss = model(input).mean()
187282
loss.backward()
188283
optim.step()
284+
optim.zero_grad()
189285
i += 1
190286
if i % 5 == 0:
191287
print(i, loss.item())

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
name='truegrad',
1111
license='BSD',
1212
description='PyTorch interface for TrueGrad-AdamW',
13-
version='2.0.0',
13+
version='2.1.0',
1414
long_description=README,
1515
url='https://github.com/clashluke/truegrad',
1616
packages=setuptools.find_packages(),

0 commit comments

Comments
 (0)