diff --git a/example/app/(tabs)/index.tsx b/example/app/(tabs)/index.tsx index 285089a..9528c20 100644 --- a/example/app/(tabs)/index.tsx +++ b/example/app/(tabs)/index.tsx @@ -6,6 +6,7 @@ import { ActivityIndicator, LayoutAnimation, Platform, + ScrollView, StyleSheet, Text, TextInput, @@ -242,6 +243,8 @@ export default function ChatScreen() { const [isReady, setIsReady] = useState(false) const [prompt, setPrompt] = useState('') const [isGenerating, setIsGenerating] = useState(false) + const [isRunningTrimDebug, setIsRunningTrimDebug] = useState(false) + const [trimDebugTurn, setTrimDebugTurn] = useState(0) const colorScheme = useColorScheme() const textColor = colorScheme === 'dark' ? 'white' : 'black' const bgColor = colorScheme === 'dark' ? 'black' : 'white' @@ -350,7 +353,7 @@ export default function ChatScreen() { } loadModel() - }, [isDownloaded, isReady]) + }, [isDownloaded, isReady, refreshManifest]) const sendPrompt = async () => { if (!isReady || !prompt.trim() || isGenerating) return @@ -562,6 +565,66 @@ export default function ChatScreen() { } } + const runHistoryTrimDebugTest = async () => { + if (!isDownloaded || isGenerating || isRunningTrimDebug) return + + setIsRunningTrimDebug(true) + setTrimDebugTurn(0) + setIsLoading(true) + setLoadProgress(0) + setIsReady(false) + setMessages([]) + isLoadingRef.current = true + + try { + console.log('[HistoryTrimDebug] Starting managed-history trim test') + LLM.unload() + LLM.systemPrompt = 'You are a concise assistant.' + await LLM.load(MODEL_ID, { + onProgress: setLoadProgress, + manageHistory: true, + tools: [weatherTool], + generationConfig: { + maxTokens: 8, + }, + contextConfig: { + maxContextTokens: 512, + keepLastMessages: 4, + }, + }) + + setIsReady(true) + await refreshManifest() + + for (let index = 0; index < 10; index += 1) { + setTrimDebugTurn(index + 1) + const promptText = [ + `History trim debug turn ${index + 1}.`, + 'Reply with only the turn number.', + 'Padding:', + 'alpha beta gamma delta epsilon zeta eta theta iota kappa '.repeat(80), + ].join(' ') + + await LLM.generate(promptText) + const history = LLM.getHistory() + console.log( + `[HistoryTrimDebug] turn ${index + 1}: ${history.length} managed message(s)`, + ) + } + + const history = LLM.getHistory() + console.log('[HistoryTrimDebug] Final managed history:', history) + syncFromHistory() + } catch (error) { + console.error('[HistoryTrimDebug] Failed:', error) + } finally { + setIsLoading(false) + setIsRunningTrimDebug(false) + setTrimDebugTurn(0) + isLoadingRef.current = false + } + } + useEffect(() => { if (isReady) { syncFromHistory() @@ -597,7 +660,11 @@ export default function ChatScreen() { - Loading model... {(loadProgress * 100).toFixed(0)}% + {isRunningTrimDebug + ? trimDebugTurn > 0 + ? `Running trim test... turn ${trimDebugTurn} of 10` + : `Preparing trim test... ${(loadProgress * 100).toFixed(0)}%` + : `Loading model... ${(loadProgress * 100).toFixed(0)}%`} ) @@ -619,14 +686,36 @@ export default function ChatScreen() { { borderBottomColor: colorScheme === 'dark' ? '#333' : '#eee' }, ]} > - - Benchmark - - MLX Chat - + + + Benchmark + + + MLX Chat + + + + Log + + + {isRunningTrimDebug ? '...' : 'Trim'} + + Manifest @@ -636,7 +725,7 @@ export default function ChatScreen() { Delete - + diff --git a/example/metro.config.ts b/example/metro.config.ts index bd52243..9b68f6b 100644 --- a/example/metro.config.ts +++ b/example/metro.config.ts @@ -1,5 +1,5 @@ const { getDefaultConfig } = require('expo/metro-config') -const path = require('path') +const path = require('node:path') const projectRoot = __dirname const monorepoRoot = path.resolve(projectRoot, '../..') diff --git a/package.json b/package.json index 4d3a090..8e056f3 100644 --- a/package.json +++ b/package.json @@ -6,6 +6,7 @@ "postinstall": "tsc -p ./package --noEmit || exit 0;", "typescript": "bun tsc -p ./package --noEmit", "test": "bun --cwd ./package test", + "test:ios-history-trim": "cd package && bun run test:ios-history-trim", "clean": "rm -rf package/tsconfig.tsbuildinfo node_modules example/node_modules example/ios package/node_modules package/lib example/.expo", "specs": "bun --cwd ./package specs", "specs:pod": "bun --cwd ./package specs && cd example/ios && pod install && cd ../../", diff --git a/package/ios/Sources/HybridLLM.swift b/package/ios/Sources/HybridLLM.swift index 1d9f868..c4fee34 100644 --- a/package/ios/Sources/HybridLLM.swift +++ b/package/ios/Sources/HybridLLM.swift @@ -405,41 +405,52 @@ private final class HybridLLMCore { minimum: 0 ) ?? defaultKeepLastMessages + var tokenizationPasses = 0 + func tokenCount(for history: [LLMMessage]) async throws -> Int { + tokenizationPasses += 1 let input = try await container.prepare( input: makeUserInput(history: history, prompt: upcomingPrompt) ) return input.text.tokens.size } - var trimmedHistory = messageHistory - let initialTokenCount = try await tokenCount(for: trimmedHistory) + let originalHistory = messageHistory + let initialTokenCount = try await tokenCount(for: originalHistory) guard initialTokenCount > maxContextTokens else { return } - while trimmedHistory.count > keepLastMessages { - trimmedHistory.removeFirst() - - if try await tokenCount(for: trimmedHistory) <= maxContextTokens { - break - } - } - - guard trimmedHistory.count != messageHistory.count else { + let maxRemovableMessages = max(0, originalHistory.count - keepLastMessages) + guard maxRemovableMessages > 0 else { log( "Context remains above the configured limit (\(maxContextTokens) tokens); pinned and recent messages were preserved" ) return } - let removedCount = messageHistory.count - trimmedHistory.count + guard let trimPlan = try await ManagedHistoryTrimPlanner.plan( + initialTokenCount: initialTokenCount, + maxContextTokens: maxContextTokens, + maxRemovableMessages: maxRemovableMessages, + tokenCountAfterRemoving: { removalCount in + try await tokenCount( + for: Array(originalHistory.dropFirst(removalCount)) + ) + } + ) else { + return + } + + let removedCount = trimPlan.removalCount + let trimmedHistory = Array(originalHistory.dropFirst(removedCount)) + messageHistory = trimmedHistory log( - "Trimmed \(removedCount) message(s) from managed history to stay within \(maxContextTokens) prompt tokens" + "Trimmed \(removedCount) message(s) from managed history to stay within \(maxContextTokens) prompt tokens after \(tokenizationPasses) tokenization pass(es)" ) rebuildManagedSession() - if try await tokenCount(for: trimmedHistory) > maxContextTokens { + if !trimPlan.fitsBudget { log( "Context still exceeds \(maxContextTokens) tokens after trimming because preserved messages alone are larger than the budget" ) diff --git a/package/ios/Sources/ManagedHistoryTrimPlanner.swift b/package/ios/Sources/ManagedHistoryTrimPlanner.swift new file mode 100644 index 0000000..ae5bbb8 --- /dev/null +++ b/package/ios/Sources/ManagedHistoryTrimPlanner.swift @@ -0,0 +1,64 @@ +import Foundation + +struct ManagedHistoryTrimPlan { + let removalCount: Int + let tokenCount: Int + let fitsBudget: Bool +} + +enum ManagedHistoryTrimPlanner { + static func plan( + initialTokenCount: Int, + maxContextTokens: Int, + maxRemovableMessages: Int, + tokenCountAfterRemoving: (Int) async throws -> Int + ) async throws -> ManagedHistoryTrimPlan? { + guard initialTokenCount > maxContextTokens else { return nil } + guard maxRemovableMessages > 0 else { return nil } + + var tokenCountCache: [Int: Int] = [0: initialTokenCount] + + func tokenCount(afterRemoving removalCount: Int) async throws -> Int { + if let cached = tokenCountCache[removalCount] { + return cached + } + + let count = try await tokenCountAfterRemoving(removalCount) + tokenCountCache[removalCount] = count + return count + } + + var lowerBound = 1 + var upperBound = maxRemovableMessages + var fittingRemovalCount: Int? + var fittingTokenCount: Int? + + while lowerBound <= upperBound { + let removalCount = lowerBound + (upperBound - lowerBound) / 2 + let count = try await tokenCount(afterRemoving: removalCount) + + if count <= maxContextTokens { + fittingRemovalCount = removalCount + fittingTokenCount = count + upperBound = removalCount - 1 + } else { + lowerBound = removalCount + 1 + } + } + + if let fittingRemovalCount, let fittingTokenCount { + return ManagedHistoryTrimPlan( + removalCount: fittingRemovalCount, + tokenCount: fittingTokenCount, + fitsBudget: true + ) + } + + let finalTokenCount = try await tokenCount(afterRemoving: maxRemovableMessages) + return ManagedHistoryTrimPlan( + removalCount: maxRemovableMessages, + tokenCount: finalTokenCount, + fitsBudget: false + ) + } +} diff --git a/package/ios/Tests/ManagedHistoryTrimPlannerSpyTests.swift b/package/ios/Tests/ManagedHistoryTrimPlannerSpyTests.swift new file mode 100644 index 0000000..3099af8 --- /dev/null +++ b/package/ios/Tests/ManagedHistoryTrimPlannerSpyTests.swift @@ -0,0 +1,90 @@ +import Foundation + +enum TestFailure: Error, CustomStringConvertible { + case failed(String) + + var description: String { + switch self { + case .failed(let message): + return message + } + } +} + +func expect(_ condition: @autoclosure () -> Bool, _ message: String) throws { + if !condition() { + throw TestFailure.failed(message) + } +} + +@main +struct ManagedHistoryTrimPlannerSpyTests { + static func main() async throws { + try await findsSmallestFittingRemovalWithLogarithmicTokenProbes() + try await trimsToMaxRemovableWhenBudgetStillCannotFit() + try await skipsWorkWhenInitialPromptAlreadyFits() + print("ManagedHistoryTrimPlannerSpyTests passed") + } + + private static func findsSmallestFittingRemovalWithLogarithmicTokenProbes() async throws { + var probedRemovalCounts: [Int] = [] + + let plan = try await ManagedHistoryTrimPlanner.plan( + initialTokenCount: 220, + maxContextTokens: 100, + maxRemovableMessages: 16, + tokenCountAfterRemoving: { removalCount in + probedRemovalCounts.append(removalCount) + return 220 - removalCount * 10 + } + ) + + try expect(plan?.removalCount == 12, "expected to remove the smallest fitting prefix") + try expect(plan?.tokenCount == 100, "expected final token count at the budget") + try expect(plan?.fitsBudget == true, "expected plan to fit the token budget") + try expect(probedRemovalCounts.count <= 5, "expected logarithmic probe count") + try expect( + Set(probedRemovalCounts).count == probedRemovalCounts.count, + "expected token-count cache to avoid duplicate probes" + ) + } + + private static func trimsToMaxRemovableWhenBudgetStillCannotFit() async throws { + var probedRemovalCounts: [Int] = [] + + let plan = try await ManagedHistoryTrimPlanner.plan( + initialTokenCount: 500, + maxContextTokens: 100, + maxRemovableMessages: 4, + tokenCountAfterRemoving: { removalCount in + probedRemovalCounts.append(removalCount) + return 500 - removalCount * 20 + } + ) + + try expect(plan?.removalCount == 4, "expected to preserve pinned/recent messages") + try expect(plan?.tokenCount == 420, "expected final count after max removal") + try expect(plan?.fitsBudget == false, "expected budget to remain exceeded") + try expect( + probedRemovalCounts.last == 4, + "expected final max-removal count to be measured for warning state" + ) + } + + private static func skipsWorkWhenInitialPromptAlreadyFits() async throws { + var prepareCalls = 0 + + let plan = try await ManagedHistoryTrimPlanner.plan( + initialTokenCount: 80, + maxContextTokens: 100, + maxRemovableMessages: 16, + tokenCountAfterRemoving: { _ in + prepareCalls += 1 + return 0 + } + ) + + try expect(plan == nil, "expected no trim plan when prompt already fits") + try expect(prepareCalls == 0, "expected no extra tokenization when already in budget") + } +} diff --git a/package/package.json b/package/package.json index e0fc778..0bb82df 100644 --- a/package/package.json +++ b/package/package.json @@ -11,6 +11,7 @@ "build": "rm -rf lib && bun typecheck && bob build", "typecheck": "tsc --noEmit", "test": "bun test src/runtime.test.ts", + "test:ios-history-trim": "swiftc ios/Sources/ManagedHistoryTrimPlanner.swift ios/Tests/ManagedHistoryTrimPlannerSpyTests.swift -o /tmp/ManagedHistoryTrimPlannerSpyTests && /tmp/ManagedHistoryTrimPlannerSpyTests", "clean": "rm -rf android/build node_modules/**/android/build lib android/.cxx node_modules/**/android/.cxx", "release": "release-it", "specs": "bun typecheck && nitrogen --logLevel=\\\"debug\\\" && bun run build",