@@ -30,13 +30,13 @@ import (
3030)
3131
3232func TestRayFinetuneLlmDeepspeedDemoLlama_2_7b (t * testing.T ) {
33- rayFinetuneLlmDeepspeed (t , 1 , "zero_3_llama_2_7b.json" )
33+ rayFinetuneLlmDeepspeed (t , 1 , "meta-llama/Llama-2-7b-chat-hf" , " zero_3_llama_2_7b.json" )
3434}
3535func TestRayFinetuneLlmDeepspeedDemoLlama_31_8b (t * testing.T ) {
36- rayFinetuneLlmDeepspeed (t , 1 , "zero_3_offload_optim_param.json" )
36+ rayFinetuneLlmDeepspeed (t , 1 , "meta-llama/Meta-Llama-3.1-8B" , " zero_3_offload_optim_param.json" )
3737}
3838
39- func rayFinetuneLlmDeepspeed (t * testing.T , numGpus int , modelConfigFile string ) {
39+ func rayFinetuneLlmDeepspeed (t * testing.T , numGpus int , modelName string , modelConfigFile string ) {
4040 test := With (t )
4141
4242 // Create a namespace
@@ -56,21 +56,22 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int, modelConfigFile string)
5656 "import os" : "import os,time,sys" ,
5757 "import sys" : "!cp /opt/app-root/notebooks/* ./\\ n\" ,\n \t \" !ls" ,
5858 "from codeflare_sdk.cluster.auth import TokenAuthentication" : "from codeflare_sdk.cluster.auth import TokenAuthentication\\ n\" ,\n \t \" from codeflare_sdk.job import RayJobClient" ,
59- "token = ''" : fmt .Sprintf ("token = '%s'" , userToken ),
60- "server = ''" : fmt .Sprintf ("server = '%s'" , GetOpenShiftApiUrl (test )),
61- "namespace='ray-finetune-llm-deepspeed'" : fmt .Sprintf ("namespace='%s'" , namespace .Name ),
62- "head_cpus=16" : "head_cpus=2" ,
63- "head_extended_resource_requests=1" : "head_extended_resource_requests=0" ,
64- "num_workers=7" : "num_workers=1" ,
65- "worker_cpu_requests=16" : "worker_cpu_requests=4" ,
66- "worker_cpu_limits=16" : "worker_cpu_limits=4" ,
67- "worker_memory_requests=128" : "worker_memory_requests=64" ,
68- "worker_memory_limits=256" : "worker_memory_limits=128" ,
69- "head_memory=128" : "head_memory=48" ,
70- "client = cluster.job_client" : "ray_dashboard = cluster.cluster_dashboard_uri()\\ n\" ,\n \t \" header = {\\ \" Authorization\\ \" : \\ \" Bearer " + userToken + "\\ \" }\\ n\" ,\n \t \" client = RayJobClient(address=ray_dashboard, headers=header, verify=False)\\ n" ,
71- "--num-devices=8" : fmt .Sprintf ("--num-devices=%d" , numGpus ),
72- "--num-epochs=3" : fmt .Sprintf ("--num-epochs=%d" , 1 ),
73- "--ds-config=./deepspeed_configs/zero_3_offload_optim+param.json" : fmt .Sprintf ("--ds-config=./%s \\ \" \\ n\" ,\n \t \" \\ \" --lora-config=./lora.json \\ \" \\ n\" ,\n \t \" \\ \" --as-test" , modelConfigFile ),
59+ "token = ''" : fmt .Sprintf ("token = '%s'" , userToken ),
60+ "server = ''" : fmt .Sprintf ("server = '%s'" , GetOpenShiftApiUrl (test )),
61+ "namespace='ray-finetune-llm-deepspeed'" : fmt .Sprintf ("namespace='%s'" , namespace .Name ),
62+ "head_cpus=16" : "head_cpus=2" ,
63+ "head_extended_resource_requests=1" : "head_extended_resource_requests=0" ,
64+ "num_workers=7" : "num_workers=1" ,
65+ "worker_cpu_requests=16" : "worker_cpu_requests=4" ,
66+ "worker_cpu_limits=16" : "worker_cpu_limits=4" ,
67+ "worker_memory_requests=128" : "worker_memory_requests=64" ,
68+ "worker_memory_limits=256" : "worker_memory_limits=128" ,
69+ "head_memory=128" : "head_memory=48" ,
70+ "client = cluster.job_client" : "ray_dashboard = cluster.cluster_dashboard_uri()\\ n\" ,\n \t \" header = {\\ \" Authorization\\ \" : \\ \" Bearer " + userToken + "\\ \" }\\ n\" ,\n \t \" client = RayJobClient(address=ray_dashboard, headers=header, verify=False)\\ n" ,
71+ "--num-devices=8" : fmt .Sprintf ("--num-devices=%d" , numGpus ),
72+ "--num-epochs=3" : fmt .Sprintf ("--num-epochs=%d" , 1 ),
73+ "--model-name=meta-llama/Meta-Llama-3.1-8B" : fmt .Sprintf ("--model-name=%s" , modelName ),
74+ "--ds-config=./deepspeed_configs/zero_3_offload_optim_param.json" : fmt .Sprintf ("--ds-config=./%s \\ \" \\ n\" ,\n \t \" \\ \" --lora-config=./lora.json \\ \" \\ n\" ,\n \t \" \\ \" --as-test" , modelConfigFile ),
7475 "--batch-size-per-device=32" : "--batch-size-per-device=6" ,
7576 "--eval-batch-size-per-device=32" : "--eval-batch-size-per-device=6" ,
7677 "'pip': 'requirements.txt'" : "'pip': '/opt/app-root/src/requirements.txt'" ,
@@ -83,7 +84,6 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int, modelConfigFile string)
8384 updatedNotebookContent = strings .Replace (updatedNotebookContent , oldValue , newValue , - 1 )
8485 }
8586 updatedNotebook := []byte (updatedNotebookContent )
86- os .WriteFile ("demo.ipynb" , updatedNotebook , 0644 )
8787
8888 // Test configuration
8989 jupyterNotebookConfigMapFileName := "ray_finetune_llm_deepspeed.ipynb"
@@ -117,8 +117,6 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int, modelConfigFile string)
117117 ),
118118 )
119119
120- time .Sleep (30 * time .Second )
121-
122120 // Fetch created raycluster
123121 rayClusterName := "ray"
124122 rayCluster , err := test .Client ().Ray ().RayV1 ().RayClusters (namespace .Name ).Get (test .Ctx (), rayClusterName , metav1.GetOptions {})
@@ -128,37 +126,44 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int, modelConfigFile string)
128126 dashboardUrl := GetDashboardUrl (test , namespace , rayCluster )
129127 rayClusterClientConfig := RayClusterClientConfig {Address : dashboardUrl .String (), Client : nil , InsecureSkipVerify : true }
130128 rayClient , err := NewRayClusterClient (rayClusterClientConfig , test .Config ().BearerToken )
131- if err != nil {
132- test .T ().Errorf ("%s" , err )
133- }
129+ test .Expect (err ).ToNot (HaveOccurred (), fmt .Sprintf ("Failed to create new raycluster client: %s" , err ))
134130
131+ // wait until rayjob exists
132+ test .Eventually (func () []RayJobDetailsResponse {
133+ rayJobs , err := rayClient .GetJobs ()
134+ test .Expect (err ).ToNot (HaveOccurred (), fmt .Sprintf ("Failed to fetch ray-jobs : %s" , err ))
135+ return * rayJobs
136+ }, TestTimeoutMedium , 1 * time .Second ).Should (HaveLen (1 ), "Ray job not found" )
137+
138+ // Get test job-id
135139 jobID := GetTestJobId (test , rayClient , dashboardUrl .Host )
136- test .Expect (jobID ).ToNot (Equal ( nil ))
140+ test .Expect (jobID ).ToNot (BeEmpty ( ))
137141
138142 // Wait for the job to be succeeded or failed
139143 var rayJobStatus string
140- fmt . Printf ("Waiting for job to be Succeeded...\n " )
144+ test . T (). Logf ("Waiting for job to be Succeeded...\n " )
141145 test .Eventually (func () string {
142146 resp , err := rayClient .GetJobDetails (jobID )
143- test .Expect (err ).ToNot (HaveOccurred ())
147+ test .Expect (err ).ToNot (HaveOccurred (), fmt . Sprintf ( "Failed to get job details :%s" , err ) )
144148 rayJobStatusVal := resp .Status
145149 if rayJobStatusVal == "SUCCEEDED" || rayJobStatusVal == "FAILED" {
146- fmt . Printf ( "JobStatus : %s\n " , rayJobStatusVal )
150+ test . T (). Logf ( "JobStatus - %s\n " , rayJobStatusVal )
147151 rayJobStatus = rayJobStatusVal
148- WriteRayJobAPILogs (test , rayClient , jobID )
149152 return rayJobStatus
150153 }
151154 if rayJobStatus != rayJobStatusVal && rayJobStatusVal != "SUCCEEDED" {
152- fmt . Printf ( "JobStatus : %s...\n " , rayJobStatusVal )
155+ test . T (). Logf ( "JobStatus - %s...\n " , rayJobStatusVal )
153156 rayJobStatus = rayJobStatusVal
154157 }
155158 return rayJobStatus
156- }, TestTimeoutDouble , 3 * time .Second ).Should (Or (Equal ("SUCCEEDED" ), Equal ("FAILED" )), "Job did not complete within the expected time" )
159+ }, TestTimeoutDouble , 1 * time .Second ).Should (Or (Equal ("SUCCEEDED" ), Equal ("FAILED" )), "Job did not complete within the expected time" )
157160 // Store job logs in output directory
158161 WriteRayJobAPILogs (test , rayClient , jobID )
162+
163+ // Assert ray-job status after job execution
159164 test .Expect (rayJobStatus ).To (Equal ("SUCCEEDED" ), "RayJob failed !" )
160165
161166 // Make sure the RayCluster finishes and is deleted
162- test .Eventually (RayClusters (test , namespace .Name ), TestTimeoutMedium ).
163- Should (HaveLen ( 0 ))
167+ test .Eventually (RayClusters (test , namespace .Name ), TestTimeoutLong ).
168+ Should (BeEmpty ( ))
164169}
0 commit comments