Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apps/agentstack-sdk-py/examples/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ async def execute(
agent = beeai_framework.agents.react.ReActAgent(
llm=beeai_framework.adapters.openai.backend.chat.OpenAIChatModel(
model_id=self.context_llm[context.context_id]["default"].api_model,
api_key=self.context_llm[context.context_id]["default"].api_key,
api_key=self.context_llm[context.context_id]["default"].api_key.get_secret_value(),
base_url=self.context_llm[context.context_id]["default"].api_base,
),
tools=[
Expand Down
4 changes: 2 additions & 2 deletions apps/agentstack-sdk-py/examples/secrets_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ async def secrets_agent(
):
"""Agent that uses request a secret that can be provided during runtime"""
if secrets and secrets.data and secrets.data.secret_fulfillments:
yield f"IBM Cloud API key: {secrets.data.secret_fulfillments['ibm_cloud'].secret}"
yield f"IBM Cloud API key: {secrets.data.secret_fulfillments['ibm_cloud'].secret.get_secret_value()}"
else:
runtime_provided_secrets = await secrets.request_secrets(
params=SecretsServiceExtensionParams(
secret_demands={"ibm_cloud": SecretDemand(description="I really need IBM Cloud Key", name="IBM Cloud")}
)
)
if runtime_provided_secrets and runtime_provided_secrets.secret_fulfillments:
yield f"IBM Cloud API key: {runtime_provided_secrets.secret_fulfillments['ibm_cloud'].secret}"
yield f"IBM Cloud API key: {runtime_provided_secrets.secret_fulfillments['ibm_cloud'].secret.get_secret_value()}"
else:
yield "No IBM Cloud API key provided"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,23 @@
from agentstack_sdk.a2a.extensions.auth.oauth.storage import MemoryTokenStorageFactory, TokenStorageFactory
from agentstack_sdk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec
from agentstack_sdk.a2a.types import AgentMessage, AuthRequired, RunYieldResume
from agentstack_sdk.util.pydantic import REVEAL_SECRETS, SecureBaseModel

if TYPE_CHECKING:
from agentstack_sdk.server.context import RunContext

_DEFAULT_DEMAND_NAME = "default"


class AuthRequest(pydantic.BaseModel):
class AuthRequest(SecureBaseModel):
authorization_endpoint_url: pydantic.AnyUrl


class AuthResponse(pydantic.BaseModel):
class AuthResponse(SecureBaseModel):
redirect_uri: pydantic.AnyUrl


class OAuthFulfillment(pydantic.BaseModel):
class OAuthFulfillment(SecureBaseModel):
redirect_uri: pydantic.AnyUrl


Expand Down Expand Up @@ -122,7 +123,10 @@ async def handle_callback() -> tuple[str, str | None]:

def create_auth_request(self, *, authorization_endpoint_url: pydantic.AnyUrl):
data = AuthRequest(authorization_endpoint_url=authorization_endpoint_url)
return AgentMessage(text="Authorization required", metadata={self.spec.URI: data.model_dump(mode="json")})
return AgentMessage(
text="Authorization required",
metadata={self.spec.URI: data.model_dump(mode="json", context={REVEAL_SECRETS: True})},
)

def parse_auth_response(self, *, message: A2AMessage):
if not message or not message.metadata or not (data := message.metadata.get(self.spec.URI)):
Expand All @@ -147,5 +151,5 @@ def create_auth_response(self, *, task_id: str, redirect_uri: pydantic.AnyUrl):
role=Role.user,
parts=[TextPart(text="Authorization completed")],
task_id=task_id,
metadata={self.spec.URI: data.model_dump(mode="json")},
metadata={self.spec.URI: data.model_dump(mode="json", context={REVEAL_SECRETS: True})},
)
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,28 @@
import pydantic
from a2a.server.agent_execution.context import RequestContext
from a2a.types import Message as A2AMessage
from opentelemetry import trace
from typing_extensions import override

from agentstack_sdk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec
from agentstack_sdk.a2a.types import AgentMessage, AuthRequired
from agentstack_sdk.util.pydantic import REDACT_SECRETS, REVEAL_SECRETS, SecureBaseModel
from agentstack_sdk.util.telemetry import flatten_dict

if TYPE_CHECKING:
from agentstack_sdk.server.context import RunContext

A2A_EXTENSION_SECRETS_REQUESTED = "a2a_extension.secrets.requested"
A2A_EXTENSION_SECRETS_RESOLVED = "a2a_extension.secrets.resolved"


class SecretDemand(pydantic.BaseModel):
name: str
description: str | None = None


class SecretFulfillment(pydantic.BaseModel):
secret: str
class SecretFulfillment(SecureBaseModel):
secret: pydantic.SecretStr


class SecretsServiceExtensionParams(pydantic.BaseModel):
Expand Down Expand Up @@ -61,15 +67,25 @@ def parse_secret_response(self, message: A2AMessage) -> SecretsServiceExtensionM
return SecretsServiceExtensionMetadata.model_validate(data)

async def request_secrets(self, params: SecretsServiceExtensionParams) -> SecretsServiceExtensionMetadata:
span = trace.get_current_span()
span.add_event(
A2A_EXTENSION_SECRETS_REQUESTED,
attributes=flatten_dict(params.model_dump(context={REDACT_SECRETS: True})),
)
resume = await self.context.yield_async(
AuthRequired(
message=AgentMessage(
metadata={self.spec.URI: params.model_dump(mode="json")},
metadata={self.spec.URI: params.model_dump(mode="json", context={REVEAL_SECRETS: True})},
)
)
)
if isinstance(resume, A2AMessage):
return self.parse_secret_response(message=resume)
response = self.parse_secret_response(message=resume)
span.add_event(
A2A_EXTENSION_SECRETS_RESOLVED,
attributes=flatten_dict(response.model_dump(context={REDACT_SECRETS: True})),
)
return response
else:
raise ValueError("Secrets has not been provided in response.")

Expand Down
32 changes: 30 additions & 2 deletions apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,17 @@
from a2a.server.agent_execution.context import RequestContext
from a2a.types import AgentCard, AgentExtension
from a2a.types import Message as A2AMessage
from opentelemetry import trace
from opentelemetry.trace import SpanKind
from pydantic import BaseModel
from typing_extensions import override

from agentstack_sdk.util.pydantic import REDACT_SECRETS
from agentstack_sdk.util.telemetry import (
flatten_dict,
trace_class,
)

ParamsT = typing.TypeVar("ParamsT")
MetadataFromClientT = typing.TypeVar("MetadataFromClientT")
MetadataFromServerT = typing.TypeVar("MetadataFromServerT")
Expand All @@ -25,6 +34,10 @@
from agentstack_sdk.server.dependencies import Dependency


A2A_EXTENSION_URI = "a2a_extension.uri"
A2A_EXTENSION_METADATA_RECEIVED_EVENT = "a2a_extension.metadata.received"


def _get_generic_args(cls: type, base_class: type) -> tuple[typing.Any, ...]:
for base in getattr(cls, "__orig_bases__", ()):
if typing.get_origin(base) is base_class and (args := typing.get_args(base)):
Expand Down Expand Up @@ -121,7 +134,14 @@ class BaseExtensionServer(abc.ABC, typing.Generic[ExtensionSpecT, MetadataFromCl

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls.MetadataFromClient = _get_generic_args(cls, BaseExtensionServer)[1]

generic_args = _get_generic_args(cls, BaseExtensionServer)
trace_class(
kind=SpanKind.SERVER,
exclude_list=["lifespan", "_fork"],
attributes={A2A_EXTENSION_URI: generic_args[0].URI},
)(cls)
cls.MetadataFromClient = generic_args[1]

_metadata_from_client: MetadataFromClientT | None = None
_dependencies: dict[str, Dependency] = {} # noqa: RUF012
Expand Down Expand Up @@ -151,6 +171,11 @@ def parse_client_metadata(self, message: A2AMessage) -> MetadataFromClientT | No
def handle_incoming_message(self, message: A2AMessage, run_context: RunContext, request_context: RequestContext):
if self._metadata_from_client is None:
self._metadata_from_client = self.parse_client_metadata(message)
if isinstance(self._metadata_from_client, BaseModel):
trace.get_current_span().add_event(
A2A_EXTENSION_METADATA_RECEIVED_EVENT,
attributes=flatten_dict(self._metadata_from_client.model_dump(context={REDACT_SECRETS: True})),
)

def _fork(self) -> typing.Self:
"""Creates a clone of this instance with the same arguments as the original"""
Expand Down Expand Up @@ -182,7 +207,10 @@ class BaseExtensionClient(abc.ABC, typing.Generic[ExtensionSpecT, MetadataFromSe

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls.MetadataFromServer = _get_generic_args(cls, BaseExtensionClient)[1]

generic_args = _get_generic_args(cls, BaseExtensionClient)
trace_class(kind=SpanKind.CLIENT, attributes={A2A_EXTENSION_URI: generic_args[0].URI})(cls)
cls.MetadataFromServer = generic_args[1]

def __init__(self, spec: ExtensionSpecT) -> None:
self.spec = spec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,33 +9,40 @@

import a2a.types
from mcp import Implementation, Tool
from opentelemetry import trace
from pydantic import BaseModel, Discriminator, Field, TypeAdapter

from agentstack_sdk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec
from agentstack_sdk.a2a.types import AgentMessage, InputRequired
from agentstack_sdk.util.pydantic import REDACT_SECRETS, REVEAL_SECRETS, SecureBaseModel
from agentstack_sdk.util.telemetry import flatten_dict

if TYPE_CHECKING:
from agentstack_sdk.server.context import RunContext


A2A_EXTENSION_APPROVAL_REQUESTED = "a2a_extension.approval.requested"
A2A_EXTENSION_APPROVAL_RESOLVED = "a2a_extension.approval.resolved"


class ApprovalRejectionError(RuntimeError):
pass


class GenericApprovalRequest(BaseModel):
class GenericApprovalRequest(SecureBaseModel):
action: Literal["generic"] = "generic"

title: str | None = Field(None, description="A human-readable title for the action being approved.")
description: str | None = Field(None, description="A human-readable description of the action being approved.")


class ToolCallServer(BaseModel):
class ToolCallServer(SecureBaseModel):
name: str = Field(description="The programmatic name of the server.")
title: str | None = Field(description="A human-readable title for the server.")
version: str = Field(description="The version of the server.")


class ToolCallApprovalRequest(BaseModel):
class ToolCallApprovalRequest(SecureBaseModel):
action: Literal["tool-call"] = "tool-call"

title: str | None = Field(None, description="A human-readable title of the tool.")
Expand All @@ -60,7 +67,7 @@ def from_mcp_tool(
ApprovalRequest = Annotated[GenericApprovalRequest | ToolCallApprovalRequest, Discriminator("action")]


class ApprovalResponse(BaseModel):
class ApprovalResponse(SecureBaseModel):
decision: Literal["approve", "reject"]

@property
Expand All @@ -86,7 +93,10 @@ class ApprovalExtensionMetadata(BaseModel):

class ApprovalExtensionServer(BaseExtensionServer[ApprovalExtensionSpec, ApprovalExtensionMetadata]):
def create_request_message(self, *, request: ApprovalRequest):
return AgentMessage(text="Approval requested", metadata={self.spec.URI: request.model_dump(mode="json")})
return AgentMessage(
text="Approval requested",
metadata={self.spec.URI: request.model_dump(mode="json", context={REVEAL_SECRETS: True})},
)

def parse_response(self, *, message: a2a.types.Message):
if not message.metadata or not (data := message.metadata.get(self.spec.URI)):
Expand All @@ -99,11 +109,21 @@ async def request_approval(
*,
context: RunContext,
) -> ApprovalResponse:
span = trace.get_current_span()
span.add_event(
A2A_EXTENSION_APPROVAL_REQUESTED,
attributes=flatten_dict(request.model_dump(context={REDACT_SECRETS: True})),
)
message = self.create_request_message(request=request)
message = await context.yield_async(InputRequired(message=message))
if not message:
raise RuntimeError("Yield did not return a message")
return self.parse_response(message=message)
response = self.parse_response(message=message)
span.add_event(
A2A_EXTENSION_APPROVAL_RESOLVED,
attributes=flatten_dict(response.model_dump(context={REDACT_SECRETS: True})),
)
return response


class ApprovalExtensionClient(BaseExtensionClient[ApprovalExtensionSpec, NoneType]):
Expand All @@ -113,7 +133,7 @@ def create_response_message(self, *, response: ApprovalResponse, task_id: str |
role=a2a.types.Role.user,
parts=[],
task_id=task_id,
metadata={self.spec.URI: response.model_dump(mode="json")},
metadata={self.spec.URI: response.model_dump(mode="json", context={REVEAL_SECRETS: True})},
)

def parse_request(self, *, message: a2a.types.Message):
Expand All @@ -122,4 +142,4 @@ def parse_request(self, *, message: a2a.types.Message):
return TypeAdapter(ApprovalRequest).validate_python(data)

def metadata(self) -> dict[str, Any]:
return {self.spec.URI: ApprovalExtensionMetadata().model_dump(mode="json")}
return {self.spec.URI: ApprovalExtensionMetadata().model_dump(mode="json", context={REVEAL_SECRETS: True})}
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
from typing_extensions import override

from agentstack_sdk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec
from agentstack_sdk.util.pydantic import REVEAL_SECRETS, SecureBaseModel

if TYPE_CHECKING:
from agentstack_sdk.server.context import RunContext


class EmbeddingFulfillment(pydantic.BaseModel):
class EmbeddingFulfillment(SecureBaseModel):
identifier: str | None = None
"""
Name of the model for identification and optimization purposes. Usually corresponds to LiteLLM identifiers.
Expand All @@ -31,7 +32,7 @@ class EmbeddingFulfillment(pydantic.BaseModel):
Base URL for an OpenAI-compatible API. It should provide at least /v1/chat/completions
"""

api_key: str
api_key: pydantic.SecretStr
"""
API key to attach as a `Authorization: Bearer $api_key` header.
"""
Expand Down Expand Up @@ -101,6 +102,6 @@ class EmbeddingServiceExtensionClient(BaseExtensionClient[EmbeddingServiceExtens
def fulfillment_metadata(self, *, embedding_fulfillments: dict[str, EmbeddingFulfillment]) -> dict[str, Any]:
return {
self.spec.URI: EmbeddingServiceExtensionMetadata(embedding_fulfillments=embedding_fulfillments).model_dump(
mode="json"
mode="json", context={REVEAL_SECRETS: True}
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
from typing_extensions import override

from agentstack_sdk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec
from agentstack_sdk.util.pydantic import REVEAL_SECRETS, SecureBaseModel

if TYPE_CHECKING:
from agentstack_sdk.server.context import RunContext


class LLMFulfillment(pydantic.BaseModel):
class LLMFulfillment(SecureBaseModel):
identifier: str | None = None
"""
Name of the model for identification and optimization purposes. Usually corresponds to LiteLLM identifiers.
Expand All @@ -31,7 +32,7 @@ class LLMFulfillment(pydantic.BaseModel):
Base URL for an OpenAI-compatible API. It should provide at least /v1/chat/completions
"""

api_key: str
api_key: pydantic.SecretStr
"""
API key to attach as a `Authorization: Bearer $api_key` header.
"""
Expand Down Expand Up @@ -97,4 +98,8 @@ def handle_incoming_message(self, message: A2AMessage, run_context: RunContext,

class LLMServiceExtensionClient(BaseExtensionClient[LLMServiceExtensionSpec, NoneType]):
def fulfillment_metadata(self, *, llm_fulfillments: dict[str, LLMFulfillment]) -> dict[str, Any]:
return {self.spec.URI: LLMServiceExtensionMetadata(llm_fulfillments=llm_fulfillments).model_dump(mode="json")}
return {
self.spec.URI: LLMServiceExtensionMetadata(llm_fulfillments=llm_fulfillments).model_dump(
mode="json", context={REVEAL_SECRETS: True}
)
}
Loading
Loading