Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 10 additions & 17 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Use a base image with Python and Git
FROM python:3.10-slim
# Use Python 3.12 as the base image
FROM python:3.12-slim-bullseye

# Install Git
RUN apt-get update && apt-get install -y git
Expand All @@ -12,29 +12,22 @@ RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyri
# Install the Google Cloud SDK
RUN apt-get update && apt-get install -y google-cloud-sdk

# Set the default Python version to 3.10
RUN update-alternatives --install /usr/bin/python3 python3 /usr/local/bin/python3.10 1
# Set the default Python version to 3.12
RUN update-alternatives --install /usr/bin/python3 python3 /usr/local/bin/python3.12 1

# Set environment variables for Google Cloud SDK and Python 3.10
ENV PATH="/usr/local/google-cloud-sdk/bin:/usr/local/bin/python3.10:${PATH}"
# Set environment variables for Google Cloud SDK and Python 3.12
ENV PATH="/usr/local/google-cloud-sdk/bin:/usr/local/bin/python3.12:${PATH}"

# Set the working directory
WORKDIR /app

# Clone the repository
RUN git clone https://github.com/AI-Hypercomputer/accelerator-microbenchmarks.git

# Navigate to the repository directory
WORKDIR /app/accelerator-microbenchmarks

# Clone the local repository
COPY . .

# Install dependencies
RUN pip install --upgrade pip && \
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html && \
pip install --upgrade clu tensorflow tensorflow-datasets && \
pip install jsonlines && \
pip install ray[default]
pip install -r requirements.txt

# Set environment variables
ENV JAX_PLATFORMS=tpu,cpu \
ENABLE_PJRT_COMPATIBILITY=true