Skip to content

Commit 7c08a4a

Browse files
committed
handle file input
1 parent e3be8cc commit 7c08a4a

File tree

5 files changed

+96
-14
lines changed

5 files changed

+96
-14
lines changed

augmenta/agent.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from pathlib import Path
55
import yaml
66
from pydantic import BaseModel, Field, create_model
7-
from pydantic_ai import Agent
7+
from pydantic_ai import Agent, BinaryContent
88
import logfire
99
from .tools.mcp import load_mcp_servers
1010
from .tools.search_web import search_web
@@ -119,18 +119,18 @@ def create_structure_class(yaml_file_path: Union[str, Path]) -> Type[BaseModel]:
119119

120120
except (yaml.YAMLError, OSError) as e:
121121
raise ValueError(f"Failed to parse YAML: {e}")
122-
122+
123123
async def run(
124124
self,
125-
prompt: str,
125+
prompt: Union[str, List[Union[str, BinaryContent]]],
126126
response_format: Optional[Type[BaseModel]] = None,
127127
temperature: Optional[float] = None,
128128
system_prompt: Optional[str] = None
129129
) -> Union[str, dict[str, Any], BaseModel]:
130130
"""Run the agent to perform web research.
131131
132132
Args:
133-
prompt: The research query or task
133+
prompt: The research query/task as a string or a list containing text and binary content
134134
response_format: Optional Pydantic model for structured output
135135
temperature: Optional override for model temperature
136136
system_prompt: Optional override for system prompt (defaults to self.system_prompt)

augmenta/augmenta.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
import asyncio
55
import pandas as pd
66
from pathlib import Path
7-
from typing import Optional, Tuple, Dict, Any, Callable, Type, Union
7+
from typing import Optional, Tuple, Dict, Any, Callable, Type, Union, List
88
from dataclasses import dataclass
99

1010
from augmenta.utils.prompt_formatter import format_examples, substitute_template_variables, build_complete_prompt
1111
from augmenta.agent import AugmentaAgent
1212
from augmenta.cache import CacheManager
1313
from augmenta.cache.process import setup_cache_handling, apply_cached_results
1414
from augmenta.config.read_config import load_config, get_config_values
15+
from augmenta.tools.file import load_file
1516
import logfire
1617

1718
@dataclass
@@ -206,12 +207,34 @@ async def process_row(
206207
"""
207208
try:
208209
index = row_data['index']
209-
row = row_data['data']
210-
# Build complete prompt with data from row
210+
row = row_data['data'] # Build complete prompt with data from row
211211
prompt_user = build_complete_prompt(config, row)
212+
# Get the file column name from config (if available)
213+
file_col = config.get("file_col")
212214

213-
# Run prompt using the agent
214-
response = await agent.run(prompt_user, response_format=response_format)
215+
# Check if a file column is specified and the row contains a file path
216+
file_path = None
217+
if file_col and file_col in row:
218+
file_path = row.get(file_col)
219+
logfire.debug(f"Using file from column '{file_col}': {file_path}")
220+
elif file_col:
221+
logfire.debug(f"File column '{file_col}' specified in config but not found in row data")
222+
else:
223+
logfire.debug("No file column specified in config")
224+
225+
try:
226+
binary_content = load_file(file_path) if file_path is not None else None
227+
if binary_content:
228+
# If file exists, create a message list with prompt and binary content
229+
message_contents = [prompt_user, binary_content]
230+
response = await agent.run(message_contents, response_format=response_format)
231+
else:
232+
# If file doesn't exist or couldn't be loaded, just use the text prompt
233+
response = await agent.run(prompt_user, response_format=response_format)
234+
except Exception as e:
235+
logfire.warning(f"Error loading file at row {index}: {str(e)}. Proceeding with text prompt only.")
236+
# Fallback to text-only prompt if file handling fails
237+
response = await agent.run(prompt_user, response_format=response_format)
215238

216239
# Handle caching and progress tracking
217240
handle_result_tracking(

augmenta/config/read_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,15 +92,15 @@ def get_config_values(config: Dict[str, Any]) -> Dict[str, Any]:
9292

9393
# Construct model ID with provider
9494
model_id = f"{model_config['provider']}:{model_config['name']}"
95-
96-
# Extract commonly used values with defaults
95+
# Extract commonly used values with defaults
9796
return {
9897
"model_id": model_id,
9998
"temperature": model_config.get("temperature", 0.0),
10099
"max_tokens": model_config.get("max_tokens"),
101100
"rate_limit": model_config.get("rate_limit"),
102101
"search_engine": search_config.get("engine"),
103-
"search_results": search_config.get("results", 3)
102+
"search_results": search_config.get("results", 3),
103+
"file_col": config.get("file_col")
104104
}
105105

106106
def load_config(config_path: Union[str, Path]) -> Dict[str, Any]:

augmenta/tools/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
"""Tools and utilities for web interaction."""
1+
"""Tools and utilities for web interaction and file handling."""
22

33
from .search_web import search_web
44
from .visit_webpages import visit_webpages
5+
from .file import load_file
56

67
__all__ = [
78
'search_web',
8-
'visit_webpages'
9+
'visit_webpages',
10+
'load_file'
911
]

augmenta/tools/file.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
"""
2+
File handling utilities for loading binary content.
3+
"""
4+
import mimetypes
5+
from pathlib import Path
6+
from typing import Optional, Union, Any
7+
8+
from pydantic_ai import BinaryContent
9+
10+
11+
def load_file(file_path: Union[str, Path, Any]) -> Optional[BinaryContent]:
12+
"""
13+
Load a file as binary content with appropriate MIME type detection.
14+
15+
Args:
16+
file_path: Path to the file to load
17+
18+
Returns:
19+
BinaryContent object with binary data and media type if file exists,
20+
None otherwise
21+
"""
22+
# Skip if file path is None, empty, or 'NA' or not a string-like object
23+
if file_path is None:
24+
return None
25+
26+
# Handle non-string types by converting to string first
27+
try:
28+
file_path_str = str(file_path).strip()
29+
if not file_path_str or file_path_str.upper() == 'NA':
30+
return None
31+
except:
32+
# If we can't convert to string, it's not a valid path
33+
return None
34+
35+
try:
36+
path = Path(file_path_str)
37+
38+
# Check if file exists
39+
if not path.exists():
40+
return None
41+
42+
# Read binary content
43+
file_binary = path.read_bytes()
44+
45+
# Determine MIME type
46+
media_type = mimetypes.guess_type(str(path))[0]
47+
48+
# Default to application/octet-stream if type cannot be determined
49+
if not media_type:
50+
media_type = "application/octet-stream"
51+
52+
# Create and return binary content
53+
return BinaryContent(data=file_binary, media_type=media_type)
54+
55+
except Exception as e:
56+
print(f"Error loading file {file_path}: {str(e)}")
57+
return None

0 commit comments

Comments
 (0)