Skip to content

Commit e997d87

Browse files
belitskiyGoogle-ML-Automation
authored andcommitted
Bisect failing callback test
PiperOrigin-RevId: 845350990
1 parent ed4d825 commit e997d87

File tree

2 files changed

+12
-223
lines changed

2 files changed

+12
-223
lines changed

.github/workflows/bazel_cuda.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ jobs:
119119
- name: Wait For Connection
120120
uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c
121121
with:
122-
halt-dispatch-input: ${{ inputs.halt-for-connection }}
122+
halt-dispatch-input: "yes"
123123
- name: "Bazel CUDA tests with build_jax=${{ inputs.build_jax }}, build_jaxlib=${{ inputs.build_jaxlib }}"
124-
timeout-minutes: 60
124+
timeout-minutes: 180
125125
run: ${{ ((inputs.run_multiaccelerator_tests == 'false') && './ci/run_bazel_test_cuda_rbe.sh') || './ci/run_bazel_test_cuda_non_rbe.sh' }}

.github/workflows/wheel_tests_continuous.yml

Lines changed: 10 additions & 221 deletions
Original file line numberDiff line numberDiff line change
@@ -31,175 +31,20 @@ on:
3131
schedule:
3232
- cron: "0 */3 * * *" # Run once every 3 hours
3333
workflow_dispatch: # allows triggering the workflow run manually
34+
pull_request:
35+
branches:
36+
- main
37+
push:
38+
branches:
39+
- main
40+
- 'release/**'
41+
3442

3543
concurrency:
3644
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
3745
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}
3846

3947
jobs:
40-
build-jax-artifact:
41-
uses: ./.github/workflows/build_artifacts.yml
42-
name: "Build jax artifact"
43-
with:
44-
# Note that since jax is a pure python package, the runner OS and Python values do not
45-
# matter. In addition, cloning main XLA also has no effect.
46-
runner: "linux-x86-n4-16"
47-
artifact: "jax"
48-
upload_artifacts_to_gcs: true
49-
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
50-
51-
build-jaxlib-artifact:
52-
uses: ./.github/workflows/build_artifacts.yml
53-
strategy:
54-
fail-fast: false # don't cancel all jobs on failure
55-
matrix:
56-
# Runner OS and Python values need to match the matrix stategy in the CPU tests job
57-
runner: ["linux-x86-n4-16", "linux-arm64-t2a-48", "windows-x86-n2-16"]
58-
artifact: ["jaxlib"]
59-
python: ["3.11"]
60-
# Note: For reasons unknown, Github actions groups jobs with the same top-level name in the
61-
# dashboard only if we use an expression in the "name" field. Otherwise, it appends the matrix
62-
# values to the name and creates a separate entry for each matrix combination.
63-
name: "Build ${{ format('{0}', 'jaxlib') }} artifacts"
64-
with:
65-
runner: ${{ matrix.runner }}
66-
artifact: ${{ matrix.artifact }}
67-
python: ${{ matrix.python }}
68-
clone_main_xla: 1
69-
upload_artifacts_to_gcs: true
70-
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
71-
72-
build-cuda-artifacts:
73-
uses: ./.github/workflows/build_artifacts.yml
74-
strategy:
75-
fail-fast: false # don't cancel all jobs on failure
76-
matrix:
77-
# Python values need to match the matrix stategy in the CUDA tests job below
78-
runner: ["linux-x86-n4-16"]
79-
artifact: ["jax-cuda-plugin", "jax-cuda-pjrt"]
80-
python: ["3.11",]
81-
cuda-version: ["12", "13"]
82-
name: "Build ${{ format('{0}', 'CUDA') }} artifacts"
83-
with:
84-
runner: ${{ matrix.runner }}
85-
artifact: ${{ matrix.artifact }}
86-
python: ${{ matrix.python }}
87-
cuda-version: ${{ matrix.cuda-version }}
88-
clone_main_xla: 1
89-
upload_artifacts_to_gcs: true
90-
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
91-
92-
run-pytest-cpu:
93-
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
94-
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
95-
# still want to run the tests for other platforms.
96-
if: ${{ !cancelled() }}
97-
needs: [build-jax-artifact, build-jaxlib-artifact]
98-
uses: ./.github/workflows/pytest_cpu.yml
99-
strategy:
100-
fail-fast: false # don't cancel all jobs on failure
101-
matrix:
102-
# Runner OS and Python values need to match the matrix stategy in the
103-
# build_jaxlib_artifact job above
104-
runner: ["linux-x86-n4-64", "linux-arm64-t2a-48", "windows-x86-n2-64"]
105-
python: ["3.11",]
106-
enable-x64: [1, 0]
107-
name: "Pytest CPU (JAX artifacts version = ${{ format('{0}', 'head') }})"
108-
with:
109-
runner: ${{ matrix.runner }}
110-
python: ${{ matrix.python }}
111-
enable-x64: ${{ matrix.enable-x64 }}
112-
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
113-
114-
run-pytest-cuda:
115-
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
116-
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
117-
# still want to run the tests for other platforms.
118-
if: ${{ !cancelled() }}
119-
needs: [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts]
120-
uses: ./.github/workflows/pytest_cuda.yml
121-
strategy:
122-
fail-fast: false # don't cancel all jobs on failure
123-
matrix:
124-
# Python values need to match the matrix stategy in the artifact build jobs above
125-
# See exlusions for what is fully tested
126-
runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu", "linux-x86-a4-224-b200-1gpu"]
127-
python: ["3.11",]
128-
cuda: [
129-
{version: "12.1", use-nvidia-pip-wheels: false},
130-
{version: "12.9", use-nvidia-pip-wheels: true},
131-
{version: "13", use-nvidia-pip-wheels: true},
132-
]
133-
enable-x64: [1, 0]
134-
exclude:
135-
# H100 runs only a single config, CUDA 12.9 Enable x64 1
136-
- runner: "linux-x86-a3-8g-h100-8gpu"
137-
cuda:
138-
version: "12.1"
139-
- runner: "linux-x86-a3-8g-h100-8gpu"
140-
enable-x64: "0"
141-
# B200 runs only a single config, CUDA 12.9 Enable x64 1
142-
- runner: "linux-x86-a4-224-b200-1gpu"
143-
cuda:
144-
version: "12.1"
145-
- runner: "linux-x86-a4-224-b200-1gpu"
146-
enable-x64: "0"
147-
148-
name: "Pytest CUDA (JAX artifacts version = ${{ format('{0}', 'head') }}, CUDA Pip packages = ${{ matrix.cuda.use-nvidia-pip-wheels }})"
149-
with:
150-
runner: ${{ matrix.runner }}
151-
python: ${{ matrix.python }}
152-
cuda-version: ${{ matrix.cuda.version }}
153-
use-nvidia-pip-wheels: ${{ matrix.cuda.use-nvidia-pip-wheels }}
154-
enable-x64: ${{ matrix.enable-x64 }}
155-
# GCS upload URI is the same for both artifact build jobs
156-
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
157-
158-
run-bazel-test-cpu-py-import:
159-
uses: ./.github/workflows/bazel_cpu.yml
160-
strategy:
161-
fail-fast: false # don't cancel all jobs on failure
162-
matrix:
163-
runner: ["linux-x86-n4-16", "linux-arm64-t2a-48", "windows-x86-n2-16"]
164-
python: ["3.11",]
165-
enable-x64: [1, 0]
166-
name: "Bazel CPU tests with ${{ format('{0}', 'build_jaxlib=wheel') }}"
167-
with:
168-
runner: ${{ matrix.runner }}
169-
python: ${{ matrix.python }}
170-
enable-x64: ${{ matrix.enable-x64 }}
171-
build_jaxlib: "wheel"
172-
build_jax: "wheel"
173-
174-
run-bazel-test-cuda:
175-
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
176-
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
177-
# still want to run the tests for other platforms.
178-
if: ${{ !cancelled() }}
179-
needs: [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts]
180-
uses: ./.github/workflows/bazel_cuda.yml
181-
strategy:
182-
fail-fast: false # don't cancel all jobs on failure
183-
matrix:
184-
# Python values need to match the matrix stategy in the build artifacts job above
185-
runner: ["linux-x86-g2-48-l4-4gpu",]
186-
python: ["3.11",]
187-
cuda-version: ["12", "13"]
188-
jaxlib-version: ["head", "pypi_latest"]
189-
enable-x64: [1, 0]
190-
name: "Bazel CUDA Non-RBE with build_jaxlib=false, (jax version = ${{ format('{0}', 'head') }})"
191-
with:
192-
runner: ${{ matrix.runner }}
193-
python: ${{ matrix.python }}
194-
cuda-version: ${{ matrix.cuda-version }}
195-
enable-x64: ${{ matrix.enable-x64 }}
196-
jaxlib-version: ${{ matrix.jaxlib-version }}
197-
# GCS upload URI is the same for both artifact build jobs
198-
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
199-
build_jaxlib: "false"
200-
build_jax: "false"
201-
write_to_bazel_remote_cache: 1
202-
run_multiaccelerator_tests: "true"
20348

20449
run-bazel-test-cuda-py-import:
20550
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
@@ -211,9 +56,9 @@ jobs:
21156
fail-fast: false # don't cancel all jobs on failure
21257
matrix:
21358
# Python values need to match the matrix stategy in the build artifacts job above
214-
runner: ["linux-x86-g2-48-l4-4gpu",]
59+
runner: ["linux-x86-g2-48-l4-4gpu"]
21560
python: ["3.11"]
216-
cuda-version: ["12", "13"]
61+
cuda-version: ["12"]
21762
enable-x64: [1]
21863
name: "Bazel CUDA Non-RBE with ${{ format('{0}', 'build_jaxlib=wheel') }}"
21964
with:
@@ -226,59 +71,3 @@ jobs:
22671
jaxlib-version: "head"
22772
write_to_bazel_remote_cache: 1
22873
run_multiaccelerator_tests: "true"
229-
230-
run-pytest-tpu:
231-
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
232-
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
233-
# still want to run the tests for other platforms.
234-
if: ${{ !cancelled() }}
235-
needs: [build-jax-artifact, build-jaxlib-artifact]
236-
uses: ./.github/workflows/pytest_tpu.yml
237-
strategy:
238-
fail-fast: false # don't cancel all jobs on failure
239-
matrix:
240-
python: ["3.11"]
241-
tpu-specs: [
242-
# {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available
243-
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"},
244-
{type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"},
245-
{type: "v7x-8", cores: "8", runner: "linux-x86-tpu7x-224-4tpu"}
246-
]
247-
libtpu-version-type: ["nightly"]
248-
name: "Pytest TPU (JAX artifacts version = ${{ format('{0}', 'head') }})"
249-
with:
250-
runner: ${{ matrix.tpu-specs.runner }}
251-
cores: ${{ matrix.tpu-specs.cores }}
252-
tpu-type: ${{ matrix.tpu-specs.type }}
253-
python: ${{ matrix.python }}
254-
run-full-tpu-test-suite: "1"
255-
libtpu-version-type: ${{ matrix.libtpu-version-type }}
256-
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
257-
258-
run-bazel-test-tpu:
259-
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
260-
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
261-
# still want to run the tests for other platforms.
262-
if: ${{ !cancelled() }}
263-
uses: ./.github/workflows/bazel_test_tpu.yml
264-
strategy:
265-
fail-fast: false # don't cancel all jobs on failure
266-
matrix:
267-
python: ["3.11"]
268-
tpu-specs: [
269-
{type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
270-
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"},
271-
]
272-
libtpu-version-type: ["nightly"]
273-
name: "Bazel tests TPU (JAX artifacts version = ${{ format('{0}', 'head') }})"
274-
with:
275-
runner: ${{ matrix.tpu-specs.runner }}
276-
cores: ${{ matrix.tpu-specs.cores }}
277-
tpu-type: ${{ matrix.tpu-specs.type }}
278-
python: ${{ matrix.python }}
279-
run-full-tpu-test-suite: "1"
280-
libtpu-version-type: ${{ matrix.libtpu-version-type }}
281-
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
282-
build_jaxlib: "wheel"
283-
build_jax: "wheel"
284-
clone_main_xla: 1

0 commit comments

Comments
 (0)