Skip to content

Commit c8a8d76

Browse files
Refactor APIClient to introduce a separate client type
This introduces a BaseAPIClient base class that is inherited by DetectorPlatformClient, which is the wrapper around Picterra's Detector platform public api. This prepares the ground for a PlotsAnalysisPlatformClient, which will wrap Picterra's Plots Analysis platform public api. Note that to preserve backward compatibility, the DetectorPlatformClient is still exported as APIClient.
1 parent 75d1a10 commit c8a8d76

File tree

7 files changed

+520
-482
lines changed

7 files changed

+520
-482
lines changed

docs/api.rst

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@ API
33
===
44

55

6-
APIClient
7-
---------
6+
DetectorPlatformClient
7+
----------------------
88

9-
.. autoclass:: picterra.client.APIClient
9+
.. autoclass:: picterra.APIClient
1010
:members:
1111

1212
Pagination
1313
----------
1414

15-
.. autoclass:: picterra.client.ResultsPage
15+
.. autoclass:: picterra.ResultsPage
1616
:members:
1717

1818

@@ -26,5 +26,5 @@ nongeo
2626
Exceptions
2727
----------
2828

29-
.. autoclass:: picterra.client.APIError
29+
.. autoclass:: picterra.APIError
3030
:members:

src/picterra/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
from .client import APIClient
1+
from .base_client import APIError, ResultsPage
2+
# Note that we import DetectorPlatformClient twice, to export it under two names:
3+
# - DetectorPlatformClient as the name it should be used with
4+
# - APIClient to preserve backward compatibility, since that was the name it was
5+
# exported under previously (when we originally had only one platform and API client).
6+
from .detector_platform_client import DetectorPlatformClient as APIClient
7+
from .detector_platform_client import DetectorPlatformClient
28
from .nongeo import nongeo_result_to_pixel
39

4-
__all__ = ["APIClient", "nongeo_result_to_pixel"]
10+
__all__ = ["APIClient", "DetectorPlatformClient", "nongeo_result_to_pixel", "APIError", "ResultsPage"]

src/picterra/base_client.py

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
import json
2+
import logging
3+
import os
4+
import time
5+
from collections.abc import Callable
6+
from typing import Any, Generic, Iterator, Literal, TypedDict, TypeVar
7+
from urllib.parse import urlencode, urljoin
8+
9+
import requests
10+
from requests.adapters import HTTPAdapter
11+
from urllib3.util.retry import Retry
12+
13+
logger = logging.getLogger()
14+
15+
CHUNK_SIZE_BYTES = 8192 # 8 KiB
16+
17+
18+
class APIError(Exception):
19+
"""Generic API error exception"""
20+
21+
pass
22+
23+
24+
class _RequestsSession(requests.Session):
25+
"""
26+
Override requests session to to implement a global session timeout
27+
"""
28+
29+
def __init__(self, *args, **kwargs):
30+
self.timeout = kwargs.pop("timeout")
31+
super().__init__(*args, **kwargs)
32+
self.headers.update(
33+
{"User-Agent": "picterra-python %s" % self.headers["User-Agent"]}
34+
)
35+
36+
def request(self, *args, **kwargs):
37+
kwargs.setdefault("timeout", self.timeout)
38+
return super().request(*args, **kwargs)
39+
40+
41+
def _download_to_file(url: str, filename: str):
42+
# Given we do not use self.sess the timeout is disabled (requests default), and this
43+
# is good as file download can take a long time
44+
with requests.get(url, stream=True) as r:
45+
r.raise_for_status()
46+
with open(filename, "wb+") as f:
47+
logger.debug("Downloading to file %s.." % filename)
48+
for chunk in r.iter_content(chunk_size=CHUNK_SIZE_BYTES):
49+
if chunk: # filter out keep-alive new chunks
50+
f.write(chunk)
51+
52+
53+
def _upload_file_to_blobstore(upload_url: str, filename: str):
54+
if not (os.path.exists(filename) and os.path.isfile(filename)):
55+
raise ValueError("Invalid file: " + filename)
56+
with open(
57+
filename, "rb"
58+
) as f: # binary recommended by requests stream upload (see link below)
59+
logger.debug("Opening and streaming to upload file %s" % filename)
60+
# Given we do not use self.sess the timeout is disabled (requests default), and this
61+
# is good as file upload can take a long time. Also we use requests streaming upload
62+
# (https://requests.readthedocs.io/en/latest/user/advanced/#streaming-uploads) to avoid
63+
# reading the (potentially large) layer GeoJSON in memory
64+
resp = requests.put(upload_url, data=f)
65+
if not resp.ok:
66+
logger.error("Error when uploading to blobstore %s" % upload_url)
67+
raise APIError(resp.text)
68+
69+
70+
T = TypeVar("T")
71+
72+
73+
class ResultsPage(Generic[T]):
74+
"""
75+
Interface for a paginated response from the API
76+
77+
Typically the endpoint returning list of objects return them splitted
78+
in pages (page 1, page 2, etc..) of a fixed dimension (eg 20). Thus
79+
each `list_XX` function returns a ResultsPage (by default the first one);
80+
once you have a ResultsPage for a given list of objects, you can:
81+
* check its length with `len()` (eg `len(page)`)
82+
* access a single element with the index operator `[]` (eg `page[5]`)
83+
* turn it into a list of dictionaries with `list()` (eg `list(page)`)
84+
* get the next page with `.next()` (eg `page.next()`); this could return
85+
None if the list is finished
86+
You can also get a specific page passing the page number to the `list_XX` function
87+
"""
88+
89+
def __init__(self, url: str, fetch: Callable[[str], requests.Response]):
90+
resp = fetch(url)
91+
if not resp.ok:
92+
raise APIError(resp.text)
93+
r: dict[str, Any] = resp.json()
94+
next_url: str | None = r["next"]
95+
results: list[T] = r["results"]
96+
97+
self._fetch = fetch
98+
self._next_url = next_url
99+
self._results = results
100+
self._url = url
101+
102+
def next(self):
103+
return ResultsPage(self._next_url, self._fetch) if self._next_url else None
104+
105+
def __len__(self) -> int:
106+
return len(self._results)
107+
108+
def __getitem__(self, key: int) -> T:
109+
return self._results[key]
110+
111+
def __iter__(self) -> Iterator[T]:
112+
return iter([self._results[i] for i in range(len(self._results))])
113+
114+
def __str__(self) -> str:
115+
return f"{len(self._results)} results from {self._url}"
116+
117+
118+
class Feature(TypedDict):
119+
type: Literal["Feature"]
120+
properties: dict[str, Any]
121+
geometry: dict[str, Any]
122+
123+
124+
class FeatureCollection(TypedDict):
125+
type: Literal["FeatureCollection"]
126+
features: list[Feature]
127+
128+
129+
class BaseAPIClient:
130+
"""
131+
Base class for Picterra API clients.
132+
133+
This is subclassed for the different products we have.
134+
"""
135+
136+
def __init__(
137+
self, api_url: str, timeout: int = 30, max_retries: int = 3, backoff_factor: int = 10
138+
):
139+
"""
140+
Args:
141+
api_url: the api's base url. This is different based on the Picterra product used
142+
and is typically defined by implementations of this client
143+
timeout: number of seconds before the request times out
144+
max_retries: max attempts when ecountering gateway issues or throttles; see
145+
retry_strategy comment below
146+
backoff_factor: factor used nin the backoff algorithm; see retry_strategy comment below
147+
"""
148+
base_url = os.environ.get(
149+
"PICTERRA_BASE_URL", "https://app.picterra.ch/"
150+
)
151+
api_key = os.environ.get("PICTERRA_API_KEY", None)
152+
if not api_key:
153+
raise APIError("PICTERRA_API_KEY environment variable is not defined")
154+
logger.info(
155+
"Using base_url=%s, api_url=%s; %d max retries, %d backoff and %s timeout.",
156+
base_url,
157+
api_url,
158+
max_retries,
159+
backoff_factor,
160+
timeout,
161+
)
162+
self.base_url = urljoin(base_url, api_url)
163+
# Create the session with a default timeout (30 sec), that we can then
164+
# override on a per-endpoint basis (will be disabled for file uploads and downloads)
165+
self.sess = _RequestsSession(timeout=timeout)
166+
# Retry: we set the HTTP codes for our throttle (429) plus possible gateway problems (50*),
167+
# and for polling methods (GET), as non-idempotent ones should be addressed via idempotency
168+
# key mechanism; given the algorithm is {<backoff_factor> * (2 **<retries-1>}, and we
169+
# default to 30s for polling and max 30 req/min, the default 5-10-20 sequence should
170+
# provide enough room for recovery
171+
retry_strategy = Retry(
172+
total=max_retries,
173+
status_forcelist=[429, 502, 503, 504],
174+
backoff_factor=backoff_factor,
175+
allowed_methods=["GET"],
176+
)
177+
adapter = HTTPAdapter(max_retries=retry_strategy)
178+
self.sess.mount("https://", adapter)
179+
self.sess.mount("http://", adapter)
180+
# Authentication
181+
self.sess.headers.update({"X-Api-Key": api_key})
182+
183+
def _full_url(self, path: str, params: dict[str, Any] | None = None):
184+
url = urljoin(self.base_url, path)
185+
if not params:
186+
return url
187+
else:
188+
qstr = urlencode(params)
189+
return "%s?%s" % (url, qstr)
190+
191+
def _wait_until_operation_completes(
192+
self, operation_response: dict[str, Any]
193+
) -> dict[str, Any]:
194+
"""Polls an operation an returns its data"""
195+
operation_id = operation_response["operation_id"]
196+
poll_interval = operation_response["poll_interval"]
197+
# Just sleep for a short while the first time
198+
time.sleep(poll_interval * 0.1)
199+
while True:
200+
logger.info("Polling operation id %s" % operation_id)
201+
resp = self.sess.get(
202+
self._full_url("operations/%s/" % operation_id),
203+
)
204+
if not resp.ok:
205+
raise APIError(resp.text)
206+
status = resp.json()["status"]
207+
logger.info("status=%s" % status)
208+
if status == "success":
209+
break
210+
if status == "failed":
211+
errors = resp.json()["errors"]
212+
raise APIError(
213+
"Operation %s failed: %s" % (operation_id, json.dumps(errors))
214+
)
215+
time.sleep(poll_interval)
216+
return resp.json()
217+
218+
def _return_results_page(
219+
self, resource_endpoint: str, params: dict[str, Any] | None = None
220+
) -> ResultsPage:
221+
if params is None:
222+
params = {}
223+
if "page_number" not in params:
224+
params["page_number"] = 1
225+
226+
url = self._full_url("%s/" % resource_endpoint, params=params)
227+
return ResultsPage(url, self.sess.get)
228+
229+
def get_operation_results(self, operation_id: str) -> dict[str, Any]:
230+
"""
231+
Return the 'results' dict of an operation
232+
233+
This a **beta** function, subject to change.
234+
235+
Args:
236+
operation_id: The id of the operation
237+
"""
238+
resp = self.sess.get(
239+
self._full_url("operations/%s/" % operation_id),
240+
)
241+
return resp.json()["results"]

0 commit comments

Comments
 (0)