Skip to content

Commit 55ec8d2

Browse files
committed
Allow to wait for MQTT subscription
1 parent 5613be3 commit 55ec8d2

File tree

3 files changed

+103
-1
lines changed

3 files changed

+103
-1
lines changed

homeassistant/components/mqtt/client.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@
3838
get_hassjob_callable_job_type,
3939
)
4040
from homeassistant.exceptions import HomeAssistantError
41-
from homeassistant.helpers.dispatcher import async_dispatcher_send
41+
from homeassistant.helpers.dispatcher import (
42+
async_dispatcher_connect,
43+
async_dispatcher_send,
44+
)
4245
from homeassistant.helpers.importlib import async_import_module
4346
from homeassistant.helpers.start import async_at_started
4447
from homeassistant.helpers.typing import ConfigType
@@ -71,6 +74,7 @@
7174
DEFAULT_WS_PATH,
7275
DOMAIN,
7376
MQTT_CONNECTION_STATE,
77+
MQTT_PROCESSED_SUBSCRIPTIONS,
7478
PROTOCOL_5,
7579
PROTOCOL_31,
7680
TRANSPORT_WEBSOCKETS,
@@ -109,6 +113,7 @@
109113
SUBSCRIBE_COOLDOWN = 0.1
110114
UNSUBSCRIBE_COOLDOWN = 0.1
111115
TIMEOUT_ACK = 10
116+
SUBSCRIBE_TIMEOUT = 10
112117
RECONNECT_INTERVAL_SECONDS = 10
113118

114119
MAX_WILDCARD_SUBSCRIBES_PER_CALL = 1
@@ -191,11 +196,47 @@ async def async_subscribe(
191196
msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
192197
qos: int = DEFAULT_QOS,
193198
encoding: str | None = DEFAULT_ENCODING,
199+
wait: bool = False,
194200
) -> CALLBACK_TYPE:
195201
"""Subscribe to an MQTT topic.
196202
197203
Call the return value to unsubscribe.
198204
"""
205+
subscription_complete: asyncio.Future[None]
206+
207+
async def _sync_mqtt_subscribe(subscriptions: list[tuple[str, int]]) -> None:
208+
if (topic, qos) not in subscriptions:
209+
return
210+
subscription_complete.set_result(None)
211+
212+
def _async_timeout_subscribe() -> None:
213+
if not subscription_complete.done():
214+
subscription_complete.set_exception(TimeoutError)
215+
216+
if (
217+
wait
218+
and DATA_MQTT in hass.data
219+
and not hass.data[DATA_MQTT].client._matching_subscriptions(topic) # noqa: SLF001
220+
):
221+
subscription_complete = hass.loop.create_future()
222+
dispatcher = async_dispatcher_connect(
223+
hass, MQTT_PROCESSED_SUBSCRIPTIONS, _sync_mqtt_subscribe
224+
)
225+
subscribe_callback = async_subscribe_internal(
226+
hass, topic, msg_callback, qos, encoding
227+
)
228+
try:
229+
hass.loop.call_later(SUBSCRIBE_TIMEOUT, _async_timeout_subscribe)
230+
await subscription_complete
231+
except TimeoutError as exc:
232+
raise HomeAssistantError(
233+
translation_domain=DOMAIN,
234+
translation_key="subscribe_timeout",
235+
) from exc
236+
finally:
237+
dispatcher()
238+
return subscribe_callback
239+
199240
return async_subscribe_internal(hass, topic, msg_callback, qos, encoding)
200241

201242

@@ -963,6 +1004,7 @@ async def _async_perform_subscriptions(self) -> None:
9631004
self._last_subscribe = time.monotonic()
9641005

9651006
await self._async_wait_for_mid_or_raise(mid, result)
1007+
async_dispatcher_send(self.hass, MQTT_PROCESSED_SUBSCRIPTIONS, chunk_list)
9661008

9671009
async def _async_perform_unsubscribes(self) -> None:
9681010
"""Perform pending MQTT client unsubscribes."""

homeassistant/components/mqtt/const.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@
373373
LOGGER = logging.getLogger(__package__)
374374

375375
MQTT_CONNECTION_STATE = "mqtt_connection_state"
376+
MQTT_PROCESSED_SUBSCRIPTIONS = "mqtt_processed_subscriptions"
376377

377378
PAYLOAD_EMPTY_JSON = "{}"
378379
PAYLOAD_NONE = "None"

tests/components/mqtt/test_client.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,65 @@ async def test_subscribe_topic(
282282
unsub()
283283

284284

285+
async def test_subscribe_topic_and_wait(
286+
hass: HomeAssistant,
287+
mock_debouncer: asyncio.Event,
288+
setup_with_birth_msg_client_mock: MqttMockPahoClient,
289+
recorded_calls: list[ReceiveMessage],
290+
record_calls: MessageCallbackType,
291+
) -> None:
292+
"""Test the subscription of a topic."""
293+
await mock_debouncer.wait()
294+
mock_debouncer.clear()
295+
unsub_no_wait = await mqtt.async_subscribe(hass, "other-test-topic/#", record_calls)
296+
unsub_wait = await mqtt.async_subscribe(hass, "test-topic", record_calls, wait=True)
297+
298+
async_fire_mqtt_message(hass, "test-topic", "test-payload")
299+
async_fire_mqtt_message(hass, "other-test-topic/test", "other-test-payload")
300+
301+
await hass.async_block_till_done()
302+
assert len(recorded_calls) == 2
303+
assert recorded_calls[0].topic == "test-topic"
304+
assert recorded_calls[0].payload == "test-payload"
305+
assert recorded_calls[1].topic == "other-test-topic/test"
306+
assert recorded_calls[1].payload == "other-test-payload"
307+
308+
unsub_no_wait()
309+
unsub_wait()
310+
311+
async_fire_mqtt_message(hass, "test-topic", "test-payload")
312+
313+
await hass.async_block_till_done()
314+
assert len(recorded_calls) == 2
315+
316+
# Cannot unsubscribe twice
317+
with pytest.raises(HomeAssistantError):
318+
unsub_no_wait()
319+
320+
with pytest.raises(HomeAssistantError):
321+
unsub_wait()
322+
323+
324+
async def test_subscribe_topic_and_wait_timeout(
325+
hass: HomeAssistant,
326+
mock_debouncer: asyncio.Event,
327+
setup_with_birth_msg_client_mock: MqttMockPahoClient,
328+
recorded_calls: list[ReceiveMessage],
329+
record_calls: MessageCallbackType,
330+
) -> None:
331+
"""Test the subscription of a topic."""
332+
await mock_debouncer.wait()
333+
mock_debouncer.clear()
334+
with (
335+
patch("homeassistant.components.mqtt.client.SUBSCRIBE_TIMEOUT", 0),
336+
pytest.raises(HomeAssistantError) as exc,
337+
):
338+
await mqtt.async_subscribe(hass, "test-topic", record_calls, wait=True)
339+
340+
assert exc.value.translation_domain == mqtt.DOMAIN
341+
assert exc.value.translation_key == "subscribe_timeout"
342+
343+
285344
@pytest.mark.usefixtures("mqtt_mock_entry")
286345
async def test_subscribe_topic_not_initialize(
287346
hass: HomeAssistant, record_calls: MessageCallbackType

0 commit comments

Comments
 (0)