diff --git a/src/swf_typed/_activities.py b/src/swf_typed/_activities.py index 0995033..20dbead 100644 --- a/src/swf_typed/_activities.py +++ b/src/swf_typed/_activities.py @@ -165,7 +165,7 @@ def list_activities( activity_filter: ActivityIdFilter = None, reverse: bool = False, client: "botocore.client.BaseClient" = None, -) -> t.Generator[ActivityInfo, None, None]: +) -> _common.PageConsumer[ActivityInfo]: """List activity types; retrieved semi-lazily. Args: diff --git a/src/swf_typed/_common.py b/src/swf_typed/_common.py index ebfc38a..1c3ef7c 100644 --- a/src/swf_typed/_common.py +++ b/src/swf_typed/_common.py @@ -5,6 +5,7 @@ import datetime import contextlib import typing as t +import collections.abc import concurrent.futures from . import _exceptions @@ -61,6 +62,132 @@ def get_api_args(self) -> t.Dict[str, t.Any]: """Serialise to SWF API request arguments.""" +class PageConsumer(collections.abc.Generator, t.Generic[T]): + """Paged SWF API response iterator.""" + + _next_page_token_key = "nextPageToken" + + def __init__( + self, + api_call: t.Callable[..., t.Dict[str, t.Any]], + model: t.Callable[[t.Dict[str, t.Any]], T], + data_key: str, + response: t.Dict[str, t.Any], + executor: concurrent.futures.Executor, + ) -> None: + """Initialise iteator. + + Args: + api_call: AWS SWF API SDK function + model: response model (constructor) + data_key: response results key + response: first response + executor: concurrency executor + """ + + self.api_call = api_call + self.model = model + self.data_key = data_key + self.response = response + self.executor = executor + + self._i = 0 + self._future: t.Union[concurrent.futures.Future, None] = None + + @property + def _items(self) -> t.List[t.Dict[str, t.Any]]: + return self.response.get(self.data_key) or [] + + def send(self, value: None) -> T: + if ( + self._i == 0 + and not self._future + and self.response.get(self._next_page_token_key) + ): + # Start getting next page (first iteration) + self._future = self.executor.submit( + self.api_call, nextPageToken=self.response[self._next_page_token_key] + ) + + if self._i >= len(self._items): + if not self._future: + raise StopIteration + # Recieve next page + self.response = self._future.result() + self._i = 0 + if self.response.get(self._next_page_token_key): + # Start getting next page + self._future = self.executor.submit( + self.api_call, + nextPageToken=self.response[self._next_page_token_key], + ) + else: + self._future = None + + item = self._items[self._i] + self._i += 1 + return self.model(item) + + def throw(self, typ, val=None, tb=None) -> T: + r = self.send(None) + self._future = None + self.response.pop(self._next_page_token_key, None) + self._i = len(self._items) + return r + + def get_page( + self, + page_token: t.Union[str, None] = None, + start_getting_next_page: bool = True, + ) -> t.Tuple[t.List[T], t.Union[str, None]]: + """Get a full page of results from SWF. + + Uses pre-fetched results if available. + + Args: + page_token: page token + start_getting_next_page: start fetching the next page in another + thread + + Returns: + page of results (structured), and next page's token + """ + + if not page_token and not self._future: + # Use pre-fetched first response + response = self.response + + if start_getting_next_page and self.response.get(self._next_page_token_key): + self._future = self.executor.submit( + self.api_call, + nextPageToken=self.response[self._next_page_token_key], + ) + elif ( + page_token + and self._future + and page_token == self.response.get(self._next_page_token_key) + ): + # Use in-flight response + response = self._future.result() + + if start_getting_next_page: + self.response = response + self._i = 0 + if self.response.get(self._next_page_token_key): + self._future = self.executor.submit( + self.api_call, + nextPageToken=self.response[self._next_page_token_key], + ) + elif page_token: + response = self.api_call(nextPageToken=page_token) + else: + # First page, but we're not certain if `self.response` is the first still + response = self.api_call() + + models = [self.model(item) for item in response.get(self.data_key) or []] + return models, response.get(self._next_page_token_key) + + def ensure_client( client: "botocore.client.BaseClient" = None, ) -> "botocore.client.BaseClient": @@ -95,7 +222,7 @@ def iter_paged( call: t.Callable[..., t.Dict[str, t.Any]], model: t.Callable[[t.Dict[str, t.Any]], T], data_key: str, -) -> t.Generator[T, None, None]: +) -> PageConsumer[T]: """Yield results from paginated method. Method is called immediately, then a generator is returned which yields @@ -112,18 +239,9 @@ def iter_paged( method results, transformed """ - def iter_() -> t.Generator[T, None, None]: - nonlocal response - - while response.get("nextPageToken"): - future = executor.submit(call, nextPageToken=response["nextPageToken"]) - yield from (model(d) for d in response.get(data_key) or []) - response = future.result() - yield from (model(d) for d in response.get(data_key) or []) - executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) response = call() - return iter_() + return PageConsumer(call, model, data_key, response, executor) @contextlib.contextmanager diff --git a/src/swf_typed/_domains.py b/src/swf_typed/_domains.py index d0a491d..770efa1 100644 --- a/src/swf_typed/_domains.py +++ b/src/swf_typed/_domains.py @@ -108,7 +108,7 @@ def list_domains( deprecated: bool = False, reverse: bool = False, client: "botocore.client.BaseClient" = None, -) -> t.Generator[DomainInfo, None, None]: +) -> _common.PageConsumer[DomainInfo]: """List domains; retrieved semi-lazily. Args: diff --git a/src/swf_typed/_executions.py b/src/swf_typed/_executions.py index 3fac67b..3b3f957 100644 --- a/src/swf_typed/_executions.py +++ b/src/swf_typed/_executions.py @@ -522,7 +522,7 @@ def list_closed_executions( ] = None, reverse: bool = False, client: "botocore.client.BaseClient" = None, -) -> t.Generator[ExecutionInfo, None, None]: +) -> _common.PageConsumer[ExecutionInfo]: """List closed workflow executions; retrieved semi-lazily. Args: @@ -562,7 +562,7 @@ def list_open_executions( ] = None, reverse: bool = False, client: "botocore.client.BaseClient" = None, -) -> t.Generator[ExecutionInfo, None, None]: +) -> _common.PageConsumer[ExecutionInfo]: """List open workflow executions; retrieved semi-lazily. Args: diff --git a/src/swf_typed/_history.py b/src/swf_typed/_history.py index 8f15da8..ecf4155 100644 --- a/src/swf_typed/_history.py +++ b/src/swf_typed/_history.py @@ -1899,7 +1899,7 @@ def get_execution_history( domain: str, reverse: bool = False, client: "botocore.client.BaseClient" = None, -) -> t.Generator[Event, None, None]: +) -> _common.PageConsumer[Event]: """Get workflow execution history; retrieved semi-lazily. Args: diff --git a/src/swf_typed/_workflows.py b/src/swf_typed/_workflows.py index edef6c1..da11388 100644 --- a/src/swf_typed/_workflows.py +++ b/src/swf_typed/_workflows.py @@ -169,7 +169,7 @@ def list_workflows( workflow_filter: WorkflowIdFilter = None, reverse: bool = False, client: "botocore.client.BaseClient" = None, -) -> t.Generator[WorkflowInfo, None, None]: +) -> _common.PageConsumer[WorkflowInfo]: """List workflow types; retrieved semi-lazily. Args: