Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion faststream/_internal/endpoint/subscriber/specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@

from faststream._internal.configs import BrokerConfig, SubscriberSpecificationConfig
from faststream.exceptions import SetupError
from faststream.specification.asyncapi.message import parse_handler_params
from faststream.specification.asyncapi.message import (
parse_handler_params,
parse_handler_return,
)
from faststream.specification.asyncapi.utils import to_camelcase

if TYPE_CHECKING:
Expand Down Expand Up @@ -78,6 +81,34 @@ def get_payloads(self) -> list[tuple["dict[str, Any]", str]]:

return payloads

def get_reply_payloads(self) -> list[tuple["dict[str, Any]", str]]:
payloads: list[tuple[dict[str, Any], str]] = []

call_name = self.call_name

for h in self.calls:
if h.dependant is None:
msg = "You should setup `Handler` at first."
raise SetupError(msg)

reply_body = parse_handler_return(
h.dependant,
prefix=f"{self.config.title_ or call_name}:ReplyMessage",
)
payloads.append((reply_body, to_camelcase(h.name)))

if not self.calls:
payloads.append(
(
{
"title": f"{self.config.title_ or call_name}:ReplyMessage:Payload",
},
to_camelcase(call_name),
),
)

return payloads

@property
@abstractmethod
def name(self) -> str:
Expand Down
5 changes: 5 additions & 0 deletions faststream/confluent/subscriber/specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def name(self) -> str:

def get_schema(self) -> dict[str, SubscriberSpec]:
payloads = self.get_payloads()
reply_payloads = self.get_reply_payloads()

channels = {}
for t in self.topics:
Expand All @@ -43,6 +44,10 @@ def get_schema(self) -> dict[str, SubscriberSpec]:
title=f"{handler_name}:Message",
payload=resolve_payloads(payloads),
),
reply_message=Message(
title=f"{handler_name}:ReplyMessage",
payload=resolve_payloads(reply_payloads),
),
bindings=None,
),
bindings=ChannelBinding(
Expand Down
5 changes: 5 additions & 0 deletions faststream/kafka/subscriber/specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def name(self) -> str:

def get_schema(self) -> dict[str, SubscriberSpec]:
payloads = self.get_payloads()
reply_payloads = self.get_reply_payloads()

channels = {}
for t in self.topics:
Expand All @@ -43,6 +44,10 @@ def get_schema(self) -> dict[str, SubscriberSpec]:
title=f"{handler_name}:Message",
payload=resolve_payloads(payloads),
),
reply_message=Message(
title=f"{handler_name}:ReplyMessage",
payload=resolve_payloads(reply_payloads),
),
bindings=None,
),
bindings=ChannelBinding(
Expand Down
5 changes: 5 additions & 0 deletions faststream/nats/subscriber/specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def name(self) -> str:

def get_schema(self) -> dict[str, SubscriberSpec]:
payloads = self.get_payloads()
reply_payloads = self.get_reply_payloads()

return {
self.name: SubscriberSpec(
Expand All @@ -32,6 +33,10 @@ def get_schema(self) -> dict[str, SubscriberSpec]:
title=f"{self.name}:Message",
payload=resolve_payloads(payloads),
),
reply_message=Message(
title=f"{self.name}:ReplyMessage",
payload=resolve_payloads(reply_payloads),
),
bindings=None,
),
bindings=ChannelBinding(
Expand Down
5 changes: 5 additions & 0 deletions faststream/rabbit/subscriber/specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def name(self) -> str:

def get_schema(self) -> dict[str, SubscriberSpec]:
payloads = self.get_payloads()
reply_payloads = self.get_reply_payloads()

queue = self.config.queue.add_prefix(self._outer_config.prefix)

Expand Down Expand Up @@ -59,6 +60,10 @@ def get_schema(self) -> dict[str, SubscriberSpec]:
title=f"{channel_name}:Message",
payload=resolve_payloads(payloads),
),
reply_message=Message(
title=f"{channel_name}:ReplyMessage",
payload=resolve_payloads(reply_payloads),
),
),
bindings=ChannelBinding(
amqp=amqp.ChannelBinding(
Expand Down
5 changes: 5 additions & 0 deletions faststream/redis/subscriber/specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class RedisSubscriberSpecification(
):
def get_schema(self) -> dict[str, SubscriberSpec]:
payloads = self.get_payloads()
reply_payloads = self.get_reply_payloads()

return {
self.name: SubscriberSpec(
Expand All @@ -29,6 +30,10 @@ def get_schema(self) -> dict[str, SubscriberSpec]:
title=f"{self.name}:Message",
payload=resolve_payloads(payloads),
),
reply_message=Message(
title=f"{self.name}:ReplyMessage",
payload=resolve_payloads(reply_payloads),
),
bindings=None,
),
bindings=ChannelBinding(
Expand Down
23 changes: 23 additions & 0 deletions faststream/specification/asyncapi/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,29 @@ def parse_handler_params(call: "CallModel", prefix: str = "") -> dict[str, Any]:
return body


def parse_handler_return(call: "CallModel", prefix: str = "") -> dict[str, Any]:
"""Parses the handler parameters."""
model_container = getattr(call, "serializer", call)
response_option = getattr(model_container, "response_option", None)
if not response_option:
return {"title": "EmptyPayload", "type": "null"}
out = response_option["return"]

body = get_model_schema(
create_model(
"",
**{out.field_name: (out.field_type, out.default_value)}, # type: ignore[call-overload]
),
prefix=prefix,
exclude=tuple(call.custom_fields.keys()),
)

if body is None:
return {"title": "EmptyPayload", "type": "null"}

return body


@overload
def get_response_schema(call: None, prefix: str = "") -> None: ...

Expand Down
124 changes: 104 additions & 20 deletions faststream/specification/asyncapi/v3_0_0/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
License,
Message,
Operation,
OperationReply,
Reference,
Server,
Tag,
Expand All @@ -25,10 +26,13 @@
OperationBinding,
http as http_bindings,
)
from faststream.specification.asyncapi.v3_0_0.schema.operation_reply import (
OperationReplyAddress,
)
from faststream.specification.asyncapi.v3_0_0.schema.operations import Action

if TYPE_CHECKING:
from faststream._internal.basic_types import AnyHttpUrl
from faststream._internal.basic_types import AnyCallable, AnyHttpUrl
from faststream._internal.broker import BrokerUsecase
from faststream._internal.types import ConnectionType, MsgType
from faststream.asgi.handlers import HttpHandler
Expand Down Expand Up @@ -64,6 +68,7 @@ def get_app_schema(
channels, operations = get_broker_channels(broker)

messages: dict[str, Message] = {}
reply_messages: dict[str, Message] = {}
payloads: dict[str, dict[str, Any]] = {}

for channel in channels.values():
Expand All @@ -89,6 +94,23 @@ def get_app_schema(

channel.messages = msgs

for operation_name, operation in operations.items():
reply_msgs: dict[str, Message | Reference] = {}
if not operation.reply:
continue
for message in operation.reply.messages:
assert isinstance(message, Message)

reply_msgs["ReplyMessage"] = _resolve_reply_payloads(
f"{operation_name.removesuffix('Subscribe')}:ReplyMessage",
message,
payloads,
reply_messages,
)
operation.reply.messages = list(reply_msgs.values())

messages.update(reply_messages)

return ApplicationSchema(
info=ApplicationInfo(
title=title,
Expand Down Expand Up @@ -166,6 +188,7 @@ def get_broker_channels(
"""Get the broker channels for an application."""
channels = {}
operations = {}
operations_by_handler: dict[AnyCallable, Operation] = {}

for sub in filter(lambda s: s.specification.include_in_schema, broker.subscribers):
for sub_key, sub_channel in sub.schema().items():
Expand All @@ -181,7 +204,7 @@ def get_broker_channels(

channels[channel_key] = channel_obj

operations[f"{channel_key}Subscribe"] = Operation.from_sub(
operation = Operation.from_sub(
messages=[
Reference(**{
"$ref": f"#/channels/{channel_key}/messages/{msg_name}",
Expand All @@ -190,7 +213,22 @@ def get_broker_channels(
],
channel=Reference(**{"$ref": f"#/channels/{channel_key}"}),
operation=sub_channel.operation,
reply=OperationReply(
messages=[Message.from_spec(sub_channel.operation.reply_message)]
if sub_channel.operation.reply_message
else [],
address=OperationReplyAddress(
description=None,
location="$message.header#/replyTo",
),
channel=None,
)
if not sub._no_reply
else None,
)
operations[f"{channel_key}Subscribe"] = operation
for call in sub.specification.calls:
operations_by_handler[call.handler._original_call] = operation

for pub in filter(lambda p: p.specification.include_in_schema, broker.publishers):
for pub_key, pub_channel in pub.schema().items():
Expand All @@ -215,6 +253,13 @@ def get_broker_channels(
channel=Reference(**{"$ref": f"#/channels/{channel_key}"}),
operation=pub_channel.operation,
)
for call in pub.specification.calls:
sub_operation = operations_by_handler.get(call)
if sub_operation is None or sub_operation.reply is None:
continue
sub_operation.reply.channel = Reference(**{
"$ref": f"#/channels/{channel_key}"
})

return channels, operations

Expand Down Expand Up @@ -257,20 +302,18 @@ def _get_http_binding_method(methods: Sequence[str]) -> str:
return next((method for method in methods if method != "HEAD"), "HEAD")


def _resolve_msg_payloads(
message_name: str,
m: Message,
channel_name: str,
def _resolve_payloads_common(
*,
m: "Message",
payloads: dict[str, Any],
messages: dict[str, Any],
) -> Reference:
messages_target: dict[str, Any],
message_ref: str,
default_payload_title: str,
) -> "Reference":
assert isinstance(m.payload, dict)

m.payload = move_pydantic_refs(m.payload, DEF_KEY)

message_name = clear_key(message_name)
channel_name = clear_key(channel_name)

if DEF_KEY in m.payload:
payloads.update(m.payload.pop(DEF_KEY))

Expand All @@ -285,19 +328,24 @@ def _resolve_msg_payloads(
defs = payload.pop(DEF_KEY) or {}
for def_name, def_schema in defs.items():
payloads[clear_key(def_name)] = def_schema

processed_payloads[clear_key(name)] = payload
one_of_list.append(Reference(**{"$ref": f"#/components/schemas/{name}"}))
one_of_list.append(
Reference(**{"$ref": f"#/components/schemas/{name}"})
)

payloads.update(processed_payloads)
m.payload["oneOf"] = one_of_list

assert m.title
messages[clear_key(m.title)] = m
return Reference(
**{"$ref": f"#/components/messages/{channel_name}:{message_name}"},
)
messages_target[clear_key(m.title)] = m

return Reference(**{"$ref": message_ref})


payloads.update(m.payload.pop(DEF_KEY, {}))
payload_name = m.payload.get("title", f"{channel_name}:{message_name}:Payload")

payload_name = m.payload.get("title", default_payload_title)
payload_name = clear_key(payload_name)

if payload_name in payloads and payloads[payload_name] != m.payload:
Expand All @@ -309,8 +357,44 @@ def _resolve_msg_payloads(

payloads[payload_name] = m.payload
m.payload = {"$ref": f"#/components/schemas/{payload_name}"}

assert m.title
messages[clear_key(m.title)] = m
return Reference(
**{"$ref": f"#/components/messages/{channel_name}:{message_name}"},
messages_target[clear_key(m.title)] = m

return Reference(**{"$ref": message_ref})


def _resolve_reply_payloads(
message_name: str,
m: "Message",
payloads: dict[str, Any],
reply_messages: dict[str, Any],
) -> "Reference":
message_name = clear_key(message_name)

return _resolve_payloads_common(
m=m,
payloads=payloads,
messages_target=reply_messages,
message_ref=f"#/components/messages/{message_name}",
default_payload_title=f"{message_name}:Payload",
)


def _resolve_msg_payloads(
message_name: str,
m: "Message",
channel_name: str,
payloads: dict[str, Any],
messages: dict[str, Any],
) -> "Reference":
message_name = clear_key(message_name)
channel_name = clear_key(channel_name)

return _resolve_payloads_common(
m=m,
payloads=payloads,
messages_target=messages,
message_ref=f"#/components/messages/{channel_name}:{message_name}",
default_payload_title=f"{channel_name}:{message_name}:Payload",
)
Loading