from datetime import datetime
from pathlib import Path
from threading import Lock
from typing import Any, List, Optional, Set, Union
from sqlalchemy import ( # type: ignore
BigInteger,
Column,
LargeBinary,
MetaData,
String,
Table,
create_engine,
delete,
insert,
select,
update,
)
from sqlalchemy.engine import Engine
from feast.base_feature_view import BaseFeatureView
from feast.data_source import DataSource
from feast.entity import Entity
from feast.errors import (
DataSourceObjectNotFoundException,
EntityNotFoundException,
FeatureServiceNotFoundException,
FeatureViewNotFoundException,
SavedDatasetNotFound,
ValidationReferenceNotFound,
)
from feast.feature_service import FeatureService
from feast.feature_view import FeatureView
from feast.infra.infra_object import Infra
from feast.on_demand_feature_view import OnDemandFeatureView
from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto
from feast.protos.feast.core.Entity_pb2 import Entity as EntityProto
from feast.protos.feast.core.FeatureService_pb2 import (
FeatureService as FeatureServiceProto,
)
from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto
from feast.protos.feast.core.InfraObject_pb2 import Infra as InfraProto
from feast.protos.feast.core.OnDemandFeatureView_pb2 import (
OnDemandFeatureView as OnDemandFeatureViewProto,
)
from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto
from feast.protos.feast.core.RequestFeatureView_pb2 import (
RequestFeatureView as RequestFeatureViewProto,
)
from feast.protos.feast.core.SavedDataset_pb2 import SavedDataset as SavedDatasetProto
from feast.protos.feast.core.StreamFeatureView_pb2 import (
StreamFeatureView as StreamFeatureViewProto,
)
from feast.protos.feast.core.ValidationProfile_pb2 import (
ValidationReference as ValidationReferenceProto,
)
from feast.registry import BaseRegistry
from feast.repo_config import RegistryConfig
from feast.request_feature_view import RequestFeatureView
from feast.saved_dataset import SavedDataset, ValidationReference
from feast.stream_feature_view import StreamFeatureView
metadata = MetaData()
entities = Table(
"entities",
metadata,
Column("entity_name", String(50), primary_key=True),
Column("project_id", String(50), primary_key=True),
Column("last_updated_timestamp", BigInteger, nullable=False),
Column("entity_proto", LargeBinary, nullable=False),
)
data_sources = Table(
"data_sources",
metadata,
Column("data_source_name", String(50), primary_key=True),
Column("project_id", String(50), primary_key=True),
Column("last_updated_timestamp", BigInteger, nullable=False),
Column("data_source_proto", LargeBinary, nullable=False),
)
feature_views = Table(
"feature_views",
metadata,
Column("feature_view_name", String(50), primary_key=True),
Column("project_id", String(50), primary_key=True),
Column("last_updated_timestamp", BigInteger, nullable=False),
Column("materialized_intervals", LargeBinary, nullable=True),
Column("feature_view_proto", LargeBinary, nullable=False),
Column("user_metadata", LargeBinary, nullable=True),
)
request_feature_views = Table(
"request_feature_views",
metadata,
Column("feature_view_name", String(50), primary_key=True),
Column("project_id", String(50), primary_key=True),
Column("last_updated_timestamp", BigInteger, nullable=False),
Column("feature_view_proto", LargeBinary, nullable=False),
Column("user_metadata", LargeBinary, nullable=True),
)
stream_feature_views = Table(
"stream_feature_views",
metadata,
Column("feature_view_name", String(50), primary_key=True),
Column("project_id", String(50), primary_key=True),
Column("last_updated_timestamp", BigInteger, nullable=False),
Column("feature_view_proto", LargeBinary, nullable=False),
Column("user_metadata", LargeBinary, nullable=True),
)
on_demand_feature_views = Table(
"on_demand_feature_views",
metadata,
Column("feature_view_name", String(50), primary_key=True),
Column("project_id", String(50), primary_key=True),
Column("last_updated_timestamp", BigInteger, nullable=False),
Column("feature_view_proto", LargeBinary, nullable=False),
Column("user_metadata", LargeBinary, nullable=True),
)
feature_services = Table(
"feature_services",
metadata,
Column("feature_service_name", String(50), primary_key=True),
Column("project_id", String(50), primary_key=True),
Column("last_updated_timestamp", BigInteger, nullable=False),
Column("feature_service_proto", LargeBinary, nullable=False),
)
saved_datasets = Table(
"saved_datasets",
metadata,
Column("saved_dataset_name", String(50), primary_key=True),
Column("project_id", String(50), primary_key=True),
Column("last_updated_timestamp", BigInteger, nullable=False),
Column("saved_dataset_proto", LargeBinary, nullable=False),
)
validation_references = Table(
"validation_references",
metadata,
Column("validation_reference_name", String(50), primary_key=True),
Column("project_id", String(50), primary_key=True),
Column("last_updated_timestamp", BigInteger, nullable=False),
Column("validation_reference_proto", LargeBinary, nullable=False),
)
managed_infra = Table(
"managed_infra",
metadata,
Column("infra_name", String(50), primary_key=True),
Column("project_id", String(50), primary_key=True),
Column("last_updated_timestamp", BigInteger, nullable=False),
Column("infra_proto", LargeBinary, nullable=False),
)
feast_metadata = Table(
"feast_metadata",
metadata,
Column("project_id", String(50), primary_key=True),
Column("metadata_key", String(50), primary_key=True),
Column("metadata_value", String(50), nullable=False),
Column("last_updated_timestamp", BigInteger, nullable=False),
)
[docs]class SqlRegistry(BaseRegistry):
def __init__(
self, registry_config: Optional[RegistryConfig], repo_path: Optional[Path]
):
assert registry_config is not None, "SqlRegistry needs a valid registry_config"
self.engine: Engine = create_engine(registry_config.path, echo=False)
metadata.create_all(self.engine)
# _refresh_lock is not used by the SqlRegistry, but is present to conform to the
# Registry class.
# TODO: remove external references to _refresh_lock and remove field.
self._refresh_lock = Lock()
[docs] def teardown(self):
for t in {
entities,
data_sources,
feature_views,
feature_services,
on_demand_feature_views,
request_feature_views,
saved_datasets,
validation_references,
}:
with self.engine.connect() as conn:
stmt = delete(t)
conn.execute(stmt)
[docs] def refresh(self):
# This method is a no-op since we're always reading the latest values from the db.
pass
[docs] def get_stream_feature_view(
self, name: str, project: str, allow_cache: bool = False
):
return self._get_object(
stream_feature_views,
name,
project,
StreamFeatureViewProto,
StreamFeatureView,
"feature_view_name",
"feature_view_proto",
FeatureViewNotFoundException,
)
[docs] def list_stream_feature_views(
self, project: str, allow_cache: bool = False
) -> List[StreamFeatureView]:
return self._list_objects(
stream_feature_views,
project,
StreamFeatureViewProto,
StreamFeatureView,
"feature_view_proto",
)
[docs] def apply_entity(self, entity: Entity, project: str, commit: bool = True):
return self._apply_object(
entities, project, "entity_name", entity, "entity_proto"
)
[docs] def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Entity:
return self._get_object(
entities,
name,
project,
EntityProto,
Entity,
"entity_name",
"entity_proto",
EntityNotFoundException,
)
[docs] def get_feature_view(
self, name: str, project: str, allow_cache: bool = False
) -> FeatureView:
return self._get_object(
feature_views,
name,
project,
FeatureViewProto,
FeatureView,
"feature_view_name",
"feature_view_proto",
FeatureViewNotFoundException,
)
[docs] def get_on_demand_feature_view(
self, name: str, project: str, allow_cache: bool = False
) -> OnDemandFeatureView:
return self._get_object(
on_demand_feature_views,
name,
project,
OnDemandFeatureViewProto,
OnDemandFeatureView,
"feature_view_name",
"feature_view_proto",
FeatureViewNotFoundException,
)
[docs] def get_request_feature_view(self, name: str, project: str):
return self._get_object(
request_feature_views,
name,
project,
RequestFeatureViewProto,
RequestFeatureView,
"feature_view_name",
"feature_view_proto",
FeatureViewNotFoundException,
)
[docs] def get_feature_service(
self, name: str, project: str, allow_cache: bool = False
) -> FeatureService:
return self._get_object(
feature_services,
name,
project,
FeatureServiceProto,
FeatureService,
"feature_service_name",
"feature_service_proto",
FeatureServiceNotFoundException,
)
[docs] def get_saved_dataset(
self, name: str, project: str, allow_cache: bool = False
) -> SavedDataset:
return self._get_object(
saved_datasets,
name,
project,
SavedDatasetProto,
SavedDataset,
"saved_dataset_name",
"saved_dataset_proto",
SavedDatasetNotFound,
)
[docs] def get_validation_reference(
self, name: str, project: str, allow_cache: bool = False
) -> ValidationReference:
return self._get_object(
validation_references,
name,
project,
ValidationReferenceProto,
ValidationReference,
"validation_reference_name",
"validation_reference_proto",
ValidationReferenceNotFound,
)
[docs] def list_entities(self, project: str, allow_cache: bool = False) -> List[Entity]:
return self._list_objects(
entities, project, EntityProto, Entity, "entity_proto"
)
[docs] def delete_entity(self, name: str, project: str, commit: bool = True):
return self._delete_object(
entities, name, project, "entity_name", EntityNotFoundException
)
[docs] def delete_feature_view(self, name: str, project: str, commit: bool = True):
deleted_count = 0
for table in {
feature_views,
request_feature_views,
on_demand_feature_views,
stream_feature_views,
}:
deleted_count += self._delete_object(
table, name, project, "feature_view_name", None
)
if deleted_count == 0:
raise FeatureViewNotFoundException(name, project)
[docs] def delete_feature_service(self, name: str, project: str, commit: bool = True):
return self._delete_object(
feature_services,
name,
project,
"feature_service_name",
FeatureServiceNotFoundException,
)
[docs] def get_data_source(
self, name: str, project: str, allow_cache: bool = False
) -> DataSource:
return self._get_object(
data_sources,
name,
project,
DataSourceProto,
DataSource,
"data_source_name",
"data_source_proto",
DataSourceObjectNotFoundException,
)
[docs] def list_data_sources(
self, project: str, allow_cache: bool = False
) -> List[DataSource]:
return self._list_objects(
data_sources, project, DataSourceProto, DataSource, "data_source_proto"
)
[docs] def apply_data_source(
self, data_source: DataSource, project: str, commit: bool = True
):
return self._apply_object(
data_sources, project, "data_source_name", data_source, "data_source_proto"
)
[docs] def apply_feature_view(
self, feature_view: BaseFeatureView, project: str, commit: bool = True
):
fv_table = self._infer_fv_table(feature_view)
return self._apply_object(
fv_table, project, "feature_view_name", feature_view, "feature_view_proto"
)
[docs] def apply_feature_service(
self, feature_service: FeatureService, project: str, commit: bool = True
):
return self._apply_object(
feature_services,
project,
"feature_service_name",
feature_service,
"feature_service_proto",
)
[docs] def delete_data_source(self, name: str, project: str, commit: bool = True):
with self.engine.connect() as conn:
stmt = delete(data_sources).where(
data_sources.c.data_source_name == name,
data_sources.c.project_id == project,
)
rows = conn.execute(stmt)
if rows.rowcount < 1:
raise DataSourceObjectNotFoundException(name, project)
[docs] def list_feature_services(
self, project: str, allow_cache: bool = False
) -> List[FeatureService]:
return self._list_objects(
feature_services,
project,
FeatureServiceProto,
FeatureService,
"feature_service_proto",
)
[docs] def list_feature_views(
self, project: str, allow_cache: bool = False
) -> List[FeatureView]:
return self._list_objects(
feature_views, project, FeatureViewProto, FeatureView, "feature_view_proto"
)
[docs] def list_saved_datasets(
self, project: str, allow_cache: bool = False
) -> List[SavedDataset]:
return self._list_objects(
saved_datasets,
project,
SavedDatasetProto,
SavedDataset,
"saved_dataset_proto",
)
[docs] def list_request_feature_views(
self, project: str, allow_cache: bool = False
) -> List[RequestFeatureView]:
return self._list_objects(
request_feature_views,
project,
RequestFeatureViewProto,
RequestFeatureView,
"feature_view_proto",
)
[docs] def list_on_demand_feature_views(
self, project: str, allow_cache: bool = False
) -> List[OnDemandFeatureView]:
return self._list_objects(
on_demand_feature_views,
project,
OnDemandFeatureViewProto,
OnDemandFeatureView,
"feature_view_proto",
)
[docs] def apply_saved_dataset(
self, saved_dataset: SavedDataset, project: str, commit: bool = True,
):
return self._apply_object(
saved_datasets,
project,
"saved_dataset_name",
saved_dataset,
"saved_dataset_proto",
)
[docs] def apply_validation_reference(
self,
validation_reference: ValidationReference,
project: str,
commit: bool = True,
):
return self._apply_object(
validation_references,
project,
"validation_reference_name",
validation_reference,
"validation_reference_proto",
)
[docs] def apply_materialization(
self,
feature_view: FeatureView,
project: str,
start_date: datetime,
end_date: datetime,
commit: bool = True,
):
table = self._infer_fv_table(feature_view)
python_class, proto_class = self._infer_fv_classes(feature_view)
if python_class in {RequestFeatureView, OnDemandFeatureView}:
raise ValueError(
f"Cannot apply materialization for feature {feature_view.name} of type {python_class}"
)
fv: Union[FeatureView, StreamFeatureView] = self._get_object(
table,
feature_view.name,
project,
proto_class,
python_class,
"feature_view_name",
"feature_view_proto",
FeatureViewNotFoundException,
)
fv.materialization_intervals.append((start_date, end_date))
self._apply_object(
table, project, "feature_view_name", fv, "feature_view_proto"
)
[docs] def delete_validation_reference(self, name: str, project: str, commit: bool = True):
self._delete_object(
validation_references,
name,
project,
"validation_reference_name",
ValidationReferenceNotFound,
)
[docs] def update_infra(self, infra: Infra, project: str, commit: bool = True):
self._apply_object(
managed_infra, project, "infra_name", infra, "infra_proto", name="infra_obj"
)
[docs] def get_infra(self, project: str, allow_cache: bool = False) -> Infra:
return self._get_object(
managed_infra,
"infra_obj",
project,
InfraProto,
Infra,
"infra_name",
"infra_proto",
None,
)
def _infer_fv_table(self, feature_view):
if isinstance(feature_view, StreamFeatureView):
table = stream_feature_views
elif isinstance(feature_view, FeatureView):
table = feature_views
elif isinstance(feature_view, OnDemandFeatureView):
table = on_demand_feature_views
elif isinstance(feature_view, RequestFeatureView):
table = request_feature_views
else:
raise ValueError(f"Unexpected feature view type: {type(feature_view)}")
return table
def _infer_fv_classes(self, feature_view):
if isinstance(feature_view, StreamFeatureView):
python_class, proto_class = StreamFeatureView, StreamFeatureViewProto
elif isinstance(feature_view, FeatureView):
python_class, proto_class = FeatureView, FeatureViewProto
elif isinstance(feature_view, OnDemandFeatureView):
python_class, proto_class = OnDemandFeatureView, OnDemandFeatureViewProto
elif isinstance(feature_view, RequestFeatureView):
python_class, proto_class = RequestFeatureView, RequestFeatureViewProto
else:
raise ValueError(f"Unexpected feature view type: {type(feature_view)}")
return python_class, proto_class
[docs] def proto(self) -> RegistryProto:
r = RegistryProto()
last_updated_timestamps = []
projects = self._get_all_projects()
for project in projects:
for lister, registry_proto_field in [
(self.list_entities, r.entities),
(self.list_feature_views, r.feature_views),
(self.list_data_sources, r.data_sources),
(self.list_on_demand_feature_views, r.on_demand_feature_views),
(self.list_request_feature_views, r.request_feature_views),
(self.list_stream_feature_views, r.stream_feature_views),
(self.list_feature_services, r.feature_services),
(self.list_saved_datasets, r.saved_datasets),
(self.list_validation_references, r.validation_references),
]:
objs: List[Any] = lister(project) # type: ignore
if objs:
registry_proto_field.extend([obj.to_proto() for obj in objs])
# This is suuuper jank. Because of https://github.com/feast-dev/feast/issues/2783,
# the registry proto only has a single infra field, which we're currently setting as the "last" project.
r.infra.CopyFrom(self.get_infra(project).to_proto())
last_updated_timestamps.append(self._get_last_updated_metadata(project))
if last_updated_timestamps:
r.last_updated.FromDatetime(max(last_updated_timestamps))
return r
[docs] def commit(self):
# This method is a no-op since we're always writing values eagerly to the db.
pass
def _apply_object(
self, table, project: str, id_field_name, obj, proto_field_name, name=None
):
name = name or obj.name
with self.engine.connect() as conn:
stmt = select(table).where(
getattr(table.c, id_field_name) == name, table.c.project_id == project
)
row = conn.execute(stmt).first()
update_datetime = datetime.utcnow()
update_time = int(update_datetime.timestamp())
if hasattr(obj, "last_updated_timestamp"):
obj.last_updated_timestamp = update_datetime
if row:
values = {
proto_field_name: obj.to_proto().SerializeToString(),
"last_updated_timestamp": update_time,
}
update_stmt = (
update(table)
.where(getattr(table.c, id_field_name) == name)
.values(values,)
)
conn.execute(update_stmt)
else:
values = {
id_field_name: name,
proto_field_name: obj.to_proto().SerializeToString(),
"last_updated_timestamp": update_time,
"project_id": project,
}
insert_stmt = insert(table).values(values,)
conn.execute(insert_stmt)
self._set_last_updated_metadata(update_datetime, project)
def _delete_object(self, table, name, project, id_field_name, not_found_exception):
with self.engine.connect() as conn:
stmt = delete(table).where(
getattr(table.c, id_field_name) == name, table.c.project_id == project
)
rows = conn.execute(stmt)
if rows.rowcount < 1 and not_found_exception:
raise not_found_exception(name, project)
self._set_last_updated_metadata(datetime.utcnow(), project)
return rows.rowcount
def _get_object(
self,
table,
name,
project,
proto_class,
python_class,
id_field_name,
proto_field_name,
not_found_exception,
):
with self.engine.connect() as conn:
stmt = select(table).where(
getattr(table.c, id_field_name) == name, table.c.project_id == project
)
row = conn.execute(stmt).first()
if row:
_proto = proto_class.FromString(row[proto_field_name])
return python_class.from_proto(_proto)
raise not_found_exception(name, project)
def _list_objects(
self, table, project, proto_class, python_class, proto_field_name
):
with self.engine.connect() as conn:
stmt = select(table).where(table.c.project_id == project)
rows = conn.execute(stmt).all()
if rows:
return [
python_class.from_proto(
proto_class.FromString(row[proto_field_name])
)
for row in rows
]
return []
def _set_last_updated_metadata(self, last_updated: datetime, project: str):
with self.engine.connect() as conn:
stmt = select(feast_metadata).where(
feast_metadata.c.metadata_key == "last_updated_timestamp",
feast_metadata.c.project_id == project,
)
row = conn.execute(stmt).first()
update_time = int(last_updated.timestamp())
values = {
"metadata_key": "last_updated_timestamp",
"metadata_value": f"{update_time}",
"last_updated_timestamp": update_time,
"project_id": project,
}
if row:
update_stmt = (
update(feast_metadata)
.where(
feast_metadata.c.metadata_key == "last_updated_timestamp",
feast_metadata.c.project_id == project,
)
.values(values)
)
conn.execute(update_stmt)
else:
insert_stmt = insert(feast_metadata).values(values,)
conn.execute(insert_stmt)
def _get_last_updated_metadata(self, project: str):
with self.engine.connect() as conn:
stmt = select(feast_metadata).where(
feast_metadata.c.metadata_key == "last_updated_timestamp",
feast_metadata.c.project_id == project,
)
row = conn.execute(stmt).first()
if not row:
return None
update_time = int(row["last_updated_timestamp"])
return datetime.utcfromtimestamp(update_time)
def _get_all_projects(self) -> Set[str]:
projects = set()
with self.engine.connect() as conn:
for table in {
entities,
data_sources,
feature_views,
request_feature_views,
on_demand_feature_views,
stream_feature_views,
}:
stmt = select(table)
rows = conn.execute(stmt).all()
for row in rows:
projects.add(row["project_id"])
return projects