@@ -5,9 +5,14 @@ import (
55 "encoding/json"
66 "fmt"
77 "log/slog"
8+ "sort"
89 "strings"
910
1011 "github.com/openai/openai-go"
12+ "go.opentelemetry.io/otel"
13+ "go.opentelemetry.io/otel/attribute"
14+ "go.opentelemetry.io/otel/codes"
15+ "go.opentelemetry.io/otel/trace"
1116 "maragu.dev/gai"
1217)
1318
@@ -22,6 +27,7 @@ type ChatCompleter struct {
2227 Client openai.Client
2328 log * slog.Logger
2429 model ChatCompleteModel
30+ tracer trace.Tracer
2531}
2632
2733type NewChatCompleterOptions struct {
@@ -33,15 +39,29 @@ func (c *Client) NewChatCompleter(opts NewChatCompleterOptions) *ChatCompleter {
3339 Client : c .Client ,
3440 log : c .log ,
3541 model : opts .Model ,
42+ tracer : otel .Tracer ("maragu.dev/gai-openai" ),
3643 }
3744}
3845
3946// ChatComplete satisfies [gai.ChatCompleter].
4047func (c * ChatCompleter ) ChatComplete (ctx context.Context , req gai.ChatCompleteRequest ) (gai.ChatCompleteResponse , error ) {
48+ ctx , span := c .tracer .Start (ctx , "openai.chat_complete" ,
49+ trace .WithSpanKind (trace .SpanKindClient ),
50+ trace .WithAttributes (
51+ attribute .String ("ai.model" , string (c .model )),
52+ attribute .Int ("ai.message_count" , len (req .Messages )),
53+ ),
54+ )
55+ defer span .End ()
56+
4157 var messages []openai.ChatCompletionMessageParamUnion
4258
4359 if req .System != nil {
4460 messages = append (messages , openai .SystemMessage (* req .System ))
61+ span .SetAttributes (
62+ attribute .Bool ("ai.has_system_prompt" , true ),
63+ attribute .String ("ai.system_prompt" , * req .System ),
64+ )
4565 }
4666
4767 for _ , m := range req .Messages {
@@ -132,6 +152,7 @@ func (c *ChatCompleter) ChatComplete(ctx context.Context, req gai.ChatCompleteRe
132152 }
133153
134154 var tools []openai.ChatCompletionToolParam
155+ var toolNames []string
135156 for _ , tool := range req .Tools {
136157 tools = append (tools , openai.ChatCompletionToolParam {
137158 Function : openai.FunctionDefinitionParam {
@@ -143,21 +164,33 @@ func (c *ChatCompleter) ChatComplete(ctx context.Context, req gai.ChatCompleteRe
143164 },
144165 },
145166 })
167+ toolNames = append (toolNames , tool .Name )
146168 }
169+ sort .Strings (toolNames )
170+ span .SetAttributes (
171+ attribute .Int ("ai.tool_count" , len (tools )),
172+ attribute .StringSlice ("ai.tools" , toolNames ),
173+ )
147174
148175 params := openai.ChatCompletionNewParams {
149176 Messages : messages ,
150177 Model : openai .ChatModel (c .model ),
151178 Tools : tools ,
179+ StreamOptions : openai.ChatCompletionStreamOptionsParam {
180+ IncludeUsage : openai .Bool (true ),
181+ },
152182 }
153183
154184 if req .Temperature != nil {
155185 params .Temperature = openai .Opt (req .Temperature .Float64 ())
186+ span .SetAttributes (attribute .Float64 ("ai.temperature" , req .Temperature .Float64 ()))
156187 }
157188
158189 stream := c .Client .Chat .Completions .NewStreaming (ctx , params )
159190
160- return gai .NewChatCompleteResponse (func (yield func (gai.MessagePart , error ) bool ) {
191+ meta := & gai.ChatCompleteResponseMetadata {}
192+
193+ res := gai .NewChatCompleteResponse (func (yield func (gai.MessagePart , error ) bool ) {
161194 defer func () {
162195 if err := stream .Close (); err != nil {
163196 c .log .Info ("Error closing stream" , "error" , err )
@@ -169,33 +202,54 @@ func (c *ChatCompleter) ChatComplete(ctx context.Context, req gai.ChatCompleteRe
169202 chunk := stream .Current ()
170203 acc .AddChunk (chunk )
171204
172- if _ , ok := acc .JustFinishedContent (); ok {
173- break
174- }
205+ if _ , ok := acc .JustFinishedContent (); ! ok {
206+ if toolCall , ok := acc .JustFinishedToolCall (); ok {
207+ if ! yield (gai .ToolCallPart (toolCall .ID , toolCall .Name , json .RawMessage (toolCall .Arguments )), nil ) {
208+ return
209+ }
210+ continue
211+ }
175212
176- if toolCall , ok := acc .JustFinishedToolCall (); ok {
177- if ! yield (gai .ToolCallPart (toolCall .ID , toolCall .Name , json .RawMessage (toolCall .Arguments )), nil ) {
213+ if refusal , ok := acc .JustFinishedRefusal (); ok {
214+ err := fmt .Errorf ("refusal: %v" , refusal )
215+ span .RecordError (err )
216+ span .SetStatus (codes .Error , "model refused request" )
217+ yield (gai.MessagePart {}, err )
178218 return
179219 }
180- continue
220+
221+ if len (chunk .Choices ) > 0 {
222+ if ! yield (gai .TextMessagePart (chunk .Choices [0 ].Delta .Content ), nil ) {
223+ return
224+ }
225+ }
181226 }
182227
183- if refusal , ok := acc .JustFinishedRefusal (); ok {
184- yield (gai.MessagePart {}, fmt .Errorf ("refusal: %v" , refusal ))
185- return
228+ if chunk .Usage .PromptTokens == 0 {
229+ continue
186230 }
187231
188- if len (chunk .Choices ) > 0 {
189- if ! yield (gai .TextMessagePart (chunk .Choices [0 ].Delta .Content ), nil ) {
190- return
191- }
232+ meta .Usage = gai.ChatCompleteResponseUsage {
233+ PromptTokens : int (chunk .Usage .PromptTokens ),
234+ CompletionTokens : int (chunk .Usage .CompletionTokens ),
192235 }
236+ span .SetAttributes (
237+ attribute .Int ("ai.prompt_tokens" , int (chunk .Usage .PromptTokens )),
238+ attribute .Int ("ai.completion_tokens" , int (chunk .Usage .CompletionTokens )),
239+ attribute .Int ("ai.total_tokens" , int (chunk .Usage .TotalTokens )),
240+ )
193241 }
194242
195243 if err := stream .Err (); err != nil {
244+ span .RecordError (err )
245+ span .SetStatus (codes .Error , "stream error" )
196246 yield (gai.MessagePart {}, err )
197247 }
198- }), nil
248+ })
249+
250+ res .Meta = meta
251+
252+ return res , nil
199253}
200254
201255// normalizeToolSchemaProperties recursively normalizes schema properties for OpenAI compatibility
0 commit comments