Skip to content

Commit 6852b96

Browse files
Created separate MNIST dataset download script to add provision of downloading dataset using specified storage bucket and run as a pre-requisite to use it for distributed KFTO training
1 parent 9fe0af4 commit 6852b96

File tree

6 files changed

+139
-14
lines changed

6 files changed

+139
-14
lines changed

go.mod

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ require (
77
github.com/matoous/go-nanoid/v2 v2.1.0
88
github.com/onsi/gomega v1.32.0
99
github.com/project-codeflare/appwrapper v0.8.0
10-
github.com/project-codeflare/codeflare-common v0.0.0-20241211130338-efe4f3e6f904
10+
github.com/project-codeflare/codeflare-common v0.0.0-20250128135036-f501cd31fe8b
1111
github.com/prometheus/client_golang v1.20.4
1212
github.com/prometheus/common v0.57.0
13-
github.com/ray-project/kuberay/ray-operator v1.1.0-alpha.0
13+
github.com/ray-project/kuberay/ray-operator v1.1.1
1414
k8s.io/api v0.30.8
1515
k8s.io/apimachinery v0.30.8
1616
k8s.io/client-go v0.30.8

go.sum

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -366,8 +366,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
366366
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
367367
github.com/project-codeflare/appwrapper v0.8.0 h1:vWHNtXUtHutN2EzYb6rryLdESnb8iDXsCokXOuNYXvg=
368368
github.com/project-codeflare/appwrapper v0.8.0/go.mod h1:FMQ2lI3fz6LakUVXgN1FTdpsc3BBkNIZZgtMmM9J5UM=
369-
github.com/project-codeflare/codeflare-common v0.0.0-20241211130338-efe4f3e6f904 h1:brU4j1V4o+z/sw0TGi360Wdjk1TEQ313ynBRGqSTaNU=
370-
github.com/project-codeflare/codeflare-common v0.0.0-20241211130338-efe4f3e6f904/go.mod h1:v7XKwaDoCspsHQlWJNarO7gOpR+iumSS+c1bWs3kJOI=
369+
github.com/project-codeflare/codeflare-common v0.0.0-20250128135036-f501cd31fe8b h1:MOmv/aLx/kcHd7PBErx8XNSTW180s8Slf/uVM0uV4rw=
370+
github.com/project-codeflare/codeflare-common v0.0.0-20250128135036-f501cd31fe8b/go.mod h1:DPSv5khRiRDFUD43SF8da+MrVQTWmxNhuKJmwSLOyO0=
371371
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
372372
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
373373
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
@@ -394,8 +394,8 @@ github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1
394394
github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
395395
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
396396
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
397-
github.com/ray-project/kuberay/ray-operator v1.1.0-alpha.0 h1:m3knC3mjkQEmMj61DY73210mKVSWEGtFKn0uQ6RLwao=
398-
github.com/ray-project/kuberay/ray-operator v1.1.0-alpha.0/go.mod h1:ZqyKKvMP5nKDldQoKmur+Wcx7wVlV9Q98phFqHzr+KY=
397+
github.com/ray-project/kuberay/ray-operator v1.1.1 h1:mVOA1ddS9aAsPvhhHrpf0ZXgTzccIAyTbeYeDqtcfAk=
398+
github.com/ray-project/kuberay/ray-operator v1.1.1/go.mod h1:ZqyKKvMP5nKDldQoKmur+Wcx7wVlV9Q98phFqHzr+KY=
399399
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
400400
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
401401
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=

tests/kfto/kfto_mnist_training_test.go

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ func runKFTOPyTorchMnistJob(t *testing.T, accelerator Accelerator, image string,
6161
namespace := test.NewTestNamespace()
6262

6363
mnist := ReadFile(test, "resources/mnist.py")
64+
download_mnist_dataset := ReadFile(test, "resources/download_mnist_datasets.py")
6465
requirementsFileName := ReadFile(test, requirementsFile)
6566

6667
if accelerator.isGpu() {
@@ -70,8 +71,9 @@ func runKFTOPyTorchMnistJob(t *testing.T, accelerator Accelerator, image string,
7071
}
7172
config := CreateConfigMap(test, namespace.Name, map[string][]byte{
7273
// MNIST Ray Notebook
73-
"mnist.py": mnist,
74-
"requirements.txt": requirementsFileName,
74+
"mnist.py": mnist,
75+
"download_mnist_datasets.py": download_mnist_dataset,
76+
"requirements.txt": requirementsFileName,
7577
})
7678

7779
// Create training PyTorch job
@@ -117,6 +119,12 @@ func createKFTOPyTorchMnistJob(test Test, namespace string, config corev1.Config
117119
backend = "gloo"
118120
}
119121

122+
storage_bucket_endpoint, storage_bucket_endpoint_exists := GetStorageBucketDefaultEndpoint()
123+
storage_bucket_access_key_id, storage_bucket_access_key_id_exists := GetStorageBucketAccessKeyId()
124+
storage_bucket_secret_key, storage_bucket_secret_key_exists := GetStorageBucketSecretKey()
125+
storage_bucket_name, storage_bucket_name_exists := GetStorageBucketName()
126+
storage_bucket_mnist_dir, storage_bucket_mnist_dir_exists := GetStorageBucketMnistDir()
127+
120128
tuningJob := &kftov1.PyTorchJob{
121129
TypeMeta: metav1.TypeMeta{
122130
APIVersion: corev1.SchemeGroupVersion.String(),
@@ -162,8 +170,7 @@ func createKFTOPyTorchMnistJob(test Test, namespace string, config corev1.Config
162170
fmt.Sprintf(`mkdir -p /tmp/lib /tmp/datasets/mnist && export PYTHONPATH=$PYTHONPATH:/tmp/lib && \
163171
pip install --no-cache-dir -r /mnt/files/requirements.txt --target=/tmp/lib && \
164172
echo "Downloading MNIST dataset..." && \
165-
python3 -c "from torchvision.datasets import MNIST; from torchvision.transforms import Compose, ToTensor; \
166-
MNIST('/tmp/datasets/mnist', train=False, download=True, transform=Compose([ToTensor()]))" && \
173+
python3 /mnt/files/download_mnist_datasets.py --dataset_path "/tmp/datasets/mnist" && \
167174
echo -e "\n\n Dataset downloaded to /tmp/datasets/mnist" && ls -R /tmp/datasets/mnist && \
168175
echo -e "\n\n Starting training..." && \
169176
torchrun --nproc_per_node=%d /mnt/files/mnist.py --dataset_path "/tmp/datasets/mnist" --epochs 7 --save_every 2 --batch_size 128 --lr 0.001 --snapshot_path "mnist_snapshot.pt" --backend %s`, numProcPerNode, backend),
@@ -247,8 +254,7 @@ func createKFTOPyTorchMnistJob(test Test, namespace string, config corev1.Config
247254
fmt.Sprintf(`mkdir -p /tmp/lib /tmp/datasets/mnist && export PYTHONPATH=$PYTHONPATH:/tmp/lib && \
248255
pip install --no-cache-dir -r /mnt/files/requirements.txt --target=/tmp/lib && \
249256
echo "Downloading MNIST dataset..." && \
250-
python3 -c "from torchvision.datasets import MNIST; from torchvision.transforms import Compose, ToTensor; \
251-
MNIST('/tmp/datasets/mnist', train=False, download=True, transform=Compose([ToTensor()]))" && \
257+
python3 /mnt/files/download_mnist_datasets.py --dataset_path "/tmp/datasets/mnist" && \
252258
echo -e "\n\n Dataset downloaded to /tmp/datasets/mnist" && ls -R /tmp/datasets/mnist && \
253259
echo -e "\n\n Starting training..." && \
254260
torchrun --nproc_per_node=%d /mnt/files/mnist.py --dataset_path "/tmp/datasets/mnist" --epochs 7 --save_every 2 --batch_size 128 --lr 0.001 --snapshot_path "mnist_snapshot.pt" --backend %s`, numProcPerNode, backend),
@@ -344,6 +350,34 @@ func createKFTOPyTorchMnistJob(test Test, namespace string, config corev1.Config
344350
}
345351
}
346352

353+
if storage_bucket_endpoint_exists && storage_bucket_access_key_id_exists && storage_bucket_secret_key_exists && storage_bucket_name_exists && storage_bucket_mnist_dir_exists {
354+
storage_bucket_env_vars := []corev1.EnvVar{
355+
{
356+
Name: "AWS_DEFAULT_ENDPOINT",
357+
Value: storage_bucket_endpoint,
358+
},
359+
{
360+
Name: "AWS_ACCESS_KEY_ID",
361+
Value: storage_bucket_access_key_id,
362+
},
363+
{
364+
Name: "AWS_SECRET_ACCESS_KEY",
365+
Value: storage_bucket_secret_key,
366+
},
367+
{
368+
Name: "AWS_STORAGE_BUCKET",
369+
Value: storage_bucket_name,
370+
},
371+
{
372+
Name: "AWS_STORAGE_BUCKET_MNIST_DIR",
373+
Value: storage_bucket_mnist_dir,
374+
},
375+
}
376+
377+
tuningJob.Spec.PyTorchReplicaSpecs[kftov1.PyTorchJobReplicaTypeMaster].Template.Spec.Containers[0].Env = append(tuningJob.Spec.PyTorchReplicaSpecs[kftov1.PyTorchJobReplicaTypeMaster].Template.Spec.Containers[0].Env, storage_bucket_env_vars...)
378+
tuningJob.Spec.PyTorchReplicaSpecs[kftov1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers[0].Env = append(tuningJob.Spec.PyTorchReplicaSpecs[kftov1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers[0].Env, storage_bucket_env_vars...)
379+
}
380+
347381
tuningJob, err := test.Client().Kubeflow().KubeflowV1().PyTorchJobs(namespace).Create(test.Ctx(), tuningJob, metav1.CreateOptions{})
348382
test.Expect(err).NotTo(HaveOccurred())
349383
test.T().Logf("Created PytorchJob %s/%s successfully", tuningJob.Namespace, tuningJob.Name)
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import os, gzip, shutil
2+
from minio import Minio
3+
from torchvision import datasets
4+
from torchvision.transforms import Compose, ToTensor
5+
6+
def main(dataset_path):
7+
# Download and Load dataset
8+
if all(var in os.environ for var in ["AWS_DEFAULT_ENDPOINT","AWS_ACCESS_KEY_ID","AWS_SECRET_ACCESS_KEY","AWS_STORAGE_BUCKET","AWS_STORAGE_BUCKET_MNIST_DIR"]):
9+
print("Using provided storage bucket to download datasets...")
10+
dataset_dir = os.path.join(dataset_path, "MNIST/raw")
11+
endpoint = os.environ.get("AWS_DEFAULT_ENDPOINT")
12+
access_key = os.environ.get("AWS_ACCESS_KEY_ID")
13+
secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY")
14+
bucket_name = os.environ.get("AWS_STORAGE_BUCKET")
15+
print(f"Storage bucket endpoint: {endpoint}")
16+
print(f"Storage bucket name: {bucket_name}\n")
17+
18+
# remove prefix if specified in storage bucket endpoint url
19+
secure = True
20+
if endpoint.startswith("https://"):
21+
endpoint = endpoint[len("https://") :]
22+
elif endpoint.startswith("http://"):
23+
endpoint = endpoint[len("http://") :]
24+
secure = False
25+
26+
client = Minio(
27+
endpoint,
28+
access_key=access_key,
29+
secret_key=secret_key,
30+
cert_check=False,
31+
secure=secure
32+
)
33+
if not os.path.exists(dataset_dir):
34+
os.makedirs(dataset_dir)
35+
else:
36+
print(f"Directory '{dataset_dir}' already exists")
37+
38+
# To download datasets from storage bucket's specific directory, use prefix to provide directory name
39+
prefix=os.environ.get("AWS_STORAGE_BUCKET_MNIST_DIR")
40+
print(f"Storage bucket MNIST directory prefix: {prefix}\n")
41+
42+
# download all files from prefix folder of storage bucket recursively
43+
for item in client.list_objects(
44+
bucket_name, prefix=prefix, recursive=True
45+
):
46+
file_name=item.object_name[len(prefix)+1:]
47+
dataset_file_path = os.path.join(dataset_dir, file_name)
48+
print(f"Downloading dataset file {file_name} to {dataset_file_path}..")
49+
if not os.path.exists(dataset_file_path):
50+
client.fget_object(
51+
bucket_name, item.object_name, dataset_file_path
52+
)
53+
# Unzip files --
54+
## Sample zipfilepath : ../data/MNIST/raw/t10k-images-idx3-ubyte.gz
55+
with gzip.open(dataset_file_path, "rb") as f_in:
56+
filename=file_name.split(".")[0] #-> t10k-images-idx3-ubyte
57+
file_path=("/".join(dataset_file_path.split("/")[:-1])) #->../data/MNIST/raw
58+
full_file_path=os.path.join(file_path,filename) #->../data/MNIST/raw/t10k-images-idx3-ubyte
59+
print(f"Extracting {dataset_file_path} to {file_path}..")
60+
61+
with open(full_file_path, "wb") as f_out:
62+
shutil.copyfileobj(f_in, f_out)
63+
print(f"Dataset file downloaded : {full_file_path}\n")
64+
# delete zip file
65+
os.remove(dataset_file_path)
66+
else:
67+
print(f"File-path '{dataset_file_path}' already exists")
68+
download_datasets = False
69+
else:
70+
print("Using default MNIST mirror references to download datasets ...")
71+
download_datasets = True
72+
73+
datasets.MNIST(
74+
dataset_path,
75+
train=False,
76+
download=download_datasets,
77+
transform=Compose([ToTensor()])
78+
)
79+
80+
if __name__ == "__main__":
81+
import argparse
82+
parser = argparse.ArgumentParser(description="MNIST dataset download")
83+
parser.add_argument('--dataset_path', type=str, default="../data", help='Path to MNIST datasets (default: ../data)')
84+
85+
args = parser.parse_args()
86+
87+
main(
88+
dataset_path=args.dataset_path,
89+
)

tests/kfto/resources/requirements-rocm.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
torchvision==0.19.0
33
tensorboard==2.18.0
44
fsspec[http]==2024.6.1
5-
numpy==2.0.2
5+
numpy==2.0.2
6+
minio==7.2.13
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
torchvision==0.19.0
22
tensorboard==2.18.0
33
fsspec[http]==2024.6.1
4-
numpy==2.0.2
4+
numpy==2.0.2
5+
minio==7.2.13

0 commit comments

Comments
 (0)