Skip to content

Commit 75ef78e

Browse files
belitskiyGoogle-ML-Automation
authored andcommitted
Testing this quickly.
PiperOrigin-RevId: 844897893
1 parent 6860386 commit 75ef78e

File tree

6 files changed

+47
-201
lines changed

6 files changed

+47
-201
lines changed

.github/workflows/build_artifacts.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,14 @@ jobs:
125125
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
126126
with:
127127
persist-credentials: false
128+
fetch-depth: 0
129+
# - name: Revert the problematic commit
130+
# run: |
131+
# git config --global --add safe.directory /__w/jax/jax
132+
# git config --global user.email "[email protected]"
133+
# git config --global user.name "Vlad Belitskiy"
134+
# git revert --no-edit 548eaa5b53afeba91518d4d9274f7198b55cc308
135+
# echo "Commit 548eaa5b53afeba91518d4d9274f7198b55cc308 reverted for this build."
128136
- name: Configure Build Environment
129137
shell: bash
130138
run: |

.github/workflows/pytest_tpu.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,16 +114,16 @@ jobs:
114114
echo "Unknown libtpu version type: ${{ inputs.libtpu-version-type }}"
115115
exit 1
116116
fi
117-
# Halt for testing
118-
- name: Wait For Connection
119-
uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c
120-
with:
121-
halt-dispatch-input: ${{ inputs.halt-for-connection }}
122117
- name: Run Pytest TPU tests
123-
timeout-minutes: ${{ github.event_name == 'pull_request' && 30 || 210 }}
118+
timeout-minutes: ${{ github.event_name == 'pull_request' && 210 || 210 }}
124119
run: |
125120
if [[ ${{ inputs.python }} == "3.13-nogil" && ${{ inputs.tpu-type }} == "v5e-8" ]]; then
126121
echo "Uninstalling xprof as it is not compatible with python 3.13t."
127122
$JAXCI_PYTHON -m uv pip uninstall xprof
128123
fi
129124
./ci/run_pytest_tpu.sh
125+
# Halt for testing
126+
- name: Wait For Connection
127+
uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c
128+
with:
129+
halt-dispatch-input: yes

.github/workflows/wheel_tests_continuous.yml

Lines changed: 12 additions & 192 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,21 @@
2424
# runs Bazel TPU tests with py_import.
2525

2626
name: CI - Wheel Tests (Continuous)
27-
permissions:
28-
contents: read
2927

3028
on:
3129
schedule:
3230
- cron: "0 */3 * * *" # Run once every 3 hours
3331
workflow_dispatch: # allows triggering the workflow run manually
3432

33+
pull_request:
34+
branches:
35+
- main
36+
push:
37+
branches:
38+
- main
39+
- 'release/**'
40+
permissions: {}
41+
3542
concurrency:
3643
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
3744
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}
@@ -54,7 +61,7 @@ jobs:
5461
fail-fast: false # don't cancel all jobs on failure
5562
matrix:
5663
# Runner OS and Python values need to match the matrix stategy in the CPU tests job
57-
runner: ["linux-x86-n4-16", "linux-arm64-t2a-48", "windows-x86-n2-16"]
64+
runner: ["linux-x86-n4-16"]
5865
artifact: ["jaxlib"]
5966
python: ["3.11"]
6067
# Note: For reasons unknown, Github actions groups jobs with the same top-level name in the
@@ -69,164 +76,6 @@ jobs:
6976
upload_artifacts_to_gcs: true
7077
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
7178

72-
build-cuda-artifacts:
73-
uses: ./.github/workflows/build_artifacts.yml
74-
strategy:
75-
fail-fast: false # don't cancel all jobs on failure
76-
matrix:
77-
# Python values need to match the matrix stategy in the CUDA tests job below
78-
runner: ["linux-x86-n4-16"]
79-
artifact: ["jax-cuda-plugin", "jax-cuda-pjrt"]
80-
python: ["3.11",]
81-
cuda-version: ["12", "13"]
82-
name: "Build ${{ format('{0}', 'CUDA') }} artifacts"
83-
with:
84-
runner: ${{ matrix.runner }}
85-
artifact: ${{ matrix.artifact }}
86-
python: ${{ matrix.python }}
87-
cuda-version: ${{ matrix.cuda-version }}
88-
clone_main_xla: 1
89-
upload_artifacts_to_gcs: true
90-
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
91-
92-
run-pytest-cpu:
93-
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
94-
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
95-
# still want to run the tests for other platforms.
96-
if: ${{ !cancelled() }}
97-
needs: [build-jax-artifact, build-jaxlib-artifact]
98-
uses: ./.github/workflows/pytest_cpu.yml
99-
strategy:
100-
fail-fast: false # don't cancel all jobs on failure
101-
matrix:
102-
# Runner OS and Python values need to match the matrix stategy in the
103-
# build_jaxlib_artifact job above
104-
runner: ["linux-x86-n4-64", "linux-arm64-t2a-48", "windows-x86-n2-64"]
105-
python: ["3.11",]
106-
enable-x64: [1, 0]
107-
name: "Pytest CPU (JAX artifacts version = ${{ format('{0}', 'head') }})"
108-
with:
109-
runner: ${{ matrix.runner }}
110-
python: ${{ matrix.python }}
111-
enable-x64: ${{ matrix.enable-x64 }}
112-
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
113-
114-
run-pytest-cuda:
115-
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
116-
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
117-
# still want to run the tests for other platforms.
118-
if: ${{ !cancelled() }}
119-
needs: [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts]
120-
uses: ./.github/workflows/pytest_cuda.yml
121-
strategy:
122-
fail-fast: false # don't cancel all jobs on failure
123-
matrix:
124-
# Python values need to match the matrix stategy in the artifact build jobs above
125-
# See exlusions for what is fully tested
126-
runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu", "linux-x86-a4-224-b200-1gpu"]
127-
python: ["3.11",]
128-
cuda: [
129-
{version: "12.1", use-nvidia-pip-wheels: false},
130-
{version: "12.9", use-nvidia-pip-wheels: true},
131-
{version: "13", use-nvidia-pip-wheels: true},
132-
]
133-
enable-x64: [1, 0]
134-
exclude:
135-
# H100 runs only a single config, CUDA 12.9 Enable x64 1
136-
- runner: "linux-x86-a3-8g-h100-8gpu"
137-
cuda:
138-
version: "12.1"
139-
- runner: "linux-x86-a3-8g-h100-8gpu"
140-
enable-x64: "0"
141-
# B200 runs only a single config, CUDA 12.9 Enable x64 1
142-
- runner: "linux-x86-a4-224-b200-1gpu"
143-
cuda:
144-
version: "12.1"
145-
- runner: "linux-x86-a4-224-b200-1gpu"
146-
enable-x64: "0"
147-
148-
name: "Pytest CUDA (JAX artifacts version = ${{ format('{0}', 'head') }}, CUDA Pip packages = ${{ matrix.cuda.use-nvidia-pip-wheels }})"
149-
with:
150-
runner: ${{ matrix.runner }}
151-
python: ${{ matrix.python }}
152-
cuda-version: ${{ matrix.cuda.version }}
153-
use-nvidia-pip-wheels: ${{ matrix.cuda.use-nvidia-pip-wheels }}
154-
enable-x64: ${{ matrix.enable-x64 }}
155-
# GCS upload URI is the same for both artifact build jobs
156-
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
157-
158-
run-bazel-test-cpu-py-import:
159-
uses: ./.github/workflows/bazel_cpu.yml
160-
strategy:
161-
fail-fast: false # don't cancel all jobs on failure
162-
matrix:
163-
runner: ["linux-x86-n4-16", "linux-arm64-t2a-48", "windows-x86-n2-16"]
164-
python: ["3.11",]
165-
enable-x64: [1, 0]
166-
name: "Bazel CPU tests with ${{ format('{0}', 'build_jaxlib=wheel') }}"
167-
with:
168-
runner: ${{ matrix.runner }}
169-
python: ${{ matrix.python }}
170-
enable-x64: ${{ matrix.enable-x64 }}
171-
build_jaxlib: "wheel"
172-
build_jax: "wheel"
173-
174-
run-bazel-test-cuda:
175-
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
176-
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
177-
# still want to run the tests for other platforms.
178-
if: ${{ !cancelled() }}
179-
needs: [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts]
180-
uses: ./.github/workflows/bazel_cuda.yml
181-
strategy:
182-
fail-fast: false # don't cancel all jobs on failure
183-
matrix:
184-
# Python values need to match the matrix stategy in the build artifacts job above
185-
runner: ["linux-x86-g2-48-l4-4gpu",]
186-
python: ["3.11",]
187-
cuda-version: ["12", "13"]
188-
jaxlib-version: ["head", "pypi_latest"]
189-
enable-x64: [1, 0]
190-
name: "Bazel CUDA Non-RBE with build_jaxlib=false, (jax version = ${{ format('{0}', 'head') }})"
191-
with:
192-
runner: ${{ matrix.runner }}
193-
python: ${{ matrix.python }}
194-
cuda-version: ${{ matrix.cuda-version }}
195-
enable-x64: ${{ matrix.enable-x64 }}
196-
jaxlib-version: ${{ matrix.jaxlib-version }}
197-
# GCS upload URI is the same for both artifact build jobs
198-
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
199-
build_jaxlib: "false"
200-
build_jax: "false"
201-
write_to_bazel_remote_cache: 1
202-
run_multiaccelerator_tests: "true"
203-
204-
run-bazel-test-cuda-py-import:
205-
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
206-
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
207-
# still want to run the tests for other platforms.
208-
if: ${{ !cancelled() }}
209-
uses: ./.github/workflows/bazel_cuda.yml
210-
strategy:
211-
fail-fast: false # don't cancel all jobs on failure
212-
matrix:
213-
# Python values need to match the matrix stategy in the build artifacts job above
214-
runner: ["linux-x86-g2-48-l4-4gpu",]
215-
python: ["3.11"]
216-
cuda-version: ["12", "13"]
217-
enable-x64: [1]
218-
name: "Bazel CUDA Non-RBE with ${{ format('{0}', 'build_jaxlib=wheel') }}"
219-
with:
220-
runner: ${{ matrix.runner }}
221-
python: ${{ matrix.python }}
222-
cuda-version: ${{ matrix.cuda-version }}
223-
enable-x64: ${{ matrix.enable-x64 }}
224-
build_jaxlib: "wheel"
225-
build_jax: "wheel"
226-
jaxlib-version: "head"
227-
write_to_bazel_remote_cache: 1
228-
run_multiaccelerator_tests: "true"
229-
23079
run-pytest-tpu:
23180
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
23281
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
@@ -241,8 +90,7 @@ jobs:
24190
tpu-specs: [
24291
# {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available
24392
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"},
244-
{type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"},
245-
{type: "v7x-8", cores: "8", runner: "linux-x86-tpu7x-224-4tpu"}
93+
{type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"}
24694
]
24795
libtpu-version-type: ["nightly"]
24896
name: "Pytest TPU (JAX artifacts version = ${{ format('{0}', 'head') }})"
@@ -253,32 +101,4 @@ jobs:
253101
python: ${{ matrix.python }}
254102
run-full-tpu-test-suite: "1"
255103
libtpu-version-type: ${{ matrix.libtpu-version-type }}
256-
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
257-
258-
run-bazel-test-tpu:
259-
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
260-
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
261-
# still want to run the tests for other platforms.
262-
if: ${{ !cancelled() }}
263-
uses: ./.github/workflows/bazel_test_tpu.yml
264-
strategy:
265-
fail-fast: false # don't cancel all jobs on failure
266-
matrix:
267-
python: ["3.11"]
268-
tpu-specs: [
269-
{type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
270-
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"},
271-
]
272-
libtpu-version-type: ["nightly"]
273-
name: "Bazel tests TPU (JAX artifacts version = ${{ format('{0}', 'head') }})"
274-
with:
275-
runner: ${{ matrix.tpu-specs.runner }}
276-
cores: ${{ matrix.tpu-specs.cores }}
277-
tpu-type: ${{ matrix.tpu-specs.type }}
278-
python: ${{ matrix.python }}
279-
run-full-tpu-test-suite: "1"
280-
libtpu-version-type: ${{ matrix.libtpu-version-type }}
281-
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
282-
build_jaxlib: "wheel"
283-
build_jax: "wheel"
284-
clone_main_xla: 1
104+
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}

build/test-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ rich
1212
matplotlib
1313
auditwheel
1414
scipy-stubs
15+
pytest-timeout

ci/run_pytest_tpu.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,11 @@ if [[ "$JAXCI_RUN_FULL_TPU_TEST_SUITE" == "1" ]]; then
6969
fi
7070

7171
# Run single-accelerator tests in parallel
72-
JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
72+
JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n=4 --tb=short \
7373
--deselect=tests/pallas/tpu_pallas_call_print_test.py::PallasCallPrintTest \
7474
--deselect=tests/pallas/tpu_sparsecore_pallas_test.py::DebugPrintTest \
7575
--deselect=tests/pallas/tpu_pallas_interpret_thread_map_test.py::InterpretThreadMapTest::test_thread_map \
76-
--maxfail=20 -m "not multiaccelerator" $IGNORE_FLAGS tests examples
76+
--maxfail=20 --dist=loadfile --timeout=600 -m "not multiaccelerator" $IGNORE_FLAGS tests examples
7777

7878
# Store the return value of the first command.
7979
first_cmd_retval=$?
@@ -86,7 +86,7 @@ if [[ "$JAXCI_RUN_FULL_TPU_TEST_SUITE" == "1" ]]; then
8686
else
8787
# Run single-accelerator tests in parallel
8888
JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
89-
--maxfail=20 -m "not multiaccelerator" \
89+
--maxfail=20 --timeout=600 -m "not multiaccelerator" \
9090
tests/pallas/ops_test.py \
9191
tests/pallas/export_back_compat_pallas_test.py \
9292
tests/pallas/export_pallas_test.py \

conftest.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
"""pytest configuration"""
1515

1616
import os
17+
import sys
18+
1719
import pytest
1820

1921

@@ -72,3 +74,18 @@ def pytest_collection() -> None:
7274
os.environ.setdefault(
7375
"CUDA_VISIBLE_DEVICES", str(xdist_worker_number % num_cuda_devices)
7476
)
77+
78+
79+
def pytest_runtest_logreport(report):
80+
# Only look at the setup/call phase
81+
if report.when == 'call':
82+
# Get the worker ID
83+
worker_id = getattr(report, "node", None)
84+
if worker_id:
85+
worker_id = worker_id.gateway.id
86+
else:
87+
worker_id = "master"
88+
89+
# Log to a file named after the worker
90+
with open(f"test_order_{worker_id}.log", "a") as f:
91+
f.write(f"{report.nodeid}\n")

0 commit comments

Comments
 (0)