from typing import Callable, Dict, Iterable, Optional, Tuple
from feast import type_map
from feast.data_source import DataSource
from feast.errors import DataSourceNoNameException, DataSourceNotFoundException
from feast.feature_logging import LoggingDestination
from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto
from feast.protos.feast.core.FeatureService_pb2 import (
LoggingConfig as LoggingConfigProto,
)
from feast.protos.feast.core.SavedDataset_pb2 import (
SavedDatasetStorage as SavedDatasetStorageProto,
)
from feast.repo_config import RepoConfig
from feast.saved_dataset import SavedDatasetStorage
from feast.value_type import ValueType
[docs]class AthenaSource(DataSource):
def __init__(
self,
*,
timestamp_field: Optional[str] = "",
table: Optional[str] = None,
database: Optional[str] = None,
data_source: Optional[str] = None,
created_timestamp_column: Optional[str] = None,
field_mapping: Optional[Dict[str, str]] = None,
date_partition_column: Optional[str] = None,
query: Optional[str] = None,
name: Optional[str] = None,
description: Optional[str] = "",
tags: Optional[Dict[str, str]] = None,
owner: Optional[str] = "",
):
"""
Creates a AthenaSource object.
Args:
timestamp_field : event timestamp column.
table (optional): Athena table where the features are stored. Exactly one of 'table'
and 'query' must be specified.
database: Athena Database Name
data_source (optional): Athena data source
created_timestamp_column (optional): Timestamp column indicating when the
row was created, used for deduplicating rows.
field_mapping (optional): A dictionary mapping of column names in this data
source to column names in a feature table or view.
date_partition_column : Timestamp column used for partitioning.
query (optional): The query to be executed to obtain the features. Exactly one of 'table'
and 'query' must be specified.
name (optional): Name for the source. Defaults to the table if not specified, in which
case the table must be specified.
description (optional): A human-readable description.
tags (optional): A dictionary of key-value pairs to store arbitrary metadata.
owner (optional): The owner of the athena source, typically the email of the primary
maintainer.
"""
_database = "default" if table and not database else database
self.athena_options = AthenaOptions(
table=table, query=query, database=_database, data_source=data_source
)
if table is None and query is None:
raise ValueError('No "table" argument provided.')
# If no name, use the table as the default name.
if name is None and table is None:
raise DataSourceNoNameException()
_name = name or table
assert _name
super().__init__(
name=_name if _name else "",
timestamp_field=timestamp_field,
created_timestamp_column=created_timestamp_column,
field_mapping=field_mapping,
date_partition_column=date_partition_column,
description=description,
tags=tags,
owner=owner,
)
[docs] @staticmethod
def from_proto(data_source: DataSourceProto):
"""
Creates a AthenaSource from a protobuf representation of a AthenaSource.
Args:
data_source: A protobuf representation of a AthenaSource
Returns:
A AthenaSource object based on the data_source protobuf.
"""
return AthenaSource(
name=data_source.name,
timestamp_field=data_source.timestamp_field,
table=data_source.athena_options.table,
database=data_source.athena_options.database,
data_source=data_source.athena_options.data_source,
created_timestamp_column=data_source.created_timestamp_column,
field_mapping=dict(data_source.field_mapping),
date_partition_column=data_source.date_partition_column,
query=data_source.athena_options.query,
description=data_source.description,
tags=dict(data_source.tags),
)
# Note: Python requires redefining hash in child classes that override __eq__
def __hash__(self):
return super().__hash__()
def __eq__(self, other):
if not isinstance(other, AthenaSource):
raise TypeError(
"Comparisons should only involve AthenaSource class objects."
)
return (
super().__eq__(other)
and self.athena_options.table == other.athena_options.table
and self.athena_options.query == other.athena_options.query
and self.athena_options.database == other.athena_options.database
and self.athena_options.data_source == other.athena_options.data_source
)
@property
def table(self):
"""Returns the table of this Athena source."""
return self.athena_options.table
@property
def database(self):
"""Returns the database of this Athena source."""
return self.athena_options.database
@property
def query(self):
"""Returns the Athena query of this Athena source."""
return self.athena_options.query
@property
def data_source(self):
"""Returns the Athena data_source of this Athena source."""
return self.athena_options.data_source
[docs] def to_proto(self) -> DataSourceProto:
"""
Converts a RedshiftSource object to its protobuf representation.
Returns:
A DataSourceProto object.
"""
data_source_proto = DataSourceProto(
type=DataSourceProto.BATCH_ATHENA,
name=self.name,
timestamp_field=self.timestamp_field,
created_timestamp_column=self.created_timestamp_column,
field_mapping=self.field_mapping,
date_partition_column=self.date_partition_column,
description=self.description,
tags=self.tags,
athena_options=self.athena_options.to_proto(),
)
return data_source_proto
[docs] def validate(self, config: RepoConfig):
# As long as the query gets successfully executed, or the table exists,
# the data source is validated. We don't need the results though.
self.get_table_column_names_and_types(config)
[docs] def get_table_query_string(self, config: Optional[RepoConfig] = None) -> str:
"""Returns a string that can directly be used to reference this table in SQL."""
if self.table:
data_source = self.data_source
database = self.database
if config:
data_source = config.offline_store.data_source
database = config.offline_store.database
return f'"{data_source}"."{database}"."{self.table}"'
else:
return f"({self.query})"
[docs] @staticmethod
def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]:
return type_map.athena_to_feast_value_type
[docs] def get_table_column_names_and_types(
self, config: RepoConfig
) -> Iterable[Tuple[str, str]]:
"""
Returns a mapping of column names to types for this Athena source.
Args:
config: A RepoConfig describing the feature repo
"""
from botocore.exceptions import ClientError
from feast.infra.offline_stores.contrib.athena_offline_store.athena import (
AthenaOfflineStoreConfig,
)
from feast.infra.utils import aws_utils
assert isinstance(config.offline_store, AthenaOfflineStoreConfig)
client = aws_utils.get_athena_data_client(config.offline_store.region)
if self.table:
try:
table = client.get_table_metadata(
CatalogName=self.data_source,
DatabaseName=self.database,
TableName=self.table,
)
except ClientError as e:
raise aws_utils.AthenaError(e)
# The API returns valid JSON with empty column list when the table doesn't exist
if len(table["TableMetadata"]["Columns"]) == 0:
raise DataSourceNotFoundException(self.table)
columns = table["TableMetadata"]["Columns"]
else:
statement_id = aws_utils.execute_athena_query(
client,
config.offline_store.data_source,
config.offline_store.database,
config.offline_store.workgroup,
f"SELECT * FROM ({self.query}) LIMIT 1",
)
columns = aws_utils.get_athena_query_result(client, statement_id)[
"ResultSetMetadata"
]["ColumnInfo"]
return [(column["Name"], column["Type"].upper()) for column in columns]
[docs]class AthenaOptions:
"""
Configuration options for a Athena data source.
"""
def __init__(
self,
table: Optional[str],
query: Optional[str],
database: Optional[str],
data_source: Optional[str],
):
self.table = table or ""
self.query = query or ""
self.database = database or ""
self.data_source = data_source or ""
[docs] @classmethod
def from_proto(cls, athena_options_proto: DataSourceProto.AthenaOptions):
"""
Creates a AthenaOptions from a protobuf representation of a Athena option.
Args:
athena_options_proto: A protobuf representation of a DataSource
Returns:
A AthenaOptions object based on the athena_options protobuf.
"""
athena_options = cls(
table=athena_options_proto.table,
query=athena_options_proto.query,
database=athena_options_proto.database,
data_source=athena_options_proto.data_source,
)
return athena_options
[docs] def to_proto(self) -> DataSourceProto.AthenaOptions:
"""
Converts an AthenaOptionsProto object to its protobuf representation.
Returns:
A AthenaOptionsProto protobuf.
"""
athena_options_proto = DataSourceProto.AthenaOptions(
table=self.table,
query=self.query,
database=self.database,
data_source=self.data_source,
)
return athena_options_proto
[docs]class SavedDatasetAthenaStorage(SavedDatasetStorage):
_proto_attr_name = "athena_storage"
athena_options: AthenaOptions
def __init__(
self,
table_ref: str,
query: str = None,
database: str = None,
data_source: str = None,
):
self.athena_options = AthenaOptions(
table=table_ref, query=query, database=database, data_source=data_source
)
[docs] @staticmethod
def from_proto(storage_proto: SavedDatasetStorageProto) -> SavedDatasetStorage:
return SavedDatasetAthenaStorage(
table_ref=AthenaOptions.from_proto(storage_proto.athena_storage).table
)
[docs] def to_proto(self) -> SavedDatasetStorageProto:
return SavedDatasetStorageProto(athena_storage=self.athena_options.to_proto())
[docs] def to_data_source(self) -> DataSource:
return AthenaSource(table=self.athena_options.table)
[docs]class AthenaLoggingDestination(LoggingDestination):
_proto_kind = "athena_destination"
table_name: str
def __init__(self, *, table_name: str):
self.table_name = table_name
[docs] @classmethod
def from_proto(cls, config_proto: LoggingConfigProto) -> "LoggingDestination":
return AthenaLoggingDestination(
table_name=config_proto.athena_destination.table_name,
)
[docs] def to_proto(self) -> LoggingConfigProto:
return LoggingConfigProto(
athena_destination=LoggingConfigProto.AthenaDestination(
table_name=self.table_name
)
)
[docs] def to_data_source(self) -> DataSource:
return AthenaSource(table=self.table_name)