Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/bazel_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ jobs:
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
halt-dispatch-input: "yes"
- name: "Bazel CUDA tests with build_jax=${{ inputs.build_jax }}, build_jaxlib=${{ inputs.build_jaxlib }}"
timeout-minutes: 60
timeout-minutes: 180
run: ${{ ((inputs.run_multiaccelerator_tests == 'false') && './ci/run_bazel_test_cuda_rbe.sh') || './ci/run_bazel_test_cuda_non_rbe.sh' }}
231 changes: 10 additions & 221 deletions .github/workflows/wheel_tests_continuous.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,175 +31,20 @@ on:
schedule:
- cron: "0 */3 * * *" # Run once every 3 hours
workflow_dispatch: # allows triggering the workflow run manually
pull_request:
branches:
- main
push:
branches:
- main
- 'release/**'


concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}

jobs:
build-jax-artifact:
uses: ./.github/workflows/build_artifacts.yml
name: "Build jax artifact"
with:
# Note that since jax is a pure python package, the runner OS and Python values do not
# matter. In addition, cloning main XLA also has no effect.
runner: "linux-x86-n4-16"
artifact: "jax"
upload_artifacts_to_gcs: true
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'

build-jaxlib-artifact:
uses: ./.github/workflows/build_artifacts.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Runner OS and Python values need to match the matrix stategy in the CPU tests job
runner: ["linux-x86-n4-16", "linux-arm64-t2a-48", "windows-x86-n2-16"]
artifact: ["jaxlib"]
python: ["3.11"]
# Note: For reasons unknown, Github actions groups jobs with the same top-level name in the
# dashboard only if we use an expression in the "name" field. Otherwise, it appends the matrix
# values to the name and creates a separate entry for each matrix combination.
name: "Build ${{ format('{0}', 'jaxlib') }} artifacts"
with:
runner: ${{ matrix.runner }}
artifact: ${{ matrix.artifact }}
python: ${{ matrix.python }}
clone_main_xla: 1
upload_artifacts_to_gcs: true
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'

build-cuda-artifacts:
uses: ./.github/workflows/build_artifacts.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Python values need to match the matrix stategy in the CUDA tests job below
runner: ["linux-x86-n4-16"]
artifact: ["jax-cuda-plugin", "jax-cuda-pjrt"]
python: ["3.11",]
cuda-version: ["12", "13"]
name: "Build ${{ format('{0}', 'CUDA') }} artifacts"
with:
runner: ${{ matrix.runner }}
artifact: ${{ matrix.artifact }}
python: ${{ matrix.python }}
cuda-version: ${{ matrix.cuda-version }}
clone_main_xla: 1
upload_artifacts_to_gcs: true
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'

run-pytest-cpu:
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
# still want to run the tests for other platforms.
if: ${{ !cancelled() }}
needs: [build-jax-artifact, build-jaxlib-artifact]
uses: ./.github/workflows/pytest_cpu.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Runner OS and Python values need to match the matrix stategy in the
# build_jaxlib_artifact job above
runner: ["linux-x86-n4-64", "linux-arm64-t2a-48", "windows-x86-n2-64"]
python: ["3.11",]
enable-x64: [1, 0]
name: "Pytest CPU (JAX artifacts version = ${{ format('{0}', 'head') }})"
with:
runner: ${{ matrix.runner }}
python: ${{ matrix.python }}
enable-x64: ${{ matrix.enable-x64 }}
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}

run-pytest-cuda:
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
# still want to run the tests for other platforms.
if: ${{ !cancelled() }}
needs: [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts]
uses: ./.github/workflows/pytest_cuda.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Python values need to match the matrix stategy in the artifact build jobs above
# See exlusions for what is fully tested
runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu", "linux-x86-a4-224-b200-1gpu"]
python: ["3.11",]
cuda: [
{version: "12.1", use-nvidia-pip-wheels: false},
{version: "12.9", use-nvidia-pip-wheels: true},
{version: "13", use-nvidia-pip-wheels: true},
]
enable-x64: [1, 0]
exclude:
# H100 runs only a single config, CUDA 12.9 Enable x64 1
- runner: "linux-x86-a3-8g-h100-8gpu"
cuda:
version: "12.1"
- runner: "linux-x86-a3-8g-h100-8gpu"
enable-x64: "0"
# B200 runs only a single config, CUDA 12.9 Enable x64 1
- runner: "linux-x86-a4-224-b200-1gpu"
cuda:
version: "12.1"
- runner: "linux-x86-a4-224-b200-1gpu"
enable-x64: "0"

name: "Pytest CUDA (JAX artifacts version = ${{ format('{0}', 'head') }}, CUDA Pip packages = ${{ matrix.cuda.use-nvidia-pip-wheels }})"
with:
runner: ${{ matrix.runner }}
python: ${{ matrix.python }}
cuda-version: ${{ matrix.cuda.version }}
use-nvidia-pip-wheels: ${{ matrix.cuda.use-nvidia-pip-wheels }}
enable-x64: ${{ matrix.enable-x64 }}
# GCS upload URI is the same for both artifact build jobs
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}

run-bazel-test-cpu-py-import:
uses: ./.github/workflows/bazel_cpu.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
runner: ["linux-x86-n4-16", "linux-arm64-t2a-48", "windows-x86-n2-16"]
python: ["3.11",]
enable-x64: [1, 0]
name: "Bazel CPU tests with ${{ format('{0}', 'build_jaxlib=wheel') }}"
with:
runner: ${{ matrix.runner }}
python: ${{ matrix.python }}
enable-x64: ${{ matrix.enable-x64 }}
build_jaxlib: "wheel"
build_jax: "wheel"

run-bazel-test-cuda:
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
# still want to run the tests for other platforms.
if: ${{ !cancelled() }}
needs: [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts]
uses: ./.github/workflows/bazel_cuda.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Python values need to match the matrix stategy in the build artifacts job above
runner: ["linux-x86-g2-48-l4-4gpu",]
python: ["3.11",]
cuda-version: ["12", "13"]
jaxlib-version: ["head", "pypi_latest"]
enable-x64: [1, 0]
name: "Bazel CUDA Non-RBE with build_jaxlib=false, (jax version = ${{ format('{0}', 'head') }})"
with:
runner: ${{ matrix.runner }}
python: ${{ matrix.python }}
cuda-version: ${{ matrix.cuda-version }}
enable-x64: ${{ matrix.enable-x64 }}
jaxlib-version: ${{ matrix.jaxlib-version }}
# GCS upload URI is the same for both artifact build jobs
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
build_jaxlib: "false"
build_jax: "false"
write_to_bazel_remote_cache: 1
run_multiaccelerator_tests: "true"

run-bazel-test-cuda-py-import:
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
Expand All @@ -211,9 +56,9 @@ jobs:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Python values need to match the matrix stategy in the build artifacts job above
runner: ["linux-x86-g2-48-l4-4gpu",]
runner: ["linux-x86-g2-48-l4-4gpu"]
python: ["3.11"]
cuda-version: ["12", "13"]
cuda-version: ["12"]
enable-x64: [1]
name: "Bazel CUDA Non-RBE with ${{ format('{0}', 'build_jaxlib=wheel') }}"
with:
Expand All @@ -226,59 +71,3 @@ jobs:
jaxlib-version: "head"
write_to_bazel_remote_cache: 1
run_multiaccelerator_tests: "true"

run-pytest-tpu:
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
# still want to run the tests for other platforms.
if: ${{ !cancelled() }}
needs: [build-jax-artifact, build-jaxlib-artifact]
uses: ./.github/workflows/pytest_tpu.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
python: ["3.11"]
tpu-specs: [
# {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"},
{type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"},
{type: "v7x-8", cores: "8", runner: "linux-x86-tpu7x-224-4tpu"}
]
libtpu-version-type: ["nightly"]
name: "Pytest TPU (JAX artifacts version = ${{ format('{0}', 'head') }})"
with:
runner: ${{ matrix.tpu-specs.runner }}
cores: ${{ matrix.tpu-specs.cores }}
tpu-type: ${{ matrix.tpu-specs.type }}
python: ${{ matrix.python }}
run-full-tpu-test-suite: "1"
libtpu-version-type: ${{ matrix.libtpu-version-type }}
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}

run-bazel-test-tpu:
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
# still want to run the tests for other platforms.
if: ${{ !cancelled() }}
uses: ./.github/workflows/bazel_test_tpu.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
python: ["3.11"]
tpu-specs: [
{type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"},
]
libtpu-version-type: ["nightly"]
name: "Bazel tests TPU (JAX artifacts version = ${{ format('{0}', 'head') }})"
with:
runner: ${{ matrix.tpu-specs.runner }}
cores: ${{ matrix.tpu-specs.cores }}
tpu-type: ${{ matrix.tpu-specs.type }}
python: ${{ matrix.python }}
run-full-tpu-test-suite: "1"
libtpu-version-type: ${{ matrix.libtpu-version-type }}
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
build_jaxlib: "wheel"
build_jax: "wheel"
clone_main_xla: 1
Loading