Skip to content

Commit 4c81706

Browse files
committed
feat: Add query-specific LLM configuration
- Add QUERY_BINDING, QUERY_MODEL, QUERY_BINDING_HOST, QUERY_BINDING_API_KEY env vars - Create separate LLM function for query operations - Support all LLM providers (openai, azure_openai, ollama, lollms, aws_bedrock) - Enable cost optimization: powerful model for queries, economical for extraction
1 parent 965d8b1 commit 4c81706

File tree

5 files changed

+133
-4
lines changed

5 files changed

+133
-4
lines changed

env.example

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,15 @@ LLM_MODEL=gpt-4o
164164
LLM_BINDING_HOST=https://api.openai.com/v1
165165
LLM_BINDING_API_KEY=your_api_key
166166

167+
#####################################
168+
### Query-specific LLM Configuration
169+
### Use a more powerful model for queries while keeping extraction economical
170+
#####################################
171+
QUERY_BINDING=openai
172+
QUERY_MODEL=gpt-5
173+
QUERY_BINDING_HOST=https://api.openai.com/v1
174+
QUERY_BINDING_API_KEY=your_api_key
175+
167176
### Optional for Azure
168177
# AZURE_OPENAI_API_VERSION=2024-08-01-preview
169178
# AZURE_OPENAI_DEPLOYMENT=gpt-4o

lightrag/api/config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,17 @@ def parse_args() -> argparse.Namespace:
326326

327327
# Inject model configuration
328328
args.llm_model = get_env_value("LLM_MODEL", "mistral-nemo:latest")
329+
args.query_binding = get_env_value("QUERY_BINDING", args.llm_binding)
330+
args.query_model = get_env_value("QUERY_MODEL", args.llm_model)
331+
args.query_binding_host = get_env_value(
332+
"QUERY_BINDING_HOST",
333+
get_default_host(args.query_binding)
334+
if args.query_binding != args.llm_binding
335+
else args.llm_binding_host,
336+
)
337+
args.query_binding_api_key = get_env_value(
338+
"QUERY_BINDING_API_KEY", args.llm_binding_api_key
339+
)
329340
args.embedding_model = get_env_value("EMBEDDING_MODEL", "bge-m3:latest")
330341
args.embedding_dim = get_env_value("EMBEDDING_DIM", 1024, int)
331342

lightrag/api/lightrag_server.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,10 @@ async def lifespan(app: FastAPI):
325325
# Store background tasks
326326
app.state.background_tasks = set()
327327

328+
# Store query LLM function in app state
329+
app.state.query_llm_func = query_llm_func
330+
app.state.query_llm_kwargs = query_llm_kwargs
331+
328332
try:
329333
# Initialize database connections
330334
await rag.initialize_storages()
@@ -497,6 +501,11 @@ async def optimized_azure_openai_model_complete(
497501

498502
return optimized_azure_openai_model_complete
499503

504+
llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int)
505+
embedding_timeout = get_env_value(
506+
"EMBEDDING_TIMEOUT", DEFAULT_EMBEDDING_TIMEOUT, int
507+
)
508+
500509
def create_llm_model_func(binding: str):
501510
"""
502511
Create LLM model function based on binding type.
@@ -524,6 +533,76 @@ def create_llm_model_func(binding: str):
524533
except ImportError as e:
525534
raise Exception(f"Failed to import {binding} LLM binding: {e}")
526535

536+
def create_query_llm_func(args, config_cache, llm_timeout):
537+
"""
538+
Create query-specific LLM function if QUERY_BINDING or QUERY_MODEL is different.
539+
Returns tuple of (query_llm_func, query_llm_kwargs).
540+
"""
541+
# Check if query-specific LLM is configured
542+
# Only skip if BOTH binding AND model are the same
543+
if not hasattr(args, "query_binding") or (
544+
args.query_binding == args.llm_binding
545+
and args.query_model == args.llm_model
546+
):
547+
logger.info("Using same LLM for both extraction and query")
548+
return None, {}
549+
550+
logger.info(
551+
f"Creating separate query LLM: {args.query_binding}/{args.query_model}"
552+
)
553+
554+
# Create a temporary args object for query LLM
555+
class QueryArgs:
556+
pass
557+
558+
query_args = QueryArgs()
559+
query_args.llm_binding = args.query_binding
560+
query_args.llm_model = args.query_model
561+
query_args.llm_binding_host = args.query_binding_host
562+
query_args.llm_binding_api_key = args.query_binding_api_key
563+
564+
# Create query-specific LLM function based on binding type
565+
query_llm_func = None
566+
query_llm_kwargs = {}
567+
568+
try:
569+
if args.query_binding == "openai":
570+
query_llm_func = create_optimized_openai_llm_func(
571+
config_cache, query_args, llm_timeout
572+
)
573+
elif args.query_binding == "azure_openai":
574+
query_llm_func = create_optimized_azure_openai_llm_func(
575+
config_cache, query_args, llm_timeout
576+
)
577+
elif args.query_binding == "ollama":
578+
from lightrag.llm.ollama import ollama_model_complete
579+
580+
query_llm_func = ollama_model_complete
581+
query_llm_kwargs = create_llm_model_kwargs(
582+
args.query_binding, query_args, llm_timeout
583+
)
584+
elif args.query_binding == "lollms":
585+
from lightrag.llm.lollms import lollms_model_complete
586+
587+
query_llm_func = lollms_model_complete
588+
query_llm_kwargs = create_llm_model_kwargs(
589+
args.query_binding, query_args, llm_timeout
590+
)
591+
elif args.query_binding == "aws_bedrock":
592+
query_llm_func = bedrock_model_complete
593+
else:
594+
raise ValueError(f"Unsupported query binding: {args.query_binding}")
595+
596+
logger.info(f"Query LLM configured: {args.query_model}")
597+
return query_llm_func, query_llm_kwargs
598+
599+
except ImportError as e:
600+
raise Exception(f"Failed to import {args.query_binding} LLM binding: {e}")
601+
602+
query_llm_func, query_llm_kwargs = create_query_llm_func(
603+
args, config_cache, llm_timeout
604+
)
605+
527606
def create_llm_model_kwargs(binding: str, args, llm_timeout: int) -> dict:
528607
"""
529608
Create LLM model kwargs based on binding type.

lightrag/api/routers/ollama_api.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,14 @@ async def stream_generator():
683683
else:
684684
first_chunk_time = time.time_ns()
685685

686+
# Check if query-specific LLM is configured
687+
if (
688+
hasattr(raw_request.app.state, "query_llm_func")
689+
and raw_request.app.state.query_llm_func
690+
):
691+
query_param.model_func = raw_request.app.state.query_llm_func
692+
logger.debug("Using query-specific LLM for Ollama chat")
693+
686694
# Determine if the request is prefix with "/bypass" or from Open WebUI's session title and session keyword generation task
687695
match_result = re.search(
688696
r"\n<chat_history>\nUSER:", cleaned_query, re.MULTILINE

lightrag/api/routers/query_routes.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import logging
77
from typing import Any, Dict, List, Literal, Optional
88

9-
from fastapi import APIRouter, Depends, HTTPException
9+
from fastapi import APIRouter, Depends, HTTPException, Request
1010
from lightrag.base import QueryParam
1111
from lightrag.api.utils_api import get_combined_auth_dependency
1212
from pydantic import BaseModel, Field, field_validator
@@ -267,7 +267,7 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
267267
},
268268
},
269269
)
270-
async def query_text(request: QueryRequest):
270+
async def query_text(request: QueryRequest, req: Request):
271271
"""
272272
Comprehensive RAG query endpoint with non-streaming response. Parameter "stream" is ignored.
273273
@@ -343,6 +343,13 @@ async def query_text(request: QueryRequest):
343343
# Force stream=False for /query endpoint regardless of include_references setting
344344
param.stream = False
345345

346+
# Use query-specific LLM if available
347+
if (
348+
hasattr(req.app.state, "query_llm_func")
349+
and req.app.state.query_llm_func
350+
):
351+
param.model_func = req.app.state.query_llm_func
352+
346353
# Unified approach: always use aquery_llm for both cases
347354
result = await rag.aquery_llm(request.query, param=param)
348355

@@ -438,7 +445,7 @@ async def query_text(request: QueryRequest):
438445
},
439446
},
440447
)
441-
async def query_text_stream(request: QueryRequest):
448+
async def query_text_stream(request: QueryRequest, req: Request):
442449
"""
443450
Advanced RAG query endpoint with flexible streaming response.
444451
@@ -560,6 +567,13 @@ async def query_text_stream(request: QueryRequest):
560567
stream_mode = request.stream if request.stream is not None else True
561568
param = request.to_query_params(stream_mode)
562569

570+
# Use query-specific LLM if available
571+
if (
572+
hasattr(req.app.state, "query_llm_func")
573+
and req.app.state.query_llm_func
574+
):
575+
param.model_func = req.app.state.query_llm_func
576+
563577
from fastapi.responses import StreamingResponse
564578

565579
# Unified approach: always use aquery_llm for all cases
@@ -907,7 +921,7 @@ async def stream_generator():
907921
},
908922
},
909923
)
910-
async def query_data(request: QueryRequest):
924+
async def query_data(request: QueryRequest, req: Request):
911925
"""
912926
Advanced data retrieval endpoint for structured RAG analysis.
913927
@@ -1002,6 +1016,14 @@ async def query_data(request: QueryRequest):
10021016
"""
10031017
try:
10041018
param = request.to_query_params(False) # No streaming for data endpoint
1019+
1020+
# Use query-specific LLM if available (for keyword extraction)
1021+
if (
1022+
hasattr(req.app.state, "query_llm_func")
1023+
and req.app.state.query_llm_func
1024+
):
1025+
param.model_func = req.app.state.query_llm_func
1026+
10051027
response = await rag.aquery_data(request.query, param=param)
10061028

10071029
# aquery_data returns the new format with status, message, data, and metadata

0 commit comments

Comments
 (0)