Skip to content

Commit 6ecda2d

Browse files
sutaakaropenshift-merge-bot[bot]
authored andcommitted
Upload models trained using Training operator into s3 bucket
1 parent e1aa8cc commit 6ecda2d

File tree

11 files changed

+203
-51
lines changed

11 files changed

+203
-51
lines changed

README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
* `CODEFLARE_TEST_TIMEOUT_MEDIUM` - Timeout duration for medium tasks
2222
* `CODEFLARE_TEST_TIMEOUT_LONG` - Timeout duration for long tasks
2323
* `CODEFLARE_TEST_RAY_IMAGE` (Optional) - Ray image used for raycluster configuration
24+
* `MINIO_CLI_IMAGE` (Optional) - Minio CLI image used for uploading/downloading data from/into s3 bucket
2425

2526
NOTE: `quay.io/modh/ray:2.35.0-py311-cu121` is the default image used for creating a RayCluster resource. If you have your own custom ray image which suits your purposes, specify it in `CODEFLARE_TEST_RAY_IMAGE` environment variable.
2627

@@ -30,10 +31,17 @@
3031

3132
### Environment variables for Training operator GPU test suite
3233

33-
* `MULTIGPU_NAMESPACE` - Existing namespace where will the Training operator GPU tests be executed
34+
* `TEST_NAMESPACE_NAME` (Optional) - Existing namespace where will the Training operator GPU tests be executed
3435
* `HF_TOKEN` - HuggingFace token used to pull models which has limited access
3536
* `GPTQ_MODEL_PVC_NAME` - Name of PersistenceVolumeClaim containing downloaded GPTQ models
3637

38+
To upload trained model into S3 compatible storage, use the environment variables mentioned below :
39+
* `AWS_DEFAULT_ENDPOINT` - Storage bucket endpoint to upload trained dataset to, if set then test will upload model into s3 bucket
40+
* `AWS_ACCESS_KEY_ID` - Storage bucket access key
41+
* `AWS_SECRET_ACCESS_KEY` - Storage bucket secret key
42+
* `AWS_STORAGE_BUCKET` - Storage bucket name
43+
* `AWS_STORAGE_BUCKET_MODEL_PATH` (Optional) - Path in the storage bucket where trained model will be stored to
44+
3745
### Environment variables for ODH integration test suite
3846

3947
* `ODH_NAMESPACE` - Namespace where ODH components are installed to

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ toolchain go1.21.5
77
require (
88
github.com/kubeflow/training-operator v1.7.0
99
github.com/onsi/gomega v1.31.1
10-
github.com/project-codeflare/codeflare-common v0.0.0-20241009115304-28574c7cd6ad
10+
github.com/project-codeflare/codeflare-common v0.0.0-20241015133940-3e0d9b3a23ad
1111
github.com/prometheus/client_golang v1.20.4
1212
github.com/prometheus/common v0.57.0
1313
github.com/ray-project/kuberay/ray-operator v1.1.0-alpha.0

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,8 +365,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
365365
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
366366
github.com/project-codeflare/appwrapper v0.8.0 h1:vWHNtXUtHutN2EzYb6rryLdESnb8iDXsCokXOuNYXvg=
367367
github.com/project-codeflare/appwrapper v0.8.0/go.mod h1:FMQ2lI3fz6LakUVXgN1FTdpsc3BBkNIZZgtMmM9J5UM=
368-
github.com/project-codeflare/codeflare-common v0.0.0-20241009115304-28574c7cd6ad h1:+xWtSKy90q5l6CjABRqTfKGYuNTgD2UIZpKmNV5styY=
369-
github.com/project-codeflare/codeflare-common v0.0.0-20241009115304-28574c7cd6ad/go.mod h1:v7XKwaDoCspsHQlWJNarO7gOpR+iumSS+c1bWs3kJOI=
368+
github.com/project-codeflare/codeflare-common v0.0.0-20241015133940-3e0d9b3a23ad h1:rZEFsEa4VQXw/U2AHt2HJ4+e55CXYElPIhOSr3+3e9o=
369+
github.com/project-codeflare/codeflare-common v0.0.0-20241015133940-3e0d9b3a23ad/go.mod h1:v7XKwaDoCspsHQlWJNarO7gOpR+iumSS+c1bWs3kJOI=
370370
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
371371
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
372372
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=

images/util/mc-cli/Dockerfile

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
FROM registry.access.redhat.com/ubi9:latest
2+
WORKDIR /app
3+
USER 1001
4+
5+
RUN curl https://dl.min.io/client/mc/release/linux-amd64/mc -o /app/mc
6+
RUN chmod +x /app/mc
7+
ENV PATH="$PATH:/app"
8+
ENV MC_CONFIG_DIR="/app"
9+
RUN chown -R $USER:0 /app && \
10+
chmod -R g+rwX /app

tests/kfto/core/config.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
{
22
"model_name_or_path": "/tmp/model/bloom-560m",
33
"training_data_path": "/etc/config/twitter_complaints_small.json",
4-
"output_dir": "/tmp/out",
5-
"save_model_dir": "/tmp/out",
4+
"output_dir": "/mnt/output/model",
5+
"save_model_dir": "/mnt/output/model",
66
"num_train_epochs": 1.0,
77
"per_device_train_batch_size": 4,
88
"per_device_eval_batch_size": 4,

tests/kfto/core/config_lora.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
{
22
"model_name_or_path": "/tmp/model/bloom-560m",
33
"training_data_path": "/etc/config/twitter_complaints_small.json",
4-
"output_dir": "/tmp/out",
5-
"save_model_dir": "/tmp/out",
4+
"output_dir": "/mnt/output/model",
5+
"save_model_dir": "/mnt/output/model",
66
"num_train_epochs": 1.0,
77
"per_device_train_batch_size": 4,
88
"per_device_eval_batch_size": 4,

tests/kfto/core/config_qlora.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
{
22
"model_name_or_path": "TechxGenus/Meta-Llama-3-8B-GPTQ",
33
"training_data_path": "/etc/config/twitter_complaints_small.json",
4-
"output_dir": "/tmp/out",
5-
"save_model_dir": "/tmp/out",
4+
"output_dir": "/mnt/output/model",
5+
"save_model_dir": "/mnt/output/model",
66
"num_train_epochs": 1.0,
77
"per_device_train_batch_size": 4,
88
"per_device_eval_batch_size": 4,

tests/kfto/core/environment.go

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,18 @@ const (
3030
bloomModelImageEnvVar = "BLOOM_MODEL_IMAGE"
3131
// The environment variable referring to image containing Stanford Alpaca dataset
3232
alpacaDatasetImageEnvVar = "ALPACA_DATASET_IMAGE"
33+
// The environment variable referring to image containing minio CLI
34+
minioCliImageEnvVar = "MINIO_CLI_IMAGE"
3335
// The environment variable for HuggingFace token to download models which require authentication
3436
huggingfaceTokenEnvVar = "HF_TOKEN"
35-
// The environment variable specifying existing namespace to be used for multiGPU tests
36-
multiGpuNamespaceEnvVar = "MULTIGPU_NAMESPACE"
37+
// The environment variable specifying existing namespace name to be used for tests
38+
testNamespaceEnvVar = "TEST_NAMESPACE_NAME"
3739
// The environment variable specifying name of PersistenceVolumeClaim containing GPTQ models
3840
gptqModelPvcNameEnvVar = "GPTQ_MODEL_PVC_NAME"
3941
// The environment variable referring to image simulating sleep condition in container
4042
sleepImageEnvVar = "SLEEP_IMAGE"
43+
// The environment variable specifying s3 bucket folder path used to store model
44+
storageBucketModelPath = "AWS_STORAGE_BUCKET_MODEL_PATH"
4145
)
4246

4347
func GetFmsHfTuningImage(t Test) string {
@@ -56,6 +60,10 @@ func GetAlpacaDatasetImage() string {
5660
return lookupEnvOrDefault(alpacaDatasetImageEnvVar, "quay.io/ksuta/alpaca-dataset@sha256:2e90f631180c7b2c916f9569b914b336b612e8ae86efad82546adc5c9fcbbb8d")
5761
}
5862

63+
func GetMinioCliImage() string {
64+
return lookupEnvOrDefault(minioCliImageEnvVar, "quay.io/ksuta/mc@sha256:e128ce4caee276bcbfe3bd32ebb01c814f6b2eb2fd52d08ef0d4684f68c1e3d6")
65+
}
66+
5967
func GetHuggingFaceToken(t Test) string {
6068
image, ok := os.LookupEnv(huggingfaceTokenEnvVar)
6169
if !ok {
@@ -64,12 +72,8 @@ func GetHuggingFaceToken(t Test) string {
6472
return image
6573
}
6674

67-
func GetMultiGpuNamespace(t Test) string {
68-
image, ok := os.LookupEnv(multiGpuNamespaceEnvVar)
69-
if !ok {
70-
t.T().Fatalf("Expected environment variable %s not found, please use this environment variable to specify namespace to be used for multiGPU tests.", multiGpuNamespaceEnvVar)
71-
}
72-
return image
75+
func GetTestNamespaceName() (namespaceName string, exists bool) {
76+
return os.LookupEnv(testNamespaceEnvVar)
7377
}
7478

7579
func GetGptqModelPvcName() (string, error) {
@@ -84,6 +88,11 @@ func GetSleepImage() string {
8488
return lookupEnvOrDefault(sleepImageEnvVar, "gcr.io/k8s-staging-perf-tests/sleep@sha256:8d91ddf9f145b66475efda1a1b52269be542292891b5de2a7fad944052bab6ea")
8589
}
8690

91+
func GetStorageBucketModelPath() string {
92+
storageBucketModelPath := lookupEnvOrDefault(storageBucketModelPath, "")
93+
return storageBucketModelPath
94+
}
95+
8796
func lookupEnvOrDefault(key, value string) string {
8897
if v, ok := os.LookupEnv(key); ok {
8998
return v

tests/kfto/core/kfto_kueue_sft_GPU_test.go

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ func TestMultiGpuPytorchjobMerlinite7b(t *testing.T) {
102102
func runMultiGpuPytorchjob(t *testing.T, modelConfigFile string, numberOfGpus int, options ...Option[*kftov1.PyTorchJob]) {
103103
test := With(t)
104104

105-
namespace := GetMultiGpuNamespace(test)
105+
namespace := GetOrCreateTestNamespace(test)
106106

107107
// Create a ConfigMap with configuration
108108
configData := map[string][]byte{
@@ -111,8 +111,12 @@ func runMultiGpuPytorchjob(t *testing.T, modelConfigFile string, numberOfGpus in
111111
config := CreateConfigMap(test, namespace, configData)
112112
defer test.Client().Core().CoreV1().ConfigMaps(namespace).Delete(test.Ctx(), config.Name, *metav1.NewDeleteOptions(0))
113113

114+
// Create PVC for trained model
115+
outputPvc := CreatePersistentVolumeClaim(test, namespace, "200Gi", corev1.ReadWriteOnce)
116+
defer test.Client().Core().CoreV1().PersistentVolumeClaims(namespace).Delete(test.Ctx(), outputPvc.Name, metav1.DeleteOptions{})
117+
114118
// Create training PyTorch job
115-
tuningJob := createAlpacaPyTorchJob(test, namespace, *config, numberOfGpus, options...)
119+
tuningJob := createAlpacaPyTorchJob(test, namespace, *config, numberOfGpus, outputPvc.Name, options...)
116120
defer test.Client().Kubeflow().KubeflowV1().PyTorchJobs(namespace).Delete(test.Ctx(), tuningJob.Name, *metav1.NewDeleteOptions(0))
117121

118122
// Make sure the PyTorch job is running
@@ -137,9 +141,14 @@ func runMultiGpuPytorchjob(t *testing.T, modelConfigFile string, numberOfGpus in
137141
// Make sure the PyTorch job succeed
138142
test.Eventually(PytorchJob(test, namespace, tuningJob.Name), 60*time.Minute).Should(WithTransform(PytorchJobConditionSucceeded, Equal(corev1.ConditionTrue)))
139143
test.T().Logf("PytorchJob %s/%s ran successfully", tuningJob.Namespace, tuningJob.Name)
144+
145+
_, bucketEndpointSet := GetStorageBucketDefaultEndpoint()
146+
if bucketEndpointSet {
147+
uploadToS3(test, namespace, outputPvc.Name, "model")
148+
}
140149
}
141150

142-
func createAlpacaPyTorchJob(test Test, namespace string, config corev1.ConfigMap, numberOfGpus int, options ...Option[*kftov1.PyTorchJob]) *kftov1.PyTorchJob {
151+
func createAlpacaPyTorchJob(test Test, namespace string, config corev1.ConfigMap, numberOfGpus int, outputPvc string, options ...Option[*kftov1.PyTorchJob]) *kftov1.PyTorchJob {
143152
tuningJob := &kftov1.PyTorchJob{
144153
TypeMeta: metav1.TypeMeta{
145154
APIVersion: corev1.SchemeGroupVersion.String(),
@@ -265,18 +274,8 @@ func createAlpacaPyTorchJob(test Test, namespace string, config corev1.ConfigMap
265274
{
266275
Name: "output-volume",
267276
VolumeSource: corev1.VolumeSource{
268-
Ephemeral: &corev1.EphemeralVolumeSource{
269-
VolumeClaimTemplate: &corev1.PersistentVolumeClaimTemplate{
270-
Spec: corev1.PersistentVolumeClaimSpec{
271-
AccessModes: []corev1.PersistentVolumeAccessMode{corev1.ReadWriteOnce},
272-
Resources: corev1.VolumeResourceRequirements{
273-
Requests: corev1.ResourceList{
274-
corev1.ResourceStorage: resource.MustParse("500Gi"),
275-
},
276-
},
277-
VolumeMode: Ptr(corev1.PersistentVolumeFilesystem),
278-
},
279-
},
277+
PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{
278+
ClaimName: outputPvc,
280279
},
281280
},
282281
},

tests/kfto/core/kfto_kueue_sft_test.go

Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,14 @@ func runPytorchjobWithSFTtrainer(t *testing.T, modelConfigFile string, numGpus i
4646
test := With(t)
4747

4848
// Create a namespace
49-
namespace := test.NewTestNamespace()
49+
namespace := GetOrCreateTestNamespace(test)
5050

5151
// Create a ConfigMap with training dataset and configuration
5252
configData := map[string][]byte{
5353
"config.json": ReadFile(test, modelConfigFile),
5454
"twitter_complaints_small.json": ReadFile(test, "twitter_complaints_small.json"),
5555
}
56-
config := CreateConfigMap(test, namespace.Name, configData)
56+
config := CreateConfigMap(test, namespace, configData)
5757

5858
// Create Kueue resources
5959
resourceFlavor := CreateKueueResourceFlavor(test, kueuev1beta1.ResourceFlavorSpec{})
@@ -87,13 +87,19 @@ func runPytorchjobWithSFTtrainer(t *testing.T, modelConfigFile string, numGpus i
8787
}
8888
clusterQueue := CreateKueueClusterQueue(test, cqSpec)
8989
defer test.Client().Kueue().KueueV1beta1().ClusterQueues().Delete(test.Ctx(), clusterQueue.Name, metav1.DeleteOptions{})
90-
localQueue := CreateKueueLocalQueue(test, namespace.Name, clusterQueue.Name, AsDefaultQueue)
90+
localQueue := CreateKueueLocalQueue(test, namespace, clusterQueue.Name, AsDefaultQueue)
91+
defer test.Client().Kueue().KueueV1beta1().LocalQueues(namespace).Delete(test.Ctx(), localQueue.Name, metav1.DeleteOptions{})
92+
93+
// Create PVC for trained model
94+
outputPvc := CreatePersistentVolumeClaim(test, namespace, "10Gi", corev1.ReadWriteOnce)
95+
defer test.Client().Core().CoreV1().PersistentVolumeClaims(namespace).Delete(test.Ctx(), outputPvc.Name, metav1.DeleteOptions{})
9196

9297
// Create training PyTorch job
93-
tuningJob := createPyTorchJob(test, namespace.Name, localQueue.Name, *config, numGpus)
98+
tuningJob := createPyTorchJob(test, namespace, localQueue.Name, *config, numGpus, outputPvc.Name)
99+
defer test.Client().Kubeflow().KubeflowV1().PyTorchJobs(namespace).Delete(test.Ctx(), tuningJob.Name, *metav1.NewDeleteOptions(0))
94100

95101
// Make sure the Kueue Workload is admitted
96-
test.Eventually(KueueWorkloads(test, namespace.Name), TestTimeoutLong).
102+
test.Eventually(KueueWorkloads(test, namespace), TestTimeoutLong).
97103
Should(
98104
And(
99105
HaveLen(1),
@@ -102,26 +108,31 @@ func runPytorchjobWithSFTtrainer(t *testing.T, modelConfigFile string, numGpus i
102108
)
103109

104110
// Make sure the PyTorch job is running
105-
test.Eventually(PytorchJob(test, namespace.Name, tuningJob.Name), TestTimeoutLong).
111+
test.Eventually(PytorchJob(test, namespace, tuningJob.Name), TestTimeoutLong).
106112
Should(WithTransform(PytorchJobConditionRunning, Equal(corev1.ConditionTrue)))
107113

108114
// Make sure the PyTorch job succeed
109-
test.Eventually(PytorchJob(test, namespace.Name, tuningJob.Name), TestTimeoutMedium).Should(WithTransform(PytorchJobConditionSucceeded, Equal(corev1.ConditionTrue)))
115+
test.Eventually(PytorchJob(test, namespace, tuningJob.Name), TestTimeoutMedium).Should(WithTransform(PytorchJobConditionSucceeded, Equal(corev1.ConditionTrue)))
110116
test.T().Logf("PytorchJob %s/%s ran successfully", tuningJob.Namespace, tuningJob.Name)
117+
118+
_, bucketEndpointSet := GetStorageBucketDefaultEndpoint()
119+
if bucketEndpointSet {
120+
uploadToS3(test, namespace, outputPvc.Name, "model")
121+
}
111122
}
112123

113124
func TestPytorchjobUsingKueueQuota(t *testing.T) {
114125
test := With(t)
115126

116127
// Create a namespace
117-
namespace := test.NewTestNamespace()
128+
namespace := GetOrCreateTestNamespace(test)
118129

119130
// Create a ConfigMap with training dataset and configuration
120131
configData := map[string][]byte{
121132
"config.json": ReadFile(test, "config.json"),
122133
"twitter_complaints_small.json": ReadFile(test, "twitter_complaints_small.json"),
123134
}
124-
config := CreateConfigMap(test, namespace.Name, configData)
135+
config := CreateConfigMap(test, namespace, configData)
125136

126137
// Create limited Kueue resources to run just one Pytorchjob at a time
127138
resourceFlavor := CreateKueueResourceFlavor(test, kueuev1beta1.ResourceFlavorSpec{})
@@ -151,36 +162,44 @@ func TestPytorchjobUsingKueueQuota(t *testing.T) {
151162
}
152163
clusterQueue := CreateKueueClusterQueue(test, cqSpec)
153164
defer test.Client().Kueue().KueueV1beta1().ClusterQueues().Delete(test.Ctx(), clusterQueue.Name, metav1.DeleteOptions{})
154-
localQueue := CreateKueueLocalQueue(test, namespace.Name, clusterQueue.Name, AsDefaultQueue)
165+
localQueue := CreateKueueLocalQueue(test, namespace, clusterQueue.Name, AsDefaultQueue)
166+
167+
// Create first PVC for trained model
168+
outputPvc := CreatePersistentVolumeClaim(test, namespace, "10Gi", corev1.ReadWriteOnce)
169+
defer test.Client().Core().CoreV1().PersistentVolumeClaims(namespace).Delete(test.Ctx(), outputPvc.Name, metav1.DeleteOptions{})
155170

156171
// Create first training PyTorch job
157-
tuningJob := createPyTorchJob(test, namespace.Name, localQueue.Name, *config, 0)
172+
tuningJob := createPyTorchJob(test, namespace, localQueue.Name, *config, 0, outputPvc.Name)
158173

159174
// Make sure the PyTorch job is running
160-
test.Eventually(PytorchJob(test, namespace.Name, tuningJob.Name), TestTimeoutLong).
175+
test.Eventually(PytorchJob(test, namespace, tuningJob.Name), TestTimeoutLong).
161176
Should(WithTransform(PytorchJobConditionRunning, Equal(corev1.ConditionTrue)))
162177

178+
// Create second PVC for trained model
179+
secondOutputPvc := CreatePersistentVolumeClaim(test, namespace, "10Gi", corev1.ReadWriteOnce)
180+
defer test.Client().Core().CoreV1().PersistentVolumeClaims(namespace).Delete(test.Ctx(), outputPvc.Name, metav1.DeleteOptions{})
181+
163182
// Create second training PyTorch job
164-
secondTuningJob := createPyTorchJob(test, namespace.Name, localQueue.Name, *config, 0)
183+
secondTuningJob := createPyTorchJob(test, namespace, localQueue.Name, *config, 0, secondOutputPvc.Name)
165184

166185
// Make sure the second PyTorch job is suspended, waiting for first job to finish
167-
test.Eventually(PytorchJob(test, namespace.Name, secondTuningJob.Name), TestTimeoutShort).
186+
test.Eventually(PytorchJob(test, namespace, secondTuningJob.Name), TestTimeoutShort).
168187
Should(WithTransform(PytorchJobConditionSuspended, Equal(corev1.ConditionTrue)))
169188

170189
// Make sure the first PyTorch job succeed
171-
test.Eventually(PytorchJob(test, namespace.Name, tuningJob.Name), TestTimeoutLong).Should(WithTransform(PytorchJobConditionSucceeded, Equal(corev1.ConditionTrue)))
190+
test.Eventually(PytorchJob(test, namespace, tuningJob.Name), TestTimeoutLong).Should(WithTransform(PytorchJobConditionSucceeded, Equal(corev1.ConditionTrue)))
172191
test.T().Logf("PytorchJob %s/%s ran successfully", tuningJob.Namespace, tuningJob.Name)
173192

174193
// Second PyTorch job should be started now
175-
test.Eventually(PytorchJob(test, namespace.Name, secondTuningJob.Name), TestTimeoutShort).
194+
test.Eventually(PytorchJob(test, namespace, secondTuningJob.Name), TestTimeoutShort).
176195
Should(WithTransform(PytorchJobConditionRunning, Equal(corev1.ConditionTrue)))
177196

178197
// Make sure the second PyTorch job succeed
179-
test.Eventually(PytorchJob(test, namespace.Name, secondTuningJob.Name), TestTimeoutLong).Should(WithTransform(PytorchJobConditionSucceeded, Equal(corev1.ConditionTrue)))
198+
test.Eventually(PytorchJob(test, namespace, secondTuningJob.Name), TestTimeoutLong).Should(WithTransform(PytorchJobConditionSucceeded, Equal(corev1.ConditionTrue)))
180199
test.T().Logf("PytorchJob %s/%s ran successfully", secondTuningJob.Namespace, secondTuningJob.Name)
181200
}
182201

183-
func createPyTorchJob(test Test, namespace, localQueueName string, config corev1.ConfigMap, numGpus int) *kftov1.PyTorchJob {
202+
func createPyTorchJob(test Test, namespace, localQueueName string, config corev1.ConfigMap, numGpus int, outputPvcName string) *kftov1.PyTorchJob {
184203
tuningJob := &kftov1.PyTorchJob{
185204
TypeMeta: metav1.TypeMeta{
186205
APIVersion: corev1.SchemeGroupVersion.String(),
@@ -248,6 +267,10 @@ func createPyTorchJob(test Test, namespace, localQueueName string, config corev1
248267
Name: "tmp-volume",
249268
MountPath: "/tmp",
250269
},
270+
{
271+
Name: "output-volume",
272+
MountPath: "/mnt/output",
273+
},
251274
},
252275
Resources: corev1.ResourceRequirements{
253276
Requests: corev1.ResourceList{
@@ -284,6 +307,14 @@ func createPyTorchJob(test Test, namespace, localQueueName string, config corev1
284307
EmptyDir: &corev1.EmptyDirVolumeSource{},
285308
},
286309
},
310+
{
311+
Name: "output-volume",
312+
VolumeSource: corev1.VolumeSource{
313+
PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{
314+
ClaimName: outputPvcName,
315+
},
316+
},
317+
},
287318
},
288319
},
289320
},

0 commit comments

Comments
 (0)