Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
9baf002
Validate monitoring_location_id format in waterdata functions
thodson-usgs Apr 13, 2026
8d5d7e4
Widen _check_monitoring_location_id to accept iterables of strings
thodson-usgs May 13, 2026
36ca0be
Tidy _check_monitoring_location_id (clearer dispatch, helper, type hi…
thodson-usgs May 13, 2026
b5b7f94
Address Copilot review on PR #229
thodson-usgs May 13, 2026
df3c108
Apply _normalize_str_iterable to every multi-value string parameter
thodson-usgs May 13, 2026
e8c8e88
Polish from /simplify review
thodson-usgs May 13, 2026
71823ad
Centralize string-iterable normalization in _get_args (-153 LOC)
thodson-usgs May 13, 2026
0fd5730
_format_api_dates: materialize iterable inputs (Copilot #4)
thodson-usgs May 13, 2026
7c32bea
Extract _DATE_RANGE_PARAMS; trim docstrings/comments from /simplify
thodson-usgs May 13, 2026
749f72e
Add StringFilter/StringList aliases; fix Copilot bugs
thodson-usgs May 13, 2026
07ed123
Revert StringFilter/StringList aliases; use inline PEP 604 unions
thodson-usgs May 13, 2026
8dc70b2
Reject list-of-non-strings at boundary instead of silently passing th…
thodson-usgs May 13, 2026
1d7a6e7
Allow int-valued list filters in _get_args (water_year, year, month, …
thodson-usgs May 13, 2026
aa98d23
Close Copilot review gaps: extend AGENCY-ID check + iterable normaliz…
thodson-usgs May 13, 2026
42852a8
Centralize monitoring_location_id check in _get_args; trim narration
thodson-usgs May 13, 2026
702ea29
Widen string-filter annotations on get_combined_metadata / get_field_…
thodson-usgs May 13, 2026
5260a8f
Fix CI and reject Mapping inputs to _format_api_dates
thodson-usgs May 13, 2026
463912a
Update parameter docstrings to "string or iterable of strings"
thodson-usgs May 13, 2026
0cf981e
Widen properties annotation to Iterable[str] for consistency
thodson-usgs May 13, 2026
b0f2289
Validate monitoring_location_id in get_ratings and widen annotations …
thodson-usgs May 13, 2026
007b76b
Drop dead `monitoring_location_id` entry from _NO_NORMALIZE_PARAMS; t…
thodson-usgs May 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
860 changes: 428 additions & 432 deletions dataretrieval/waterdata/api.py

Large diffs are not rendered by default.

9 changes: 5 additions & 4 deletions dataretrieval/waterdata/nearest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

from collections.abc import Iterable
from typing import Literal, get_args

import pandas as pd
Expand All @@ -18,8 +19,8 @@

def get_nearest_continuous(
targets,
monitoring_location_id: str | list[str] | None = None,
parameter_code: str | list[str] | None = None,
monitoring_location_id: str | Iterable[str] | None = None,
parameter_code: str | Iterable[str] | None = None,
*,
window: str | pd.Timedelta = "PT7M30S",
on_tie: OnTie = "first",
Expand All @@ -44,9 +45,9 @@ def get_nearest_continuous(
Target timestamps. Naive datetimes are treated as UTC. Accepts a
list, ``pandas.Series``, ``pandas.DatetimeIndex``, ``numpy.ndarray``,
or anything ``pandas.to_datetime`` consumes.
monitoring_location_id : string or list of strings, optional
monitoring_location_id : string or iterable of strings, optional
Forwarded to ``get_continuous``.
parameter_code : string or list of strings, optional
parameter_code : string or iterable of strings, optional
Forwarded to ``get_continuous``.
window : string or ``pandas.Timedelta``, default ``"PT7M30S"``
Half-window around each target, as an ISO 8601 duration
Expand Down
13 changes: 10 additions & 3 deletions dataretrieval/waterdata/ratings.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@

from dataretrieval.rdb import extract_rdb_comment, read_rdb

from .utils import _DURATION_RE, BASE_URL, _default_headers, _format_api_dates
from .utils import (
_DURATION_RE,
BASE_URL,
_check_monitoring_location_id,
_default_headers,
_format_api_dates,
)

logger = logging.getLogger(__name__)

Expand All @@ -33,7 +39,7 @@


def get_ratings(
monitoring_location_id: str | list[str] | None = None,
monitoring_location_id: str | Iterable[str] | None = None,
file_type: RATING_FILE_TYPE | list[RATING_FILE_TYPE] = "exsa",
file_path: str | None = None,
time: str | list[str] | None = None,
Expand Down Expand Up @@ -62,7 +68,7 @@ def get_ratings(

Parameters
----------
monitoring_location_id : string or list of strings, optional
monitoring_location_id : string or iterable of strings, optional
One or more identifiers in ``AGENCY-ID`` form (e.g.
``"USGS-01104475"``). If omitted, the spatial / temporal filters
determine the result set.
Expand Down Expand Up @@ -142,6 +148,7 @@ def get_ratings(
... )

"""
monitoring_location_id = _check_monitoring_location_id(monitoring_location_id)
file_types = _as_list(file_type)
invalid = [ft for ft in file_types if ft not in _VALID_FILE_TYPES]
if invalid:
Expand Down
179 changes: 167 additions & 12 deletions dataretrieval/waterdata/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import os
import re
from collections.abc import Iterable, Mapping
from datetime import datetime
from typing import Any, get_args
from zoneinfo import ZoneInfo
Expand Down Expand Up @@ -143,6 +144,15 @@ def _switch_properties_id(properties: list[str] | None, id_name: str, service: s
# admits time-only forms like ``PT36H``.
_DURATION_RE = re.compile(r"^[Pp]T?\d")

# OGC API parameters that carry a date/datetime value (single string,
# two-element range, or interval/duration string) rather than a multi-value
# string list. Used by ``_construct_api_requests`` to keep them out of the
# POST/CQL2 multi-value path and to route them through ``_format_api_dates``,
# and by ``_NO_NORMALIZE_PARAMS`` to bypass string-iterable normalization.
_DATE_RANGE_PARAMS = frozenset(
{"datetime", "last_modified", "begin", "begin_utc", "end", "end_utc", "time"}
)


def _parse_datetime(value: str) -> datetime | None:
"""Parse a single datetime string against the supported formats.
Expand Down Expand Up @@ -223,12 +233,24 @@ def _format_api_dates(
converted from that offset to UTC; naive inputs are interpreted in the
local time zone for backwards compatibility.
"""
if datetime_input is None:
return None
# Get timezone
local_timezone = datetime.now().astimezone().tzinfo

# Convert single string to list for uniform processing
if isinstance(datetime_input, str):
datetime_input = [datetime_input]
elif isinstance(datetime_input, Mapping):
# `list(mapping)` returns keys, which silently accepts the wrong shape.
raise TypeError(
f"date input must be a string or sequence of strings, "
f"not {type(datetime_input).__name__}."
)
elif not isinstance(datetime_input, (list, tuple)):
# Materialize any other iterable (pandas.Series, numpy.ndarray,
# generator, ...) so the len()/subscript operations below work.
datetime_input = list(datetime_input)
Comment thread
thodson-usgs marked this conversation as resolved.
Comment thread
thodson-usgs marked this conversation as resolved.

# Check for null or all NA and return None
if all(pd.isna(dt) or dt == "" or dt is None for dt in datetime_input):
Expand Down Expand Up @@ -429,14 +451,11 @@ def _construct_api_requests(
"""
service_url = f"{OGC_API_URL}/collections/{service}/items"

# Single parameters can only have one value
single_params = {"datetime", "last_modified", "begin", "end", "time"}

# Identify which parameters should be included in the POST content body
post_params = {
k: v
for k, v in kwargs.items()
if k not in single_params and isinstance(v, (list, tuple)) and len(v) > 1
if k not in _DATE_RANGE_PARAMS and isinstance(v, (list, tuple)) and len(v) > 1
}

# Everything else goes into the params dictionary for the URL
Expand All @@ -452,15 +471,13 @@ def _construct_api_requests(
POST = bool(post_params)

# Convert dates to ISO08601 format
time_periods = {"last_modified", "datetime", "time", "begin", "end"}
for i in time_periods:
for i in _DATE_RANGE_PARAMS:
if i in params:
dates = service == "daily" and i != "last_modified"
params[i] = _format_api_dates(params[i], date=dates)

# String together bbox elements from a list to a comma-separated string,
# and string together properties if provided
if bbox:
# `len()` instead of truthiness: a numpy ndarray would raise on `if bbox:`.
if bbox is not None and len(bbox) > 0:
params["bbox"] = ",".join(map(str, bbox))
if properties:
params["properties"] = ",".join(properties)
Expand Down Expand Up @@ -1168,6 +1185,129 @@ def _check_profiles(
)


_MONITORING_LOCATION_ID_RE = re.compile(r"[^-\s]+-[^-\s]+")


# Iterable-shaped params that ``_get_args`` must NOT push through
# ``_normalize_str_iterable`` (scalar non-string knobs are caught by runtime
# type, so only iterables with special handling need to be named here):
# - date-range params may contain ``pd.NaT``/None or interval strings
# - ``bbox``/``boundingBox`` are ``list[float]``, sometimes ``numpy.ndarray``
# - ``get_peaks``'s int-valued filters (``water_year`` etc.) are ``list[int]``
# - ``get_combined_metadata``'s ``thresholds`` is ``list[float]``
_NO_NORMALIZE_PARAMS = _DATE_RANGE_PARAMS | {
"bbox",
"boundingBox",
"water_year",
"year",
"month",
"day",
"peak_since",
"thresholds",
}
Comment thread
thodson-usgs marked this conversation as resolved.


def _normalize_str_iterable(
value: str | Iterable[str] | None,
param_name: str = "value",
) -> str | list[str] | None:
"""Validate that ``value`` is None, a string, or an iterable of strings.

Non-string iterables (``list``, ``tuple``, ``pandas.Series``,
``pandas.Index``, ``numpy.ndarray``, generators) are materialized to a
``list`` so downstream code that branches on ``isinstance(v, (list,
tuple))`` keeps working. ``Mapping`` types are rejected because
iterating a mapping yields keys, not values.

Parameters
----------
value : None, str, or iterable of str
param_name : str, optional
Used in error messages. Defaults to ``"value"``.

Returns
-------
None, str, or list of str

Raises
------
TypeError
If the input isn't ``None``, ``str``, or a non-``Mapping``
iterable; or if any iterable element isn't a string.
"""
if value is None:
return None
if isinstance(value, str):
return value
if isinstance(value, Mapping) or not isinstance(value, Iterable):
raise TypeError(
f"{param_name} must be a string or iterable of strings, "
f"not {type(value).__name__} (got {value!r})."
)
values: list[str] = []
for v in value:
if not isinstance(v, str):
raise TypeError(
f"{param_name} elements must be strings, "
f"not {type(v).__name__} (got {v!r})."
)
values.append(v)
return values


def _check_monitoring_location_id(
monitoring_location_id: str | Iterable[str] | None,
) -> str | list[str] | None:
"""Validate and normalize a ``monitoring_location_id`` value.

Combines :func:`_normalize_str_iterable` with the AGENCY-ID format
check that is unique to ``monitoring_location_id`` (the OGC spec
requires a hyphen separator, e.g. ``USGS-01646500``).

Parameters
----------
monitoring_location_id : None, str, or iterable of str
See :func:`_normalize_str_iterable`. Each string is additionally
required to match the AGENCY-ID hyphen-separated format.

Returns
-------
None, str, or list of str

Raises
------
TypeError
If the input isn't ``None``, ``str``, or a non-``Mapping``
iterable; or if any iterable element isn't a string.
ValueError
If any identifier doesn't contain a hyphen separator
(per the OGC API spec: AGENCY-ID format, e.g. ``USGS-01646500``).
"""
try:
value = _normalize_str_iterable(
monitoring_location_id, "monitoring_location_id"
)
except TypeError as exc:
# Re-raise with the AGENCY-ID hint the generic helper doesn't carry.
raise TypeError(
f"{exc} Expected 'AGENCY-ID' format, e.g., 'USGS-01646500'."
) from None
if value is None:
return None
for item in (value,) if isinstance(value, str) else value:
_check_id_format(item)
return value


def _check_id_format(value: str) -> None:
"""Raise ``ValueError`` if ``value`` is not in ``AGENCY-ID`` format."""
if not _MONITORING_LOCATION_ID_RE.fullmatch(value):
raise ValueError(
Comment thread
thodson-usgs marked this conversation as resolved.
f"Invalid monitoring_location_id: {value!r}. "
f"Expected 'AGENCY-ID' format, e.g., 'USGS-01646500'."
)


def _get_args(
local_vars: dict[str, Any], exclude: set[str] | None = None
) -> dict[str, Any]:
Expand All @@ -1194,6 +1334,21 @@ def _get_args(
if exclude:
to_exclude.update(exclude)

return {
k: v for k, v in local_vars.items() if k not in to_exclude and v is not None
}
args: dict[str, Any] = {}
for k, v in local_vars.items():
if k in to_exclude or v is None:
continue
if k == "monitoring_location_id":
args[k] = _check_monitoring_location_id(v)
Comment thread
thodson-usgs marked this conversation as resolved.
elif k == "properties":
# `",".join(properties)` would iterate a bare string as characters.
args[k] = [v] if isinstance(v, str) else _normalize_str_iterable(v, k)
elif (
k in _NO_NORMALIZE_PARAMS
or isinstance(v, str)
or not isinstance(v, Iterable)
):
args[k] = v
else:
args[k] = _normalize_str_iterable(v, k)
Comment thread
thodson-usgs marked this conversation as resolved.
return args
Loading
Loading