Skip to content

Commit 27ec42d

Browse files
committed
vibe refactor
1 parent 3dc9359 commit 27ec42d

File tree

7 files changed

+584
-166
lines changed

7 files changed

+584
-166
lines changed

README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ output_csv: path/to/processed_data.csv
7878
model:
7979
provider: openai
8080
name: gpt-4o-mini
81-
query_col: DonorName
8281
search:
8382
engine: brightdata_google
8483
prompt:
@@ -161,7 +160,6 @@ You will need to edit this file to suit your project. Let's break all this down:
161160
162161
- `input_csv` and `output_csv` are the names of the data you want to process and where you want to save the results, respectively.
163162
- `model`: The LLM you want to use. You can find a list of supported models [here](https://ai.pydantic.dev/models/). Note that you need to provide both a `provider` and model `name` (ie. `anthropic` and `claude-3.5-sonnet`). You will also likely need to set up an API key (see [credentials below](#credentials)).
164-
- `query_col`: The name of the column in your input CSV that you want to use as the main search term (eg. company name).
165163
- `search`: The search engine you want to use. You can find a list of supported search engines [here](/docs/search.md). You will also likely need to set up an API key here (see [credentials](#credentials)).
166164
- `prompt`: LLMs take in a [system prompt](https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/system-prompts) and a user prompt. Think of the system prompt as explaining to the LLM what its role is, and the user prompt as the instructions you want it to follow. You can use double curly braces (`{{ }}`) to refer to columns in your input CSV. Therea are some tips on writing good prompts [here](docs/prompt.md).
167165
- `structure`: The structure of the output data. You can think of this as the columns you want added to your original CSV.

augmenta/agent.py

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,16 @@ def __init__(
3939
verbose: Whether to enable verbose logging with logfire
4040
system_prompt: Default system prompt for the agent
4141
"""
42-
# Create model settings with all available parameters
43-
model_settings = {'temperature': temperature}
44-
if rate_limit is not None:
45-
model_settings['rate_limit'] = rate_limit
46-
if max_tokens is not None:
47-
model_settings['max_tokens'] = max_tokens
42+
# Store parameters for reuse
43+
self.model = model
44+
self.temperature = temperature
45+
self.rate_limit = rate_limit
46+
self.max_tokens = max_tokens
47+
self.verbose = verbose
48+
self.system_prompt = system_prompt
49+
50+
# Create model settings
51+
model_settings = self._create_model_settings(temperature)
4852

4953
# Load MCP servers from config
5054
try:
@@ -59,13 +63,23 @@ def __init__(
5963
tools=[search_web, visit_webpages],
6064
mcp_servers=mcp_servers
6165
)
62-
self.model = model
63-
self.temperature = temperature
64-
self.rate_limit = rate_limit
65-
self.max_tokens = max_tokens
66-
self.verbose = verbose
67-
self.system_prompt = system_prompt
6866

67+
def _create_model_settings(self, temperature: float) -> Dict[str, Any]:
68+
"""Create model settings dictionary with proper parameters.
69+
70+
Args:
71+
temperature: Temperature setting for the model
72+
73+
Returns:
74+
Dictionary with model settings
75+
"""
76+
model_settings = {'temperature': temperature}
77+
if self.rate_limit is not None:
78+
model_settings['rate_limit'] = self.rate_limit
79+
if self.max_tokens is not None:
80+
model_settings['max_tokens'] = self.max_tokens
81+
return model_settings
82+
6983
@staticmethod
7084
def create_structure_class(yaml_file_path: Union[str, Path]) -> Type[BaseModel]:
7185
"""Creates a Pydantic model from YAML structure definition.
@@ -77,7 +91,6 @@ def create_structure_class(yaml_file_path: Union[str, Path]) -> Type[BaseModel]:
7791
A Pydantic model class based on the YAML structure
7892
"""
7993
yaml_file_path = Path(yaml_file_path)
80-
8194
try:
8295
with open(yaml_file_path, 'r', encoding='utf-8') as f:
8396
yaml_content = yaml.safe_load(f)
@@ -90,9 +103,12 @@ def create_structure_class(yaml_file_path: Union[str, Path]) -> Type[BaseModel]:
90103
if not isinstance(field_info, dict):
91104
raise ValueError(f"Invalid field definition for {field_name}")
92105

93-
field_type = (Literal[tuple(str(opt) for opt in field_info['options'])]
94-
if 'options' in field_info
95-
else AugmentaAgent.TYPE_MAPPING.get(field_info.get('type', 'str'), str))
106+
# Determine field type based on options or type specification
107+
if 'options' in field_info:
108+
field_type = Literal[tuple(str(opt) for opt in field_info['options'])]
109+
else:
110+
type_str = field_info.get('type', 'str')
111+
field_type = AugmentaAgent.TYPE_MAPPING.get(type_str, str)
96112

97113
fields[field_name] = (
98114
field_type,
@@ -101,8 +117,9 @@ def create_structure_class(yaml_file_path: Union[str, Path]) -> Type[BaseModel]:
101117

102118
return create_model('Structure', **fields, __base__=BaseModel)
103119

104-
except (yaml.YAMLError, OSError) as e: raise ValueError(f"Failed to parse YAML: {e}")
105-
120+
except (yaml.YAMLError, OSError) as e:
121+
raise ValueError(f"Failed to parse YAML: {e}")
122+
106123
async def run(
107124
self,
108125
prompt: str,
@@ -122,18 +139,14 @@ async def run(
122139
The agent's response after researching, either as string, dict or Pydantic model
123140
"""
124141
try:
125-
# Create model_settings for this specific request if temperature is provided
126-
model_settings = None
127-
if temperature is not None:
128-
model_settings = {'temperature': temperature}
129-
if self.rate_limit is not None:
130-
model_settings['rate_limit'] = self.rate_limit
131-
if self.max_tokens is not None:
132-
model_settings['max_tokens'] = self.max_tokens
133-
134142
# Set the system prompt
135143
self.agent.system_prompt = system_prompt or self.system_prompt
136144

145+
# Prepare model settings only if temperature override is provided
146+
model_settings = None
147+
if temperature is not None and temperature != self.temperature:
148+
model_settings = self._create_model_settings(temperature)
149+
137150
async with self.agent.run_mcp_servers():
138151
result = await self.agent.run(
139152
prompt,

0 commit comments

Comments
 (0)