Skip to content

Commit 25f24d5

Browse files
Update logic for appending list of environment variables to the existing env variable list and add log statements
1 parent 6852b96 commit 25f24d5

File tree

3 files changed

+27
-3
lines changed

3 files changed

+27
-3
lines changed

tests/kfto/kfto_mnist_training_test.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ func runKFTOPyTorchMnistJob(t *testing.T, accelerator Accelerator, image string,
7070
mnist = bytes.Replace(mnist, []byte("accelerator=\"has to be specified\""), []byte("accelerator=\"cpu\""), 1)
7171
}
7272
config := CreateConfigMap(test, namespace.Name, map[string][]byte{
73-
// MNIST Ray Notebook
7473
"mnist.py": mnist,
7574
"download_mnist_datasets.py": download_mnist_dataset,
7675
"requirements.txt": requirementsFileName,
@@ -350,6 +349,7 @@ func createKFTOPyTorchMnistJob(test Test, namespace string, config corev1.Config
350349
}
351350
}
352351

352+
// Use storage bucket to download the MNIST datasets if required environment variables are provided, else use default MNIST mirror references as the fallback
353353
if storage_bucket_endpoint_exists && storage_bucket_access_key_id_exists && storage_bucket_secret_key_exists && storage_bucket_name_exists && storage_bucket_mnist_dir_exists {
354354
storage_bucket_env_vars := []corev1.EnvVar{
355355
{
@@ -374,8 +374,13 @@ func createKFTOPyTorchMnistJob(test Test, namespace string, config corev1.Config
374374
},
375375
}
376376

377-
tuningJob.Spec.PyTorchReplicaSpecs[kftov1.PyTorchJobReplicaTypeMaster].Template.Spec.Containers[0].Env = append(tuningJob.Spec.PyTorchReplicaSpecs[kftov1.PyTorchJobReplicaTypeMaster].Template.Spec.Containers[0].Env, storage_bucket_env_vars...)
378-
tuningJob.Spec.PyTorchReplicaSpecs[kftov1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers[0].Env = append(tuningJob.Spec.PyTorchReplicaSpecs[kftov1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers[0].Env, storage_bucket_env_vars...)
377+
// Append the list of environment variables for the worker container
378+
for _, envVar := range storage_bucket_env_vars {
379+
tuningJob.Spec.PyTorchReplicaSpecs[kftov1.PyTorchJobReplicaTypeMaster].Template.Spec.Containers[0].Env = upsert(tuningJob.Spec.PyTorchReplicaSpecs[kftov1.PyTorchJobReplicaTypeMaster].Template.Spec.Containers[0].Env, envVar, withEnvVarName(envVar.Name))
380+
}
381+
382+
} else {
383+
test.T().Logf("Skipped usage of S3 storage bucket, because required environment variables aren't provided!\nRequired environment variables : AWS_DEFAULT_ENDPOINT, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_STORAGE_BUCKET, AWS_STORAGE_BUCKET_MNIST_DIR")
379384
}
380385

381386
tuningJob, err := test.Client().Kubeflow().KubeflowV1().PyTorchJobs(namespace).Create(test.Ctx(), tuningJob, metav1.CreateOptions{})

tests/kfto/resources/download_mnist_datasets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def main(dataset_path):
6868
download_datasets = False
6969
else:
7070
print("Using default MNIST mirror references to download datasets ...")
71+
print("Skipped usage of S3 storage bucket, because required environment variables aren't provided!\nRequired environment variables : AWS_DEFAULT_ENDPOINT, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_STORAGE_BUCKET, AWS_STORAGE_BUCKET_MNIST_DIR")
7172
download_datasets = True
7273

7374
datasets.MNIST(

tests/kfto/support.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,21 @@ func OpenShiftPrometheusGpuUtil(test Test, pod corev1.Pod, gpu Accelerator) func
7171
return util
7272
}
7373
}
74+
75+
type compare[T any] func(T, T) bool
76+
77+
func upsert[T any](items []T, item T, predicate compare[T]) []T {
78+
for i, t := range items {
79+
if predicate(t, item) {
80+
items[i] = item
81+
return items
82+
}
83+
}
84+
return append(items, item)
85+
}
86+
87+
func withEnvVarName(name string) compare[corev1.EnvVar] {
88+
return func(e1, e2 corev1.EnvVar) bool {
89+
return e1.Name == name
90+
}
91+
}

0 commit comments

Comments
 (0)