@@ -1075,6 +1075,96 @@ func TestWaitAsync(t *testing.T) {
10751075 assert .Equal (t , replicate .Succeeded , lastStatus )
10761076}
10771077
1078+ func TestRun (t * testing.T ) {
1079+ mockServer := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
1080+ switch r .URL .Path {
1081+ case "/predictions" :
1082+ assert .Equal (t , http .MethodPost , r .Method )
1083+ prediction := replicate.Prediction {
1084+ ID : "gtsllfynndufawqhdngldkdrkq" ,
1085+ Version : "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" ,
1086+ Status : replicate .Starting ,
1087+ }
1088+ json .NewEncoder (w ).Encode (prediction )
1089+ case "/predictions/gtsllfynndufawqhdngldkdrkq" :
1090+ assert .Equal (t , http .MethodGet , r .Method )
1091+ prediction := replicate.Prediction {
1092+ ID : "gtsllfynndufawqhdngldkdrkq" ,
1093+ Version : "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" ,
1094+ Status : replicate .Succeeded ,
1095+ Output : "Hello, world!" ,
1096+ }
1097+ json .NewEncoder (w ).Encode (prediction )
1098+ default :
1099+ t .Fatalf ("Unexpected request to %s" , r .URL .Path )
1100+ }
1101+ }))
1102+ defer mockServer .Close ()
1103+
1104+ client , err := replicate .NewClient (
1105+ replicate .WithToken ("test-token" ),
1106+ replicate .WithBaseURL (mockServer .URL ),
1107+ )
1108+ require .NoError (t , err )
1109+
1110+ ctx := context .Background ()
1111+ input := replicate.PredictionInput {"prompt" : "Hello" }
1112+ output , err := client .Run (ctx , "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" , input , nil )
1113+
1114+ require .NoError (t , err )
1115+ assert .NotNil (t , output )
1116+ assert .Equal (t , "Hello, world!" , output )
1117+ }
1118+
1119+ func TestRunReturningModelError (t * testing.T ) {
1120+ mockServer := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
1121+ switch r .URL .Path {
1122+ case "/predictions" :
1123+ assert .Equal (t , http .MethodPost , r .Method )
1124+ prediction := replicate.Prediction {
1125+ ID : "fynndufawqhdngldkgtslldrkq" ,
1126+ Version : "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" ,
1127+ Status : replicate .Starting ,
1128+ }
1129+ json .NewEncoder (w ).Encode (prediction )
1130+ case "/predictions/fynndufawqhdngldkgtslldrkq" :
1131+ assert .Equal (t , http .MethodGet , r .Method )
1132+
1133+ logs := "Could not say hello"
1134+ prediction := replicate.Prediction {
1135+ ID : "fynndufawqhdngldkgtslldrkq" ,
1136+ Version : "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" ,
1137+ Status : replicate .Failed ,
1138+ Logs : & logs ,
1139+ Error : "Model execution failed" ,
1140+ }
1141+ json .NewEncoder (w ).Encode (prediction )
1142+ default :
1143+ t .Fatalf ("Unexpected request to %s" , r .URL .Path )
1144+ }
1145+ }))
1146+ defer mockServer .Close ()
1147+
1148+ client , err := replicate .NewClient (
1149+ replicate .WithToken ("test-token" ),
1150+ replicate .WithBaseURL (mockServer .URL ),
1151+ )
1152+ require .NoError (t , err )
1153+
1154+ ctx := context .Background ()
1155+ input := replicate.PredictionInput {"prompt" : "Hello" }
1156+ _ , err = client .Run (ctx , "replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" , input , nil )
1157+
1158+ require .Error (t , err )
1159+ modelErr , ok := err .(* replicate.ModelError )
1160+ require .True (t , ok , "Expected error to be of type *replicate.ModelError" )
1161+ assert .Equal (t , "model error: Model execution failed" , modelErr .Error ())
1162+ assert .Equal (t , "fynndufawqhdngldkgtslldrkq" , modelErr .Prediction .ID )
1163+ assert .Equal (t , replicate .Failed , modelErr .Prediction .Status )
1164+ assert .Equal (t , "Model execution failed" , modelErr .Prediction .Error )
1165+ assert .Equal (t , "Could not say hello" , * modelErr .Prediction .Logs )
1166+ }
1167+
10781168func TestCreateTraining (t * testing.T ) {
10791169 mockServer := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
10801170 assert .Equal (t , http .MethodPost , r .Method )
0 commit comments