Source code for feast.infra.offline_stores.file

import uuid
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union

import dask.dataframe as dd
import pandas as pd
import pyarrow
import pyarrow.dataset
import pyarrow.parquet
import pytz
from pydantic.typing import Literal

from feast import FileSource, OnDemandFeatureView
from feast.data_source import DataSource
from feast.errors import FeastJoinKeysDuringMaterialization
from feast.feature_logging import LoggingConfig, LoggingSource
from feast.feature_view import DUMMY_ENTITY_ID, DUMMY_ENTITY_VAL, FeatureView
from feast.infra.offline_stores.file_source import (
    FileLoggingDestination,
    SavedDatasetFileStorage,
)
from feast.infra.offline_stores.offline_store import (
    OfflineStore,
    RetrievalJob,
    RetrievalMetadata,
)
from feast.infra.offline_stores.offline_utils import (
    DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL,
)
from feast.infra.provider import (
    _get_requested_feature_views_to_features_dict,
    _run_dask_field_mapping,
)
from feast.registry import BaseRegistry
from feast.repo_config import FeastConfigBaseModel, RepoConfig
from feast.saved_dataset import SavedDatasetStorage
from feast.usage import log_exceptions_and_usage


[docs]class FileOfflineStoreConfig(FeastConfigBaseModel): """Offline store config for local (file-based) store""" type: Literal["file"] = "file" """ Offline store type selector"""
[docs]class FileRetrievalJob(RetrievalJob): def __init__( self, evaluation_function: Callable, full_feature_names: bool, on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None, metadata: Optional[RetrievalMetadata] = None, ): """Initialize a lazy historical retrieval job""" # The evaluation function executes a stored procedure to compute a historical retrieval. self.evaluation_function = evaluation_function self._full_feature_names = full_feature_names self._on_demand_feature_views = ( on_demand_feature_views if on_demand_feature_views else [] ) self._metadata = metadata @property def full_feature_names(self) -> bool: return self._full_feature_names @property def on_demand_feature_views(self) -> Optional[List[OnDemandFeatureView]]: return self._on_demand_feature_views @log_exceptions_and_usage def _to_df_internal(self) -> pd.DataFrame: # Only execute the evaluation function to build the final historical retrieval dataframe at the last moment. df = self.evaluation_function().compute() df = df.reset_index(drop=True) return df @log_exceptions_and_usage def _to_arrow_internal(self): # Only execute the evaluation function to build the final historical retrieval dataframe at the last moment. df = self.evaluation_function().compute() return pyarrow.Table.from_pandas(df)
[docs] def persist(self, storage: SavedDatasetStorage): assert isinstance(storage, SavedDatasetFileStorage) filesystem, path = FileSource.create_filesystem_and_path( storage.file_options.uri, storage.file_options.s3_endpoint_override, ) if path.endswith(".parquet"): pyarrow.parquet.write_table( self.to_arrow(), where=path, filesystem=filesystem ) else: # otherwise assume destination is directory pyarrow.parquet.write_to_dataset( self.to_arrow(), root_path=path, filesystem=filesystem )
@property def metadata(self) -> Optional[RetrievalMetadata]: return self._metadata
[docs]class FileOfflineStore(OfflineStore):
[docs] @staticmethod @log_exceptions_and_usage(offline_store="file") def get_historical_features( config: RepoConfig, feature_views: List[FeatureView], feature_refs: List[str], entity_df: Union[pd.DataFrame, str], registry: BaseRegistry, project: str, full_feature_names: bool = False, ) -> RetrievalJob: if not isinstance(entity_df, pd.DataFrame) and not isinstance( entity_df, dd.DataFrame ): raise ValueError( f"Please provide an entity_df of type {type(pd.DataFrame)} instead of type {type(entity_df)}" ) entity_df_event_timestamp_col = DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL # local modifiable copy of global variable if entity_df_event_timestamp_col not in entity_df.columns: datetime_columns = entity_df.select_dtypes( include=["datetime", "datetimetz"] ).columns if len(datetime_columns) == 1: print( f"Using {datetime_columns[0]} as the event timestamp. To specify a column explicitly, please name it {DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL}." ) entity_df_event_timestamp_col = datetime_columns[0] else: raise ValueError( f"Please provide an entity_df with a column named {DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL} representing the time of events." ) ( feature_views_to_features, on_demand_feature_views_to_features, ) = _get_requested_feature_views_to_features_dict( feature_refs, feature_views, registry.list_on_demand_feature_views(config.project), ) entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range( entity_df, entity_df_event_timestamp_col ) # Create lazy function that is only called from the RetrievalJob object def evaluate_historical_retrieval(): # Create a copy of entity_df to prevent modifying the original entity_df_with_features = entity_df.copy() entity_df_event_timestamp_col_type = entity_df_with_features.dtypes[ entity_df_event_timestamp_col ] if ( not hasattr(entity_df_event_timestamp_col_type, "tz") or entity_df_event_timestamp_col_type.tz != pytz.UTC ): # Make sure all event timestamp fields are tz-aware. We default tz-naive fields to UTC entity_df_with_features[ entity_df_event_timestamp_col ] = entity_df_with_features[entity_df_event_timestamp_col].apply( lambda x: x if x.tzinfo is not None else x.replace(tzinfo=pytz.utc) ) # Convert event timestamp column to datetime and normalize time zone to UTC # This is necessary to avoid issues with pd.merge_asof if isinstance(entity_df_with_features, dd.DataFrame): entity_df_with_features[ entity_df_event_timestamp_col ] = dd.to_datetime( entity_df_with_features[entity_df_event_timestamp_col], utc=True ) else: entity_df_with_features[ entity_df_event_timestamp_col ] = pd.to_datetime( entity_df_with_features[entity_df_event_timestamp_col], utc=True ) # Sort event timestamp values entity_df_with_features = entity_df_with_features.sort_values( entity_df_event_timestamp_col ) all_join_keys = [] # Load feature view data from sources and join them incrementally for feature_view, features in feature_views_to_features.items(): timestamp_field = feature_view.batch_source.timestamp_field created_timestamp_column = ( feature_view.batch_source.created_timestamp_column ) # Build a list of entity columns to join on (from the right table) join_keys = [] for entity_column in feature_view.entity_columns: join_key = feature_view.projection.join_key_map.get( entity_column.name, entity_column.name ) join_keys.append(join_key) right_entity_key_columns = [ timestamp_field, created_timestamp_column, ] + join_keys right_entity_key_columns = [c for c in right_entity_key_columns if c] all_join_keys = list(set(all_join_keys + join_keys)) df_to_join = _read_datasource(feature_view.batch_source) df_to_join, timestamp_field = _field_mapping( df_to_join, feature_view, features, right_entity_key_columns, entity_df_event_timestamp_col, timestamp_field, full_feature_names, ) df_to_join = _merge(entity_df_with_features, df_to_join, join_keys) df_to_join = _normalize_timestamp( df_to_join, timestamp_field, created_timestamp_column ) df_to_join = _filter_ttl( df_to_join, feature_view, entity_df_event_timestamp_col, timestamp_field, ) df_to_join = _drop_duplicates( df_to_join, all_join_keys, timestamp_field, created_timestamp_column, entity_df_event_timestamp_col, ) entity_df_with_features = _drop_columns( df_to_join, timestamp_field, created_timestamp_column ) # Ensure that we delete dataframes to free up memory del df_to_join return entity_df_with_features.persist() job = FileRetrievalJob( evaluation_function=evaluate_historical_retrieval, full_feature_names=full_feature_names, on_demand_feature_views=OnDemandFeatureView.get_requested_odfvs( feature_refs, project, registry ), metadata=RetrievalMetadata( features=feature_refs, keys=list(set(entity_df.columns) - {entity_df_event_timestamp_col}), min_event_timestamp=entity_df_event_timestamp_range[0], max_event_timestamp=entity_df_event_timestamp_range[1], ), ) return job
[docs] @staticmethod @log_exceptions_and_usage(offline_store="file") def pull_latest_from_table_or_query( config: RepoConfig, data_source: DataSource, join_key_columns: List[str], feature_name_columns: List[str], timestamp_field: str, created_timestamp_column: Optional[str], start_date: datetime, end_date: datetime, ) -> RetrievalJob: assert isinstance(data_source, FileSource) # Create lazy function that is only called from the RetrievalJob object def evaluate_offline_job(): source_df = _read_datasource(data_source) source_df = _normalize_timestamp( source_df, timestamp_field, created_timestamp_column ) source_columns = set(source_df.columns) if not set(join_key_columns).issubset(source_columns): raise FeastJoinKeysDuringMaterialization( data_source.path, set(join_key_columns), source_columns ) ts_columns = ( [timestamp_field, created_timestamp_column] if created_timestamp_column else [timestamp_field] ) # try-catch block is added to deal with this issue https://github.com/dask/dask/issues/8939. # TODO(kevjumba): remove try catch when fix is merged upstream in Dask. try: if created_timestamp_column: source_df = source_df.sort_values(by=created_timestamp_column,) source_df = source_df.sort_values(by=timestamp_field) except ZeroDivisionError: # Use 1 partition to get around case where everything in timestamp column is the same so the partition algorithm doesn't # try to divide by zero. if created_timestamp_column: source_df = source_df.sort_values( by=created_timestamp_column, npartitions=1 ) source_df = source_df.sort_values(by=timestamp_field, npartitions=1) source_df = source_df[ (source_df[timestamp_field] >= start_date) & (source_df[timestamp_field] < end_date) ] source_df = source_df.persist() columns_to_extract = set( join_key_columns + feature_name_columns + ts_columns ) if join_key_columns: source_df = source_df.drop_duplicates( join_key_columns, keep="last", ignore_index=True ) else: source_df[DUMMY_ENTITY_ID] = DUMMY_ENTITY_VAL columns_to_extract.add(DUMMY_ENTITY_ID) source_df = source_df.persist() return source_df[list(columns_to_extract)].persist() # When materializing a single feature view, we don't need full feature names. On demand transforms aren't materialized return FileRetrievalJob( evaluation_function=evaluate_offline_job, full_feature_names=False, )
[docs] @staticmethod @log_exceptions_and_usage(offline_store="file") def pull_all_from_table_or_query( config: RepoConfig, data_source: DataSource, join_key_columns: List[str], feature_name_columns: List[str], timestamp_field: str, start_date: datetime, end_date: datetime, ) -> RetrievalJob: return FileOfflineStore.pull_latest_from_table_or_query( config=config, data_source=data_source, join_key_columns=join_key_columns + [timestamp_field], # avoid deduplication feature_name_columns=feature_name_columns, timestamp_field=timestamp_field, created_timestamp_column=None, start_date=start_date, end_date=end_date, )
[docs] @staticmethod def write_logged_features( config: RepoConfig, data: Union[pyarrow.Table, Path], source: LoggingSource, logging_config: LoggingConfig, registry: BaseRegistry, ): destination = logging_config.destination assert isinstance(destination, FileLoggingDestination) if isinstance(data, Path): # Since this code will be mostly used from Go-created thread, it's better to avoid producing new threads data = pyarrow.parquet.read_table(data, use_threads=False, pre_buffer=False) filesystem, path = FileSource.create_filesystem_and_path( destination.path, destination.s3_endpoint_override, ) pyarrow.dataset.write_dataset( data, base_dir=path, basename_template=f"{uuid.uuid4().hex}-{{i}}.parquet", partitioning=destination.partition_by, filesystem=filesystem, use_threads=False, format=pyarrow.dataset.ParquetFileFormat(), existing_data_behavior="overwrite_or_ignore", )
[docs] @staticmethod def offline_write_batch( config: RepoConfig, feature_view: FeatureView, data: pyarrow.Table, progress: Optional[Callable[[int], Any]], ): if not feature_view.batch_source: raise ValueError( "feature view does not have a batch source to persist offline data" ) if not isinstance(config.offline_store, FileOfflineStoreConfig): raise ValueError( f"offline store config is of type {type(config.offline_store)} when file type required" ) if not isinstance(feature_view.batch_source, FileSource): raise ValueError( f"feature view batch source is {type(feature_view.batch_source)} not file source" ) file_options = feature_view.batch_source.file_options filesystem, path = FileSource.create_filesystem_and_path( file_options.uri, file_options.s3_endpoint_override ) prev_table = pyarrow.parquet.read_table(path, memory_map=True) if prev_table.column_names != data.column_names: raise ValueError( f"Input dataframe has incorrect schema or wrong order, expected columns are: {prev_table.column_names}" ) if data.schema != prev_table.schema: data = data.cast(prev_table.schema) new_table = pyarrow.concat_tables([data, prev_table]) writer = pyarrow.parquet.ParquetWriter(path, data.schema, filesystem=filesystem) writer.write_table(new_table) writer.close()
def _get_entity_df_event_timestamp_range( entity_df: Union[pd.DataFrame, str], entity_df_event_timestamp_col: str, ) -> Tuple[datetime, datetime]: if not isinstance(entity_df, pd.DataFrame): raise ValueError( f"Please provide an entity_df of type {type(pd.DataFrame)} instead of type {type(entity_df)}" ) entity_df_event_timestamp = entity_df.loc[ :, entity_df_event_timestamp_col ].infer_objects() if pd.api.types.is_string_dtype(entity_df_event_timestamp): entity_df_event_timestamp = pd.to_datetime(entity_df_event_timestamp, utc=True) return ( entity_df_event_timestamp.min().to_pydatetime(), entity_df_event_timestamp.max().to_pydatetime(), ) def _read_datasource(data_source) -> dd.DataFrame: storage_options = ( { "client_kwargs": { "endpoint_url": data_source.file_options.s3_endpoint_override } } if data_source.file_options.s3_endpoint_override else None ) return dd.read_parquet(data_source.path, storage_options=storage_options,) def _field_mapping( df_to_join: dd.DataFrame, feature_view: FeatureView, features: List[str], right_entity_key_columns: List[str], entity_df_event_timestamp_col: str, timestamp_field: str, full_feature_names: bool, ) -> Tuple[dd.DataFrame, str]: # Rename columns by the field mapping dictionary if it exists if feature_view.batch_source.field_mapping: df_to_join = _run_dask_field_mapping( df_to_join, feature_view.batch_source.field_mapping ) # Rename entity columns by the join_key_map dictionary if it exists if feature_view.projection.join_key_map: df_to_join = _run_dask_field_mapping( df_to_join, feature_view.projection.join_key_map ) # Build a list of all the features we should select from this source feature_names = [] columns_map = {} for feature in features: # Modify the separator for feature refs in column names to double underscore. We are using # double underscore as separator for consistency with other databases like BigQuery, # where there are very few characters available for use as separators if full_feature_names: formatted_feature_name = ( f"{feature_view.projection.name_to_use()}__{feature}" ) else: formatted_feature_name = feature # Add the feature name to the list of columns feature_names.append(formatted_feature_name) columns_map[feature] = formatted_feature_name # Ensure that the source dataframe feature column includes the feature view name as a prefix df_to_join = _run_dask_field_mapping(df_to_join, columns_map) # Select only the columns we need to join from the feature dataframe df_to_join = df_to_join[right_entity_key_columns + feature_names] df_to_join = df_to_join.persist() # Make sure to not have duplicated columns if entity_df_event_timestamp_col == timestamp_field: df_to_join = _run_dask_field_mapping( df_to_join, {timestamp_field: f"__{timestamp_field}"}, ) timestamp_field = f"__{timestamp_field}" return df_to_join.persist(), timestamp_field def _merge( entity_df_with_features: dd.DataFrame, df_to_join: dd.DataFrame, join_keys: List[str], ) -> dd.DataFrame: # tmp join keys needed for cross join with null join table view tmp_join_keys = [] if not join_keys: entity_df_with_features["__tmp"] = 1 df_to_join["__tmp"] = 1 tmp_join_keys = ["__tmp"] # Get only data with requested entities df_to_join = dd.merge( entity_df_with_features, df_to_join, left_on=join_keys or tmp_join_keys, right_on=join_keys or tmp_join_keys, suffixes=("", "__"), how="left", ) if tmp_join_keys: df_to_join = df_to_join.drop(tmp_join_keys, axis=1).persist() else: df_to_join = df_to_join.persist() return df_to_join def _normalize_timestamp( df_to_join: dd.DataFrame, timestamp_field: str, created_timestamp_column: str, ) -> dd.DataFrame: df_to_join_types = df_to_join.dtypes timestamp_field_type = df_to_join_types[timestamp_field] if created_timestamp_column: created_timestamp_column_type = df_to_join_types[created_timestamp_column] if not hasattr(timestamp_field_type, "tz") or timestamp_field_type.tz != pytz.UTC: # Make sure all timestamp fields are tz-aware. We default tz-naive fields to UTC df_to_join[timestamp_field] = df_to_join[timestamp_field].apply( lambda x: x if x.tzinfo is not None else x.replace(tzinfo=pytz.utc), meta=(timestamp_field, "datetime64[ns, UTC]"), ) if created_timestamp_column and ( not hasattr(created_timestamp_column_type, "tz") or created_timestamp_column_type.tz != pytz.UTC ): df_to_join[created_timestamp_column] = df_to_join[ created_timestamp_column ].apply( lambda x: x if x.tzinfo is not None else x.replace(tzinfo=pytz.utc), meta=(timestamp_field, "datetime64[ns, UTC]"), ) return df_to_join.persist() def _filter_ttl( df_to_join: dd.DataFrame, feature_view: FeatureView, entity_df_event_timestamp_col: str, timestamp_field: str, ) -> dd.DataFrame: # Filter rows by defined timestamp tolerance if feature_view.ttl and feature_view.ttl.total_seconds() != 0: df_to_join = df_to_join[ # do not drop entity rows if one of the sources returns NaNs df_to_join[timestamp_field].isna() | ( ( df_to_join[timestamp_field] >= df_to_join[entity_df_event_timestamp_col] - feature_view.ttl ) & ( df_to_join[timestamp_field] <= df_to_join[entity_df_event_timestamp_col] ) ) ] df_to_join = df_to_join.persist() return df_to_join def _drop_duplicates( df_to_join: dd.DataFrame, all_join_keys: List[str], timestamp_field: str, created_timestamp_column: str, entity_df_event_timestamp_col: str, ) -> dd.DataFrame: if created_timestamp_column: df_to_join = df_to_join.sort_values( by=created_timestamp_column, na_position="first" ) df_to_join = df_to_join.persist() df_to_join = df_to_join.sort_values(by=timestamp_field, na_position="first") df_to_join = df_to_join.persist() df_to_join = df_to_join.drop_duplicates( all_join_keys + [entity_df_event_timestamp_col], keep="last", ignore_index=True, ) return df_to_join.persist() def _drop_columns( df_to_join: dd.DataFrame, timestamp_field: str, created_timestamp_column: str, ) -> dd.DataFrame: entity_df_with_features = df_to_join.drop([timestamp_field], axis=1).persist() if created_timestamp_column: entity_df_with_features = entity_df_with_features.drop( [created_timestamp_column], axis=1 ).persist() return entity_df_with_features