Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions homeassistant/components/mqtt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from . import debug_info, discovery
from .client import (
MQTT,
async_await_subscription,
async_publish,
async_subscribe,
async_subscribe_internal,
Expand Down Expand Up @@ -159,6 +160,7 @@
"PublishPayloadType",
"ReceiveMessage",
"SetupPhases",
"async_await_subscription",
"async_check_config_schema",
"async_create_certificate_temp_files",
"async_forward_entry_setup_and_setup_discovery",
Expand Down
110 changes: 107 additions & 3 deletions homeassistant/components/mqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@
get_hassjob_callable_job_type,
)
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.dispatcher import (
async_dispatcher_connect,
async_dispatcher_send,
)
from homeassistant.helpers.importlib import async_import_module
from homeassistant.helpers.start import async_at_started
from homeassistant.helpers.typing import ConfigType
Expand Down Expand Up @@ -71,6 +74,7 @@
DEFAULT_WS_PATH,
DOMAIN,
MQTT_CONNECTION_STATE,
MQTT_PROCESSED_SUBSCRIPTIONS,
PROTOCOL_5,
PROTOCOL_31,
TRANSPORT_WEBSOCKETS,
Expand Down Expand Up @@ -109,6 +113,7 @@
SUBSCRIBE_COOLDOWN = 0.1
UNSUBSCRIBE_COOLDOWN = 0.1
TIMEOUT_ACK = 10
SUBSCRIBE_TIMEOUT = 10
RECONNECT_INTERVAL_SECONDS = 10

MAX_WILDCARD_SUBSCRIBES_PER_CALL = 1
Expand Down Expand Up @@ -184,18 +189,112 @@ async def async_publish(
)


async def async_await_subscription(
hass: HomeAssistant,
topic: str,
qos: int = DEFAULT_QOS,
) -> None:
"""Wait for an MQTT subscription to be completed."""
subscription_complete: asyncio.Future[None]

async def _sync_mqtt_subscribe(subscriptions: list[tuple[str, int]]) -> None:
if (topic, qos) not in subscriptions:
return
subscription_complete.set_result(None)

def _async_timeout_subscribe() -> None:
if not subscription_complete.done():
subscription_complete.set_exception(TimeoutError)

try:
mqtt_data = hass.data[DATA_MQTT]
except KeyError as exc:
raise HomeAssistantError(
f"Cannot wait for subscription to topic '{topic}' QoS {qos}, "
"make sure MQTT is set up correctly",
translation_key="mqtt_not_setup_cannot_wait_for_subscribe",
translation_domain=DOMAIN,
translation_placeholders={"topic": topic, "qos": str(qos)},
) from exc
client = mqtt_data.client

if not client.is_active_subscription(topic):
raise HomeAssistantError(
f"Cannot find subscription to topic '{topic}' and QoS {qos}, "
"make sure the subscription is successful",
translation_key="mqtt_not_setup_cannot_find_subscription",
translation_domain=DOMAIN,
translation_placeholders={"topic": topic, "qos": str(qos)},
)
if not client.is_pending_subscription(topic):
# Existing non pending subscription are assumed to be completed already
return

subscription_complete = hass.loop.create_future()
dispatcher = async_dispatcher_connect(
hass, MQTT_PROCESSED_SUBSCRIPTIONS, _sync_mqtt_subscribe
)
try:
hass.loop.call_later(SUBSCRIBE_TIMEOUT, _async_timeout_subscribe)
await subscription_complete
except TimeoutError as exc:
raise HomeAssistantError(
translation_domain=DOMAIN,
translation_key="subscribe_timeout",
) from exc
finally:
dispatcher()
return


@bind_hass
async def async_subscribe(
hass: HomeAssistant,
topic: str,
msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
qos: int = DEFAULT_QOS,
encoding: str | None = DEFAULT_ENCODING,
wait: bool = False,
) -> CALLBACK_TYPE:
"""Subscribe to an MQTT topic.

Call the return value to unsubscribe.
"""
subscription_complete: asyncio.Future[None]

async def _sync_mqtt_subscribe(subscriptions: list[tuple[str, int]]) -> None:
if (topic, qos) not in subscriptions:
return
subscription_complete.set_result(None)

def _async_timeout_subscribe() -> None:
if not subscription_complete.done():
subscription_complete.set_exception(TimeoutError)

if (
wait
and DATA_MQTT in hass.data
and not hass.data[DATA_MQTT].client.is_active_subscription(topic)
):
subscription_complete = hass.loop.create_future()
dispatcher = async_dispatcher_connect(
hass, MQTT_PROCESSED_SUBSCRIPTIONS, _sync_mqtt_subscribe
)
subscribe_callback = async_subscribe_internal(
hass, topic, msg_callback, qos, encoding
)
try:
hass.loop.call_later(SUBSCRIBE_TIMEOUT, _async_timeout_subscribe)
await subscription_complete
except TimeoutError as exc:
raise HomeAssistantError(
translation_domain=DOMAIN,
translation_key="subscribe_timeout",
) from exc
finally:
dispatcher()
return subscribe_callback

return async_subscribe_internal(hass, topic, msg_callback, qos, encoding)


Expand Down Expand Up @@ -640,12 +739,16 @@ def _async_on_socket_unregister_write(
if fileno > -1:
self.loop.remove_writer(sock)

def _is_active_subscription(self, topic: str) -> bool:
def is_active_subscription(self, topic: str) -> bool:
"""Check if a topic has an active subscription."""
return topic in self._simple_subscriptions or any(
other.topic == topic for other in self._wildcard_subscriptions
)

def is_pending_subscription(self, topic: str) -> bool:
"""Check if a topic has a pending subscription."""
return topic in self._pending_subscriptions

async def async_publish(
self, topic: str, payload: PublishPayloadType, qos: int, retain: bool
) -> None:
Expand Down Expand Up @@ -899,7 +1002,7 @@ def _async_remove(self, subscription: Subscription) -> None:
@callback
def _async_unsubscribe(self, topic: str) -> None:
"""Unsubscribe from a topic."""
if self._is_active_subscription(topic):
if self.is_active_subscription(topic):
if self._max_qos[topic] == 0:
return
subs = self._matching_subscriptions(topic)
Expand Down Expand Up @@ -963,6 +1066,7 @@ async def _async_perform_subscriptions(self) -> None:
self._last_subscribe = time.monotonic()

await self._async_wait_for_mid_or_raise(mid, result)
async_dispatcher_send(self.hass, MQTT_PROCESSED_SUBSCRIPTIONS, chunk_list)

async def _async_perform_unsubscribes(self) -> None:
"""Perform pending MQTT client unsubscribes."""
Expand Down
1 change: 1 addition & 0 deletions homeassistant/components/mqtt/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@
LOGGER = logging.getLogger(__package__)

MQTT_CONNECTION_STATE = "mqtt_connection_state"
MQTT_PROCESSED_SUBSCRIPTIONS = "mqtt_processed_subscriptions"

PAYLOAD_EMPTY_JSON = "{}"
PAYLOAD_NONE = "None"
Expand Down
6 changes: 6 additions & 0 deletions homeassistant/components/mqtt/strings.json
Original file line number Diff line number Diff line change
Expand Up @@ -1456,6 +1456,12 @@
"mqtt_not_setup_cannot_subscribe": {
"message": "Cannot subscribe to topic \"{topic}\", make sure MQTT is set up correctly."
},
"mqtt_not_setup_cannot_wait_for_subscribe": {
"message": "Cannot wait for subscription to topic \"{topic}\" and QoS {qos}, make sure MQTT is set up correctly."
},
"mqtt_not_setup_cannot_find_subscription": {
"message": "Cannot find subscription to topic \"{topic}\" and QoS {qos}, make sure the subscription is successful."
},
"mqtt_not_setup_cannot_publish": {
"message": "Cannot publish to topic \"{topic}\", make sure MQTT is set up correctly."
},
Expand Down
146 changes: 146 additions & 0 deletions tests/components/mqtt/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,143 @@ async def test_subscribe_topic(
unsub()


async def test_await_subscription(
hass: HomeAssistant,
mqtt_client_mock: MqttMockPahoClient,
mqtt_mock_entry: MqttMockHAClientGenerator,
recorded_calls: list[ReceiveMessage],
record_calls: MessageCallbackType,
) -> None:
"""Test the subscription of a topic."""
mock_mqtt = await mqtt_mock_entry()
# Fail when no subscription is queued
with pytest.raises(HomeAssistantError) as exc:
await mqtt.async_await_subscription(hass, "test-topic")
assert exc.value.args[0] == (
"Cannot find subscription to topic 'test-topic' "
"and QoS 0, make sure the subscription is successful"
)
assert exc.value.translation_key == "mqtt_not_setup_cannot_find_subscription"
assert exc.value.translation_placeholders == {"topic": "test-topic", "qos": "0"}

# Test awaiting pending subscription
with (
patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.5),
patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.5),
):
unsub_no_wait = await mqtt.async_subscribe(hass, "test-topic", record_calls)
# assert ("test-topic", 0) not in help_all_subscribe_calls(mqtt_client_mock)
await mqtt.async_await_subscription(hass, "test-topic")
assert ("test-topic", 0) in help_all_subscribe_calls(mqtt_client_mock)

async_fire_mqtt_message(hass, "test-topic", "test-payload")

await hass.async_block_till_done()
assert len(recorded_calls) == 1
assert recorded_calls[0].topic == "test-topic"
assert recorded_calls[0].payload == "test-payload"
recorded_calls.clear()

unsub_no_wait()
await hass.async_block_till_done()

# Test existing subscription, should return immediately
unsub = await mqtt.async_subscribe(hass, "test-topic", record_calls, wait=True)
await mqtt.async_await_subscription(hass, "test-topic")

async_fire_mqtt_message(hass, "test-topic", "test-payload")

await hass.async_block_till_done()
assert len(recorded_calls) == 1
assert recorded_calls[0].topic == "test-topic"
assert recorded_calls[0].payload == "test-payload"

assert mock_mqtt.is_active_subscription("test-topic")

unsub()
assert not mock_mqtt.is_active_subscription("test-topic")

recorded_calls.clear()


async def test_subscribe_topic_and_wait(
hass: HomeAssistant,
mqtt_mock_entry: MqttMockHAClientGenerator,
recorded_calls: list[ReceiveMessage],
record_calls: MessageCallbackType,
) -> None:
"""Test the subscription of a topic."""
await mqtt_mock_entry()
unsub_no_wait = await mqtt.async_subscribe(hass, "other-test-topic/#", record_calls)
unsub_wait = await mqtt.async_subscribe(hass, "test-topic", record_calls, wait=True)

async_fire_mqtt_message(hass, "test-topic", "test-payload")
async_fire_mqtt_message(hass, "other-test-topic/test", "other-test-payload")

await hass.async_block_till_done()
assert len(recorded_calls) == 2
assert recorded_calls[0].topic == "test-topic"
assert recorded_calls[0].payload == "test-payload"
assert recorded_calls[1].topic == "other-test-topic/test"
assert recorded_calls[1].payload == "other-test-payload"

unsub_no_wait()
unsub_wait()

async_fire_mqtt_message(hass, "test-topic", "test-payload")

await hass.async_block_till_done()
assert len(recorded_calls) == 2

# Cannot unsubscribe twice
with pytest.raises(HomeAssistantError):
unsub_no_wait()

with pytest.raises(HomeAssistantError):
unsub_wait()


async def test_subscribe_topic_and_wait_timeout(
hass: HomeAssistant,
mqtt_mock_entry: MqttMockHAClientGenerator,
recorded_calls: list[ReceiveMessage],
record_calls: MessageCallbackType,
) -> None:
"""Test the subscription of a topic."""
await mqtt_mock_entry()
with (
patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.5),
patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.5),
patch("homeassistant.components.mqtt.client.SUBSCRIBE_TIMEOUT", 0),
pytest.raises(HomeAssistantError) as exc,
):
await mqtt.async_subscribe(hass, "test-topic", record_calls, wait=True)

assert exc.value.translation_domain == mqtt.DOMAIN
assert exc.value.translation_key == "subscribe_timeout"


async def test_await_subscription_and_wait_timeout(
hass: HomeAssistant,
mock_debouncer: asyncio.Event,
mqtt_mock_entry: MqttMockHAClientGenerator,
record_calls: MessageCallbackType,
) -> None:
"""Test the subscription of a topic."""
await mqtt_mock_entry()
await mqtt.async_subscribe(hass, "test-topic", record_calls)
with (
patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.5),
patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.5),
patch("homeassistant.components.mqtt.client.SUBSCRIBE_TIMEOUT", 0.0),
pytest.raises(HomeAssistantError) as exc,
):
await mqtt.async_await_subscription(hass, "test-topic", 0)

assert exc.value.translation_domain == mqtt.DOMAIN
assert exc.value.translation_key == "subscribe_timeout"


@pytest.mark.usefixtures("mqtt_mock_entry")
async def test_subscribe_topic_not_initialize(
hass: HomeAssistant, record_calls: MessageCallbackType
Expand All @@ -293,6 +430,15 @@ async def test_subscribe_topic_not_initialize(
await mqtt.async_subscribe(hass, "test-topic", record_calls)


@pytest.mark.usefixtures("mqtt_mock_entry")
async def test_await_subscription_not_initialize(hass: HomeAssistant) -> None:
"""Test the subscription of a topic when MQTT was not initialized."""
with pytest.raises(
HomeAssistantError, match=r".*make sure MQTT is set up correctly"
):
await mqtt.async_await_subscription(hass, "test-topic")


async def test_subscribe_mqtt_config_entry_disabled(
hass: HomeAssistant, mqtt_mock: MqttMockHAClient, record_calls: MessageCallbackType
) -> None:
Expand Down
Loading