Source code for feast.infra.offline_stores.bigquery

import contextlib
import uuid
from datetime import date, datetime, timedelta
from typing import (

import numpy as np
import pandas as pd
import pyarrow
import pyarrow.parquet
from pydantic import StrictStr
from pydantic.typing import Literal
from tenacity import Retrying, retry_if_exception_type, stop_after_delay, wait_fixed

from feast import flags_helper
from feast.data_source import DataSource
from feast.errors import (
from feast.feature_view import DUMMY_ENTITY_ID, DUMMY_ENTITY_VAL, FeatureView
from feast.infra.offline_stores import offline_utils
from feast.infra.offline_stores.offline_store import (
from feast.on_demand_feature_view import OnDemandFeatureView
from feast.registry import Registry
from feast.repo_config import FeastConfigBaseModel, RepoConfig

from ...saved_dataset import SavedDatasetStorage
from ...usage import log_exceptions_and_usage
from .bigquery_source import BigQuerySource, SavedDatasetBigQueryStorage

    from google.api_core.exceptions import NotFound
    from google.auth.exceptions import DefaultCredentialsError
    from import bigquery
    from import Client, Table

except ImportError as e:
    from feast.errors import FeastExtrasDependencyImportError

    raise FeastExtrasDependencyImportError("gcp", str(e))

[docs]class BigQueryOfflineStoreConfig(FeastConfigBaseModel): """ Offline store config for GCP BigQuery """ type: Literal["bigquery"] = "bigquery" """ Offline store type selector""" dataset: StrictStr = "feast" """ (optional) BigQuery Dataset name for temporary tables """ project_id: Optional[StrictStr] = None """ (optional) GCP project name used for the BigQuery offline store """ location: Optional[StrictStr] = None """ (optional) GCP location name used for the BigQuery offline store. Examples of location names include ``US``, ``EU``, ``us-central1``, ``us-west4``. If a location is not specified, the location defaults to the ``US`` multi-regional location. For more information on BigQuery data locations see: """
[docs]class BigQueryOfflineStore(OfflineStore):
[docs] @staticmethod @log_exceptions_and_usage(offline_store="bigquery") def pull_latest_from_table_or_query( config: RepoConfig, data_source: DataSource, join_key_columns: List[str], feature_name_columns: List[str], event_timestamp_column: str, created_timestamp_column: Optional[str], start_date: datetime, end_date: datetime, ) -> RetrievalJob: assert isinstance(data_source, BigQuerySource) from_expression = data_source.get_table_query_string() partition_by_join_key_string = ", ".join(join_key_columns) if partition_by_join_key_string != "": partition_by_join_key_string = ( "PARTITION BY " + partition_by_join_key_string ) timestamps = [event_timestamp_column] if created_timestamp_column: timestamps.append(created_timestamp_column) timestamp_desc_string = " DESC, ".join(timestamps) + " DESC" field_string = ", ".join(join_key_columns + feature_name_columns + timestamps) client = _get_bigquery_client( project=config.offline_store.project_id, location=config.offline_store.location, ) query = f""" SELECT {field_string} {f", {repr(DUMMY_ENTITY_VAL)} AS {DUMMY_ENTITY_ID}" if not join_key_columns else ""} FROM ( SELECT {field_string}, ROW_NUMBER() OVER({partition_by_join_key_string} ORDER BY {timestamp_desc_string}) AS _feast_row FROM {from_expression} WHERE {event_timestamp_column} BETWEEN TIMESTAMP('{start_date}') AND TIMESTAMP('{end_date}') ) WHERE _feast_row = 1 """ # When materializing a single feature view, we don't need full feature names. On demand transforms aren't materialized return BigQueryRetrievalJob( query=query, client=client, config=config, full_feature_names=False, )
[docs] @staticmethod @log_exceptions_and_usage(offline_store="bigquery") def pull_all_from_table_or_query( config: RepoConfig, data_source: DataSource, join_key_columns: List[str], feature_name_columns: List[str], event_timestamp_column: str, start_date: datetime, end_date: datetime, ) -> RetrievalJob: assert isinstance(data_source, BigQuerySource) from_expression = data_source.get_table_query_string() client = _get_bigquery_client( project=config.offline_store.project_id, location=config.offline_store.location, ) field_string = ", ".join( join_key_columns + feature_name_columns + [event_timestamp_column] ) query = f""" SELECT {field_string} FROM {from_expression} WHERE {event_timestamp_column} BETWEEN TIMESTAMP('{start_date}') AND TIMESTAMP('{end_date}') """ return BigQueryRetrievalJob( query=query, client=client, config=config, full_feature_names=False, )
[docs] @staticmethod @log_exceptions_and_usage(offline_store="bigquery") def get_historical_features( config: RepoConfig, feature_views: List[FeatureView], feature_refs: List[str], entity_df: Union[pd.DataFrame, str], registry: Registry, project: str, full_feature_names: bool = False, ) -> RetrievalJob: # TODO: Add entity_df validation in order to fail before interacting with BigQuery assert isinstance(config.offline_store, BigQueryOfflineStoreConfig) client = _get_bigquery_client( project=config.offline_store.project_id, location=config.offline_store.location, ) assert isinstance(config.offline_store, BigQueryOfflineStoreConfig) table_reference = _get_table_reference_for_new_entity( client, client.project, config.offline_store.dataset, config.offline_store.location, ) entity_schema = _get_entity_schema(client=client, entity_df=entity_df,) entity_df_event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df( entity_schema ) entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range( entity_df, entity_df_event_timestamp_col, client, ) @contextlib.contextmanager def query_generator() -> Iterator[str]: _upload_entity_df( client=client, table_name=table_reference, entity_df=entity_df, ) expected_join_keys = offline_utils.get_expected_join_keys( project, feature_views, registry ) offline_utils.assert_expected_columns_in_entity_df( entity_schema, expected_join_keys, entity_df_event_timestamp_col ) # Build a query context containing all information required to template the BigQuery SQL query query_context = offline_utils.get_feature_view_query_context( feature_refs, feature_views, registry, project, entity_df_event_timestamp_range, ) # Generate the BigQuery SQL query from the query context query = offline_utils.build_point_in_time_query( query_context, left_table_query_string=table_reference, entity_df_event_timestamp_col=entity_df_event_timestamp_col, entity_df_columns=entity_schema.keys(), query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN, full_feature_names=full_feature_names, ) try: yield query finally: # Asynchronously clean up the uploaded Bigquery table, which will expire # if cleanup fails client.delete_table(table=table_reference, not_found_ok=True) return BigQueryRetrievalJob( query=query_generator, client=client, config=config, full_feature_names=full_feature_names, on_demand_feature_views=OnDemandFeatureView.get_requested_odfvs( feature_refs, project, registry ), metadata=RetrievalMetadata( features=feature_refs, keys=list(entity_schema.keys() - {entity_df_event_timestamp_col}), min_event_timestamp=entity_df_event_timestamp_range[0], max_event_timestamp=entity_df_event_timestamp_range[1], ), )
[docs]class BigQueryRetrievalJob(RetrievalJob): def __init__( self, query: Union[str, Callable[[], ContextManager[str]]], client: bigquery.Client, config: RepoConfig, full_feature_names: bool, on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None, metadata: Optional[RetrievalMetadata] = None, ): if not isinstance(query, str): self._query_generator = query else: @contextlib.contextmanager def query_generator() -> Iterator[str]: assert isinstance(query, str) yield query self._query_generator = query_generator self.client = client self.config = config self._full_feature_names = full_feature_names self._on_demand_feature_views = ( on_demand_feature_views if on_demand_feature_views else [] ) self._metadata = metadata @property def full_feature_names(self) -> bool: return self._full_feature_names @property def on_demand_feature_views(self) -> Optional[List[OnDemandFeatureView]]: return self._on_demand_feature_views def _to_df_internal(self) -> pd.DataFrame: with self._query_generator() as query: df = self._execute_query(query).to_dataframe(create_bqstorage_client=True) return df
[docs] def to_sql(self) -> str: """ Returns the SQL query that will be executed in BigQuery to build the historical feature table. """ with self._query_generator() as query: return query
[docs] def to_bigquery( self, job_config: bigquery.QueryJobConfig = None, timeout: int = 1800, retry_cadence: int = 10, ) -> Optional[str]: """ Triggers the execution of a historical feature retrieval query and exports the results to a BigQuery table. Runs for a maximum amount of time specified by the timeout parameter (defaulting to 30 minutes). Args: job_config: An optional bigquery.QueryJobConfig to specify options like destination table, dry run, etc. timeout: An optional number of seconds for setting the time limit of the QueryJob. retry_cadence: An optional number of seconds for setting how long the job should checked for completion. Returns: Returns the destination table name or returns None if job_config.dry_run is True. """ if not job_config: today ="%Y%m%d") rand_id = str(uuid.uuid4())[:7] path = f"{self.client.project}.{self.config.offline_store.dataset}.historical_{today}_{rand_id}" job_config = bigquery.QueryJobConfig(destination=path) if not job_config.dry_run and self.on_demand_feature_views: job = self.client.load_table_from_dataframe( self.to_df(), job_config.destination ) job.result() print(f"Done writing to '{job_config.destination}'.") return str(job_config.destination) with self._query_generator() as query: self._execute_query(query, job_config, timeout) print(f"Done writing to '{job_config.destination}'.") return str(job_config.destination)
def _to_arrow_internal(self) -> pyarrow.Table: with self._query_generator() as query: return self._execute_query(query).to_arrow() @log_exceptions_and_usage def _execute_query( self, query, job_config=None, timeout: int = 1800 ) -> bigquery.job.query.QueryJob: bq_job = self.client.query(query, job_config=job_config) if job_config and job_config.dry_run: print( "This query will process {} bytes.".format(bq_job.total_bytes_processed) ) return None block_until_done(client=self.client, bq_job=bq_job, timeout=timeout) return bq_job
[docs] def persist(self, storage: SavedDatasetStorage): assert isinstance(storage, SavedDatasetBigQueryStorage) self.to_bigquery( bigquery.QueryJobConfig(destination=storage.bigquery_options.table_ref) )
@property def metadata(self) -> Optional[RetrievalMetadata]: return self._metadata
[docs]def block_until_done( client: Client, bq_job: Union[bigquery.job.query.QueryJob, bigquery.job.load.LoadJob], timeout: int = 1800, retry_cadence: float = 1, ): """ Waits for bq_job to finish running, up to a maximum amount of time specified by the timeout parameter (defaulting to 30 minutes). Args: client: A bigquery.client.Client to monitor the bq_job. bq_job: The bigquery.job.QueryJob that blocks until done runnning. timeout: An optional number of seconds for setting the time limit of the job. retry_cadence: An optional number of seconds for setting how long the job should checked for completion. Raises: BigQueryJobStillRunning exception if the function has blocked longer than 30 minutes. BigQueryJobCancelled exception to signify when that the job has been cancelled (i.e. from timeout or KeyboardInterrupt). """ # For test environments, retry more aggressively if flags_helper.is_test(): retry_cadence = 0.1 def _wait_until_done(bq_job): if client.get_job(bq_job).state in ["PENDING", "RUNNING"]: raise BigQueryJobStillRunning(job_id=bq_job.job_id) try: retryer = Retrying( wait=wait_fixed(retry_cadence), stop=stop_after_delay(timeout), retry=retry_if_exception_type(BigQueryJobStillRunning), reraise=True, ) retryer(_wait_until_done, bq_job) finally: if client.get_job(bq_job).state in ["PENDING", "RUNNING"]: client.cancel_job(bq_job) raise BigQueryJobCancelled(job_id=bq_job.job_id) if bq_job.exception(): raise bq_job.exception()
def _get_table_reference_for_new_entity( client: Client, dataset_project: str, dataset_name: str, dataset_location: Optional[str], ) -> str: """Gets the table_id for the new entity to be uploaded.""" # First create the BigQuery dataset if it doesn't exist dataset = bigquery.Dataset(f"{dataset_project}.{dataset_name}") dataset.location = dataset_location if dataset_location else "US" try: client.get_dataset(dataset) except NotFound: # Only create the dataset if it does not exist client.create_dataset(dataset, exists_ok=True) table_name = offline_utils.get_temp_entity_table_name() return f"{dataset_project}.{dataset_name}.{table_name}" def _upload_entity_df( client: Client, table_name: str, entity_df: Union[pd.DataFrame, str], ) -> Table: """Uploads a Pandas entity dataframe into a BigQuery table and returns the resulting table""" if isinstance(entity_df, str): job = client.query(f"CREATE TABLE {table_name} AS ({entity_df})") elif isinstance(entity_df, pd.DataFrame): # Drop the index so that we don't have unnecessary columns entity_df.reset_index(drop=True, inplace=True) job = client.load_table_from_dataframe(entity_df, table_name) else: raise InvalidEntityType(type(entity_df)) block_until_done(client, job) # Ensure that the table expires after some time table = client.get_table(table=table_name) table.expires = datetime.utcnow() + timedelta(minutes=30) client.update_table(table, ["expires"]) return table def _get_entity_schema( client: Client, entity_df: Union[pd.DataFrame, str] ) -> Dict[str, np.dtype]: if isinstance(entity_df, str): entity_df_sample = ( client.query(f"SELECT * FROM ({entity_df}) LIMIT 1").result().to_dataframe() ) entity_schema = dict(zip(entity_df_sample.columns, entity_df_sample.dtypes)) elif isinstance(entity_df, pd.DataFrame): entity_schema = dict(zip(entity_df.columns, entity_df.dtypes)) else: raise InvalidEntityType(type(entity_df)) return entity_schema def _get_entity_df_event_timestamp_range( entity_df: Union[pd.DataFrame, str], entity_df_event_timestamp_col: str, client: Client, ) -> Tuple[datetime, datetime]: if type(entity_df) is str: job = client.query( f"SELECT MIN({entity_df_event_timestamp_col}) AS min, MAX({entity_df_event_timestamp_col}) AS max " f"FROM ({entity_df})" ) res = next(job.result()) entity_df_event_timestamp_range = ( res.get("min"), res.get("max"), ) elif isinstance(entity_df, pd.DataFrame): entity_df_event_timestamp = entity_df.loc[ :, entity_df_event_timestamp_col ].infer_objects() if pd.api.types.is_string_dtype(entity_df_event_timestamp): entity_df_event_timestamp = pd.to_datetime( entity_df_event_timestamp, utc=True ) entity_df_event_timestamp_range = ( entity_df_event_timestamp.min().to_pydatetime(), entity_df_event_timestamp.max().to_pydatetime(), ) else: raise InvalidEntityType(type(entity_df)) return entity_df_event_timestamp_range def _get_bigquery_client(project: Optional[str] = None, location: Optional[str] = None): try: client = bigquery.Client(project=project, location=location) except DefaultCredentialsError as e: raise FeastProviderLoginError( str(e) + '\nIt may be necessary to run "gcloud auth application-default login" if you would like to use your ' "local Google Cloud account" ) except EnvironmentError as e: raise FeastProviderLoginError( "GCP error: " + str(e) + "\nIt may be necessary to set a default GCP project by running " '"gcloud config set project your-project"' ) return client # TODO: Optimizations # * Use GENERATE_UUID() instead of ROW_NUMBER(), or join on entity columns directly # * Precompute ROW_NUMBER() so that it doesn't have to be recomputed for every query on entity_dataframe # * Create temporary tables instead of keeping all tables in memory # Note: Keep this in sync with sdk/python/feast/infra/offline_stores/ MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN = """ /* Compute a deterministic hash for the `left_table_query_string` that will be used throughout all the logic as the field to GROUP BY the data */ WITH entity_dataframe AS ( SELECT *, {{entity_df_event_timestamp_col}} AS entity_timestamp {% for featureview in featureviews %} {% if featureview.entities %} ,CONCAT( {% for entity in featureview.entities %} CAST({{entity}} AS STRING), {% endfor %} CAST({{entity_df_event_timestamp_col}} AS STRING) ) AS {{}}__entity_row_unique_id {% else %} ,CAST({{entity_df_event_timestamp_col}} AS STRING) AS {{}}__entity_row_unique_id {% endif %} {% endfor %} FROM `{{ left_table_query_string }}` ), {% for featureview in featureviews %} {{ }}__entity_dataframe AS ( SELECT {{ featureview.entities | join(', ')}}{% if featureview.entities %},{% else %}{% endif %} entity_timestamp, {{}}__entity_row_unique_id FROM entity_dataframe GROUP BY {{ featureview.entities | join(', ')}}{% if featureview.entities %},{% else %}{% endif %} entity_timestamp, {{}}__entity_row_unique_id ), /* This query template performs the point-in-time correctness join for a single feature set table to the provided entity table. 1. We first join the current feature_view to the entity dataframe that has been passed. This JOIN has the following logic: - For each row of the entity dataframe, only keep the rows where the `event_timestamp_column` is less than the one provided in the entity dataframe - If there a TTL for the current feature_view, also keep the rows where the `event_timestamp_column` is higher the the one provided minus the TTL - For each row, Join on the entity key and retrieve the `entity_row_unique_id` that has been computed previously The output of this CTE will contain all the necessary information and already filtered out most of the data that is not relevant. */ {{ }}__subquery AS ( SELECT {{ featureview.event_timestamp_column }} as event_timestamp, {{ featureview.created_timestamp_column ~ ' as created_timestamp,' if featureview.created_timestamp_column else '' }} {{ featureview.entity_selections | join(', ')}}{% if featureview.entity_selections %},{% else %}{% endif %} {% for feature in featureview.features %} {{ feature }} as {% if full_feature_names %}{{ }}__{{featureview.field_mapping.get(feature, feature)}}{% else %}{{ featureview.field_mapping.get(feature, feature) }}{% endif %}{% if loop.last %}{% else %}, {% endif %} {% endfor %} FROM {{ featureview.table_subquery }} WHERE {{ featureview.event_timestamp_column }} <= '{{ featureview.max_event_timestamp }}' {% if featureview.ttl == 0 %}{% else %} AND {{ featureview.event_timestamp_column }} >= '{{ featureview.min_event_timestamp }}' {% endif %} ), {{ }}__base AS ( SELECT subquery.*, entity_dataframe.entity_timestamp, entity_dataframe.{{}}__entity_row_unique_id FROM {{ }}__subquery AS subquery INNER JOIN {{ }}__entity_dataframe AS entity_dataframe ON TRUE AND subquery.event_timestamp <= entity_dataframe.entity_timestamp {% if featureview.ttl == 0 %}{% else %} AND subquery.event_timestamp >= Timestamp_sub(entity_dataframe.entity_timestamp, interval {{ featureview.ttl }} second) {% endif %} {% for entity in featureview.entities %} AND subquery.{{ entity }} = entity_dataframe.{{ entity }} {% endfor %} ), /* 2. If the `created_timestamp_column` has been set, we need to deduplicate the data first. This is done by calculating the `MAX(created_at_timestamp)` for each event_timestamp. We then join the data on the next CTE */ {% if featureview.created_timestamp_column %} {{ }}__dedup AS ( SELECT {{}}__entity_row_unique_id, event_timestamp, MAX(created_timestamp) as created_timestamp FROM {{ }}__base GROUP BY {{}}__entity_row_unique_id, event_timestamp ), {% endif %} /* 3. The data has been filtered during the first CTE "*__base" Thus we only need to compute the latest timestamp of each feature. */ {{ }}__latest AS ( SELECT event_timestamp, {% if featureview.created_timestamp_column %}created_timestamp,{% endif %} {{}}__entity_row_unique_id FROM ( SELECT *, ROW_NUMBER() OVER( PARTITION BY {{}}__entity_row_unique_id ORDER BY event_timestamp DESC{% if featureview.created_timestamp_column %},created_timestamp DESC{% endif %} ) AS row_number FROM {{ }}__base {% if featureview.created_timestamp_column %} INNER JOIN {{ }}__dedup USING ({{}}__entity_row_unique_id, event_timestamp, created_timestamp) {% endif %} ) WHERE row_number = 1 ), /* 4. Once we know the latest value of each feature for a given timestamp, we can join again the data back to the original "base" dataset */ {{ }}__cleaned AS ( SELECT base.* FROM {{ }}__base as base INNER JOIN {{ }}__latest USING( {{}}__entity_row_unique_id, event_timestamp {% if featureview.created_timestamp_column %} ,created_timestamp {% endif %} ) ){% if loop.last %}{% else %}, {% endif %} {% endfor %} /* Joins the outputs of multiple time travel joins to a single table. The entity_dataframe dataset being our source of truth here. */ SELECT {{ final_output_feature_names | join(', ')}} FROM entity_dataframe {% for featureview in featureviews %} LEFT JOIN ( SELECT {{}}__entity_row_unique_id {% for feature in featureview.features %} ,{% if full_feature_names %}{{ }}__{{featureview.field_mapping.get(feature, feature)}}{% else %}{{ featureview.field_mapping.get(feature, feature) }}{% endif %} {% endfor %} FROM {{ }}__cleaned ) USING ({{}}__entity_row_unique_id) {% endfor %} """