You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
|[truegrad.nn](#nn)| * What you see is what you get - Modules not in truegrad.nn and truegrad.nn.functional are not supported<br/>* Custom forward/backward for some fused functions<br/>* Optimized backward passes | * Limited applicability - custom modules can't be used<br/>* Requires code modification |
20
-
|[truegrad.utils.patch_torch](#patch-torch)| * Uses truegrad.nn under the hood<br/>* Works for many (off-the-shelf!) torch models<br/>* No code modification necessary | * Uncertainty if model is compatible |
21
-
|[backpack](#backpack)| * Highest stability<br/>* Loud warnings and errors<br/>* Battle-tested<br/>* Simple to extend further | * High memory usage<br/>* High compute usage<br/>* Sparse support for torch operations |
22
-
|[truegrad.utils.patch_model](#patch-custom-models)| * Best compatibility | * Fails silently on fused functions<br/>* More costly than truegrad.nn |
|[truegrad.nn](#nn)| * What you see is what you get - Modules not in truegrad.nn and truegrad.nn.functional are not supported<br/>* Custom forward/backward for some fused functions<br/>* Optimized backward passes | * Limited applicability - custom modules can't be used<br/>* Requires code modification |
20
+
|[truegrad.utils.patch_torch](#patch-torch)| * Uses truegrad.nn under the hood<br/>* Works for many (off-the-shelf!) torch models<br/>* No code modification necessary | * Uncertainty if model is compatible |
21
+
|[backpack](#backpack)| * Highest stability<br/>* Loud warnings and errors<br/>* Battle-tested<br/>* Simple to extend further | * High memory usage<br/>* High compute usage<br/>* Sparse support for torch operations |
22
+
|[truegrad.utils.patch_model](#patch-custom-models)| * Works with custom models | * Fails silently on fused functions<br/>*~50% to 100% slower than truegrad.nn |
23
+
|[patch_torch + patch_model](#Full Patching) | * Best compatibility<br/>* Reduced overheads compared to `patch_model` (by falling back to faster pre-patched `patch_torch` where available) | * Fails silently on fused functions outside of torch.nn<br/> * Slower than truegrad.nn when truegrad.nn would've been enough |
23
24
24
25
Below, you'll find examples for each of these backends, as well as a [general strategy](#partial-truegrad) allowing
25
26
partial application of TrueGrad.
@@ -47,6 +48,7 @@ while True:
47
48
input= torch.randn((16, 1))
48
49
model(input).mean().backward()
49
50
optim.step()
51
+
optim.zero_grad()
50
52
```
51
53
52
54
### Patch Torch
@@ -77,11 +79,46 @@ while True:
77
79
loss = torch.nn.functional.cross_entropy(model(inp), tgt)
78
80
loss.backward()
79
81
optim.step()
82
+
optim.zero_grad()
80
83
i +=1
81
84
if i %5==0:
82
85
print(i, loss.item())
83
86
```
84
87
88
+
Similarly, most huggingface transformers work out of the box:
89
+
90
+
```PYTHON
91
+
import torch
92
+
import transformers
93
+
from torch.nn import functional as F
94
+
95
+
from truegrad.optim import TGAdamW
96
+
from truegrad.utils import patch_torch
97
+
98
+
patch_torch() # only added line to get truegrad statistics for TGAdamW
99
+
100
+
model = transformers.BertModel.from_pretrained("google/bert_uncased_L-2_H-128_A-2") # any existing model
0 commit comments