Source code for feast.infra.offline_stores.contrib.postgres_offline_store.postgres_source

import json
from typing import Callable, Dict, Iterable, Optional, Tuple

from typeguard import typechecked

from feast.data_source import DataSource
from feast.errors import DataSourceNoNameException
from feast.infra.utils.postgres.connection_utils import _get_conn
from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto
from feast.protos.feast.core.SavedDataset_pb2 import (
    SavedDatasetStorage as SavedDatasetStorageProto,
)
from feast.repo_config import RepoConfig
from feast.saved_dataset import SavedDatasetStorage
from feast.type_map import pg_type_code_to_pg_type, pg_type_to_feast_value_type
from feast.value_type import ValueType


[docs]@typechecked class PostgreSQLSource(DataSource): def __init__( self, name: Optional[str] = None, query: Optional[str] = None, table: Optional[str] = None, timestamp_field: Optional[str] = "", created_timestamp_column: Optional[str] = "", field_mapping: Optional[Dict[str, str]] = None, description: Optional[str] = "", tags: Optional[Dict[str, str]] = None, owner: Optional[str] = "", ): self._postgres_options = PostgreSQLOptions(name=name, query=query, table=table) # If no name, use the table as the default name. if name is None and table is None: raise DataSourceNoNameException() name = name or table assert name super().__init__( name=name, timestamp_field=timestamp_field, created_timestamp_column=created_timestamp_column, field_mapping=field_mapping, description=description, tags=tags, owner=owner, ) def __hash__(self): return super().__hash__() def __eq__(self, other): if not isinstance(other, PostgreSQLSource): raise TypeError( "Comparisons should only involve PostgreSQLSource class objects." ) return ( super().__eq__(other) and self._postgres_options._query == other._postgres_options._query and self.timestamp_field == other.timestamp_field and self.created_timestamp_column == other.created_timestamp_column and self.field_mapping == other.field_mapping )
[docs] @staticmethod def from_proto(data_source: DataSourceProto): assert data_source.HasField("custom_options") postgres_options = json.loads(data_source.custom_options.configuration) return PostgreSQLSource( name=postgres_options["name"], query=postgres_options["query"], table=postgres_options["table"], field_mapping=dict(data_source.field_mapping), timestamp_field=data_source.timestamp_field, created_timestamp_column=data_source.created_timestamp_column, description=data_source.description, tags=dict(data_source.tags), owner=data_source.owner, )
[docs] def to_proto(self) -> DataSourceProto: data_source_proto = DataSourceProto( name=self.name, type=DataSourceProto.CUSTOM_SOURCE, data_source_class_type="feast.infra.offline_stores.contrib.postgres_offline_store.postgres_source.PostgreSQLSource", field_mapping=self.field_mapping, custom_options=self._postgres_options.to_proto(), description=self.description, tags=self.tags, owner=self.owner, ) data_source_proto.timestamp_field = self.timestamp_field data_source_proto.created_timestamp_column = self.created_timestamp_column return data_source_proto
[docs] def validate(self, config: RepoConfig): pass
[docs] @staticmethod def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]: return pg_type_to_feast_value_type
[docs] def get_table_column_names_and_types( self, config: RepoConfig ) -> Iterable[Tuple[str, str]]: with _get_conn(config.offline_store) as conn, conn.cursor() as cur: cur.execute(f"SELECT * FROM {self.get_table_query_string()} AS sub LIMIT 0") return ( (c.name, pg_type_code_to_pg_type(c.type_code)) for c in cur.description )
[docs] def get_table_query_string(self) -> str: if self._postgres_options._table: return f"{self._postgres_options._table}" else: return f"({self._postgres_options._query})"
[docs]class PostgreSQLOptions: def __init__( self, name: Optional[str], query: Optional[str], table: Optional[str], ): self._name = name or "" self._query = query or "" self._table = table or ""
[docs] @classmethod def from_proto(cls, postgres_options_proto: DataSourceProto.CustomSourceOptions): config = json.loads(postgres_options_proto.configuration.decode("utf8")) postgres_options = cls( name=config["name"], query=config["query"], table=config["table"] ) return postgres_options
[docs] def to_proto(self) -> DataSourceProto.CustomSourceOptions: postgres_options_proto = DataSourceProto.CustomSourceOptions( configuration=json.dumps( {"name": self._name, "query": self._query, "table": self._table} ).encode() ) return postgres_options_proto
[docs]class SavedDatasetPostgreSQLStorage(SavedDatasetStorage): _proto_attr_name = "custom_storage" postgres_options: PostgreSQLOptions def __init__(self, table_ref: str): self.postgres_options = PostgreSQLOptions( table=table_ref, name=None, query=None )
[docs] @staticmethod def from_proto(storage_proto: SavedDatasetStorageProto) -> SavedDatasetStorage: return SavedDatasetPostgreSQLStorage( table_ref=PostgreSQLOptions.from_proto(storage_proto.custom_storage)._table )
[docs] def to_proto(self) -> SavedDatasetStorageProto: return SavedDatasetStorageProto(custom_storage=self.postgres_options.to_proto())
[docs] def to_data_source(self) -> DataSource: return PostgreSQLSource(table=self.postgres_options._table)