Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 82 additions & 18 deletions tinytroupe/agent/tiny_person.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from typing import Any
from rich import print

import logging



#######################################################################################################################
Expand Down Expand Up @@ -226,12 +228,9 @@ def generate_agent_system_prompt(self):
return chevron.render(agent_prompt_template, template_variables)

def reset_prompt(self):

# render the template with the current configuration
self._init_system_message = self.generate_agent_system_prompt()

# TODO actually, figure out another way to update agent state without "changing history"

# reset system message
self.current_messages = [
{"role": "system", "content": self._init_system_message}
Expand All @@ -240,14 +239,16 @@ def reset_prompt(self):
# sets up the actual interaction messages to use for prompting
self.current_messages += self.retrieve_recent_memories()

# add a final user message, which is neither stimuli or action, to instigate the agent to act properly
# add a final user message to instigate the agent to act properly
self.current_messages.append({"role": "user",
"content": "Now you **must** generate a sequence of actions following your interaction directives, " +\
"and complying with **all** instructions and contraints related to the action you use." +\
"DO NOT repeat the exact same action more than once in a row!" +\
"DO NOT keep saying or doing very similar things, but instead try to adapt and make the interactions look natural." +\
"These actions **MUST** be rendered following the JSON specification perfectly, including all required keys (even if their value is empty), **ALWAYS**."
})
"content": "Now you **must** generate a single action following your interaction directives, "
"and complying with **all** instructions and constraints related to the action you use. "
"DO NOT repeat the exact same action more than once in a row! "
"DO NOT keep saying or doing very similar things, but instead try to adapt and make the interactions look natural. "
"The action **MUST** be rendered as a single JSON object with the following structure: "
"{\"cognitive_state\": {\"goals\": \"...\", \"attention\": \"...\", \"emotions\": \"...\"}, "
"\"action\": {\"type\": \"...\", \"content\": \"...\", \"target\": \"...\"}}, "
"with no extra text or additional JSON objects."})

def get(self, key):
"""
Expand Down Expand Up @@ -771,25 +772,88 @@ def make_all_agents_inaccessible(self):

@transactional
def _produce_message(self):
# logger.debug(f"Current messages: {self.current_messages}")

# ensure we have the latest prompt (initial system message + selected messages from memory)
"""
Produces the next message by sending the current conversation to the Groq API
and parsing the returned response. Handles multiple JSON objects and ensures required fields.
"""
# Refresh the prompt (integrate the latest system message and memory)
self.reset_prompt()

# Prepare the messages by serializing each message's content as JSON
messages = [
{"role": msg["role"], "content": json.dumps(msg["content"])}
for msg in self.current_messages
]

logger.debug(f"[{self.name}] Sending messages to OpenAI API")
logger.debug(f"[{self.name}] Last interaction: {messages[-1]}")
logger.debug(f"[{self.name}] Sending messages to Groq API")
if messages:
logger.debug(f"[{self.name}] Last interaction: {messages[-1]}")

next_message = openai_utils.client().send_message(messages, response_format=CognitiveActionModel)
try:
# Send messages using GroqClient via openai_utils
response = openai_utils.client().send_message(
messages,
temperature=1.0,
max_tokens=1024,
response_format=CognitiveActionModel # Expected response format
)
logger.debug(f"[{self.name}] Received response: {response}")

if response is None:
logger.error("API response is None.")
return "assistant", {"error": "No response received from API."}

logger.debug(f"[{self.name}] Received message: {next_message}")
# Extract content from response
if isinstance(response, dict):
role = response.get("role", "assistant")
content = response.get("content", "")
else:
logger.error("Unexpected response format from API.")
return "assistant", {"error": "Unexpected API response format."}

if not content:
logger.error("No content found in the API response.")
return "assistant", {"error": "Empty content received from API."}

return next_message["role"], utils.extract_json(next_message["content"])
# Log raw content for debugging
logger.info(f"[{self.name}] Raw API content: {content}")

# Extract all JSON objects
import re
json_objects = []
for match in re.finditer(r'\{.*?\}(?=\s*(?:\{|$))', content, re.DOTALL):
json_str = match.group(0)
try:
parsed_content = json.loads(json_str)
json_objects.append(parsed_content)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse JSON object: {e}. Raw content: {content}")
return "assistant", {"error": f"JSON parsing error: {str(e)}"}

if not json_objects:
logger.error(f"No valid JSON found in content: {content}")
return "assistant", {"error": "No valid JSON found in API response."}

# Use the last action in the sequence (assuming act() processes one at a time)
parsed_content = json_objects[-1]

# Ensure required fields are present
if "cognitive_state" not in parsed_content:
logger.warning("cognitive_state missing in response. Adding default value.")
parsed_content["cognitive_state"] = {
"goals": self._mental_state.get("goals", []),
"attention": self._mental_state.get("attention", "unknown"),
"emotions": self._mental_state.get("emotions", "neutral")
}
if "action" not in parsed_content:
logger.warning("action missing in response. Adding default value.")
parsed_content["action"] = {"type": "NONE", "content": "No action specified"}

return role, parsed_content

except Exception as e:
logger.error(f"Error while communicating with Groq API: {e}")
return "assistant", {"error": f"API communication error: {str(e)}"}
###########################################################
# Internal cognitive state changes
###########################################################
Expand Down
37 changes: 9 additions & 28 deletions tinytroupe/config.ini
Original file line number Diff line number Diff line change
@@ -1,20 +1,5 @@
[OpenAI]
#
# OpenAI or Azure OpenAI Service
#

# Default options: openai, azure
API_TYPE=openai

# Check Azure's documentation for updates here:
# https://learn.microsoft.com/en-us/azure/ai-services/openai/chatgpt-quickstart?tabs=command-line&pivots=programming-language-python
AZURE_API_VERSION=2023-05-15

#
# Model parameters
#

MODEL=gpt-4o-mini
API_TYPE=groq
MAX_TOKENS=4000
TEMPERATURE=1.2
FREQ_PENALTY=0.0
Expand All @@ -23,22 +8,18 @@ TIMEOUT=60
MAX_ATTEMPTS=5
WAITING_TIME=1
EXPONENTIAL_BACKOFF_FACTOR=5

EMBEDDING_MODEL=text-embedding-3-small

CACHE_API_CALLS=False
CACHE_FILE_NAME=openai_api_cache.pickle

MAX_CONTENT_DISPLAY_LENGTH=1024
[Groq]
MODEL=llama3-70b-8192
; put your groq api key here
API_KEY= ...


[Logging]
LOGLEVEL=INFO

[Simulation]
RAI_HARMFUL_CONTENT_PREVENTION=True
RAI_COPYRIGHT_INFRINGEMENT_PREVENTION=True


[Logging]
LOGLEVEL=ERROR
# ERROR
# WARNING
# INFO
# DEBUG
114 changes: 110 additions & 4 deletions tinytroupe/openai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@
from tinytroupe import utils
from tinytroupe.control import transactional

from langchain_openai import ChatOpenAI
from tenacity import retry, stop_after_attempt, wait_exponential, RetryError
import time

from openai import OpenAI
from groq import Groq


logger = logging.getLogger("tinytroupe")

# We'll use various configuration elements below
Expand Down Expand Up @@ -129,7 +137,11 @@ def call(self, **rendering_configs):
#
# call the LLM model
#
self.model_output = client().send_message(self.messages, **self.model_params)
self.model_output = client().send_message(
self.messages,
temperature=self.model_params.get("temperature", 0.2),
max_tokens=self.model_params.get("max_tokens", 128)
)

if 'content' in self.model_output:
self.response_raw = self.response_value = self.model_output['content']
Expand Down Expand Up @@ -630,6 +642,92 @@ def _setup_from_config(self):
api_version = config["OpenAI"]["AZURE_API_VERSION"],
api_key = os.getenv("AZURE_OPENAI_KEY"))


# logging.basicConfig(level=logging.INFO)
# logger = logging.getLogger(__name__)from groq import Groq
from groq import Groq

class GroqClient(OpenAIClient):
"""
A utility class for interacting with the Groq API, inheriting from OpenAIClient.
"""

def __init__(self, model_name, api_key=None, cache_api_calls=False, cache_file_name="groq_cache.pickle"):
"""
Initializes the Groq client.

Args:
model_name (str): The Groq model to use (e.g., "qwen-2.5-32b").
api_key (str, optional): API key for Groq authentication.
cache_api_calls (bool): Whether to cache API calls.
cache_file_name (str): File name for caching responses.
"""
super().__init__(cache_api_calls=cache_api_calls, cache_file_name=cache_file_name)
self.model_name = model_name
self.client = Groq(api_key=api_key) # Initialize Groq client with API key

def _raw_model_call(self, model, chat_api_params):
"""
Calls the Groq API with the provided parameters, ignoring the passed model parameter.

Args:
model (str): Ignored; uses self.model_name instead.
chat_api_params (dict): Parameters for the API call.

Returns:
The raw completion object from Groq.

Raises:
NonTerminalError: For retryable errors like rate limits.
InvalidRequestError: For non-retryable errors.
"""
# Ignore the passed model parameter and use self.model_name
messages = chat_api_params["messages"]
temperature = chat_api_params.get("temperature", 1.0)
max_tokens = chat_api_params.get("max_tokens", 1024)
top_p = chat_api_params.get("top_p", 1.0)
stop = chat_api_params.get("stop", None)

try:
completion = self.client.chat.completions.create(
model=self.model_name, # Always use initialized model_name
messages=messages,
temperature=temperature,
max_completion_tokens=max_tokens,
top_p=top_p,
stream=True, # Groq uses streaming
stop=stop
)
return completion
except Exception as e:
error_msg = str(e).lower()
if "rate limit" in error_msg:
raise NonTerminalError(f"Rate limit error: {e}")
else:
raise InvalidRequestError(f"Error calling Groq API: {e}")

def _raw_model_response_extractor(self, response):
"""
Extracts the response content from the Groq API's streaming completion.

Args:
response: The raw completion object from Groq.

Returns:
dict: A dictionary with role and content, e.g., {"role": "assistant", "content": "..."}.
"""
response_text = ""
for chunk in response:
delta = chunk.choices[0].delta.content or ""
response_text += delta
return {"role": "assistant", "content": response_text}

def _setup_from_config(self):
"""
Overrides the base method. Not needed for Groq as setup is handled in __init__.
"""
pass # API key and client are set in __init__


###########################################################################
# Exceptions
Expand Down Expand Up @@ -721,6 +819,14 @@ def force_api_cache(cache_api_calls, cache_file_name=default["cache_file_name"])
# default client
register_client("openai", OpenAIClient())
register_client("azure", AzureClient())



# register_client("avalai", AvalAiClient(
# base_url=config["AvalAi"]["BASE_URL"],
# api_key=config["AvalAi"]["API_KEY"],
# model_name=config["AvalAi"]["MODEL"]
# ))
register_client("groq", GroqClient(
model_name=config["Groq"]["MODEL"],
api_key=config["Groq"]["API_KEY"],
cache_api_calls=config["OpenAI"].getboolean("CACHE_API_CALLS", False),
cache_file_name=config["OpenAI"].get("CACHE_FILE_NAME", "openai_api_cache.pickle")
))