Skip to content

Commit 2c0682f

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Add to_cotangent_aval to HiType and use it by default in bwd pass
PiperOrigin-RevId: 845470915
1 parent 8d4e9c1 commit 2c0682f

File tree

3 files changed

+17
-3
lines changed

3 files changed

+17
-3
lines changed

jax/_src/api.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2332,9 +2332,7 @@ def _vjp_check_ct_avals(cts, primal_avals):
23322332
# TODO(mattjj): improve this error by flattening with keys in the first place
23332333
for ct, aval in zip(cts, primal_avals):
23342334
ct_aval = typeof(ct)
2335-
ct_aval_expected = (
2336-
aval.to_cotangent_aval() if hasattr(aval, 'to_cotangent_aval') else
2337-
aval.to_tangent_aval())
2335+
ct_aval_expected = aval.to_cotangent_aval()
23382336
if (not core.typecompat(ct_aval, ct_aval_expected) and
23392337
not _temporary_dtype_exception(ct_aval, ct_aval_expected)):
23402338
raise ValueError(

jax/_src/core.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1654,6 +1654,9 @@ class AbstractValue:
16541654
def to_tangent_aval(self):
16551655
raise NotImplementedError("must override")
16561656

1657+
def to_cotangent_aval(self):
1658+
raise NotImplementedError("must override")
1659+
16571660
# TODO(dougalm): deprecate this alias
16581661
def at_least_vspace(self):
16591662
return self.to_tangent_aval()
@@ -2619,6 +2622,7 @@ def accum_grad_in_ref(x):
26192622
class AbstractToken(AbstractValue):
26202623
def str_short(self, short_dtypes=False, mesh_axis_types=False): return 'Tok'
26212624
def to_tangent_aval(self): return self
2625+
def to_cotangent_aval(self): return self
26222626
abstract_token: AbstractToken = AbstractToken()
26232627

26242628
# Singleton shaped array used by all abstract tokens when shape/dtype is needed.

jax/_src/hijax.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,16 @@ def raise_val(self, *lo_vals: LoVal) -> HiVal:
9292
# autodiff interface
9393
def to_tangent_aval(self) -> HiType:
9494
assert False, "must override"
95+
96+
# Subclasses should override if the cotangent type is a function of primal
97+
# type. For example, CT unreduced = reduced and vice-versa.
98+
def to_cotangent_aval(self) -> HiType:
99+
return self.to_tangent_aval()
100+
95101
# the next two are required if this type is itself a tangent type
96102
def vspace_zero(self) -> HiVal:
97103
assert False, "must override"
104+
98105
def vspace_add(self, x: HiVal, y: HiVal) -> HiVal:
99106
assert False, "must override"
100107

@@ -127,6 +134,11 @@ def update_from_loval(self, state: QDD, val: HiVal, *lo_vals: LoVal) -> None:
127134
def to_tangent_aval(self) -> HiType:
128135
assert False, "must override"
129136

137+
# Subclasses should override if the cotangent type is a function of primal
138+
# type. For example, CT unreduced = reduced and vice-versa.
139+
def to_cotangent_aval(self) -> HiType:
140+
return self.to_tangent_aval()
141+
130142
def register_hitype(val_cls, typeof_fn) -> None:
131143
core.pytype_aval_mappings[val_cls] = typeof_fn
132144
dtypes.canonicalize_value_handlers[val_cls] = lambda x: x

0 commit comments

Comments
 (0)