@@ -70,7 +70,6 @@ func runKFTOPyTorchMnistJob(t *testing.T, accelerator Accelerator, image string,
7070 mnist = bytes .Replace (mnist , []byte ("accelerator=\" has to be specified\" " ), []byte ("accelerator=\" cpu\" " ), 1 )
7171 }
7272 config := CreateConfigMap (test , namespace .Name , map [string ][]byte {
73- // MNIST Ray Notebook
7473 "mnist.py" : mnist ,
7574 "download_mnist_datasets.py" : download_mnist_dataset ,
7675 "requirements.txt" : requirementsFileName ,
@@ -350,6 +349,7 @@ func createKFTOPyTorchMnistJob(test Test, namespace string, config corev1.Config
350349 }
351350 }
352351
352+ // Use storage bucket to download the MNIST datasets if required environment variables are provided, else use default MNIST mirror references as the fallback
353353 if storage_bucket_endpoint_exists && storage_bucket_access_key_id_exists && storage_bucket_secret_key_exists && storage_bucket_name_exists && storage_bucket_mnist_dir_exists {
354354 storage_bucket_env_vars := []corev1.EnvVar {
355355 {
@@ -374,8 +374,13 @@ func createKFTOPyTorchMnistJob(test Test, namespace string, config corev1.Config
374374 },
375375 }
376376
377- tuningJob .Spec .PyTorchReplicaSpecs [kftov1 .PyTorchJobReplicaTypeMaster ].Template .Spec .Containers [0 ].Env = append (tuningJob .Spec .PyTorchReplicaSpecs [kftov1 .PyTorchJobReplicaTypeMaster ].Template .Spec .Containers [0 ].Env , storage_bucket_env_vars ... )
378- tuningJob .Spec .PyTorchReplicaSpecs [kftov1 .PyTorchJobReplicaTypeWorker ].Template .Spec .Containers [0 ].Env = append (tuningJob .Spec .PyTorchReplicaSpecs [kftov1 .PyTorchJobReplicaTypeWorker ].Template .Spec .Containers [0 ].Env , storage_bucket_env_vars ... )
377+ // Append the list of environment variables for the worker container
378+ for _ , envVar := range storage_bucket_env_vars {
379+ tuningJob .Spec .PyTorchReplicaSpecs [kftov1 .PyTorchJobReplicaTypeMaster ].Template .Spec .Containers [0 ].Env = upsert (tuningJob .Spec .PyTorchReplicaSpecs [kftov1 .PyTorchJobReplicaTypeMaster ].Template .Spec .Containers [0 ].Env , envVar , withEnvVarName (envVar .Name ))
380+ }
381+
382+ } else {
383+ test .T ().Logf ("Skipped usage of S3 storage bucket, because required environment variables aren't provided!\n Required environment variables : AWS_DEFAULT_ENDPOINT, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_STORAGE_BUCKET, AWS_STORAGE_BUCKET_MNIST_DIR" )
379384 }
380385
381386 tuningJob , err := test .Client ().Kubeflow ().KubeflowV1 ().PyTorchJobs (namespace ).Create (test .Ctx (), tuningJob , metav1.CreateOptions {})
0 commit comments