Skip to content

Commit fd7636e

Browse files
Restructure files a bit
1 parent 089b288 commit fd7636e

File tree

9 files changed

+323
-289
lines changed

9 files changed

+323
-289
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
<img src="logo.jpg" alt="Logo" width="300" align="right">
44

5-
[![CI](https://github.com/maragudk/openai/actions/workflows/ci.yml/badge.svg)](https://github.com/maragudk/openai/actions/workflows/ci.yml)
5+
[![CI](https://github.com/maragudk/gai-openai/actions/workflows/ci.yml/badge.svg)](https://github.com/maragudk/gai-openai/actions/workflows/ci.yml)
66

77
[GAI](https://github.com/maragudk/gai) client for [OpenAI](https://openai.com) models and compatible APIs, such as [LlamaCPP](https://github.com/ggml-org/llama.cpp).
88

chat_complete.go

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
package openai
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"log/slog"
7+
8+
"github.com/openai/openai-go"
9+
"maragu.dev/gai"
10+
)
11+
12+
type ChatCompleteModel string
13+
14+
const (
15+
ChatCompleteModelGPT4o = ChatCompleteModel(openai.ChatModelGPT4o)
16+
ChatCompleteModelGPT4oMini = ChatCompleteModel(openai.ChatModelGPT4oMini)
17+
)
18+
19+
type ChatCompleter struct {
20+
Client *openai.Client
21+
log *slog.Logger
22+
model ChatCompleteModel
23+
}
24+
25+
type NewChatCompleterOptions struct {
26+
Model ChatCompleteModel
27+
}
28+
29+
func (c *Client) NewChatCompleter(opts NewChatCompleterOptions) *ChatCompleter {
30+
return &ChatCompleter{
31+
Client: c.Client,
32+
log: c.log,
33+
model: opts.Model,
34+
}
35+
}
36+
37+
// ChatComplete satisfies [gai.ChatCompleter].
38+
func (c *ChatCompleter) ChatComplete(ctx context.Context, req gai.ChatCompleteRequest) (gai.ChatCompleteResponse, error) {
39+
var messages []openai.ChatCompletionMessageParamUnion
40+
41+
for _, m := range req.Messages {
42+
switch m.Role {
43+
case gai.MessageRoleUser:
44+
var parts []openai.ChatCompletionContentPartUnionParam
45+
for _, part := range m.Parts {
46+
switch part.Type {
47+
case gai.MessagePartTypeText:
48+
parts = append(parts, openai.TextPart(part.Text()))
49+
default:
50+
panic("not implemented")
51+
}
52+
}
53+
messages = append(messages, openai.UserMessageParts(parts...))
54+
55+
default:
56+
panic("not implemented")
57+
}
58+
}
59+
60+
params := openai.ChatCompletionNewParams{
61+
Messages: openai.F(messages),
62+
Model: openai.F(openai.ChatModel(c.model)),
63+
}
64+
65+
if req.Temperature != nil {
66+
params.Temperature = openai.F(req.Temperature.Float64())
67+
}
68+
69+
stream := c.Client.Chat.Completions.NewStreaming(ctx, params)
70+
71+
return gai.NewChatCompleteResponse(func(yield func(gai.MessagePart, error) bool) {
72+
defer func() {
73+
if err := stream.Close(); err != nil {
74+
c.log.Info("Error closing stream", "error", err)
75+
}
76+
}()
77+
78+
var acc openai.ChatCompletionAccumulator
79+
for stream.Next() {
80+
chunk := stream.Current()
81+
acc.AddChunk(chunk)
82+
83+
if _, ok := acc.JustFinishedContent(); ok {
84+
break
85+
}
86+
87+
if _, ok := acc.JustFinishedToolCall(); ok {
88+
continue
89+
// TODO handle tool call
90+
// println("Tool call stream finished:", tool.Index, tool.Name, tool.Arguments)
91+
}
92+
93+
if refusal, ok := acc.JustFinishedRefusal(); ok {
94+
yield(gai.MessagePart{}, fmt.Errorf("refusal: %v", refusal))
95+
return
96+
}
97+
98+
if len(chunk.Choices) > 0 {
99+
if !yield(gai.TextMessagePart(chunk.Choices[0].Delta.Content), nil) {
100+
return
101+
}
102+
}
103+
}
104+
105+
if err := stream.Err(); err != nil {
106+
yield(gai.MessagePart{}, err)
107+
}
108+
}), nil
109+
}
110+
111+
var _ gai.ChatCompleter = (*ChatCompleter)(nil)

chat_complete_test.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package openai_test
2+
3+
import (
4+
"testing"
5+
6+
"maragu.dev/gai"
7+
"maragu.dev/is"
8+
9+
openai "maragu.dev/gai-openai"
10+
)
11+
12+
func TestChatCompleter_ChatComplete(t *testing.T) {
13+
t.Run("can chat-complete", func(t *testing.T) {
14+
c := newClient()
15+
16+
cc := c.NewChatCompleter(openai.NewChatCompleterOptions{Model: openai.ChatCompleteModelGPT4oMini})
17+
18+
req := gai.ChatCompleteRequest{
19+
Messages: []gai.Message{
20+
gai.NewUserTextMessage("Hi!"),
21+
},
22+
Temperature: gai.Ptr(gai.Temperature(0)),
23+
}
24+
25+
res, err := cc.ChatComplete(t.Context(), req)
26+
is.NotError(t, err)
27+
28+
var output string
29+
for part, err := range res.Parts() {
30+
is.NotError(t, err)
31+
output += part.Text()
32+
}
33+
is.Equal(t, "Hello! How can I assist you today?", output)
34+
})
35+
}

client.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package openai
2+
3+
import (
4+
"log/slog"
5+
"strings"
6+
7+
"github.com/openai/openai-go"
8+
"github.com/openai/openai-go/option"
9+
)
10+
11+
type Client struct {
12+
Client *openai.Client
13+
log *slog.Logger
14+
}
15+
16+
type NewClientOptions struct {
17+
BaseURL string
18+
Key string
19+
Log *slog.Logger
20+
}
21+
22+
func NewClient(opts NewClientOptions) *Client {
23+
if opts.Log == nil {
24+
opts.Log = slog.New(slog.DiscardHandler)
25+
}
26+
27+
var clientOpts []option.RequestOption
28+
29+
if opts.BaseURL != "" {
30+
if !strings.HasSuffix(opts.BaseURL, "/") {
31+
opts.BaseURL += "/"
32+
}
33+
clientOpts = append(clientOpts, option.WithBaseURL(opts.BaseURL))
34+
}
35+
36+
if opts.Key != "" {
37+
clientOpts = append(clientOpts, option.WithAPIKey(opts.Key))
38+
}
39+
40+
return &Client{
41+
Client: openai.NewClient(clientOpts...),
42+
log: opts.Log,
43+
}
44+
}

client_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package openai_test
2+
3+
import (
4+
"testing"
5+
6+
"maragu.dev/env"
7+
"maragu.dev/is"
8+
9+
openai "maragu.dev/gai-openai"
10+
)
11+
12+
func TestNewClient(t *testing.T) {
13+
t.Run("can create a new client with a key", func(t *testing.T) {
14+
client := openai.NewClient(openai.NewClientOptions{Key: "123"})
15+
is.NotNil(t, client)
16+
})
17+
}
18+
19+
func newClient() *openai.Client {
20+
_ = env.Load(".env.test.local")
21+
22+
return openai.NewClient(openai.NewClientOptions{Key: env.GetStringOrDefault("OPENAI_KEY", "")})
23+
}

embed.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
package openai
2+
3+
import (
4+
"context"
5+
"log/slog"
6+
7+
"github.com/openai/openai-go"
8+
"github.com/openai/openai-go/shared"
9+
"maragu.dev/errors"
10+
"maragu.dev/gai"
11+
)
12+
13+
type EmbedModel string
14+
15+
const (
16+
EmbedModelTextEmbedding3Large = EmbedModel(openai.EmbeddingModelTextEmbedding3Large)
17+
EmbedModelTextEmbedding3Small = EmbedModel(openai.EmbeddingModelTextEmbedding3Small)
18+
)
19+
20+
type Embedder struct {
21+
Client *openai.Client
22+
dimensions int
23+
log *slog.Logger
24+
model EmbedModel
25+
}
26+
27+
type NewEmbedderOptions struct {
28+
Dimensions int
29+
Model EmbedModel
30+
}
31+
32+
func (c *Client) NewEmbedder(opts NewEmbedderOptions) *Embedder {
33+
if opts.Dimensions <= 0 {
34+
panic("dimensions must be greater than 0")
35+
}
36+
37+
switch opts.Model {
38+
case EmbedModelTextEmbedding3Large:
39+
if opts.Dimensions > 3072 {
40+
panic("dimensions must be less than or equal to 3072")
41+
}
42+
case EmbedModelTextEmbedding3Small:
43+
if opts.Dimensions > 1536 {
44+
panic("dimensions must be less than or equal to 1536")
45+
}
46+
}
47+
48+
return &Embedder{
49+
Client: c.Client,
50+
dimensions: opts.Dimensions,
51+
log: c.log,
52+
model: opts.Model,
53+
}
54+
}
55+
56+
// Embed satisfies [gai.Embedder].
57+
func (e *Embedder) Embed(ctx context.Context, req gai.EmbedRequest) (gai.EmbedResponse[float64], error) {
58+
v := gai.ReadAllString(req.Input)
59+
60+
res, err := e.Client.Embeddings.New(ctx, openai.EmbeddingNewParams{
61+
Input: openai.F[openai.EmbeddingNewParamsInputUnion](shared.UnionString(v)),
62+
Model: openai.F(openai.EmbeddingModel(e.model)),
63+
EncodingFormat: openai.F(openai.EmbeddingNewParamsEncodingFormatFloat),
64+
Dimensions: openai.F(int64(e.dimensions)),
65+
})
66+
if err != nil {
67+
return gai.EmbedResponse[float64]{}, errors.Wrap(err, "error embedding")
68+
}
69+
if len(res.Data) == 0 {
70+
return gai.EmbedResponse[float64]{}, errors.New("no embeddings returned")
71+
}
72+
73+
return gai.EmbedResponse[float64]{
74+
Embedding: res.Data[0].Embedding,
75+
}, nil
76+
}
77+
78+
var _ gai.Embedder[float64] = (*Embedder)(nil)

embed_test.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package openai_test
2+
3+
import (
4+
"strings"
5+
"testing"
6+
7+
"maragu.dev/gai"
8+
"maragu.dev/is"
9+
10+
openai "maragu.dev/gai-openai"
11+
)
12+
13+
func TestEmbedder_Embed(t *testing.T) {
14+
t.Run("can embed a text", func(t *testing.T) {
15+
c := newClient()
16+
17+
e := c.NewEmbedder(openai.NewEmbedderOptions{
18+
Model: openai.EmbedModelTextEmbedding3Small,
19+
Dimensions: 1536,
20+
})
21+
22+
req := gai.EmbedRequest{
23+
Input: strings.NewReader("Embed this, please."),
24+
}
25+
26+
res, err := e.Embed(t.Context(), req)
27+
is.NotError(t, err)
28+
29+
is.Equal(t, 1536, len(res.Embedding))
30+
})
31+
}

0 commit comments

Comments
 (0)