@@ -12,13 +12,62 @@ python3 -m pip install truegrad
1212
1313## Examples
1414
15+ ### BackPack
16+
17+ The preferred method to integrate TrueGrad is using [ BackPack] ( https://github.com/f-dangel/backpack ) . BackPack is a
18+ third-party library that automatically computes the sum of gradient squares and works for most models by implementing
19+ custom backward rules for many ` torch.nn.Module ` 's.
20+
21+ ``` PYTHON
22+ import backpack
23+ import torch
24+ from torch.nn import CrossEntropyLoss
25+ from truegrad.optim import TGAdamW
26+ from torchvision.models import alexnet
27+
28+ model = alexnet()
29+ optim = TGAdamW(model.parameters(), lr = 1e-7 , weight_decay = 0 )
30+
31+ # backpack can't handle inplace ops like nn.ReLU(inplace=True) and `x += y`
32+ for mod in model.modules():
33+ if hasattr (mod, " inplace" ):
34+ mod.inplace = False
35+
36+ # backpack relies on module-level pytorch hooks
37+ model = backpack.extend(model)
38+ lossfunc = backpack.extend(CrossEntropyLoss())
39+
40+ # constant input/output to overfit
41+ inp = torch.randn((2 , 3 , 224 , 224 ))
42+ tgt = torch.randint(0 , 1000 , (2 ,))
43+
44+ # standard training loop
45+ i = 0
46+ while True :
47+ # "SumGradSquared" computes the sum of the squared gradient
48+ with backpack.backpack(backpack.extensions.SumGradSquared()):
49+ loss = lossfunc(model(inp), tgt)
50+ loss.backward()
51+ optim.step()
52+ i += 1
53+ if i % 5 == 0 :
54+ print (i, loss.item())
55+ ```
56+
57+ If you're using custom modules with self-defined parameters, this method will not work. Additionally, note that, if
58+ your model has any layer called ` .output ` or you're using PyTorch >= 1.13, you will need to install
59+ [ BackPack-HF] ( https://github.com/ClashLuke/backpack-hf ) via
60+ ` python3 -m pip install git+https://github.com/ClashLuke/backpack-hf ` .
61+
1562### Patch Custom Models
1663
17- The easiest way to integrate TrueGrad into existing models is to patch them using ` truegrad.utils.patch_model() ` .
64+ Another option to integrate TrueGrad into existing models is to patch them using ` truegrad.utils.patch_model() ` .
1865` patch_model() ` will go through all` torch.nn.Module ` 's in PyTorch model and convert their ` torch.nn.Parameter ` 's to
1966` truegrad.nn.TrueGradParameter ` 's. A ` TrueGradParameter ` acts largely the same as a ` torch.nn.Parameter ` , but adds
2067required operations into the model's backward pass.\
21- Patching an existing
68+ Importantly, be aware that this does not work for fused functions, such as ` torch.nn.LayerNorm `
69+ and ` torch.nn.MultiheadAttention ` . However, unfused functions which directly access a parameter, such as multiplication
70+ and work well. Therefore, torch.nn.Linear and HuggingFace's attention work as expected.
2271
2372``` PYTHON
2473import transformers
@@ -40,11 +89,13 @@ for sample in ["Hello", "World", "!"]:
4089
4190### nn
4291
43- Patching existing PyTorch computation graphs on the fly might add unnecessary memory and computation. That's why a
44- pre-patched alternative of ` torch.nn ` with hand-crafted gradients exists alongside the ` truegrad.utils ` module. Compared
45- to ` truegrad.utils.patch_model() ` , ` truegrad.nn ` offers higher speeds and lower memory usage, although it might require
46- code alterations and doesn't support all models. You cannot (currently) use ` truegrad.nn ` with ` truegrad.utils ` , as both
47- use different ways to arrive at the same value.
92+ Patching existing PyTorch computation graphs on the fly might add unnecessary memory and computation or even fail
93+ unexpectedly. That's why a pre-patched alternative of ` torch.nn ` with hand-crafted gradients exists alongside the
94+ ` truegrad.utils ` module. Compared to ` truegrad.utils.patch_model() ` , ` truegrad.nn ` offers higher speeds and lower
95+ memory usage, although it might require code alterations and doesn't support all models. You cannot (currently) use
96+ ` truegrad.nn ` with ` truegrad.utils ` , as both use different ways to arrive at the same value. However, you can
97+ combine ` torch.nn.Modules ` and ` truegrad.nn.Modules ` and use the truegrad information only where it is available (
98+ see [ Partial TrueGrad] ( #Partial-TrueGrad ) ).
4899
49100``` PYTHON
50101import torch
@@ -58,7 +109,7 @@ model = torch.nn.Sequential(nn.Linear(1, 10),
58109 nn.Linear(10 , 1 ))
59110optim = TGAdamW(model.parameters()) # truegrad.optim.TGAdamW instead of torch.optim.AdamW
60111
61- # training loop as normal
112+ # standard training loop
62113while True :
63114 input = torch.randn((16 , 1 ))
64115 model(input ).mean().backward()
@@ -67,10 +118,11 @@ while True:
67118
68119### Partial TrueGrad
69120
70- Unfortunately, it's not always sensible to apply TrueGrad, as some backward passes are too slow to do them twice.
71- Therefore, it can be an option to use TGAdamW only on specific subsections of the model. To do so, you can either check
72- which parameters are of type ` truegrad.nn.TrueGradParameter ` when using ` truegrad.utils.patch_model() ` or which
73- parameters belong to a module listed in ` truegrad.nn.modules ` .
121+ Unfortunately, it's not always sensible to apply TrueGrad, as some backward passes are too slow, and sometimes it's
122+ impossible to avoid a fused function.
123+ Therefore, it can be an option to use TGAdamW only on specific subsections of the model. To do so, you can
124+ specify ` default_to_adam=True ` to TGAdamW. Adding this option allows TGAdamW to fall back to AdamW if there is
125+ no ` sum_grad_squared ` attribute available.
74126For example, the code from [ #nn] ( #nn ) could be extended in the following way:
75127
76128``` PYTHON
@@ -83,26 +135,11 @@ model = torch.nn.Sequential(nn.Linear(1, 10), # Weights coming from truegrad.nn
83135 torch.nn.ReLU(),
84136 torch.nn.Linear(10 , 1 )) # Weights coming torch.nn
85137
86- truegrad_parameters = []
87- normal_parameters = []
88-
89-
90- def get_parameters (mod : torch.nn.Module):
91- if isinstance (mod, nn.modules):
92- truegrad_parameters.extend(list (mod.parameters(recurse = False )))
93- else :
94- # you could do truegrad.utils.patch_model(mod, recurse=False) here!
95- normal_parameters.extend(list (mod.parameters(recurse = False )))
96-
97-
98- model = model.apply(get_parameters)
99-
100- optim0 = TGAdamW(truegrad_parameters)
101- optim1 = torch.optim.AdamW(normal_parameters)
138+ optim = TGAdamW(model.parameters(), default_to_adam = True )
102139
140+ # standard training loop
103141while True :
104142 input = torch.randn((16 , 1 ))
105143 model(input ).mean().backward()
106- optim0.step() # update both parameter sets separately
107- optim1.step()
144+ optim.step()
108145```
0 commit comments