Source code for feast.infra.offline_stores.redshift_source

from typing import Callable, Dict, Iterable, Optional, Tuple

from feast import type_map
from feast.data_source import DataSource
from feast.errors import DataSourceNotFoundException, RedshiftCredentialsError
from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto
from feast.repo_config import RepoConfig
from feast.value_type import ValueType


[docs]class RedshiftSource(DataSource): def __init__( self, event_timestamp_column: Optional[str] = "", table: Optional[str] = None, created_timestamp_column: Optional[str] = "", field_mapping: Optional[Dict[str, str]] = None, date_partition_column: Optional[str] = "", query: Optional[str] = None, ): super().__init__( event_timestamp_column, created_timestamp_column, field_mapping, date_partition_column, ) self._redshift_options = RedshiftOptions(table=table, query=query)
[docs] @staticmethod def from_proto(data_source: DataSourceProto): return RedshiftSource( field_mapping=dict(data_source.field_mapping), table=data_source.redshift_options.table, event_timestamp_column=data_source.event_timestamp_column, created_timestamp_column=data_source.created_timestamp_column, date_partition_column=data_source.date_partition_column, query=data_source.redshift_options.query, )
def __eq__(self, other): if not isinstance(other, RedshiftSource): raise TypeError( "Comparisons should only involve RedshiftSource class objects." ) return ( self.redshift_options.table == other.redshift_options.table and self.redshift_options.query == other.redshift_options.query and self.event_timestamp_column == other.event_timestamp_column and self.created_timestamp_column == other.created_timestamp_column and self.field_mapping == other.field_mapping ) @property def table(self): return self._redshift_options.table @property def query(self): return self._redshift_options.query @property def redshift_options(self): """ Returns the Redshift options of this data source """ return self._redshift_options @redshift_options.setter def redshift_options(self, _redshift_options): """ Sets the Redshift options of this data source """ self._redshift_options = _redshift_options
[docs] def to_proto(self) -> DataSourceProto: data_source_proto = DataSourceProto( type=DataSourceProto.BATCH_REDSHIFT, field_mapping=self.field_mapping, redshift_options=self.redshift_options.to_proto(), ) data_source_proto.event_timestamp_column = self.event_timestamp_column data_source_proto.created_timestamp_column = self.created_timestamp_column data_source_proto.date_partition_column = self.date_partition_column return data_source_proto
[docs] def validate(self, config: RepoConfig): # As long as the query gets successfully executed, or the table exists, # the data source is validated. We don't need the results though. self.get_table_column_names_and_types(config)
[docs] def get_table_query_string(self) -> str: """Returns a string that can directly be used to reference this table in SQL""" if self.table: return f'"{self.table}"' else: return f"({self.query})"
[docs] @staticmethod def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]: return type_map.redshift_to_feast_value_type
[docs] def get_table_column_names_and_types( self, config: RepoConfig ) -> Iterable[Tuple[str, str]]: from botocore.exceptions import ClientError from feast.infra.offline_stores.redshift import RedshiftOfflineStoreConfig from feast.infra.utils import aws_utils assert isinstance(config.offline_store, RedshiftOfflineStoreConfig) client = aws_utils.get_redshift_data_client(config.offline_store.region) if self.table is not None: try: table = client.describe_table( ClusterIdentifier=config.offline_store.cluster_id, Database=config.offline_store.database, DbUser=config.offline_store.user, Table=self.table, ) except ClientError as e: if e.response["Error"]["Code"] == "ValidationException": raise RedshiftCredentialsError() from e raise # The API returns valid JSON with empty column list when the table doesn't exist if len(table["ColumnList"]) == 0: raise DataSourceNotFoundException(self.table) columns = table["ColumnList"] else: statement_id = aws_utils.execute_redshift_statement( client, config.offline_store.cluster_id, config.offline_store.database, config.offline_store.user, f"SELECT * FROM ({self.query}) LIMIT 1", ) columns = aws_utils.get_redshift_statement_result(client, statement_id)[ "ColumnMetadata" ] return [(column["name"], column["typeName"].upper()) for column in columns]
class RedshiftOptions: """ DataSource Redshift options used to source features from Redshift query """ def __init__(self, table: Optional[str], query: Optional[str]): self._table = table self._query = query @property def query(self): """ Returns the Redshift SQL query referenced by this source """ return self._query @query.setter def query(self, query): """ Sets the Redshift SQL query referenced by this source """ self._query = query @property def table(self): """ Returns the table name of this Redshift table """ return self._table @table.setter def table(self, table_name): """ Sets the table ref of this Redshift table """ self._table = table_name @classmethod def from_proto(cls, redshift_options_proto: DataSourceProto.RedshiftOptions): """ Creates a RedshiftOptions from a protobuf representation of a Redshift option Args: redshift_options_proto: A protobuf representation of a DataSource Returns: Returns a RedshiftOptions object based on the redshift_options protobuf """ redshift_options = cls( table=redshift_options_proto.table, query=redshift_options_proto.query, ) return redshift_options def to_proto(self) -> DataSourceProto.RedshiftOptions: """ Converts an RedshiftOptionsProto object to its protobuf representation. Returns: RedshiftOptionsProto protobuf """ redshift_options_proto = DataSourceProto.RedshiftOptions( table=self.table, query=self.query, ) return redshift_options_proto