|
38 | 38 | get_hassjob_callable_job_type, |
39 | 39 | ) |
40 | 40 | from homeassistant.exceptions import HomeAssistantError |
41 | | -from homeassistant.helpers.dispatcher import async_dispatcher_send |
| 41 | +from homeassistant.helpers.dispatcher import ( |
| 42 | + async_dispatcher_connect, |
| 43 | + async_dispatcher_send, |
| 44 | +) |
42 | 45 | from homeassistant.helpers.importlib import async_import_module |
43 | 46 | from homeassistant.helpers.start import async_at_started |
44 | 47 | from homeassistant.helpers.typing import ConfigType |
|
71 | 74 | DEFAULT_WS_PATH, |
72 | 75 | DOMAIN, |
73 | 76 | MQTT_CONNECTION_STATE, |
| 77 | + MQTT_PROCESSED_SUBSCRIPTIONS, |
74 | 78 | PROTOCOL_5, |
75 | 79 | PROTOCOL_31, |
76 | 80 | TRANSPORT_WEBSOCKETS, |
|
109 | 113 | SUBSCRIBE_COOLDOWN = 0.1 |
110 | 114 | UNSUBSCRIBE_COOLDOWN = 0.1 |
111 | 115 | TIMEOUT_ACK = 10 |
| 116 | +SUBSCRIBE_TIMEOUT = 10 |
112 | 117 | RECONNECT_INTERVAL_SECONDS = 10 |
113 | 118 |
|
114 | 119 | MAX_WILDCARD_SUBSCRIBES_PER_CALL = 1 |
@@ -191,11 +196,47 @@ async def async_subscribe( |
191 | 196 | msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None], |
192 | 197 | qos: int = DEFAULT_QOS, |
193 | 198 | encoding: str | None = DEFAULT_ENCODING, |
| 199 | + wait: bool = False, |
194 | 200 | ) -> CALLBACK_TYPE: |
195 | 201 | """Subscribe to an MQTT topic. |
196 | 202 |
|
197 | 203 | Call the return value to unsubscribe. |
198 | 204 | """ |
| 205 | + subscription_complete: asyncio.Future[None] |
| 206 | + |
| 207 | + async def _sync_mqtt_subscribe(subscriptions: list[tuple[str, int]]) -> None: |
| 208 | + if (topic, qos) not in subscriptions: |
| 209 | + return |
| 210 | + subscription_complete.set_result(None) |
| 211 | + |
| 212 | + def _async_timeout_subscribe() -> None: |
| 213 | + if not subscription_complete.done(): |
| 214 | + subscription_complete.set_exception(TimeoutError) |
| 215 | + |
| 216 | + if ( |
| 217 | + wait |
| 218 | + and DATA_MQTT in hass.data |
| 219 | + and not hass.data[DATA_MQTT].client._matching_subscriptions(topic) # noqa: SLF001 |
| 220 | + ): |
| 221 | + subscription_complete = hass.loop.create_future() |
| 222 | + dispatcher = async_dispatcher_connect( |
| 223 | + hass, MQTT_PROCESSED_SUBSCRIPTIONS, _sync_mqtt_subscribe |
| 224 | + ) |
| 225 | + subscribe_callback = async_subscribe_internal( |
| 226 | + hass, topic, msg_callback, qos, encoding |
| 227 | + ) |
| 228 | + try: |
| 229 | + hass.loop.call_later(SUBSCRIBE_TIMEOUT, _async_timeout_subscribe) |
| 230 | + await subscription_complete |
| 231 | + except TimeoutError as exc: |
| 232 | + raise HomeAssistantError( |
| 233 | + translation_domain=DOMAIN, |
| 234 | + translation_key="subscribe_timeout", |
| 235 | + ) from exc |
| 236 | + finally: |
| 237 | + dispatcher() |
| 238 | + return subscribe_callback |
| 239 | + |
199 | 240 | return async_subscribe_internal(hass, topic, msg_callback, qos, encoding) |
200 | 241 |
|
201 | 242 |
|
@@ -963,6 +1004,7 @@ async def _async_perform_subscriptions(self) -> None: |
963 | 1004 | self._last_subscribe = time.monotonic() |
964 | 1005 |
|
965 | 1006 | await self._async_wait_for_mid_or_raise(mid, result) |
| 1007 | + async_dispatcher_send(self.hass, MQTT_PROCESSED_SUBSCRIPTIONS, chunk_list) |
966 | 1008 |
|
967 | 1009 | async def _async_perform_unsubscribes(self) -> None: |
968 | 1010 | """Perform pending MQTT client unsubscribes.""" |
|
0 commit comments