|
| 1 | +import json |
| 2 | +import logging |
| 3 | +import os |
| 4 | +import time |
| 5 | +from collections.abc import Callable |
| 6 | +from typing import Any, Generic, Iterator, Literal, TypedDict, TypeVar |
| 7 | +from urllib.parse import urlencode, urljoin |
| 8 | + |
| 9 | +import requests |
| 10 | +from requests.adapters import HTTPAdapter |
| 11 | +from urllib3.util.retry import Retry |
| 12 | + |
| 13 | +logger = logging.getLogger() |
| 14 | + |
| 15 | +CHUNK_SIZE_BYTES = 8192 # 8 KiB |
| 16 | + |
| 17 | + |
| 18 | +class APIError(Exception): |
| 19 | + """Generic API error exception""" |
| 20 | + |
| 21 | + pass |
| 22 | + |
| 23 | + |
| 24 | +class _RequestsSession(requests.Session): |
| 25 | + """ |
| 26 | + Override requests session to to implement a global session timeout |
| 27 | + """ |
| 28 | + |
| 29 | + def __init__(self, *args, **kwargs): |
| 30 | + self.timeout = kwargs.pop("timeout") |
| 31 | + super().__init__(*args, **kwargs) |
| 32 | + self.headers.update( |
| 33 | + {"User-Agent": "picterra-python %s" % self.headers["User-Agent"]} |
| 34 | + ) |
| 35 | + |
| 36 | + def request(self, *args, **kwargs): |
| 37 | + kwargs.setdefault("timeout", self.timeout) |
| 38 | + return super().request(*args, **kwargs) |
| 39 | + |
| 40 | + |
| 41 | +def _download_to_file(url: str, filename: str): |
| 42 | + # Given we do not use self.sess the timeout is disabled (requests default), and this |
| 43 | + # is good as file download can take a long time |
| 44 | + with requests.get(url, stream=True) as r: |
| 45 | + r.raise_for_status() |
| 46 | + with open(filename, "wb+") as f: |
| 47 | + logger.debug("Downloading to file %s.." % filename) |
| 48 | + for chunk in r.iter_content(chunk_size=CHUNK_SIZE_BYTES): |
| 49 | + if chunk: # filter out keep-alive new chunks |
| 50 | + f.write(chunk) |
| 51 | + |
| 52 | + |
| 53 | +def _upload_file_to_blobstore(upload_url: str, filename: str): |
| 54 | + if not (os.path.exists(filename) and os.path.isfile(filename)): |
| 55 | + raise ValueError("Invalid file: " + filename) |
| 56 | + with open( |
| 57 | + filename, "rb" |
| 58 | + ) as f: # binary recommended by requests stream upload (see link below) |
| 59 | + logger.debug("Opening and streaming to upload file %s" % filename) |
| 60 | + # Given we do not use self.sess the timeout is disabled (requests default), and this |
| 61 | + # is good as file upload can take a long time. Also we use requests streaming upload |
| 62 | + # (https://requests.readthedocs.io/en/latest/user/advanced/#streaming-uploads) to avoid |
| 63 | + # reading the (potentially large) layer GeoJSON in memory |
| 64 | + resp = requests.put(upload_url, data=f) |
| 65 | + if not resp.ok: |
| 66 | + logger.error("Error when uploading to blobstore %s" % upload_url) |
| 67 | + raise APIError(resp.text) |
| 68 | + |
| 69 | + |
| 70 | +T = TypeVar("T") |
| 71 | + |
| 72 | + |
| 73 | +class ResultsPage(Generic[T]): |
| 74 | + """ |
| 75 | + Interface for a paginated response from the API |
| 76 | +
|
| 77 | + Typically the endpoint returning list of objects return them splitted |
| 78 | + in pages (page 1, page 2, etc..) of a fixed dimension (eg 20). Thus |
| 79 | + each `list_XX` function returns a ResultsPage (by default the first one); |
| 80 | + once you have a ResultsPage for a given list of objects, you can: |
| 81 | + * check its length with `len()` (eg `len(page)`) |
| 82 | + * access a single element with the index operator `[]` (eg `page[5]`) |
| 83 | + * turn it into a list of dictionaries with `list()` (eg `list(page)`) |
| 84 | + * get the next page with `.next()` (eg `page.next()`); this could return |
| 85 | + None if the list is finished |
| 86 | + You can also get a specific page passing the page number to the `list_XX` function |
| 87 | + """ |
| 88 | + |
| 89 | + def __init__(self, url: str, fetch: Callable[[str], requests.Response]): |
| 90 | + resp = fetch(url) |
| 91 | + if not resp.ok: |
| 92 | + raise APIError(resp.text) |
| 93 | + r: dict[str, Any] = resp.json() |
| 94 | + next_url: str | None = r["next"] |
| 95 | + results: list[T] = r["results"] |
| 96 | + |
| 97 | + self._fetch = fetch |
| 98 | + self._next_url = next_url |
| 99 | + self._results = results |
| 100 | + self._url = url |
| 101 | + |
| 102 | + def next(self): |
| 103 | + return ResultsPage(self._next_url, self._fetch) if self._next_url else None |
| 104 | + |
| 105 | + def __len__(self) -> int: |
| 106 | + return len(self._results) |
| 107 | + |
| 108 | + def __getitem__(self, key: int) -> T: |
| 109 | + return self._results[key] |
| 110 | + |
| 111 | + def __iter__(self) -> Iterator[T]: |
| 112 | + return iter([self._results[i] for i in range(len(self._results))]) |
| 113 | + |
| 114 | + def __str__(self) -> str: |
| 115 | + return f"{len(self._results)} results from {self._url}" |
| 116 | + |
| 117 | + |
| 118 | +class Feature(TypedDict): |
| 119 | + type: Literal["Feature"] |
| 120 | + properties: dict[str, Any] |
| 121 | + geometry: dict[str, Any] |
| 122 | + |
| 123 | + |
| 124 | +class FeatureCollection(TypedDict): |
| 125 | + type: Literal["FeatureCollection"] |
| 126 | + features: list[Feature] |
| 127 | + |
| 128 | + |
| 129 | +class BaseAPIClient: |
| 130 | + """ |
| 131 | + Base class for Picterra API clients. |
| 132 | +
|
| 133 | + This is subclassed for the different products we have. |
| 134 | + """ |
| 135 | + |
| 136 | + def __init__( |
| 137 | + self, api_url: str, timeout: int = 30, max_retries: int = 3, backoff_factor: int = 10 |
| 138 | + ): |
| 139 | + """ |
| 140 | + Args: |
| 141 | + api_url: the api's base url. This is different based on the Picterra product used |
| 142 | + and is typically defined by implementations of this client |
| 143 | + timeout: number of seconds before the request times out |
| 144 | + max_retries: max attempts when ecountering gateway issues or throttles; see |
| 145 | + retry_strategy comment below |
| 146 | + backoff_factor: factor used nin the backoff algorithm; see retry_strategy comment below |
| 147 | + """ |
| 148 | + base_url = os.environ.get( |
| 149 | + "PICTERRA_BASE_URL", "https://app.picterra.ch/" |
| 150 | + ) |
| 151 | + api_key = os.environ.get("PICTERRA_API_KEY", None) |
| 152 | + if not api_key: |
| 153 | + raise APIError("PICTERRA_API_KEY environment variable is not defined") |
| 154 | + logger.info( |
| 155 | + "Using base_url=%s, api_url=%s; %d max retries, %d backoff and %s timeout.", |
| 156 | + base_url, |
| 157 | + api_url, |
| 158 | + max_retries, |
| 159 | + backoff_factor, |
| 160 | + timeout, |
| 161 | + ) |
| 162 | + self.base_url = urljoin(base_url, api_url) |
| 163 | + # Create the session with a default timeout (30 sec), that we can then |
| 164 | + # override on a per-endpoint basis (will be disabled for file uploads and downloads) |
| 165 | + self.sess = _RequestsSession(timeout=timeout) |
| 166 | + # Retry: we set the HTTP codes for our throttle (429) plus possible gateway problems (50*), |
| 167 | + # and for polling methods (GET), as non-idempotent ones should be addressed via idempotency |
| 168 | + # key mechanism; given the algorithm is {<backoff_factor> * (2 **<retries-1>}, and we |
| 169 | + # default to 30s for polling and max 30 req/min, the default 5-10-20 sequence should |
| 170 | + # provide enough room for recovery |
| 171 | + retry_strategy = Retry( |
| 172 | + total=max_retries, |
| 173 | + status_forcelist=[429, 502, 503, 504], |
| 174 | + backoff_factor=backoff_factor, |
| 175 | + allowed_methods=["GET"], |
| 176 | + ) |
| 177 | + adapter = HTTPAdapter(max_retries=retry_strategy) |
| 178 | + self.sess.mount("https://", adapter) |
| 179 | + self.sess.mount("http://", adapter) |
| 180 | + # Authentication |
| 181 | + self.sess.headers.update({"X-Api-Key": api_key}) |
| 182 | + |
| 183 | + def _full_url(self, path: str, params: dict[str, Any] | None = None): |
| 184 | + url = urljoin(self.base_url, path) |
| 185 | + if not params: |
| 186 | + return url |
| 187 | + else: |
| 188 | + qstr = urlencode(params) |
| 189 | + return "%s?%s" % (url, qstr) |
| 190 | + |
| 191 | + def _wait_until_operation_completes( |
| 192 | + self, operation_response: dict[str, Any] |
| 193 | + ) -> dict[str, Any]: |
| 194 | + """Polls an operation an returns its data""" |
| 195 | + operation_id = operation_response["operation_id"] |
| 196 | + poll_interval = operation_response["poll_interval"] |
| 197 | + # Just sleep for a short while the first time |
| 198 | + time.sleep(poll_interval * 0.1) |
| 199 | + while True: |
| 200 | + logger.info("Polling operation id %s" % operation_id) |
| 201 | + resp = self.sess.get( |
| 202 | + self._full_url("operations/%s/" % operation_id), |
| 203 | + ) |
| 204 | + if not resp.ok: |
| 205 | + raise APIError(resp.text) |
| 206 | + status = resp.json()["status"] |
| 207 | + logger.info("status=%s" % status) |
| 208 | + if status == "success": |
| 209 | + break |
| 210 | + if status == "failed": |
| 211 | + errors = resp.json()["errors"] |
| 212 | + raise APIError( |
| 213 | + "Operation %s failed: %s" % (operation_id, json.dumps(errors)) |
| 214 | + ) |
| 215 | + time.sleep(poll_interval) |
| 216 | + return resp.json() |
| 217 | + |
| 218 | + def _return_results_page( |
| 219 | + self, resource_endpoint: str, params: dict[str, Any] | None = None |
| 220 | + ) -> ResultsPage: |
| 221 | + if params is None: |
| 222 | + params = {} |
| 223 | + if "page_number" not in params: |
| 224 | + params["page_number"] = 1 |
| 225 | + |
| 226 | + url = self._full_url("%s/" % resource_endpoint, params=params) |
| 227 | + return ResultsPage(url, self.sess.get) |
| 228 | + |
| 229 | + def get_operation_results(self, operation_id: str) -> dict[str, Any]: |
| 230 | + """ |
| 231 | + Return the 'results' dict of an operation |
| 232 | +
|
| 233 | + This a **beta** function, subject to change. |
| 234 | +
|
| 235 | + Args: |
| 236 | + operation_id: The id of the operation |
| 237 | + """ |
| 238 | + resp = self.sess.get( |
| 239 | + self._full_url("operations/%s/" % operation_id), |
| 240 | + ) |
| 241 | + return resp.json()["results"] |
0 commit comments