# -*- coding: utf-8 -*-
from abc import ABCMeta
import logging
from typing import Type, Optional, Sequence
import pytray.pretty
from . import depositors
from . import exceptions
from . import expr
from . import fields
from . import migrations
from . import saving
from . import tracking
from . import types
__all__ = "TypeHelper", "WrapperHelper", "BaseHelper"
logger = logging.getLogger(__name__)
def inject_creation_tracking(cls: Type):
# Check to make sure we don't do this twice!
if not hasattr(cls, "__orig_new"):
cls.__orig_new = cls.__new__ # pylint: disable=protected-access
def new(_cls, *_args, **_kwargs):
inst = cls.__orig_new(_cls) # pylint: disable=protected-access
tracking.obj_created(inst)
return inst
cls.__new__ = new
def remove_creation_tracking(cls: Type):
try:
cls.__new__ = cls.__orig_new # pylint: disable=protected-access
except AttributeError:
pass
[docs]class TypeHelper(fields.WithFields):
"""This interface provides the basic methods necessary to enable a type to be compatible with
the historian."""
#: The type this helper corresponds to
TYPE: Type = None
TYPE_ID = None # The unique id for this type of object
IMMUTABLE = False # If set to true then the object is decoded straight away
INJECT_CREATION_TRACKING = False
# The latest migration, if there is one
LATEST_MIGRATION = None # type: migrations.ObjectMigration
@classmethod
def init_field(cls, field: fields.Field, attr_name: str):
super().init_field(field, attr_name)
field.set_query_context(expr.Comparison("type_id", expr.Eq(cls.TYPE_ID)))
field.path_prefix = "state"
def __init__(self):
assert (
self.TYPE is not None
), "Must set the TYPE to a type of or a tuple of types"
if self.INJECT_CREATION_TRACKING:
inject_creation_tracking(self.TYPE)
[docs] def new(self, encoded_saved_state): # pylint: disable=unused-argument
"""Create a new blank object of this type"""
cls = self.TYPE
return cls.__new__(cls)
[docs] def yield_hashables(self, obj: object, hasher):
"""Yield values from this object that should be included in its hash"""
yield from hasher.yield_hashables(saving.save_instance_state(obj, type(self)))
[docs] def eq(self, one, other) -> bool: # pylint: disable=invalid-name
"""Determine if two objects are equal"""
if not isinstance(one, self.TYPE) or not isinstance(
other, self.TYPE
): # pylint: disable=isinstance-second-argument-not-valid-type
return False
return saving.save_instance_state(
one, type(self)
) == saving.save_instance_state(other, type(self))
[docs] def save_instance_state(self, obj, saver): # pylint: disable=unused-argument
"""Save the instance state of an object, should return a saved instance"""
return saving.save_instance_state(obj, type(self))
[docs] def load_instance_state(
self, obj, saved_state, loader: depositors.Loader
): # pylint: disable=unused-argument
"""Take the given blank object and load the instance state into it"""
saving.load_instance_state(obj, saved_state, type(self))
[docs] def get_version(self) -> Optional[int]:
"""Gets the version of the latest migration, returns None if there is not migration"""
if self.LATEST_MIGRATION is None:
return None
version = self.LATEST_MIGRATION.VERSION
if version is None:
raise RuntimeError(
f"Object '{self.TYPE}' has a migration ({self.LATEST_MIGRATION}) which has no version number"
)
return version
[docs] def ensure_up_to_date(
self, saved_state, version: Optional[int], loader: depositors.Loader
):
"""Apply any migrations that are necessary to this saved state. If no migrations are
necessary then None is returned"""
latest_version = (
None if self.LATEST_MIGRATION is None else self.LATEST_MIGRATION.VERSION
)
if latest_version == version:
return None
if latest_version is None or (version is not None and latest_version < version):
raise exceptions.VersionError(
f"This codebase's version of '{pytray.pretty.type_string(self.TYPE)}' is older "
f"({latest_version}) than the saved version ({version}). Check for updates."
)
to_apply = self._get_migrations(version)
if not to_apply:
return None
total = len(to_apply)
logger.info(
"Migrating saved state of '%s' from version %s to %i (%i migrations to apply)",
pytray.pretty.type_string(self.TYPE),
version,
self.get_version(),
total,
)
for i, migration in enumerate(to_apply):
saved_state = migration.upgrade(saved_state, loader)
logger.info(
"Migration '%s' applied (%i/%i)",
pytray.pretty.type_string(migration),
i + 1,
total,
)
logger.info(
"Migration of '%s' completed successfully",
pytray.pretty.type_string(self.TYPE),
)
return saved_state
def _get_migrations(
self, version: Optional[int]
) -> Sequence[migrations.ObjectMigration]:
"""Get the sequence of migrations that needs to be applied to a given version"""
if self.LATEST_MIGRATION is None:
return [] # No migrations we can apply
to_apply = []
current = self.LATEST_MIGRATION
while version is None or version < current.VERSION:
to_apply.append(current)
current = current.PREVIOUS
if current is None:
break
to_apply.reverse()
return to_apply
[docs]class BaseHelper(TypeHelper, metaclass=ABCMeta):
"""A base helper that defaults to yielding hashables directly on the object
and testing for equality using == given two objects. This behaviour is fairly
standard and therefore most type helpers will want to subclass from this class."""
[docs] def yield_hashables(self, obj, hasher):
yield from hasher.yield_hashables(obj)
[docs] def eq(self, one, other) -> bool:
return one == other
[docs]class WrapperHelper(TypeHelper):
"""Wraps up an object type to perform the necessary Historian actions"""
# pylint: disable=invalid-name
def __init__(self, obj_type: Type[types.SavableObject]):
self.TYPE = obj_type
self.TYPE_ID = obj_type.TYPE_ID
self.LATEST_MIGRATION = obj_type.LATEST_MIGRATION
super().__init__()
[docs] def yield_hashables(self, obj, hasher):
yield from self.TYPE.yield_hashables(obj, hasher)
[docs] def eq(self, one, other) -> bool:
return self.TYPE.__eq__(one, other) # pylint: disable=unnecessary-dunder-call
[docs] def save_instance_state(self, obj: types.Savable, saver):
return self.TYPE.save_instance_state(obj, saver)
[docs] def load_instance_state(self, obj, saved_state: types.Savable, loader):
self.TYPE.load_instance_state(obj, saved_state, loader)