Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 191 additions & 5 deletions src/strands/multiagent/a2a/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,28 @@
streamed requests to the A2AServer.
"""

import json
import logging
from typing import Any
from typing import Any, Literal

from a2a.server.agent_execution import AgentExecutor, RequestContext
from a2a.server.events import EventQueue
from a2a.server.tasks import TaskUpdater
from a2a.types import InternalError, Part, TaskState, TextPart, UnsupportedOperationError
from a2a.types import DataPart, FilePart, InternalError, Part, TaskState, TextPart, UnsupportedOperationError
from a2a.utils import new_agent_text_message, new_task
from a2a.utils.errors import ServerError

from ...agent.agent import Agent as SAAgent
from ...agent.agent import AgentResult as SAAgentResult
from ...types.content import ContentBlock
from ...types.media import (
DocumentContent,
DocumentSource,
ImageContent,
ImageSource,
VideoContent,
VideoSource,
)

logger = logging.getLogger(__name__)

Expand All @@ -31,6 +41,26 @@ class StrandsA2AExecutor(AgentExecutor):
and converts Strands Agent responses to A2A protocol events.
"""

# File format mappings for different content types
IMAGE_FORMAT_MAPPINGS = {"jpeg": "jpeg", "jpg": "jpeg", "png": "png", "gif": "gif", "webp": "webp"}

VIDEO_FORMAT_MAPPINGS = {
"mp4": "mp4",
"mpeg": "mpeg",
"mpg": "mpg",
"webm": "webm",
"mov": "mov",
"mkv": "mkv",
"flv": "flv",
"wmv": "wmv",
"3gpp": "three_gp",
}

DOCUMENT_FORMAT_MAPPINGS = {"pdf": "pdf", "csv": "csv", "html": "html", "plain": "txt", "markdown": "md"}

# Default formats for each file type when MIME type is unavailable
DEFAULT_FORMATS = {"document": "txt", "image": "png", "video": "mp4", "unknown": "txt"}

def __init__(self, agent: SAAgent):
"""Initialize a StrandsA2AExecutor.

Expand Down Expand Up @@ -78,10 +108,15 @@ async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater
context: The A2A request context, containing the user's input and other metadata.
updater: The task updater for managing task state and sending updates.
"""
logger.info("Executing request in streaming mode")
user_input = context.get_user_input()
# Convert A2A message parts to Strands ContentBlocks
if context.message and hasattr(context.message, "parts"):
content_blocks = self._convert_a2a_parts_to_content_blocks(context.message.parts)
else:
# Fallback to original text extraction if no parts available
user_input = context.get_user_input()
content_blocks = [ContentBlock(text=user_input)]
try:
async for event in self.agent.stream_async(user_input):
async for event in self.agent.stream_async(content_blocks):
await self._handle_streaming_event(event, updater)
except Exception:
logger.exception("Error in streaming execution")
Expand Down Expand Up @@ -146,3 +181,154 @@ async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None
"""
logger.warning("Cancellation requested but not supported")
raise ServerError(error=UnsupportedOperationError())

def _get_file_type_from_mime_type(self, mime_type: str | None) -> Literal["document", "image", "video", "unknown"]:
"""Classify file type based on MIME type.

Args:
mime_type: The MIME type of the file

Returns:
The classified file type
"""
if not mime_type:
return "unknown"

mime_type = mime_type.lower()

if mime_type.startswith("image/"):
return "image"
elif mime_type.startswith("video/"):
return "video"
elif (
mime_type.startswith("text/")
or mime_type.startswith("application/")
or mime_type in ["application/pdf", "application/json", "application/xml"]
):
return "document"
else:
return "unknown"

def _get_file_format_from_mime_type(self, mime_type: str | None, file_type: str) -> str:
"""Extract file format from MIME type.

Args:
mime_type: The MIME type of the file
file_type: The classified file type (image, video, document, txt)

Returns:
The file format string
"""
if not mime_type:
return self.DEFAULT_FORMATS.get(file_type, "txt")

mime_type = mime_type.lower()

# Extract format from MIME type
if "/" in mime_type:
format_part = mime_type.split("/")[1]

# Handle common MIME type mappings with validation
if file_type == "image":
return self.IMAGE_FORMAT_MAPPINGS.get(format_part, "png")
elif file_type == "video":
return self.VIDEO_FORMAT_MAPPINGS.get(format_part, "mp4")
else: # document
return self.DOCUMENT_FORMAT_MAPPINGS.get(format_part, "txt")

# Fallback defaults
return self.DEFAULT_FORMATS.get(file_type, "txt")

def _strip_file_extension(self, file_name: str) -> str:
"""Strip the file extension from a file name.

Args:
file_name: The original file name with extension

Returns:
The file name without extension
"""
if "." in file_name:
return file_name.rsplit(".", 1)[0]
return file_name

def _convert_a2a_parts_to_content_blocks(self, parts: list[Part]) -> list[ContentBlock]:
"""Convert A2A message parts to Strands ContentBlocks.

Args:
parts: List of A2A Part objects

Returns:
List of Strands ContentBlock objects
"""
content_blocks: list[ContentBlock] = []

for part in parts:
try:
part_root = part.root

if isinstance(part_root, TextPart):
# Handle TextPart
content_blocks.append(ContentBlock(text=part_root.text))

elif isinstance(part_root, FilePart):
# Handle FilePart
file_obj = part_root.file
mime_type = getattr(file_obj, "mime_type", None)
raw_file_name = getattr(file_obj, "name", "FileNameNotProvided")
file_name = self._strip_file_extension(raw_file_name)
file_type = self._get_file_type_from_mime_type(mime_type)
file_format = self._get_file_format_from_mime_type(mime_type, file_type)

# Handle FileWithBytes vs FileWithUri
bytes_data = getattr(file_obj, "bytes", None)
uri_data = getattr(file_obj, "uri", None)

if bytes_data:
if file_type == "image":
content_blocks.append(
ContentBlock(
image=ImageContent(
format=file_format, # type: ignore
source=ImageSource(bytes=bytes_data),
)
)
)
elif file_type == "video":
content_blocks.append(
ContentBlock(
video=VideoContent(
format=file_format, # type: ignore
source=VideoSource(bytes=bytes_data),
)
)
)
else: # document or unknown
content_blocks.append(
ContentBlock(
document=DocumentContent(
format=file_format, # type: ignore
name=file_name,
source=DocumentSource(bytes=bytes_data),
)
)
)
# Handle FileWithUri
elif uri_data:
# For URI files, create a text representation since Strands ContentBlocks expect bytes
content_blocks.append(
ContentBlock(
text="[File: %s (%s)] - Referenced file at: %s" % (file_name, mime_type, uri_data)
)
)
elif isinstance(part_root, DataPart):
# Handle DataPart - convert structured data to JSON text
try:
data_text = json.dumps(part_root.data, indent=2)
content_blocks.append(ContentBlock(text="[Structured Data]\n%s" % data_text))
except Exception:
logger.exception("Failed to serialize data part")
except Exception:
logger.exception("Error processing part")

return content_blocks
5 changes: 2 additions & 3 deletions src/strands/multiagent/a2a/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
agent: SAAgent,
*,
# AgentCard
host: str = "0.0.0.0",
host: str = "127.0.0.1",
port: int = 9000,
http_url: str | None = None,
serve_at_root: bool = False,
Expand All @@ -42,13 +42,12 @@ def __init__(
queue_manager: QueueManager | None = None,
push_config_store: PushNotificationConfigStore | None = None,
push_sender: PushNotificationSender | None = None,

):
"""Initialize an A2A-compatible server from a Strands agent.

Args:
agent: The Strands Agent to wrap with A2A compatibility.
host: The hostname or IP address to bind the A2A server to. Defaults to "0.0.0.0".
host: The hostname or IP address to bind the A2A server to. Defaults to "127.0.0.1".
port: The port to bind the A2A server to. Defaults to 9000.
http_url: The public HTTP URL where this agent will be accessible. If provided,
this overrides the generated URL from host/port and enables automatic
Expand Down
Loading
Loading