Source code for feast.infra.offline_stores.contrib.postgres_offline_store.postgres

import contextlib
from dataclasses import asdict
from datetime import datetime
from typing import (
    Any,
    Callable,
    ContextManager,
    Iterator,
    KeysView,
    List,
    Optional,
    Tuple,
    Union,
)

import pandas as pd
import pyarrow as pa
from jinja2 import BaseLoader, Environment
from psycopg2 import sql
from pydantic.typing import Literal
from pytz import utc

from feast.data_source import DataSource
from feast.errors import InvalidEntityType
from feast.feature_view import DUMMY_ENTITY_ID, DUMMY_ENTITY_VAL, FeatureView
from feast.infra.offline_stores import offline_utils
from feast.infra.offline_stores.offline_store import (
    OfflineStore,
    RetrievalJob,
    RetrievalMetadata,
)
from feast.infra.utils.postgres.connection_utils import (
    _get_conn,
    df_to_postgres_table,
    get_query_schema,
)
from feast.infra.utils.postgres.postgres_config import PostgreSQLConfig
from feast.on_demand_feature_view import OnDemandFeatureView
from feast.registry import Registry
from feast.repo_config import RepoConfig
from feast.saved_dataset import SavedDatasetStorage
from feast.type_map import pg_type_code_to_arrow
from feast.usage import log_exceptions_and_usage

from .postgres_source import PostgreSQLSource


[docs]class PostgreSQLOfflineStoreConfig(PostgreSQLConfig): type: Literal["postgres"] = "postgres"
[docs]class PostgreSQLOfflineStore(OfflineStore):
[docs] @staticmethod @log_exceptions_and_usage(offline_store="postgres") def pull_latest_from_table_or_query( config: RepoConfig, data_source: DataSource, join_key_columns: List[str], feature_name_columns: List[str], timestamp_field: str, created_timestamp_column: Optional[str], start_date: datetime, end_date: datetime, ) -> RetrievalJob: assert isinstance(data_source, PostgreSQLSource) from_expression = data_source.get_table_query_string() partition_by_join_key_string = ", ".join(_append_alias(join_key_columns, "a")) if partition_by_join_key_string != "": partition_by_join_key_string = ( "PARTITION BY " + partition_by_join_key_string ) timestamps = [timestamp_field] if created_timestamp_column: timestamps.append(created_timestamp_column) timestamp_desc_string = " DESC, ".join(_append_alias(timestamps, "a")) + " DESC" a_field_string = ", ".join( _append_alias(join_key_columns + feature_name_columns + timestamps, "a") ) b_field_string = ", ".join( _append_alias(join_key_columns + feature_name_columns + timestamps, "b") ) query = f""" SELECT {b_field_string} {f", {repr(DUMMY_ENTITY_VAL)} AS {DUMMY_ENTITY_ID}" if not join_key_columns else ""} FROM ( SELECT {a_field_string}, ROW_NUMBER() OVER({partition_by_join_key_string} ORDER BY {timestamp_desc_string}) AS _feast_row FROM ({from_expression}) a WHERE a."{timestamp_field}" BETWEEN '{start_date}'::timestamptz AND '{end_date}'::timestamptz ) b WHERE _feast_row = 1 """ return PostgreSQLRetrievalJob( query=query, config=config, full_feature_names=False, on_demand_feature_views=None, )
[docs] @staticmethod @log_exceptions_and_usage(offline_store="postgres") def get_historical_features( config: RepoConfig, feature_views: List[FeatureView], feature_refs: List[str], entity_df: Union[pd.DataFrame, str], registry: Registry, project: str, full_feature_names: bool = False, ) -> RetrievalJob: @contextlib.contextmanager def query_generator() -> Iterator[str]: table_name = None if isinstance(entity_df, pd.DataFrame): table_name = offline_utils.get_temp_entity_table_name() entity_schema = df_to_postgres_table( config.offline_store, entity_df, table_name ) df_query = table_name elif isinstance(entity_df, str): df_query = f"({entity_df}) AS sub" entity_schema = get_query_schema(config.offline_store, df_query) else: raise TypeError(entity_df) entity_df_event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df( entity_schema ) expected_join_keys = offline_utils.get_expected_join_keys( project, feature_views, registry ) offline_utils.assert_expected_columns_in_entity_df( entity_schema, expected_join_keys, entity_df_event_timestamp_col ) entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range( entity_df, entity_df_event_timestamp_col, config, df_query, ) query_context = offline_utils.get_feature_view_query_context( feature_refs, feature_views, registry, project, entity_df_event_timestamp_range, ) query_context_dict = [asdict(context) for context in query_context] # Hack for query_context.entity_selections to support uppercase in columns for context in query_context_dict: context["entity_selections"] = [ f'''"{entity_selection.replace(' AS ', '" AS "')}\"''' for entity_selection in context["entity_selections"] ] try: yield build_point_in_time_query( query_context_dict, left_table_query_string=df_query, entity_df_event_timestamp_col=entity_df_event_timestamp_col, entity_df_columns=entity_schema.keys(), query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN, full_feature_names=full_feature_names, ) finally: if table_name: with _get_conn(config.offline_store) as conn, conn.cursor() as cur: cur.execute( sql.SQL( """ DROP TABLE IF EXISTS {}; """ ).format(sql.Identifier(table_name)), ) return PostgreSQLRetrievalJob( query=query_generator, config=config, full_feature_names=full_feature_names, on_demand_feature_views=OnDemandFeatureView.get_requested_odfvs( feature_refs, project, registry ), )
[docs] @staticmethod @log_exceptions_and_usage(offline_store="postgres") def pull_all_from_table_or_query( config: RepoConfig, data_source: DataSource, join_key_columns: List[str], feature_name_columns: List[str], timestamp_field: str, start_date: datetime, end_date: datetime, ) -> RetrievalJob: assert isinstance(data_source, PostgreSQLSource) from_expression = data_source.get_table_query_string() field_string = ", ".join( join_key_columns + feature_name_columns + [timestamp_field] ) start_date = start_date.astimezone(tz=utc) end_date = end_date.astimezone(tz=utc) query = f""" SELECT {field_string} FROM {from_expression} WHERE "{timestamp_field}" BETWEEN '{start_date}'::timestamptz AND '{end_date}'::timestamptz """ return PostgreSQLRetrievalJob( query=query, config=config, full_feature_names=False, on_demand_feature_views=None, )
[docs]class PostgreSQLRetrievalJob(RetrievalJob): def __init__( self, query: Union[str, Callable[[], ContextManager[str]]], config: RepoConfig, full_feature_names: bool, on_demand_feature_views: Optional[List[OnDemandFeatureView]], metadata: Optional[RetrievalMetadata] = None, ): if not isinstance(query, str): self._query_generator = query else: @contextlib.contextmanager def query_generator() -> Iterator[str]: assert isinstance(query, str) yield query self._query_generator = query_generator self.config = config self._full_feature_names = full_feature_names self._on_demand_feature_views = on_demand_feature_views self._metadata = metadata @property def full_feature_names(self) -> bool: return self._full_feature_names @property def on_demand_feature_views(self) -> Optional[List[OnDemandFeatureView]]: return self._on_demand_feature_views def _to_df_internal(self) -> pd.DataFrame: # We use arrow format because it gives better control of the table schema return self._to_arrow_internal().to_pandas()
[docs] def to_sql(self) -> str: with self._query_generator() as query: return query
def _to_arrow_internal(self) -> pa.Table: with self._query_generator() as query: with _get_conn(self.config.offline_store) as conn, conn.cursor() as cur: conn.set_session(readonly=True) cur.execute(query) fields = [ (c.name, pg_type_code_to_arrow(c.type_code)) for c in cur.description ] data = cur.fetchall() schema = pa.schema(fields) # TODO: Fix... data_transposed: List[List[Any]] = [] for col in range(len(fields)): data_transposed.append([]) for row in range(len(data)): data_transposed[col].append(data[row][col]) table = pa.Table.from_arrays( [pa.array(row) for row in data_transposed], schema=schema ) return table @property def metadata(self) -> Optional[RetrievalMetadata]: return self._metadata
[docs] def persist(self, storage: SavedDatasetStorage): pass
def _get_entity_df_event_timestamp_range( entity_df: Union[pd.DataFrame, str], entity_df_event_timestamp_col: str, config: RepoConfig, table_name: str, ) -> Tuple[datetime, datetime]: if isinstance(entity_df, pd.DataFrame): entity_df_event_timestamp = entity_df.loc[ :, entity_df_event_timestamp_col ].infer_objects() if pd.api.types.is_string_dtype(entity_df_event_timestamp): entity_df_event_timestamp = pd.to_datetime( entity_df_event_timestamp, utc=True ) entity_df_event_timestamp_range = ( entity_df_event_timestamp.min(), entity_df_event_timestamp.max(), ) elif isinstance(entity_df, str): # If the entity_df is a string (SQL query), determine range # from table with _get_conn(config.offline_store) as conn, conn.cursor() as cur: cur.execute( f"SELECT MIN({entity_df_event_timestamp_col}) AS min, MAX({entity_df_event_timestamp_col}) AS max FROM {table_name}" ), res = cur.fetchone() entity_df_event_timestamp_range = (res[0], res[1]) else: raise InvalidEntityType(type(entity_df)) return entity_df_event_timestamp_range def _append_alias(field_names: List[str], alias: str) -> List[str]: return [f'{alias}."{field_name}"' for field_name in field_names]
[docs]def build_point_in_time_query( feature_view_query_contexts: List[dict], left_table_query_string: str, entity_df_event_timestamp_col: str, entity_df_columns: KeysView[str], query_template: str, full_feature_names: bool = False, ) -> str: """Build point-in-time query between each feature view table and the entity dataframe for PostgreSQL""" template = Environment(loader=BaseLoader()).from_string(source=query_template) final_output_feature_names = list(entity_df_columns) final_output_feature_names.extend( [ ( f'{fv["name"]}__{fv["field_mapping"].get(feature, feature)}' if full_feature_names else fv["field_mapping"].get(feature, feature) ) for fv in feature_view_query_contexts for feature in fv["features"] ] ) # Add additional fields to dict template_context = { "left_table_query_string": left_table_query_string, "entity_df_event_timestamp_col": entity_df_event_timestamp_col, "unique_entity_keys": set( [entity for fv in feature_view_query_contexts for entity in fv["entities"]] ), "featureviews": feature_view_query_contexts, "full_feature_names": full_feature_names, "final_output_feature_names": final_output_feature_names, } query = template.render(template_context) return query
# Copied from the Feast Redshift offline store implementation # Note: Keep this in sync with sdk/python/feast/infra/offline_stores/redshift.py: # MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN # https://github.com/feast-dev/feast/blob/master/sdk/python/feast/infra/offline_stores/redshift.py MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN = """ /* Compute a deterministic hash for the `left_table_query_string` that will be used throughout all the logic as the field to GROUP BY the data */ WITH entity_dataframe AS ( SELECT *, {{entity_df_event_timestamp_col}} AS entity_timestamp {% for featureview in featureviews %} {% if featureview.entities %} ,( {% for entity in featureview.entities %} CAST("{{entity}}" as VARCHAR) || {% endfor %} CAST("{{entity_df_event_timestamp_col}}" AS VARCHAR) ) AS "{{featureview.name}}__entity_row_unique_id" {% else %} ,CAST("{{entity_df_event_timestamp_col}}" AS VARCHAR) AS "{{featureview.name}}__entity_row_unique_id" {% endif %} {% endfor %} FROM {{ left_table_query_string }} ), {% for featureview in featureviews %} "{{ featureview.name }}__entity_dataframe" AS ( SELECT {% if featureview.entities %}"{{ featureview.entities | join('", "') }}",{% endif %} entity_timestamp, "{{featureview.name}}__entity_row_unique_id" FROM entity_dataframe GROUP BY {% if featureview.entities %}"{{ featureview.entities | join('", "')}}",{% endif %} entity_timestamp, "{{featureview.name}}__entity_row_unique_id" ), /* This query template performs the point-in-time correctness join for a single feature set table to the provided entity table. 1. We first join the current feature_view to the entity dataframe that has been passed. This JOIN has the following logic: - For each row of the entity dataframe, only keep the rows where the `timestamp_field` is less than the one provided in the entity dataframe - If there a TTL for the current feature_view, also keep the rows where the `timestamp_field` is higher the the one provided minus the TTL - For each row, Join on the entity key and retrieve the `entity_row_unique_id` that has been computed previously The output of this CTE will contain all the necessary information and already filtered out most of the data that is not relevant. */ "{{ featureview.name }}__subquery" AS ( SELECT "{{ featureview.timestamp_field }}" as event_timestamp, {{ '"' ~ featureview.created_timestamp_column ~ '" as created_timestamp,' if featureview.created_timestamp_column else '' }} {{ featureview.entity_selections | join(', ')}}{% if featureview.entity_selections %},{% else %}{% endif %} {% for feature in featureview.features %} "{{ feature }}" as {% if full_feature_names %}"{{ featureview.name }}__{{featureview.field_mapping.get(feature, feature)}}"{% else %}"{{ featureview.field_mapping.get(feature, feature) }}"{% endif %}{% if loop.last %}{% else %}, {% endif %} {% endfor %} FROM {{ featureview.table_subquery }} AS sub WHERE "{{ featureview.timestamp_field }}" <= (SELECT MAX(entity_timestamp) FROM entity_dataframe) {% if featureview.ttl == 0 %}{% else %} AND "{{ featureview.timestamp_field }}" >= (SELECT MIN(entity_timestamp) FROM entity_dataframe) - {{ featureview.ttl }} * interval '1' second {% endif %} ), "{{ featureview.name }}__base" AS ( SELECT subquery.*, entity_dataframe.entity_timestamp, entity_dataframe."{{featureview.name}}__entity_row_unique_id" FROM "{{ featureview.name }}__subquery" AS subquery INNER JOIN "{{ featureview.name }}__entity_dataframe" AS entity_dataframe ON TRUE AND subquery.event_timestamp <= entity_dataframe.entity_timestamp {% if featureview.ttl == 0 %}{% else %} AND subquery.event_timestamp >= entity_dataframe.entity_timestamp - {{ featureview.ttl }} * interval '1' second {% endif %} {% for entity in featureview.entities %} AND subquery."{{ entity }}" = entity_dataframe."{{ entity }}" {% endfor %} ), /* 2. If the `created_timestamp_column` has been set, we need to deduplicate the data first. This is done by calculating the `MAX(created_at_timestamp)` for each event_timestamp. We then join the data on the next CTE */ {% if featureview.created_timestamp_column %} "{{ featureview.name }}__dedup" AS ( SELECT "{{featureview.name}}__entity_row_unique_id", event_timestamp, MAX(created_timestamp) as created_timestamp FROM "{{ featureview.name }}__base" GROUP BY "{{featureview.name}}__entity_row_unique_id", event_timestamp ), {% endif %} /* 3. The data has been filtered during the first CTE "*__base" Thus we only need to compute the latest timestamp of each feature. */ "{{ featureview.name }}__latest" AS ( SELECT event_timestamp, {% if featureview.created_timestamp_column %}created_timestamp,{% endif %} "{{featureview.name}}__entity_row_unique_id" FROM ( SELECT *, ROW_NUMBER() OVER( PARTITION BY "{{featureview.name}}__entity_row_unique_id" ORDER BY event_timestamp DESC{% if featureview.created_timestamp_column %},created_timestamp DESC{% endif %} ) AS row_number FROM "{{ featureview.name }}__base" {% if featureview.created_timestamp_column %} INNER JOIN "{{ featureview.name }}__dedup" USING ("{{featureview.name}}__entity_row_unique_id", event_timestamp, created_timestamp) {% endif %} ) AS sub WHERE row_number = 1 ), /* 4. Once we know the latest value of each feature for a given timestamp, we can join again the data back to the original "base" dataset */ "{{ featureview.name }}__cleaned" AS ( SELECT base.* FROM "{{ featureview.name }}__base" as base INNER JOIN "{{ featureview.name }}__latest" USING( "{{featureview.name}}__entity_row_unique_id", event_timestamp {% if featureview.created_timestamp_column %} ,created_timestamp {% endif %} ) ){% if loop.last %}{% else %}, {% endif %} {% endfor %} /* Joins the outputs of multiple time travel joins to a single table. The entity_dataframe dataset being our source of truth here. */ SELECT "{{ final_output_feature_names | join('", "')}}" FROM entity_dataframe {% for featureview in featureviews %} LEFT JOIN ( SELECT "{{featureview.name}}__entity_row_unique_id" {% for feature in featureview.features %} ,"{% if full_feature_names %}{{ featureview.name }}__{{featureview.field_mapping.get(feature, feature)}}{% else %}{{ featureview.field_mapping.get(feature, feature) }}{% endif %}" {% endfor %} FROM "{{ featureview.name }}__cleaned" ) AS "{{featureview.name}}" USING ("{{featureview.name}}__entity_row_unique_id") {% endfor %} """