Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 130 additions & 11 deletions example/app/(tabs)/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
ActivityIndicator,
LayoutAnimation,
Platform,
ScrollView,
StyleSheet,
Text,
TextInput,
Expand Down Expand Up @@ -242,6 +243,8 @@ export default function ChatScreen() {
const [isReady, setIsReady] = useState(false)
const [prompt, setPrompt] = useState('')
const [isGenerating, setIsGenerating] = useState(false)
const [isRunningTrimDebug, setIsRunningTrimDebug] = useState(false)
const [trimDebugTurn, setTrimDebugTurn] = useState(0)
const colorScheme = useColorScheme()
const textColor = colorScheme === 'dark' ? 'white' : 'black'
const bgColor = colorScheme === 'dark' ? 'black' : 'white'
Expand Down Expand Up @@ -350,7 +353,7 @@ export default function ChatScreen() {
}

loadModel()
}, [isDownloaded, isReady])
}, [isDownloaded, isReady, refreshManifest])

const sendPrompt = async () => {
if (!isReady || !prompt.trim() || isGenerating) return
Expand Down Expand Up @@ -562,6 +565,66 @@ export default function ChatScreen() {
}
}

const runHistoryTrimDebugTest = async () => {
if (!isDownloaded || isGenerating || isRunningTrimDebug) return

setIsRunningTrimDebug(true)
setTrimDebugTurn(0)
setIsLoading(true)
setLoadProgress(0)
setIsReady(false)
setMessages([])
isLoadingRef.current = true

try {
console.log('[HistoryTrimDebug] Starting managed-history trim test')
LLM.unload()
LLM.systemPrompt = 'You are a concise assistant.'
await LLM.load(MODEL_ID, {
onProgress: setLoadProgress,
manageHistory: true,
tools: [weatherTool],
generationConfig: {
maxTokens: 8,
},
contextConfig: {
maxContextTokens: 512,
keepLastMessages: 4,
},
})

setIsReady(true)
await refreshManifest()

for (let index = 0; index < 10; index += 1) {
setTrimDebugTurn(index + 1)
const promptText = [
`History trim debug turn ${index + 1}.`,
'Reply with only the turn number.',
'Padding:',
'alpha beta gamma delta epsilon zeta eta theta iota kappa '.repeat(80),
].join(' ')

await LLM.generate(promptText)
const history = LLM.getHistory()
console.log(
`[HistoryTrimDebug] turn ${index + 1}: ${history.length} managed message(s)`,
)
}

const history = LLM.getHistory()
console.log('[HistoryTrimDebug] Final managed history:', history)
syncFromHistory()
} catch (error) {
console.error('[HistoryTrimDebug] Failed:', error)
} finally {
setIsLoading(false)
setIsRunningTrimDebug(false)
setTrimDebugTurn(0)
isLoadingRef.current = false
}
}

useEffect(() => {
if (isReady) {
syncFromHistory()
Expand Down Expand Up @@ -597,7 +660,11 @@ export default function ChatScreen() {
<SafeAreaView style={[styles.centered, { backgroundColor: bgColor }]}>
<ActivityIndicator size="large" />
<Text style={[styles.statusText, { color: textColor }]}>
Loading model... {(loadProgress * 100).toFixed(0)}%
{isRunningTrimDebug
? trimDebugTurn > 0
? `Running trim test... turn ${trimDebugTurn} of 10`
: `Preparing trim test... ${(loadProgress * 100).toFixed(0)}%`
: `Loading model... ${(loadProgress * 100).toFixed(0)}%`}
</Text>
</SafeAreaView>
)
Expand All @@ -619,14 +686,36 @@ export default function ChatScreen() {
{ borderBottomColor: colorScheme === 'dark' ? '#333' : '#eee' },
]}
>
<TouchableOpacity onPress={openSettings}>
<Text style={[styles.headerButton, { color: '#007AFF' }]}>Benchmark</Text>
</TouchableOpacity>
<Text style={[styles.headerTitle, { color: textColor }]}>MLX Chat</Text>
<View style={styles.headerButtons}>
<View style={styles.headerTopRow}>
<TouchableOpacity style={styles.benchmarkLink} onPress={openSettings}>
<Text style={[styles.headerButton, { color: '#007AFF' }]}>Benchmark</Text>
</TouchableOpacity>
<Text numberOfLines={1} style={[styles.headerTitle, { color: textColor }]}>
MLX Chat
</Text>
<View style={styles.headerTopSpacer} />
</View>
<ScrollView
horizontal
showsHorizontalScrollIndicator={false}
contentContainerStyle={styles.headerButtons}
style={styles.headerActionsRail}
>
<TouchableOpacity style={styles.historyButton} onPress={logHistory}>
<Text style={styles.historyButtonText}>Log</Text>
</TouchableOpacity>
<TouchableOpacity
style={[
styles.trimDebugButton,
isRunningTrimDebug && styles.headerActionDisabled,
]}
onPress={runHistoryTrimDebugTest}
disabled={isRunningTrimDebug}
>
<Text style={styles.trimDebugButtonText}>
{isRunningTrimDebug ? '...' : 'Trim'}
</Text>
</TouchableOpacity>
<TouchableOpacity style={styles.manifestButton} onPress={refreshManifest}>
<Text style={styles.manifestButtonText}>Manifest</Text>
</TouchableOpacity>
Expand All @@ -636,7 +725,7 @@ export default function ChatScreen() {
<TouchableOpacity style={styles.deleteButton} onPress={deleteModel}>
<Text style={styles.deleteButtonText}>Delete</Text>
</TouchableOpacity>
</View>
</ScrollView>
</View>

<View
Expand Down Expand Up @@ -706,21 +795,37 @@ const styles = StyleSheet.create({
padding: 20,
},
header: {
padding: 16,
paddingHorizontal: 16,
paddingVertical: 12,
borderBottomWidth: 1,
alignItems: 'center',
gap: 10,
},
headerTopRow: {
flexDirection: 'row',
justifyContent: 'space-between',
alignItems: 'center',
gap: 12,
},
headerTitle: {
flex: 1,
fontSize: 18,
fontWeight: '600',
textAlign: 'center',
},
headerButton: {
fontSize: 14,
fontWeight: '500',
},
benchmarkLink: {
minWidth: 80,
},
headerTopSpacer: {
width: 80,
},
headerActionsRail: {
marginHorizontal: -16,
},
headerButtons: {
paddingHorizontal: 16,
flexDirection: 'row',
gap: 6,
},
Expand All @@ -746,6 +851,20 @@ const styles = StyleSheet.create({
fontSize: 12,
fontWeight: '600',
},
trimDebugButton: {
paddingHorizontal: 10,
paddingVertical: 6,
borderRadius: 8,
backgroundColor: '#5856D6',
},
trimDebugButtonText: {
color: 'white',
fontSize: 12,
fontWeight: '600',
},
headerActionDisabled: {
opacity: 0.5,
},
clearButton: {
paddingHorizontal: 10,
paddingVertical: 6,
Expand Down
2 changes: 1 addition & 1 deletion example/app/download-modal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ export default function DownloadModal() {
}

downloadModel()
}, [])
}, [MODEL_ID])

return (
<View style={[styles.container, { backgroundColor: bgColor }]}>
Expand Down
2 changes: 1 addition & 1 deletion example/metro.config.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
const { getDefaultConfig } = require('expo/metro-config')
const path = require('path')
const path = require('node:path')

const projectRoot = __dirname
const monorepoRoot = path.resolve(projectRoot, '../..')
Expand Down
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"postinstall": "tsc -p ./package --noEmit || exit 0;",
"typescript": "bun tsc -p ./package --noEmit",
"test": "bun --cwd ./package test",
"test:ios-history-trim": "cd package && bun run test:ios-history-trim",
"clean": "rm -rf package/tsconfig.tsbuildinfo node_modules example/node_modules example/ios package/node_modules package/lib example/.expo",
"specs": "bun --cwd ./package specs",
"specs:pod": "bun --cwd ./package specs && cd example/ios && pod install && cd ../../",
Expand Down
39 changes: 25 additions & 14 deletions package/ios/Sources/HybridLLM.swift
Original file line number Diff line number Diff line change
Expand Up @@ -405,41 +405,52 @@ private final class HybridLLMCore {
minimum: 0
) ?? defaultKeepLastMessages

var tokenizationPasses = 0

func tokenCount(for history: [LLMMessage]) async throws -> Int {
tokenizationPasses += 1
let input = try await container.prepare(
input: makeUserInput(history: history, prompt: upcomingPrompt)
)
return input.text.tokens.size
}

var trimmedHistory = messageHistory
let initialTokenCount = try await tokenCount(for: trimmedHistory)
let originalHistory = messageHistory
let initialTokenCount = try await tokenCount(for: originalHistory)

guard initialTokenCount > maxContextTokens else { return }

while trimmedHistory.count > keepLastMessages {
trimmedHistory.removeFirst()

if try await tokenCount(for: trimmedHistory) <= maxContextTokens {
break
}
}

guard trimmedHistory.count != messageHistory.count else {
let maxRemovableMessages = max(0, originalHistory.count - keepLastMessages)
guard maxRemovableMessages > 0 else {
log(
"Context remains above the configured limit (\(maxContextTokens) tokens); pinned and recent messages were preserved"
)
return
}

let removedCount = messageHistory.count - trimmedHistory.count
guard let trimPlan = try await ManagedHistoryTrimPlanner.plan(
initialTokenCount: initialTokenCount,
maxContextTokens: maxContextTokens,
maxRemovableMessages: maxRemovableMessages,
tokenCountAfterRemoving: { removalCount in
try await tokenCount(
for: Array(originalHistory.dropFirst(removalCount))
)
}
) else {
return
}

let removedCount = trimPlan.removalCount
let trimmedHistory = Array(originalHistory.dropFirst(removedCount))

messageHistory = trimmedHistory
log(
"Trimmed \(removedCount) message(s) from managed history to stay within \(maxContextTokens) prompt tokens"
"Trimmed \(removedCount) message(s) from managed history to stay within \(maxContextTokens) prompt tokens after \(tokenizationPasses) tokenization pass(es)"
)
rebuildManagedSession()

if try await tokenCount(for: trimmedHistory) > maxContextTokens {
if !trimPlan.fitsBudget {
log(
"Context still exceeds \(maxContextTokens) tokens after trimming because preserved messages alone are larger than the budget"
)
Expand Down
64 changes: 64 additions & 0 deletions package/ios/Sources/ManagedHistoryTrimPlanner.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import Foundation

struct ManagedHistoryTrimPlan {
let removalCount: Int
let tokenCount: Int
let fitsBudget: Bool
}

enum ManagedHistoryTrimPlanner {
static func plan(
initialTokenCount: Int,
maxContextTokens: Int,
maxRemovableMessages: Int,
tokenCountAfterRemoving: (Int) async throws -> Int
) async throws -> ManagedHistoryTrimPlan? {
guard initialTokenCount > maxContextTokens else { return nil }
guard maxRemovableMessages > 0 else { return nil }

var tokenCountCache: [Int: Int] = [0: initialTokenCount]

func tokenCount(afterRemoving removalCount: Int) async throws -> Int {
if let cached = tokenCountCache[removalCount] {
return cached
}

let count = try await tokenCountAfterRemoving(removalCount)
tokenCountCache[removalCount] = count
return count
}

var lowerBound = 1
var upperBound = maxRemovableMessages
var fittingRemovalCount: Int?
var fittingTokenCount: Int?

while lowerBound <= upperBound {
let removalCount = lowerBound + (upperBound - lowerBound) / 2
let count = try await tokenCount(afterRemoving: removalCount)

if count <= maxContextTokens {
fittingRemovalCount = removalCount
fittingTokenCount = count
upperBound = removalCount - 1
} else {
lowerBound = removalCount + 1
}
}

if let fittingRemovalCount, let fittingTokenCount {
return ManagedHistoryTrimPlan(
removalCount: fittingRemovalCount,
tokenCount: fittingTokenCount,
fitsBudget: true
)
}

let finalTokenCount = try await tokenCount(afterRemoving: maxRemovableMessages)
return ManagedHistoryTrimPlan(
removalCount: maxRemovableMessages,
tokenCount: finalTokenCount,
fitsBudget: false
)
}
}
Loading
Loading