Source code for feast.infra.offline_stores.bigquery

import time
import uuid
from dataclasses import asdict, dataclass
from datetime import date, datetime, timedelta
from typing import List, Optional, Set, Union

import pandas
import pyarrow
from jinja2 import BaseLoader, Environment
from pydantic import StrictStr
from pydantic.typing import Literal
from tenacity import retry, stop_after_delay, wait_fixed

from feast import errors
from feast.data_source import BigQuerySource, DataSource
from feast.errors import FeastProviderLoginError
from feast.feature_view import FeatureView
from feast.infra.offline_stores.offline_store import OfflineStore
from feast.infra.provider import (
    DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL,
    RetrievalJob,
    _get_requested_feature_views_to_features_dict,
)
from feast.registry import Registry
from feast.repo_config import FeastConfigBaseModel, RepoConfig

try:
    from google.api_core.exceptions import NotFound
    from google.auth.exceptions import DefaultCredentialsError
    from google.cloud import bigquery
    from google.cloud.bigquery import Client, Table

except ImportError as e:
    from feast.errors import FeastExtrasDependencyImportError

    raise FeastExtrasDependencyImportError("gcp", str(e))


[docs]class BigQueryOfflineStoreConfig(FeastConfigBaseModel): """ Offline store config for GCP BigQuery """ type: Literal["bigquery"] = "bigquery" """ Offline store type selector""" dataset: StrictStr = "feast" """ (optional) BigQuery Dataset name for temporary tables """ project_id: Optional[StrictStr] = None """ (optional) GCP project name used for the BigQuery offline store """
[docs]class BigQueryOfflineStore(OfflineStore):
[docs] @staticmethod def pull_latest_from_table_or_query( data_source: DataSource, join_key_columns: List[str], feature_name_columns: List[str], event_timestamp_column: str, created_timestamp_column: Optional[str], start_date: datetime, end_date: datetime, ) -> pyarrow.Table: assert isinstance(data_source, BigQuerySource) from_expression = data_source.get_table_query_string() partition_by_join_key_string = ", ".join(join_key_columns) if partition_by_join_key_string != "": partition_by_join_key_string = ( "PARTITION BY " + partition_by_join_key_string ) timestamps = [event_timestamp_column] if created_timestamp_column: timestamps.append(created_timestamp_column) timestamp_desc_string = " DESC, ".join(timestamps) + " DESC" field_string = ", ".join(join_key_columns + feature_name_columns + timestamps) query = f""" SELECT {field_string} FROM ( SELECT {field_string}, ROW_NUMBER() OVER({partition_by_join_key_string} ORDER BY {timestamp_desc_string}) AS _feast_row FROM {from_expression} WHERE {event_timestamp_column} BETWEEN TIMESTAMP('{start_date}') AND TIMESTAMP('{end_date}') ) WHERE _feast_row = 1 """ return BigQueryOfflineStore._pull_query(query)
@staticmethod def _pull_query(query: str) -> pyarrow.Table: client = _get_bigquery_client() query_job = client.query(query) return query_job.to_arrow()
[docs] @staticmethod def get_historical_features( config: RepoConfig, feature_views: List[FeatureView], feature_refs: List[str], entity_df: Union[pandas.DataFrame, str], registry: Registry, project: str, ) -> RetrievalJob: # TODO: Add entity_df validation in order to fail before interacting with BigQuery client = _get_bigquery_client() expected_join_keys = _get_join_keys(project, feature_views, registry) assert isinstance(config.offline_store, BigQueryOfflineStoreConfig) dataset_project = config.offline_store.project_id or client.project table = _upload_entity_df_into_bigquery( client=client, project=config.project, dataset_name=config.offline_store.dataset, dataset_project=dataset_project, entity_df=entity_df, ) entity_df_event_timestamp_col = _infer_event_timestamp_from_bigquery_query( table.schema ) _assert_expected_columns_in_bigquery( expected_join_keys, entity_df_event_timestamp_col, table.schema, ) # Build a query context containing all information required to template the BigQuery SQL query query_context = get_feature_view_query_context( feature_refs, feature_views, registry, project ) # TODO: Infer min_timestamp and max_timestamp from entity_df # Generate the BigQuery SQL query from the query context query = build_point_in_time_query( query_context, min_timestamp=datetime.now() - timedelta(days=365), max_timestamp=datetime.now() + timedelta(days=1), left_table_query_string=str(table.reference), entity_df_event_timestamp_col=entity_df_event_timestamp_col, ) job = BigQueryRetrievalJob(query=query, client=client, config=config) return job
def _assert_expected_columns_in_dataframe( join_keys: Set[str], entity_df_event_timestamp_col: str, entity_df: pandas.DataFrame ): entity_df_columns = set(entity_df.columns.values) expected_columns = join_keys.copy() expected_columns.add(entity_df_event_timestamp_col) missing_keys = expected_columns - entity_df_columns if len(missing_keys) != 0: raise errors.FeastEntityDFMissingColumnsError(expected_columns, missing_keys) def _assert_expected_columns_in_bigquery( join_keys: Set[str], entity_df_event_timestamp_col: str, table_schema ): entity_columns = set() for schema_field in table_schema: entity_columns.add(schema_field.name) expected_columns = join_keys.copy() expected_columns.add(entity_df_event_timestamp_col) missing_keys = expected_columns - entity_columns if len(missing_keys) != 0: raise errors.FeastEntityDFMissingColumnsError(expected_columns, missing_keys) def _get_join_keys( project: str, feature_views: List[FeatureView], registry: Registry ) -> Set[str]: join_keys = set() for feature_view in feature_views: entities = feature_view.entities for entity_name in entities: entity = registry.get_entity(entity_name, project) join_keys.add(entity.join_key) return join_keys def _infer_event_timestamp_from_bigquery_query(table_schema) -> str: if any( schema_field.name == DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL for schema_field in table_schema ): return DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL else: datetime_columns = list( filter( lambda schema_field: schema_field.field_type == "TIMESTAMP", table_schema, ) ) if len(datetime_columns) == 1: print( f"Using {datetime_columns[0].name} as the event timestamp. To specify a column explicitly, please name it {DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL}." ) return datetime_columns[0].name else: raise ValueError( f"Please provide an entity_df with a column named {DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL} representing the time of events." ) def _infer_event_timestamp_from_dataframe(entity_df: pandas.DataFrame) -> str: if DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL in entity_df.columns: return DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL else: datetime_columns = entity_df.select_dtypes( include=["datetime", "datetimetz"] ).columns if len(datetime_columns) == 1: print( f"Using {datetime_columns[0]} as the event timestamp. To specify a column explicitly, please name it {DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL}." ) return datetime_columns[0] else: raise ValueError( f"Please provide an entity_df with a column named {DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL} representing the time of events." )
[docs]class BigQueryRetrievalJob(RetrievalJob): def __init__(self, query, client, config): self.query = query self.client = client self.config = config
[docs] def to_df(self): # TODO: Ideally only start this job when the user runs "get_historical_features", not when they run to_df() df = self.client.query(self.query).to_dataframe(create_bqstorage_client=True) return df
[docs] def to_bigquery(self, dry_run=False) -> Optional[str]: @retry(wait=wait_fixed(10), stop=stop_after_delay(1800), reraise=True) def _block_until_done(): return self.client.get_job(bq_job.job_id).state in ["PENDING", "RUNNING"] today = date.today().strftime("%Y%m%d") rand_id = str(uuid.uuid4())[:7] dataset_project = self.config.offline_store.project_id or self.client.project path = f"{dataset_project}.{self.config.offline_store.dataset}.historical_{today}_{rand_id}" job_config = bigquery.QueryJobConfig(destination=path, dry_run=dry_run) bq_job = self.client.query(self.query, job_config=job_config) if dry_run: print( "This query will process {} bytes.".format(bq_job.total_bytes_processed) ) return None _block_until_done() if bq_job.exception(): raise bq_job.exception() print(f"Done writing to '{path}'.") return path
[docs]@dataclass(frozen=True) class FeatureViewQueryContext: """Context object used to template a BigQuery point-in-time SQL query""" name: str ttl: int entities: List[str] features: List[str] # feature reference format table_ref: str event_timestamp_column: str created_timestamp_column: Optional[str] query: str table_subquery: str entity_selections: List[str]
def _get_table_id_for_new_entity( client: Client, project: str, dataset_name: str, dataset_project: str ) -> str: """Gets the table_id for the new entity to be uploaded.""" # First create the BigQuery dataset if it doesn't exist dataset = bigquery.Dataset(f"{dataset_project}.{dataset_name}") dataset.location = "US" try: client.get_dataset(dataset) except NotFound: # Only create the dataset if it does not exist client.create_dataset(dataset, exists_ok=True) return f"{dataset_project}.{dataset_name}.entity_df_{project}_{int(time.time())}" def _upload_entity_df_into_bigquery( client: Client, project: str, dataset_name: str, dataset_project: str, entity_df: Union[pandas.DataFrame, str], ) -> Table: """Uploads a Pandas entity dataframe into a BigQuery table and returns the resulting table""" table_id = _get_table_id_for_new_entity( client, project, dataset_name, dataset_project ) if type(entity_df) is str: job = client.query(f"CREATE TABLE {table_id} AS ({entity_df})") job.result() elif isinstance(entity_df, pandas.DataFrame): # Drop the index so that we dont have unnecessary columns entity_df.reset_index(drop=True, inplace=True) # Upload the dataframe into BigQuery, creating a temporary table job_config = bigquery.LoadJobConfig() job = client.load_table_from_dataframe( entity_df, table_id, job_config=job_config ) job.result() else: raise ValueError( f"The entity dataframe you have provided must be a Pandas DataFrame or BigQuery SQL query, " f"but we found: {type(entity_df)} " ) # Ensure that the table expires after some time table = client.get_table(table=table_id) table.expires = datetime.utcnow() + timedelta(minutes=30) client.update_table(table, ["expires"]) return table
[docs]def get_feature_view_query_context( feature_refs: List[str], feature_views: List[FeatureView], registry: Registry, project: str, ) -> List[FeatureViewQueryContext]: """Build a query context containing all information required to template a BigQuery point-in-time SQL query""" feature_views_to_feature_map = _get_requested_feature_views_to_features_dict( feature_refs, feature_views ) query_context = [] for feature_view, features in feature_views_to_feature_map.items(): join_keys = [] entity_selections = [] reverse_field_mapping = { v: k for k, v in feature_view.input.field_mapping.items() } for entity_name in feature_view.entities: entity = registry.get_entity(entity_name, project) join_keys.append(entity.join_key) join_key_column = reverse_field_mapping.get( entity.join_key, entity.join_key ) entity_selections.append(f"{join_key_column} AS {entity.join_key}") if isinstance(feature_view.ttl, timedelta): ttl_seconds = int(feature_view.ttl.total_seconds()) else: ttl_seconds = 0 assert isinstance(feature_view.input, BigQuerySource) event_timestamp_column = feature_view.input.event_timestamp_column created_timestamp_column = feature_view.input.created_timestamp_column context = FeatureViewQueryContext( name=feature_view.name, ttl=ttl_seconds, entities=join_keys, features=features, table_ref=feature_view.input.table_ref, event_timestamp_column=reverse_field_mapping.get( event_timestamp_column, event_timestamp_column ), created_timestamp_column=reverse_field_mapping.get( created_timestamp_column, created_timestamp_column ), # TODO: Make created column optional and not hardcoded query=feature_view.input.query, table_subquery=feature_view.input.get_table_query_string(), entity_selections=entity_selections, ) query_context.append(context) return query_context
[docs]def build_point_in_time_query( feature_view_query_contexts: List[FeatureViewQueryContext], min_timestamp: datetime, max_timestamp: datetime, left_table_query_string: str, entity_df_event_timestamp_col: str, ): """Build point-in-time query between each feature view table and the entity dataframe""" template = Environment(loader=BaseLoader()).from_string( source=SINGLE_FEATURE_VIEW_POINT_IN_TIME_JOIN ) # Add additional fields to dict template_context = { "min_timestamp": min_timestamp, "max_timestamp": max_timestamp, "left_table_query_string": left_table_query_string, "entity_df_event_timestamp_col": entity_df_event_timestamp_col, "unique_entity_keys": set( [entity for fv in feature_view_query_contexts for entity in fv.entities] ), "featureviews": [asdict(context) for context in feature_view_query_contexts], } query = template.render(template_context) return query
def _get_bigquery_client(): try: client = bigquery.Client() except DefaultCredentialsError as e: raise FeastProviderLoginError( str(e) + '\nIt may be necessary to run "gcloud auth application-default login" if you would like to use your ' "local Google Cloud account" ) except EnvironmentError as e: raise FeastProviderLoginError( "GCP error: " + str(e) + "\nIt may be necessary to set a default GCP project by running " '"gcloud config set project your-project"' ) return client # TODO: Optimizations # * Use GENERATE_UUID() instead of ROW_NUMBER(), or join on entity columns directly # * Precompute ROW_NUMBER() so that it doesn't have to be recomputed for every query on entity_dataframe # * Create temporary tables instead of keeping all tables in memory SINGLE_FEATURE_VIEW_POINT_IN_TIME_JOIN = """ /* Compute a deterministic hash for the `left_table_query_string` that will be used throughout all the logic as the field to GROUP BY the data */ WITH entity_dataframe AS ( SELECT *, CONCAT( {% for entity_key in unique_entity_keys %} CAST({{entity_key}} AS STRING), {% endfor %} CAST({{entity_df_event_timestamp_col}} AS STRING) ) AS entity_row_unique_id FROM {{ left_table_query_string }} ), {% for featureview in featureviews %} /* This query template performs the point-in-time correctness join for a single feature set table to the provided entity table. 1. We first join the current feature_view to the entity dataframe that has been passed. This JOIN has the following logic: - For each row of the entity dataframe, only keep the rows where the `event_timestamp_column` is less than the one provided in the entity dataframe - If there a TTL for the current feature_view, also keep the rows where the `event_timestamp_column` is higher the the one provided minus the TTL - For each row, Join on the entity key and retrieve the `entity_row_unique_id` that has been computed previously The output of this CTE will contain all the necessary information and already filtered out most of the data that is not relevant. */ {{ featureview.name }}__subquery AS ( SELECT {{ featureview.event_timestamp_column }} as event_timestamp, {{ featureview.created_timestamp_column ~ ' as created_timestamp,' if featureview.created_timestamp_column else '' }} {{ featureview.entity_selections | join(', ')}}, {% for feature in featureview.features %} {{ feature }} as {{ featureview.name }}__{{ feature }}{% if loop.last %}{% else %}, {% endif %} {% endfor %} FROM {{ featureview.table_subquery }} ), {{ featureview.name }}__base AS ( SELECT subquery.*, entity_dataframe.{{entity_df_event_timestamp_col}} AS entity_timestamp, entity_dataframe.entity_row_unique_id FROM {{ featureview.name }}__subquery AS subquery INNER JOIN entity_dataframe ON TRUE AND subquery.event_timestamp <= entity_dataframe.{{entity_df_event_timestamp_col}} {% if featureview.ttl == 0 %}{% else %} AND subquery.event_timestamp >= Timestamp_sub(entity_dataframe.{{entity_df_event_timestamp_col}}, interval {{ featureview.ttl }} second) {% endif %} {% for entity in featureview.entities %} AND subquery.{{ entity }} = entity_dataframe.{{ entity }} {% endfor %} ), /* 2. If the `created_timestamp_column` has been set, we need to deduplicate the data first. This is done by calculating the `MAX(created_at_timestamp)` for each event_timestamp. We then join the data on the next CTE */ {% if featureview.created_timestamp_column %} {{ featureview.name }}__dedup AS ( SELECT entity_row_unique_id, event_timestamp, MAX(created_timestamp) as created_timestamp, FROM {{ featureview.name }}__base GROUP BY entity_row_unique_id, event_timestamp ), {% endif %} /* 3. The data has been filtered during the first CTE "*__base" Thus we only need to compute the latest timestamp of each feature. */ {{ featureview.name }}__latest AS ( SELECT entity_row_unique_id, MAX(event_timestamp) AS event_timestamp {% if featureview.created_timestamp_column %} ,ANY_VALUE(created_timestamp) AS created_timestamp {% endif %} FROM {{ featureview.name }}__base {% if featureview.created_timestamp_column %} INNER JOIN {{ featureview.name }}__dedup USING (entity_row_unique_id, event_timestamp, created_timestamp) {% endif %} GROUP BY entity_row_unique_id ), /* 4. Once we know the latest value of each feature for a given timestamp, we can join again the data back to the original "base" dataset */ {{ featureview.name }}__cleaned AS ( SELECT base.* FROM {{ featureview.name }}__base as base INNER JOIN {{ featureview.name }}__latest USING( entity_row_unique_id, event_timestamp {% if featureview.created_timestamp_column %} ,created_timestamp {% endif %} ) ){% if loop.last %}{% else %}, {% endif %} {% endfor %} /* Joins the outputs of multiple time travel joins to a single table. The entity_dataframe dataset being our source of truth here. */ SELECT * EXCEPT (entity_row_unique_id) FROM entity_dataframe {% for featureview in featureviews %} LEFT JOIN ( SELECT entity_row_unique_id, {% for feature in featureview.features %} {{ featureview.name }}__{{ feature }}, {% endfor %} FROM {{ featureview.name }}__cleaned ) USING (entity_row_unique_id) {% endfor %} """