@@ -19,30 +19,40 @@ package kfto
1919import (
2020 "testing"
2121
22- "github.com/onsi/gomega"
2322 . "github.com/onsi/gomega"
2423 . "github.com/project-codeflare/codeflare-common/support"
24+ kueuev1beta1 "sigs.k8s.io/kueue/apis/kueue/v1beta1"
2525
2626 corev1 "k8s.io/api/core/v1"
27+ "k8s.io/apimachinery/pkg/api/resource"
2728 metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2829
2930 kftov1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
3031)
3132
32- func PytorchJob (t Test , namespace , name string ) func (g gomega. Gomega ) * kftov1.PyTorchJob {
33- return func (g gomega. Gomega ) * kftov1.PyTorchJob {
33+ func PytorchJob (t Test , namespace , name string ) func (g Gomega ) * kftov1.PyTorchJob {
34+ return func (g Gomega ) * kftov1.PyTorchJob {
3435 job , err := t .Client ().Kubeflow ().KubeflowV1 ().PyTorchJobs (namespace ).Get (t .Ctx (), name , metav1.GetOptions {})
35- g .Expect (err ).NotTo (gomega . HaveOccurred ())
36+ g .Expect (err ).NotTo (HaveOccurred ())
3637 return job
3738 }
3839}
3940
40- // s
41- func PytorchJobCondition (job * kftov1.PyTorchJob ) string {
42- if len (job .Status .Conditions ) == 0 {
43- return ""
41+ func PytorchJobConditionRunning (job * kftov1.PyTorchJob ) corev1.ConditionStatus {
42+ return PytorchJobCondition (job , kftov1 .JobRunning )
43+ }
44+
45+ func PytorchJobConditionSucceeded (job * kftov1.PyTorchJob ) corev1.ConditionStatus {
46+ return PytorchJobCondition (job , kftov1 .JobSucceeded )
47+ }
48+
49+ func PytorchJobCondition (job * kftov1.PyTorchJob , conditionType kftov1.JobConditionType ) corev1.ConditionStatus {
50+ for _ , condition := range job .Status .Conditions {
51+ if condition .Type == conditionType {
52+ return condition .Status
53+ }
4454 }
45- return job . Status . Conditions [ len ( job . Status . Conditions ) - 1 ]. Reason
55+ return corev1 . ConditionUnknown
4656}
4757
4858func TestPytorchjobWithSFTtrainer (t * testing.T ) {
@@ -51,29 +61,45 @@ func TestPytorchjobWithSFTtrainer(t *testing.T) {
5161
5262 // Create a namespace
5363 namespace := test .NewTestNamespace ()
54- config := & corev1.ConfigMap {
55- TypeMeta : metav1.TypeMeta {
56- APIVersion : corev1 .SchemeGroupVersion .String (),
57- Kind : "ConfigMap" ,
58- },
59- ObjectMeta : metav1.ObjectMeta {
60- Name : "my-config" ,
61- Namespace : namespace .Name ,
62- Labels : map [string ]string {
63- "kueue.x-k8s.io/queue-name" : "lq-trainer" ,
64+
65+ // Create a ConfigMap with training dataset and configuration
66+ configData := map [string ][]byte {
67+ "config.json" : ReadFile (test , "config.json" ),
68+ "twitter_complaints_small.json" : ReadFile (test , "twitter_complaints_small.json" ),
69+ }
70+ config := CreateConfigMap (test , namespace .Name , configData )
71+
72+ // Create Kueue resources
73+ resourceFlavor := CreateKueueResourceFlavor (test , kueuev1beta1.ResourceFlavorSpec {})
74+ defer test .Client ().Kueue ().KueueV1beta1 ().ResourceFlavors ().Delete (test .Ctx (), resourceFlavor .Name , metav1.DeleteOptions {})
75+ cqSpec := kueuev1beta1.ClusterQueueSpec {
76+ NamespaceSelector : & metav1.LabelSelector {},
77+ ResourceGroups : []kueuev1beta1.ResourceGroup {
78+ {
79+ CoveredResources : []corev1.ResourceName {corev1 .ResourceName ("cpu" ), corev1 .ResourceName ("memory" )},
80+ Flavors : []kueuev1beta1.FlavorQuotas {
81+ {
82+ Name : kueuev1beta1 .ResourceFlavorReference (resourceFlavor .Name ),
83+ Resources : []kueuev1beta1.ResourceQuota {
84+ {
85+ Name : corev1 .ResourceCPU ,
86+ NominalQuota : resource .MustParse ("8" ),
87+ },
88+ {
89+ Name : corev1 .ResourceMemory ,
90+ NominalQuota : resource .MustParse ("12Gi" ),
91+ },
92+ },
93+ },
94+ },
6495 },
6596 },
66- BinaryData : map [string ][]byte {
67- "config.json" : ReadFile (test , "config.json" ),
68- "twitter_complaints_small.json" : ReadFile (test , "twitter_complaints_small.json" ),
69- },
70- Immutable : Ptr (true ),
7197 }
98+ clusterQueue := CreateKueueClusterQueue (test , cqSpec )
99+ defer test .Client ().Kueue ().KueueV1beta1 ().ClusterQueues ().Delete (test .Ctx (), clusterQueue .Name , metav1.DeleteOptions {})
100+ localQueue := CreateKueueLocalQueue (test , namespace .Name , clusterQueue .Name )
72101
73- config , err := test .Client ().Core ().CoreV1 ().ConfigMaps (namespace .Name ).Create (test .Ctx (), config , metav1.CreateOptions {})
74- test .Expect (err ).NotTo (HaveOccurred ())
75- test .T ().Logf ("Created ConfigMap %s/%s successfully" , config .Namespace , config .Name )
76-
102+ // Run training PyTorch job
77103 tuningJob := & kftov1.PyTorchJob {
78104 TypeMeta : metav1.TypeMeta {
79105 APIVersion : corev1 .SchemeGroupVersion .String (),
@@ -82,18 +108,21 @@ func TestPytorchjobWithSFTtrainer(t *testing.T) {
82108 ObjectMeta : metav1.ObjectMeta {
83109 Name : "kfto-sft" ,
84110 Namespace : namespace .Name ,
111+ Labels : map [string ]string {
112+ "kueue.x-k8s.io/queue-name" : localQueue .Name ,
113+ },
85114 },
86115 Spec : kftov1.PyTorchJobSpec {
87116 PyTorchReplicaSpecs : map [kftov1.ReplicaType ]* kftov1.ReplicaSpec {
88- "Master" : & kftov1. ReplicaSpec {
117+ "Master" : {
89118 Replicas : Ptr (int32 (1 )),
90119 RestartPolicy : "Never" ,
91120 Template : corev1.PodTemplateSpec {
92121 Spec : corev1.PodSpec {
93122 Containers : []corev1.Container {
94123 {
95124 Name : "pytorch" ,
96- Image : "quay.io/tedchang/sft-trainer:dev" ,
125+ Image : GetFmsHfTuningImage () ,
97126 ImagePullPolicy : corev1 .PullIfNotPresent ,
98127 Command : []string {"python" , "/app/launch_training.py" },
99128 Env : []corev1.EnvVar {
@@ -108,6 +137,12 @@ func TestPytorchjobWithSFTtrainer(t *testing.T) {
108137 MountPath : "/etc/config" ,
109138 },
110139 },
140+ Resources : corev1.ResourceRequirements {
141+ Requests : corev1.ResourceList {
142+ corev1 .ResourceCPU : resource .MustParse ("2" ),
143+ corev1 .ResourceMemory : resource .MustParse ("5Gi" ),
144+ },
145+ },
111146 },
112147 },
113148 Volumes : []corev1.Volume {
@@ -116,7 +151,7 @@ func TestPytorchjobWithSFTtrainer(t *testing.T) {
116151 VolumeSource : corev1.VolumeSource {
117152 ConfigMap : & corev1.ConfigMapVolumeSource {
118153 LocalObjectReference : corev1.LocalObjectReference {
119- Name : "my- config" ,
154+ Name : config . Name ,
120155 },
121156 Items : []corev1.KeyToPath {
122157 {
@@ -139,10 +174,24 @@ func TestPytorchjobWithSFTtrainer(t *testing.T) {
139174 },
140175 }
141176
142- tuningJob , err = test .Client ().Kubeflow ().KubeflowV1 ().PyTorchJobs (namespace .Name ).Create (test .Ctx (), tuningJob , metav1.CreateOptions {})
177+ tuningJob , err : = test .Client ().Kubeflow ().KubeflowV1 ().PyTorchJobs (namespace .Name ).Create (test .Ctx (), tuningJob , metav1.CreateOptions {})
143178 test .Expect (err ).NotTo (HaveOccurred ())
144179 test .T ().Logf ("Created PytorchJob %s/%s successfully" , tuningJob .Namespace , tuningJob .Name )
145180
146- test .Eventually (PytorchJob (test , namespace .Name , tuningJob .Name ), TestTimeoutLong ).Should (WithTransform (PytorchJobCondition , Equal ("PyTorchJobSucceeded" )))
181+ // Make sure the Kueue Workload is admitted
182+ test .Eventually (KueueWorkloads (test , namespace .Name ), TestTimeoutLong ).
183+ Should (
184+ And (
185+ HaveLen (1 ),
186+ ContainElement (WithTransform (KueueWorkloadAdmitted , BeTrueBecause ("Workload failed to be admitted" ))),
187+ ),
188+ )
189+
190+ // Make sure the PyTorch job is running
191+ test .Eventually (PytorchJob (test , namespace .Name , tuningJob .Name ), TestTimeoutShort ).
192+ Should (WithTransform (PytorchJobConditionRunning , Equal (corev1 .ConditionTrue )))
193+
194+ // Make sure the PyTorch job succeed
195+ test .Eventually (PytorchJob (test , namespace .Name , tuningJob .Name ), TestTimeoutLong ).Should (WithTransform (PytorchJobConditionSucceeded , Equal (corev1 .ConditionTrue )))
147196 test .T ().Logf ("PytorchJob %s/%s ran successfully" , tuningJob .Namespace , tuningJob .Name )
148197}
0 commit comments