Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions python/hsfs/constructor/fs_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def __init__(
query_online: Optional[str] = None,
pit_query: Optional[str] = None,
pit_query_asof: Optional[str] = None,
hqs_payload: Optional[str] = None,
hqs_payload_signature: Optional[str] = None,
href: Optional[str] = None,
expand: Optional[List[str]] = None,
items: Optional[List[Dict[str, Any]]] = None,
Expand All @@ -43,6 +45,9 @@ def __init__(
self._pit_query = pit_query
self._pit_query_asof = pit_query_asof

self._hqs_payload = hqs_payload
self._hqs_payload_signature = hqs_payload_signature

if on_demand_feature_groups is not None:
self._on_demand_fg_aliases = [
external_feature_group_alias.ExternalFeatureGroupAlias.from_response_json(
Expand Down Expand Up @@ -102,6 +107,14 @@ def hudi_cached_feature_groups(
) -> List["hudi_feature_group_alias.HudiFeatureGroupAlias"]:
return self._hudi_cached_feature_groups

@property
def hqs_payload(self) -> Optional[str]:
return self._hqs_payload

@property
def hqs_payload_signature(self) -> Optional[str]:
return self._hqs_payload_signature

def register_external(
self,
spine: Optional[
Expand Down
19 changes: 12 additions & 7 deletions python/hsfs/constructor/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ def _prep_read(
Union[str, Dict[str, Any]], Optional["storage_connector.StorageConnector"]
]:
self._check_read_supported(online)
fs_query = self._query_constructor_api.construct_query(self)

if online:
fs_query = self._query_constructor_api.construct_query(self)
sql_query = self._to_string(fs_query, online)
online_conn = self._storage_connector_api.get_online_connector(
self._feature_store_id
Expand All @@ -109,13 +109,10 @@ def _prep_read(
online_conn = None

if engine.get_instance().is_flyingduck_query_supported(self, read_options):
from hsfs.core import arrow_flight_client

sql_query = self._to_string(fs_query, online, asof=True)
sql_query = arrow_flight_client.get_instance().create_query_object(
self, sql_query, fs_query.on_demand_fg_aliases
)
# The FlyingDuck (Hopsworks Query Service) payload is build in the backend
sql_query = self._query_constructor_api.construct_query(self, hqs=True)
else:
fs_query = self._query_constructor_api.construct_query(self)
sql_query = self._to_string(fs_query, online)
# Register on demand feature groups as temporary tables
if isinstance(self._left_feature_group, fg_mod.SpineGroup):
Expand Down Expand Up @@ -724,6 +721,14 @@ def _to_string(
def __str__(self) -> str:
return self._query_constructor_api.construct_query(self)

def _get_signature(self, fs_query: "FsQuery", asof: bool = False) -> Optional[str]:
if fs_query.pit_query is not None:
if asof:
return fs_query.pit_query_asof_signature
else:
return fs_query.pit_query_signature
return fs_query.query_signature

@property
def left_feature_group_start_time(
self,
Expand Down
244 changes: 12 additions & 232 deletions python/hsfs/core/arrow_flight_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from __future__ import annotations

import base64
import datetime
import json
import logging
import warnings
Expand All @@ -37,6 +36,7 @@
from hopsworks_common.core.constants import HAS_POLARS, polars_not_installed_message
from hsfs import feature_group
from hsfs.constructor import query
from hsfs.constructor.fs_query import FsQuery
from hsfs.core.variable_api import VariableApi
from hsfs.storage_connector import StorageConnector
from pyarrow.flight import FlightServerError
Expand Down Expand Up @@ -462,14 +462,19 @@ def get_flight_info(self, descriptor):
stop_max_attempt_number=3,
retry_on_exception=_should_retry,
)
def _get_dataset(self, descriptor, timeout=None, dataframe_type="pandas"):
def _get_dataset(self, descriptor, timeout=None, headers=None, dataframe_type="pandas"):
if timeout is None:
timeout = self.timeout
info = self.get_flight_info(descriptor)
_logger.debug("Retrieved flight info: %s. Fetching dataset.", str(info))

if headers is None:
headers = self._certificates_headers()

options = pyarrow.flight.FlightCallOptions(
timeout=timeout, headers=self._certificates_headers()
timeout=timeout, headers=headers
)

reader = self._connection.do_get(info.endpoints[0].ticket, options)
_logger.debug("Dataset fetched. Converting to dataframe %s.", dataframe_type)
if dataframe_type.lower() == "polars":
Expand All @@ -481,8 +486,8 @@ def _get_dataset(self, descriptor, timeout=None, dataframe_type="pandas"):

# retry is handled in get_dataset
@_handle_afs_exception(user_message=READ_ERROR)
def read_query(self, query_object, arrow_flight_config, dataframe_type):
query_encoded = json.dumps(query_object).encode("ascii")
def read_query(self, query_object: FsQuery, arrow_flight_config, dataframe_type):
query_encoded = query_object.hqs_payload.encode("ascii")
descriptor = pyarrow.flight.FlightDescriptor.for_command(query_encoded)
return self._get_dataset(
descriptor,
Expand All @@ -491,7 +496,8 @@ def read_query(self, query_object, arrow_flight_config, dataframe_type):
if arrow_flight_config
else self.timeout
),
dataframe_type,
headers=[(b'hopsworks-signature', query_object.hqs_payload_signature.encode('ascii'))] if query_object.hqs_payload_signature else None,
dataframe_type=dataframe_type,
)

# retry is handled in get_dataset
Expand Down Expand Up @@ -549,30 +555,6 @@ def create_training_dataset(
_logger.exception(e)
print("Error calling action:", e)

def create_query_object(self, query, query_str, on_demand_fg_aliases=None):
if on_demand_fg_aliases is None:
on_demand_fg_aliases = []
features = {}
connectors = {}
for fg in query.featuregroups:
fg_name = _serialize_featuregroup_name(fg)
fg_connector = _serialize_featuregroup_connector(
fg, query, on_demand_fg_aliases
)
features[fg_name] = [
{"name": feat.name, "type": feat.type} for feat in fg.features
]
connectors[fg_name] = fg_connector
filters = _serialize_filter_expression(query.filters, query)

query = {
"query_string": _translate_to_duckdb(query, query_str),
"features": features,
"filters": filters,
"connectors": connectors,
}
return query

def is_enabled(self):
if self._disabled_for_session or not self._enabled_on_cluster:
return False
Expand Down Expand Up @@ -616,208 +598,6 @@ def enabled_on_cluster(self) -> bool:
return self._enabled_on_cluster


def _serialize_featuregroup_connector(fg, query, on_demand_fg_aliases):
# Add feature_group_id to build cache key in flyingduck
connector = {"feature_group_id": fg.id}
if isinstance(fg, feature_group.ExternalFeatureGroup):
connector["time_travel_type"] = None
connector["type"] = fg.storage_connector.type
connector["options"] = _get_connector_options(fg)
connector["query"] = fg.data_source.query
for on_demand_fg_alias in on_demand_fg_aliases:
# backend attaches dynamic query to on_demand_fg_alias.on_demand_feature_group.query if any
if on_demand_fg_alias.on_demand_feature_group.name == fg.name:
connector["query"] = (
on_demand_fg_alias.on_demand_feature_group.data_source.query
if fg.data_source.query is None
else fg.data_source.query
)
connector["alias"] = on_demand_fg_alias.alias
break
connector["query"] = (
connector["query"][:-1]
if connector["query"].endswith(";")
else connector["query"]
)
if query._left_feature_group == fg:
connector["filters"] = _serialize_filter_expression(
query._filter, query, True
)
else:
for join_obj in query._joins:
if join_obj._query._left_feature_group == fg:
connector["filters"] = _serialize_filter_expression(
join_obj._query._filter, join_obj._query, True
)
elif fg.time_travel_format == "DELTA":
connector["time_travel_type"] = "delta"
if fg.storage_connector:
connector["type"] = fg.storage_connector.type
connector["options"] = _get_connector_options(fg)
else:
connector["type"] = ""
connector["options"] = {}
connector["query"] = ""
if query._left_feature_group == fg:
connector["filters"] = _serialize_filter_expression(
query._filter, query, True
)
else:
for join_obj in query._joins:
if join_obj._query._left_feature_group == fg:
connector["filters"] = _serialize_filter_expression(
join_obj._query._filter, join_obj._query, True
)
else:
connector["time_travel_type"] = "hudi"
return connector


def _get_connector_options(fg):
# same as in the backend (maybe move to common?)
option_map = {}

datasource = fg.data_source
connector = fg.storage_connector
connector_type = connector.type

if connector_type == StorageConnector.SNOWFLAKE:
option_map = {
"user": connector.user,
"account": connector.account,
"database": datasource.database,
"schema": datasource.group,
}
if connector.password:
option_map["password"] = connector.password
elif connector.token:
option_map["authenticator"] = "oauth"
option_map["token"] = connector.token
else:
option_map["snowflake_private_key"] = connector.private_key
option_map["passphrase"] = connector.passphrase

if connector.warehouse:
option_map["warehouse"] = connector.warehouse
if connector.application:
option_map["application"] = connector.application
elif connector_type == StorageConnector.BIGQUERY:
option_map = {
"key_path": connector.key_path,
"project_id": datasource.database,
"dataset_id": datasource.group,
"parent_project": connector.parent_project,
}
elif connector_type == StorageConnector.REDSHIFT:
option_map = {
"host": connector.cluster_identifier + "." + connector.database_endpoint,
"port": connector.database_port,
"database": datasource.database,
}
if connector.database_user_name:
option_map["user"] = connector.database_user_name
if connector.database_password:
option_map["password"] = connector.database_password
if connector.iam_role:
option_map["iam_role"] = connector.iam_role
option_map["iam"] = "True"
elif connector_type == StorageConnector.RDS:
option_map = {
"host": connector.host,
"port": connector.port,
"database": datasource.database,
}
if connector.user:
option_map["user"] = connector.user
if connector.password:
option_map["password"] = connector.password
elif connector_type == StorageConnector.S3:
option_map = {
"access_key": connector.access_key,
"secret_key": connector.secret_key,
"session_token": connector.session_token,
"region": connector.region,
}
if connector.arguments.get("fs.s3a.endpoint"):
option_map["endpoint"] = connector.arguments.get("fs.s3a.endpoint")
option_map["path"] = fg.location
elif connector_type == StorageConnector.GCS:
option_map = {
"key_path": connector.key_path,
"path": fg.location,
}
else:
raise FeatureStoreException(
f"Arrow Flight doesn't support connector of type: {connector_type}"
)

return option_map


def _serialize_featuregroup_name(fg):
return f"{fg._get_project_name()}.{fg.name}_{fg.version}"


def _serialize_filter_expression(filters, query, short_name=False):
if filters is None:
return None
return _serialize_logic(filters, query, short_name)


def _serialize_logic(logic, query, short_name):
return {
"type": "logic",
"logic_type": logic._type,
"left_filter": _serialize_filter_or_logic(
logic._left_f, logic._left_l, query, short_name
),
"right_filter": _serialize_filter_or_logic(
logic._right_f, logic._right_l, query, short_name
),
}


def _serialize_filter_or_logic(filter, logic, query, short_name):
if filter:
return _serialize_filter(filter, query, short_name)
elif logic:
return _serialize_logic(logic, query, short_name)
else:
return None


def _serialize_filter(filter, query, short_name):
if isinstance(filter._value, datetime.datetime):
filter_value = filter._value.strftime("%Y-%m-%d %H:%M:%S")
else:
filter_value = filter._value

return {
"type": "filter",
"condition": filter._condition,
"value": filter_value,
"feature": _serialize_feature_name(filter._feature, query, short_name),
}


def _serialize_feature_name(feature, query, short_name):
if short_name:
return feature.name
fg = query._get_featuregroup_by_feature(feature)
fg_name = _serialize_featuregroup_name(fg)
return f"{fg_name}.{feature.name}"


def _translate_to_duckdb(query, query_str):
translated = query_str
for fg in query.featuregroups:
translated = translated.replace(
f"`{fg.feature_store_name}`.`",
f"`{fg._get_project_name()}.",
)
return translated.replace("`", '"')


def supports(featuregroups):
if len(featuregroups) > sum(
1
Expand Down
13 changes: 11 additions & 2 deletions python/hsfs/core/query_constructor_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,21 @@


class QueryConstructorApi:
def construct_query(self, query):
def construct_query(self, query, hqs=False):
_client = client.get_instance()
path_params = ["project", _client._project_id, "featurestores", "query"]

query_params = {
"hqs": hqs,
}

headers = {"content-type": "application/json"}
return fs_query.FsQuery.from_response_json(
_client._send_request(
"PUT", path_params, headers=headers, data=query.json()
"PUT",
path_params,
headers=headers,
query_params=query_params,
data=query.json(),
)
)
Loading
Loading