Skip to content

Commit ca660fe

Browse files
mehmettokgozmattt
andauthored
Add support for /predictions/{id}/cancel endpoint (#46)
* Add cancel prediction endpoint. * Add documentation. * Add test for CancelPrediction * Reorder training methods --------- Co-authored-by: Mattt Zmuda <[email protected]>
1 parent b0335b6 commit ca660fe

File tree

3 files changed

+55
-10
lines changed

3 files changed

+55
-10
lines changed

client_test.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,41 @@ func TestCreatePredictionWithModel(t *testing.T) {
610610
assert.Equal(t, replicate.Starting, prediction.Status)
611611
}
612612

613+
func TestCancelPrediction(t *testing.T) {
614+
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
615+
assert.Equal(t, "POST", r.Method)
616+
assert.Equal(t, "/predictions/ufawqhfynnddngldkgtslldrkq/cancel", r.URL.Path)
617+
618+
response := replicate.Prediction{
619+
ID: "ufawqhfynnddngldkgtslldrkq",
620+
Status: replicate.Canceled,
621+
}
622+
responseBytes, err := json.Marshal(response)
623+
if err != nil {
624+
t.Fatal(err)
625+
}
626+
627+
w.WriteHeader(http.StatusOK)
628+
w.Write(responseBytes)
629+
}))
630+
defer mockServer.Close()
631+
632+
client, err := replicate.NewClient(
633+
replicate.WithToken("test-token"),
634+
replicate.WithBaseURL(mockServer.URL),
635+
)
636+
require.NotNil(t, client)
637+
require.NoError(t, err)
638+
639+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
640+
defer cancel()
641+
642+
prediction, err := client.CancelPrediction(ctx, "ufawqhfynnddngldkgtslldrkq")
643+
assert.NoError(t, err)
644+
assert.Equal(t, "ufawqhfynnddngldkgtslldrkq", prediction.ID)
645+
assert.Equal(t, replicate.Canceled, prediction.Status)
646+
}
647+
613648
func TestPredictionProgress(t *testing.T) {
614649
prediction := replicate.Prediction{
615650
ID: "ufawqhfynnddngldkgtslldrkq",

prediction.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,13 @@ func (r *Client) GetPrediction(ctx context.Context, id string) (*Prediction, err
141141
}
142142
return prediction, nil
143143
}
144+
145+
// CancelPrediction cancels a running prediction by its ID.
146+
func (r *Client) CancelPrediction(ctx context.Context, id string) (*Prediction, error) {
147+
prediction := &Prediction{}
148+
err := r.fetch(ctx, "POST", fmt.Sprintf("/predictions/%s/cancel", id), nil, prediction)
149+
if err != nil {
150+
return nil, fmt.Errorf("failed to cancel prediction: %w", err)
151+
}
152+
return prediction, nil
153+
}

training.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,16 @@ func (r *Client) CreateTraining(ctx context.Context, model_owner string, model_n
3131
return training, nil
3232
}
3333

34+
// ListTrainings returns a list of trainings.
35+
func (r *Client) ListTrainings(ctx context.Context) (*Page[Training], error) {
36+
response := &Page[Training]{}
37+
err := r.fetch(ctx, "GET", "/trainings", nil, response)
38+
if err != nil {
39+
return nil, fmt.Errorf("failed to list trainings: %w", err)
40+
}
41+
return response, nil
42+
}
43+
3444
// GetTraining sends a request to the Replicate API to get a training.
3545
func (r *Client) GetTraining(ctx context.Context, trainingID string) (*Training, error) {
3646
training := &Training{}
@@ -52,13 +62,3 @@ func (r *Client) CancelTraining(ctx context.Context, trainingID string) (*Traini
5262

5363
return training, nil
5464
}
55-
56-
// ListTrainings returns a list of trainings.
57-
func (r *Client) ListTrainings(ctx context.Context) (*Page[Training], error) {
58-
response := &Page[Training]{}
59-
err := r.fetch(ctx, "GET", "/trainings", nil, response)
60-
if err != nil {
61-
return nil, fmt.Errorf("failed to list trainings: %w", err)
62-
}
63-
return response, nil
64-
}

0 commit comments

Comments
 (0)