Source code for feast.infra.offline_stores.contrib.athena_offline_store.tests.data_source

import os
import uuid
from typing import Dict, List, Optional

import pandas as pd

from feast import AthenaSource
from feast.data_source import DataSource
from feast.feature_logging import LoggingDestination
from feast.infra.offline_stores.contrib.athena_offline_store.athena import (
    AthenaOfflineStoreConfig,
)
from feast.infra.offline_stores.contrib.athena_offline_store.athena_source import (
    AthenaLoggingDestination,
    SavedDatasetAthenaStorage,
)
from feast.infra.utils import aws_utils
from feast.repo_config import FeastConfigBaseModel
from tests.integration.feature_repos.universal.data_source_creator import (
    DataSourceCreator,
)


[docs]class AthenaDataSourceCreator(DataSourceCreator): tables: List[str] = [] def __init__(self, project_name: str, *args, **kwargs): super().__init__(project_name) region = os.getenv("ATHENA_REGION", "ap-northeast-2") data_source = os.getenv("ATHENA_DATA_SOURCE", "AwsDataCatalog") database = os.getenv("ATHENA_DATABASE", "default") workgroup = os.getenv("ATHENA_WORKGROUP", "primary") bucket_name = os.getenv("ATHENA_S3_BUCKET_NAME", "feast-integration-tests") self.client = aws_utils.get_athena_data_client(region) self.s3 = aws_utils.get_s3_resource(region) self.offline_store_config = AthenaOfflineStoreConfig( data_source=data_source, region=region, database=database, workgroup=workgroup, s3_staging_location=f"s3://{bucket_name}/test_dir", )
[docs] def create_data_source( self, df: pd.DataFrame, destination_name: str, suffix: Optional[str] = None, timestamp_field="ts", created_timestamp_column="created_ts", field_mapping: Dict[str, str] = None, ) -> DataSource: table_name = destination_name s3_target = ( self.offline_store_config.s3_staging_location + "/" + self.project_name + "/" + table_name + "/" + table_name + ".parquet" ) aws_utils.upload_df_to_athena( self.client, self.offline_store_config.data_source, self.offline_store_config.database, self.offline_store_config.workgroup, self.s3, s3_target, table_name, df, ) self.tables.append(table_name) return AthenaSource( table=table_name, timestamp_field=timestamp_field, created_timestamp_column=created_timestamp_column, field_mapping=field_mapping or {"ts_1": "ts"}, database=self.offline_store_config.database, data_source=self.offline_store_config.data_source, )
[docs] def create_saved_dataset_destination(self) -> SavedDatasetAthenaStorage: table = self.get_prefixed_table_name( f"persisted_ds_{str(uuid.uuid4()).replace('-', '_')}" ) self.tables.append(table) return SavedDatasetAthenaStorage( table_ref=table, database=self.offline_store_config.database, data_source=self.offline_store_config.data_source, )
[docs] def create_logged_features_destination(self) -> LoggingDestination: table = self.get_prefixed_table_name( f"persisted_ds_{str(uuid.uuid4()).replace('-', '_')}" ) self.tables.append(table) return AthenaLoggingDestination(table_name=table)
[docs] def create_offline_store_config(self) -> FeastConfigBaseModel: return self.offline_store_config
[docs] def get_prefixed_table_name(self, suffix: str) -> str: return f"{self.project_name}_{suffix}"
[docs] def teardown(self): for table in self.tables: aws_utils.execute_athena_query( self.client, self.offline_store_config.data_source, self.offline_store_config.database, self.offline_store_config.workgroup, f"DROP TABLE IF EXISTS {table}", )