3131 schedule :
3232 - cron : " 0 */3 * * *" # Run once every 3 hours
3333 workflow_dispatch : # allows triggering the workflow run manually
34+ pull_request :
35+ branches :
36+ - main
37+ push :
38+ branches :
39+ - main
40+ - ' release/**'
41+
3442
3543concurrency :
3644 group : ${{ github.workflow }}-${{ github.head_ref || github.ref }}
3745 cancel-in-progress : ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}
3846
3947jobs :
40- build-jax-artifact :
41- uses : ./.github/workflows/build_artifacts.yml
42- name : " Build jax artifact"
43- with :
44- # Note that since jax is a pure python package, the runner OS and Python values do not
45- # matter. In addition, cloning main XLA also has no effect.
46- runner : " linux-x86-n4-16"
47- artifact : " jax"
48- upload_artifacts_to_gcs : true
49- gcs_upload_uri : ' gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
50-
51- build-jaxlib-artifact :
52- uses : ./.github/workflows/build_artifacts.yml
53- strategy :
54- fail-fast : false # don't cancel all jobs on failure
55- matrix :
56- # Runner OS and Python values need to match the matrix stategy in the CPU tests job
57- runner : ["linux-x86-n4-16", "linux-arm64-t2a-48", "windows-x86-n2-16"]
58- artifact : ["jaxlib"]
59- python : ["3.11"]
60- # Note: For reasons unknown, Github actions groups jobs with the same top-level name in the
61- # dashboard only if we use an expression in the "name" field. Otherwise, it appends the matrix
62- # values to the name and creates a separate entry for each matrix combination.
63- name : " Build ${{ format('{0}', 'jaxlib') }} artifacts"
64- with :
65- runner : ${{ matrix.runner }}
66- artifact : ${{ matrix.artifact }}
67- python : ${{ matrix.python }}
68- clone_main_xla : 1
69- upload_artifacts_to_gcs : true
70- gcs_upload_uri : ' gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
71-
72- build-cuda-artifacts :
73- uses : ./.github/workflows/build_artifacts.yml
74- strategy :
75- fail-fast : false # don't cancel all jobs on failure
76- matrix :
77- # Python values need to match the matrix stategy in the CUDA tests job below
78- runner : ["linux-x86-n4-16"]
79- artifact : ["jax-cuda-plugin", "jax-cuda-pjrt"]
80- python : ["3.11",]
81- cuda-version : ["12", "13"]
82- name : " Build ${{ format('{0}', 'CUDA') }} artifacts"
83- with :
84- runner : ${{ matrix.runner }}
85- artifact : ${{ matrix.artifact }}
86- python : ${{ matrix.python }}
87- cuda-version : ${{ matrix.cuda-version }}
88- clone_main_xla : 1
89- upload_artifacts_to_gcs : true
90- gcs_upload_uri : ' gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
91-
92- run-pytest-cpu :
93- # Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
94- # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
95- # still want to run the tests for other platforms.
96- if : ${{ !cancelled() }}
97- needs : [build-jax-artifact, build-jaxlib-artifact]
98- uses : ./.github/workflows/pytest_cpu.yml
99- strategy :
100- fail-fast : false # don't cancel all jobs on failure
101- matrix :
102- # Runner OS and Python values need to match the matrix stategy in the
103- # build_jaxlib_artifact job above
104- runner : ["linux-x86-n4-64", "linux-arm64-t2a-48", "windows-x86-n2-64"]
105- python : ["3.11",]
106- enable-x64 : [1, 0]
107- name : " Pytest CPU (JAX artifacts version = ${{ format('{0}', 'head') }})"
108- with :
109- runner : ${{ matrix.runner }}
110- python : ${{ matrix.python }}
111- enable-x64 : ${{ matrix.enable-x64 }}
112- gcs_download_uri : ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
113-
114- run-pytest-cuda :
115- # Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
116- # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
117- # still want to run the tests for other platforms.
118- if : ${{ !cancelled() }}
119- needs : [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts]
120- uses : ./.github/workflows/pytest_cuda.yml
121- strategy :
122- fail-fast : false # don't cancel all jobs on failure
123- matrix :
124- # Python values need to match the matrix stategy in the artifact build jobs above
125- # See exlusions for what is fully tested
126- runner : ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu", "linux-x86-a4-224-b200-1gpu"]
127- python : ["3.11",]
128- cuda : [
129- {version: "12.1", use-nvidia-pip-wheels: false},
130- {version: "12.9", use-nvidia-pip-wheels: true},
131- {version: "13", use-nvidia-pip-wheels: true},
132- ]
133- enable-x64 : [1, 0]
134- exclude :
135- # H100 runs only a single config, CUDA 12.9 Enable x64 1
136- - runner : " linux-x86-a3-8g-h100-8gpu"
137- cuda :
138- version : " 12.1"
139- - runner : " linux-x86-a3-8g-h100-8gpu"
140- enable-x64 : " 0"
141- # B200 runs only a single config, CUDA 12.9 Enable x64 1
142- - runner : " linux-x86-a4-224-b200-1gpu"
143- cuda :
144- version : " 12.1"
145- - runner : " linux-x86-a4-224-b200-1gpu"
146- enable-x64 : " 0"
147-
148- name : " Pytest CUDA (JAX artifacts version = ${{ format('{0}', 'head') }}, CUDA Pip packages = ${{ matrix.cuda.use-nvidia-pip-wheels }})"
149- with :
150- runner : ${{ matrix.runner }}
151- python : ${{ matrix.python }}
152- cuda-version : ${{ matrix.cuda.version }}
153- use-nvidia-pip-wheels : ${{ matrix.cuda.use-nvidia-pip-wheels }}
154- enable-x64 : ${{ matrix.enable-x64 }}
155- # GCS upload URI is the same for both artifact build jobs
156- gcs_download_uri : ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
157-
158- run-bazel-test-cpu-py-import :
159- uses : ./.github/workflows/bazel_cpu.yml
160- strategy :
161- fail-fast : false # don't cancel all jobs on failure
162- matrix :
163- runner : ["linux-x86-n4-16", "linux-arm64-t2a-48", "windows-x86-n2-16"]
164- python : ["3.11",]
165- enable-x64 : [1, 0]
166- name : " Bazel CPU tests with ${{ format('{0}', 'build_jaxlib=wheel') }}"
167- with :
168- runner : ${{ matrix.runner }}
169- python : ${{ matrix.python }}
170- enable-x64 : ${{ matrix.enable-x64 }}
171- build_jaxlib : " wheel"
172- build_jax : " wheel"
173-
174- run-bazel-test-cuda :
175- # Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
176- # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
177- # still want to run the tests for other platforms.
178- if : ${{ !cancelled() }}
179- needs : [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts]
180- uses : ./.github/workflows/bazel_cuda.yml
181- strategy :
182- fail-fast : false # don't cancel all jobs on failure
183- matrix :
184- # Python values need to match the matrix stategy in the build artifacts job above
185- runner : ["linux-x86-g2-48-l4-4gpu",]
186- python : ["3.11",]
187- cuda-version : ["12", "13"]
188- jaxlib-version : ["head", "pypi_latest"]
189- enable-x64 : [1, 0]
190- name : " Bazel CUDA Non-RBE with build_jaxlib=false, (jax version = ${{ format('{0}', 'head') }})"
191- with :
192- runner : ${{ matrix.runner }}
193- python : ${{ matrix.python }}
194- cuda-version : ${{ matrix.cuda-version }}
195- enable-x64 : ${{ matrix.enable-x64 }}
196- jaxlib-version : ${{ matrix.jaxlib-version }}
197- # GCS upload URI is the same for both artifact build jobs
198- gcs_download_uri : ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
199- build_jaxlib : " false"
200- build_jax : " false"
201- write_to_bazel_remote_cache : 1
202- run_multiaccelerator_tests : " true"
20348
20449 run-bazel-test-cuda-py-import :
20550 # Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
21156 fail-fast : false # don't cancel all jobs on failure
21257 matrix :
21358 # Python values need to match the matrix stategy in the build artifacts job above
214- runner : ["linux-x86-g2-48-l4-4gpu", ]
59+ runner : ["linux-x86-g2-48-l4-4gpu"]
21560 python : ["3.11"]
216- cuda-version : ["12", "13" ]
61+ cuda-version : ["12"]
21762 enable-x64 : [1]
21863 name : " Bazel CUDA Non-RBE with ${{ format('{0}', 'build_jaxlib=wheel') }}"
21964 with :
@@ -226,59 +71,3 @@ jobs:
22671 jaxlib-version : " head"
22772 write_to_bazel_remote_cache : 1
22873 run_multiaccelerator_tests : " true"
229-
230- run-pytest-tpu :
231- # Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
232- # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
233- # still want to run the tests for other platforms.
234- if : ${{ !cancelled() }}
235- needs : [build-jax-artifact, build-jaxlib-artifact]
236- uses : ./.github/workflows/pytest_tpu.yml
237- strategy :
238- fail-fast : false # don't cancel all jobs on failure
239- matrix :
240- python : ["3.11"]
241- tpu-specs : [
242- # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available
243- {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"},
244- {type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"},
245- {type: "v7x-8", cores: "8", runner: "linux-x86-tpu7x-224-4tpu"}
246- ]
247- libtpu-version-type : ["nightly"]
248- name : " Pytest TPU (JAX artifacts version = ${{ format('{0}', 'head') }})"
249- with :
250- runner : ${{ matrix.tpu-specs.runner }}
251- cores : ${{ matrix.tpu-specs.cores }}
252- tpu-type : ${{ matrix.tpu-specs.type }}
253- python : ${{ matrix.python }}
254- run-full-tpu-test-suite : " 1"
255- libtpu-version-type : ${{ matrix.libtpu-version-type }}
256- gcs_download_uri : ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
257-
258- run-bazel-test-tpu :
259- # Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
260- # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
261- # still want to run the tests for other platforms.
262- if : ${{ !cancelled() }}
263- uses : ./.github/workflows/bazel_test_tpu.yml
264- strategy :
265- fail-fast : false # don't cancel all jobs on failure
266- matrix :
267- python : ["3.11"]
268- tpu-specs : [
269- {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
270- {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"},
271- ]
272- libtpu-version-type : ["nightly"]
273- name : " Bazel tests TPU (JAX artifacts version = ${{ format('{0}', 'head') }})"
274- with :
275- runner : ${{ matrix.tpu-specs.runner }}
276- cores : ${{ matrix.tpu-specs.cores }}
277- tpu-type : ${{ matrix.tpu-specs.type }}
278- python : ${{ matrix.python }}
279- run-full-tpu-test-suite : " 1"
280- libtpu-version-type : ${{ matrix.libtpu-version-type }}
281- gcs_download_uri : ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
282- build_jaxlib : " wheel"
283- build_jax : " wheel"
284- clone_main_xla : 1
0 commit comments