2424# runs Bazel TPU tests with py_import.
2525
2626name : CI - Wheel Tests (Continuous)
27- permissions :
28- contents : read
2927
3028on :
3129 schedule :
3230 - cron : " 0 */3 * * *" # Run once every 3 hours
3331 workflow_dispatch : # allows triggering the workflow run manually
3432
33+ pull_request :
34+ branches :
35+ - main
36+ push :
37+ branches :
38+ - main
39+ - ' release/**'
40+ permissions : {}
41+
3542concurrency :
3643 group : ${{ github.workflow }}-${{ github.head_ref || github.ref }}
3744 cancel-in-progress : ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}
5461 fail-fast : false # don't cancel all jobs on failure
5562 matrix :
5663 # Runner OS and Python values need to match the matrix stategy in the CPU tests job
57- runner : ["linux-x86-n4-16", "linux-arm64-t2a-48", "windows-x86-n2-16" ]
64+ runner : ["linux-x86-n4-16"]
5865 artifact : ["jaxlib"]
5966 python : ["3.11"]
6067 # Note: For reasons unknown, Github actions groups jobs with the same top-level name in the
@@ -69,164 +76,6 @@ jobs:
6976 upload_artifacts_to_gcs : true
7077 gcs_upload_uri : ' gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
7178
72- build-cuda-artifacts :
73- uses : ./.github/workflows/build_artifacts.yml
74- strategy :
75- fail-fast : false # don't cancel all jobs on failure
76- matrix :
77- # Python values need to match the matrix stategy in the CUDA tests job below
78- runner : ["linux-x86-n4-16"]
79- artifact : ["jax-cuda-plugin", "jax-cuda-pjrt"]
80- python : ["3.11",]
81- cuda-version : ["12", "13"]
82- name : " Build ${{ format('{0}', 'CUDA') }} artifacts"
83- with :
84- runner : ${{ matrix.runner }}
85- artifact : ${{ matrix.artifact }}
86- python : ${{ matrix.python }}
87- cuda-version : ${{ matrix.cuda-version }}
88- clone_main_xla : 1
89- upload_artifacts_to_gcs : true
90- gcs_upload_uri : ' gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
91-
92- run-pytest-cpu :
93- # Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
94- # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
95- # still want to run the tests for other platforms.
96- if : ${{ !cancelled() }}
97- needs : [build-jax-artifact, build-jaxlib-artifact]
98- uses : ./.github/workflows/pytest_cpu.yml
99- strategy :
100- fail-fast : false # don't cancel all jobs on failure
101- matrix :
102- # Runner OS and Python values need to match the matrix stategy in the
103- # build_jaxlib_artifact job above
104- runner : ["linux-x86-n4-64", "linux-arm64-t2a-48", "windows-x86-n2-64"]
105- python : ["3.11",]
106- enable-x64 : [1, 0]
107- name : " Pytest CPU (JAX artifacts version = ${{ format('{0}', 'head') }})"
108- with :
109- runner : ${{ matrix.runner }}
110- python : ${{ matrix.python }}
111- enable-x64 : ${{ matrix.enable-x64 }}
112- gcs_download_uri : ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
113-
114- run-pytest-cuda :
115- # Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
116- # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
117- # still want to run the tests for other platforms.
118- if : ${{ !cancelled() }}
119- needs : [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts]
120- uses : ./.github/workflows/pytest_cuda.yml
121- strategy :
122- fail-fast : false # don't cancel all jobs on failure
123- matrix :
124- # Python values need to match the matrix stategy in the artifact build jobs above
125- # See exlusions for what is fully tested
126- runner : ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu", "linux-x86-a4-224-b200-1gpu"]
127- python : ["3.11",]
128- cuda : [
129- {version: "12.1", use-nvidia-pip-wheels: false},
130- {version: "12.9", use-nvidia-pip-wheels: true},
131- {version: "13", use-nvidia-pip-wheels: true},
132- ]
133- enable-x64 : [1, 0]
134- exclude :
135- # H100 runs only a single config, CUDA 12.9 Enable x64 1
136- - runner : " linux-x86-a3-8g-h100-8gpu"
137- cuda :
138- version : " 12.1"
139- - runner : " linux-x86-a3-8g-h100-8gpu"
140- enable-x64 : " 0"
141- # B200 runs only a single config, CUDA 12.9 Enable x64 1
142- - runner : " linux-x86-a4-224-b200-1gpu"
143- cuda :
144- version : " 12.1"
145- - runner : " linux-x86-a4-224-b200-1gpu"
146- enable-x64 : " 0"
147-
148- name : " Pytest CUDA (JAX artifacts version = ${{ format('{0}', 'head') }}, CUDA Pip packages = ${{ matrix.cuda.use-nvidia-pip-wheels }})"
149- with :
150- runner : ${{ matrix.runner }}
151- python : ${{ matrix.python }}
152- cuda-version : ${{ matrix.cuda.version }}
153- use-nvidia-pip-wheels : ${{ matrix.cuda.use-nvidia-pip-wheels }}
154- enable-x64 : ${{ matrix.enable-x64 }}
155- # GCS upload URI is the same for both artifact build jobs
156- gcs_download_uri : ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
157-
158- run-bazel-test-cpu-py-import :
159- uses : ./.github/workflows/bazel_cpu.yml
160- strategy :
161- fail-fast : false # don't cancel all jobs on failure
162- matrix :
163- runner : ["linux-x86-n4-16", "linux-arm64-t2a-48", "windows-x86-n2-16"]
164- python : ["3.11",]
165- enable-x64 : [1, 0]
166- name : " Bazel CPU tests with ${{ format('{0}', 'build_jaxlib=wheel') }}"
167- with :
168- runner : ${{ matrix.runner }}
169- python : ${{ matrix.python }}
170- enable-x64 : ${{ matrix.enable-x64 }}
171- build_jaxlib : " wheel"
172- build_jax : " wheel"
173-
174- run-bazel-test-cuda :
175- # Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
176- # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
177- # still want to run the tests for other platforms.
178- if : ${{ !cancelled() }}
179- needs : [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts]
180- uses : ./.github/workflows/bazel_cuda.yml
181- strategy :
182- fail-fast : false # don't cancel all jobs on failure
183- matrix :
184- # Python values need to match the matrix stategy in the build artifacts job above
185- runner : ["linux-x86-g2-48-l4-4gpu",]
186- python : ["3.11",]
187- cuda-version : ["12", "13"]
188- jaxlib-version : ["head", "pypi_latest"]
189- enable-x64 : [1, 0]
190- name : " Bazel CUDA Non-RBE with build_jaxlib=false, (jax version = ${{ format('{0}', 'head') }})"
191- with :
192- runner : ${{ matrix.runner }}
193- python : ${{ matrix.python }}
194- cuda-version : ${{ matrix.cuda-version }}
195- enable-x64 : ${{ matrix.enable-x64 }}
196- jaxlib-version : ${{ matrix.jaxlib-version }}
197- # GCS upload URI is the same for both artifact build jobs
198- gcs_download_uri : ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
199- build_jaxlib : " false"
200- build_jax : " false"
201- write_to_bazel_remote_cache : 1
202- run_multiaccelerator_tests : " true"
203-
204- run-bazel-test-cuda-py-import :
205- # Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
206- # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
207- # still want to run the tests for other platforms.
208- if : ${{ !cancelled() }}
209- uses : ./.github/workflows/bazel_cuda.yml
210- strategy :
211- fail-fast : false # don't cancel all jobs on failure
212- matrix :
213- # Python values need to match the matrix stategy in the build artifacts job above
214- runner : ["linux-x86-g2-48-l4-4gpu",]
215- python : ["3.11"]
216- cuda-version : ["12", "13"]
217- enable-x64 : [1]
218- name : " Bazel CUDA Non-RBE with ${{ format('{0}', 'build_jaxlib=wheel') }}"
219- with :
220- runner : ${{ matrix.runner }}
221- python : ${{ matrix.python }}
222- cuda-version : ${{ matrix.cuda-version }}
223- enable-x64 : ${{ matrix.enable-x64 }}
224- build_jaxlib : " wheel"
225- build_jax : " wheel"
226- jaxlib-version : " head"
227- write_to_bazel_remote_cache : 1
228- run_multiaccelerator_tests : " true"
229-
23079 run-pytest-tpu :
23180 # Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
23281 # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
24190 tpu-specs : [
24291 # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available
24392 {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"},
244- {type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"},
245- {type: "v7x-8", cores: "8", runner: "linux-x86-tpu7x-224-4tpu"}
93+ {type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"}
24694 ]
24795 libtpu-version-type : ["nightly"]
24896 name : " Pytest TPU (JAX artifacts version = ${{ format('{0}', 'head') }})"
@@ -253,32 +101,4 @@ jobs:
253101 python : ${{ matrix.python }}
254102 run-full-tpu-test-suite : " 1"
255103 libtpu-version-type : ${{ matrix.libtpu-version-type }}
256- gcs_download_uri : ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
257-
258- run-bazel-test-tpu :
259- # Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
260- # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
261- # still want to run the tests for other platforms.
262- if : ${{ !cancelled() }}
263- uses : ./.github/workflows/bazel_test_tpu.yml
264- strategy :
265- fail-fast : false # don't cancel all jobs on failure
266- matrix :
267- python : ["3.11"]
268- tpu-specs : [
269- {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
270- {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"},
271- ]
272- libtpu-version-type : ["nightly"]
273- name : " Bazel tests TPU (JAX artifacts version = ${{ format('{0}', 'head') }})"
274- with :
275- runner : ${{ matrix.tpu-specs.runner }}
276- cores : ${{ matrix.tpu-specs.cores }}
277- tpu-type : ${{ matrix.tpu-specs.type }}
278- python : ${{ matrix.python }}
279- run-full-tpu-test-suite : " 1"
280- libtpu-version-type : ${{ matrix.libtpu-version-type }}
281- gcs_download_uri : ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
282- build_jaxlib : " wheel"
283- build_jax : " wheel"
284- clone_main_xla : 1
104+ gcs_download_uri : ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
0 commit comments