Skip to content
This repository was archived by the owner on May 6, 2025. It is now read-only.

Commit ad47437

Browse files
committed
Bump min JAX version to 0.4.16 + minor no-op linter fixes.
PiperOrigin-RevId: 589663637
1 parent ad3d524 commit ad47437

File tree

16 files changed

+136
-51
lines changed

16 files changed

+136
-51
lines changed

neural_tangents/_src/batching.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,20 @@
4646
import warnings
4747

4848
import jax
49+
4950
from jax import device_put
5051
from jax import devices
5152
from jax import jit
5253
from jax import pmap
5354
from jax import random
55+
5456
import jax.numpy as jnp
57+
5558
from jax.tree_util import tree_all
5659
from jax.tree_util import tree_flatten
5760
from jax.tree_util import tree_map
5861
from jax.tree_util import tree_unflatten
62+
5963
import numpy as np
6064

6165
from .utils import utils

neural_tangents/_src/empirical.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,16 +116,22 @@
116116
from jax import linear_transpose
117117
from jax import vjp
118118
from jax import vmap
119+
119120
from jax.core import Jaxpr
120121
from jax.core import JaxprEqn
121122
from jax.core import Literal
122123
from jax.core import ShapedArray
123124
from jax.core import Value
124125
from jax.core import Var
126+
127+
from jax.extend import linear_util as lu
128+
125129
from jax.interpreters import ad
126130
from jax.interpreters.ad import UndefinedPrimal
127131
from jax.interpreters.ad import Zero
132+
128133
import jax.numpy as jnp
134+
129135
from jax.tree_util import tree_flatten
130136
from jax.tree_util import tree_map
131137
from jax.tree_util import tree_reduce
@@ -134,6 +140,7 @@
134140
from jax.tree_util import tree_unflatten
135141
from jax.util import safe_map as map
136142
from jax.util import safe_zip as zip
143+
137144
import numpy as np
138145

139146
from .utils import rules
@@ -146,12 +153,6 @@
146153
from .utils.typing import VMapAxes
147154
from .utils.typing import VMapAxisTriple
148155

149-
try:
150-
# jax >=0.4.16
151-
from jax.extend import linear_util as lu
152-
except ImportError:
153-
from jax import linear_util as lu
154-
155156

156157
# LINEARIZATION AND TAYLOR EXPANSION
157158

neural_tangents/_src/monte_carlo.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,34 @@
2626
kernel function is JITted internally.
2727
"""
2828

29-
3029
from functools import partial
3130
import operator
3231
from typing import Generator, Iterable, Optional, Union
3332

34-
from .batching import batch
35-
from .empirical import empirical_kernel_fn, NtkImplementation, DEFAULT_NTK_IMPLEMENTATION, _DEFAULT_NTK_FWD, _DEFAULT_NTK_S_RULES, _DEFAULT_NTK_J_RULES
3633
import jax
3734
from jax import random
3835
import jax.numpy as jnp
3936
from jax.tree_util import tree_map
37+
38+
from .batching import batch
39+
40+
from .empirical import _DEFAULT_NTK_FWD
41+
from .empirical import _DEFAULT_NTK_J_RULES
42+
from .empirical import _DEFAULT_NTK_S_RULES
43+
from .empirical import DEFAULT_NTK_IMPLEMENTATION
44+
from .empirical import empirical_kernel_fn
45+
from .empirical import NtkImplementation
46+
4047
from .utils import utils
41-
from .utils.typing import ApplyFn, Axes, EmpiricalGetKernelFn, Get, InitFn, MonteCarloKernelFn, NTTree, PyTree, VMapAxes
48+
from .utils.typing import ApplyFn
49+
from .utils.typing import Axes
50+
from .utils.typing import EmpiricalGetKernelFn
51+
from .utils.typing import Get
52+
from .utils.typing import InitFn
53+
from .utils.typing import MonteCarloKernelFn
54+
from .utils.typing import NTTree
55+
from .utils.typing import PyTree
56+
from .utils.typing import VMapAxes
4257

4358

4459
def _sample_once_kernel_fn(

neural_tangents/_src/predict.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,21 +27,25 @@
2727
note that closed-form kernels currently only support a single `channel_axis`).
2828
"""
2929

30-
3130
import collections
3231
from functools import lru_cache
33-
from typing import Callable, Generator, Iterable, NamedTuple, Optional, Any, Union, Protocol
32+
from typing import Any, Callable, Generator, Iterable, NamedTuple, Optional, Protocol, Union
3433

3534
import jax
3635
from jax import grad
3736
from jax.experimental import ode
3837
import jax.numpy as jnp
3938
import jax.scipy as jsp
40-
from jax.tree_util import tree_all, tree_map
39+
from jax.tree_util import tree_all
40+
from jax.tree_util import tree_map
4141
import numpy as np
4242
import scipy as sp
43-
from .utils import dataclasses, utils
44-
from .utils.typing import Axes, Get, KernelFn
43+
44+
from .utils import dataclasses
45+
from .utils import utils
46+
from .utils.typing import Axes
47+
from .utils.typing import Get
48+
from .utils.typing import KernelFn
4549

4650

4751
PyTree = Any

neural_tangents/_src/stax/branching.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,19 @@
1818
several branches into one.
1919
"""
2020

21-
2221
import functools
2322
from typing import Callable, Iterable, Optional, Sequence
2423
import warnings
2524

2625
from jax import numpy as jnp
2726
import jax.example_libraries.stax as ostax
28-
from .requirements import layer, supports_masking
27+
2928
from ..utils.kernel import Kernel
30-
from ..utils.typing import InternalLayer, InternalLayerMasked, Kernels
29+
from ..utils.typing import InternalLayer
30+
from ..utils.typing import InternalLayerMasked
31+
from ..utils.typing import Kernels
32+
from .requirements import layer
33+
from .requirements import supports_masking
3134

3235

3336
@layer

neural_tangents/_src/stax/combinators.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,21 @@
2020

2121
import frozendict
2222
import jax
23-
from jax import random, lax
23+
from jax import lax
24+
from jax import random
2425
import jax.example_libraries.stax as ostax
25-
from .requirements import Diagonal, get_req, layer, requires
26+
2627
from ..utils.kernel import Kernel
27-
from ..utils.typing import InternalLayer, Layer, LayerKernelFn, NTTree, NTTrees, Shapes
28+
from ..utils.typing import InternalLayer
29+
from ..utils.typing import Layer
30+
from ..utils.typing import LayerKernelFn
31+
from ..utils.typing import NTTree
32+
from ..utils.typing import NTTrees
33+
from ..utils.typing import Shapes
34+
from .requirements import Diagonal
35+
from .requirements import get_req
36+
from .requirements import layer
37+
from .requirements import requires
2838

2939

3040
@layer

neural_tangents/_src/stax/elementwise.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,24 @@
2424
import warnings
2525

2626
import jax
27-
from jax import custom_jvp, grad, vmap
27+
from jax import custom_jvp
28+
from jax import grad
2829
from jax import numpy as jnp
30+
from jax import vmap
2931
from jax.scipy.special import erf
3032
import numpy as np
31-
from .requirements import Diagonal, get_diagonal, get_diagonal_outer_prods, layer, requires, supports_masking
3233
import scipy as sp
34+
3335
from ..utils import utils
3436
from ..utils.kernel import Kernel
35-
from ..utils.typing import InternalLayer, LayerKernelFn
37+
from ..utils.typing import InternalLayer
38+
from ..utils.typing import LayerKernelFn
39+
from .requirements import Diagonal
40+
from .requirements import get_diagonal
41+
from .requirements import get_diagonal_outer_prods
42+
from .requirements import layer
43+
from .requirements import requires
44+
from .requirements import supports_masking
3645

3746

3847
@layer

neural_tangents/_src/stax/linear.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,30 @@
2222
import warnings
2323

2424
import jax
25+
from jax import eval_shape
2526
from jax import lax
2627
from jax import numpy as jnp
2728
from jax import ops
2829
from jax import random
29-
from jax import ShapeDtypeStruct, eval_shape, vmap
30+
from jax import ShapeDtypeStruct
31+
from jax import vmap
3032
from jax.core import ShapedArray
3133
import jax.example_libraries.stax as ostax
3234
import numpy as np
33-
from .requirements import Bool, Diagonal, get_diagonal_outer_prods, layer, mean_and_var, requires, supports_masking
35+
3436
from ..utils import utils
3537
from ..utils.kernel import Kernel
36-
from ..utils.typing import Axes, InternalLayer, InternalLayerMasked, PyTree
38+
from ..utils.typing import Axes
39+
from ..utils.typing import InternalLayer
40+
from ..utils.typing import InternalLayerMasked
41+
from ..utils.typing import PyTree
42+
from .requirements import Bool
43+
from .requirements import Diagonal
44+
from .requirements import get_diagonal_outer_prods
45+
from .requirements import layer
46+
from .requirements import mean_and_var
47+
from .requirements import requires
48+
from .requirements import supports_masking
3749

3850

3951
# Enums

neural_tangents/_src/stax/requirements.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,34 @@
1414

1515
"""Requirement management for :obj:`~neural_tangents.stax` layers."""
1616

17+
import dataclasses
1718
import enum
1819
from typing import Callable, Optional, Sequence, Union
1920
import warnings
2021

2122
import frozendict
2223
import jax
24+
from jax import eval_shape
2325
from jax import lax
2426
from jax import numpy as jnp
25-
from jax import eval_shape
2627
from jax.core import ShapedArray
27-
from jax.tree_util import tree_map, tree_all
28-
from ..utils import utils
29-
import dataclasses
28+
from jax.tree_util import tree_all
29+
from jax.tree_util import tree_map
30+
import numpy as np
31+
3032
from ..utils import dataclasses as nt_dataclasses
33+
from ..utils import utils
3134
from ..utils.kernel import Kernel
32-
from ..utils.typing import AnalyticKernelFn, Axes, Get, InitFn, ApplyFn, InternalLayer, Layer, LayerKernelFn, NTTree, PyTree
33-
import numpy as np
35+
from ..utils.typing import AnalyticKernelFn
36+
from ..utils.typing import ApplyFn
37+
from ..utils.typing import Axes
38+
from ..utils.typing import Get
39+
from ..utils.typing import InitFn
40+
from ..utils.typing import InternalLayer
41+
from ..utils.typing import Layer
42+
from ..utils.typing import LayerKernelFn
43+
from ..utils.typing import NTTree
44+
from ..utils.typing import PyTree
3445

3546

3647
# Public decorators

neural_tangents/_src/utils/rules.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,26 @@
1414

1515
"""Structured derivatives rules."""
1616

17-
from .dataclasses import dataclass, field
1817
import functools
19-
from typing import Callable, Optional, Any, Union
18+
from typing import Any, Callable, Optional, Union
2019

21-
from . import utils
2220
import jax
2321
from jax import lax
24-
from jax.core import JaxprEqn, ShapedArray, Primitive, Jaxpr, Var, AbstractValue, Literal
22+
from jax.core import AbstractValue
23+
from jax.core import Jaxpr
24+
from jax.core import JaxprEqn
25+
from jax.core import Literal
26+
from jax.core import Primitive
27+
from jax.core import ShapedArray
28+
from jax.core import Var
2529
from jax.interpreters import ad
2630
import jax.numpy as jnp
2731
import numpy as np
2832

33+
from . import utils
34+
from .dataclasses import dataclass
35+
from .dataclasses import field
36+
2937

3038
# pytype: disable=wrong-keyword-args
3139

0 commit comments

Comments
 (0)