Skip to content

Run tests on TPU

Run tests on TPU #1

Workflow file for this run

# name: TPU Tests
# on:
# push:
# branches: [ master ]
# pull_request:
# release:
# types: [created]
# # Only basic permissions are needed now.
# permissions:
# contents: read
# jobs:
# test-in-container:
# name: Test in Custom Container
# runs-on: linux-x86-ct6e-44-1tpu
# # With the correct IAM policies applied to the runner's underlying service accounts,
# # the runner can now pull this private image directly without any in-workflow auth.
# container:
# image: us-central1-docker.pkg.dev/gtech-rmi-dev/keras-docker-images/keras-jax-tpu-amd64:latest
# # Options are still needed for the container to access the host's TPU hardware.
# options: --privileged --network host
# steps:
# - name: Checkout Repository
# uses: actions/checkout@v4
# # This makes your code available inside the container's workspace.
# - name: Run Verification and Tests
# run: |
# echo "Successfully running inside the private container from GAR!"
# echo "Verifying JAX installation..."
# python3 -c "import jax; print(f'JAX backend: {jax.default_backend()}'); print(f'JAX devices: {jax.devices()}')"
# pip install grain
# pytest keras --ignore keras/src/applications \
# --ignore keras/src/layers/merging/merging_test.py \
# --ignore keras/src/trainers/data_adapters/py_dataset_adapter_test.py \
# --ignore keras/src/backend/jax/distribution_lib_test.py \
# --ignore keras/src/distribution/distribution_lib_test.py \
# --cov=keras \
# --cov-config=pyproject.toml
name: Keras Tests on TPU Runner using JAX Backend
on:
push:
branches: [ master ]
pull_request:
release:
types: [created]
# Only basic permissions are needed now.
permissions:
contents: read
jobs:
test-in-container:
name: Run Keras tests on TPU runner using JAX Backend
runs-on: linux-x86-ct6e-44-1tpu
container:
# The container image is now set to python:3.10-slim
image: python:3.10-slim
# Options are still needed for the container to access the host's TPU hardware.
options: --privileged --network host
steps:
- name: Checkout Repository
uses: actions/checkout@v4
- name: Install System Dependencies
run: |
apt-get update && apt-get install -y --no-install-recommends \
git \
sudo \
&& rm -rf /var/lib/apt/lists/*
- name: Install Dependencies
run: |
pip install --no-cache-dir -U pip setuptools && \
pip install --no-cache-dir -U psutil && \
pip install --no-cache-dir -r requirements-jax-tpu.txt && \
pip uninstall -y keras keras-nightly
- name: Set Keras Backend
run: echo "KERAS_BACKEND=jax" >> $GITHUB_ENV
- name: Run Verification and Tests
run: |
echo "Successfully running inside the public python container!"
echo "Verifying JAX installation..."
python3 -c "import jax; print(f'JAX backend: {jax.default_backend()}'); print(f'JAX devices: {jax.devices()[0].device_kind}')"
# pytest keras --ignore keras/src/applications \
# --ignore keras/src/layers/merging/merging_test.py \
# --ignore keras/src/trainers/data_adapters/py_dataset_adapter_test.py \
# --ignore keras/src/backend/jax/distribution_lib_test.py \
# --ignore keras/src/distribution/distribution_lib_test.py \
# --cov=keras \
# --cov-config=pyproject.toml