Source code for feast.infra.online_stores.redis

# Copyright 2021 The Feast Authors
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
from datetime import datetime
from enum import Enum
from typing import (

from google.protobuf.timestamp_pb2 import Timestamp
from pydantic import StrictStr
from pydantic.typing import Literal

from feast import Entity, FeatureView, RepoConfig, utils
from feast.infra.online_stores.helpers import _mmh3, _redis_key, _redis_key_prefix
from feast.infra.online_stores.online_store import OnlineStore
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
from feast.repo_config import FeastConfigBaseModel
from feast.usage import log_exceptions_and_usage, tracing_span

    from redis import Redis
    from redis.cluster import RedisCluster
except ImportError as e:
    from feast.errors import FeastExtrasDependencyImportError

    raise FeastExtrasDependencyImportError("redis", str(e))

logger = logging.getLogger(__name__)

[docs]class RedisType(str, Enum): redis = "redis" redis_cluster = "redis_cluster"
[docs]class RedisOnlineStoreConfig(FeastConfigBaseModel): """Online store config for Redis store""" type: Literal["redis"] = "redis" """Online store type selector""" redis_type: RedisType = RedisType.redis """Redis type: redis or redis_cluster""" connection_string: StrictStr = "localhost:6379" """Connection string containing the host, port, and configuration parameters for Redis format: host:port,parameter1,parameter2 eg. redis:6379,db=0 """
[docs]class RedisOnlineStore(OnlineStore): _client: Optional[Union[Redis, RedisCluster]] = None def delete_table_values(self, config: RepoConfig, table: FeatureView): client = self._get_client(config.online_store) deleted_count = 0 pipeline = client.pipeline() prefix = _redis_key_prefix(table.entities) for _k in client.scan_iter( b"".join([prefix, b"*", config.project.encode("utf8")]) ): pipeline.delete(_k) deleted_count += 1 pipeline.execute() logger.debug(f"Deleted {deleted_count} keys for {}")
[docs] @log_exceptions_and_usage(online_store="redis") def update( self, config: RepoConfig, tables_to_delete: Sequence[FeatureView], tables_to_keep: Sequence[FeatureView], entities_to_delete: Sequence[Entity], entities_to_keep: Sequence[Entity], partial: bool, ): """ We delete the keys in redis for tables/views being removed. """ for table in tables_to_delete: self.delete_table_values(config, table)
[docs] def teardown( self, config: RepoConfig, tables: Sequence[FeatureView], entities: Sequence[Entity], ): """ We delete the keys in redis for tables/views being removed. """ for table in tables: self.delete_table_values(config, table)
@staticmethod def _parse_connection_string(connection_string: str): """ Reads Redis connections string using format for RedisCluster: redis1:6379,redis2:6379,decode_responses=true,skip_full_coverage_check=true,ssl=true,password=... for Redis: redis_master:6379,db=0,ssl=true,password=... """ startup_nodes = [ dict(zip(["host", "port"], c.split(":"))) for c in connection_string.split(",") if "=" not in c ] params = {} for c in connection_string.split(","): if "=" in c: kv = c.split("=", 1) try: kv[1] = json.loads(kv[1]) except json.JSONDecodeError: ... it = iter(kv) params.update(dict(zip(it, it))) return startup_nodes, params def _get_client(self, online_store_config: RedisOnlineStoreConfig): """ Creates the Redis client RedisCluster or Redis depending on configuration """ if not self._client: startup_nodes, kwargs = self._parse_connection_string( online_store_config.connection_string ) if online_store_config.redis_type == RedisType.redis_cluster: kwargs["startup_nodes"] = startup_nodes self._client = RedisCluster(**kwargs) else: kwargs["host"] = startup_nodes[0]["host"] kwargs["port"] = startup_nodes[0]["port"] self._client = Redis(**kwargs) return self._client
[docs] @log_exceptions_and_usage(online_store="redis") def online_write_batch( self, config: RepoConfig, table: FeatureView, data: List[ Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] ], progress: Optional[Callable[[int], Any]], ) -> None: online_store_config = config.online_store assert isinstance(online_store_config, RedisOnlineStoreConfig) client = self._get_client(online_store_config) project = config.project feature_view = ts_key = f"_ts:{feature_view}" keys = [] # redis pipelining optimization: send multiple commands to redis server without waiting for every reply with client.pipeline() as pipe: # check if a previous record under the key bin exists # TODO: investigate if check and set is a better approach rather than pulling all entity ts and then setting # it may be significantly slower but avoids potential (rare) race conditions for entity_key, _, _, _ in data: redis_key_bin = _redis_key(project, entity_key) keys.append(redis_key_bin) pipe.hmget(redis_key_bin, ts_key) prev_event_timestamps = pipe.execute() # flattening the list of lists. `hmget` does the lookup assuming a list of keys in the key bin prev_event_timestamps = [i[0] for i in prev_event_timestamps] for redis_key_bin, prev_event_time, (_, values, timestamp, _) in zip( keys, prev_event_timestamps, data ): event_time_seconds = int(utils.make_tzaware(timestamp).timestamp()) # ignore if event_timestamp is before the event features that are currently in the feature store if prev_event_time: prev_ts = Timestamp() prev_ts.ParseFromString(prev_event_time) if prev_ts.seconds and event_time_seconds <= prev_ts.seconds: # TODO: somehow signal that it's not overwriting the current record? if progress: progress(1) continue ts = Timestamp() ts.seconds = event_time_seconds entity_hset = dict() entity_hset[ts_key] = ts.SerializeToString() for feature_name, val in values.items(): f_key = _mmh3(f"{feature_view}:{feature_name}") entity_hset[f_key] = val.SerializeToString() pipe.hset(redis_key_bin, mapping=entity_hset) # TODO: support expiring the entity / features in Redis # otherwise entity features remain in redis until cleaned up in separate process # client.expire redis_key_bin based a ttl setting results = pipe.execute() if progress: progress(len(results))
[docs] @log_exceptions_and_usage(online_store="redis") def online_read( self, config: RepoConfig, table: FeatureView, entity_keys: List[EntityKeyProto], requested_features: Optional[List[str]] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: online_store_config = config.online_store assert isinstance(online_store_config, RedisOnlineStoreConfig) client = self._get_client(online_store_config) feature_view = project = config.project result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] if not requested_features: requested_features = [ for f in table.features] hset_keys = [_mmh3(f"{feature_view}:{k}") for k in requested_features] ts_key = f"_ts:{feature_view}" hset_keys.append(ts_key) requested_features.append(ts_key) keys = [] for entity_key in entity_keys: redis_key_bin = _redis_key(project, entity_key) keys.append(redis_key_bin) with client.pipeline() as pipe: for redis_key_bin in keys: pipe.hmget(redis_key_bin, hset_keys) with tracing_span(name="remote_call"): redis_values = pipe.execute() for values in redis_values: features = self._get_features_for_entity( values, feature_view, requested_features ) result.append(features) return result
def _get_features_for_entity( self, values: List[ByteString], feature_view: str, requested_features: List[str], ) -> Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]: res_val = dict(zip(requested_features, values)) res_ts = Timestamp() ts_val = res_val.pop(f"_ts:{feature_view}") if ts_val: res_ts.ParseFromString(ts_val) res = {} for feature_name, val_bin in res_val.items(): val = ValueProto() if val_bin: val.ParseFromString(val_bin) res[feature_name] = val if not res: return None, None else: timestamp = datetime.fromtimestamp(res_ts.seconds) return timestamp, res