Skip to content

Commit 8bbd0c4

Browse files
Update KFTO multinode pytorch training test for disconnected
1 parent 25f24d5 commit 8bbd0c4

File tree

1 file changed

+35
-18
lines changed

1 file changed

+35
-18
lines changed

tests/kfto/kfto_mnist_training_test.go

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ func createKFTOPyTorchMnistJob(test Test, namespace string, config corev1.Config
167167
Command: []string{
168168
"/bin/bash", "-c",
169169
fmt.Sprintf(`mkdir -p /tmp/lib /tmp/datasets/mnist && export PYTHONPATH=$PYTHONPATH:/tmp/lib && \
170-
pip install --no-cache-dir -r /mnt/files/requirements.txt --target=/tmp/lib && \
170+
pip install --no-cache-dir -r /mnt/files/requirements.txt --target=/tmp/lib --verbose && \
171171
echo "Downloading MNIST dataset..." && \
172172
python3 /mnt/files/download_mnist_datasets.py --dataset_path "/tmp/datasets/mnist" && \
173173
echo -e "\n\n Dataset downloaded to /tmp/datasets/mnist" && ls -R /tmp/datasets/mnist && \
@@ -251,7 +251,7 @@ func createKFTOPyTorchMnistJob(test Test, namespace string, config corev1.Config
251251
Command: []string{
252252
"/bin/bash", "-c",
253253
fmt.Sprintf(`mkdir -p /tmp/lib /tmp/datasets/mnist && export PYTHONPATH=$PYTHONPATH:/tmp/lib && \
254-
pip install --no-cache-dir -r /mnt/files/requirements.txt --target=/tmp/lib && \
254+
pip install --no-cache-dir -r /mnt/files/requirements.txt --target=/tmp/lib --verbose && \
255255
echo "Downloading MNIST dataset..." && \
256256
python3 /mnt/files/download_mnist_datasets.py --dataset_path "/tmp/datasets/mnist" && \
257257
echo -e "\n\n Dataset downloaded to /tmp/datasets/mnist" && ls -R /tmp/datasets/mnist && \
@@ -306,14 +306,36 @@ func createKFTOPyTorchMnistJob(test Test, namespace string, config corev1.Config
306306
},
307307
}
308308

309+
// Add PIP Index to download python packages, use provided custom PYPI mirror index url in case of disconnected environemnt
310+
tuningJob.Spec.PyTorchReplicaSpecs[kftov1.PyTorchJobReplicaTypeMaster].Template.Spec.Containers[0].Env = []corev1.EnvVar{
311+
{
312+
Name: "PIP_INDEX_URL",
313+
Value: GetPipIndexURL(),
314+
},
315+
{
316+
Name: "PIP_TRUSTED_HOST",
317+
Value: GetPipTrustedHost(),
318+
},
319+
}
320+
tuningJob.Spec.PyTorchReplicaSpecs[kftov1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers[0].Env = []corev1.EnvVar{
321+
{
322+
Name: "PIP_INDEX_URL",
323+
Value: GetPipIndexURL(),
324+
},
325+
{
326+
Name: "PIP_TRUSTED_HOST",
327+
Value: GetPipTrustedHost(),
328+
},
329+
}
330+
309331
if accelerator.isGpu() {
310332
// Update resource lists for GPU (NVIDIA/ROCm) usecase
311-
tuningJob.Spec.PyTorchReplicaSpecs["Master"].Template.Spec.Containers[0].Resources.Requests[corev1.ResourceName(accelerator.ResourceLabel)] = resource.MustParse(fmt.Sprint(numProcPerNode))
312-
tuningJob.Spec.PyTorchReplicaSpecs["Master"].Template.Spec.Containers[0].Resources.Limits[corev1.ResourceName(accelerator.ResourceLabel)] = resource.MustParse(fmt.Sprint(numProcPerNode))
313-
tuningJob.Spec.PyTorchReplicaSpecs["Worker"].Template.Spec.Containers[0].Resources.Requests[corev1.ResourceName(accelerator.ResourceLabel)] = resource.MustParse(fmt.Sprint(numProcPerNode))
314-
tuningJob.Spec.PyTorchReplicaSpecs["Worker"].Template.Spec.Containers[0].Resources.Limits[corev1.ResourceName(accelerator.ResourceLabel)] = resource.MustParse(fmt.Sprint(numProcPerNode))
333+
tuningJob.Spec.PyTorchReplicaSpecs[kftov1.PyTorchJobReplicaTypeMaster].Template.Spec.Containers[0].Resources.Requests[corev1.ResourceName(accelerator.ResourceLabel)] = resource.MustParse(fmt.Sprint(numProcPerNode))
334+
tuningJob.Spec.PyTorchReplicaSpecs[kftov1.PyTorchJobReplicaTypeMaster].Template.Spec.Containers[0].Resources.Limits[corev1.ResourceName(accelerator.ResourceLabel)] = resource.MustParse(fmt.Sprint(numProcPerNode))
335+
tuningJob.Spec.PyTorchReplicaSpecs[kftov1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers[0].Resources.Requests[corev1.ResourceName(accelerator.ResourceLabel)] = resource.MustParse(fmt.Sprint(numProcPerNode))
336+
tuningJob.Spec.PyTorchReplicaSpecs[kftov1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers[0].Resources.Limits[corev1.ResourceName(accelerator.ResourceLabel)] = resource.MustParse(fmt.Sprint(numProcPerNode))
315337

316-
tuningJob.Spec.PyTorchReplicaSpecs["Master"].Template.Spec.Containers[0].Env = []corev1.EnvVar{
338+
torch_distributed_debug_env_vars := []corev1.EnvVar{
317339
{
318340
Name: "NCCL_DEBUG",
319341
Value: "INFO",
@@ -323,25 +345,19 @@ func createKFTOPyTorchMnistJob(test Test, namespace string, config corev1.Config
323345
Value: "DETAIL",
324346
},
325347
}
326-
tuningJob.Spec.PyTorchReplicaSpecs["Worker"].Template.Spec.Containers[0].Env = []corev1.EnvVar{
327-
{
328-
Name: "NCCL_DEBUG",
329-
Value: "INFO",
330-
},
331-
{
332-
Name: "TORCH_DISTRIBUTED_DEBUG",
333-
Value: "DETAIL",
334-
},
348+
for _, envVar := range torch_distributed_debug_env_vars {
349+
tuningJob.Spec.PyTorchReplicaSpecs[kftov1.PyTorchJobReplicaTypeMaster].Template.Spec.Containers[0].Env = upsert(tuningJob.Spec.PyTorchReplicaSpecs[kftov1.PyTorchJobReplicaTypeMaster].Template.Spec.Containers[0].Env, envVar, withEnvVarName(envVar.Name))
350+
tuningJob.Spec.PyTorchReplicaSpecs[kftov1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers[0].Env = upsert(tuningJob.Spec.PyTorchReplicaSpecs[kftov1.PyTorchJobReplicaTypeMaster].Template.Spec.Containers[0].Env, envVar, withEnvVarName(envVar.Name))
335351
}
336352

337353
// Update tolerations
338-
tuningJob.Spec.PyTorchReplicaSpecs["Master"].Template.Spec.Tolerations = []corev1.Toleration{
354+
tuningJob.Spec.PyTorchReplicaSpecs[kftov1.PyTorchJobReplicaTypeMaster].Template.Spec.Tolerations = []corev1.Toleration{
339355
{
340356
Key: accelerator.ResourceLabel,
341357
Operator: corev1.TolerationOpExists,
342358
},
343359
}
344-
tuningJob.Spec.PyTorchReplicaSpecs["Worker"].Template.Spec.Tolerations = []corev1.Toleration{
360+
tuningJob.Spec.PyTorchReplicaSpecs[kftov1.PyTorchJobReplicaTypeWorker].Template.Spec.Tolerations = []corev1.Toleration{
345361
{
346362
Key: accelerator.ResourceLabel,
347363
Operator: corev1.TolerationOpExists,
@@ -377,6 +393,7 @@ func createKFTOPyTorchMnistJob(test Test, namespace string, config corev1.Config
377393
// Append the list of environment variables for the worker container
378394
for _, envVar := range storage_bucket_env_vars {
379395
tuningJob.Spec.PyTorchReplicaSpecs[kftov1.PyTorchJobReplicaTypeMaster].Template.Spec.Containers[0].Env = upsert(tuningJob.Spec.PyTorchReplicaSpecs[kftov1.PyTorchJobReplicaTypeMaster].Template.Spec.Containers[0].Env, envVar, withEnvVarName(envVar.Name))
396+
tuningJob.Spec.PyTorchReplicaSpecs[kftov1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers[0].Env = upsert(tuningJob.Spec.PyTorchReplicaSpecs[kftov1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers[0].Env, envVar, withEnvVarName(envVar.Name))
380397
}
381398

382399
} else {

0 commit comments

Comments
 (0)