diff --git a/src/ember/examples/__pycache__/__init__.cpython-312.pyc b/src/ember/examples/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 00000000..162f6ed7 Binary files /dev/null and b/src/ember/examples/__pycache__/__init__.cpython-312.pyc differ diff --git a/src/ember/examples/multi_agent_systems/MULTI_AGENT_SYSTEMS_README.md b/src/ember/examples/multi_agent_systems/MULTI_AGENT_SYSTEMS_README.md new file mode 100644 index 00000000..00931d8c --- /dev/null +++ b/src/ember/examples/multi_agent_systems/MULTI_AGENT_SYSTEMS_README.md @@ -0,0 +1,117 @@ +# Multi-Agent Systems with Ember + +This repository contains implementations of various multi-agent systems using the Ember framework. Each system demonstrates how specialized agents can work together to accomplish complex tasks, with dedicated styling and formatting agents to ensure high-quality output. + +## Overview + +All systems use OpenAI models exclusively, with different temperature settings to specialize each agent for its specific role. The systems demonstrate how multiple agents can collaborate on complex tasks, with each agent focusing on a specific aspect of the task. + +## Multi-Agent Systems + +### 1. Content Creation Studio + +A multi-agent system for creating high-quality content, from planning to final formatting. + +**Agents:** +- **Content Planner**: Creates detailed content plans +- **Researcher**: Conducts comprehensive research on topics +- **Content Creator**: Generates initial drafts based on research +- **Editor**: Improves clarity, coherence, and impact +- **Stylist**: Ensures adherence to style guidelines +- **Formatter**: Optimizes formatting for different platforms + +**File:** `content_creation_studio.py` + +### 2. Code Development Pipeline + +A multi-agent system for developing software components, from requirements analysis to documentation. + +**Agents:** +- **Requirements Analyzer**: Formalizes project requirements +- **Architect**: Designs system architecture +- **Code Generator**: Writes code based on specifications +- **Code Stylist**: Ensures code follows style guidelines +- **Code Formatter**: Handles proper formatting and organization +- **Code Reviewer**: Reviews code for quality and security +- **Test Writer**: Creates comprehensive tests +- **Documentation Writer**: Creates technical documentation + +**File:** `code_development_pipeline.py` + +### 3. Educational Content Generator + +A multi-agent system for creating educational materials tailored to different learning styles and needs. + +**Agents:** +- **Curriculum Designer**: Designs comprehensive curricula +- **Subject Expert**: Creates expert-level content +- **Content Creator**: Develops engaging learning materials +- **Assessment Designer**: Creates comprehensive assessments +- **Educational Stylist**: Applies pedagogical approaches +- **Content Formatter**: Formats content for delivery platforms +- **Accessibility Specialist**: Ensures materials are accessible + +**File:** `educational_content_generator.py` + +### 4. Marketing Campaign Generator + +A multi-agent system for creating comprehensive marketing campaigns. + +**Agents:** +- **Market Researcher**: Analyzes target audience and competitors +- **Campaign Strategist**: Develops marketing strategies +- **Content Creator**: Creates campaign content +- **Copywriter**: Writes marketing copy +- **Brand Stylist**: Ensures adherence to brand guidelines +- **Channel Formatter**: Optimizes content for different channels +- **Campaign Analyzer**: Analyzes campaign effectiveness + +**File:** `marketing_campaign_generator.py` + +## Running the Systems + +You can run all systems or individual systems using the provided script: + +```bash +# Run all systems +python run_multi_agent_systems.py all + +# Run individual systems +python run_multi_agent_systems.py content +python run_multi_agent_systems.py code +python run_multi_agent_systems.py education +python run_multi_agent_systems.py marketing +``` + +## Requirements + +- Ember framework +- OpenAI API key (set as environment variable `OPENAI_API_KEY`) + +## Implementation Details + +Each multi-agent system follows a similar pattern: + +1. **Specialized Agents**: Each agent is implemented as an instance of a model with specific temperature settings to optimize for its role. +2. **Sequential Processing**: Agents work in sequence, with each agent building on the output of previous agents. +3. **Styling and Formatting**: Dedicated styling and formatting agents ensure the output adheres to guidelines and is properly formatted. +4. **Quality Control**: Review and analysis agents ensure the quality of the final output. + +## Example Outputs + +Each system produces high-quality outputs in its domain: + +- **Content Creation Studio**: Well-structured, styled, and formatted blog posts +- **Code Development Pipeline**: Clean, well-documented code with comprehensive tests +- **Educational Content Generator**: Engaging, accessible educational materials +- **Marketing Campaign Generator**: Effective marketing content optimized for different channels + +## Future Enhancements + +Potential enhancements for these multi-agent systems include: + +1. **Parallel Processing**: Allow agents to work in parallel when possible +2. **Feedback Loops**: Implement feedback loops between agents +3. **User Interaction**: Add capabilities for user feedback during the process +4. **Additional Specializations**: Add more specialized agents for specific tasks +5. **Cross-System Integration**: Enable different systems to work together diff --git a/src/ember/examples/multi_agent_systems/code_development_pipeline.py b/src/ember/examples/multi_agent_systems/code_development_pipeline.py new file mode 100644 index 00000000..00a048b0 --- /dev/null +++ b/src/ember/examples/multi_agent_systems/code_development_pipeline.py @@ -0,0 +1,381 @@ +"""Multi-Agent Code Development Pipeline. + +This module demonstrates a multi-agent system for developing software components. +""" + +import logging +from typing import Any, Dict, List, Optional, Tuple + +from ember.api import models + +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +class CodeDevelopmentPipeline: + """A multi-agent system for collaborative code development. + + This pipeline uses specialized agents to handle different aspects of code development: + - Requirements Analyst + - System Architect + - Code Generator + - Style Guide Expert + - Code Formatter + - Code Reviewer + - Test Writer + - Documentation Writer + """ + + def __init__(self): + """Initialize the specialized agents with different model configurations.""" + # All agents using OpenAI models with different specializations + self.requirements_agent = models.model("openai:gpt-4-turbo", temperature=0.2) + self.architect_agent = models.model("openai:gpt-4-turbo", temperature=0.3) + self.code_generator = models.model("openai:gpt-4-turbo", temperature=0.2) + self.style_guide_agent = models.model("openai:gpt-4-turbo", temperature=0.1) + self.code_formatter = models.model("openai:gpt-4-turbo", temperature=0.1) + self.code_reviewer = models.model("openai:gpt-4-turbo", temperature=0.1) + self.test_writer = models.model("openai:gpt-4-turbo", temperature=0.2) + self.documentation_writer = models.model("openai:gpt-4-turbo", temperature=0.3) + + def analyze_requirements(self, project_description: str) -> Dict[str, Any]: + """Analyze project requirements from a project description. + + Args: + project_description: A description of the project + + Returns: + Dictionary of analyzed requirements + """ + logger.info("Analyzing requirements...") + prompt = f"""Analyze the following project description and extract key requirements: + +Project Description: +{project_description} + +Your task: +1. Identify all functional requirements +2. Identify all non-functional requirements +3. Identify key user stories +4. List potential technical constraints +5. Highlight any ambiguities that need clarification + +Format your response as JSON with these categories. +""" + result = self.requirements_agent(prompt) + return {"requirements_analysis": result} + + def design_architecture(self, requirements: Dict[str, Any]) -> Dict[str, Any]: + """Design a system architecture based on the analyzed requirements. + + Args: + requirements: Dict containing analyzed requirements + + Returns: + Dictionary with architectural design + """ + logger.info("Designing architecture...") + prompt = f"""Based on the following requirements analysis, design a high-level architecture: + +Requirements Analysis: +{requirements["requirements_analysis"]} + +Your task: +1. Identify key components/modules +2. Define interfaces between components +3. Specify data models/structures +4. Recommend technologies/frameworks +5. Create a diagram description (ASCII or text-based) + +Format your response as JSON with these categories. +""" + result = self.architect_agent(prompt) + return {"architecture_design": result} + + def generate_code(self, requirements: Dict[str, Any], architecture: Dict[str, Any], component_name: str) -> Dict[str, Any]: + """Generate code for a specific component based on requirements and architecture. + + Args: + requirements: Dict containing analyzed requirements + architecture: Dict containing architectural design + component_name: Name of the component to generate + + Returns: + Dictionary with the generated code + """ + logger.info(f"Generating code for {component_name}...") + prompt = f"""Generate code for the {component_name} component based on: + +Requirements Analysis: +{requirements["requirements_analysis"]} + +Architecture Design: +{architecture["architecture_design"]} + +Your task: +1. Write production-quality code for the {component_name} component +2. Include necessary imports +3. Add brief inline comments for complex logic +4. Implement all required methods/functions +5. Handle basic error cases +6. Focus on readability and maintainability + +Format your response as valid code with no additional text. +""" + result = self.code_generator(prompt) + return {"initial_code": result} + + def apply_code_style(self, code: Dict[str, Any], language: str, style_guide: str) -> Dict[str, Any]: + """Apply code style guidelines to the generated code. + + Args: + code: Dict containing initial code + language: Programming language of the code + style_guide: Style guide to follow (e.g., PEP 8 for Python) + + Returns: + Dictionary with styled code + """ + logger.info(f"Applying {style_guide} style to code...") + prompt = f"""Apply the {style_guide} style guide to the following {language} code: + +```{language} +{code["initial_code"]} +``` + +Your task: +1. Follow {style_guide} conventions +2. Fix any style issues (naming, spacing, indentation, etc.) +3. Improve code organization if needed +4. Add or improve docstrings/comments as needed +5. Don't change the functionality + +Return only the updated code with no additional text. +""" + result = self.style_guide_agent(prompt) + return {"styled_code": result} + + def format_code(self, styled_code: Dict[str, Any], language: str) -> Dict[str, Any]: + """Format the code for consistency and readability. + + Args: + styled_code: Dict containing styled code + language: Programming language of the code + + Returns: + Dictionary with formatted code + """ + logger.info("Formatting code for consistency...") + prompt = f"""Format the following {language} code for optimal readability and consistency: + +```{language} +{styled_code["styled_code"]} +``` + +Your task: +1. Apply consistent indentation +2. Apply consistent line spacing between methods/functions +3. Apply consistent spacing around operators +4. Apply consistent braces/block style +5. Ensure line length follows best practices +6. Keep the existing functionality intact + +Return only the formatted code with no additional text. +""" + result = self.code_formatter(prompt) + return {"formatted_code": result} + + def review_code(self, formatted_code: Dict[str, Any], requirements: Dict[str, Any], language: str) -> Dict[str, Any]: + """Review the code for issues and suggest improvements. + + Args: + formatted_code: Dict containing formatted code + requirements: Dict containing analyzed requirements + language: Programming language of the code + + Returns: + Dictionary with review comments and improved code + """ + logger.info("Reviewing code...") + prompt = f"""Review the following {language} code against the requirements and best practices: + +```{language} +{formatted_code["formatted_code"]} +``` + +Requirements: +{requirements["requirements_analysis"]} + +Your task: +1. Identify any bugs or logical errors +2. Check for security vulnerabilities +3. Evaluate performance issues +4. Verify the code meets requirements +5. Suggest specific improvements +6. Provide an improved version of the code + +Format your response as JSON with "review_comments" and "improved_code" fields. +""" + result = self.code_reviewer(prompt) + return {"code_review": result} + + def write_tests(self, formatted_code: Dict[str, Any], requirements: Dict[str, Any], language: str) -> Dict[str, Any]: + """Write tests for the code. + + Args: + formatted_code: Dict containing formatted code + requirements: Dict containing analyzed requirements + language: Programming language of the code + + Returns: + Dictionary with test code + """ + logger.info("Writing tests...") + prompt = f"""Write comprehensive tests for the following {language} code: + +```{language} +{formatted_code["formatted_code"]} +``` + +Requirements: +{requirements["requirements_analysis"]} + +Your task: +1. Write unit tests covering all functions/methods +2. Include both positive and negative test cases +3. Test edge cases +4. Ensure test code follows {language} best practices +5. Include brief comments explaining the purpose of each test +6. Use appropriate testing framework for {language} + +Return only the test code with no additional text. +""" + result = self.test_writer(prompt) + return {"tests": result} + + def write_documentation(self, formatted_code: Dict[str, Any], architecture: Dict[str, Any], language: str) -> Dict[str, Any]: + """Write documentation for the code. + + Args: + formatted_code: Dict containing formatted code + architecture: Dict containing architectural design + language: Programming language of the code + + Returns: + Dictionary with documentation + """ + logger.info("Writing documentation...") + prompt = f"""Write comprehensive documentation for the following {language} code: + +```{language} +{formatted_code["formatted_code"]} +``` + +Architecture Context: +{architecture["architecture_design"]} + +Your task: +1. Create a README with overview and usage examples +2. Document the component's purpose and integration with other components +3. Document all public APIs, functions, and classes +4. Include installation and configuration instructions +5. Document any dependencies +6. Provide troubleshooting information + +Return the documentation in Markdown format. +""" + result = self.documentation_writer(prompt) + return {"documentation": result} + + def develop_component(self, project_description: str, component_name: str, language: str, style_guide: str) -> Dict[str, Any]: + """Execute the full code development pipeline. + + Args: + project_description: A description of the project + component_name: Name of the component to develop + language: Programming language to use + style_guide: Style guide to follow + + Returns: + Dictionary with all artifacts from the development process + """ + # Analyze requirements + requirements = self.analyze_requirements(project_description) + + # Design architecture + architecture = self.design_architecture(requirements) + + # Generate initial code + code = self.generate_code(requirements, architecture, component_name) + + # Apply style guide + styled_code = self.apply_code_style(code, language, style_guide) + + # Format code + formatted_code = self.format_code(styled_code, language) + + # Review code and get improvements + code_review = self.review_code(formatted_code, requirements, language) + + # Write tests + tests = self.write_tests(formatted_code, requirements, language) + + # Write documentation + documentation = self.write_documentation(formatted_code, architecture, language) + + # Return all artifacts + return { + "requirements": requirements, + "architecture": architecture, + "code": { + "initial": code["initial_code"], + "styled": styled_code["styled_code"], + "formatted": formatted_code["formatted_code"], + "final": code_review["code_review"] + }, + "review": code_review["code_review"], + "tests": tests["tests"], + "documentation": documentation["documentation"] + } + +# Example usage +if __name__ == "__main__": + # Set up environment variables for API keys + import os + if "OPENAI_API_KEY" not in os.environ: + print("Warning: OPENAI_API_KEY environment variable not set.") + print("Please set your API key using: export OPENAI_API_KEY='your-key-here'") + + # Create the pipeline + pipeline = CodeDevelopmentPipeline() + + # Define parameters for code development + project_description = """ + Create a weather forecast application that: + 1. Fetches weather data from a public API + 2. Displays current conditions and 5-day forecast + 3. Allows users to save favorite locations + 4. Sends notifications for severe weather alerts + 5. Works on both mobile and desktop browsers + """ + + component_name = "WeatherDataService" + language = "python" + style_guide = "PEP 8" + + # Develop the component + print("Developing the WeatherDataService component...") + result = pipeline.develop_component( + project_description, + component_name, + language, + style_guide + ) + + # Print the final code + print("\n\n=== FINAL CODE ===\n\n") + print(result["code"]["final"]) + + # Print the tests + print("\n\n=== TESTS ===\n\n") + print(result["tests"]) diff --git a/src/ember/examples/multi_agent_systems/content_creation_studio.py b/src/ember/examples/multi_agent_systems/content_creation_studio.py new file mode 100644 index 00000000..51c6ad38 --- /dev/null +++ b/src/ember/examples/multi_agent_systems/content_creation_studio.py @@ -0,0 +1,305 @@ +"""Multi-Agent Content Creation Studio. + +This module demonstrates a multi-agent system for creating content. +""" + +import logging +from typing import Any, Dict, List, Optional + +from ember.api import models + +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +class ContentCreationStudio: + """A multi-agent system for collaborative content creation. + + This studio uses specialized agents to handle different aspects of content creation: + - Content Planner + - Researcher + - Content Drafter + - Editor + - Style Specialist + - Formatter + """ + + def __init__(self): + """Initialize the specialized agents with different model configurations.""" + # All agents using OpenAI models with different temperatures for specialization + self.content_planner = models.model("openai:gpt-4-turbo", temperature=0.7) # More creative for ideation + self.researcher = models.model("openai:gpt-4-turbo", temperature=0.2) # More factual for research + self.content_drafter = models.model("openai:gpt-4-turbo", temperature=0.5) # Balanced for drafting + self.editor = models.model("openai:gpt-4-turbo", temperature=0.3) # More critical for editing + self.style_specialist = models.model("openai:gpt-4-turbo", temperature=0.4) # Creative but controlled for style + self.formatter = models.model("openai:gpt-4-turbo", temperature=0.1) # More precise for formatting + + def create_content_plan(self, topic: str, target_audience: str, content_type: str) -> Dict[str, Any]: + """Create a content plan based on topic, audience and content type. + + Args: + topic: The main topic of the content + target_audience: Description of the target audience + content_type: Type of content (e.g., blog post, social media, whitepaper) + + Returns: + Dictionary containing the content plan + """ + logger.info(f"Creating content plan for {content_type} on {topic}...") + prompt = f"""Create a detailed content plan for a {content_type} about {topic} aimed at {target_audience}. + +Your plan should include: +1. A compelling headline/title +2. Content objectives and key messages +3. Detailed outline with section headings +4. Key points to cover in each section +5. Recommended tone and style +6. Suggested content length + +Format your response as JSON with these categories. +""" + result = self.content_planner(prompt) + return {"content_plan": result} + + def conduct_research(self, topic: str, plan: Dict[str, Any]) -> Dict[str, Any]: + """Conduct research on the topic based on the content plan. + + Args: + topic: The main topic to research + plan: Dictionary containing the content plan + + Returns: + Dictionary containing research findings + """ + logger.info(f"Conducting research on {topic}...") + prompt = f"""Conduct thorough research on {topic} to support this content plan: + +{plan["content_plan"]} + +Your research should: +1. Identify key facts, statistics, and data points +2. Find relevant examples and case studies +3. Identify expert opinions and quotations (with attributions) +4. Uncover common questions and misconceptions +5. Find trending topics related to {topic} + +Format your response as JSON with these categories and include sources where applicable. +""" + result = self.researcher(prompt) + return {"research": result} + + def generate_draft(self, plan: Dict[str, Any], research: Dict[str, Any]) -> Dict[str, Any]: + """Generate a content draft based on the plan and research. + + Args: + plan: Dictionary containing the content plan + research: Dictionary containing research findings + + Returns: + Dictionary containing the draft content + """ + logger.info("Generating content draft...") + prompt = f"""Create a comprehensive first draft based on this content plan and research: + +Content Plan: +{plan["content_plan"]} + +Research: +{research["research"]} + +Write a complete draft that: +1. Follows the structure in the content plan +2. Incorporates key facts and insights from the research +3. Uses a conversational but informative tone +4. Includes an engaging introduction and conclusion +5. Weaves in examples and data points naturally + +Return only the draft content with no additional commentary. +""" + result = self.content_drafter(prompt) + return {"draft": result} + + def edit_content(self, draft: Dict[str, Any]) -> Dict[str, Any]: + """Edit the draft content for clarity, flow, and accuracy. + + Args: + draft: Dictionary containing the draft content + + Returns: + Dictionary containing the edited content + """ + logger.info("Editing content...") + prompt = f"""Edit this content draft for clarity, flow, and impact: + +{draft["draft"]} + +Improve the content by: +1. Enhancing clarity and readability +2. Improving paragraph and sentence structure +3. Eliminating redundancies and filler content +4. Strengthening transitions between sections +5. Ensuring logical flow of ideas +6. Maintaining a consistent voice throughout + +Return only the edited content with no additional commentary. +""" + result = self.editor(prompt) + return {"edited_content": result} + + def apply_style(self, edited_content: Dict[str, Any], style_guide: str) -> Dict[str, Any]: + """Apply the specified style guide to the edited content. + + Args: + edited_content: Dictionary containing the edited content + style_guide: Style guide to apply + + Returns: + Dictionary containing the styled content + """ + logger.info("Applying style guidelines...") + prompt = f"""Apply the following style guide to this edited content: + +Style Guide: +{style_guide} + +Content: +{edited_content["edited_content"]} + +Apply the style by: +1. Adjusting the tone and voice to match the style guide +2. Using appropriate terminology and phrasing +3. Applying brand language elements where relevant +4. Ensuring sentence structure aligns with the style +5. Maintaining consistency with the specified style throughout + +Return only the styled content with no additional commentary. +""" + result = self.style_specialist(prompt) + return {"styled_content": result} + + def format_content(self, styled_content: Dict[str, Any], format_requirements: str, platform: str) -> Dict[str, Any]: + """Format the styled content according to specified requirements and platform. + + Args: + styled_content: Dictionary containing the styled content + format_requirements: Formatting requirements specification + platform: Target platform (e.g., WordPress, Medium, LinkedIn) + + Returns: + Dictionary containing the formatted content + """ + logger.info(f"Formatting content for {platform}...") + prompt = f"""Format this content according to these requirements for {platform}: + +Format Requirements: +{format_requirements} + +Content: +{styled_content["styled_content"]} + +Apply formatting by: +1. Structuring the content with proper headings and subheadings +2. Adding appropriate formatting tags/markdown for {platform} +3. Breaking up text with lists, callouts, and spacing as needed +4. Including any platform-specific elements +5. Ensuring the format enhances readability and engagement + +Return the formatted content in a format appropriate for {platform}. +""" + result = self.formatter(prompt) + return {"final_formatted": result} + + def create_complete_content(self, topic: str, target_audience: str, content_type: str, style_guide: str, format_requirements: str, platform: str) -> Dict[str, Any]: + """Execute the full content creation pipeline. + + Args: + topic: The main topic of the content + target_audience: Description of the target audience + content_type: Type of content (e.g., blog post, social media, whitepaper) + style_guide: Style guide to apply + format_requirements: Formatting requirements specification + platform: Target platform (e.g., WordPress, Medium, LinkedIn) + + Returns: + Dictionary with all artifacts from the content creation process + """ + # Create content plan + plan = self.create_content_plan(topic, target_audience, content_type) + + # Conduct research + research = self.conduct_research(topic, plan) + + # Generate draft + draft = self.generate_draft(plan, research) + + # Edit content + edited_content = self.edit_content(draft) + + # Apply style + styled_content = self.apply_style(edited_content, style_guide) + + # Format content + formatted_content = self.format_content(styled_content, format_requirements, platform) + + # Return all artifacts + return { + "plan": plan, + "research": research, + "draft": draft["draft"], + "edited_content": edited_content["edited_content"], + "styled_content": styled_content["styled_content"], + "final_formatted": formatted_content["final_formatted"] + } + +# Example usage +if __name__ == "__main__": + # Set up environment variables for API keys + import os + if "OPENAI_API_KEY" not in os.environ: + print("Warning: OPENAI_API_KEY environment variable not set.") + print("Please set your API key using: export OPENAI_API_KEY='your-key-here'") + + # Create the studio + studio = ContentCreationStudio() + + # Define parameters for content creation + topic = "Sustainable Urban Gardening" + target_audience = "Urban millennials interested in sustainability and home gardening" + content_type = "blog post" + + style_guide = """ + Voice: Friendly, informative, and encouraging + Tone: Conversational but authoritative + Terminology: Use accessible gardening terms, explain technical concepts + Sentence structure: Mix of short and medium-length sentences, avoid complex structures + Brand elements: Emphasize sustainability, community, and practical solutions + """ + + format_requirements = """ + - 1500-2000 words + - H1 main title + - H2 for main sections + - H3 for subsections + - Short paragraphs (3-4 sentences max) + - Include bulleted lists where appropriate + - Bold key points and important terms + - Include a "Quick Tips" section + - End with a call-to-action + """ + + platform = "WordPress blog" + + # Generate the content + print("Generating complete content for 'Sustainable Urban Gardening'...") + content = studio.create_complete_content( + topic, + target_audience, + content_type, + style_guide, + format_requirements, + platform + ) + + # Print the final formatted content + print("\n\n=== FINAL FORMATTED CONTENT ===\n\n") + print(content["final_formatted"]) diff --git a/src/ember/examples/multi_agent_systems/educational_content_generator.py b/src/ember/examples/multi_agent_systems/educational_content_generator.py new file mode 100644 index 00000000..9a1f459f --- /dev/null +++ b/src/ember/examples/multi_agent_systems/educational_content_generator.py @@ -0,0 +1,252 @@ +"""Educational Content Generator. + +This module demonstrates a multi-agent system for creating educational content. +""" + +import logging +from typing import Any, Dict, List, Optional + +from ember.api import models + +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +class EducationalContentGenerator: + def __init__(self): + # All agents using OpenAI models with different specializations + self.curriculum_designer = models.model("openai:gpt-4o", temperature=0.4) + self.subject_expert = models.model("openai:gpt-4o", temperature=0.3) + self.content_creator = models.model("openai:gpt-4o", temperature=0.6) + self.assessment_designer = models.model("openai:gpt-4o", temperature=0.4) + self.educational_stylist = models.model("openai:gpt-4o", temperature=0.3) # Specialized for educational styling + self.content_formatter = models.model("openai:gpt-4o", temperature=0.2) # Specialized for formatting + self.accessibility_specialist = models.model("openai:gpt-4o", temperature=0.3) + + def design_curriculum(self, subject, grade_level, learning_objectives): + curriculum = self.curriculum_designer( + f"""Design a comprehensive curriculum for {subject} at {grade_level} grade level. + + Learning Objectives: + {learning_objectives} + + Include: + - Unit structure and sequence + - Key concepts for each unit + - Skill progression throughout the curriculum + - Suggested timeframes + - Prerequisites and connections to other subjects + """ + ) + return curriculum + + def create_expert_content(self, curriculum, unit_name): + expert_content = self.subject_expert( + f"""Create expert-level content for the {unit_name} unit in this curriculum: + + {curriculum} + + Provide: + - Comprehensive explanation of all concepts + - Accurate and up-to-date information + - Common misconceptions and clarifications + - Advanced examples and applications + - Historical context and future directions + """ + ) + return expert_content + + def create_learning_materials(self, expert_content, grade_level, learning_styles): + learning_materials = self.content_creator( + f"""Create engaging learning materials based on this expert content for {grade_level} grade level: + + Expert Content: + {expert_content} + + Learning Styles to Address: + {learning_styles} + + Create: + - Lesson plans with clear objectives + - Engaging explanations and examples + - Interactive activities and exercises + - Visual aids and diagrams + - Real-world applications and scenarios + """ + ) + return learning_materials + + def create_assessments(self, learning_materials, learning_objectives): + assessments = self.assessment_designer( + f"""Create comprehensive assessments for these learning materials: + + Learning Materials: + {learning_materials} + + Learning Objectives: + {learning_objectives} + + Include: + - Formative assessments for ongoing feedback + - Summative assessments for unit completion + - Various question types (multiple choice, short answer, essay, etc.) + - Performance tasks and projects + - Rubrics and scoring guidelines + """ + ) + return assessments + + def apply_educational_style(self, learning_materials, pedagogical_approach, grade_level): + styled_materials = self.educational_stylist( + f"""Apply the {pedagogical_approach} pedagogical approach to these learning materials for {grade_level} grade level: + + Learning Materials: + {learning_materials} + + Ensure the materials: + - Use age-appropriate language and examples + - Follow best practices for the specified pedagogical approach + - Maintain consistent terminology and presentation + - Use effective instructional techniques + - Incorporate appropriate scaffolding + """ + ) + return styled_materials + + def format_educational_content(self, styled_materials, delivery_format): + formatted_materials = self.content_formatter( + f"""Format these educational materials for {delivery_format}: + + Materials: + {styled_materials} + + Apply formatting appropriate for {delivery_format}, including: + - Proper headings and structure + - Visual organization of information + - Consistent layout and design + - Appropriate use of space and breaks + - Format-specific elements and features + """ + ) + return formatted_materials + + def ensure_accessibility(self, formatted_materials, accessibility_requirements): + accessible_materials = self.accessibility_specialist( + f"""Enhance these educational materials to meet these accessibility requirements: + + Materials: + {formatted_materials} + + Accessibility Requirements: + {accessibility_requirements} + + Ensure the materials: + - Are accessible to students with various disabilities + - Include alternative representations of visual content + - Use accessible language and structure + - Can be navigated with assistive technologies + - Meet specified accessibility standards + """ + ) + return accessible_materials + + def generate_educational_unit(self, subject, unit_name, grade_level, learning_objectives, + learning_styles, pedagogical_approach, delivery_format, + accessibility_requirements): + # Design curriculum + curriculum = self.design_curriculum(subject, grade_level, learning_objectives) + curriculum_text = str(curriculum) + + # Create expert content + expert_content = self.create_expert_content(curriculum_text, unit_name) + expert_content_text = str(expert_content) + + # Create learning materials + learning_materials = self.create_learning_materials(expert_content_text, grade_level, learning_styles) + learning_materials_text = str(learning_materials) + + # Create assessments + assessments = self.create_assessments(learning_materials_text, learning_objectives) + assessments_text = str(assessments) + + # Apply educational styling + styled_materials = self.apply_educational_style(learning_materials_text, pedagogical_approach, grade_level) + styled_materials_text = str(styled_materials) + + # Format content for delivery + formatted_materials = self.format_educational_content(styled_materials_text, delivery_format) + formatted_materials_text = str(formatted_materials) + + # Ensure accessibility + accessible_materials = self.ensure_accessibility(formatted_materials_text, accessibility_requirements) + accessible_materials_text = str(accessible_materials) + + return { + "curriculum": curriculum_text, + "expert_content": expert_content_text, + "learning_materials": { + "raw": learning_materials_text, + "styled": styled_materials_text, + "formatted": formatted_materials_text, + "accessible": accessible_materials_text + }, + "assessments": assessments_text + } + +if __name__ == "__main__": + # Set up environment variables for API keys + import os + if "OPENAI_API_KEY" not in os.environ: + print("Warning: OPENAI_API_KEY environment variable not set.") + print("Please set your API key using: export OPENAI_API_KEY='your-key-here'") + + # Create the generator + generator = EducationalContentGenerator() + + # Define parameters for educational content generation + subject = "Environmental Science" + unit_name = "Ecosystems and Biodiversity" + grade_level = "8th" + + learning_objectives = """ + 1. Understand the components and interactions within ecosystems + 2. Explain the importance of biodiversity for ecosystem health + 3. Identify human impacts on ecosystems and biodiversity + 4. Analyze solutions for preserving biodiversity + 5. Design a plan to protect a local ecosystem + """ + + learning_styles = """ + - Visual learners: diagrams, charts, videos + - Auditory learners: discussions, audio explanations + - Kinesthetic learners: hands-on activities, experiments + - Reading/writing learners: text-based materials, note-taking activities + """ + + pedagogical_approach = "inquiry-based learning" + delivery_format = "digital learning management system" + + accessibility_requirements = """ + - Screen reader compatibility + - Alternative text for images + - Transcripts for audio content + - Color contrast for visually impaired students + - Multiple means of engagement and expression + """ + + # Generate the educational unit + print("Generating educational unit for 'Ecosystems and Biodiversity'...") + unit = generator.generate_educational_unit( + subject, + unit_name, + grade_level, + learning_objectives, + learning_styles, + pedagogical_approach, + delivery_format, + accessibility_requirements + ) + + # Print the accessible learning materials + print("\n\n=== ACCESSIBLE LEARNING MATERIALS ===\n\n") + print(unit["learning_materials"]["accessible"]) diff --git a/src/ember/examples/multi_agent_systems/marketing_campaign_generator.py b/src/ember/examples/multi_agent_systems/marketing_campaign_generator.py new file mode 100644 index 00000000..6a5fe739 --- /dev/null +++ b/src/ember/examples/multi_agent_systems/marketing_campaign_generator.py @@ -0,0 +1,281 @@ +"""Marketing Campaign Generator. + +This module demonstrates a multi-agent system for creating marketing campaigns. +""" + +import logging +from typing import Any, Dict, List, Optional + +from ember.api import models + +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +class MarketingCampaignGenerator: + def __init__(self): + # All agents using OpenAI models with different specializations + self.market_researcher = models.model("openai:gpt-4o", temperature=0.3) + self.campaign_strategist = models.model("openai:gpt-4o", temperature=0.5) + self.content_creator = models.model("openai:gpt-4o", temperature=0.7) + self.copywriter = models.model("openai:gpt-4o", temperature=0.6) + self.brand_stylist = models.model("openai:gpt-4o", temperature=0.3) # Specialized for brand styling + self.channel_formatter = models.model("openai:gpt-4o", temperature=0.2) # Specialized for channel-specific formatting + self.campaign_analyzer = models.model("openai:gpt-4o", temperature=0.2) + + def conduct_market_research(self, product, target_audience, competitors): + research = self.market_researcher( + f"""Conduct comprehensive market research for {product} targeting {target_audience}. + + Competitors to analyze: + {competitors} + + Provide: + - Target audience demographics and psychographics + - Market trends and opportunities + - Competitor analysis and positioning + - Customer pain points and needs + - Unique selling propositions for the product + """ + ) + return research + + def develop_campaign_strategy(self, research, campaign_objectives, budget_range): + strategy = self.campaign_strategist( + f"""Develop a marketing campaign strategy based on this research: + + Research: + {research} + + Campaign Objectives: + {campaign_objectives} + + Budget Range: + {budget_range} + + Include: + - Campaign theme and key messages + - Marketing channels and tactics + - Budget allocation + - Timeline and key milestones + - KPIs and success metrics + """ + ) + return strategy + + def create_campaign_content(self, strategy, brand_guidelines): + content = self.content_creator( + f"""Create marketing campaign content based on this strategy: + + Strategy: + {strategy} + + Brand Guidelines: + {brand_guidelines} + + Develop: + - Campaign tagline and slogans + - Key messaging points + - Content themes and concepts + - Visual direction recommendations + - Campaign narrative and storytelling elements + """ + ) + return content + + def write_marketing_copy(self, content, target_audience, channels): + copy = self.copywriter( + f"""Write marketing copy based on this campaign content: + + Campaign Content: + {content} + + Target Audience: + {target_audience} + + Marketing Channels: + {channels} + + Create copy for each specified channel that: + - Resonates with the target audience + - Communicates key messages effectively + - Drives desired actions + - Maintains consistent voice across channels + - Adapts to channel-specific requirements + """ + ) + return copy + + def apply_brand_style(self, copy, brand_guidelines): + styled_copy = self.brand_stylist( + f"""Apply these brand guidelines to the marketing copy: + + Marketing Copy: + {copy} + + Brand Guidelines: + {brand_guidelines} + + Ensure the copy: + - Adheres to brand voice and tone + - Uses approved terminology and language + - Follows brand messaging hierarchy + - Incorporates brand values and personality + - Maintains brand consistency + """ + ) + return styled_copy + + def format_for_channels(self, styled_copy, channels): + formatted_content = {} + + for channel in channels: + formatted = self.channel_formatter( + f"""Format this marketing copy for {channel}: + + Styled Copy: + {styled_copy} + + Apply formatting specific to {channel}, including: + - Character/word count limitations + - Platform-specific features and capabilities + - Best practices for engagement on this channel + - Technical requirements and constraints + - Optimal content structure for this channel + """ + ) + formatted_content[channel] = formatted + + return formatted_content + + def analyze_campaign_effectiveness(self, strategy, formatted_content): + analysis = self.campaign_analyzer( + f"""Analyze the potential effectiveness of this marketing campaign: + + Campaign Strategy: + {strategy} + + Campaign Content: + {formatted_content} + + Provide: + - Strengths and weaknesses of the campaign + - Alignment with campaign objectives + - Potential challenges and risks + - Recommendations for optimization + - Expected outcomes and impact + """ + ) + return analysis + + def generate_marketing_campaign(self, product, target_audience, competitors, campaign_objectives, + budget_range, brand_guidelines, channels): + # Conduct market research + research = self.conduct_market_research(product, target_audience, competitors) + research_text = str(research) + + # Develop campaign strategy + strategy = self.develop_campaign_strategy(research_text, campaign_objectives, budget_range) + strategy_text = str(strategy) + + # Create campaign content + content = self.create_campaign_content(strategy_text, brand_guidelines) + content_text = str(content) + + # Write marketing copy + copy = self.write_marketing_copy(content_text, target_audience, channels) + copy_text = str(copy) + + # Apply brand styling + styled_copy = self.apply_brand_style(copy_text, brand_guidelines) + styled_copy_text = str(styled_copy) + + # Format for different channels + formatted_content = {} + for channel in channels: + channel_content = self.channel_formatter( + f"""Format this marketing copy for {channel}: + + Styled Copy: + {styled_copy_text} + + Apply formatting specific to {channel}, including: + - Character/word count limitations + - Platform-specific features and capabilities + - Best practices for engagement on this channel + - Technical requirements and constraints + - Optimal content structure for this channel + """ + ) + formatted_content[channel] = str(channel_content) + + # Analyze campaign effectiveness + analysis = self.analyze_campaign_effectiveness(strategy_text, str(formatted_content)) + analysis_text = str(analysis) + + return { + "research": research_text, + "strategy": strategy_text, + "content": content_text, + "copy": copy_text, + "styled_copy": styled_copy_text, + "formatted_content": formatted_content, + "analysis": analysis_text + } + +if __name__ == "__main__": + # Set up environment variables for API keys + import os + if "OPENAI_API_KEY" not in os.environ: + print("Warning: OPENAI_API_KEY environment variable not set.") + print("Please set your API key using: export OPENAI_API_KEY='your-key-here'") + + # Create the generator + generator = MarketingCampaignGenerator() + + # Define parameters for marketing campaign generation + product = "EcoCharge - Solar-powered portable charger for mobile devices" + + target_audience = "Environmentally conscious millennials and Gen Z who enjoy outdoor activities" + + competitors = """ + 1. SolarJuice - Premium solar chargers with high price point + 2. PowerGreen - Budget solar chargers with lower efficiency + 3. NaturePower - Well-established brand with wide product range + """ + + campaign_objectives = """ + 1. Increase brand awareness by 30% among target audience + 2. Generate 10,000 website visits within the first month + 3. Achieve 2,000 product sales in the first quarter + 4. Establish EcoCharge as an eco-friendly tech leader + """ + + budget_range = "$50,000 - $75,000" + + brand_guidelines = """ + Voice: Friendly, enthusiastic, and environmentally conscious + Tone: Inspirational, educational, and slightly playful + Colors: Green (#2E8B57), Blue (#1E90FF), White (#FFFFFF) + Typography: Clean, modern sans-serif fonts + Values: Sustainability, innovation, quality, adventure + Messaging: Focus on environmental impact, convenience, and reliability + """ + + channels = ["Instagram", "TikTok", "Email Newsletter", "Google Ads", "Outdoor Retailer Partnerships"] + + # Generate the marketing campaign + print("Generating marketing campaign for 'EcoCharge'...") + campaign = generator.generate_marketing_campaign( + product, + target_audience, + competitors, + campaign_objectives, + budget_range, + brand_guidelines, + channels + ) + + # Print the formatted content for Instagram + print("\n\n=== INSTAGRAM CONTENT ===\n\n") + print(campaign["formatted_content"]["Instagram"]) diff --git a/src/ember/examples/multi_agent_systems/run_multi_agent_systems.py b/src/ember/examples/multi_agent_systems/run_multi_agent_systems.py new file mode 100644 index 00000000..f95603b3 --- /dev/null +++ b/src/ember/examples/multi_agent_systems/run_multi_agent_systems.py @@ -0,0 +1,243 @@ +import os +import sys +from content_creation_studio import ContentCreationStudio +from code_development_pipeline import CodeDevelopmentPipeline +from educational_content_generator import EducationalContentGenerator +from marketing_campaign_generator import MarketingCampaignGenerator + +def run_content_creation_studio(): + print("\n" + "="*80) + print("RUNNING CONTENT CREATION STUDIO") + print("="*80 + "\n") + + # Create the studio + studio = ContentCreationStudio() + + # Define parameters for content creation + topic = "Sustainable Urban Gardening" + target_audience = "Urban millennials interested in sustainability and home gardening" + content_type = "blog post" + + style_guide = """ + Voice: Friendly, informative, and encouraging + Tone: Conversational but authoritative + Terminology: Use accessible gardening terms, explain technical concepts + Sentence structure: Mix of short and medium-length sentences, avoid complex structures + Brand elements: Emphasize sustainability, community, and practical solutions + """ + + format_requirements = """ + - 1500-2000 words + - H1 main title + - H2 for main sections + - H3 for subsections + - Short paragraphs (3-4 sentences max) + - Include bulleted lists where appropriate + - Bold key points and important terms + - Include a "Quick Tips" section + - End with a call-to-action + """ + + platform = "WordPress blog" + + # Generate the content + print(f"Generating content for '{topic}'...") + content = studio.create_complete_content( + topic, + target_audience, + content_type, + style_guide, + format_requirements, + platform + ) + + # Print the final formatted content + print("\n--- FINAL FORMATTED CONTENT ---\n") + print(content["final_formatted"][:1000] + "...\n[Content truncated for brevity]") + + return content + +def run_code_development_pipeline(): + print("\n" + "="*80) + print("RUNNING CODE DEVELOPMENT PIPELINE") + print("="*80 + "\n") + + # Create the pipeline + pipeline = CodeDevelopmentPipeline() + + # Define parameters for code development + project_description = """ + Create a weather forecast application that: + 1. Fetches weather data from a public API + 2. Displays current conditions and 5-day forecast + 3. Allows users to save favorite locations + 4. Sends notifications for severe weather alerts + 5. Works on both mobile and desktop browsers + """ + + component_name = "WeatherDataService" + language = "python" + style_guide = "PEP 8" + + # Develop the component + print(f"Developing the {component_name} component...") + result = pipeline.develop_component( + project_description, + component_name, + language, + style_guide + ) + + # Print the final code + print("\n--- FINAL CODE ---\n") + print(result["code"]["final"][:1000] + "...\n[Code truncated for brevity]") + + # Print the tests + print("\n--- TESTS ---\n") + print(result["tests"][:1000] + "...\n[Tests truncated for brevity]") + + return result + +def run_educational_content_generator(): + print("\n" + "="*80) + print("RUNNING EDUCATIONAL CONTENT GENERATOR") + print("="*80 + "\n") + + # Create the generator + generator = EducationalContentGenerator() + + # Define parameters for educational content generation + subject = "Environmental Science" + unit_name = "Ecosystems and Biodiversity" + grade_level = "8th" + + learning_objectives = """ + 1. Understand the components and interactions within ecosystems + 2. Explain the importance of biodiversity for ecosystem health + 3. Identify human impacts on ecosystems and biodiversity + 4. Analyze solutions for preserving biodiversity + 5. Design a plan to protect a local ecosystem + """ + + learning_styles = """ + - Visual learners: diagrams, charts, videos + - Auditory learners: discussions, audio explanations + - Kinesthetic learners: hands-on activities, experiments + - Reading/writing learners: text-based materials, note-taking activities + """ + + pedagogical_approach = "inquiry-based learning" + delivery_format = "digital learning management system" + + accessibility_requirements = """ + - Screen reader compatibility + - Alternative text for images + - Transcripts for audio content + - Color contrast for visually impaired students + - Multiple means of engagement and expression + """ + + # Generate the educational unit + print(f"Generating educational unit for '{unit_name}'...") + unit = generator.generate_educational_unit( + subject, + unit_name, + grade_level, + learning_objectives, + learning_styles, + pedagogical_approach, + delivery_format, + accessibility_requirements + ) + + # Print the accessible learning materials + print("\n--- ACCESSIBLE LEARNING MATERIALS ---\n") + print(unit["learning_materials"]["accessible"][:1000] + "...\n[Content truncated for brevity]") + + return unit + +def run_marketing_campaign_generator(): + print("\n" + "="*80) + print("RUNNING MARKETING CAMPAIGN GENERATOR") + print("="*80 + "\n") + + # Create the generator + generator = MarketingCampaignGenerator() + + # Define parameters for marketing campaign generation + product = "EcoCharge - Solar-powered portable charger for mobile devices" + + target_audience = "Environmentally conscious millennials and Gen Z who enjoy outdoor activities" + + competitors = """ + 1. SolarJuice - Premium solar chargers with high price point + 2. PowerGreen - Budget solar chargers with lower efficiency + 3. NaturePower - Well-established brand with wide product range + """ + + campaign_objectives = """ + 1. Increase brand awareness by 30% among target audience + 2. Generate 10,000 website visits within the first month + 3. Achieve 2,000 product sales in the first quarter + 4. Establish EcoCharge as an eco-friendly tech leader + """ + + budget_range = "$50,000 - $75,000" + + brand_guidelines = """ + Voice: Friendly, enthusiastic, and environmentally conscious + Tone: Inspirational, educational, and slightly playful + Colors: Green (#2E8B57), Blue (#1E90FF), White (#FFFFFF) + Typography: Clean, modern sans-serif fonts + Values: Sustainability, innovation, quality, adventure + Messaging: Focus on environmental impact, convenience, and reliability + """ + + channels = ["Instagram", "TikTok", "Email Newsletter", "Google Ads", "Outdoor Retailer Partnerships"] + + # Generate the marketing campaign + print(f"Generating marketing campaign for 'EcoCharge'...") + campaign = generator.generate_marketing_campaign( + product, + target_audience, + competitors, + campaign_objectives, + budget_range, + brand_guidelines, + channels + ) + + # Print the formatted content for Instagram + print("\n--- INSTAGRAM CONTENT ---\n") + print(campaign["formatted_content"]["Instagram"]) + + return campaign + +def main(): + # Set up environment variables for API keys + if "OPENAI_API_KEY" not in os.environ: + print("Warning: OPENAI_API_KEY environment variable not set.") + print("Please set your API key using: export OPENAI_API_KEY='your-key-here'") + return + + # Get the system to run from command line arguments + if len(sys.argv) > 1: + system_to_run = sys.argv[1].lower() + else: + system_to_run = "all" + + # Run the selected system(s) + if system_to_run == "content" or system_to_run == "all": + run_content_creation_studio() + + if system_to_run == "code" or system_to_run == "all": + run_code_development_pipeline() + + if system_to_run == "education" or system_to_run == "all": + run_educational_content_generator() + + if system_to_run == "marketing" or system_to_run == "all": + run_marketing_campaign_generator() + +if __name__ == "__main__": + main() diff --git a/src/ember/examples/nft_assistant/NFT_ASSISTANT_README.md b/src/ember/examples/nft_assistant/NFT_ASSISTANT_README.md new file mode 100644 index 00000000..e52fd828 --- /dev/null +++ b/src/ember/examples/nft_assistant/NFT_ASSISTANT_README.md @@ -0,0 +1,107 @@ +# NFT Education Assistant + +This project implements an NFT Education Assistant using Ember's model API. The assistant provides personalized explanations of NFT and poker concepts based on the user's expertise level. + +## Features + +- Uses state-of-the-art language models for accurate explanations +- Adapts explanations based on user expertise level (beginner, intermediate, expert) +- Focuses on NFT and poker-related concepts +- Supports multiple model providers (OpenAI, Anthropic, Google/Deepmind) + +## Implementation + +The assistant uses Ember's model API to generate personalized explanations based on the user's query and expertise level. + +Supported models include: +- OpenAI: GPT-4o, GPT-4o-mini, GPT-4, GPT-4-turbo, GPT-3.5-turbo +- Anthropic: Claude 3 Opus, Claude 3.5 Sonnet, Claude 3.5 Haiku, Claude 3.7 Sonnet +- Deepmind: Gemini 1.5 Pro, Gemini 1.5 Flash, Gemini 2.0 Pro + +## Setup + +Before using the NFT Education Assistant, you need to set up your API keys for the providers you want to use: + +1. Get API keys from the providers you want to use: + - [OpenAI](https://platform.openai.com/) + - [Anthropic](https://www.anthropic.com/) + - [Google AI](https://ai.google.dev/) + +2. Set them as environment variables: + +```bash +export OPENAI_API_KEY="your-openai-api-key" +export ANTHROPIC_API_KEY="your-anthropic-api-key" +export GOOGLE_API_KEY="your-google-api-key" +``` + +## Usage + +```python +from nft_education_assistant import NFTEducationAssistant + +# Create an instance of the assistant with a specific model +assistant = NFTEducationAssistant(model_name="openai:gpt-4o") + +# Or use a different model provider +# assistant = NFTEducationAssistant(model_name="deepmind:gemini-1.5-pro") +# assistant = NFTEducationAssistant(model_name="anthropic:claude-3-opus") + +# Get an explanation for a beginner +beginner_explanation = assistant.explain_concept( + query="What is a non-fungible token?", + user_expertise_level="beginner" +) + +# Get an explanation for an expert +expert_explanation = assistant.explain_concept( + query="What are the implications of ERC-721 vs ERC-1155 for poker-based NFTs?", + user_expertise_level="expert" +) +``` + +## Running the Test Script + +To test the assistant with sample queries (using environment variables for API keys): + +```bash +python test_nft_assistant.py +``` + +Or use the provided script with environment variables already set: + +```bash +python run_nft_assistant.py +``` + +The script will use the Google Gemini model by default, but you can modify it to use a different model by changing the `model_name` parameter in the script. + +## Requirements + +- Ember framework +- Access to at least one of the following: + - OpenAI API key (for GPT models) + - Anthropic API key (for Claude models) + - Google AI API key (for Gemini models) + +## Example Outputs + +### Beginner Level Explanation + +``` +Imagine you have a trading card of your favorite basketball player. It's special, right? It's not exactly the same as any other card, even another card of the same player. It might have a different number, a different condition, or maybe even a unique autograph. You can trade it, sell it, or keep it as part of your collection. + +A non-fungible token, or NFT, is kind of like a digital version of that special trading card. "Non-fungible" just means it's unique and can't be replaced with something exactly the same. Think of a dollar bill – that's *fungible*. You can trade one dollar for another, and they're essentially identical. But your trading card, or an NFT, is one-of-a-kind. + +NFTs use blockchain technology (think of it like a super-secure digital ledger) to prove ownership and authenticity. So, even though an image or video might be copied easily online, the NFT is like a certificate of ownership saying you own the *original* digital item. +``` + +### Expert Level Explanation + +``` +A non-fungible token (NFT) is a cryptographic token representing ownership of a unique digital or physical asset recorded on a distributed ledger, typically a blockchain. It's crucial to understand the nuances beyond the basic definition, particularly at an expert level. + +Uniqueness and Indivisibility: NFTs are inherently non-fungible, meaning they are not interchangeable on a 1:1 basis like cryptocurrencies such as Bitcoin or Ethereum. This uniqueness derives from distinct metadata embedded within the token, effectively acting as a digital fingerprint. While an NFT can represent fractional ownership, the token itself remains indivisible, pointing to a specific entry on the blockchain. + +Metadata and Provenance: The core value proposition of an NFT often lies in its metadata. This data, structured usually as JSON, defines the asset the NFT represents and can include various information: creator information, creation date, a hash of the underlying asset (image, video, audio, etc.), ownership history, and even embedded smart contract functionality. +``` diff --git a/src/ember/examples/nft_assistant/nft_education_assistant.py b/src/ember/examples/nft_assistant/nft_education_assistant.py new file mode 100644 index 00000000..eb870bfe --- /dev/null +++ b/src/ember/examples/nft_assistant/nft_education_assistant.py @@ -0,0 +1,59 @@ +from ember.api import models + +class NFTEducationAssistant: + def __init__(self, model_name="openai:gpt-4o"): + """ + Initialize the NFT Education Assistant with a customizable model. + + Args: + model_name (str, optional): Model to use for explanations. Default is "openai:gpt-4o". + Options include: + - OpenAI: "openai:gpt-4o", "openai:gpt-4o-mini", "openai:gpt-4", "openai:gpt-4-turbo", "openai:gpt-3.5-turbo" + - Anthropic: "anthropic:claude-3-opus", "anthropic:claude-3-5-sonnet", "anthropic:claude-3-5-haiku", "anthropic:claude-3-7-sonnet" + - Deepmind: "deepmind:gemini-1.5-pro", "deepmind:gemini-1.5-flash", "deepmind:gemini-2.0-pro" + """ + # Note: API keys should be set as environment variables before running this code: + # - OPENAI_API_KEY for OpenAI models + # - ANTHROPIC_API_KEY for Anthropic models + # - GOOGLE_API_KEY for Deepmind models + + # Create the model + self.model = models.model(model_name, temperature=0.7) + + def explain_concept(self, query, user_expertise_level): + """Provides personalized explanations of NFT/poker concepts + + Args: + query (str): The NFT or poker concept to explain + user_expertise_level (str): The user's expertise level (e.g., "beginner", "intermediate", "expert") + + Returns: + str: The personalized explanation + """ + # Format the prompt to include the expertise level + prompt = f""" + You are an NFT and Poker Education Assistant. Your task is to explain concepts related to NFTs and poker + in a way that matches the user's expertise level. + + EXPERTISE LEVEL: {user_expertise_level} + + CONCEPT TO EXPLAIN: {query} + + Please provide a clear, accurate explanation of this concept that is appropriate for someone at the + {user_expertise_level} level. Include relevant examples and analogies where helpful. + """ + + # Call the model with the formatted prompt + response = self.model(prompt) + + # Return the response text + return response + +# Example usage +if __name__ == "__main__": + assistant = NFTEducationAssistant() + result = assistant.explain_concept( + query="What is a non-fungible token?", + user_expertise_level="beginner" + ) + print(f"Explanation: {result}") diff --git a/src/ember/examples/nft_assistant/run_nft_assistant.py b/src/ember/examples/nft_assistant/run_nft_assistant.py new file mode 100644 index 00000000..b1767c76 --- /dev/null +++ b/src/ember/examples/nft_assistant/run_nft_assistant.py @@ -0,0 +1,95 @@ +"""NFT Assistant Runner. + +This script demonstrates how to run the NFT Education Assistant. +""" + +import os +import sys +from nft_education_assistant import NFTEducationAssistant + +def main() -> None: + """Run the NFT Education Assistant example.""" + # Check if OPENAI_API_KEY is set in the environment + if not os.environ.get("OPENAI_API_KEY"): + print("Warning: OPENAI_API_KEY environment variable not set.") + print("Please set your API key using: export OPENAI_API_KEY='your-key-here'") + return + + # Check for GOOGLE_API_KEY environment variable + if not os.environ.get("GOOGLE_API_KEY"): + print("Warning: GOOGLE_API_KEY environment variable not set.") + print("Please set your API key using: export GOOGLE_API_KEY='your-key-here'") + print("This is required for Deepmind models.") + + # Check if Anthropic API key is provided as command line argument + if len(sys.argv) > 1: + os.environ["ANTHROPIC_API_KEY"] = sys.argv[1] + print("Set ANTHROPIC_API_KEY from command line argument.") + elif not os.environ.get("ANTHROPIC_API_KEY"): + print("Note: ANTHROPIC_API_KEY not set. Required for Anthropic models.") + + # Get model names from command line arguments or use recommended configuration + ensemble_model = sys.argv[2] if len(sys.argv) > 2 else "deepmind:gemini-1.5-pro" + judge_model = sys.argv[3] if len(sys.argv) > 3 else "openai:gpt-4o" + + print(f"Using ensemble model: {ensemble_model}") + print(f"Using judge model: {judge_model}") + + # Create an instance of the NFT Education Assistant + # Default to OpenAI model, but use Deepmind if specified + model_name = "deepmind:gemini-1.5-pro" if "deepmind" in ensemble_model else "openai:gpt-4o" + assistant = NFTEducationAssistant(model_name=model_name) + + # Example questions to demonstrate the assistant + example_questions = [ + "What is an NFT?", + "How do I create my first NFT?", + "What is the environmental impact of NFTs?", + "How can artists make money with NFTs?", + "What are the most popular NFT marketplaces?" + ] + + # User expertise levels for examples + expertise_level = "beginner" + + # Run the example questions + for i, question in enumerate(example_questions, 1): + print(f"\n--- Example {i} ---") + print(f"Question: {question}") + + # Process the question using explain_concept method + response = assistant.explain_concept(question, expertise_level) + + # Display the results + print("\nAnswer:") + print(response) + print("\n" + "-" * 80) + + # Interactive mode + print("\n--- Interactive Mode ---") + print("Type your questions about NFTs (or 'exit' to quit)") + print("Default expertise level is 'beginner'. To change, type 'expertise:level'") + + current_expertise = "beginner" + + while True: + question = input("\nYour question: ") + if question.lower() in ("exit", "quit", "q"): + break + + # Check if user wants to change expertise level + if question.lower().startswith("expertise:"): + try: + current_expertise = question.split(":")[1].strip().lower() + print(f"Expertise level set to: {current_expertise}") + continue + except IndexError: + print("Invalid format. Use 'expertise:level' (e.g., expertise:intermediate)") + continue + + response = assistant.explain_concept(question, current_expertise) + print("\nAnswer:") + print(response) + +if __name__ == "__main__": + main() diff --git a/src/ember/examples/nft_assistant/test_nft_assistant.py b/src/ember/examples/nft_assistant/test_nft_assistant.py new file mode 100644 index 00000000..0ea20c85 --- /dev/null +++ b/src/ember/examples/nft_assistant/test_nft_assistant.py @@ -0,0 +1,57 @@ +import os +from nft_education_assistant import NFTEducationAssistant + +def main(): + # Get API keys from environment variables + openai_api_key = os.environ.get("OPENAI_API_KEY") + anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY") + + # Check if API keys are available + if not openai_api_key: + print("Warning: OPENAI_API_KEY environment variable not set") + if not anthropic_api_key: + print("Warning: ANTHROPIC_API_KEY environment variable not set") + + # Create an instance of the NFT Education Assistant + assistant = NFTEducationAssistant( + openai_api_key=openai_api_key, + anthropic_api_key=anthropic_api_key + ) + + # Example queries at different expertise levels + test_cases = [ + { + "query": "What is a non-fungible token?", + "expertise_level": "beginner" + }, + { + "query": "How do smart contracts work with NFTs?", + "expertise_level": "intermediate" + }, + { + "query": "What are the implications of ERC-721 vs ERC-1155 for poker-based NFTs?", + "expertise_level": "expert" + } + ] + + # Test each query + for i, test_case in enumerate(test_cases, 1): + print(f"\n--- Test Case {i} ---") + print(f"Query: {test_case['query']}") + print(f"Expertise Level: {test_case['expertise_level']}") + + # Get explanation + result = assistant.explain_concept( + query=test_case["query"], + user_expertise_level=test_case["expertise_level"] + ) + + # Print result + print("\nResult:") + if isinstance(result, dict) and "synthesized_response" in result: + print(result["synthesized_response"]) + else: + print(result) + +if __name__ == "__main__": + main() diff --git a/src/ember/examples/sql_agent/README.md b/src/ember/examples/sql_agent/README.md new file mode 100644 index 00000000..7e2eba6f --- /dev/null +++ b/src/ember/examples/sql_agent/README.md @@ -0,0 +1,73 @@ +# SQL Agent Example + +## Overview +This example demonstrates a dynamic SQL Agent capable of handling any database schema without hardcoded patterns. The agent can dynamically explore and map database structures, generate appropriate SQL queries, and provide definitive answers for any question. + +## Key Features + +### 1. Dynamic Database Schema Exploration +- Functionality to dynamically explore the entire database schema +- Retrieves all tables, columns, and sample data +- Attempts to infer relationships between tables +- Adapts to any database structure without hardcoding + +### 2. Improved Query Generation +- Generates SQL queries based on the complete database schema +- Explicitly instructs the LLM to use actual table names, not placeholders +- Validates queries against the actual database schema +- Automatically corrects queries with invalid column or table names + +### 3. Enhanced Answer Generation +- Generates clear, definitive answers based on query results and database schema +- Includes specific details from query results in the answers +- Adapts to any type of question without relying on predefined patterns +- Distills SQL query results and adds them to the LLM context + +## Usage + +### Command-line Interface +```python +from ember.examples.sql_agent.sql_agent import SQLAgent + +# Initialize the agent with a database connection +agent = SQLAgent(database_url="sqlite:///your_database.db") + +# Ask a question about your data +result = agent.query("What was the average value in the last month?") +print(result["answer"]) +``` + +### Streamlit Web Interface +The SQL Agent also comes with a Streamlit web interface for interactive querying: + +```bash +# First, set up your environment with the required API key(s) +export OPENAI_API_KEY='your-api-key' # or ANTHROPIC_API_KEY, DEEPMIND_API_KEY + +# Then run the Streamlit app +streamlit run src/ember/examples/sql_agent/app.py +``` + +This launches a web interface where you can: +- Ask questions about your data in natural language +- See the generated SQL queries and their results +- Browse sample queries +- Select different LLM models +- Export your chat history + +## Example Data +This example includes a Formula 1 dataset with tables for: +- Drivers championship +- Constructors championship +- Race results +- Race wins +- Fastest laps + +## Technical Details +- Uses SQLAlchemy's `text()` function for safer SQL query execution +- Implements dynamic schema mapping to understand any database structure +- Adds query validation and correction mechanisms +- Enhances the answer generation process to provide more definitive responses + +## PR Enhancement Notes +This SQL Agent example was created as part of a PR to enhance the Ember framework with a dynamic SQL Agent that can work with any database schema without hardcoded patterns. The implementation follows best practices for Python development with proper typing, documentation, and test coverage. \ No newline at end of file diff --git a/src/ember/examples/sql_agent/__init__.py b/src/ember/examples/sql_agent/__init__.py new file mode 100644 index 00000000..3a55e6cb --- /dev/null +++ b/src/ember/examples/sql_agent/__init__.py @@ -0,0 +1,4 @@ +"""SQL Agent Example Package. + +This package contains a dynamic SQL agent that can analyze any database schema. +""" \ No newline at end of file diff --git a/src/ember/examples/sql_agent/__pycache__/__init__.cpython-312.pyc b/src/ember/examples/sql_agent/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 00000000..f0b2b49b Binary files /dev/null and b/src/ember/examples/sql_agent/__pycache__/__init__.cpython-312.pyc differ diff --git a/src/ember/examples/sql_agent/__pycache__/agent.cpython-312.pyc b/src/ember/examples/sql_agent/__pycache__/agent.cpython-312.pyc new file mode 100644 index 00000000..1352ae29 Binary files /dev/null and b/src/ember/examples/sql_agent/__pycache__/agent.cpython-312.pyc differ diff --git a/src/ember/examples/sql_agent/__pycache__/load_f1_data.cpython-312.pyc b/src/ember/examples/sql_agent/__pycache__/load_f1_data.cpython-312.pyc new file mode 100644 index 00000000..11b62caf Binary files /dev/null and b/src/ember/examples/sql_agent/__pycache__/load_f1_data.cpython-312.pyc differ diff --git a/src/ember/examples/sql_agent/__pycache__/load_knowledge.cpython-312.pyc b/src/ember/examples/sql_agent/__pycache__/load_knowledge.cpython-312.pyc new file mode 100644 index 00000000..bde8251e Binary files /dev/null and b/src/ember/examples/sql_agent/__pycache__/load_knowledge.cpython-312.pyc differ diff --git a/src/ember/examples/sql_agent/__pycache__/sql_agent.cpython-312.pyc b/src/ember/examples/sql_agent/__pycache__/sql_agent.cpython-312.pyc new file mode 100644 index 00000000..4fb23a5d Binary files /dev/null and b/src/ember/examples/sql_agent/__pycache__/sql_agent.cpython-312.pyc differ diff --git a/src/ember/examples/sql_agent/__pycache__/utils.cpython-312.pyc b/src/ember/examples/sql_agent/__pycache__/utils.cpython-312.pyc new file mode 100644 index 00000000..454556c2 Binary files /dev/null and b/src/ember/examples/sql_agent/__pycache__/utils.cpython-312.pyc differ diff --git a/src/ember/examples/sql_agent/agent.py b/src/ember/examples/sql_agent/agent.py new file mode 100644 index 00000000..ad7bfd64 --- /dev/null +++ b/src/ember/examples/sql_agent/agent.py @@ -0,0 +1,618 @@ +import json +import os +from pathlib import Path +from textwrap import dedent +import logging +from typing import Dict, ClassVar, Optional + +from ember.api import models +from ember.api.operators import Operator, Specification, EmberModel, Field +from sqlalchemy import create_engine, text +import pandas as pd + +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# Database connection string +DB_URL = "sqlite:///f1_data.db" + +# Paths +CWD = Path(__file__).parent +KNOWLEDGE_DIR = CWD.joinpath("knowledge") +OUTPUT_DIR = CWD.joinpath("output") + +# Create the output directory if it does not exist +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + +# Define input and output models +class SQLQueryInput(EmberModel): + query: str = Field(..., description="The natural language query to process") + +class SQLQueryOutput(EmberModel): + response: str = Field(..., description="The generated response") + sql_query: Optional[str] = Field(None, description="The SQL query that was executed") + execution_time: Optional[float] = Field(None, description="Time taken to execute the query in seconds") + +# Define specification +class SQLAgentSpec(Specification): + input_model: type[EmberModel] = SQLQueryInput + structured_output: type[EmberModel] = SQLQueryOutput + +class SQLAgent(Operator[SQLQueryInput, SQLQueryOutput]): + """SQL Agent that converts natural language to SQL and executes queries.""" + + specification: ClassVar[Specification] = SQLAgentSpec() + + def __init__(self, model_name: str = "openai:gpt-4o"): + """Initialize the SQL Agent. + + Args: + model_name: Model identifier in format 'provider:model_name' + Options include: + - OpenAI: "openai:gpt-4o", "openai:gpt-4o-mini", "openai:gpt-4", "openai:gpt-4-turbo", "openai:gpt-3.5-turbo" + - Anthropic: "anthropic:claude-3-opus", "anthropic:claude-3-5-sonnet", "anthropic:claude-3-5-haiku", "anthropic:claude-3-7-sonnet" + - Deepmind: "deepmind:gemini-1.5-pro", "deepmind:gemini-1.5-flash", "deepmind:gemini-2.0-pro" + """ + super().__init__() + self.model_id = model_name + self.llm = models.model(model_name, temperature=0.2) + self.db_engine = create_engine(DB_URL) + self.knowledge_base = self._load_knowledge_base() + self.semantic_model = self._create_semantic_model() + self.session_name = "New Session" + self.messages = [] + self.tool_calls = [] + + def _load_knowledge_base(self) -> Dict: + """Load the knowledge base from the knowledge directory.""" + knowledge_base = { + "tables": {}, + "sample_queries": [] + } + + # Load table metadata + for file_path in KNOWLEDGE_DIR.glob("*.json"): + with open(file_path, "r") as f: + table_data = json.load(f) + table_name = table_data.get("table_name") + if table_name: + knowledge_base["tables"][table_name] = table_data + + # Load sample queries + sample_queries_path = KNOWLEDGE_DIR / "sample_queries.sql" + if sample_queries_path.exists(): + with open(sample_queries_path, "r") as f: + content = f.read() + + # Parse the sample queries + query_blocks = [] + current_block = {"description": "", "query": ""} + in_description = False + in_query = False + + for line in content.split("\n"): + if line.strip() == "-- ": + in_description = True + in_query = False + if current_block["query"]: + query_blocks.append(current_block) + current_block = {"description": "", "query": ""} + elif line.strip() == "-- ": + in_description = False + elif line.strip() == "-- ": + in_query = True + in_description = False + elif line.strip() == "-- ": + in_query = False + elif in_description: + current_block["description"] += line.replace("-- ", "") + "\n" + elif in_query: + current_block["query"] += line + "\n" + + if current_block["query"]: + query_blocks.append(current_block) + + knowledge_base["sample_queries"] = query_blocks + + return knowledge_base + + def _create_semantic_model(self) -> Dict: + """Create a semantic model from the knowledge base.""" + semantic_model = { + "tables": [] + } + + for table_name, table_data in self.knowledge_base["tables"].items(): + semantic_model["tables"].append({ + "table_name": table_name, + "table_description": table_data.get("table_description", ""), + "Use Case": f"Use this table to get data on {table_name.replace('_', ' ')}." + }) + + return semantic_model + + def search_knowledge_base(self, table_name: str) -> str: + """Search the knowledge base for information about a table.""" + # Log the tool call + self.tool_calls.append({ + "tool_name": "search_knowledge_base", + "tool_args": {"table_name": table_name}, + "content": None + }) + + if table_name in self.knowledge_base["tables"]: + table_data = self.knowledge_base["tables"][table_name] + result = { + "table_name": table_data.get("table_name", ""), + "table_description": table_data.get("table_description", ""), + "table_columns": table_data.get("table_columns", []), + "table_rules": table_data.get("table_rules", []) + } + result_str = json.dumps(result, indent=2) + self.tool_calls[-1]["content"] = result_str + return result_str + else: + result = f"Table '{table_name}' not found in the knowledge base." + self.tool_calls[-1]["content"] = result + return result + + def describe_table(self, table_name: str) -> str: + """Get the schema of a table from the database.""" + # Log the tool call + self.tool_calls.append({ + "tool_name": "describe_table", + "tool_args": {"table_name": table_name}, + "content": None + }) + + try: + # For SQLite, we need to use a different approach to get table schema + # since it doesn't support information_schema + query = f"PRAGMA table_info({table_name});" + + with self.db_engine.connect() as conn: + result = conn.execute(text(query)) + columns = result.fetchall() + + if not columns: + result = f"Table '{table_name}' not found in the database." + self.tool_calls[-1]["content"] = result + return result + + schema_info = f"Schema for table '{table_name}':\n\n" + schema_info += "| Column Name | Data Type | Nullable | Primary Key |\n" + schema_info += "|-------------|-----------|----------|------------|\n" + + for column in columns: + # PRAGMA table_info returns: cid, name, type, notnull, dflt_value, pk + col_name = column[1] + col_type = column[2] + col_nullable = "NO" if column[3] == 1 else "YES" + col_pk = "YES" if column[5] == 1 else "NO" + + schema_info += f"| {col_name} | {col_type} | {col_nullable} | {col_pk} |\n" + + # Also get a sample of data to help understand the table + sample_query = f"SELECT * FROM {table_name} LIMIT 3;" + try: + with self.db_engine.connect() as conn: + sample_result = conn.execute(text(sample_query)) + sample_rows = sample_result.fetchall() + sample_columns = sample_result.keys() + + if sample_rows: + schema_info += f"\n\nSample data from '{table_name}':\n\n" + # Create header row + schema_info += "| " + " | ".join(sample_columns) + " |\n" + # Create separator row + schema_info += "| " + " | ".join(["---" for _ in sample_columns]) + " |\n" + # Add data rows + for row in sample_rows: + schema_info += "| " + " | ".join([str(cell) for cell in row]) + " |\n" + except Exception as e: + schema_info += f"\n\nCould not retrieve sample data: {str(e)}" + + self.tool_calls[-1]["content"] = schema_info + return schema_info + + except Exception as e: + error_message = f"Error describing table: {str(e)}" + logger.error(error_message) + self.tool_calls[-1]["content"] = error_message + return error_message + + def run_sql_query(self, query: str) -> str: + """Run a SQL query and return the results as a formatted string.""" + try: + # Log the query + logger.info(f"Running SQL query: {query}") + self.tool_calls.append({ + "tool_name": "run_sql_query", + "tool_args": {"query": query}, + "content": None + }) + + # Execute the query + with self.db_engine.connect() as conn: + df = pd.read_sql(query, conn) + + # Convert to markdown table + markdown_table = df.to_markdown(index=False) + + # Store the result in the tool call + self.tool_calls[-1]["content"] = json.dumps(df.to_dict(orient="records")) + + return markdown_table + + except Exception as e: + error_message = f"Error executing SQL query: {str(e)}" + logger.error(error_message) + self.tool_calls[-1]["content"] = json.dumps({"error": error_message}) + return error_message + + def get_tool_call_history(self, num_calls: int = 3) -> str: + """Get the history of tool calls.""" + if not self.tool_calls: + return "No tool calls in history." + + recent_calls = self.tool_calls[-num_calls:] + result = "Recent tool calls:\n\n" + + for i, call in enumerate(recent_calls): + result += f"Tool Call {i+1}:\n" + result += f"Tool: {call['tool_name']}\n" + + if call['tool_name'] == 'run_sql_query': + result += f"Query:\n```sql\n{call['tool_args']['query']}\n```\n" + else: + result += f"Arguments: {json.dumps(call['tool_args'], indent=2)}\n" + + result += "\n" + + return result + + def execute_sql_query(self, query: str) -> pd.DataFrame: + """Execute a SQL query and return the results as a DataFrame. + + Args: + query: The SQL query to execute + + Returns: + A pandas DataFrame with the query results + """ + logger.info(f"Executing SQL query: {query}") + # Add this to the tool calls for tracking + self.tool_calls.append({ + "tool_name": "execute_sql_query", + "tool_args": {"query": query}, + "content": None + }) + + try: + with self.db_engine.connect() as conn: + result = conn.execute(text(query)) + rows = result.fetchall() + columns = result.keys() + df = pd.DataFrame(rows, columns=columns) + + # Store the result in the tool call + self.tool_calls[-1]["content"] = json.dumps(df.to_dict(orient="records")) + + return df + except Exception as e: + error_message = f"Error executing SQL query: {str(e)}" + logger.error(error_message) + self.tool_calls[-1]["content"] = json.dumps({"error": error_message}) + raise + + def forward(self, *, inputs: SQLQueryInput) -> SQLQueryOutput: + """Process a natural language query and return a response. + + Args: + inputs: The input query + + Returns: + The generated response with SQL query information + """ + query = inputs.query + + # Add the user message to the conversation history + self.messages.append({"role": "user", "content": query}) + + # Prepare the system prompt + system_prompt = self._create_system_prompt() + + # Prepare the conversation history as a formatted string + conversation_history = "" + for message in self.messages[:-1]: # Exclude the current query + role = "User" if message["role"] == "user" else "Assistant" + conversation_history += f"{role}: {message['content']}\n\n" + + # Format the prompt + prompt = f""" + {system_prompt} + + CONVERSATION HISTORY: + {conversation_history} + + USER QUERY: {query} + + IMPORTANT INSTRUCTIONS: + 1. ALWAYS use the describe_table() function to check the actual schema of any table before writing a query. + 2. You MUST include a SQL query in your response to answer the user's question. + 3. Format the SQL query as a code block with ```sql at the beginning and ``` at the end. + 4. Make sure your SQL query is correct and will run successfully. + 5. DO NOT include placeholders or comments in your SQL query that would prevent it from executing. + 6. DO NOT include multiple SQL queries - just provide ONE complete query that answers the question. + 7. ALWAYS provide a direct, clear answer at the beginning of your response. + 8. For questions about "most", "top", "longest", etc., clearly state the answer with specific values. + 9. ALWAYS execute the SQL query and include the results in your response. + 10. Be dynamic and flexible - handle ANY type of question about the data without relying on predefined patterns. + 11. NEVER stop after just retrieving schema information - ALWAYS proceed to generate and execute a SQL query. + 12. Your response MUST include a SQL query in a code block - this is critical for the system to work. + 13. CAREFULLY check the schema information to use the CORRECT column names in your query. + 14. Pay close attention to the sample data to understand the structure and content of each table. + 15. Generate SQL queries dynamically based on the question - don't rely on templates or patterns. + 16. Be creative in your SQL query construction to answer complex questions accurately. + 17. ALWAYS provide a clear, definitive answer based on the query results. + + For example, if the user asks "Who are the top 5 drivers with the most race wins?", your response should include: + + ```sql + SELECT name, COUNT(*) as wins + FROM race_wins + GROUP BY name + ORDER BY wins DESC + LIMIT 5; + ``` + + Notice how the query uses the exact column name 'name' from the race_wins table schema, not 'driver' or 'winner'. + + Another example, if the user asks "Show me the number of races per year", your response should include: + + ```sql + SELECT year, COUNT(DISTINCT venue) as num_races + FROM race_results + GROUP BY year + ORDER BY year; + ``` + + Notice how this query uses the exact column names 'year' and 'venue' from the race_results table schema. + + For a question like "Tell me the driver with the longest racing career", your response should include: + + ```sql + SELECT name, MIN(year) as first_year, MAX(year) as last_year, MAX(year) - MIN(year) as career_length + FROM race_results + GROUP BY name + ORDER BY career_length DESC + LIMIT 1; + ``` + + This query finds the driver with the longest span between their first and last race appearance. + + I will execute this query for you and provide the results. Make sure your response is clear and concise. + + IMPORTANT: Your response MUST include a SQL query in a code block with ```sql at the beginning and ``` at the end. If you don't include a SQL query, the system will not be able to execute it and provide results. + """ + + # Call the model with the formatted prompt + import time + start_time = time.time() + response_obj = self.llm(prompt) + execution_time = time.time() - start_time + + # Convert response to string if it's not already + response_text = str(response_obj) + + # Extract SQL query if present in the response + sql_query = None + sql_result = None + + # Look for SQL code blocks in the response + import re + sql_matches = re.findall(r'```sql\s*([\s\S]*?)\s*```', response_text) + + if sql_matches: + # Use the first SQL query found + sql_query = sql_matches[0].strip() + + try: + # Execute the SQL query + logger.info(f"Executing SQL query: {sql_query}") + with self.db_engine.connect() as conn: + result = conn.execute(text(sql_query)) + rows = result.fetchall() + columns = result.keys() + + # Convert to DataFrame and then to markdown table + df = pd.DataFrame(rows, columns=columns) + sql_result = df.to_markdown(index=False) if not df.empty else "No results found." + + # Always add a direct answer at the beginning for clarity + if len(rows) > 0: + # For all queries, extract key information from the first row or summarize results + first_row = df.iloc[0].to_dict() + + # Check if the response already starts with a clear answer + has_clear_answer = False + answer_patterns = ["**Answer:**", "**Results Summary:**", "The driver with", "The team with", + "The most", "The top", "The longest", "The highest", "The best"] + + for pattern in answer_patterns: + if pattern.lower() in response_text[:200].lower(): + has_clear_answer = True + break + + # If no clear answer is found, create one based on the query and results + if not has_clear_answer: + # For superlative queries, highlight the top result + if any(term in query.lower() for term in ["most", "top", "best", "highest", "longest", "greatest"]): + # Extract the key column names for a better answer + key_cols = [col for col in df.columns if col.lower() not in ['count', 'sum', 'avg', 'min', 'max', 'index']] + value_cols = [col for col in df.columns if col.lower() in ['count', 'wins', 'championships', 'points', 'races']] + + if key_cols and value_cols: + key_val = first_row[key_cols[0]] + metric_val = first_row[value_cols[0]] + metric_name = value_cols[0] + answer = f"**Answer:** {key_val} has the most {metric_name} with {metric_val}.\n\n" + else: + # Generic answer with the first row data + answer = f"**Answer:** The top result is {first_row}.\n\n" + else: + # For other queries, provide a summary + answer = f"**Results Summary:** Found {len(rows)} records.\n\n" + + # Add the answer to the beginning of the response + response_text = answer + response_text + + # Add the SQL result to the response if not already present + if "Query Results" not in response_text and "SQL Query Result" not in response_text: + response_text += f"\n\n### Query Results\n\n{sql_result}" + + except Exception as e: + # If there's an error executing the query, add it to the response + error_message = f"Error executing SQL query: {str(e)}" + logger.error(error_message) + response_text += f"\n\n### Error Executing Query\n\n```\n{error_message}\n```" + + # Add the assistant's response to the conversation history + self.messages.append({"role": "assistant", "content": response_text}) + + return SQLQueryOutput( + response=response_text, + sql_query=sql_query, + execution_time=execution_time + ) + + def run(self, query: str) -> str: + """Legacy method for backward compatibility.""" + result = self.forward(inputs=SQLQueryInput(query=query)) + return result.response + + def _create_system_prompt(self) -> str: + """Create the system prompt for the agent.""" + semantic_model_str = json.dumps(self.semantic_model, indent=2) + + return dedent(f""" + You are SQL Agent-X, an elite SQL Data Scientist specializing in: + - Historical race analysis + - Driver performance metrics + - Team championship insights + - Track statistics and records + - Performance trend analysis + - Race strategy evaluation + + You combine deep F1 knowledge with advanced SQL expertise to uncover insights from decades of racing data. + You are dynamic and flexible, able to handle any type of question about the F1 data without relying on predefined patterns. + You can generate SQL queries for any question, no matter how complex, by understanding the schema and relationships between tables. + + When a user messages you, determine if you need to query the database or can respond directly. + + If you can respond directly, do so. + + If you need to query the database to answer the user's question, follow these steps: + + 1. First identify the tables you need to query from the semantic model. + + 2. Then, ALWAYS use the `search_knowledge_base(table_name)` function to get table metadata, rules and sample queries. + + 3. ALWAYS use the `describe_table(table_name)` function to get the actual schema and sample data from the database. This is critical to ensure you use the correct column names. + + 4. If table rules are provided, ALWAYS follow them. + + 5. Then, think step-by-step about query construction, don't rush this step. + + 6. Follow a chain of thought approach before writing SQL, ask clarifying questions where needed. + + 7. Then, using all the information available, create one single syntactically correct SQL query to accomplish your task. + - Be creative and flexible in your query construction + - Don't rely on predefined patterns - adapt to the specific question + - Use appropriate SQL functions and operations based on the question + - Consider different ways to interpret the question and choose the most appropriate + - ALWAYS include a SQL query in your response - this is critical + - NEVER stop after just retrieving schema information - always proceed to generate and execute a query + + 8. If you need to join tables, check the `semantic_model` for the relationships between the tables. + + 9. If you cannot find relevant tables, columns or relationships, stop and ask the user for more information. + + 10. Once you have a syntactically correct query, run it using the `run_sql_query` function. + + 11. When running a query: + - Do not add a `;` at the end of the query. + - Always provide a limit unless the user explicitly asks for all results. + + 12. After you run the query, analyse the results and return the answer in markdown format. + + 13. ALWAYS start your response with a direct answer to the user's question. For example: + - "The driver with the most race wins is Michael Schumacher with 91 wins." + - "Ferrari has won the most Constructor Championships with 16 titles." + - "The driver with the longest career is Kimi Räikkönen who raced from 2001 to 2021." + + 14. Always show the user the SQL you ran to get the answer. + + 15. Continue till you have accomplished the task. + + 16. Show results as a table or a chart if possible. + + After finishing your task, ask the user relevant followup questions like "was the result okay, would you like me to fix any problems?" + + If the user says yes, get the previous query using the `get_tool_call_history(num_calls=3)` function and fix the problems. + + If the user wants to see the SQL, get it using the `get_tool_call_history(num_calls=3)` function. + + Finally, here are the set of rules that you MUST follow: + + - Use the `search_knowledge_base(table_name)` function to get table information before writing a query. + - Do not use phrases like "based on the information provided" or "from the knowledge base". + - Always show the SQL queries you use to get the answer. + - Make sure your query accounts for duplicate records. + - Make sure your query accounts for null values. + - If you run a query, explain why you ran it. + - NEVER, EVER RUN CODE TO DELETE DATA OR ABUSE THE LOCAL SYSTEM + - ALWAYS FOLLOW THE `table rules` if provided. NEVER IGNORE THEM. + - Be dynamic and flexible - don't rely on hardcoded patterns for different question types. + - Adapt your query approach based on the specific question being asked. + - Always include the query results in your response. + + The `semantic_model` contains information about tables and the relationships between them. + + If the users asks about the tables you have access to, simply share the table names from the `semantic_model`. + + + {semantic_model_str} + + + You have the following functions available: + + 1. search_knowledge_base(table_name: str) -> str + - Get metadata, rules, and sample queries for a table + + 2. describe_table(table_name: str) -> str + - Get the schema of a table from the database + + 3. run_sql_query(query: str) -> str + - Run a SQL query and return the results + + 4. get_tool_call_history(num_calls: int = 3) -> str + - Get the history of recent tool calls + """) + + def rename_session(self, new_name: str) -> None: + """Rename the current session.""" + self.session_name = new_name + + +def get_sql_agent(model_name: str = "openai:gpt-4o") -> SQLAgent: + """Get an instance of the SQL Agent. + + Args: + model_name: The model to use for the agent + + Returns: + An initialized SQL Agent + """ + return SQLAgent(model_name=model_name) diff --git a/src/ember/examples/sql_agent/app.py b/src/ember/examples/sql_agent/app.py new file mode 100644 index 00000000..9b0b97ea --- /dev/null +++ b/src/ember/examples/sql_agent/app.py @@ -0,0 +1,672 @@ +import nest_asyncio +import streamlit as st +import logging +import os +import pandas as pd +import re +from sqlalchemy import text + +from agent import get_sql_agent +from utils import CUSTOM_CSS, add_message, display_tool_calls, export_chat_history +from load_f1_data import load_f1_data +from load_knowledge import load_knowledge + +# Set the OpenAI API key from environment variable +# Make sure to set your OPENAI_API_KEY environment variable before running this app +# Example: export OPENAI_API_KEY="your-api-key-here" +if "OPENAI_API_KEY" not in os.environ: + print("Warning: OPENAI_API_KEY environment variable not set. Some features may not work properly.") + +# Apply nest_asyncio to allow nested event loops +nest_asyncio.apply() + +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# Set up Streamlit page +st.set_page_config( + page_title="SQL Agent with Ember", + page_icon="💎", + layout="wide", + initial_sidebar_state="expanded", +) + +# Load custom CSS with dark mode support +st.markdown(CUSTOM_CSS, unsafe_allow_html=True) + +def sidebar_widget() -> None: + """Display a sidebar with sample user queries""" + with st.sidebar: + st.markdown("#### 🏆 Sample Queries") + + if st.button("📋 Show Tables"): + if "messages" not in st.session_state: + st.session_state["messages"] = [] + st.session_state["messages"] = add_message( + st.session_state["messages"], + "user", + "Which tables do you have access to?" + ) + + if st.button("🥇 Most Race Wins"): + if "messages" not in st.session_state: + st.session_state["messages"] = [] + st.session_state["messages"] = add_message( + st.session_state["messages"], + "user", + "Which driver has the most race wins?" + ) + + if st.button("🏆 Constructor Champs"): + if "messages" not in st.session_state: + st.session_state["messages"] = [] + st.session_state["messages"] = add_message( + st.session_state["messages"], + "user", + "Which team won the most Constructors Championships?" + ) + + if st.button("⏳ Longest Career"): + if "messages" not in st.session_state: + st.session_state["messages"] = [] + st.session_state["messages"] = add_message( + st.session_state["messages"], + "user", + "Tell me the name of the driver with the longest racing career? Also tell me when they started and when they retired." + ) + + if st.button("📈 Races per Year"): + if "messages" not in st.session_state: + st.session_state["messages"] = [] + st.session_state["messages"] = add_message( + st.session_state["messages"], + "user", + "Show me the number of races per year." + ) + + if st.button("🔍 Team Performance"): + if "messages" not in st.session_state: + st.session_state["messages"] = [] + st.session_state["messages"] = add_message( + st.session_state["messages"], + "user", + "Write a query to identify the drivers that won the most races per year from 2010 onwards and the position of their team that year." + ) + + st.markdown("---") + + st.markdown("#### 🛠️ Utilities") + + col1, col2 = st.columns(2) + + with col1: + if st.button("🔄 New Chat"): + st.session_state["sql_agent"] = None + st.session_state["messages"] = [] + st.rerun() + + with col2: + fn = "sql_agent_chat_history.md" + + if st.download_button( + "💾 Export Chat", + export_chat_history(st.session_state.get("messages", [])), + file_name=fn, + mime="text/markdown", + ): + st.sidebar.success("Chat history exported!") + + if st.sidebar.button("🚀 Load Data & Knowledge"): + with st.spinner("🔄 Loading data into database..."): + load_f1_data() + with st.spinner("📚 Loading knowledge base..."): + load_knowledge() + st.success("✅ Data and knowledge loaded successfully!") + +def about_widget() -> None: + """Display an about section in the sidebar""" + st.sidebar.markdown("---") + st.sidebar.markdown("### ℹ️ About") + st.sidebar.markdown(""" + This SQL Assistant helps you analyze Formula 1 data from 1950 to 2020 using natural language queries. + + Built with: + - 🚀 Ember + - 💫 Streamlit + """) + +def rename_session_widget(agent) -> None: + """Rename the current session of the agent""" + container = st.sidebar.container() + session_row = container.columns([3, 1], vertical_alignment="center") + + # Initialize session_edit_mode if needed + if "session_edit_mode" not in st.session_state: + st.session_state.session_edit_mode = False + + with session_row[0]: + if st.session_state.session_edit_mode: + new_session_name = st.text_input( + "Session Name", + value=agent.session_name, + key="session_name_input", + label_visibility="collapsed", + ) + else: + st.markdown(f"Session Name: **{agent.session_name}**") + + with session_row[1]: + if st.session_state.session_edit_mode: + if st.button("✓", key="save_session_name", type="primary"): + if new_session_name: + agent.rename_session(new_session_name) + st.session_state.session_edit_mode = False + container.success("Renamed!") + else: + if st.button("✎", key="edit_session_name"): + st.session_state.session_edit_mode = True + +def main() -> None: + #################################################################### + # App header + #################################################################### + st.markdown( + "

SQL Agent with Ember

", unsafe_allow_html=True + ) + st.markdown( + "

Your intelligent SQL Agent that can think, analyze and reason, powered by Ember

", + unsafe_allow_html=True, + ) + + #################################################################### + # Auto-load data and knowledge if not already loaded + #################################################################### + if "data_loaded" not in st.session_state or not st.session_state["data_loaded"]: + with st.spinner("🔄 Loading data into database..."): + try: + load_f1_data() + st.session_state["data_loaded"] = True + st.success("✅ F1 data loaded successfully!") + except Exception as e: + st.error(f"❌ Error loading F1 data: {str(e)}") + st.session_state["data_loaded"] = False + + with st.spinner("📚 Loading knowledge base..."): + try: + load_knowledge() + st.session_state["knowledge_loaded"] = True + st.success("✅ Knowledge base loaded successfully!") + except Exception as e: + st.error(f"❌ Error loading knowledge base: {str(e)}") + st.session_state["knowledge_loaded"] = False + + #################################################################### + # Model selector + #################################################################### + model_options = { + "GPT-4o": "openai:gpt-4o", + "GPT-4o-mini": "openai:gpt-4o-mini", + "GPT-4-turbo": "openai:gpt-4-turbo", + "GPT-3.5-turbo": "openai:gpt-3.5-turbo", + } + + selected_model = st.sidebar.selectbox( + "Select a model", + options=list(model_options.keys()), + index=0, + key="model_selector", + ) + + model_id = model_options[selected_model] + + #################################################################### + # Initialize Agent + #################################################################### + sql_agent = None + + if ( + "sql_agent" not in st.session_state + or st.session_state["sql_agent"] is None + or st.session_state.get("current_model") != model_id + ): + logger.info("---*--- Creating new SQL agent ---*---") + sql_agent = get_sql_agent(model_name=model_id) + st.session_state["sql_agent"] = sql_agent + st.session_state["current_model"] = model_id + else: + sql_agent = st.session_state["sql_agent"] + + #################################################################### + # Initialize messages if not already done + #################################################################### + if "messages" not in st.session_state: + st.session_state["messages"] = [] + + #################################################################### + # Sidebar + #################################################################### + sidebar_widget() + + #################################################################### + # Get user input + #################################################################### + if prompt := st.chat_input("👋 Ask me about F1 data from 1950 to 2020!"): + st.session_state["messages"] = add_message(st.session_state["messages"], "user", prompt) + + #################################################################### + # Display chat history + #################################################################### + for message in st.session_state["messages"]: + if message["role"] in ["user", "assistant"]: + content = message["content"] + if content is not None: + with st.chat_message(message["role"]): + # Display tool calls if they exist in the message + if "tool_calls" in message and message["tool_calls"]: + display_tool_calls(st.empty(), message["tool_calls"]) + st.markdown(content) + + #################################################################### + # Generate response for user message + #################################################################### + last_message = ( + st.session_state["messages"][-1] if st.session_state["messages"] else None + ) + + if last_message and last_message.get("role") == "user": + question = last_message["content"] + + with st.chat_message("assistant"): + # Create container for tool calls + tool_calls_container = st.empty() + resp_container = st.empty() + + with st.spinner("🤔 Thinking..."): + try: + # Run the agent + from agent import SQLQueryInput + result = sql_agent.forward(inputs=SQLQueryInput(query=question)) + + # Display tool calls if available + if hasattr(sql_agent, 'tool_calls') and sql_agent.tool_calls: + display_tool_calls(tool_calls_container, sql_agent.tool_calls) + + # Extract SQL query and results if present + response_text = result.response + sql_query = result.sql_query + + # Check if we need to extract and execute a query from the response + import re + + # First, check if there's a describe_table call in the response + describe_matches = re.findall(r'describe_table\("([^"]+)"\)', response_text) + + if describe_matches: + # Execute the describe_table function for each match + for table_name in describe_matches: + try: + # Call the describe_table function + table_info = sql_agent.describe_table(table_name) + + # Add the table info to the response + response_text += f"\n\n{table_info}\n\n" + + # Let the LLM generate the appropriate query in its response + # We don't need hardcoded queries anymore as the agent is now more dynamic + + except Exception as e: + error_msg = f"Error executing describe_table for {table_name}: {str(e)}" + response_text += f"\n\n{error_msg}\n\n" + st.error(error_msg) + + # Display the full response + resp_container.markdown(response_text) + + # If there's a SQL query in the response, execute it + sql_matches = re.findall(r'```sql\s*([\s\S]*?)\s*```', response_text) + if sql_matches and not describe_matches: # Only if we haven't already executed a query above + sql_query = sql_matches[0].strip() + try: + # Execute the SQL query directly + with sql_agent.db_engine.connect() as conn: + df = pd.read_sql(sql_query, conn) + + # Display the results in a dataframe for better visualization + st.markdown("### Interactive Results") + st.dataframe(df, use_container_width=True) + + # If the response doesn't already contain the results, add them + if "Query Results" not in response_text and "SQL Query Result" not in response_text: + # Update the response container with the new content + updated_response = response_text + f"\n\n### Query Results\n\n{df.to_markdown(index=False)}" + resp_container.markdown(updated_response) + # Update the response text for the session state + response_text = updated_response + except Exception as e: + st.error(f"Error executing SQL query: {str(e)}") + elif not sql_matches and describe_matches: # If we have describe_table calls but no SQL query + # The agent retrieved schema information but didn't generate a SQL query + # Let's generate a simple query based on the table and question + note = "\n\n**Note: The agent retrieved schema information but didn't generate a SQL query. Generating a simple query based on the table...**" + response_text += note + resp_container.markdown(response_text) + + # Generate a simple query based on the table name and question + table_name = describe_matches[0] # Use the first table mentioned + + # Dynamically explore and map the database schema + db_schema_info = "" + all_tables = [] + table_columns = {} + relationships = [] + + try: + with sql_agent.db_engine.connect() as conn: + # Get all tables in the database + result = conn.execute(text("SELECT name FROM sqlite_master WHERE type='table'")) + all_tables = [row[0] for row in result.fetchall()] + + # For each table, get its columns + for table in all_tables: + result = conn.execute(text(f"PRAGMA table_info({table})")) + columns = result.fetchall() + table_columns[table] = [(col[1], col[2]) for col in columns] + + # Get sample data for each table + try: + sample = conn.execute(text(f"SELECT * FROM {table} LIMIT 3")).fetchall() + if sample: + # Add sample data to the schema info + table_columns[table].append(("sample_data", sample)) + except: + pass + + # Try to infer relationships between tables + for table1 in all_tables: + for table2 in all_tables: + if table1 != table2: + # Get column names for both tables + cols1 = [col[0] for col in table_columns[table1] if isinstance(col, tuple) and len(col) == 2] + cols2 = [col[0] for col in table_columns[table2] if isinstance(col, tuple) and len(col) == 2] + + # Look for common column names that might indicate relationships + for col1 in cols1: + if col1 in cols2 or f"{table1[:-1]}_id" == col1 or f"{table2[:-1]}_id" == col1: + relationships.append((table1, table2, col1)) + except Exception as e: + logger.error(f"Error exploring database schema: {str(e)}") + + # Format the schema information + db_schema_info = "Database Schema:\n" + for table in all_tables: + db_schema_info += f"\nTable: {table}\n" + db_schema_info += "Columns:\n" + for col in table_columns[table]: + if isinstance(col, tuple) and len(col) == 2 and col[0] != "sample_data": + db_schema_info += f"- {col[0]}: {col[1]}\n" + + if relationships: + db_schema_info += "\nPossible Relationships:\n" + for rel in relationships: + db_schema_info += f"- {rel[0]} may be related to {rel[1]} through column {rel[2]}\n" + + # Get specific information about the table mentioned in the question + table_info = "" + if table_name in table_columns: + table_info = f"\nFocused Table: {table_name}\n" + table_info += "Columns:\n" + for col in table_columns[table_name]: + if isinstance(col, tuple) and len(col) == 2 and col[0] != "sample_data": + table_info += f"- {col[0]}: {col[1]}\n" + + # Add sample data if available + for col in table_columns[table_name]: + if isinstance(col, tuple) and len(col) == 2 and col[0] == "sample_data": + table_info += "\nSample Data:\n" + for row in col[1][:3]: # Show up to 3 rows + table_info += f"{row}\n" + + # Send a follow-up request to the LLM to generate a SQL query + follow_up_prompt = f"""Based on the following database schema, generate a SQL query to answer: {question} + + {db_schema_info} + + {table_info} + + IMPORTANT INSTRUCTIONS: + 1. Use ONLY the exact table names and column names listed above. Do not use placeholder names like 'table_name' or 'column1'. + 2. For the current question, you should use the actual table name '{table_name}' in your query, not a placeholder. + 3. Consider the relationships between tables if they are relevant to the question. + 4. For questions about time spans or careers, use appropriate SQL functions like MIN(), MAX(), etc. + 5. Be creative and flexible in your approach, but ensure the query will execute correctly. + 6. Format your response as a SQL query inside a code block with ```sql at the beginning and ``` at the end. + 7. Double-check that all table and column names in your query exactly match those in the schema. + """ + + # Call the LLM to generate a SQL query + try: + # Use the same model as the agent + llm = sql_agent.llm + response = llm(follow_up_prompt) + + # Extract SQL query from the response + import re + sql_matches = re.findall(r'```sql\s*([\s\S]*?)\s*```', str(response)) + + if sql_matches: + simple_query = sql_matches[0].strip() + + # Validate that the query only uses existing tables and columns + valid_columns = [] + valid_tables = list(table_columns.keys()) + + for table in table_columns: + for col in table_columns[table]: + if isinstance(col, tuple) and len(col) == 2 and col[0] != "sample_data": + valid_columns.append(col[0].lower()) + + # Add SQL keywords and functions to the valid list + sql_keywords = ['select', 'from', 'where', 'group', 'by', 'order', 'limit', 'as', 'min', 'max', + 'count', 'sum', 'avg', 'and', 'or', 'not', 'distinct', 'having', 'desc', 'asc', + 'join', 'inner', 'outer', 'left', 'right', 'on', 'case', 'when', 'then', 'else', 'end', + 'in', 'between', 'like', 'is', 'null', 'cast', 'coalesce', 'nullif', 'ifnull', + 'date', 'datetime', 'time', 'strftime', 'julianday', 'unixepoch', 'localtime', + 'year', 'month', 'day', 'hour', 'minute', 'second'] + + # Check if the query contains the table name + if table_name.lower() not in simple_query.lower(): + logger.warning(f"Query does not contain the table name: {table_name}") + # Replace 'table_name' with the actual table name + simple_query = simple_query.replace('table_name', table_name) + # Also try to replace 'tablename' with the actual table name + simple_query = simple_query.replace('tablename', table_name) + + # Simple validation to catch obvious errors + invalid_columns = [] + for col_name in re.findall(r'\b([a-zA-Z_][a-zA-Z0-9_]*)\b', simple_query): + if (col_name.lower() not in valid_columns and + col_name.lower() not in valid_tables and + col_name.lower() not in sql_keywords and + not col_name.lower().startswith('sqlite_') and + not col_name.isdigit()): + invalid_columns.append(col_name) + + if invalid_columns: + # If invalid columns found, try to fix the query + logger.warning(f"Invalid column names in query: {invalid_columns}") + + # Try to generate a corrected query + correction_prompt = f"""The following SQL query contains invalid column or table names: {invalid_columns} + + Query: {simple_query} + + Valid table names are: {valid_tables} + Valid column names are: {valid_columns} + + For this specific question, you should use the table '{table_name}' and its columns. + + Please correct the query to use only valid table names, column names, and SQL keywords. + Do not use placeholder names like 'table_name' or 'column1' - use the actual names from the schema. + Format your response as a SQL query inside a code block with ```sql at the beginning and ``` at the end. + """ + + try: + correction_response = llm(correction_prompt) + correction_matches = re.findall(r'```sql\s*([\s\S]*?)\s*```', str(correction_response)) + + if correction_matches: + simple_query = correction_matches[0].strip() + logger.info(f"Corrected query: {simple_query}") + else: + # If no corrected query found, use a generic query + simple_query = f"SELECT * FROM {table_name} LIMIT 10" + except Exception as e: + logger.error(f"Error correcting query: {str(e)}") + simple_query = f"SELECT * FROM {table_name} LIMIT 10" + else: + # If no SQL query found, use a generic query + simple_query = f"SELECT * FROM {table_name} LIMIT 10" + except Exception as e: + # If there's an error, use a generic query + logger.error(f"Error generating SQL query: {str(e)}") + simple_query = f"SELECT * FROM {table_name} LIMIT 10" + + try: + # Execute the simple query + import time + start_time = time.time() + with sql_agent.db_engine.connect() as conn: + df = pd.read_sql(simple_query, conn) + execution_time = time.time() - start_time + + # Add the query and results to the response + response_text += f"\n\n```sql\n{simple_query}\n```\n\n### Query Results\n\n{df.to_markdown(index=False)}" + + # Generate a clear, definitive answer based on the query results + # Instead of hardcoded patterns, use the LLM to generate a dynamic answer + if not df.empty: + # Convert the dataframe to a string representation + df_str = df.to_string(index=False) + + # Create a prompt for the LLM to generate an answer + answer_prompt = f"""Based on the following SQL query and its results, provide a clear, direct answer to the question: '{question}' + + SQL Query: + {simple_query} + + Query Results: + {df_str} + + Database Schema Summary: + {db_schema_info} + + IMPORTANT INSTRUCTIONS: + 1. Give a concise, definitive answer that directly addresses the question. + 2. Start with 'Answer:' and focus on the key insights from the data. + 3. Be specific and include actual numbers, names, and values from the query results. + 4. If the question asks for a specific piece of information (like who has the most wins or longest career), clearly state that information. + 5. If the query results show multiple records, summarize the most important findings. + 6. Make your answer complete enough that someone could understand it without seeing the query or results. + 7. Do not say things like 'Based on the query results' or 'According to the data' - just state the facts directly. + 8. If the results are empty or don't answer the question, say so clearly. + """ + + try: + # Use the same model as the agent + llm = sql_agent.llm + answer_response = llm(answer_prompt) + + # Extract the answer + answer_text = str(answer_response).strip() + + # If the answer doesn't start with "Answer:", add it + if not answer_text.startswith("Answer:"): + answer_text = "Answer: " + answer_text + + # Format the answer + direct_answer = f"\n\n**{answer_text}**" + response_text = direct_answer + response_text + except Exception as e: + # If there's an error, create a generic answer + logger.error(f"Error generating answer: {str(e)}") + + # Create a generic answer based on the first row + if len(df) > 0: + # Extract key columns for a meaningful answer + key_cols = [col for col in df.columns if col.lower() not in ['index', 'id']] + if key_cols: + # Create a summary of the first row + first_row_summary = ", ".join([f"{col}: {df.iloc[0][col]}" for col in key_cols[:3]]) + direct_answer = f"\n\n**Answer: Based on the data, {first_row_summary}. See the full results below for more details.**" + else: + direct_answer = f"\n\n**Answer: Found {len(df)} records matching your query. See the full results below.**" + response_text = direct_answer + response_text + else: + direct_answer = "\n\n**Answer: No records found matching your query.**" + response_text = direct_answer + response_text + + resp_container.markdown(response_text) + + # Display the results in a dataframe + st.markdown("### Interactive Results") + st.dataframe(df, use_container_width=True) + + # Set the SQL query for display in the expander + sql_query = simple_query + + except Exception as e: + error_msg = f"Error executing simple query: {str(e)}" + response_text += f"\n\n{error_msg}" + resp_container.markdown(response_text) + st.error(error_msg) + + # Display SQL query if available (in a collapsible section) + if sql_query: + with st.expander("View SQL Query"): + st.code(sql_query, language="sql") + + # Display execution time if available (as a small note) + try: + if hasattr(result, 'execution_time') and result.execution_time: + st.caption(f"Execution time: {result.execution_time:.2f} seconds") + elif 'execution_time' in locals(): + st.caption(f"Execution time: {execution_time:.2f} seconds") + except Exception as e: + logger.warning(f"Could not display execution time: {str(e)}") + + # Add the response to the messages + tool_calls_to_add = sql_agent.tool_calls if hasattr(sql_agent, 'tool_calls') else None + st.session_state["messages"] = add_message( + st.session_state["messages"], + "assistant", + result.response, + tool_calls_to_add + ) + + except Exception as e: + logger.exception(e) + error_message = f"Sorry, I encountered an error: {str(e)}" + + # Display error message + resp_container.error(error_message) + + # Add error message to conversation history + st.session_state["messages"] = add_message( + st.session_state["messages"], + "assistant", + error_message + ) + + #################################################################### + # Rename session widget + #################################################################### + rename_session_widget(sql_agent) + + #################################################################### + # About section + #################################################################### + about_widget() + +if __name__ == "__main__": + main() diff --git a/src/ember/examples/sql_agent/f1_data.db b/src/ember/examples/sql_agent/f1_data.db new file mode 100644 index 00000000..1dad7174 Binary files /dev/null and b/src/ember/examples/sql_agent/f1_data.db differ diff --git a/src/ember/examples/sql_agent/knowledge/__init__.py b/src/ember/examples/sql_agent/knowledge/__init__.py new file mode 100644 index 00000000..6e8316b9 --- /dev/null +++ b/src/ember/examples/sql_agent/knowledge/__init__.py @@ -0,0 +1,6 @@ +"""Knowledge package for SQL Agent examples. + +This package contains knowledge data for the SQL Agent examples. +""" + +# Knowledge base for the SQL Agent diff --git a/src/ember/examples/sql_agent/knowledge/constructors_championship.json b/src/ember/examples/sql_agent/knowledge/constructors_championship.json new file mode 100644 index 00000000..f4b5748e --- /dev/null +++ b/src/ember/examples/sql_agent/knowledge/constructors_championship.json @@ -0,0 +1,36 @@ +{ + "table_name": "constructors_championship", + "table_description": "Contains data for the constructor's championship from 1958 to 2020, capturing championship standings from when it was introduced.", + "table_columns": [ + { + "name": "index", + "type": "int", + "description": "Unique index for each entry." + }, + { + "name": "year", + "type": "int", + "description": "The year of the championship." + }, + { + "name": "position", + "type": "int", + "description": "The finishing position of the constructor in the championship." + }, + { + "name": "team", + "type": "text", + "description": "The name of the constructor/team." + }, + { + "name": "points", + "type": "float", + "description": "Points earned in the championship." + } + ], + "table_rules": [ + "When comparing teams across years, account for team name changes (e.g., Brawn GP became Mercedes).", + "Points systems have changed over the years, so direct points comparisons across eras may not be meaningful.", + "Constructor's Championship only started in 1958, so no data exists before that year." + ] +} diff --git a/src/ember/examples/sql_agent/knowledge/drivers_championship.json b/src/ember/examples/sql_agent/knowledge/drivers_championship.json new file mode 100644 index 00000000..aa5dc2ae --- /dev/null +++ b/src/ember/examples/sql_agent/knowledge/drivers_championship.json @@ -0,0 +1,46 @@ +{ + "table_name": "drivers_championship", + "table_description": "Contains data for driver's championship standings from 1950-2020, detailing driver positions, teams, and points.", + "table_columns": [ + { + "name": "index", + "type": "int", + "description": "Unique index for each entry." + }, + { + "name": "year", + "type": "int", + "description": "The year of the championship." + }, + { + "name": "position", + "type": "int", + "description": "The finishing position of the driver in the championship." + }, + { + "name": "name", + "type": "text", + "description": "The name of the driver." + }, + { + "name": "team", + "type": "text", + "description": "The team the driver raced for." + }, + { + "name": "points", + "type": "float", + "description": "Points earned in the championship." + }, + { + "name": "wins", + "type": "int", + "description": "Number of race wins in the season." + } + ], + "table_rules": [ + "Points systems have changed over the years, so direct points comparisons across eras may not be meaningful.", + "Driver names should be used consistently (e.g., 'Lewis Hamilton' not 'Hamilton, Lewis').", + "When analyzing driver careers, join with race_results or race_wins tables for detailed race information." + ] +} diff --git a/src/ember/examples/sql_agent/knowledge/fastest_laps.json b/src/ember/examples/sql_agent/knowledge/fastest_laps.json new file mode 100644 index 00000000..f2ecf25f --- /dev/null +++ b/src/ember/examples/sql_agent/knowledge/fastest_laps.json @@ -0,0 +1,51 @@ +{ + "table_name": "fastest_laps", + "table_description": "Contains data for the fastest laps recorded in races from 1950-2020, including driver and team details.", + "table_columns": [ + { + "name": "index", + "type": "int", + "description": "Unique index for each entry." + }, + { + "name": "year", + "type": "int", + "description": "The year of the race." + }, + { + "name": "venue", + "type": "text", + "description": "Location/track of the race." + }, + { + "name": "name", + "type": "text", + "description": "Name of the driver who set the fastest lap." + }, + { + "name": "team", + "type": "text", + "description": "The team of the driver." + }, + { + "name": "time", + "type": "text", + "description": "The fastest lap time, in the format 'M:SS.mmm'." + }, + { + "name": "lap", + "type": "int", + "description": "The lap number on which the fastest lap was set." + }, + { + "name": "speed", + "type": "float", + "description": "The average speed during the fastest lap in km/h." + } + ], + "table_rules": [ + "Lap times are stored as text in the format 'M:SS.mmm', use appropriate string functions for comparison.", + "Track configurations have changed over time, so direct lap time comparisons for the same venue across years may not be valid.", + "Speed is measured in km/h and should be treated as a float for calculations." + ] +} diff --git a/src/ember/examples/sql_agent/knowledge/race_results.json b/src/ember/examples/sql_agent/knowledge/race_results.json new file mode 100644 index 00000000..dc64d4e7 --- /dev/null +++ b/src/ember/examples/sql_agent/knowledge/race_results.json @@ -0,0 +1,66 @@ +{ + "table_name": "race_results", + "table_description": "Race data for each Formula 1 race from 1950-2020, including positions, drivers, teams, and points.", + "table_columns": [ + { + "name": "index", + "type": "int", + "description": "Unique index for each entry." + }, + { + "name": "year", + "type": "int", + "description": "The year of the race." + }, + { + "name": "position", + "type": "text", + "description": "The finishing position of the driver." + }, + { + "name": "driver_no", + "type": "int", + "description": "Driver number." + }, + { + "name": "venue", + "type": "text", + "description": "Location of the race." + }, + { + "name": "name", + "type": "text", + "description": "Name of the driver." + }, + { + "name": "name_tag", + "type": "text", + "description": "Abbreviated tag of the driver's name." + }, + { + "name": "team", + "type": "text", + "description": "The racing team of the driver." + }, + { + "name": "laps", + "type": "float", + "description": "Number of laps completed." + }, + { + "name": "time", + "type": "text", + "description": "Finishing time or gap to the leader." + }, + { + "name": "points", + "type": "float", + "description": "Points earned in the race." + } + ], + "table_rules": [ + "Position can contain values like 'R' for retired or 'D' for disqualified, not just numeric positions.", + "Time can be the actual race time for the winner, or the gap to the leader for other finishers.", + "Points systems have changed over the years, so direct points comparisons across eras may not be meaningful." + ] +} diff --git a/src/ember/examples/sql_agent/knowledge/race_wins.json b/src/ember/examples/sql_agent/knowledge/race_wins.json new file mode 100644 index 00000000..03d8c312 --- /dev/null +++ b/src/ember/examples/sql_agent/knowledge/race_wins.json @@ -0,0 +1,46 @@ +{ + "table_name": "race_wins", + "table_description": "Documents race win data from 1950-2020, detailing venue, winner, team, and race duration.", + "table_columns": [ + { + "name": "index", + "type": "int", + "description": "Unique index for each entry." + }, + { + "name": "date", + "type": "text", + "description": "Date of the race in 'DD Mon YYYY' format." + }, + { + "name": "venue", + "type": "text", + "description": "Location/track of the race." + }, + { + "name": "name", + "type": "text", + "description": "Name of the race winner." + }, + { + "name": "team", + "type": "text", + "description": "The team of the winner." + }, + { + "name": "laps", + "type": "int", + "description": "Number of laps in the race." + }, + { + "name": "time", + "type": "text", + "description": "Total race duration." + } + ], + "table_rules": [ + "Date is stored as text in the format 'DD Mon YYYY', use TO_DATE function for date operations.", + "When joining with other tables, extract the year from the date using EXTRACT(YEAR FROM TO_DATE(date, 'DD Mon YYYY')).", + "Time is the total race duration and is stored as text, typically in the format 'H:MM:SS.mmm'." + ] +} diff --git a/src/ember/examples/sql_agent/knowledge/sample_queries.sql b/src/ember/examples/sql_agent/knowledge/sample_queries.sql new file mode 100644 index 00000000..b2eb0b5f --- /dev/null +++ b/src/ember/examples/sql_agent/knowledge/sample_queries.sql @@ -0,0 +1,189 @@ +-- +-- How many races did the championship winners win each year? +-- +-- +SELECT + dc.year, + dc.name AS champion_name, + COUNT(rw.name) AS race_wins +FROM + drivers_championship dc +JOIN + race_wins rw +ON + dc.name = rw.name AND dc.year = EXTRACT(YEAR FROM TO_DATE(rw.date, 'DD Mon YYYY')) +WHERE + dc.position = 1 +GROUP BY + dc.year, dc.name +ORDER BY + dc.year; +-- + +-- +-- Compare the number of race wins vs championship positions for constructors in 2019 +-- +-- +WITH race_wins_2019 AS ( + SELECT team, COUNT(*) AS wins + FROM race_wins + WHERE EXTRACT(YEAR FROM TO_DATE(date, 'DD Mon YYYY')) = 2019 + GROUP BY team +), +constructors_positions_2019 AS ( + SELECT team, position + FROM constructors_championship + WHERE year = 2019 +) +SELECT cp.team, cp.position, COALESCE(rw.wins, 0) AS wins +FROM constructors_positions_2019 cp +LEFT JOIN race_wins_2019 rw ON cp.team = rw.team +ORDER BY cp.position; +-- + +-- +-- Most race wins by a driver +-- +-- +SELECT name, COUNT(*) AS win_count +FROM race_wins +GROUP BY name +ORDER BY win_count DESC +LIMIT 1; +-- + +-- +-- Which team won the most Constructors Championships? +-- +-- +SELECT team, COUNT(*) AS championship_wins +FROM constructors_championship +WHERE position = 1 +GROUP BY team +ORDER BY championship_wins DESC +LIMIT 1; +-- + +-- +-- Show me Lewis Hamilton's win percentage by season +-- +-- +WITH hamilton_races AS ( + SELECT + year, + COUNT(*) AS total_races + FROM + race_results + WHERE + name = 'Lewis Hamilton' + GROUP BY + year +), +hamilton_wins AS ( + SELECT + EXTRACT(YEAR FROM TO_DATE(date, 'DD Mon YYYY')) AS year, + COUNT(*) AS wins + FROM + race_wins + WHERE + name = 'Lewis Hamilton' + GROUP BY + year +) +SELECT + hr.year, + hr.total_races, + COALESCE(hw.wins, 0) AS wins, + ROUND((COALESCE(hw.wins, 0)::float / hr.total_races) * 100, 2) AS win_percentage +FROM + hamilton_races hr +LEFT JOIN + hamilton_wins hw ON hr.year = hw.year +ORDER BY + hr.year; +-- + +-- +-- Which drivers have won championships with multiple teams? +-- +-- +WITH champion_teams AS ( + SELECT + name, + team, + COUNT(*) AS championships + FROM + drivers_championship + WHERE + position = 1 + GROUP BY + name, team +) +SELECT + name, + COUNT(DISTINCT team) AS different_teams, + STRING_AGG(team || ' (' || championships || ')', ', ') AS teams_with_championships +FROM + champion_teams +GROUP BY + name +HAVING + COUNT(DISTINCT team) > 1 +ORDER BY + different_teams DESC, name; +-- + +-- +-- What tracks have hosted the most races? +-- +-- +SELECT + venue, + COUNT(DISTINCT date) AS race_count +FROM + race_wins +GROUP BY + venue +ORDER BY + race_count DESC +LIMIT 10; +-- + +-- +-- Compare Mercedes vs Ferrari performance in constructors championships +-- +-- +SELECT + year, + MAX(CASE WHEN team = 'Mercedes' THEN position ELSE NULL END) AS mercedes_position, + MAX(CASE WHEN team = 'Mercedes' THEN points ELSE NULL END) AS mercedes_points, + MAX(CASE WHEN team = 'Ferrari' THEN position ELSE NULL END) AS ferrari_position, + MAX(CASE WHEN team = 'Ferrari' THEN points ELSE NULL END) AS ferrari_points +FROM + constructors_championship +WHERE + team IN ('Mercedes', 'Ferrari') + AND year >= 2010 +GROUP BY + year +ORDER BY + year; +-- + +-- +-- Show me the progression of fastest lap times at Monza +-- +-- +SELECT + year, + name, + team, + time, + speed +FROM + fastest_laps +WHERE + venue = 'Monza' +ORDER BY + year; +-- diff --git a/src/ember/examples/sql_agent/load_f1_data.py b/src/ember/examples/sql_agent/load_f1_data.py new file mode 100644 index 00000000..85541dbb --- /dev/null +++ b/src/ember/examples/sql_agent/load_f1_data.py @@ -0,0 +1,80 @@ +"""Formula 1 Data Loader. + +This module loads Formula 1 data from remote sources into a SQLite database. +""" + +from io import StringIO +import logging +import os +from typing import Dict, Optional + +import pandas as pd +import requests +import urllib3 +from sqlalchemy import create_engine + +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# S3 URI for F1 data +s3_uri = "https://agno-public.s3.amazonaws.com/f1" + +# List of files and their corresponding table names +files_to_tables = { + f"{s3_uri}/constructors_championship_1958_2020.csv": "constructors_championship", + f"{s3_uri}/drivers_championship_1950_2020.csv": "drivers_championship", + f"{s3_uri}/fastest_laps_1950_to_2020.csv": "fastest_laps", + f"{s3_uri}/race_results_1950_to_2020.csv": "race_results", + f"{s3_uri}/race_wins_1950_to_2020.csv": "race_wins", +} + +def load_f1_data(db_path: Optional[str] = None) -> None: + """Load Formula 1 data into a SQLite database. + + Downloads F1 data from S3 and loads it into tables in a SQLite database. + + Args: + db_path: Optional path to the database file. If not provided, + defaults to 'f1_data.db' in the current directory. + """ + # Set default database path if not provided + if db_path is None: + db_path = "f1_data.db" + + # Database connection string + db_url = f"sqlite:///{db_path}" + + logger.info(f"Loading database to {db_path}") + engine = create_engine(db_url) + + # Load each CSV file into the corresponding SQLite table + for file_path, table_name in files_to_tables.items(): + logger.info(f"Loading {file_path} into {table_name} table") + + try: + # Download the file using requests + response = requests.get(file_path, verify=False) + response.raise_for_status() # Raise an exception for bad status codes + + # Read the CSV data from the response content + csv_data = StringIO(response.text) + df = pd.read_csv(csv_data) + + df.to_sql(table_name, engine, if_exists="replace", index=False) + logger.info(f"Successfully loaded {len(df)} rows into {table_name} table") + + except Exception as e: + logger.error(f"Error loading {file_path}: {str(e)}") + + logger.info(f"Database loaded to {db_path}") + +if __name__ == "__main__": + # Disable SSL verification warnings + urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + + # Get database path from command line argument if provided + import sys + db_path = sys.argv[1] if len(sys.argv) > 1 else None + + load_f1_data(db_path) \ No newline at end of file diff --git a/src/ember/examples/sql_agent/load_knowledge.py b/src/ember/examples/sql_agent/load_knowledge.py new file mode 100644 index 00000000..e1052b86 --- /dev/null +++ b/src/ember/examples/sql_agent/load_knowledge.py @@ -0,0 +1,82 @@ +import json +import os +from pathlib import Path +import logging + +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# Paths +CWD = Path(__file__).parent +KNOWLEDGE_DIR = CWD.joinpath("knowledge") + +def load_knowledge(): + """Load the knowledge base and print a summary.""" + logger.info("Loading SQL agent knowledge.") + + # Load table metadata + tables = {} + for file_path in KNOWLEDGE_DIR.glob("*.json"): + with open(file_path, "r") as f: + table_data = json.load(f) + table_name = table_data.get("table_name") + if table_name: + tables[table_name] = table_data + + logger.info(f"Loaded metadata for {len(tables)} tables.") + + # Load sample queries + sample_queries_path = KNOWLEDGE_DIR / "sample_queries.sql" + sample_queries = [] + + if sample_queries_path.exists(): + with open(sample_queries_path, "r") as f: + content = f.read() + + # Parse the sample queries + query_blocks = [] + current_block = {"description": "", "query": ""} + in_description = False + in_query = False + + for line in content.split("\n"): + if line.strip() == "-- ": + in_description = True + in_query = False + if current_block["query"]: + query_blocks.append(current_block) + current_block = {"description": "", "query": ""} + elif line.strip() == "-- ": + in_description = False + elif line.strip() == "-- ": + in_query = True + in_description = False + elif line.strip() == "-- ": + in_query = False + elif in_description: + current_block["description"] += line.replace("-- ", "") + "\n" + elif in_query: + current_block["query"] += line + "\n" + + if current_block["query"]: + query_blocks.append(current_block) + + sample_queries = query_blocks + + logger.info(f"Loaded {len(sample_queries)} sample queries.") + + # Print a summary of the knowledge base + logger.info("Knowledge base summary:") + logger.info(f"Tables: {', '.join(tables.keys())}") + logger.info(f"Sample queries: {len(sample_queries)}") + + logger.info("SQL agent knowledge loaded.") + + return { + "tables": tables, + "sample_queries": sample_queries + } + +if __name__ == "__main__": + load_knowledge() diff --git a/src/ember/examples/sql_agent/requirements.txt b/src/ember/examples/sql_agent/requirements.txt new file mode 100644 index 00000000..614ac62c --- /dev/null +++ b/src/ember/examples/sql_agent/requirements.txt @@ -0,0 +1,7 @@ +openai +streamlit +pandas +sqlalchemy +requests +nest_asyncio +tabulate \ No newline at end of file diff --git a/src/ember/examples/sql_agent/run_sql_agent.py b/src/ember/examples/sql_agent/run_sql_agent.py new file mode 100644 index 00000000..ca0aabbb --- /dev/null +++ b/src/ember/examples/sql_agent/run_sql_agent.py @@ -0,0 +1,93 @@ +"""SQL Agent Example Runner. + +This script demonstrates how to use the SQL Agent to query a Formula 1 database. +""" + +import os +import sys +from ember.examples.sql_agent.sql_agent import SQLAgent +from ember.examples.sql_agent.load_f1_data import load_f1_data + +def main() -> None: + """Run the SQL Agent example.""" + # Check if OPENAI_API_KEY is set in the environment + if not os.environ.get("OPENAI_API_KEY"): + print("Warning: OPENAI_API_KEY environment variable not set.") + print("Please set your API key using: export OPENAI_API_KEY='your-key-here'") + + # Check if a database path is provided + database_path = sys.argv[1] if len(sys.argv) > 1 else "f1_data.db" + database_url = f"sqlite:///{database_path}" + + # Check if the database exists; if not, create it + if not os.path.exists(database_path): + print(f"Database {database_path} not found. Creating and loading data...") + load_f1_data(database_path) + + # Get model name from command line arguments or use default + model_name = sys.argv[2] if len(sys.argv) > 2 else "openai:gpt-4o" + + print(f"Using database: {database_path}") + print(f"Using model: {model_name}") + + # Create an instance of the SQL Agent + agent = SQLAgent( + database_url=database_url, + model_name=model_name, + temperature=0.0 + ) + + # Example questions to demonstrate the agent + example_questions = [ + "Who won the most races in the 2019 season?", + "Which team had the most points in the constructors championship in 2018?", + "Who had the fastest lap in Monaco in 2019?", + "Compare the performance of Hamilton and Vettel in 2018" + ] + + # Run the example questions + for i, question in enumerate(example_questions, 1): + print(f"\n--- Example {i} ---") + print(f"Question: {question}") + + # Process the question + result = agent.query(question) + + # Display the results + print("\nGenerated SQL Query:") + print(result["sql_query"]) + + print("\nQuery Results:") + if result["query_result"]["success"]: + if result["query_result"]["record_count"] > 0: + import pandas as pd + df = pd.DataFrame(result["query_result"]["results"]) + print(df.to_string(index=False)) + else: + print("No results found.") + else: + print(f"Query failed: {result['query_result']['error']}") + + print("\nAnswer:") + print(result["answer"]) + print("\n" + "-" * 80) + + # Interactive mode + print("\n--- Interactive Mode ---") + print("Type your questions about Formula 1 data (or 'exit' to quit)") + + while True: + question = input("\nYour question: ") + if question.lower() in ("exit", "quit", "q"): + break + + result = agent.query(question) + + print("\nSQL Query:") + print(result["sql_query"]) + + print("\nAnswer:") + print(result["answer"]) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/ember/examples/sql_agent/sql_agent.py b/src/ember/examples/sql_agent/sql_agent.py new file mode 100644 index 00000000..8c6a7dd5 --- /dev/null +++ b/src/ember/examples/sql_agent/sql_agent.py @@ -0,0 +1,356 @@ +"""SQL Agent Module. + +This module provides a dynamic SQL Agent capable of understanding any database schema +and answering questions about the data. +""" + +import json +import logging +import time +from typing import Any, Dict, List, Optional, Union + +import pandas as pd +from sqlalchemy import create_engine, inspect, text + +from ember.api import models +from ember.examples.sql_agent.utils import add_message + +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +class SQLAgent: + """A dynamic SQL Agent that can work with any database schema. + + This agent can dynamically explore database schemas, generate appropriate SQL queries, + and provide definitive answers for any question about the data. + + Attributes: + database_url: The SQLAlchemy connection URL for the database. + model_name: The LLM model to use for query generation and answer synthesis. + temperature: The temperature setting for the LLM. + """ + + def __init__(self, + database_url: str = "sqlite:///f1_data.db", + model_name: str = "openai:gpt-4o", + temperature: float = 0.0) -> None: + """Initialize the SQL Agent. + + Args: + database_url: SQLAlchemy connection URL string. + model_name: Model identifier to use for SQL generation and answer synthesis. + Options include: + - OpenAI: "openai:gpt-4o", "openai:gpt-4", "openai:gpt-3.5-turbo" + - Anthropic: "anthropic:claude-3-opus", "anthropic:claude-3-sonnet" + - Deepmind: "deepmind:gemini-1.5-pro" + temperature: Temperature setting for the model (0.0 for deterministic outputs). + """ + self.database_url = database_url + self.model_name = model_name + self.temperature = temperature + + # Create database engine + self.engine = create_engine(database_url) + + # Create the model + self.model = models.model(model_name, temperature=temperature) + + # Cache for database schema + self._schema_cache = None + + def _get_database_schema(self) -> Dict[str, Any]: + """Dynamically retrieve and cache the database schema. + + Returns: + A dictionary containing the database schema information. + """ + if self._schema_cache: + return self._schema_cache + + logger.info("Retrieving database schema") + inspector = inspect(self.engine) + + schema = { + "tables": {} + } + + # Get all tables + for table_name in inspector.get_table_names(): + table_info = { + "columns": [], + "sample_data": None, + "row_count": 0 + } + + # Get column information + for column in inspector.get_columns(table_name): + table_info["columns"].append({ + "name": column["name"], + "type": str(column["type"]), + "nullable": column.get("nullable", True) + }) + + # Get sample data + try: + query = f"SELECT * FROM {table_name} LIMIT 5" + sample_df = pd.read_sql(query, self.engine) + table_info["sample_data"] = sample_df.to_dict(orient="records") + + # Get row count + count_query = f"SELECT COUNT(*) as count FROM {table_name}" + count_df = pd.read_sql(count_query, self.engine) + table_info["row_count"] = int(count_df["count"].iloc[0]) + except Exception as e: + logger.error(f"Error getting sample data for {table_name}: {str(e)}") + + schema["tables"][table_name] = table_info + + # Cache the schema + self._schema_cache = schema + return schema + + def _generate_sql_query(self, question: str) -> str: + """Generate a SQL query to answer the given question. + + Args: + question: The question to generate SQL for. + + Returns: + A SQL query string. + """ + schema = self._get_database_schema() + + # Format schema info for prompt + schema_info = "Database Schema:\n" + for table_name, table_info in schema["tables"].items(): + schema_info += f"Table: {table_name} ({table_info['row_count']} rows)\n" + schema_info += "Columns:\n" + + for column in table_info["columns"]: + nullable = "NULL" if column["nullable"] else "NOT NULL" + schema_info += f" - {column['name']} ({column['type']}) {nullable}\n" + + # Include sample data + if table_info["sample_data"]: + schema_info += "Sample Data:\n" + sample_df = pd.DataFrame(table_info["sample_data"]) + schema_info += f"{sample_df.head().to_string()}\n\n" + + # Prompt for SQL generation + prompt = f"""You are an expert SQL developer. Given the following database schema, write a SQL query to answer this question: "{question}" + +{schema_info} + +Rules: +1. Use only the tables and columns that exist in the schema +2. Write a valid SQL query that will run on SQLite +3. Use table and column names exactly as they appear in the schema +4. Do not use any placeholders like or +5. Always return a query that will provide a meaningful answer +6. Format complex queries for readability + +SQL query:""" + + # Generate the SQL query + start_time = time.time() + generated_sql = self.model(prompt).strip() + end_time = time.time() + logger.info(f"Query generation took {end_time - start_time:.2f} seconds") + + return generated_sql + + def _validate_and_correct_query(self, query: str) -> str: + """Validate and correct SQL queries before execution. + + Args: + query: The SQL query to validate. + + Returns: + A corrected SQL query string. + """ + schema = self._get_database_schema() + + # Check for table names + tables = schema["tables"].keys() + for table in tables: + # Very basic check - could be enhanced + if table in query: + return query + + # If no valid tables found, try to correct + prompt = f"""The following SQL query doesn't seem to reference any existing tables in our database: + +{query} + +Our database has the following tables: {', '.join(tables)} + +Please rewrite the query to use the correct table names:""" + + corrected_query = self.model(prompt).strip() + return corrected_query + + def _execute_query(self, query: str) -> Dict[str, Any]: + """Execute a SQL query and return the results. + + Args: + query: The SQL query to execute. + + Returns: + A dictionary with the query execution results and metadata. + """ + try: + # Validate and correct query + validated_query = self._validate_and_correct_query(query) + + # Execute query + start_time = time.time() + result_df = pd.read_sql(text(validated_query), self.engine) + end_time = time.time() + execution_time = end_time - start_time + + # Convert results to dictionary + records = result_df.to_dict(orient="records") + + return { + "success": True, + "query": validated_query, + "results": records, + "record_count": len(records), + "execution_time": f"{execution_time:.2f} seconds", + "columns": list(result_df.columns), + "error": None + } + except Exception as e: + logger.error(f"Error executing query: {str(e)}") + return { + "success": False, + "query": query, + "results": [], + "record_count": 0, + "execution_time": "0.00 seconds", + "columns": [], + "error": str(e) + } + + def _generate_answer(self, question: str, query_result: Dict[str, Any]) -> str: + """Generate a natural language answer based on the query results. + + Args: + question: The original question. + query_result: The results from the query execution. + + Returns: + A natural language answer to the question. + """ + # Format the results for the prompt + formatted_results = "No results found." + if query_result["success"] and query_result["results"]: + result_df = pd.DataFrame(query_result["results"]) + formatted_results = result_df.to_string() + + # Prompt for answer generation + prompt = f"""You are an analytics assistant. Answer the following question based on the database query results. + +Question: {question} + +SQL Query: {query_result["query"]} + +Query Results: +{formatted_results} + +Execution Information: +- Record Count: {query_result["record_count"]} +- Execution Time: {query_result["execution_time"]} +- Success: {query_result["success"]} +{f'- Error: {query_result["error"]}' if not query_result["success"] else ''} + +Instructions: +1. Provide a clear, definitive answer to the question +2. Include specific details and numbers from the query results +3. If the query failed, explain why and suggest a better approach +4. Be concise but thorough +5. Format the answer for readability with markdown where appropriate + +Answer:""" + + # Generate the answer + start_time = time.time() + answer = self.model(prompt).strip() + end_time = time.time() + logger.info(f"Answer generation took {end_time - start_time:.2f} seconds") + + return answer + + def query(self, question: str) -> Dict[str, Any]: + """Process a natural language question about the database. + + Args: + question: The question to answer. + + Returns: + A dictionary containing the question, SQL query, results, and answer. + """ + logger.info(f"Processing question: {question}") + + # Generate SQL query + sql_query = self._generate_sql_query(question) + logger.info(f"Generated SQL query: {sql_query}") + + # Execute the query + query_result = self._execute_query(sql_query) + + # Generate answer + answer = self._generate_answer(question, query_result) + + # Return complete response + return { + "question": question, + "sql_query": sql_query, + "query_result": query_result, + "answer": answer + } + + def chat(self, question: str, messages: Optional[List[Dict[str, Any]]] = None) -> Dict[str, Any]: + """Interactive chat interface with message history. + + Args: + question: The question to answer. + messages: Optional list of previous messages. + + Returns: + A dictionary with the updated message history and latest response. + """ + if messages is None: + messages = [] + + # Add user question to messages + messages = add_message(messages, "user", question) + + # Process the question + response = self.query(question) + + # Create assistant message content + content = response["answer"] + + # Create tool calls list for transparency + tool_calls = [ + { + "tool_name": "generate_sql_query", + "tool_args": {"question": question}, + "content": response["sql_query"] + }, + { + "tool_name": "execute_sql_query", + "tool_args": {"query": response["sql_query"]}, + "content": json.dumps(response["query_result"], indent=2) + } + ] + + # Add assistant response to messages + messages = add_message(messages, "assistant", content, tool_calls) + + return { + "messages": messages, + "response": response + } \ No newline at end of file diff --git a/src/ember/examples/sql_agent/utils.py b/src/ember/examples/sql_agent/utils.py new file mode 100644 index 00000000..1c4fd1ba --- /dev/null +++ b/src/ember/examples/sql_agent/utils.py @@ -0,0 +1,201 @@ +"""SQL Agent Utilities Module. + +This module provides utility functions for the SQL Agent. +""" + +import json +import logging +from typing import Any, Dict, List, Optional + +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +def is_json(myjson: Any) -> bool: + """Check if a string is valid JSON. + + Args: + myjson: The string to check. + + Returns: + bool: True if the string is valid JSON, False otherwise. + """ + try: + json.loads(myjson) + return True + except (ValueError, TypeError): + return False + +def add_message( + messages: List[Dict[str, Any]], + role: str, + content: str, + tool_calls: Optional[List[Dict[str, Any]]] = None +) -> List[Dict[str, Any]]: + """Safely add a message to the messages list. + + Args: + messages: The list of messages to add to. + role: The role of the message sender (user, assistant, system). + content: The content of the message. + tool_calls: Optional list of tool calls. + + Returns: + The updated list of messages. + """ + if messages is None: + messages = [] + + messages.append({ + "role": role, + "content": content, + "tool_calls": tool_calls + }) + + return messages + +def export_chat_history(messages: List[Dict[str, Any]]) -> str: + """Export chat history as markdown. + + Args: + messages: The list of messages to export. + + Returns: + str: The chat history formatted as markdown. + """ + if not messages: + return "# SQL Agent - Chat History\n\nNo messages in history." + + chat_text = "# SQL Agent - Chat History\n\n" + + for msg in messages: + role = "🤖 Assistant" if msg["role"] == "assistant" else "👤 User" + chat_text += f"### {role}\n{msg['content']}\n\n" + + return chat_text + +def display_tool_calls(tool_calls_container: Any, tools: List[Dict[str, Any]]) -> None: + """Display tool calls in a streamlit container with expandable sections. + + Args: + tool_calls_container: The streamlit container to display in. + tools: The list of tools to display. + """ + try: + if not tools: + return + + with tool_calls_container.container(): + for tool_call in tools: + tool_name = tool_call.get("tool_name", "Unknown Tool") + tool_args = tool_call.get("tool_args", {}) + content = tool_call.get("content", None) + + with tool_calls_container.expander( + f"🛠️ {tool_name.replace('_', ' ').title()}", + expanded=False, + ): + # Show query with syntax highlighting + if isinstance(tool_args, dict) and "query" in tool_args: + tool_calls_container.code(tool_args["query"], language="sql") + + # Display arguments in a more readable format + if tool_args and tool_args != {"query": None}: + tool_calls_container.markdown("**Arguments:**") + tool_calls_container.json(tool_args) + + if content is not None: + try: + if is_json(content): + try: + parsed_content = json.loads(content) + tool_calls_container.markdown("**Results:**") + tool_calls_container.json(parsed_content) + except: + tool_calls_container.markdown("**Results:**") + tool_calls_container.markdown(content) + else: + tool_calls_container.markdown("**Results:**") + tool_calls_container.markdown(content) + except Exception as e: + logger.debug(f"Skipped tool call content: {e}") + + except Exception as e: + logger.error(f"Error displaying tool calls: {str(e)}") + tool_calls_container.error("Failed to display tool results") + +# Custom CSS for the Streamlit app +CUSTOM_CSS = """ + +""" \ No newline at end of file