Source code for feast.data_source

# Copyright 2020 The Feast Authors
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
import warnings
from abc import ABC, abstractmethod
from datetime import timedelta
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple

from google.protobuf.duration_pb2 import Duration
from google.protobuf.json_format import MessageToJson
from typeguard import typechecked

from feast import type_map
from feast.data_format import StreamFormat
from feast.field import Field
from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto
from feast.repo_config import RepoConfig, get_data_source_class_from_type
from feast.types import from_value_type
from feast.value_type import ValueType

[docs]class KafkaOptions: """ DataSource Kafka options used to source features from Kafka messages """ def __init__( self, kafka_bootstrap_servers: str, message_format: StreamFormat, topic: str, watermark_delay_threshold: Optional[timedelta] = None, ): self.kafka_bootstrap_servers = kafka_bootstrap_servers self.message_format = message_format self.topic = topic self.watermark_delay_threshold = watermark_delay_threshold or None
[docs] @classmethod def from_proto(cls, kafka_options_proto: DataSourceProto.KafkaOptions): """ Creates a KafkaOptions from a protobuf representation of a kafka option Args: kafka_options_proto: A protobuf representation of a DataSource Returns: Returns a KafkaOptions object based on the kafka_options protobuf """ watermark_delay_threshold = None if kafka_options_proto.HasField("watermark_delay_threshold"): watermark_delay_threshold = ( timedelta(days=0) if kafka_options_proto.watermark_delay_threshold.ToNanoseconds() == 0 else kafka_options_proto.watermark_delay_threshold.ToTimedelta() ) kafka_options = cls( kafka_bootstrap_servers=kafka_options_proto.kafka_bootstrap_servers, message_format=StreamFormat.from_proto(kafka_options_proto.message_format), topic=kafka_options_proto.topic, watermark_delay_threshold=watermark_delay_threshold, ) return kafka_options
[docs] def to_proto(self) -> DataSourceProto.KafkaOptions: """ Converts an KafkaOptionsProto object to its protobuf representation. Returns: KafkaOptionsProto protobuf """ watermark_delay_threshold = None if self.watermark_delay_threshold is not None: watermark_delay_threshold = Duration() watermark_delay_threshold.FromTimedelta(self.watermark_delay_threshold) kafka_options_proto = DataSourceProto.KafkaOptions( kafka_bootstrap_servers=self.kafka_bootstrap_servers, message_format=self.message_format.to_proto(), topic=self.topic, watermark_delay_threshold=watermark_delay_threshold, ) return kafka_options_proto
[docs]class KinesisOptions: """ DataSource Kinesis options used to source features from Kinesis records """ def __init__( self, record_format: StreamFormat, region: str, stream_name: str, ): self.record_format = record_format self.region = region self.stream_name = stream_name
[docs] @classmethod def from_proto(cls, kinesis_options_proto: DataSourceProto.KinesisOptions): """ Creates a KinesisOptions from a protobuf representation of a kinesis option Args: kinesis_options_proto: A protobuf representation of a DataSource Returns: Returns a KinesisOptions object based on the kinesis_options protobuf """ kinesis_options = cls( record_format=StreamFormat.from_proto(kinesis_options_proto.record_format), region=kinesis_options_proto.region, stream_name=kinesis_options_proto.stream_name, ) return kinesis_options
[docs] def to_proto(self) -> DataSourceProto.KinesisOptions: """ Converts an KinesisOptionsProto object to its protobuf representation. Returns: KinesisOptionsProto protobuf """ kinesis_options_proto = DataSourceProto.KinesisOptions( record_format=self.record_format.to_proto(), region=self.region, stream_name=self.stream_name, ) return kinesis_options_proto
_DATA_SOURCE_OPTIONS = { DataSourceProto.SourceType.BATCH_FILE: "feast.infra.offline_stores.file_source.FileSource", DataSourceProto.SourceType.BATCH_BIGQUERY: "feast.infra.offline_stores.bigquery_source.BigQuerySource", DataSourceProto.SourceType.BATCH_REDSHIFT: "feast.infra.offline_stores.redshift_source.RedshiftSource", DataSourceProto.SourceType.BATCH_SNOWFLAKE: "feast.infra.offline_stores.snowflake_source.SnowflakeSource", DataSourceProto.SourceType.BATCH_TRINO: "feast.infra.offline_stores.contrib.trino_offline_store.trino_source.TrinoSource", DataSourceProto.SourceType.BATCH_SPARK: "feast.infra.offline_stores.contrib.spark_offline_store.spark_source.SparkSource", DataSourceProto.SourceType.BATCH_ATHENA: "feast.infra.offline_stores.contrib.athena_offline_store.athena_source.AthenaSource", DataSourceProto.SourceType.STREAM_KAFKA: "feast.data_source.KafkaSource", DataSourceProto.SourceType.STREAM_KINESIS: "feast.data_source.KinesisSource", DataSourceProto.SourceType.REQUEST_SOURCE: "feast.data_source.RequestSource", DataSourceProto.SourceType.PUSH_SOURCE: "feast.data_source.PushSource", }
[docs]@typechecked class DataSource(ABC): """ DataSource that can be used to source features. Args: name: Name of data source, which should be unique within a project timestamp_field (optional): Event timestamp field used for point-in-time joins of feature values. created_timestamp_column (optional): Timestamp column indicating when the row was created, used for deduplicating rows. field_mapping (optional): A dictionary mapping of column names in this data source to feature names in a feature table or view. Only used for feature columns, not entity or timestamp columns. description (optional) A human-readable description. tags (optional): A dictionary of key-value pairs to store arbitrary metadata. owner (optional): The owner of the data source, typically the email of the primary maintainer. date_partition_column (optional): Timestamp column used for partitioning. Not supported by all offline stores. """ name: str timestamp_field: str created_timestamp_column: str field_mapping: Dict[str, str] description: str tags: Dict[str, str] owner: str date_partition_column: str def __init__( self, *, name: str, timestamp_field: Optional[str] = None, created_timestamp_column: Optional[str] = None, field_mapping: Optional[Dict[str, str]] = None, description: Optional[str] = "", tags: Optional[Dict[str, str]] = None, owner: Optional[str] = "", date_partition_column: Optional[str] = None, ): """ Creates a DataSource object. Args: name: Name of data source, which should be unique within a project. timestamp_field (optional): Event timestamp field used for point-in-time joins of feature values. created_timestamp_column (optional): Timestamp column indicating when the row was created, used for deduplicating rows. field_mapping (optional): A dictionary mapping of column names in this data source to feature names in a feature table or view. Only used for feature columns, not entity or timestamp columns. description (optional): A human-readable description. tags (optional): A dictionary of key-value pairs to store arbitrary metadata. owner (optional): The owner of the data source, typically the email of the primary maintainer. date_partition_column (optional): Timestamp column used for partitioning. Not supported by all stores """ = name self.timestamp_field = timestamp_field or "" self.created_timestamp_column = ( created_timestamp_column if created_timestamp_column else "" ) self.field_mapping = field_mapping if field_mapping else {} if ( self.timestamp_field and self.timestamp_field == self.created_timestamp_column ): raise ValueError( "Please do not use the same column for 'timestamp_field' and 'created_timestamp_column'." ) self.description = description or "" self.tags = tags or {} self.owner = owner or "" self.date_partition_column = ( date_partition_column if date_partition_column else "" ) def __hash__(self): return hash((, self.timestamp_field)) def __str__(self): return str(MessageToJson(self.to_proto())) def __eq__(self, other): if other is None: return False if not isinstance(other, DataSource): raise TypeError("Comparisons should only involve DataSource class objects.") if ( != or self.timestamp_field != other.timestamp_field or self.created_timestamp_column != other.created_timestamp_column or self.field_mapping != other.field_mapping or self.date_partition_column != other.date_partition_column or self.description != other.description or self.tags != other.tags or self.owner != other.owner ): return False return True
[docs] @staticmethod @abstractmethod def from_proto(data_source: DataSourceProto) -> Any: """ Converts data source config in protobuf spec to a DataSource class object. Args: data_source: A protobuf representation of a DataSource. Returns: A DataSource class object. Raises: ValueError: The type of DataSource could not be identified. """ data_source_type = data_source.type if not data_source_type or ( data_source_type not in list(_DATA_SOURCE_OPTIONS.keys()) + [DataSourceProto.SourceType.CUSTOM_SOURCE] ): raise ValueError("Could not identify the source type being added.") if data_source_type == DataSourceProto.SourceType.CUSTOM_SOURCE: cls = get_data_source_class_from_type(data_source.data_source_class_type) return cls.from_proto(data_source) cls = get_data_source_class_from_type(_DATA_SOURCE_OPTIONS[data_source_type]) return cls.from_proto(data_source)
[docs] @abstractmethod def to_proto(self) -> DataSourceProto: """ Converts a DataSourceProto object to its protobuf representation. """ raise NotImplementedError
[docs] def validate(self, config: RepoConfig): """ Validates the underlying data source. Args: config: Configuration object used to configure a feature store. """ raise NotImplementedError
[docs] @staticmethod @abstractmethod def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]: """ Returns the callable method that returns Feast type given the raw column type. """ raise NotImplementedError
[docs] def get_table_column_names_and_types( self, config: RepoConfig ) -> Iterable[Tuple[str, str]]: """ Returns the list of column names and raw column types. Args: config: Configuration object used to configure a feature store. """ raise NotImplementedError
[docs] def get_table_query_string(self) -> str: """ Returns a string that can directly be used to reference this table in SQL. """ raise NotImplementedError
[docs]@typechecked class KafkaSource(DataSource): def __init__( self, *, name: str, timestamp_field: str, message_format: StreamFormat, bootstrap_servers: Optional[str] = None, kafka_bootstrap_servers: Optional[str] = None, topic: Optional[str] = None, created_timestamp_column: Optional[str] = "", field_mapping: Optional[Dict[str, str]] = None, description: Optional[str] = "", tags: Optional[Dict[str, str]] = None, owner: Optional[str] = "", batch_source: Optional[DataSource] = None, watermark_delay_threshold: Optional[timedelta] = None, ): """ Creates a KafkaSource object. Args: name: Name of data source, which should be unique within a project timestamp_field: Event timestamp field used for point-in-time joins of feature values. message_format: StreamFormat of serialized messages. bootstrap_servers: (Deprecated) The servers of the kafka broker in the form "localhost:9092". kafka_bootstrap_servers (optional): The servers of the kafka broker in the form "localhost:9092". topic (optional): The name of the topic to read from in the kafka source. created_timestamp_column (optional): Timestamp column indicating when the row was created, used for deduplicating rows. field_mapping (optional): A dictionary mapping of column names in this data source to feature names in a feature table or view. Only used for feature columns, not entity or timestamp columns. description (optional): A human-readable description. tags (optional): A dictionary of key-value pairs to store arbitrary metadata. owner (optional): The owner of the data source, typically the email of the primary maintainer. batch_source (optional): The datasource that acts as a batch source. watermark_delay_threshold (optional): The watermark delay threshold for stream data. Specifically how late stream data can arrive without being discarded. """ if bootstrap_servers: warnings.warn( ( "The 'bootstrap_servers' parameter has been deprecated in favor of 'kafka_bootstrap_servers'. " "Feast 0.25 and onwards will not support the 'bootstrap_servers' parameter." ), DeprecationWarning, ) super().__init__( name=name, timestamp_field=timestamp_field, created_timestamp_column=created_timestamp_column, field_mapping=field_mapping, description=description, tags=tags, owner=owner, ) self.batch_source = batch_source kafka_bootstrap_servers = kafka_bootstrap_servers or bootstrap_servers or "" topic = topic or "" self.kafka_options = KafkaOptions( kafka_bootstrap_servers=kafka_bootstrap_servers, message_format=message_format, topic=topic, watermark_delay_threshold=watermark_delay_threshold, ) def __eq__(self, other): if not isinstance(other, KafkaSource): raise TypeError( "Comparisons should only involve KafkaSource class objects." ) if not super().__eq__(other): return False if ( self.kafka_options.kafka_bootstrap_servers != other.kafka_options.kafka_bootstrap_servers or self.kafka_options.message_format != other.kafka_options.message_format or self.kafka_options.topic != other.kafka_options.topic or self.kafka_options.watermark_delay_threshold != other.kafka_options.watermark_delay_threshold ): return False return True def __hash__(self): return super().__hash__()
[docs] @staticmethod def from_proto(data_source: DataSourceProto): watermark_delay_threshold = None if data_source.kafka_options.watermark_delay_threshold: watermark_delay_threshold = ( timedelta(days=0) if data_source.kafka_options.watermark_delay_threshold.ToNanoseconds() == 0 else data_source.kafka_options.watermark_delay_threshold.ToTimedelta() ) return KafkaSource(, field_mapping=dict(data_source.field_mapping), kafka_bootstrap_servers=data_source.kafka_options.kafka_bootstrap_servers, message_format=StreamFormat.from_proto( data_source.kafka_options.message_format ), watermark_delay_threshold=watermark_delay_threshold, topic=data_source.kafka_options.topic, created_timestamp_column=data_source.created_timestamp_column, timestamp_field=data_source.timestamp_field, description=data_source.description, tags=dict(data_source.tags), owner=data_source.owner, batch_source=DataSource.from_proto(data_source.batch_source) if data_source.batch_source else None, )
[docs] def to_proto(self) -> DataSourceProto: data_source_proto = DataSourceProto(, type=DataSourceProto.STREAM_KAFKA, field_mapping=self.field_mapping, kafka_options=self.kafka_options.to_proto(), description=self.description, tags=self.tags, owner=self.owner, ) data_source_proto.timestamp_field = self.timestamp_field data_source_proto.created_timestamp_column = self.created_timestamp_column if self.batch_source: data_source_proto.batch_source.MergeFrom(self.batch_source.to_proto()) return data_source_proto
[docs] def validate(self, config: RepoConfig): pass
[docs] def get_table_column_names_and_types( self, config: RepoConfig ) -> Iterable[Tuple[str, str]]: pass
[docs] @staticmethod def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]: return type_map.redshift_to_feast_value_type
[docs] def get_table_query_string(self) -> str: raise NotImplementedError
[docs]@typechecked class RequestSource(DataSource): """ RequestSource that can be used to provide input features for on demand transforms Attributes: name: Name of the request data source schema: Schema mapping from the input feature name to a ValueType description: A human-readable description. tags: A dictionary of key-value pairs to store arbitrary metadata. owner: The owner of the request data source, typically the email of the primary maintainer. """ name: str schema: List[Field] description: str tags: Dict[str, str] owner: str def __init__( self, *, name: str, schema: List[Field], description: Optional[str] = "", tags: Optional[Dict[str, str]] = None, owner: Optional[str] = "", ): """Creates a RequestSource object.""" super().__init__(name=name, description=description, tags=tags, owner=owner) self.schema = schema
[docs] def validate(self, config: RepoConfig): pass
[docs] def get_table_column_names_and_types( self, config: RepoConfig ) -> Iterable[Tuple[str, str]]: pass
def __eq__(self, other): if not isinstance(other, RequestSource): raise TypeError( "Comparisons should only involve RequestSource class objects." ) if not super().__eq__(other): return False if isinstance(self.schema, List) and isinstance(other.schema, List): for field1, field2 in zip(self.schema, other.schema): if field1 != field2: return False return True else: return False def __hash__(self): return super().__hash__()
[docs] @staticmethod def from_proto(data_source: DataSourceProto): schema_pb = data_source.request_data_options.schema list_schema = [] for field_proto in schema_pb: list_schema.append(Field.from_proto(field_proto)) return RequestSource(, schema=list_schema, description=data_source.description, tags=dict(data_source.tags), owner=data_source.owner, )
[docs] def to_proto(self) -> DataSourceProto: schema_pb = [] if isinstance(self.schema, Dict): for key, value in self.schema.items(): schema_pb.append( Field(name=key, dtype=from_value_type(value.value)).to_proto() ) else: for field in self.schema: schema_pb.append(field.to_proto()) data_source_proto = DataSourceProto(, type=DataSourceProto.REQUEST_SOURCE, description=self.description, tags=self.tags, owner=self.owner, ) data_source_proto.request_data_options.schema.extend(schema_pb) return data_source_proto
[docs] def get_table_query_string(self) -> str: raise NotImplementedError
[docs] @staticmethod def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]: raise NotImplementedError
[docs]@typechecked class KinesisSource(DataSource):
[docs] def validate(self, config: RepoConfig): pass
[docs] def get_table_column_names_and_types( self, config: RepoConfig ) -> Iterable[Tuple[str, str]]: pass
[docs] @staticmethod def from_proto(data_source: DataSourceProto): return KinesisSource(, timestamp_field=data_source.timestamp_field, field_mapping=dict(data_source.field_mapping), record_format=StreamFormat.from_proto( data_source.kinesis_options.record_format ), region=data_source.kinesis_options.region, stream_name=data_source.kinesis_options.stream_name, created_timestamp_column=data_source.created_timestamp_column, description=data_source.description, tags=dict(data_source.tags), owner=data_source.owner, batch_source=DataSource.from_proto(data_source.batch_source) if data_source.batch_source else None, )
[docs] @staticmethod def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]: pass
[docs] def get_table_query_string(self) -> str: raise NotImplementedError
def __init__( self, *, name: str, record_format: StreamFormat, region: str, stream_name: str, timestamp_field: Optional[str] = "", created_timestamp_column: Optional[str] = "", field_mapping: Optional[Dict[str, str]] = None, description: Optional[str] = "", tags: Optional[Dict[str, str]] = None, owner: Optional[str] = "", batch_source: Optional[DataSource] = None, ): if record_format is None: raise ValueError("Record format must be specified for kinesis source") super().__init__( name=name, timestamp_field=timestamp_field, created_timestamp_column=created_timestamp_column, field_mapping=field_mapping, description=description, tags=tags, owner=owner, ) self.batch_source = batch_source self.kinesis_options = KinesisOptions( record_format=record_format, region=region, stream_name=stream_name ) def __eq__(self, other): if not isinstance(other, KinesisSource): raise TypeError( "Comparisons should only involve KinesisSource class objects." ) if not super().__eq__(other): return False if ( self.kinesis_options.record_format != other.kinesis_options.record_format or self.kinesis_options.region != other.kinesis_options.region or self.kinesis_options.stream_name != other.kinesis_options.stream_name ): return False return True def __hash__(self): return super().__hash__()
[docs] def to_proto(self) -> DataSourceProto: data_source_proto = DataSourceProto(, type=DataSourceProto.STREAM_KINESIS, field_mapping=self.field_mapping, kinesis_options=self.kinesis_options.to_proto(), description=self.description, tags=self.tags, owner=self.owner, ) data_source_proto.timestamp_field = self.timestamp_field data_source_proto.created_timestamp_column = self.created_timestamp_column if self.batch_source: data_source_proto.batch_source.MergeFrom(self.batch_source.to_proto()) return data_source_proto
[docs]class PushMode(enum.Enum): ONLINE = 1 OFFLINE = 2 ONLINE_AND_OFFLINE = 3
[docs]@typechecked class PushSource(DataSource): """ A source that can be used to ingest features on request """ # TODO(adchia): consider adding schema here in case where Feast manages pushing events to the offline store # TODO(adchia): consider a "mode" to support pushing raw vs transformed events batch_source: DataSource def __init__( self, *, name: str, batch_source: DataSource, description: Optional[str] = "", tags: Optional[Dict[str, str]] = None, owner: Optional[str] = "", ): """ Creates a PushSource object. Args: name: Name of the push source batch_source: The batch source that backs this push source. It's used when materializing from the offline store to the online store, and when retrieving historical features. description (optional): A human-readable description. tags (optional): A dictionary of key-value pairs to store arbitrary metadata. owner (optional): The owner of the data source, typically the email of the primary maintainer. """ super().__init__(name=name, description=description, tags=tags, owner=owner) self.batch_source = batch_source def __eq__(self, other): if not isinstance(other, PushSource): return False if not super().__eq__(other): return False if self.batch_source != other.batch_source: return False return True def __hash__(self): return super().__hash__()
[docs] def validate(self, config: RepoConfig): pass
[docs] def get_table_column_names_and_types( self, config: RepoConfig ) -> Iterable[Tuple[str, str]]: pass
[docs] @staticmethod def from_proto(data_source: DataSourceProto): assert data_source.HasField("batch_source") batch_source = DataSource.from_proto(data_source.batch_source) return PushSource(, batch_source=batch_source, description=data_source.description, tags=dict(data_source.tags), owner=data_source.owner, )
[docs] def to_proto(self) -> DataSourceProto: batch_source_proto = None if self.batch_source: batch_source_proto = self.batch_source.to_proto() data_source_proto = DataSourceProto(, type=DataSourceProto.PUSH_SOURCE, description=self.description, tags=self.tags, owner=self.owner, batch_source=batch_source_proto, ) return data_source_proto
[docs] def get_table_query_string(self) -> str: raise NotImplementedError
[docs] @staticmethod def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]: raise NotImplementedError