Skip to content

Commit 7c575c1

Browse files
markomitosevgri243
authored andcommitted
Add linux/aarch64 build configuration
* Added matrix build with both aarch64 and x86_64 to github/actions * Added dynamic platform wheel name generator script to build_python_package.sh
1 parent cca8ab5 commit 7c575c1

File tree

7 files changed

+129
-28
lines changed

7 files changed

+129
-28
lines changed

.bazelrc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,17 @@ build --define=allow_oversize_protos=true
3333
# explicit __init__.py files.
3434
# build --incompatible_default_to_explicit_init_py
3535

36+
# TODO: find a way to automatically actvate os_arch specific flags
37+
# Configuration for Linux on AMD64 (x86_64)
3638
# Haswell processor and later optimizations. This covers most processors deployed
3739
# today, includin Colab CPU runtimes.
38-
build --copt=-march=haswell
39-
build --host_copt=-march=haswell
40+
build:linux_x86_64 --copt=-march=haswell
41+
build:linux_x86_64 --host_copt=-march=haswell
42+
43+
# Configuration for Linux on ARM64 (aarch64)
44+
build:linux_aarch64 --copt=-march=armv8-a
45+
build:linux_aarch64 --host_copt=-march=armv8-a
46+
4047
# Only use level three optimizations for target, not necessarily for host
4148
# since host artifacts don't need to be fast.
4249
build --copt=-O3

.devcontainer/devcontainer.json

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
{
2+
"name": "TensorFlow Federated (DevContainer)",
3+
image: "mcr.microsoft.com/devcontainers/base:ubuntu-22.04",
4+
"features": {
5+
"ghcr.io/devcontainers-community/features/bazel": {
6+
"bazelisk_version": "latest"
7+
},
8+
"ghcr.io/devcontainers/features/python" : {
9+
"version": "3.10",
10+
"toolsToInstall": "black,pytest,pylint,isort,twine"
11+
},
12+
"ghcr.io/devcontainers-community/features/llvm": {
13+
"version": "17"
14+
},
15+
"ghcr.io/devcontainers/features/java": {
16+
"jdkDistro": "open"
17+
}
18+
},
19+
"mounts": [
20+
{ "target": "/root/.cache/bazel", "source": "jetbrains-federated-devcontainer-cache-bazel", "type": "volume" },
21+
{ "target": "/root/.cache/pip/", "source": "jetbrains-federated-devcontainer-cache-pip", "type": "volume" },
22+
],
23+
"customizations": {
24+
"jetbrains": {
25+
"backend": "PyCharm",
26+
"plugins": [
27+
"com.google.idea.bazel.ijwb"
28+
]
29+
}
30+
},
31+
"containerUser": "root",
32+
"postCreateCommand": ".devcontainer/post-create.sh",
33+
"waitFor": "postStartCommand",
34+
}

.devcontainer/post-create.sh

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#!/usr/bin/env bash
2+
set -e
3+
exec 2>&1
4+
5+
PROJECT_DIR=$(pwd)
6+
7+
echo "Installing required python packages..."
8+
9+
if nvidia-smi &> /dev/null; then
10+
GPU_EXTRA="tensorflow[and-cuda]"
11+
fi
12+
13+
# shellcheck disable=SC2086
14+
python3 -m pip install --root-user-action=ignore -r $PROJECT_DIR/requirements.txt ${GPU_EXTRA}

.github/workflows/publish.yaml

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ jobs:
3737
# Only if:
3838
# * Repository is not a fork.
3939
# * Branch is `main` (for workflow_dispatch trigger).
40-
if: |
41-
github.repository == 'google-parfait/tensorflow-federated'
42-
&& github.ref == 'refs/heads/main'
43-
runs-on: ubuntu-latest
40+
# if: |
41+
# github.repository == 'google-parfait/tensorflow-federated'
42+
# && github.ref == 'refs/heads/main'
43+
runs-on: ubuntu-22.04
4444
timeout-minutes: 5
4545
permissions:
4646
contents: write # Required to create a release.
@@ -97,8 +97,16 @@ jobs:
9797
build-package:
9898
name: Build Package
9999
needs: [publish-release]
100-
runs-on: ubuntu-20.04
101-
timeout-minutes: 60
100+
strategy:
101+
matrix:
102+
include:
103+
- os: ubuntu-22.04
104+
arch: x86_64
105+
- os: ubuntu-22.04-arm
106+
arch: aarch64
107+
runs-on: ${{ matrix.os }}
108+
timeout-minutes: 360
109+
environment: release
102110
steps:
103111

104112
- name: Checkout repository
@@ -114,11 +122,11 @@ jobs:
114122
create_credentials_file: true
115123

116124
- name: Set up bazel repository cache
117-
uses: actions/cache@v4.0.2
125+
uses: actions/cache@v4.2.0
118126
with:
119127
path: "~/.cache/bazel/"
120-
key: ${{ runner.os }}-bazel-${{ hashFiles('.bazelversion', '.bazelrc', 'WORKSPACE') }}
121-
restore-keys: ${{ runner.os }}-bazel-
128+
key: ${{ runner.os }}-${{ matrix.arch }}-bazel-${{ hashFiles('.bazelversion', '.bazelrc', 'WORKSPACE') }}
129+
restore-keys: ${{ runner.os }}-${{ matrix.arch }}-bazel
122130

123131
- name: Set up Python
124132
uses: actions/[email protected]
@@ -130,26 +138,39 @@ jobs:
130138
run: |
131139
pip install --upgrade "pip"
132140
141+
- name: Install LLVM and Clang
142+
uses: KyleMayes/install-llvm-action@v2
143+
with:
144+
version: "14.0"
145+
133146
- name: Build Python package
134147
run: |
135148
pip install --upgrade "numpy~=1.25"
136149
bazelisk run //tools/python_package:build_python_package \
137150
--build_tag_filters="-nokokoro,-nopresubmit,-requires-gpu-nvidia" \
138151
--google_credentials="${{ steps.auth.outputs.credentials_file_path }}" \
139-
--remote_cache="https://storage.googleapis.com/tensorflow-federated-bazel-cache/${{ github.job }}" \
152+
--remote_cache="https://storage.googleapis.com/${{ vars.BAZEL_CACHE_BUCKET }}/${{ github.job }}" \
153+
--config="linux_${{ matrix.arch }}" \
140154
-- \
141155
--output_dir="${{ github.workspace }}/dist/"
142156
143157
- name: Upload Python package
144158
uses: actions/[email protected]
145159
with:
146-
name: python-package-distributions
160+
name: python-package-distributions-${{ matrix.arch }}
147161
path: dist/*.whl
148162

149163
test-package:
150164
name: Test Package
151165
needs: [build-package]
152-
runs-on: ubuntu-20.04
166+
strategy:
167+
matrix:
168+
include:
169+
- os: ubuntu-22.04
170+
arch: x86_64
171+
- os: ubuntu-22.04-arm
172+
arch: aarch64
173+
runs-on: ${{ matrix.os }}
153174
timeout-minutes: 5
154175
steps:
155176

@@ -161,7 +182,7 @@ jobs:
161182
- name: Download Python package
162183
uses: actions/[email protected]
163184
with:
164-
name: python-package-distributions
185+
name: python-package-distributions-${{ matrix.arch }}
165186
path: dist/
166187

167188
- name: Set up Python
@@ -190,17 +211,24 @@ jobs:
190211
publish-package:
191212
name: Publish Package
192213
needs: [build-package, test-package]
193-
runs-on: ubuntu-latest
214+
runs-on: ubuntu-22.04
194215
timeout-minutes: 5
195216
permissions:
196217
id-token: write # Required for trusted publishing.
218+
environment: release
197219
steps:
198220

199221
- name: Download Python package
200222
uses: actions/[email protected]
201223
with:
202-
name: python-package-distributions
224+
pattern: python-package-distributions-*
225+
merge-multiple: true
203226
path: dist/
204227

205228
- name: Publish Python package
206229
uses: pypa/[email protected]
230+
with:
231+
password: ${{ secrets.PYTHON_REPOSITORY_TOKEN }}
232+
repository-url: ${{ vars.PYTHON_REPOSITORY_URL }}
233+
verbose: 'true'
234+
user: ${{ secrets.PYTHON_REPOSITORY_USER }}

pyproject.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,8 @@ dependencies = [
3939
'dp-accounting==0.4.3',
4040
'google-vizier==0.1.11',
4141
'grpcio~=1.46',
42-
'jaxlib==0.4.14',
43-
'jax==0.4.14',
44-
'ml_dtypes>=0.2.0,==0.2.*',
42+
'jaxlib==0.4.18',
43+
'jax==0.4.18',
4544
'numpy~=1.25',
4645
'portpicker~=1.6',
4746
'scipy~=1.9.3',

requirements.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,8 @@ dm-tree==0.1.8
3232
dp-accounting==0.4.3
3333
google-vizier==0.1.11
3434
grpcio~=1.46
35-
jaxlib==0.4.14
36-
jax==0.4.14
37-
ml_dtypes>=0.2.0,==0.2.*
35+
jaxlib==0.4.18
36+
jax==0.4.18
3837
numpy~=1.25
3938
portpicker~=1.6
4039
scipy~=1.9.3

tools/python_package/build_python_package.sh

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,30 @@ main() {
4949
fi
5050

5151
# Check the GLIBC version.
52-
local expected_glibc="2.31"
53-
if ! ldd --version | grep --quiet "${expected_glibc}"; then
54-
echo "error: expected GLIBC version to be '${expected_glibc}', found:" 1>&2
55-
ldd --version 1>&2
52+
glibc_version=$(ldd --version 2>&1 | grep "GLIBC" | awk '{print $NF}')
53+
54+
# Error handling if GLIBC version couldn't be determined.
55+
if [[ -z "$glibc_version" ]]; then
56+
echo "error: Could not determine GLIBC version." 1>&2
5657
exit 1
5758
fi
5859

60+
echo "Detected GLIBC version: $glibc_version"
61+
62+
# Extract major and minor version numbers for manylinux tag.
63+
IFS='.' read -r glibc_major glibc_minor <<< "$glibc_version"
64+
manylinux_version="${glibc_major}_${glibc_minor}"
65+
66+
# Detect architecture.
67+
arch=$(uname -m)
68+
case "$arch" in
69+
aarch64|x86_64) ;; # Supported architectures
70+
*) echo "error: Unsupported architecture: $arch" >&2; exit 1 ;;
71+
esac
72+
73+
plat_name="manylinux_${manylinux_version}_${arch}"
74+
75+
5976
# Create a temp directory.
6077
local temp_dir="$(mktemp --directory)"
6178
trap "rm -rf ${temp_dir}" EXIT
@@ -68,7 +85,10 @@ main() {
6885
pip --version
6986

7087
# Build the Python package.
71-
pip install --upgrade "build"
88+
pip install --upgrade "build" "toml-cli"
89+
90+
# Update wheel platform
91+
toml set --toml-path "pyproject.toml" "tool.distutils.bdist_wheel.plat-name" "$plat_name"
7292
pip freeze
7393
python -m build --outdir "${output_dir}"
7494
}

0 commit comments

Comments
 (0)