Source code for feast.infra.offline_stores.contrib.spark_offline_store.spark_source

import logging
import pickle
import traceback
import warnings
from enum import Enum
from typing import Any, Callable, Dict, Iterable, Optional, Tuple

from pyspark.sql import SparkSession

from feast.data_source import DataSource
from feast.errors import DataSourceNoNameException
from feast.infra.offline_stores.offline_utils import get_temp_entity_table_name
from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto
from feast.protos.feast.core.SavedDataset_pb2 import (
    SavedDatasetStorage as SavedDatasetStorageProto,
)
from feast.repo_config import RepoConfig
from feast.saved_dataset import SavedDatasetStorage
from feast.type_map import spark_to_feast_value_type
from feast.value_type import ValueType

logger = logging.getLogger(__name__)


class SparkSourceFormat(Enum):
    csv = "csv"
    json = "json"
    parquet = "parquet"


[docs]class SparkSource(DataSource): def __init__( self, name: Optional[str] = None, table: Optional[str] = None, query: Optional[str] = None, path: Optional[str] = None, file_format: Optional[str] = None, event_timestamp_column: Optional[str] = None, created_timestamp_column: Optional[str] = None, field_mapping: Optional[Dict[str, str]] = None, date_partition_column: Optional[str] = None, ): # If no name, use the table_ref as the default name _name = name if not _name: if table: _name = table else: raise DataSourceNoNameException() super().__init__( _name, event_timestamp_column, created_timestamp_column, field_mapping, date_partition_column, ) warnings.warn( "The spark data source API is an experimental feature in alpha development. " "This API is unstable and it could and most probably will be changed in the future.", RuntimeWarning, ) self.allowed_formats = [format.value for format in SparkSourceFormat] # Check that only one of the ways to load a spark dataframe can be used. if sum([(arg is not None) for arg in [table, query, path]]) != 1: raise ValueError( "Exactly one of params(table, query, path) must be specified." ) if path is not None: if file_format is None: raise ValueError( "If 'path' is specified, then 'file_format' is required." ) if file_format not in self.allowed_formats: raise ValueError( f"'file_format' should be one of {self.allowed_formats}" ) self.spark_options = SparkOptions( table=table, query=query, path=path, file_format=file_format, ) @property def table(self): """ Returns the table of this feature data source """ return self.spark_options.table @property def query(self): """ Returns the query of this feature data source """ return self.spark_options.query @property def path(self): """ Returns the path of the spark data source file. """ return self.spark_options.path @property def file_format(self): """ Returns the file format of this feature data source. """ return self.spark_options.file_format
[docs] @staticmethod def from_proto(data_source: DataSourceProto) -> Any: assert data_source.HasField("custom_options") spark_options = SparkOptions.from_proto(data_source.custom_options) return SparkSource( name=data_source.name, field_mapping=dict(data_source.field_mapping), table=spark_options.table, query=spark_options.query, path=spark_options.path, file_format=spark_options.file_format, event_timestamp_column=data_source.event_timestamp_column, created_timestamp_column=data_source.created_timestamp_column, date_partition_column=data_source.date_partition_column, )
[docs] def to_proto(self) -> DataSourceProto: data_source_proto = DataSourceProto( name=self.name, type=DataSourceProto.CUSTOM_SOURCE, field_mapping=self.field_mapping, custom_options=self.spark_options.to_proto(), ) data_source_proto.event_timestamp_column = self.event_timestamp_column data_source_proto.created_timestamp_column = self.created_timestamp_column data_source_proto.date_partition_column = self.date_partition_column return data_source_proto
[docs] def validate(self, config: RepoConfig): self.get_table_column_names_and_types(config)
[docs] @staticmethod def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]: return spark_to_feast_value_type
[docs] def get_table_column_names_and_types( self, config: RepoConfig ) -> Iterable[Tuple[str, str]]: from feast.infra.offline_stores.contrib.spark_offline_store.spark import ( get_spark_session_or_start_new_with_repoconfig, ) spark_session = get_spark_session_or_start_new_with_repoconfig( store_config=config.offline_store ) df = spark_session.sql(f"SELECT * FROM {self.get_table_query_string()}") return ( (fields["name"], fields["type"]) for fields in df.schema.jsonValue()["fields"] )
[docs] def get_table_query_string(self) -> str: """Returns a string that can directly be used to reference this table in SQL""" if self.table: # Backticks make sure that spark sql knows this a table reference. return f"`{self.table}`" if self.query: return f"({self.query})" # If both the table query string and the actual query are null, we can load from file. spark_session = SparkSession.getActiveSession() if spark_session is None: raise AssertionError("Could not find an active spark session.") try: df = spark_session.read.format(self.file_format).load(self.path) except Exception: logger.exception( "Spark read of file source failed.\n" + traceback.format_exc() ) tmp_table_name = get_temp_entity_table_name() df.createOrReplaceTempView(tmp_table_name) return f"`{tmp_table_name}`"
class SparkOptions: def __init__( self, table: Optional[str] = None, query: Optional[str] = None, path: Optional[str] = None, file_format: Optional[str] = None, ): self.table = table self.query = query self.path = path self.file_format = file_format @classmethod def from_proto(cls, spark_options_proto: DataSourceProto.CustomSourceOptions): """ Creates a SparkOptions from a protobuf representation of a spark option args: spark_options_proto: a protobuf representation of a datasource Returns: Returns a SparkOptions object based on the spark_options protobuf """ spark_configuration = pickle.loads(spark_options_proto.configuration) spark_options = cls( table=spark_configuration.table, query=spark_configuration.query, path=spark_configuration.path, file_format=spark_configuration.file_format, ) return spark_options def to_proto(self) -> DataSourceProto.CustomSourceOptions: """ Converts an SparkOptionsProto object to its protobuf representation. Returns: SparkOptionsProto protobuf """ spark_options_proto = DataSourceProto.CustomSourceOptions( configuration=pickle.dumps(self), ) return spark_options_proto class SavedDatasetSparkStorage(SavedDatasetStorage): _proto_attr_name = "spark_storage" spark_options: SparkOptions def __init__(self, table_ref: Optional[str] = None, query: Optional[str] = None): self.spark_options = SparkOptions(table=table_ref, query=query) @staticmethod def from_proto(storage_proto: SavedDatasetStorageProto) -> SavedDatasetStorage: # TODO: implementation is not correct. Needs fix and update to protos. return SavedDatasetSparkStorage(table_ref="", query=None) def to_proto(self) -> SavedDatasetStorageProto: return SavedDatasetStorageProto() def to_data_source(self) -> DataSource: return SparkSource(table=self.spark_options.table)