@@ -61,6 +61,7 @@ func runKFTOPyTorchMnistJob(t *testing.T, accelerator Accelerator, image string,
6161 namespace := test .NewTestNamespace ()
6262
6363 mnist := ReadFile (test , "resources/mnist.py" )
64+ download_mnist_dataset := ReadFile (test , "resources/download_mnist_datasets.py" )
6465 requirementsFileName := ReadFile (test , requirementsFile )
6566
6667 if accelerator .isGpu () {
@@ -70,8 +71,9 @@ func runKFTOPyTorchMnistJob(t *testing.T, accelerator Accelerator, image string,
7071 }
7172 config := CreateConfigMap (test , namespace .Name , map [string ][]byte {
7273 // MNIST Ray Notebook
73- "mnist.py" : mnist ,
74- "requirements.txt" : requirementsFileName ,
74+ "mnist.py" : mnist ,
75+ "download_mnist_datasets.py" : download_mnist_dataset ,
76+ "requirements.txt" : requirementsFileName ,
7577 })
7678
7779 // Create training PyTorch job
@@ -117,6 +119,12 @@ func createKFTOPyTorchMnistJob(test Test, namespace string, config corev1.Config
117119 backend = "gloo"
118120 }
119121
122+ storage_bucket_endpoint , storage_bucket_endpoint_exists := GetStorageBucketDefaultEndpoint ()
123+ storage_bucket_access_key_id , storage_bucket_access_key_id_exists := GetStorageBucketAccessKeyId ()
124+ storage_bucket_secret_key , storage_bucket_secret_key_exists := GetStorageBucketSecretKey ()
125+ storage_bucket_name , storage_bucket_name_exists := GetStorageBucketName ()
126+ storage_bucket_mnist_dir , storage_bucket_mnist_dir_exists := GetStorageBucketMnistDir ()
127+
120128 tuningJob := & kftov1.PyTorchJob {
121129 TypeMeta : metav1.TypeMeta {
122130 APIVersion : corev1 .SchemeGroupVersion .String (),
@@ -162,8 +170,7 @@ func createKFTOPyTorchMnistJob(test Test, namespace string, config corev1.Config
162170 fmt .Sprintf (`mkdir -p /tmp/lib /tmp/datasets/mnist && export PYTHONPATH=$PYTHONPATH:/tmp/lib && \
163171 pip install --no-cache-dir -r /mnt/files/requirements.txt --target=/tmp/lib && \
164172 echo "Downloading MNIST dataset..." && \
165- python3 -c "from torchvision.datasets import MNIST; from torchvision.transforms import Compose, ToTensor; \
166- MNIST('/tmp/datasets/mnist', train=False, download=True, transform=Compose([ToTensor()]))" && \
173+ python3 /mnt/files/download_mnist_datasets.py --dataset_path "/tmp/datasets/mnist" && \
167174 echo -e "\n\n Dataset downloaded to /tmp/datasets/mnist" && ls -R /tmp/datasets/mnist && \
168175 echo -e "\n\n Starting training..." && \
169176 torchrun --nproc_per_node=%d /mnt/files/mnist.py --dataset_path "/tmp/datasets/mnist" --epochs 7 --save_every 2 --batch_size 128 --lr 0.001 --snapshot_path "mnist_snapshot.pt" --backend %s` , numProcPerNode , backend ),
@@ -247,8 +254,7 @@ func createKFTOPyTorchMnistJob(test Test, namespace string, config corev1.Config
247254 fmt .Sprintf (`mkdir -p /tmp/lib /tmp/datasets/mnist && export PYTHONPATH=$PYTHONPATH:/tmp/lib && \
248255 pip install --no-cache-dir -r /mnt/files/requirements.txt --target=/tmp/lib && \
249256 echo "Downloading MNIST dataset..." && \
250- python3 -c "from torchvision.datasets import MNIST; from torchvision.transforms import Compose, ToTensor; \
251- MNIST('/tmp/datasets/mnist', train=False, download=True, transform=Compose([ToTensor()]))" && \
257+ python3 /mnt/files/download_mnist_datasets.py --dataset_path "/tmp/datasets/mnist" && \
252258 echo -e "\n\n Dataset downloaded to /tmp/datasets/mnist" && ls -R /tmp/datasets/mnist && \
253259 echo -e "\n\n Starting training..." && \
254260 torchrun --nproc_per_node=%d /mnt/files/mnist.py --dataset_path "/tmp/datasets/mnist" --epochs 7 --save_every 2 --batch_size 128 --lr 0.001 --snapshot_path "mnist_snapshot.pt" --backend %s` , numProcPerNode , backend ),
@@ -344,6 +350,34 @@ func createKFTOPyTorchMnistJob(test Test, namespace string, config corev1.Config
344350 }
345351 }
346352
353+ if storage_bucket_endpoint_exists && storage_bucket_access_key_id_exists && storage_bucket_secret_key_exists && storage_bucket_name_exists && storage_bucket_mnist_dir_exists {
354+ storage_bucket_env_vars := []corev1.EnvVar {
355+ {
356+ Name : "AWS_DEFAULT_ENDPOINT" ,
357+ Value : storage_bucket_endpoint ,
358+ },
359+ {
360+ Name : "AWS_ACCESS_KEY_ID" ,
361+ Value : storage_bucket_access_key_id ,
362+ },
363+ {
364+ Name : "AWS_SECRET_ACCESS_KEY" ,
365+ Value : storage_bucket_secret_key ,
366+ },
367+ {
368+ Name : "AWS_STORAGE_BUCKET" ,
369+ Value : storage_bucket_name ,
370+ },
371+ {
372+ Name : "AWS_STORAGE_BUCKET_MNIST_DIR" ,
373+ Value : storage_bucket_mnist_dir ,
374+ },
375+ }
376+
377+ tuningJob .Spec .PyTorchReplicaSpecs [kftov1 .PyTorchJobReplicaTypeMaster ].Template .Spec .Containers [0 ].Env = append (tuningJob .Spec .PyTorchReplicaSpecs [kftov1 .PyTorchJobReplicaTypeMaster ].Template .Spec .Containers [0 ].Env , storage_bucket_env_vars ... )
378+ tuningJob .Spec .PyTorchReplicaSpecs [kftov1 .PyTorchJobReplicaTypeWorker ].Template .Spec .Containers [0 ].Env = append (tuningJob .Spec .PyTorchReplicaSpecs [kftov1 .PyTorchJobReplicaTypeWorker ].Template .Spec .Containers [0 ].Env , storage_bucket_env_vars ... )
379+ }
380+
347381 tuningJob , err := test .Client ().Kubeflow ().KubeflowV1 ().PyTorchJobs (namespace ).Create (test .Ctx (), tuningJob , metav1.CreateOptions {})
348382 test .Expect (err ).NotTo (HaveOccurred ())
349383 test .T ().Logf ("Created PytorchJob %s/%s successfully" , tuningJob .Namespace , tuningJob .Name )
0 commit comments