Skip to content

Commit a9c9e3f

Browse files
authored
Merge pull request #471 from MODSetter/dev
feat: added top_k in chat Interface.
2 parents e79845b + 7ed159b commit a9c9e3f

File tree

10 files changed

+217
-5
lines changed

10 files changed

+217
-5
lines changed

surfsense_backend/app/agents/researcher/configuration.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class Configuration:
2727
search_mode: SearchMode
2828
document_ids_to_add_in_context: list[int]
2929
language: str | None = None
30+
top_k: int = 10
3031

3132
@classmethod
3233
def from_runnable_config(

surfsense_backend/app/agents/researcher/nodes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,8 +1366,8 @@ async def handle_qna_workflow(
13661366
}
13671367
)
13681368

1369-
# Use a reasonable top_k for QNA - not too many documents to avoid overwhelming the LLM
1370-
top_k = 5 if configuration.search_mode == SearchMode.DOCUMENTS else 20
1369+
# Use the top_k value from configuration
1370+
top_k = configuration.top_k
13711371

13721372
relevant_documents = []
13731373
user_selected_documents = []

surfsense_backend/app/routes/chats_routes.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
validate_research_mode,
2525
validate_search_mode,
2626
validate_search_space_id,
27+
validate_top_k,
2728
)
2829

2930
router = APIRouter()
@@ -54,6 +55,7 @@ async def handle_chat_data(
5455
request_data.get("document_ids_to_add_in_context")
5556
)
5657
search_mode_str = validate_search_mode(request_data.get("search_mode"))
58+
top_k = validate_top_k(request_data.get("top_k"))
5759
# print("RESQUEST DATA:", request_data)
5860
# print("SELECTED CONNECTORS:", selected_connectors)
5961

@@ -123,6 +125,7 @@ async def handle_chat_data(
123125
search_mode_str,
124126
document_ids_to_add_in_context,
125127
language,
128+
top_k,
126129
)
127130
)
128131

surfsense_backend/app/tasks/stream_connector_search_results.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ async def stream_connector_search_results(
2121
search_mode_str: str,
2222
document_ids_to_add_in_context: list[int],
2323
language: str | None = None,
24+
top_k: int = 10,
2425
) -> AsyncGenerator[str, None]:
2526
"""
2627
Stream connector search results to the client
@@ -56,6 +57,7 @@ async def stream_connector_search_results(
5657
"search_mode": search_mode,
5758
"document_ids_to_add_in_context": document_ids_to_add_in_context,
5859
"language": language, # Add language to the configuration
60+
"top_k": top_k, # Add top_k to the configuration
5961
}
6062
}
6163
# print(f"Researcher configuration: {config['configurable']}") # Debug print

surfsense_backend/app/utils/validators.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,60 @@ def validate_search_mode(search_mode: Any) -> str:
241241
return normalized_mode
242242

243243

244+
def validate_top_k(top_k: Any) -> int:
245+
"""
246+
Validate and convert top_k to integer.
247+
248+
Args:
249+
top_k: The top_k value to validate
250+
251+
Returns:
252+
int: Validated top_k value (defaults to 10 if None)
253+
254+
Raises:
255+
HTTPException: If validation fails
256+
"""
257+
if top_k is None:
258+
return 10 # Default value
259+
260+
if isinstance(top_k, bool):
261+
raise HTTPException(
262+
status_code=400, detail="top_k must be an integer, not a boolean"
263+
)
264+
265+
if isinstance(top_k, int):
266+
if top_k <= 0:
267+
raise HTTPException(
268+
status_code=400, detail="top_k must be a positive integer"
269+
)
270+
if top_k > 100:
271+
raise HTTPException(status_code=400, detail="top_k must not exceed 100")
272+
return top_k
273+
274+
if isinstance(top_k, str):
275+
if not top_k.strip():
276+
raise HTTPException(status_code=400, detail="top_k cannot be empty")
277+
278+
if not re.match(r"^[1-9]\d*$", top_k.strip()):
279+
raise HTTPException(
280+
status_code=400, detail="top_k must be a valid positive integer"
281+
)
282+
283+
value = int(top_k.strip())
284+
if value <= 0:
285+
raise HTTPException(
286+
status_code=400, detail="top_k must be a positive integer"
287+
)
288+
if value > 100:
289+
raise HTTPException(status_code=400, detail="top_k must not exceed 100")
290+
return value
291+
292+
raise HTTPException(
293+
status_code=400,
294+
detail="top_k must be an integer or string representation of an integer",
295+
)
296+
297+
244298
def validate_messages(messages: Any) -> list[dict]:
245299
"""
246300
Validate messages structure.

surfsense_web/app/dashboard/[search_space_id]/researcher/[[...chat_id]]/page.tsx

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ export default function ResearcherPage() {
2727
setSelectedConnectors,
2828
selectedDocuments,
2929
setSelectedDocuments,
30+
topK,
31+
setTopK,
3032
} = useChatState({
3133
search_space_id: search_space_id as string,
3234
chat_id: chatIdParam,
@@ -66,6 +68,7 @@ export default function ResearcherPage() {
6668
selectedConnectors: string[];
6769
searchMode: "DOCUMENTS" | "CHUNKS";
6870
researchMode: "QNA"; // Always QNA mode
71+
topK: number;
6972
}
7073

7174
const getChatStateStorageKey = (searchSpaceId: string, chatId: string) =>
@@ -105,6 +108,7 @@ export default function ResearcherPage() {
105108
research_mode: researchMode,
106109
search_mode: searchMode,
107110
document_ids_to_add_in_context: documentIds,
111+
top_k: topK,
108112
},
109113
},
110114
onError: (error) => {
@@ -124,6 +128,7 @@ export default function ResearcherPage() {
124128
selectedConnectors,
125129
searchMode,
126130
researchMode,
131+
topK,
127132
});
128133
router.replace(`/dashboard/${search_space_id}/researcher/${newChatId}`);
129134
}
@@ -145,10 +150,18 @@ export default function ResearcherPage() {
145150
setSelectedDocuments(restoredState.selectedDocuments);
146151
setSelectedConnectors(restoredState.selectedConnectors);
147152
setSearchMode(restoredState.searchMode);
153+
setTopK(restoredState.topK);
148154
// researchMode is always "QNA", no need to restore
149155
}
150156
}
151-
}, [chatIdParam, search_space_id, setSelectedDocuments, setSelectedConnectors, setSearchMode]);
157+
}, [
158+
chatIdParam,
159+
search_space_id,
160+
setSelectedDocuments,
161+
setSelectedConnectors,
162+
setSearchMode,
163+
setTopK,
164+
]);
152165

153166
// Set all sources as default for new chats
154167
useEffect(() => {
@@ -234,6 +247,8 @@ export default function ResearcherPage() {
234247
selectedConnectors={selectedConnectors}
235248
searchMode={searchMode}
236249
onSearchModeChange={setSearchMode}
250+
topK={topK}
251+
onTopKChange={setTopK}
237252
/>
238253
);
239254
}

surfsense_web/components/chat/ChatInputGroup.tsx

Lines changed: 128 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"use client";
22

33
import { ChatInput } from "@llamaindex/chat-ui";
4-
import { Brain, Check, FolderOpen, Zap } from "lucide-react";
4+
import { Brain, Check, FolderOpen, Minus, Plus, Zap } from "lucide-react";
55
import { useParams } from "next/navigation";
66
import React, { Suspense, useCallback, useState } from "react";
77
import { DocumentsDataTable } from "@/components/chat/DocumentsDataTable";
@@ -15,13 +15,15 @@ import {
1515
DialogTitle,
1616
DialogTrigger,
1717
} from "@/components/ui/dialog";
18+
import { Input } from "@/components/ui/input";
1819
import {
1920
Select,
2021
SelectContent,
2122
SelectItem,
2223
SelectTrigger,
2324
SelectValue,
2425
} from "@/components/ui/select";
26+
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip";
2527
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
2628
import { useDocumentTypes } from "@/hooks/use-document-types";
2729
import type { Document } from "@/hooks/use-documents";
@@ -447,6 +449,119 @@ const SearchModeSelector = React.memo(
447449

448450
SearchModeSelector.displayName = "SearchModeSelector";
449451

452+
const TopKSelector = React.memo(
453+
({ topK = 10, onTopKChange }: { topK?: number; onTopKChange?: (topK: number) => void }) => {
454+
const MIN_VALUE = 1;
455+
const MAX_VALUE = 100;
456+
457+
const handleIncrement = React.useCallback(() => {
458+
if (topK < MAX_VALUE) {
459+
onTopKChange?.(topK + 1);
460+
}
461+
}, [topK, onTopKChange]);
462+
463+
const handleDecrement = React.useCallback(() => {
464+
if (topK > MIN_VALUE) {
465+
onTopKChange?.(topK - 1);
466+
}
467+
}, [topK, onTopKChange]);
468+
469+
const handleInputChange = React.useCallback(
470+
(e: React.ChangeEvent<HTMLInputElement>) => {
471+
const value = e.target.value;
472+
// Allow empty input for editing
473+
if (value === "") {
474+
return;
475+
}
476+
const numValue = parseInt(value, 10);
477+
if (!isNaN(numValue) && numValue >= MIN_VALUE && numValue <= MAX_VALUE) {
478+
onTopKChange?.(numValue);
479+
}
480+
},
481+
[onTopKChange]
482+
);
483+
484+
const handleInputBlur = React.useCallback(
485+
(e: React.FocusEvent<HTMLInputElement>) => {
486+
const value = e.target.value;
487+
if (value === "") {
488+
// Reset to default if empty
489+
onTopKChange?.(10);
490+
return;
491+
}
492+
const numValue = parseInt(value, 10);
493+
if (isNaN(numValue) || numValue < MIN_VALUE) {
494+
onTopKChange?.(MIN_VALUE);
495+
} else if (numValue > MAX_VALUE) {
496+
onTopKChange?.(MAX_VALUE);
497+
}
498+
},
499+
[onTopKChange]
500+
);
501+
502+
return (
503+
<TooltipProvider>
504+
<Tooltip delayDuration={200}>
505+
<TooltipTrigger asChild>
506+
<div className="flex items-center h-8 border rounded-md bg-background hover:bg-accent/50 transition-colors">
507+
<Button
508+
type="button"
509+
variant="ghost"
510+
size="icon"
511+
className="h-full w-7 rounded-l-md rounded-r-none hover:bg-accent border-r"
512+
onClick={handleDecrement}
513+
disabled={topK <= MIN_VALUE}
514+
>
515+
<Minus className="h-3.5 w-3.5" />
516+
</Button>
517+
<div className="flex flex-col items-center justify-center px-2 min-w-[60px]">
518+
<Input
519+
type="number"
520+
value={topK}
521+
onChange={handleInputChange}
522+
onBlur={handleInputBlur}
523+
min={MIN_VALUE}
524+
max={MAX_VALUE}
525+
className="h-5 w-full px-1 text-center text-sm font-semibold border-0 bg-transparent focus-visible:ring-0 focus-visible:ring-offset-0 [appearance:textfield] [&::-webkit-outer-spin-button]:appearance-none [&::-webkit-inner-spin-button]:appearance-none"
526+
/>
527+
<span className="text-[10px] text-muted-foreground leading-none">Results</span>
528+
</div>
529+
<Button
530+
type="button"
531+
variant="ghost"
532+
size="icon"
533+
className="h-full w-7 rounded-r-md rounded-l-none hover:bg-accent border-l"
534+
onClick={handleIncrement}
535+
disabled={topK >= MAX_VALUE}
536+
>
537+
<Plus className="h-3.5 w-3.5" />
538+
</Button>
539+
</div>
540+
</TooltipTrigger>
541+
<TooltipContent side="top" className="max-w-xs">
542+
<div className="space-y-2">
543+
<p className="text-sm font-semibold">Results per Source</p>
544+
<p className="text-xs text-muted-foreground leading-relaxed">
545+
Control how many results to fetch from each data source. Set a higher number to get
546+
more information, or a lower number for faster, more focused results.
547+
</p>
548+
<div className="flex items-center gap-2 text-xs text-muted-foreground pt-1 border-t">
549+
<span>Recommended: 5-20</span>
550+
<span></span>
551+
<span>
552+
Range: {MIN_VALUE}-{MAX_VALUE}
553+
</span>
554+
</div>
555+
</div>
556+
</TooltipContent>
557+
</Tooltip>
558+
</TooltipProvider>
559+
);
560+
}
561+
);
562+
563+
TopKSelector.displayName = "TopKSelector";
564+
450565
const LLMSelector = React.memo(() => {
451566
const { search_space_id } = useParams();
452567
const searchSpaceId = Number(search_space_id);
@@ -604,13 +719,17 @@ const CustomChatInputOptions = React.memo(
604719
selectedConnectors,
605720
searchMode,
606721
onSearchModeChange,
722+
topK,
723+
onTopKChange,
607724
}: {
608725
onDocumentSelectionChange?: (documents: Document[]) => void;
609726
selectedDocuments?: Document[];
610727
onConnectorSelectionChange?: (connectorTypes: string[]) => void;
611728
selectedConnectors?: string[];
612729
searchMode?: "DOCUMENTS" | "CHUNKS";
613730
onSearchModeChange?: (mode: "DOCUMENTS" | "CHUNKS") => void;
731+
topK?: number;
732+
onTopKChange?: (topK: number) => void;
614733
}) => {
615734
// Memoize the loading fallback to prevent recreation
616735
const loadingFallback = React.useMemo(
@@ -637,6 +756,8 @@ const CustomChatInputOptions = React.memo(
637756
<div className="h-4 w-px bg-border hidden sm:block" />
638757
<SearchModeSelector searchMode={searchMode} onSearchModeChange={onSearchModeChange} />
639758
<div className="h-4 w-px bg-border hidden sm:block" />
759+
<TopKSelector topK={topK} onTopKChange={onTopKChange} />
760+
<div className="h-4 w-px bg-border hidden sm:block" />
640761
<LLMSelector />
641762
</div>
642763
);
@@ -653,13 +774,17 @@ export const ChatInputUI = React.memo(
653774
selectedConnectors,
654775
searchMode,
655776
onSearchModeChange,
777+
topK,
778+
onTopKChange,
656779
}: {
657780
onDocumentSelectionChange?: (documents: Document[]) => void;
658781
selectedDocuments?: Document[];
659782
onConnectorSelectionChange?: (connectorTypes: string[]) => void;
660783
selectedConnectors?: string[];
661784
searchMode?: "DOCUMENTS" | "CHUNKS";
662785
onSearchModeChange?: (mode: "DOCUMENTS" | "CHUNKS") => void;
786+
topK?: number;
787+
onTopKChange?: (topK: number) => void;
663788
}) => {
664789
return (
665790
<ChatInput>
@@ -674,6 +799,8 @@ export const ChatInputUI = React.memo(
674799
selectedConnectors={selectedConnectors}
675800
searchMode={searchMode}
676801
onSearchModeChange={onSearchModeChange}
802+
topK={topK}
803+
onTopKChange={onTopKChange}
677804
/>
678805
</ChatInput>
679806
);

0 commit comments

Comments
 (0)