diff --git a/faststream/_internal/endpoint/subscriber/specification.py b/faststream/_internal/endpoint/subscriber/specification.py index 369649713c..985cfacf0c 100644 --- a/faststream/_internal/endpoint/subscriber/specification.py +++ b/faststream/_internal/endpoint/subscriber/specification.py @@ -7,7 +7,10 @@ from faststream._internal.configs import BrokerConfig, SubscriberSpecificationConfig from faststream.exceptions import SetupError -from faststream.specification.asyncapi.message import parse_handler_params +from faststream.specification.asyncapi.message import ( + parse_handler_params, + parse_handler_return, +) from faststream.specification.asyncapi.utils import to_camelcase if TYPE_CHECKING: @@ -78,6 +81,34 @@ def get_payloads(self) -> list[tuple["dict[str, Any]", str]]: return payloads + def get_reply_payloads(self) -> list[tuple["dict[str, Any]", str]]: + payloads: list[tuple[dict[str, Any], str]] = [] + + call_name = self.call_name + + for h in self.calls: + if h.dependant is None: + msg = "You should setup `Handler` at first." + raise SetupError(msg) + + reply_body = parse_handler_return( + h.dependant, + prefix=f"{self.config.title_ or call_name}:ReplyMessage", + ) + payloads.append((reply_body, to_camelcase(h.name))) + + if not self.calls: + payloads.append( + ( + { + "title": f"{self.config.title_ or call_name}:ReplyMessage:Payload", + }, + to_camelcase(call_name), + ), + ) + + return payloads + @property @abstractmethod def name(self) -> str: diff --git a/faststream/confluent/subscriber/specification.py b/faststream/confluent/subscriber/specification.py index 95680561a1..86386038e2 100644 --- a/faststream/confluent/subscriber/specification.py +++ b/faststream/confluent/subscriber/specification.py @@ -31,6 +31,7 @@ def name(self) -> str: def get_schema(self) -> dict[str, SubscriberSpec]: payloads = self.get_payloads() + reply_payloads = self.get_reply_payloads() channels = {} for t in self.topics: @@ -43,6 +44,10 @@ def get_schema(self) -> dict[str, SubscriberSpec]: title=f"{handler_name}:Message", payload=resolve_payloads(payloads), ), + reply_message=Message( + title=f"{handler_name}:ReplyMessage", + payload=resolve_payloads(reply_payloads), + ), bindings=None, ), bindings=ChannelBinding( diff --git a/faststream/kafka/subscriber/specification.py b/faststream/kafka/subscriber/specification.py index 5325069e9e..367f509821 100644 --- a/faststream/kafka/subscriber/specification.py +++ b/faststream/kafka/subscriber/specification.py @@ -31,6 +31,7 @@ def name(self) -> str: def get_schema(self) -> dict[str, SubscriberSpec]: payloads = self.get_payloads() + reply_payloads = self.get_reply_payloads() channels = {} for t in self.topics: @@ -43,6 +44,10 @@ def get_schema(self) -> dict[str, SubscriberSpec]: title=f"{handler_name}:Message", payload=resolve_payloads(payloads), ), + reply_message=Message( + title=f"{handler_name}:ReplyMessage", + payload=resolve_payloads(reply_payloads), + ), bindings=None, ), bindings=ChannelBinding( diff --git a/faststream/nats/subscriber/specification.py b/faststream/nats/subscriber/specification.py index 31f3f8b445..25152aa13b 100644 --- a/faststream/nats/subscriber/specification.py +++ b/faststream/nats/subscriber/specification.py @@ -23,6 +23,7 @@ def name(self) -> str: def get_schema(self) -> dict[str, SubscriberSpec]: payloads = self.get_payloads() + reply_payloads = self.get_reply_payloads() return { self.name: SubscriberSpec( @@ -32,6 +33,10 @@ def get_schema(self) -> dict[str, SubscriberSpec]: title=f"{self.name}:Message", payload=resolve_payloads(payloads), ), + reply_message=Message( + title=f"{self.name}:ReplyMessage", + payload=resolve_payloads(reply_payloads), + ), bindings=None, ), bindings=ChannelBinding( diff --git a/faststream/rabbit/subscriber/specification.py b/faststream/rabbit/subscriber/specification.py index 78408c8b5d..737634cb60 100644 --- a/faststream/rabbit/subscriber/specification.py +++ b/faststream/rabbit/subscriber/specification.py @@ -31,6 +31,7 @@ def name(self) -> str: def get_schema(self) -> dict[str, SubscriberSpec]: payloads = self.get_payloads() + reply_payloads = self.get_reply_payloads() queue = self.config.queue.add_prefix(self._outer_config.prefix) @@ -59,6 +60,10 @@ def get_schema(self) -> dict[str, SubscriberSpec]: title=f"{channel_name}:Message", payload=resolve_payloads(payloads), ), + reply_message=Message( + title=f"{channel_name}:ReplyMessage", + payload=resolve_payloads(reply_payloads), + ), ), bindings=ChannelBinding( amqp=amqp.ChannelBinding( diff --git a/faststream/redis/subscriber/specification.py b/faststream/redis/subscriber/specification.py index 1ea1721b1d..83c49f3923 100644 --- a/faststream/redis/subscriber/specification.py +++ b/faststream/redis/subscriber/specification.py @@ -20,6 +20,7 @@ class RedisSubscriberSpecification( ): def get_schema(self) -> dict[str, SubscriberSpec]: payloads = self.get_payloads() + reply_payloads = self.get_reply_payloads() return { self.name: SubscriberSpec( @@ -29,6 +30,10 @@ def get_schema(self) -> dict[str, SubscriberSpec]: title=f"{self.name}:Message", payload=resolve_payloads(payloads), ), + reply_message=Message( + title=f"{self.name}:ReplyMessage", + payload=resolve_payloads(reply_payloads), + ), bindings=None, ), bindings=ChannelBinding( diff --git a/faststream/specification/asyncapi/message.py b/faststream/specification/asyncapi/message.py index 16946a69da..d1a2e06f70 100644 --- a/faststream/specification/asyncapi/message.py +++ b/faststream/specification/asyncapi/message.py @@ -36,6 +36,29 @@ def parse_handler_params(call: "CallModel", prefix: str = "") -> dict[str, Any]: return body +def parse_handler_return(call: "CallModel", prefix: str = "") -> dict[str, Any]: + """Parses the handler parameters.""" + model_container = getattr(call, "serializer", call) + response_option = getattr(model_container, "response_option", None) + if not response_option: + return {"title": "EmptyPayload", "type": "null"} + out = response_option["return"] + + body = get_model_schema( + create_model( + "", + **{out.field_name: (out.field_type, out.default_value)}, # type: ignore[call-overload] + ), + prefix=prefix, + exclude=tuple(call.custom_fields.keys()), + ) + + if body is None: + return {"title": "EmptyPayload", "type": "null"} + + return body + + @overload def get_response_schema(call: None, prefix: str = "") -> None: ... diff --git a/faststream/specification/asyncapi/v3_0_0/generate.py b/faststream/specification/asyncapi/v3_0_0/generate.py index b08b86f823..515a9c94b8 100644 --- a/faststream/specification/asyncapi/v3_0_0/generate.py +++ b/faststream/specification/asyncapi/v3_0_0/generate.py @@ -17,6 +17,7 @@ License, Message, Operation, + OperationReply, Reference, Server, Tag, @@ -25,10 +26,13 @@ OperationBinding, http as http_bindings, ) +from faststream.specification.asyncapi.v3_0_0.schema.operation_reply import ( + OperationReplyAddress, +) from faststream.specification.asyncapi.v3_0_0.schema.operations import Action if TYPE_CHECKING: - from faststream._internal.basic_types import AnyHttpUrl + from faststream._internal.basic_types import AnyCallable, AnyHttpUrl from faststream._internal.broker import BrokerUsecase from faststream._internal.types import ConnectionType, MsgType from faststream.asgi.handlers import HttpHandler @@ -64,6 +68,7 @@ def get_app_schema( channels, operations = get_broker_channels(broker) messages: dict[str, Message] = {} + reply_messages: dict[str, Message] = {} payloads: dict[str, dict[str, Any]] = {} for channel in channels.values(): @@ -89,6 +94,23 @@ def get_app_schema( channel.messages = msgs + for operation_name, operation in operations.items(): + reply_msgs: dict[str, Message | Reference] = {} + if not operation.reply: + continue + for message in operation.reply.messages: + assert isinstance(message, Message) + + reply_msgs["ReplyMessage"] = _resolve_reply_payloads( + f"{operation_name.removesuffix('Subscribe')}:ReplyMessage", + message, + payloads, + reply_messages, + ) + operation.reply.messages = list(reply_msgs.values()) + + messages.update(reply_messages) + return ApplicationSchema( info=ApplicationInfo( title=title, @@ -166,6 +188,7 @@ def get_broker_channels( """Get the broker channels for an application.""" channels = {} operations = {} + operations_by_handler: dict[AnyCallable, Operation] = {} for sub in filter(lambda s: s.specification.include_in_schema, broker.subscribers): for sub_key, sub_channel in sub.schema().items(): @@ -181,7 +204,7 @@ def get_broker_channels( channels[channel_key] = channel_obj - operations[f"{channel_key}Subscribe"] = Operation.from_sub( + operation = Operation.from_sub( messages=[ Reference(**{ "$ref": f"#/channels/{channel_key}/messages/{msg_name}", @@ -190,7 +213,22 @@ def get_broker_channels( ], channel=Reference(**{"$ref": f"#/channels/{channel_key}"}), operation=sub_channel.operation, + reply=OperationReply( + messages=[Message.from_spec(sub_channel.operation.reply_message)] + if sub_channel.operation.reply_message + else [], + address=OperationReplyAddress( + description=None, + location="$message.header#/replyTo", + ), + channel=None, + ) + if not sub._no_reply + else None, ) + operations[f"{channel_key}Subscribe"] = operation + for call in sub.specification.calls: + operations_by_handler[call.handler._original_call] = operation for pub in filter(lambda p: p.specification.include_in_schema, broker.publishers): for pub_key, pub_channel in pub.schema().items(): @@ -215,6 +253,13 @@ def get_broker_channels( channel=Reference(**{"$ref": f"#/channels/{channel_key}"}), operation=pub_channel.operation, ) + for call in pub.specification.calls: + sub_operation = operations_by_handler.get(call) + if sub_operation is None or sub_operation.reply is None: + continue + sub_operation.reply.channel = Reference(**{ + "$ref": f"#/channels/{channel_key}" + }) return channels, operations @@ -257,20 +302,18 @@ def _get_http_binding_method(methods: Sequence[str]) -> str: return next((method for method in methods if method != "HEAD"), "HEAD") -def _resolve_msg_payloads( - message_name: str, - m: Message, - channel_name: str, +def _resolve_payloads_common( + *, + m: "Message", payloads: dict[str, Any], - messages: dict[str, Any], -) -> Reference: + messages_target: dict[str, Any], + message_ref: str, + default_payload_title: str, +) -> "Reference": assert isinstance(m.payload, dict) m.payload = move_pydantic_refs(m.payload, DEF_KEY) - message_name = clear_key(message_name) - channel_name = clear_key(channel_name) - if DEF_KEY in m.payload: payloads.update(m.payload.pop(DEF_KEY)) @@ -285,19 +328,24 @@ def _resolve_msg_payloads( defs = payload.pop(DEF_KEY) or {} for def_name, def_schema in defs.items(): payloads[clear_key(def_name)] = def_schema + processed_payloads[clear_key(name)] = payload - one_of_list.append(Reference(**{"$ref": f"#/components/schemas/{name}"})) + one_of_list.append( + Reference(**{"$ref": f"#/components/schemas/{name}"}) + ) payloads.update(processed_payloads) m.payload["oneOf"] = one_of_list + assert m.title - messages[clear_key(m.title)] = m - return Reference( - **{"$ref": f"#/components/messages/{channel_name}:{message_name}"}, - ) + messages_target[clear_key(m.title)] = m + + return Reference(**{"$ref": message_ref}) + payloads.update(m.payload.pop(DEF_KEY, {})) - payload_name = m.payload.get("title", f"{channel_name}:{message_name}:Payload") + + payload_name = m.payload.get("title", default_payload_title) payload_name = clear_key(payload_name) if payload_name in payloads and payloads[payload_name] != m.payload: @@ -309,8 +357,44 @@ def _resolve_msg_payloads( payloads[payload_name] = m.payload m.payload = {"$ref": f"#/components/schemas/{payload_name}"} + assert m.title - messages[clear_key(m.title)] = m - return Reference( - **{"$ref": f"#/components/messages/{channel_name}:{message_name}"}, + messages_target[clear_key(m.title)] = m + + return Reference(**{"$ref": message_ref}) + + +def _resolve_reply_payloads( + message_name: str, + m: "Message", + payloads: dict[str, Any], + reply_messages: dict[str, Any], +) -> "Reference": + message_name = clear_key(message_name) + + return _resolve_payloads_common( + m=m, + payloads=payloads, + messages_target=reply_messages, + message_ref=f"#/components/messages/{message_name}", + default_payload_title=f"{message_name}:Payload", + ) + + +def _resolve_msg_payloads( + message_name: str, + m: "Message", + channel_name: str, + payloads: dict[str, Any], + messages: dict[str, Any], +) -> "Reference": + message_name = clear_key(message_name) + channel_name = clear_key(channel_name) + + return _resolve_payloads_common( + m=m, + payloads=payloads, + messages_target=messages, + message_ref=f"#/components/messages/{channel_name}:{message_name}", + default_payload_title=f"{channel_name}:{message_name}:Payload", ) diff --git a/faststream/specification/asyncapi/v3_0_0/schema/__init__.py b/faststream/specification/asyncapi/v3_0_0/schema/__init__.py index e0cbcbd7b2..317f57b750 100644 --- a/faststream/specification/asyncapi/v3_0_0/schema/__init__.py +++ b/faststream/specification/asyncapi/v3_0_0/schema/__init__.py @@ -5,6 +5,7 @@ from .info import ApplicationInfo from .license import License from .message import CorrelationId, Message +from .operation_reply import OperationReply from .operations import Operation from .schema import ApplicationSchema from .servers import Server, ServerVariable @@ -23,6 +24,7 @@ "License", "Message", "Operation", + "OperationReply", "Parameter", "Reference", "Server", diff --git a/faststream/specification/asyncapi/v3_0_0/schema/operation_reply.py b/faststream/specification/asyncapi/v3_0_0/schema/operation_reply.py new file mode 100644 index 0000000000..4cc9fb5985 --- /dev/null +++ b/faststream/specification/asyncapi/v3_0_0/schema/operation_reply.py @@ -0,0 +1,24 @@ +from pydantic import BaseModel + +from faststream._internal._compat import PYDANTIC_V2 +from faststream.specification.asyncapi.v3_0_0.schema.message import Message + +from .utils import Reference + + +class OperationReplyAddress(BaseModel): + description: str | None = None + location: str + + +class OperationReply(BaseModel): + messages: list[Message | Reference] + channel: Reference | None = None + address: OperationReplyAddress | None = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + else: + + class Config: + extra = "allow" diff --git a/faststream/specification/asyncapi/v3_0_0/schema/operations.py b/faststream/specification/asyncapi/v3_0_0/schema/operations.py index 05ca29fbf0..7bdfe9b2f4 100644 --- a/faststream/specification/asyncapi/v3_0_0/schema/operations.py +++ b/faststream/specification/asyncapi/v3_0_0/schema/operations.py @@ -9,6 +9,7 @@ from .bindings import OperationBinding from .channels import Channel +from .operation_reply import OperationReply from .tag import Tag from .utils import Reference @@ -43,6 +44,8 @@ class Operation(BaseModel): security: dict[str, list[str]] | None = None + reply: OperationReply | None = None + # TODO # traits @@ -62,12 +65,14 @@ def from_sub( messages: list[Reference], channel: Reference, operation: OperationSpec, + reply: OperationReply | None = None, ) -> Self: return cls( action=Action.RECEIVE, messages=messages, channel=channel, bindings=OperationBinding.from_sub(operation.bindings), + reply=reply, summary=None, description=None, security=None, @@ -86,6 +91,7 @@ def from_pub( messages=messages, channel=channel, bindings=OperationBinding.from_pub(operation.bindings), + reply=None, summary=None, description=None, security=None, diff --git a/faststream/specification/schema/operation/model.py b/faststream/specification/schema/operation/model.py index 58f426dc17..b2780c3cdf 100644 --- a/faststream/specification/schema/operation/model.py +++ b/faststream/specification/schema/operation/model.py @@ -8,3 +8,4 @@ class Operation: message: Message bindings: OperationBinding | None + reply_message: Message | None = None diff --git a/faststream/specification/schema/reply/__init__.py b/faststream/specification/schema/reply/__init__.py new file mode 100644 index 0000000000..09ef2e907f --- /dev/null +++ b/faststream/specification/schema/reply/__init__.py @@ -0,0 +1,3 @@ +from .model import OperationReply, OperationReplyAddress + +__all__ = ("OperationReply", "OperationReplyAddress") diff --git a/faststream/specification/schema/reply/model.py b/faststream/specification/schema/reply/model.py new file mode 100644 index 0000000000..3ec3441cfa --- /dev/null +++ b/faststream/specification/schema/reply/model.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass + +from faststream.specification.schema import Message + + +@dataclass +class OperationReplyAddress: + location: str + description: str | None = None + + +@dataclass +class OperationReply: + message: Message | None + address: OperationReplyAddress | None