From 096a1ef9d5aa9c4cf12c17ec425022ea7844c43b Mon Sep 17 00:00:00 2001 From: Henry Paulino Date: Fri, 24 Apr 2026 16:01:48 +0100 Subject: [PATCH] refactor(ios): unify generation paths via GenerationSink protocol Collapse performGeneration and performGenerationWithEvents into a single engine parameterized by a GenerationSink. StringGenerationSink and EventGenerationSink each own their streaming specifics (token batching, thinking state machine, tool-call events, zero-width-space force-flush) while sharing the generation loop, tool-execution task group, and history management. Public API is unchanged. --- package/ios/Sources/HybridLLM.swift | 394 +++++++++++----------------- 1 file changed, 156 insertions(+), 238 deletions(-) diff --git a/package/ios/Sources/HybridLLM.swift b/package/ios/Sources/HybridLLM.swift index c4fee34..58ddd18 100644 --- a/package/ios/Sources/HybridLLM.swift +++ b/package/ios/Sources/HybridLLM.swift @@ -179,6 +179,134 @@ private final class HybridLLMCore { } } + private protocol GenerationSink: AnyObject { + var firstTokenTime: Date? { get } + func ingest(chunk: String) -> String + func flush() + func finalizeStream() -> String + func registerToolCall(name: String, arguments: String) -> String + func willExecuteTool(id: String) + func didCompleteTool(id: String, result: String) + func didFailTool(id: String, error: String) + func willContinueAfterTools() + } + + private final class StringGenerationSink: GenerationSink { + private let batcher: TokenBatcher + private let onToolCall: (String, String) -> Void + private(set) var firstTokenTime: Date? + + init(batcher: TokenBatcher, onToolCall: @escaping (String, String) -> Void) { + self.batcher = batcher + self.onToolCall = onToolCall + } + + func ingest(chunk: String) -> String { + if !chunk.isEmpty && firstTokenTime == nil { + firstTokenTime = Date() + } + batcher.append(chunk) + return chunk + } + + func flush() { + batcher.flush() + } + + func finalizeStream() -> String { "" } + + func registerToolCall(name: String, arguments: String) -> String { + onToolCall(name, arguments) + return UUID().uuidString + } + + func willExecuteTool(id: String) {} + func didCompleteTool(id: String, result: String) {} + func didFailTool(id: String, error: String) {} + + func willContinueAfterTools() { + batcher.flush() + if firstTokenTime == nil { + firstTokenTime = Date() + } + batcher.append("\u{200B}") + } + } + + private final class EventGenerationSink: GenerationSink { + private let emitter: StreamEventEmitter + private let batcher: TokenBatcher + private var thinkingMachine = ThinkingStateMachine() + private(set) var firstTokenTime: Date? + + init(emitter: StreamEventEmitter, batcher: TokenBatcher) { + self.emitter = emitter + self.batcher = batcher + } + + func ingest(chunk: String) -> String { + var result = "" + for out in thinkingMachine.process(token: chunk) { + result += emit(out) + } + return result + } + + func flush() { + batcher.flush() + } + + func finalizeStream() -> String { + var result = "" + for out in thinkingMachine.flush() { + result += emit(out) + } + batcher.flush() + return result + } + + func registerToolCall(name: String, arguments: String) -> String { + let id = UUID().uuidString + emitter.emitToolCallStart(id: id, name: name, arguments: arguments) + return id + } + + func willExecuteTool(id: String) { + emitter.emitToolCallExecuting(id: id) + } + + func didCompleteTool(id: String, result: String) { + emitter.emitToolCallCompleted(id: id, result: result) + } + + func didFailTool(id: String, error: String) { + emitter.emitToolCallFailed(id: id, error: error) + } + + func willContinueAfterTools() {} + + private func emit(_ output: ThinkingStateMachine.Output) -> String { + switch output { + case .token(let token): + if !token.isEmpty && firstTokenTime == nil { + firstTokenTime = Date() + } + batcher.append(token) + return token + case .thinkingStart: + batcher.flush() + emitter.emitThinkingStart() + case .thinkingChunk(let chunk): + batcher.flush() + emitter.emitThinkingChunk(chunk) + case .thinkingEnd(let content): + batcher.flush() + emitter.emitThinkingEnd(content) + } + return "" + } + } + private struct ManagedSessionResult { let output: String let generationTokenCount: Int @@ -683,10 +811,11 @@ private final class HybridLLMCore { } var history = messageHistory - var firstTokenTime: Date? var generationTokenCount = 0 var generationTimeMs: Double = 0 var toolExecutionTime: Double = 0 + let batcher = TokenBatcher(batchSize: tokenBatchSize, emit: { _ in }) + let sink = StringGenerationSink(batcher: batcher, onToolCall: { _, _ in }) let result = try await performGeneration( container: container, @@ -694,17 +823,11 @@ private final class HybridLLMCore { prompt: prompt, toolResults: nil, depth: 0, - onToken: { token in - if !token.isEmpty && firstTokenTime == nil { - firstTokenTime = Date() - } - }, - flushOutput: {}, + sink: sink, onGenerationInfo: { tokens, time in generationTokenCount += tokens generationTimeMs += time }, - onToolCall: { _, _ in }, toolExecutionTime: &toolExecutionTime ) @@ -712,7 +835,7 @@ private final class HybridLLMCore { lastStats = makeStats( startTime: startTime, - firstTokenTime: firstTokenTime, + firstTokenTime: sink.firstTokenTime, generationTokenCount: generationTokenCount, generationTimeMs: generationTimeMs, toolExecutionTimeMs: toolExecutionTime @@ -765,10 +888,13 @@ private final class HybridLLMCore { } var history = messageHistory - var firstTokenTime: Date? var generationTokenCount = 0 var generationTimeMs: Double = 0 var toolExecutionTime: Double = 0 + let sink = StringGenerationSink( + batcher: batcher, + onToolCall: onToolCall ?? { _, _ in } + ) let result = try await performGeneration( container: container, @@ -776,20 +902,11 @@ private final class HybridLLMCore { prompt: prompt, toolResults: nil, depth: 0, - onToken: { token in - if !token.isEmpty && firstTokenTime == nil { - firstTokenTime = Date() - } - batcher.append(token) - }, - flushOutput: { - batcher.flush() - }, + sink: sink, onGenerationInfo: { tokens, time in generationTokenCount += tokens generationTimeMs += time }, - onToolCall: onToolCall ?? { _, _ in }, toolExecutionTime: &toolExecutionTime ) @@ -798,7 +915,7 @@ private final class HybridLLMCore { let stats = makeStats( startTime: startTime, - firstTokenTime: firstTokenTime, + firstTokenTime: sink.firstTokenTime, generationTokenCount: generationTokenCount, generationTimeMs: generationTimeMs, toolExecutionTimeMs: toolExecutionTime @@ -859,30 +976,21 @@ private final class HybridLLMCore { } var history = messageHistory - var firstTokenTime: Date? var generationTokenCount = 0 var generationTimeMs: Double = 0 var toolExecutionTime: Double = 0 let tokenBatcher = TokenBatcher(batchSize: tokenBatchSize) { token in emitter.emitToken(token) } + let sink = EventGenerationSink(emitter: emitter, batcher: tokenBatcher) - let result = try await performGenerationWithEvents( + let result = try await performGeneration( container: container, history: &history, prompt: prompt, toolResults: nil, depth: 0, - emitter: emitter, - emitToken: { token in - if !token.isEmpty && firstTokenTime == nil { - firstTokenTime = Date() - } - tokenBatcher.append(token) - }, - flushTokenBatch: { - tokenBatcher.flush() - }, + sink: sink, onGenerationInfo: { tokens, time in generationTokenCount += tokens generationTimeMs += time @@ -895,7 +1003,7 @@ private final class HybridLLMCore { let stats = makeStats( startTime: startTime, - firstTokenTime: firstTokenTime, + firstTokenTime: sink.firstTokenTime, generationTokenCount: generationTokenCount, generationTimeMs: generationTimeMs, toolExecutionTimeMs: toolExecutionTime @@ -914,15 +1022,13 @@ private final class HybridLLMCore { return try await task.value } - private func performGenerationWithEvents( + private func performGeneration( container: ModelContainer, history: inout [LLMMessage], prompt: String, toolResults: [String]?, depth: Int, - emitter: StreamEventEmitter, - emitToken: @escaping (String) -> Void, - flushTokenBatch: @escaping () -> Void, + sink: GenerationSink, onGenerationInfo: @escaping (Int, Double) -> Void, toolExecutionTime: inout Double ) async throws -> String { @@ -932,8 +1038,7 @@ private final class HybridLLMCore { } var output = "" - var thinkingMachine = ThinkingStateMachine() - var pendingToolCalls: [(id: String, tool: ToolDefinition, args: [String: Any], argsJson: String)] = [] + var pendingToolCalls: [(id: String, tool: ToolDefinition, args: [String: Any])] = [] let chat = buildChatMessages( history: history, @@ -958,30 +1063,10 @@ private final class HybridLLMCore { switch generation { case .chunk(let text): - let outputs = thinkingMachine.process(token: text) - - for machineOutput in outputs { - switch machineOutput { - case .token(let token): - output += token - emitToken(token) - - case .thinkingStart: - flushTokenBatch() - emitter.emitThinkingStart() - - case .thinkingChunk(let chunk): - flushTokenBatch() - emitter.emitThinkingChunk(chunk) - - case .thinkingEnd(let content): - flushTokenBatch() - emitter.emitThinkingEnd(content) - } - } + output += sink.ingest(chunk: text) case .toolCall(let toolCall): - flushTokenBatch() + sink.flush() log("Tool call detected: \(toolCall.function.name)") guard let tool = tools.first(where: { $0.name == toolCall.function.name }) else { @@ -989,19 +1074,14 @@ private final class HybridLLMCore { continue } - let toolCallId = UUID().uuidString let argsDict = convertToolCallArguments(toolCall.function.arguments) let argsJson = dictionaryToJson(argsDict) + let id = sink.registerToolCall(name: toolCall.function.name, arguments: argsJson) - emitter.emitToolCallStart( - id: toolCallId, - name: toolCall.function.name, - arguments: argsJson - ) - pendingToolCalls.append((id: toolCallId, tool: tool, args: argsDict, argsJson: argsJson)) + pendingToolCalls.append((id: id, tool: tool, args: argsDict)) case .info(let info): - flushTokenBatch() + sink.flush() log( "Generation info: \(info.generationTokenCount) tokens, \(String(format: "%.1f", info.tokensPerSecond)) tokens/s" ) @@ -1012,176 +1092,15 @@ private final class HybridLLMCore { } } - let flushOutputs = thinkingMachine.flush() - for machineOutput in flushOutputs { - switch machineOutput { - case .token(let token): - output += token - emitToken(token) - case .thinkingStart: - flushTokenBatch() - emitter.emitThinkingStart() - case .thinkingChunk(let chunk): - flushTokenBatch() - emitter.emitThinkingChunk(chunk) - case .thinkingEnd(let content): - flushTokenBatch() - emitter.emitThinkingEnd(content) - } - } - - flushTokenBatch() + output += sink.finalizeStream() if !pendingToolCalls.isEmpty { log("Executing \(pendingToolCalls.count) tool call(s)") let toolStartTime = Date() for call in pendingToolCalls { - emitter.emitToolCallExecuting(id: call.id) - } - - let allToolResults: [String] = await withTaskGroup(of: (Int, String).self) { group in - for (index, call) in pendingToolCalls.enumerated() { - group.addTask { [self] in - do { - let resultJson = try await executeToolCall( - tool: call.tool, - argsDict: call.args - ) - await log("Tool result for \(call.tool.name): \(resultJson.prefix(100))...") - emitter.emitToolCallCompleted(id: call.id, result: resultJson) - return (index, resultJson) - } catch { - await log("Tool execution error for \(call.tool.name): \(error)") - emitter.emitToolCallFailed(id: call.id, error: error.localizedDescription) - return (index, "{\"error\": \"Tool execution failed\"}") - } - } - } - - var results = Array(repeating: "", count: pendingToolCalls.count) - for await (index, result) in group { - results[index] = result - } - return results - } - - toolExecutionTime += Date().timeIntervalSince(toolStartTime) * 1000 - - if depth == 0 { - history.append(LLMMessage(role: "user", content: prompt)) - } - if !output.isEmpty { - history.append(LLMMessage(role: "assistant", content: output)) - } - for result in allToolResults { - history.append(LLMMessage(role: "tool", content: result)) - } - - let continuation = try await performGenerationWithEvents( - container: container, - history: &history, - prompt: prompt, - toolResults: allToolResults, - depth: depth + 1, - emitter: emitter, - emitToken: emitToken, - flushTokenBatch: flushTokenBatch, - onGenerationInfo: onGenerationInfo, - toolExecutionTime: &toolExecutionTime - ) - - return output + continuation - } - - if manageHistory { - if depth == 0 { - history.append(LLMMessage(role: "user", content: prompt)) - } - if !output.isEmpty { - history.append(LLMMessage(role: "assistant", content: output)) + sink.willExecuteTool(id: call.id) } - } - - return output - } - - private func performGeneration( - container: ModelContainer, - history: inout [LLMMessage], - prompt: String, - toolResults: [String]?, - depth: Int, - onToken: @escaping (String) -> Void, - flushOutput: @escaping () -> Void, - onGenerationInfo: @escaping (Int, Double) -> Void, - onToolCall: @escaping (String, String) -> Void, - toolExecutionTime: inout Double - ) async throws -> String { - if depth >= maxToolCallDepth { - log("Max tool call depth reached (\(maxToolCallDepth))") - return "" - } - - var output = "" - var pendingToolCalls: [(tool: ToolDefinition, args: [String: Any], argsJson: String)] = [] - - let chat = buildChatMessages( - history: history, - prompt: prompt, - toolResults: toolResults, - depth: depth - ) - let userInput = UserInput(chat: chat, tools: configuredToolSchemas()) - let lmInput = try await container.prepare(input: userInput) - let parameters = generationParameters - - let stream = try await container.perform { context in - try MLXLMCommon.generate( - input: lmInput, - parameters: parameters, - context: context - ) - } - - for await generation in stream { - if Task.isCancelled { break } - - switch generation { - case .chunk(let text): - output += text - onToken(text) - - case .toolCall(let toolCall): - flushOutput() - log("Tool call detected: \(toolCall.function.name)") - - guard let tool = tools.first(where: { $0.name == toolCall.function.name }) else { - log("Unknown tool: \(toolCall.function.name)") - continue - } - - let argsDict = convertToolCallArguments(toolCall.function.arguments) - let argsJson = dictionaryToJson(argsDict) - - pendingToolCalls.append((tool: tool, args: argsDict, argsJson: argsJson)) - onToolCall(toolCall.function.name, argsJson) - - case .info(let info): - flushOutput() - log( - "Generation info: \(info.generationTokenCount) tokens, \(String(format: "%.1f", info.tokensPerSecond)) tokens/s" - ) - let generationTime = info.tokensPerSecond > 0 - ? Double(info.generationTokenCount) / info.tokensPerSecond * 1000 - : 0 - onGenerationInfo(info.generationTokenCount, generationTime) - } - } - - if !pendingToolCalls.isEmpty { - log("Executing \(pendingToolCalls.count) tool call(s)") - let toolStartTime = Date() let allToolResults: [String] = await withTaskGroup(of: (Int, String).self) { group in for (index, call) in pendingToolCalls.enumerated() { @@ -1192,9 +1111,11 @@ private final class HybridLLMCore { argsDict: call.args ) await log("Tool result for \(call.tool.name): \(resultJson.prefix(100))...") + sink.didCompleteTool(id: call.id, result: resultJson) return (index, resultJson) } catch { await log("Tool execution error for \(call.tool.name): \(error)") + sink.didFailTool(id: call.id, error: error.localizedDescription) return (index, "{\"error\": \"Tool execution failed\"}") } } @@ -1219,8 +1140,7 @@ private final class HybridLLMCore { history.append(LLMMessage(role: "tool", content: result)) } - flushOutput() - onToken("\u{200B}") + sink.willContinueAfterTools() let continuation = try await performGeneration( container: container, @@ -1228,10 +1148,8 @@ private final class HybridLLMCore { prompt: prompt, toolResults: allToolResults, depth: depth + 1, - onToken: onToken, - flushOutput: flushOutput, + sink: sink, onGenerationInfo: onGenerationInfo, - onToolCall: onToolCall, toolExecutionTime: &toolExecutionTime )