-
Notifications
You must be signed in to change notification settings - Fork 70
test: initial implementation of SDK e2e #488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1,2 @@ | ||
| bin/* | ||
| .vscode/* |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| /* | ||
| Copyright 2025. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. | ||
| */ | ||
|
|
||
| package trainer | ||
|
|
||
| import ( | ||
| "testing" | ||
|
|
||
| . "github.com/opendatahub-io/distributed-workloads/tests/common" | ||
| sdktests "github.com/opendatahub-io/distributed-workloads/tests/trainer/sdk_tests" | ||
| ) | ||
|
|
||
| func TestKubeflowSDK_Sanity(t *testing.T) { | ||
| Tags(t, Sanity) | ||
| sdktests.RunFashionMnistCpuDistributedTraining(t) | ||
| // ADD MORE SANITY TESTS HERE | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,275 @@ | ||
| { | ||
| "cells": [ | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": 1, | ||
| "metadata": { | ||
| "execution": { | ||
| "iopub.execute_input": "2025-09-03T13:19:46.917723Z", | ||
| "iopub.status.busy": "2025-09-03T13:19:46.917308Z", | ||
| "iopub.status.idle": "2025-09-03T13:19:46.935181Z", | ||
| "shell.execute_reply": "2025-09-03T13:19:46.934697Z", | ||
| "shell.execute_reply.started": "2025-09-03T13:19:46.917698Z" | ||
| } | ||
| }, | ||
| "outputs": [], | ||
| "source": [ | ||
| "def train_fashion_mnist():\n", | ||
| " import os\n", | ||
| "\n", | ||
| " import torch\n", | ||
| " import torch.distributed as dist\n", | ||
| " import torch.nn.functional as F\n", | ||
| " from torch import nn\n", | ||
| " from torch.utils.data import DataLoader, DistributedSampler\n", | ||
| " from torchvision import datasets, transforms\n", | ||
| "\n", | ||
| " # Define the PyTorch CNN model to be trained\n", | ||
| " class Net(nn.Module):\n", | ||
| " def __init__(self):\n", | ||
| " super(Net, self).__init__()\n", | ||
| " self.conv1 = nn.Conv2d(1, 20, 5, 1)\n", | ||
| " self.conv2 = nn.Conv2d(20, 50, 5, 1)\n", | ||
| " self.fc1 = nn.Linear(4 * 4 * 50, 500)\n", | ||
| " self.fc2 = nn.Linear(500, 10)\n", | ||
| "\n", | ||
| " def forward(self, x):\n", | ||
| " x = F.relu(self.conv1(x))\n", | ||
| " x = F.max_pool2d(x, 2, 2)\n", | ||
| " x = F.relu(self.conv2(x))\n", | ||
| " x = F.max_pool2d(x, 2, 2)\n", | ||
| " x = x.view(-1, 4 * 4 * 50)\n", | ||
| " x = F.relu(self.fc1(x))\n", | ||
| " x = self.fc2(x)\n", | ||
| " return F.log_softmax(x, dim=1)\n", | ||
| "\n", | ||
| " # Use NCCL if a GPU is available, otherwise use Gloo as communication backend.\n", | ||
| " device, backend = (\"cuda\", \"nccl\") if torch.cuda.is_available() else (\"cpu\", \"gloo\")\n", | ||
| " print(f\"Using Device: {device}, Backend: {backend}\")\n", | ||
| "\n", | ||
| " # Setup PyTorch distributed.\n", | ||
| " local_rank = int(os.getenv(\"LOCAL_RANK\", 0))\n", | ||
| " dist.init_process_group(backend=backend)\n", | ||
| " print(\n", | ||
| " \"Distributed Training for WORLD_SIZE: {}, RANK: {}, LOCAL_RANK: {}\".format(\n", | ||
| " dist.get_world_size(),\n", | ||
| " dist.get_rank(),\n", | ||
| " local_rank,\n", | ||
| " )\n", | ||
| " )\n", | ||
| "\n", | ||
| " # Create the model and load it into the device.\n", | ||
| " device = torch.device(f\"{device}:{local_rank}\")\n", | ||
| " model = nn.parallel.DistributedDataParallel(Net().to(device))\n", | ||
| " optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)\n", | ||
| "\n", | ||
| " \n", | ||
| " # Download FashionMNIST dataset only on local_rank=0 process.\n", | ||
| " if local_rank == 0:\n", | ||
| " dataset = datasets.FashionMNIST(\n", | ||
| " \"./data\",\n", | ||
| " train=True,\n", | ||
| " download=True,\n", | ||
| " transform=transforms.Compose([transforms.ToTensor()]),\n", | ||
| " )\n", | ||
| " dist.barrier()\n", | ||
| " dataset = datasets.FashionMNIST(\n", | ||
| " \"./data\",\n", | ||
| " train=True,\n", | ||
| " download=False,\n", | ||
| " transform=transforms.Compose([transforms.ToTensor()]),\n", | ||
| " )\n", | ||
| "\n", | ||
| "\n", | ||
| " # Shard the dataset accross workers.\n", | ||
| " train_loader = DataLoader(\n", | ||
| " dataset,\n", | ||
| " batch_size=100,\n", | ||
| " sampler=DistributedSampler(dataset)\n", | ||
| " )\n", | ||
| "\n", | ||
| " # TODO(astefanutti): add parameters to the training function\n", | ||
| " dist.barrier()\n", | ||
| " for epoch in range(1, 3):\n", | ||
| " model.train()\n", | ||
| "\n", | ||
| " # Iterate over mini-batches from the training set\n", | ||
| " for batch_idx, (inputs, labels) in enumerate(train_loader):\n", | ||
| " # Copy the data to the GPU device if available\n", | ||
| " inputs, labels = inputs.to(device), labels.to(device)\n", | ||
| " # Forward pass\n", | ||
| " outputs = model(inputs)\n", | ||
| " loss = F.nll_loss(outputs, labels)\n", | ||
| " # Backward pass\n", | ||
| " optimizer.zero_grad()\n", | ||
| " loss.backward()\n", | ||
| " optimizer.step()\n", | ||
| "\n", | ||
| " if batch_idx % 10 == 0 and dist.get_rank() == 0:\n", | ||
| " print(\n", | ||
| " \"Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}\".format(\n", | ||
| " epoch,\n", | ||
| " batch_idx * len(inputs),\n", | ||
| " len(train_loader.dataset),\n", | ||
| " 100.0 * batch_idx / len(train_loader),\n", | ||
| " loss.item(),\n", | ||
| " )\n", | ||
| " )\n", | ||
| "\n", | ||
| " # Wait for the distributed training to complete\n", | ||
| " dist.barrier()\n", | ||
| " if dist.get_rank() == 0:\n", | ||
| " print(\"Training is finished\")\n", | ||
| "\n", | ||
| " # Finally clean up PyTorch distributed\n", | ||
| " dist.destroy_process_group()" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": { | ||
| "execution": { | ||
| "iopub.execute_input": "2025-09-03T13:19:49.832393Z", | ||
| "iopub.status.busy": "2025-09-03T13:19:49.832117Z", | ||
| "iopub.status.idle": "2025-09-03T13:19:51.924613Z", | ||
| "shell.execute_reply": "2025-09-03T13:19:51.924264Z", | ||
| "shell.execute_reply.started": "2025-09-03T13:19:49.832371Z" | ||
| }, | ||
| "pycharm": { | ||
| "name": "#%%\n" | ||
| } | ||
| }, | ||
| "outputs": [], | ||
| "source": [ | ||
| "from kubeflow.trainer import CustomTrainer, TrainerClient\n", | ||
| "\n", | ||
| "client = TrainerClient()\n" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "for runtime in client.list_runtimes():\n", | ||
| " print(runtime)\n", | ||
| " if runtime.name == \"universal\": # Update to actual universal image runtime once available\n", | ||
| " torch_runtime = runtime" | ||
|
||
| ] | ||
|
||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": { | ||
| "execution": { | ||
| "iopub.execute_input": "2025-09-03T13:19:56.525591Z", | ||
| "iopub.status.busy": "2025-09-03T13:19:56.524936Z", | ||
| "iopub.status.idle": "2025-09-03T13:19:56.721404Z", | ||
| "shell.execute_reply": "2025-09-03T13:19:56.720565Z", | ||
| "shell.execute_reply.started": "2025-09-03T13:19:56.525536Z" | ||
| } | ||
| }, | ||
| "outputs": [], | ||
| "source": [ | ||
| "job_name = client.train(\n", | ||
| " trainer=CustomTrainer(\n", | ||
| " func=train_fashion_mnist,\n", | ||
| " num_nodes=2,\n", | ||
| " resources_per_node={\n", | ||
| " \"cpu\": 2,\n", | ||
| " \"memory\": \"8Gi\",\n", | ||
| " },\n", | ||
| " packages_to_install=[\"torchvision\"],\n", | ||
|
||
| " ),\n", | ||
| " runtime=torch_runtime,\n", | ||
| ")" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": { | ||
| "execution": { | ||
| "iopub.execute_input": "2025-09-03T13:20:01.378158Z", | ||
| "iopub.status.busy": "2025-09-03T13:20:01.377707Z", | ||
| "iopub.status.idle": "2025-09-03T13:20:12.713960Z", | ||
| "shell.execute_reply": "2025-09-03T13:20:12.713295Z", | ||
| "shell.execute_reply.started": "2025-09-03T13:20:01.378130Z" | ||
| } | ||
| }, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# Wait for the running status.\n", | ||
| "client.wait_for_job_status(name=job_name, status={\"Running\"})" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": { | ||
| "execution": { | ||
| "iopub.execute_input": "2025-09-03T13:20:24.045774Z", | ||
| "iopub.status.busy": "2025-09-03T13:20:24.045480Z", | ||
| "iopub.status.idle": "2025-09-03T13:20:24.772877Z", | ||
| "shell.execute_reply": "2025-09-03T13:20:24.772178Z", | ||
| "shell.execute_reply.started": "2025-09-03T13:20:24.045755Z" | ||
| } | ||
| }, | ||
| "outputs": [], | ||
| "source": [ | ||
| "for c in client.get_job(name=job_name).steps:\n", | ||
| " print(f\"Step: {c.name}, Status: {c.status}, Devices: {c.device} x {c.device_count}\\n\")" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": { | ||
| "execution": { | ||
| "iopub.execute_input": "2025-09-03T13:20:26.729486Z", | ||
| "iopub.status.busy": "2025-09-03T13:20:26.728951Z", | ||
| "iopub.status.idle": "2025-09-03T13:20:29.596510Z", | ||
| "shell.execute_reply": "2025-09-03T13:20:29.594741Z", | ||
| "shell.execute_reply.started": "2025-09-03T13:20:26.729446Z" | ||
| } | ||
| }, | ||
| "outputs": [], | ||
| "source": [ | ||
| "for logline in client.get_job_logs(job_name, follow=True):\n", | ||
| " print(logline)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "client.delete_job(job_name)" | ||
| ] | ||
| } | ||
| ], | ||
| "metadata": { | ||
| "kernelspec": { | ||
| "display_name": "Python 3 (ipykernel)", | ||
| "language": "python", | ||
| "name": "python3" | ||
| }, | ||
| "language_info": { | ||
| "codemirror_mode": { | ||
| "name": "ipython", | ||
| "version": 3 | ||
| }, | ||
| "file_extension": ".py", | ||
| "mimetype": "text/x-python", | ||
| "name": "python", | ||
| "nbconvert_exporter": "python", | ||
| "pygments_lexer": "ipython3", | ||
| "version": "3.11.13" | ||
| } | ||
| }, | ||
| "nbformat": 4, | ||
| "nbformat_minor": 4 | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,86 @@ | ||
| /* | ||
| Copyright 2025. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. | ||
| */ | ||
|
|
||
| package sdk_tests | ||
|
|
||
| import ( | ||
| "fmt" | ||
| "os" | ||
| "testing" | ||
|
|
||
| . "github.com/onsi/gomega" | ||
|
|
||
| corev1 "k8s.io/api/core/v1" | ||
|
|
||
| common "github.com/opendatahub-io/distributed-workloads/tests/common" | ||
| support "github.com/opendatahub-io/distributed-workloads/tests/common/support" | ||
| trainerutils "github.com/opendatahub-io/distributed-workloads/tests/trainer/utils" | ||
| ) | ||
|
|
||
| const ( | ||
| notebookName = "mnist.ipynb" | ||
| notebookPath = "resources/" + notebookName | ||
| ) | ||
|
|
||
| // CPU Only - Distributed Training | ||
| func RunFashionMnistCpuDistributedTraining(t *testing.T) { | ||
| test := support.With(t) | ||
|
|
||
| // Create a new test namespace | ||
| namespace := test.NewTestNamespace() | ||
|
|
||
| // Ensure pre-requisites to run the test are met | ||
| trainerutils.EnsureTrainerClusterReady(t, test) | ||
|
|
||
| // Ensure Notebook SA and RBACs are set for this namespace | ||
| trainerutils.EnsureNotebookRBAC(t, test, namespace.Name) | ||
|
|
||
| // RBACs setup | ||
| userName := common.GetNotebookUserName(test) | ||
| userToken := common.GetNotebookUserToken(test) | ||
| support.CreateUserRoleBindingWithClusterRole(test, userName, namespace.Name, "admin") | ||
|
|
||
| // Read notebook from directory | ||
| localPath := notebookPath | ||
| nb, err := os.ReadFile(localPath) | ||
| test.Expect(err).NotTo(HaveOccurred(), fmt.Sprintf("failed to read notebook: %s", localPath)) | ||
|
|
||
| // Create ConfigMap with notebook | ||
| cm := support.CreateConfigMap(test, namespace.Name, map[string][]byte{notebookName: nb}) | ||
|
|
||
| // Build command | ||
| marker := "/opt/app-root/src/notebook_completion_marker" | ||
| shellCmd := trainerutils.BuildPapermillShellCmd(notebookName, marker, nil) | ||
| command := []string{"/bin/sh", "-c", shellCmd} | ||
|
|
||
| // Create Notebook CR (with default 10Gi PVC) | ||
| pvc := support.CreatePersistentVolumeClaim(test, namespace.Name, "10Gi", support.AccessModes(corev1.ReadWriteOnce)) | ||
| common.CreateNotebook(test, namespace, userToken, command, cm.Name, notebookName, 0, pvc, common.ContainerSizeSmall) | ||
|
|
||
| // Cleanup | ||
| defer func() { | ||
| common.DeleteNotebook(test, namespace) | ||
| test.Eventually(common.Notebooks(test, namespace), support.TestTimeoutLong).Should(HaveLen(0)) | ||
| }() | ||
|
|
||
| // Wait for the Notebook Pod and get pod/container names | ||
| podName, containerName := trainerutils.WaitForNotebookPodRunning(test, namespace.Name) | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it have sense to add assertion checking that TrainJob is created and successfully finished?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I think this is very important :) |
||
| // Poll marker file to check if the notebook execution completed successfully | ||
| if err := trainerutils.PollNotebookCompletionMarker(test, namespace.Name, podName, containerName, marker, support.TestTimeoutDouble); err != nil { | ||
| test.Expect(err).To(Succeed(), "Notebook execution reported FAILURE") | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will become an issue for running tests on disconnected clusters.
In Trainer v1 tests we uploaded dataset on AWS S3, it is downloaded from there if AWS env variables are declared - https://github.com/opendatahub-io/distributed-workloads/blob/main/tests/kfto/resources/kfto_sdk_mnist.py#L67