diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index b1f3b6c8f545..9249b169b893 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -125,6 +125,14 @@ jobs: - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 with: persist-credentials: false + fetch-depth: 0 + - name: Revert the problematic commit + run: | + git config --global --add safe.directory /__w/jax/jax + git config --global user.email "belitskiy@google.com" + git config --global user.name "Vlad Belitskiy" + git revert --no-edit 548eaa5b53afeba91518d4d9274f7198b55cc308 + echo "Commit 548eaa5b53afeba91518d4d9274f7198b55cc308 reverted for this build." - name: Configure Build Environment shell: bash run: | diff --git a/.github/workflows/pytest_tpu.yml b/.github/workflows/pytest_tpu.yml index 2050de4202de..55257cd61852 100644 --- a/.github/workflows/pytest_tpu.yml +++ b/.github/workflows/pytest_tpu.yml @@ -114,16 +114,16 @@ jobs: echo "Unknown libtpu version type: ${{ inputs.libtpu-version-type }}" exit 1 fi - # Halt for testing - - name: Wait For Connection - uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c - with: - halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Run Pytest TPU tests - timeout-minutes: ${{ github.event_name == 'pull_request' && 30 || 210 }} + timeout-minutes: ${{ github.event_name == 'pull_request' && 210 || 210 }} run: | if [[ ${{ inputs.python }} == "3.13-nogil" && ${{ inputs.tpu-type }} == "v5e-8" ]]; then echo "Uninstalling xprof as it is not compatible with python 3.13t." $JAXCI_PYTHON -m uv pip uninstall xprof fi ./ci/run_pytest_tpu.sh + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c + with: + halt-dispatch-input: ${{ inputs.tpu-type == 'v6e-8' && 'yes' || 'no' }} diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index 07e2eb4fd978..cc5c31c6f55f 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -24,14 +24,21 @@ # runs Bazel TPU tests with py_import. name: CI - Wheel Tests (Continuous) -permissions: - contents: read on: schedule: - cron: "0 */3 * * *" # Run once every 3 hours workflow_dispatch: # allows triggering the workflow run manually + pull_request: + branches: + - main + push: + branches: + - main + - 'release/**' +permissions: {} + concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }} @@ -54,7 +61,7 @@ jobs: fail-fast: false # don't cancel all jobs on failure matrix: # Runner OS and Python values need to match the matrix stategy in the CPU tests job - runner: ["linux-x86-n4-16", "linux-arm64-t2a-48", "windows-x86-n2-16"] + runner: ["linux-x86-n4-16"] artifact: ["jaxlib"] python: ["3.11"] # Note: For reasons unknown, Github actions groups jobs with the same top-level name in the @@ -69,164 +76,6 @@ jobs: upload_artifacts_to_gcs: true gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' - build-cuda-artifacts: - uses: ./.github/workflows/build_artifacts.yml - strategy: - fail-fast: false # don't cancel all jobs on failure - matrix: - # Python values need to match the matrix stategy in the CUDA tests job below - runner: ["linux-x86-n4-16"] - artifact: ["jax-cuda-plugin", "jax-cuda-pjrt"] - python: ["3.11",] - cuda-version: ["12", "13"] - name: "Build ${{ format('{0}', 'CUDA') }} artifacts" - with: - runner: ${{ matrix.runner }} - artifact: ${{ matrix.artifact }} - python: ${{ matrix.python }} - cuda-version: ${{ matrix.cuda-version }} - clone_main_xla: 1 - upload_artifacts_to_gcs: true - gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' - - run-pytest-cpu: - # Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated - # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we - # still want to run the tests for other platforms. - if: ${{ !cancelled() }} - needs: [build-jax-artifact, build-jaxlib-artifact] - uses: ./.github/workflows/pytest_cpu.yml - strategy: - fail-fast: false # don't cancel all jobs on failure - matrix: - # Runner OS and Python values need to match the matrix stategy in the - # build_jaxlib_artifact job above - runner: ["linux-x86-n4-64", "linux-arm64-t2a-48", "windows-x86-n2-64"] - python: ["3.11",] - enable-x64: [1, 0] - name: "Pytest CPU (JAX artifacts version = ${{ format('{0}', 'head') }})" - with: - runner: ${{ matrix.runner }} - python: ${{ matrix.python }} - enable-x64: ${{ matrix.enable-x64 }} - gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} - - run-pytest-cuda: - # Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated - # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we - # still want to run the tests for other platforms. - if: ${{ !cancelled() }} - needs: [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts] - uses: ./.github/workflows/pytest_cuda.yml - strategy: - fail-fast: false # don't cancel all jobs on failure - matrix: - # Python values need to match the matrix stategy in the artifact build jobs above - # See exlusions for what is fully tested - runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu", "linux-x86-a4-224-b200-1gpu"] - python: ["3.11",] - cuda: [ - {version: "12.1", use-nvidia-pip-wheels: false}, - {version: "12.9", use-nvidia-pip-wheels: true}, - {version: "13", use-nvidia-pip-wheels: true}, - ] - enable-x64: [1, 0] - exclude: - # H100 runs only a single config, CUDA 12.9 Enable x64 1 - - runner: "linux-x86-a3-8g-h100-8gpu" - cuda: - version: "12.1" - - runner: "linux-x86-a3-8g-h100-8gpu" - enable-x64: "0" - # B200 runs only a single config, CUDA 12.9 Enable x64 1 - - runner: "linux-x86-a4-224-b200-1gpu" - cuda: - version: "12.1" - - runner: "linux-x86-a4-224-b200-1gpu" - enable-x64: "0" - - name: "Pytest CUDA (JAX artifacts version = ${{ format('{0}', 'head') }}, CUDA Pip packages = ${{ matrix.cuda.use-nvidia-pip-wheels }})" - with: - runner: ${{ matrix.runner }} - python: ${{ matrix.python }} - cuda-version: ${{ matrix.cuda.version }} - use-nvidia-pip-wheels: ${{ matrix.cuda.use-nvidia-pip-wheels }} - enable-x64: ${{ matrix.enable-x64 }} - # GCS upload URI is the same for both artifact build jobs - gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} - - run-bazel-test-cpu-py-import: - uses: ./.github/workflows/bazel_cpu.yml - strategy: - fail-fast: false # don't cancel all jobs on failure - matrix: - runner: ["linux-x86-n4-16", "linux-arm64-t2a-48", "windows-x86-n2-16"] - python: ["3.11",] - enable-x64: [1, 0] - name: "Bazel CPU tests with ${{ format('{0}', 'build_jaxlib=wheel') }}" - with: - runner: ${{ matrix.runner }} - python: ${{ matrix.python }} - enable-x64: ${{ matrix.enable-x64 }} - build_jaxlib: "wheel" - build_jax: "wheel" - - run-bazel-test-cuda: - # Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated - # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we - # still want to run the tests for other platforms. - if: ${{ !cancelled() }} - needs: [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts] - uses: ./.github/workflows/bazel_cuda.yml - strategy: - fail-fast: false # don't cancel all jobs on failure - matrix: - # Python values need to match the matrix stategy in the build artifacts job above - runner: ["linux-x86-g2-48-l4-4gpu",] - python: ["3.11",] - cuda-version: ["12", "13"] - jaxlib-version: ["head", "pypi_latest"] - enable-x64: [1, 0] - name: "Bazel CUDA Non-RBE with build_jaxlib=false, (jax version = ${{ format('{0}', 'head') }})" - with: - runner: ${{ matrix.runner }} - python: ${{ matrix.python }} - cuda-version: ${{ matrix.cuda-version }} - enable-x64: ${{ matrix.enable-x64 }} - jaxlib-version: ${{ matrix.jaxlib-version }} - # GCS upload URI is the same for both artifact build jobs - gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} - build_jaxlib: "false" - build_jax: "false" - write_to_bazel_remote_cache: 1 - run_multiaccelerator_tests: "true" - - run-bazel-test-cuda-py-import: - # Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated - # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we - # still want to run the tests for other platforms. - if: ${{ !cancelled() }} - uses: ./.github/workflows/bazel_cuda.yml - strategy: - fail-fast: false # don't cancel all jobs on failure - matrix: - # Python values need to match the matrix stategy in the build artifacts job above - runner: ["linux-x86-g2-48-l4-4gpu",] - python: ["3.11"] - cuda-version: ["12", "13"] - enable-x64: [1] - name: "Bazel CUDA Non-RBE with ${{ format('{0}', 'build_jaxlib=wheel') }}" - with: - runner: ${{ matrix.runner }} - python: ${{ matrix.python }} - cuda-version: ${{ matrix.cuda-version }} - enable-x64: ${{ matrix.enable-x64 }} - build_jaxlib: "wheel" - build_jax: "wheel" - jaxlib-version: "head" - write_to_bazel_remote_cache: 1 - run_multiaccelerator_tests: "true" - run-pytest-tpu: # Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we @@ -241,8 +90,7 @@ jobs: tpu-specs: [ # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}, - {type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"}, - {type: "v7x-8", cores: "8", runner: "linux-x86-tpu7x-224-4tpu"} + {type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"} ] libtpu-version-type: ["nightly"] name: "Pytest TPU (JAX artifacts version = ${{ format('{0}', 'head') }})" @@ -266,8 +114,7 @@ jobs: matrix: python: ["3.11"] tpu-specs: [ - {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, - {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}, + {type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"} ] libtpu-version-type: ["nightly"] name: "Bazel tests TPU (JAX artifacts version = ${{ format('{0}', 'head') }})" @@ -281,4 +128,4 @@ jobs: gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} build_jaxlib: "wheel" build_jax: "wheel" - clone_main_xla: 1 \ No newline at end of file + clone_main_xla: 1 diff --git a/build/test-requirements.txt b/build/test-requirements.txt index 453ebe4e18ae..d8a5c08daa35 100644 --- a/build/test-requirements.txt +++ b/build/test-requirements.txt @@ -12,3 +12,4 @@ rich matplotlib auditwheel scipy-stubs +# pytest-timeout diff --git a/ci/run_pytest_tpu.sh b/ci/run_pytest_tpu.sh index abb45cbe10e8..6e47068a2f23 100755 --- a/ci/run_pytest_tpu.sh +++ b/ci/run_pytest_tpu.sh @@ -73,7 +73,7 @@ if [[ "$JAXCI_RUN_FULL_TPU_TEST_SUITE" == "1" ]]; then --deselect=tests/pallas/tpu_pallas_call_print_test.py::PallasCallPrintTest \ --deselect=tests/pallas/tpu_sparsecore_pallas_test.py::DebugPrintTest \ --deselect=tests/pallas/tpu_pallas_interpret_thread_map_test.py::InterpretThreadMapTest::test_thread_map \ - --maxfail=20 -m "not multiaccelerator" $IGNORE_FLAGS tests examples + --maxfail=20 --dist=loadfile -m "not multiaccelerator" $IGNORE_FLAGS tests examples # Store the return value of the first command. first_cmd_retval=$? diff --git a/conftest.py b/conftest.py index fa0e6de94346..673f3035c718 100644 --- a/conftest.py +++ b/conftest.py @@ -14,6 +14,8 @@ """pytest configuration""" import os +import sys + import pytest @@ -72,3 +74,18 @@ def pytest_collection() -> None: os.environ.setdefault( "CUDA_VISIBLE_DEVICES", str(xdist_worker_number % num_cuda_devices) ) + + +def pytest_runtest_logreport(report): + # Only look at the setup/call phase + if report.when == 'call': + # Get the worker ID + worker_id = getattr(report, "node", None) + if worker_id: + worker_id = worker_id.gateway.id + else: + worker_id = "master" + + # Log to a file named after the worker + with open(f"test_order_{worker_id}.log", "a") as f: + f.write(f"{report.nodeid}\n") diff --git a/tests/pallas/tpu_sparsecore_pallas_debug_check_test.py b/tests/pallas/tpu_sparsecore_pallas_debug_check_test.py index 55cb1ce37c01..d8efcf5be0c5 100644 --- a/tests/pallas/tpu_sparsecore_pallas_debug_check_test.py +++ b/tests/pallas/tpu_sparsecore_pallas_debug_check_test.py @@ -67,6 +67,7 @@ def setUp(self): super().setUp() + @unittest.skip("Failing on all TPU versions: b/436509694") def test_scalar_debug_check(self): if not jtu.is_device_tpu_at_least(7): # TODO: b/469486032 - Figure out why the test gets stuck on v5p, v6e.