Source code for feast.infra.offline_stores.snowflake_source

from typing import Callable, Dict, Iterable, Optional, Tuple

from feast import type_map
from feast.data_source import DataSource
from feast.errors import DataSourceNoNameException
from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto
from feast.protos.feast.core.SavedDataset_pb2 import (
    SavedDatasetStorage as SavedDatasetStorageProto,
)
from feast.repo_config import RepoConfig
from feast.saved_dataset import SavedDatasetStorage
from feast.value_type import ValueType


[docs]class SnowflakeSource(DataSource): def __init__( self, name: Optional[str] = None, database: Optional[str] = None, schema: Optional[str] = None, table: Optional[str] = None, query: Optional[str] = None, event_timestamp_column: Optional[str] = "", created_timestamp_column: Optional[str] = "", field_mapping: Optional[Dict[str, str]] = None, date_partition_column: Optional[str] = "", ): """ Creates a SnowflakeSource object. Args: name (optional): Name for the source. Defaults to the table if not specified. database (optional): Snowflake database where the features are stored. schema (optional): Snowflake schema in which the table is located. table (optional): Snowflake table where the features are stored. event_timestamp_column (optional): Event timestamp column used for point in time joins of feature values. query (optional): The query to be executed to obtain the features. created_timestamp_column (optional): Timestamp column indicating when the row was created, used for deduplicating rows. field_mapping (optional): A dictionary mapping of column names in this data source to column names in a feature table or view. date_partition_column (optional): Timestamp column used for partitioning. """ if table is None and query is None: raise ValueError('No "table" argument provided.') # If no name, use the table as the default name _name = name if not _name: if table: _name = table else: raise DataSourceNoNameException() super().__init__( _name, event_timestamp_column, created_timestamp_column, field_mapping, date_partition_column, ) # The default Snowflake schema is named "PUBLIC". _schema = "PUBLIC" if (database and table and not schema) else schema self.snowflake_options = SnowflakeOptions( database=database, schema=_schema, table=table, query=query )
[docs] @staticmethod def from_proto(data_source: DataSourceProto): """ Creates a SnowflakeSource from a protobuf representation of a SnowflakeSource. Args: data_source: A protobuf representation of a SnowflakeSource Returns: A SnowflakeSource object based on the data_source protobuf. """ return SnowflakeSource( field_mapping=dict(data_source.field_mapping), database=data_source.snowflake_options.database, schema=data_source.snowflake_options.schema, table=data_source.snowflake_options.table, event_timestamp_column=data_source.event_timestamp_column, created_timestamp_column=data_source.created_timestamp_column, date_partition_column=data_source.date_partition_column, query=data_source.snowflake_options.query, )
# Note: Python requires redefining hash in child classes that override __eq__ def __hash__(self): return super().__hash__() def __eq__(self, other): if not isinstance(other, SnowflakeSource): raise TypeError( "Comparisons should only involve SnowflakeSource class objects." ) return ( self.name == other.name and self.snowflake_options.database == other.snowflake_options.database and self.snowflake_options.schema == other.snowflake_options.schema and self.snowflake_options.table == other.snowflake_options.table and self.snowflake_options.query == other.snowflake_options.query and self.event_timestamp_column == other.event_timestamp_column and self.created_timestamp_column == other.created_timestamp_column and self.field_mapping == other.field_mapping ) @property def database(self): """Returns the database of this snowflake source.""" return self.snowflake_options.database @property def schema(self): """Returns the schema of this snowflake source.""" return self.snowflake_options.schema @property def table(self): """Returns the table of this snowflake source.""" return self.snowflake_options.table @property def query(self): """Returns the snowflake options of this snowflake source.""" return self.snowflake_options.query
[docs] def to_proto(self) -> DataSourceProto: """ Converts a SnowflakeSource object to its protobuf representation. Returns: A DataSourceProto object. """ data_source_proto = DataSourceProto( type=DataSourceProto.BATCH_SNOWFLAKE, field_mapping=self.field_mapping, snowflake_options=self.snowflake_options.to_proto(), ) data_source_proto.event_timestamp_column = self.event_timestamp_column data_source_proto.created_timestamp_column = self.created_timestamp_column data_source_proto.date_partition_column = self.date_partition_column return data_source_proto
[docs] def validate(self, config: RepoConfig): # As long as the query gets successfully executed, or the table exists, # the data source is validated. We don't need the results though. self.get_table_column_names_and_types(config)
[docs] def get_table_query_string(self) -> str: """Returns a string that can directly be used to reference this table in SQL.""" if self.database and self.table: return f'"{self.database}"."{self.schema}"."{self.table}"' elif self.table: return f'"{self.table}"' else: return f"({self.query})"
[docs] @staticmethod def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]: return type_map.snowflake_python_type_to_feast_value_type
[docs] def get_table_column_names_and_types( self, config: RepoConfig ) -> Iterable[Tuple[str, str]]: """ Returns a mapping of column names to types for this snowflake source. Args: config: A RepoConfig describing the feature repo """ from feast.infra.offline_stores.snowflake import SnowflakeOfflineStoreConfig from feast.infra.utils.snowflake_utils import ( execute_snowflake_statement, get_snowflake_conn, ) assert isinstance(config.offline_store, SnowflakeOfflineStoreConfig) snowflake_conn = get_snowflake_conn(config.offline_store) if self.database and self.table: query = f'SELECT * FROM "{self.database}"."{self.schema}"."{self.table}" LIMIT 1' elif self.table: query = f'SELECT * FROM "{self.table}" LIMIT 1' else: query = f"SELECT * FROM ({self.query}) LIMIT 1" result = execute_snowflake_statement(snowflake_conn, query).fetch_pandas_all() if not result.empty: metadata = result.dtypes.apply(str) return list(zip(metadata.index, metadata)) else: raise ValueError("The following source:\n" + query + "\n ... is empty")
class SnowflakeOptions: """ DataSource snowflake options used to source features from snowflake query. """ def __init__( self, database: Optional[str], schema: Optional[str], table: Optional[str], query: Optional[str], ): self._database = database self._schema = schema self._table = table self._query = query @property def query(self): """Returns the snowflake SQL query referenced by this source.""" return self._query @query.setter def query(self, query): """Sets the snowflake SQL query referenced by this source.""" self._query = query @property def database(self): """Returns the database name of this snowflake table.""" return self._database @database.setter def database(self, database): """Sets the database ref of this snowflake table.""" self._database = database @property def schema(self): """Returns the schema name of this snowflake table.""" return self._schema @schema.setter def schema(self, schema): """Sets the schema of this snowflake table.""" self._schema = schema @property def table(self): """Returns the table name of this snowflake table.""" return self._table @table.setter def table(self, table): """Sets the table ref of this snowflake table.""" self._table = table @classmethod def from_proto(cls, snowflake_options_proto: DataSourceProto.SnowflakeOptions): """ Creates a SnowflakeOptions from a protobuf representation of a snowflake option. Args: snowflake_options_proto: A protobuf representation of a DataSource Returns: A SnowflakeOptions object based on the snowflake_options protobuf. """ snowflake_options = cls( database=snowflake_options_proto.database, schema=snowflake_options_proto.schema, table=snowflake_options_proto.table, query=snowflake_options_proto.query, ) return snowflake_options def to_proto(self) -> DataSourceProto.SnowflakeOptions: """ Converts an SnowflakeOptionsProto object to its protobuf representation. Returns: A SnowflakeOptionsProto protobuf. """ snowflake_options_proto = DataSourceProto.SnowflakeOptions( database=self.database, schema=self.schema, table=self.table, query=self.query, ) return snowflake_options_proto class SavedDatasetSnowflakeStorage(SavedDatasetStorage): _proto_attr_name = "snowflake_storage" snowflake_options: SnowflakeOptions def __init__(self, table_ref: str): self.snowflake_options = SnowflakeOptions( database=None, schema=None, table=table_ref, query=None ) @staticmethod def from_proto(storage_proto: SavedDatasetStorageProto) -> SavedDatasetStorage: return SavedDatasetSnowflakeStorage( table_ref=SnowflakeOptions.from_proto(storage_proto.snowflake_storage).table ) def to_proto(self) -> SavedDatasetStorageProto: return SavedDatasetStorageProto( snowflake_storage=self.snowflake_options.to_proto() ) def to_data_source(self) -> DataSource: return SnowflakeSource(table=self.snowflake_options.table)