Skip to content

Commit 48b7cbb

Browse files
authored
Add support for deployments.get endpoint (#41)
1 parent 157dd19 commit 48b7cbb

File tree

2 files changed

+118
-0
lines changed

2 files changed

+118
-0
lines changed

client_test.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1653,3 +1653,69 @@ func TestValidateWebhook(t *testing.T) {
16531653
require.NoError(t, err)
16541654
assert.True(t, isValid)
16551655
}
1656+
1657+
func TestGetDeployment(t *testing.T) {
1658+
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1659+
assert.Equal(t, "/deployments/acme/image-upscaler", r.URL.Path)
1660+
assert.Equal(t, http.MethodGet, r.Method)
1661+
1662+
deployment := &replicate.Deployment{
1663+
Owner: "acme",
1664+
Name: "image-upscaler",
1665+
CurrentRelease: replicate.DeploymentRelease{
1666+
Number: 1,
1667+
Model: "acme/esrgan",
1668+
Version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
1669+
CreatedAt: "2022-01-01T00:00:00Z",
1670+
CreatedBy: replicate.Account{
1671+
Type: "organization",
1672+
Username: "acme",
1673+
Name: "Acme, Inc.",
1674+
},
1675+
Configuration: replicate.DeploymentConfiguration{
1676+
Hardware: "gpu-t4",
1677+
MinInstances: 1,
1678+
MaxInstances: 5,
1679+
},
1680+
},
1681+
}
1682+
1683+
responseBytes, err := json.Marshal(deployment)
1684+
if err != nil {
1685+
t.Fatal(err)
1686+
}
1687+
1688+
w.WriteHeader(http.StatusOK)
1689+
w.Write(responseBytes)
1690+
}))
1691+
defer mockServer.Close()
1692+
1693+
client, err := replicate.NewClient(
1694+
replicate.WithToken("test-token"),
1695+
replicate.WithBaseURL(mockServer.URL),
1696+
)
1697+
require.NotNil(t, client)
1698+
require.NoError(t, err)
1699+
1700+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
1701+
defer cancel()
1702+
1703+
deployment, err := client.GetDeployment(ctx, "acme", "image-upscaler")
1704+
if err != nil {
1705+
t.Fatal(err)
1706+
}
1707+
1708+
assert.NotNil(t, deployment)
1709+
assert.Equal(t, "acme", deployment.Owner)
1710+
assert.Equal(t, "image-upscaler", deployment.Name)
1711+
assert.Equal(t, 1, deployment.CurrentRelease.Number)
1712+
assert.Equal(t, "acme/esrgan", deployment.CurrentRelease.Model)
1713+
assert.Equal(t, "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", deployment.CurrentRelease.Version)
1714+
assert.Equal(t, "2022-01-01T00:00:00Z", deployment.CurrentRelease.CreatedAt)
1715+
assert.Equal(t, "organization", deployment.CurrentRelease.CreatedBy.Type)
1716+
assert.Equal(t, "acme", deployment.CurrentRelease.CreatedBy.Username)
1717+
assert.Equal(t, "Acme, Inc.", deployment.CurrentRelease.CreatedBy.Name)
1718+
assert.Equal(t, "gpu-t4", deployment.CurrentRelease.Configuration.Hardware)
1719+
assert.Equal(t, 1, deployment.CurrentRelease.Configuration.MinInstances)
1720+
assert.Equal(t, 5, deployment.CurrentRelease.Configuration.MaxInstances)
1721+
}

deployment.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,61 @@ package replicate
22

33
import (
44
"context"
5+
"encoding/json"
56
"fmt"
67
)
78

9+
type Deployment struct {
10+
Owner string `json:"owner"`
11+
Name string `json:"name"`
12+
CurrentRelease DeploymentRelease `json:"current_release"`
13+
14+
rawJSON json.RawMessage `json:"-"`
15+
}
16+
17+
type DeploymentRelease struct {
18+
Number int `json:"number"`
19+
Model string `json:"model"`
20+
Version string `json:"version"`
21+
CreatedAt string `json:"created_at"`
22+
CreatedBy Account `json:"created_by"`
23+
Configuration DeploymentConfiguration `json:"configuration"`
24+
}
25+
26+
type DeploymentConfiguration struct {
27+
Hardware string `json:"hardware"`
28+
MinInstances int `json:"min_instances"`
29+
MaxInstances int `json:"max_instances"`
30+
}
31+
32+
func (d Deployment) MarshalJSON() ([]byte, error) {
33+
if d.rawJSON != nil {
34+
return d.rawJSON, nil
35+
} else {
36+
type Alias Deployment
37+
return json.Marshal(&struct{ *Alias }{Alias: (*Alias)(&d)})
38+
}
39+
}
40+
41+
func (d *Deployment) UnmarshalJSON(data []byte) error {
42+
d.rawJSON = data
43+
type Alias Deployment
44+
alias := &struct{ *Alias }{Alias: (*Alias)(d)}
45+
return json.Unmarshal(data, alias)
46+
}
47+
48+
// GetDeployment retrieves the details of a specific deployment.
49+
func (r *Client) GetDeployment(ctx context.Context, deployment_owner string, deployment_name string) (*Deployment, error) {
50+
deployment := &Deployment{}
51+
path := fmt.Sprintf("/deployments/%s/%s", deployment_owner, deployment_name)
52+
err := r.fetch(ctx, "GET", path, nil, deployment)
53+
if err != nil {
54+
return nil, fmt.Errorf("failed to get deployment: %w", err)
55+
}
56+
57+
return deployment, nil
58+
}
59+
860
// CreateDeploymentPrediction sends a request to the Replicate API to create a prediction using the specified deployment.
961
func (r *Client) CreatePredictionWithDeployment(ctx context.Context, deployment_owner string, deployment_name string, input PredictionInput, webhook *Webhook, stream bool) (*Prediction, error) {
1062
data := map[string]interface{}{

0 commit comments

Comments
 (0)