Skip to content
33 changes: 11 additions & 22 deletions templates/components/agents/python/blog/app/agents/researcher.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,28 @@
import os
from textwrap import dedent
from typing import List
from typing import List, Optional, Dict, Any

from app.engine.index import IndexConfig, get_index
from app.engine.tools import ToolFactory
from app.workflows.single import FunctionCallingAgent
from llama_index.core.chat_engine.types import ChatMessage
from llama_index.core.tools import QueryEngineTool, ToolMetadata
from llama_index.core.tools import QueryEngineTool
from app.engine.tools.query_engine import get_query_engine_tool


def _create_query_engine_tool(params=None) -> QueryEngineTool:
"""
Provide an agent worker that can be used to query the index.
"""
def _create_query_engine_tool(
params: Optional[Dict[str, Any]] = None, **kwargs
) -> QueryEngineTool:
if params is None:
params = {}
# Add query tool if index exists
index_config = IndexConfig(**(params or {}))
index_config = IndexConfig(**params)
index = get_index(index_config)
if index is None:
return None
top_k = int(os.getenv("TOP_K", 0))
query_engine = index.as_query_engine(
**({"similarity_top_k": top_k} if top_k != 0 else {})
)
return QueryEngineTool(
query_engine=query_engine,
metadata=ToolMetadata(
name="query_index",
description="""
Use this tool to retrieve information about the text corpus from the index.
""",
),
)
return get_query_engine_tool(index=index, **kwargs)


def _get_research_tools(**kwargs) -> QueryEngineTool:
def _get_research_tools(**kwargs):
"""
Researcher take responsibility for retrieving information.
Try init wikipedia or duckduckgo tool if available.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import os
from typing import Any, Dict, List, Optional

from app.engine.index import IndexConfig, get_index
from app.engine.tools import ToolFactory
from app.engine.tools.query_engine import get_query_engine_tool
from app.workflows.events import AgentRunEvent
from app.workflows.tools import (
call_tools,
chat_with_tools,
)
from llama_index.core import Settings
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from llama_index.core.indices.vector_store import VectorStoreIndex
from llama_index.core.llms.function_calling import FunctionCallingLLM
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.tools import FunctionTool, QueryEngineTool, ToolSelection
Expand All @@ -24,19 +23,23 @@
)


def _create_query_engine_tool(params=None, **kwargs) -> QueryEngineTool:
if params is None:
params = {}
# Add query tool if index exists
index_config = IndexConfig(**params)
index = get_index(index_config)
if index is None:
return None
return get_query_engine_tool(index=index, **kwargs)


def create_workflow(
chat_history: Optional[List[ChatMessage]] = None,
params: Optional[Dict[str, Any]] = None,
filters: Optional[List[Any]] = None,
**kwargs,
) -> Workflow:
index_config = IndexConfig(**params)
index: VectorStoreIndex = get_index(config=index_config)
if index is None:
query_engine_tool = None
else:
top_k = int(os.getenv("TOP_K", 10))
query_engine = index.as_query_engine(similarity_top_k=top_k, filters=filters)
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine)
query_engine_tool = _create_query_engine_tool(params, **kwargs)

configured_tools: Dict[str, FunctionTool] = ToolFactory.from_env(map_result=True) # type: ignore
code_interpreter_tool = configured_tools.get("interpret")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,7 @@
import os
from typing import Any, Dict, List, Optional

from app.engine.index import IndexConfig, get_index
from app.engine.tools import ToolFactory
from app.workflows.events import AgentRunEvent
from app.workflows.tools import (
call_tools,
chat_with_tools,
)
from llama_index.core import Settings
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from llama_index.core.indices.vector_store import VectorStoreIndex
from llama_index.core.llms.function_calling import FunctionCallingLLM
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.tools import FunctionTool, QueryEngineTool, ToolSelection
Expand All @@ -23,25 +14,35 @@
step,
)

from app.engine.index import IndexConfig, get_index
from app.engine.tools import ToolFactory
from app.engine.tools.query_engine import get_query_engine_tool
from app.workflows.events import AgentRunEvent
from app.workflows.tools import (
call_tools,
chat_with_tools,
)


def create_workflow(
chat_history: Optional[List[ChatMessage]] = None,
params: Optional[Dict[str, Any]] = None,
filters: Optional[List[Any]] = None,
) -> Workflow:
def _create_query_engine_tool(
params: Optional[Dict[str, Any]] = None, **kwargs
) -> QueryEngineTool:
if params is None:
params = {}
if filters is None:
filters = []
# Add query tool if index exists
index_config = IndexConfig(**params)
index: VectorStoreIndex = get_index(config=index_config)
index = get_index(index_config)
if index is None:
query_engine_tool = None
else:
top_k = int(os.getenv("TOP_K", 10))
query_engine = index.as_query_engine(similarity_top_k=top_k, filters=filters)
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine)
return None
return get_query_engine_tool(index=index, **kwargs)


def create_workflow(
chat_history: Optional[List[ChatMessage]] = None,
params: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Workflow:
query_engine_tool = _create_query_engine_tool(params, **kwargs)
configured_tools = ToolFactory.from_env(map_result=True)
extractor_tool = configured_tools.get("extract_questions") # type: ignore
filling_tool = configured_tools.get("fill_form") # type: ignore
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import { ChatMessage } from "llamaindex";
import { getTool } from "../engine/tools";
import { FunctionCallingAgent } from "./single-agent";
import { getQueryEngineTools } from "./tools";
import { getQueryEngineTool } from "./tools";

export const createResearcher = async (chatHistory: ChatMessage[]) => {
const queryEngineTools = await getQueryEngineTools();
const queryEngineTool = await getQueryEngineTool();
const tools = [
await getTool("wikipedia_tool"),
await getTool("duckduckgo_search"),
await getTool("image_generator"),
...(queryEngineTools ? queryEngineTools : []),
queryEngineTool,
].filter((tool) => tool !== undefined);

return new FunctionCallingAgent({
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { ChatMessage, ToolCallLLM } from "llamaindex";
import { getTool } from "../engine/tools";
import { FinancialReportWorkflow } from "./fin-report";
import { getQueryEngineTools } from "./tools";
import { getQueryEngineTool } from "./tools";

const TIMEOUT = 360 * 1000;

Expand All @@ -11,7 +11,7 @@ export async function createWorkflow(options: {
}) {
return new FinancialReportWorkflow({
chatHistory: options.chatHistory,
queryEngineTools: (await getQueryEngineTools()) || [],
queryEngineTool: (await getQueryEngineTool())!,
codeInterpreterTool: (await getTool("interpreter"))!,
documentGeneratorTool: (await getTool("document_generator"))!,
llm: options.llm,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,15 @@ export class FinancialReportWorkflow extends Workflow<
> {
llm: ToolCallLLM;
memory: ChatMemoryBuffer;
queryEngineTools: BaseToolWithCall[];
queryEngineTool: BaseToolWithCall[];
codeInterpreterTool: BaseToolWithCall;
documentGeneratorTool: BaseToolWithCall;
systemPrompt?: string;

constructor(options: {
llm?: ToolCallLLM;
chatHistory: ChatMessage[];
queryEngineTools: BaseToolWithCall[];
queryEngineTool: BaseToolWithCall;
codeInterpreterTool: BaseToolWithCall;
documentGeneratorTool: BaseToolWithCall;
systemPrompt?: string;
Expand All @@ -70,7 +70,7 @@ export class FinancialReportWorkflow extends Workflow<
throw new Error("LLM is not a ToolCallLLM");
}
this.systemPrompt = options.systemPrompt ?? DEFAULT_SYSTEM_PROMPT;
this.queryEngineTools = options.queryEngineTools;
this.queryEngineTool = options.queryEngineTool;
this.codeInterpreterTool = options.codeInterpreterTool;

this.documentGeneratorTool = options.documentGeneratorTool;
Expand Down Expand Up @@ -154,8 +154,8 @@ export class FinancialReportWorkflow extends Workflow<
const chatHistory = ev.data.input;

const tools = [this.codeInterpreterTool, this.documentGeneratorTool];
if (this.queryEngineTools) {
tools.push(...this.queryEngineTools);
if (this.queryEngineTool) {
tools.push(this.queryEngineTool);
}

const toolCallResponse = await chatWithTools(this.llm, tools, chatHistory);
Expand Down Expand Up @@ -190,8 +190,8 @@ export class FinancialReportWorkflow extends Workflow<
});
default:
if (
this.queryEngineTools &&
this.queryEngineTools.some((tool) => tool.metadata.name === toolName)
this.queryEngineTool &&
this.queryEngineTool.metadata.name === toolName
) {
return new ResearchEvent({
toolCalls: toolCallResponse.toolCalls,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { ChatMessage, ToolCallLLM } from "llamaindex";
import { getTool } from "../engine/tools";
import { FormFillingWorkflow } from "./form-filling";
import { getQueryEngineTools } from "./tools";
import { getQueryEngineTool } from "./tools";

const TIMEOUT = 360 * 1000;

Expand All @@ -11,7 +11,7 @@ export async function createWorkflow(options: {
}) {
return new FormFillingWorkflow({
chatHistory: options.chatHistory,
queryEngineTools: (await getQueryEngineTools()) || [],
queryEngineTool: (await getQueryEngineTool())!,
extractorTool: (await getTool("extract_missing_cells"))!,
fillMissingCellsTool: (await getTool("fill_missing_cells"))!,
llm: options.llm,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ export class FormFillingWorkflow extends Workflow<
llm: ToolCallLLM;
memory: ChatMemoryBuffer;
extractorTool: BaseToolWithCall;
queryEngineTools?: BaseToolWithCall[];
queryEngineTool?: BaseToolWithCall;
fillMissingCellsTool: BaseToolWithCall;
systemPrompt?: string;

constructor(options: {
llm?: ToolCallLLM;
chatHistory: ChatMessage[];
extractorTool: BaseToolWithCall;
queryEngineTools?: BaseToolWithCall[];
queryEngineTool: BaseToolWithCall;
fillMissingCellsTool: BaseToolWithCall;
systemPrompt?: string;
verbose?: boolean;
Expand All @@ -73,7 +73,7 @@ export class FormFillingWorkflow extends Workflow<
}
this.systemPrompt = options.systemPrompt ?? DEFAULT_SYSTEM_PROMPT;
this.extractorTool = options.extractorTool;
this.queryEngineTools = options.queryEngineTools;
this.queryEngineTool = options.queryEngineTool;
this.fillMissingCellsTool = options.fillMissingCellsTool;

this.memory = new ChatMemoryBuffer({
Expand Down Expand Up @@ -156,8 +156,8 @@ export class FormFillingWorkflow extends Workflow<
const chatHistory = ev.data.input;

const tools = [this.extractorTool, this.fillMissingCellsTool];
if (this.queryEngineTools) {
tools.push(...this.queryEngineTools);
if (this.queryEngineTool) {
tools.push(this.queryEngineTool);
}

const toolCallResponse = await chatWithTools(this.llm, tools, chatHistory);
Expand Down Expand Up @@ -192,8 +192,8 @@ export class FormFillingWorkflow extends Workflow<
});
default:
if (
this.queryEngineTools &&
this.queryEngineTools.some((tool) => tool.metadata.name === toolName)
this.queryEngineTool &&
this.queryEngineTool.metadata.name === toolName
) {
return new FindAnswersEvent({
toolCalls: toolCallResponse.toolCalls,
Expand Down Expand Up @@ -232,7 +232,7 @@ export class FormFillingWorkflow extends Workflow<
ev: FindAnswersEvent,
): Promise<InputEvent> => {
const { toolCalls } = ev.data;
if (!this.queryEngineTools) {
if (!this.queryEngineTool) {
throw new Error("Query engine tool is not available");
}
ctx.sendEvent(
Expand All @@ -243,7 +243,7 @@ export class FormFillingWorkflow extends Workflow<
}),
);
const toolMsgs = await callTools({
tools: this.queryEngineTools,
tools: [this.queryEngineTool],
toolCalls,
ctx,
agentName: "Researcher",
Expand Down
15 changes: 6 additions & 9 deletions templates/components/engines/python/agent/engine.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,26 @@
import os
from typing import List

from app.engine.index import IndexConfig, get_index
from app.engine.tools import ToolFactory
from llama_index.core.agent import AgentRunner
from llama_index.core.callbacks import CallbackManager
from llama_index.core.settings import Settings
from llama_index.core.tools import BaseTool
from llama_index.core.tools.query_engine import QueryEngineTool

from app.engine.index import IndexConfig, get_index
from app.engine.tools import ToolFactory
from app.engine.tools.query_engine import get_query_engine_tool


def get_chat_engine(filters=None, params=None, event_handlers=None, **kwargs):
def get_chat_engine(params=None, event_handlers=None, **kwargs):
system_prompt = os.getenv("SYSTEM_PROMPT")
top_k = int(os.getenv("TOP_K", 0))
tools: List[BaseTool] = []
callback_manager = CallbackManager(handlers=event_handlers or [])

# Add query tool if index exists
index_config = IndexConfig(callback_manager=callback_manager, **(params or {}))
index = get_index(index_config)
if index is not None:
query_engine = index.as_query_engine(
filters=filters, **({"similarity_top_k": top_k} if top_k != 0 else {})
)
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine)
query_engine_tool = get_query_engine_tool(index, **kwargs)
tools.append(query_engine_tool)

# Add additional tools
Expand Down
Loading