# -*- coding: utf-8 -*-
"""This module contains various strategies for loading, saving and migrating objects in the archive
"""
from abc import ABCMeta, abstractmethod
import contextlib
import logging
from typing import Optional, Dict, Any, Iterable, Sequence
import deprecation
from pytray import tree
import mincepy
from . import archives
from . import exceptions
from . import operations
from . import records
from . import staging
from . import transactions # pylint: disable=unused-import
from . import version as version_mod
__all__ = "Saver", "Loader", "SnapshotLoader", "LiveDepositor", "Migrator"
logger = logging.getLogger(__name__)
CONTAINERS = list, dict
class Base(metaclass=ABCMeta):
"""Common base for loader and saver"""
def __init__(self, historian):
self._historian: "mincepy.Historian" = historian
@property
def historian(self) -> "mincepy.Historian":
"""Get the owning historian"""
return self._historian
@property
def archive(self) -> archives.Archive:
"""Get the archive of the owning historian"""
return self._historian.archive
def get_historian(self) -> "mincepy.Historian":
"""Get the owning historian"""
return self._historian
def get_archive(self) -> archives.Archive:
"""Get the archive of the owning historian"""
return self._historian.archive
[docs]class Saver(Base, metaclass=ABCMeta):
"""A depositor that knows how to save records to the archive"""
_extras: Dict[str, Dict] = {}
[docs] @deprecation.deprecated(
deprecated_in="0.14.2",
removed_in="0.16.0",
current_version=version_mod.__version__,
details="Use get_snapshot_id() instead",
)
def ref(self, obj) -> records.SnapshotId:
"""Get a persistent reference for the given object"""
return self.get_snapshot_id(obj)
[docs] @abstractmethod
def get_snapshot_id(self, obj) -> records.SnapshotId:
"""Get a persistent reference for the given object"""
[docs] def encode(self, obj, schema=None, path=()):
"""Encode a type for archiving"""
historian = self.get_historian()
if historian.is_primitive(obj):
# Deal with the special containers by encoding their values if need be
return tree.transform(self.encode, obj, path, schema=schema)
# Store by value
helper = historian.get_helper(type(obj), auto_register=True)
save_state = helper.save_instance_state(obj, self)
if not historian.is_primitive(save_state):
raise RuntimeError("Saved state must be one of the primitive types")
schema_entry = [path, helper.TYPE_ID]
version = helper.get_version()
if version is not None:
schema_entry.append(version)
if schema is not None:
schema.append(schema_entry)
return self.encode(save_state, schema, path)
[docs] def save_state(self, obj) -> dict:
"""Save the state of an object and return the encoded state ready to be archived in a
record"""
schema = []
saved_state = self.encode(obj, schema)
return {records.STATE: saved_state, records.STATE_TYPES: schema}
def set_extras(self, namespace: str, extras):
pass
[docs]class Loader(Base, metaclass=ABCMeta):
"""A loader that knows how to load objects from the archive"""
[docs] @abstractmethod
def load(self, snapshot_id: records.SnapshotId):
"""Load an object"""
[docs] def decode(
self,
encoded,
schema: records.StateSchema = None,
path=(),
created_callback=None,
updates=None,
):
"""Given the encoded state and an optional schema that defines the type of the encoded
objects this method will decode the saved state and load the object."""
try:
entry = schema[path]
except KeyError:
# There is no schema entry so this is a primitive type and only containers need to (potentially)
# decoded further
if isinstance(encoded, CONTAINERS):
return self._recursive_unpack(encoded, schema, path, created_callback)
# Fully decoded
return encoded
else:
saved_state = encoded
helper = self.get_historian().get_helper(entry.type_id)
if helper.IMMUTABLE:
saved_state = self._recursive_unpack(
encoded, schema, path, created_callback
)
new_obj = helper.new(saved_state)
if new_obj is None:
raise RuntimeError(
f"Helper '{helper.__class__}' failed to create a class given state '{saved_state}'"
)
if created_callback is not None:
created_callback(path, new_obj)
if not helper.IMMUTABLE:
saved_state = self._recursive_unpack(
encoded, schema, path, created_callback, updates
)
updated = helper.ensure_up_to_date(saved_state, entry.version, self)
if updated is not None:
# Use the current version of the record
saved_state = updated
if updates is not None:
updates[path] = updated
helper.load_instance_state(new_obj, saved_state, self)
return new_obj
def _recursive_unpack(
self,
encoded_saved_state,
schema: records.StateSchema = None,
path=(),
created_callback=None,
updates=None,
):
"""Unpack a saved state expanding any contained objects"""
return tree.transform(
self.decode,
encoded_saved_state,
path,
schema=schema,
created_callback=created_callback,
updates=updates,
)
[docs]class LiveDepositor(Saver, Loader):
"""Depositor with strategy that all objects that get referenced should be saved"""
def __init__(self, *args, **kwargs):
# Just patch through
super().__init__(*args, **kwargs)
self._saving_set = set()
[docs] def get_snapshot_id(self, obj) -> Optional[records.SnapshotId]:
if obj is None:
return None
try:
# Try getting it from the transaction as there may be one from an in-progress save.
# We can't use historian.get_snapshot_id here because we _only_ want one that's
# currently being saved or we should try saving it as below to ensure it's up to date
return self._get_current_snapshot_id(obj)
except exceptions.NotFound:
# Then we have to save it and get the resulting reference
return self._save_object(
obj
).snapshot_id # pylint: disable=protected-access
def _get_current_snapshot_id(self, obj) -> records.SnapshotId:
"""Get the current snapshot id of an object"""
return self._historian.current_transaction().get_snapshot_id_for_live_object(
obj
)
[docs] def load(self, snapshot_id: records.SnapshotId):
try:
return self._historian.get_obj(snapshot_id.obj_id)
except exceptions.NotFound:
return self._load_object(snapshot_id.obj_id)
[docs] def load_from_record(self, record: records.DataRecord) -> object:
"""Load an object from a record"""
with self._historian.in_transaction() as trans:
def created(path, new_obj):
"""Called each time an object is created whilst decoding"""
# For the root object, put it into the transaction as a live object
if not path:
trans.insert_live_object(new_obj, record)
updates = {}
loaded = self.decode(
record.state,
record.get_state_schema(),
created_callback=created,
updates=updates,
)
if updates:
logger.warning(
"Object snapshot '%s' is at an older version that your current codebase. It "
"can be migrated by using `mince migrate` from the command line. If this "
"object is saved the new entry will use the new version.",
record.snapshot_id,
)
return loaded
def _load_object(self, obj_id) -> object:
"""Load an object form the database. This method is deliberately private as it should
only be used by the the depositor and the historian"""
historian = self.get_historian()
archive = self.get_archive()
with historian.in_transaction() as trans:
if trans.is_deleted(obj_id):
raise exceptions.ObjectDeleted(obj_id)
# Get the record from the database
record = self._create_record(archive.objects.get(obj_id)) # DB HIT
assert not record.is_deleted_record(), (
f"Found a deleted record in the objects collection ({record.snapshot_id}), "
f"this should never happen!"
)
try:
obj = historian._live_objects.get_object(
obj_id
) # pylint: disable=protected-access
except exceptions.NotFound:
logger.debug("Loading object from record: %s", record.snapshot_id)
# Ok, just use the one from the archive
return self.load_from_record(record)
else:
# Compare with the current, live, version
live_record = historian._live_objects.get_record(
obj
) # pylint: disable=protected-access
if record.version != live_record.version:
# The one in the archive is newer, so use that
logger.debug("Updating object from record: %s", record.snapshot_id)
self.update_from_record(obj, record)
return obj
[docs] def update_from_record(self, obj: object, record: records.DataRecord) -> bool:
"""Do an in-place update of an object from a record"""
historian = self.get_historian()
helper = historian.get_helper(type(obj))
with historian.in_transaction() as trans:
# Make sure the record is in the transaction with the object
trans.insert_live_object(obj, record)
saved_state = self._recursive_unpack(
record.state, record.get_state_schema()
)
helper.load_instance_state(obj, saved_state, self)
return True
def _save_object(self, obj: object) -> records.DataRecord:
historian = self._historian
try:
helper = historian.get_helper(type(obj), auto_register=True)
except ValueError:
raise TypeError(
f"Type is incompatible with the historian: {type(obj).__name__}"
) from None
with historian.in_transaction() as trans:
# Check if an object is already being saved in the transaction
try:
record = trans.get_record_for_live_object(obj)
return record
except exceptions.NotFound:
pass
with self._cycle_protection(obj):
# Ok, have to save it
current_hash = historian.hash(obj)
try:
# Let's see if we have a record at all
record = historian._live_objects.get_record(
obj
) # pylint: disable=protected-access
except exceptions.NotFound:
# Object being saved for the first time
builder = self._create_builder(helper, snapshot_hash=current_hash)
record = self._save_from_builder(obj, builder)
if historian.meta.sticky:
# Apply the sticky meta
historian.meta.update(record.obj_id, historian.meta.sticky)
return record
else:
if helper.IMMUTABLE:
logger.info(
"Tried to save immutable object with id '%s' again",
record.obj_id,
)
return record
# Check if our record is up-to-date
with historian.transaction() as nested:
loaded_obj = SnapshotLoader(historian).load_from_record(record)
if current_hash == record.snapshot_hash and historian.eq(
obj, loaded_obj
):
# Objects identical
nested.rollback()
else:
builder = records.make_child_builder(
record, snapshot_hash=current_hash
)
record = self._save_from_builder(obj, builder)
return record
def _save_from_builder(self, obj, builder: records.DataRecordBuilder):
"""Save a live object"""
assert (
builder.snapshot_hash is not None
), "The snapshot hash must be set on the builder before saving"
historian = self.get_historian()
with historian.in_transaction() as trans: # type: transactions.Transaction
# Insert the object into the transaction so others can refer to it
sid = records.SnapshotId(builder.obj_id, builder.version)
with trans.prepare_for_saving(sid, obj):
# Inject the extras
builder.extras.update(
self._get_extras(obj, builder.obj_id, builder.version)
)
# Now ask the object to save itself and create the record
builder.update(self.save_state(obj))
record = builder.build()
# Insert the record into the transaction
trans.insert_live_object(obj, record)
trans.stage(operations.Insert(record)) # Stage it for being saved
return record
def _get_extras(self, obj, obj_id, version: int) -> dict:
"""Create the extras dictionary for a object that is going to be saved"""
historian = self.get_historian()
extras = self.get_historian().get_user_info()
if version == 0:
# Stuff to be done the first time an object is saved
obj_info = staging.get_info(obj)
if obj_info:
# Deal with a possible object creator
created_by = obj_info.get(records.ExtraKeys.CREATED_BY, None)
if created_by is not None:
try:
sid = historian.get_snapshot_id(created_by)
extras[records.ExtraKeys.CREATED_BY] = sid.obj_id
except exceptions.NotFound:
logger.info(
"Object with id '%s' is being saved but information about the "
"object it was created by will not be in the record because "
"the original object has not been saved yet and therefore has "
"no id.",
obj_id,
)
# Deal with possible copied from
copied_from = obj_info.get(records.ExtraKeys.COPIED_FROM, None)
if copied_from is not None:
try:
sid = historian.get_snapshot_id(copied_from)
extras[records.ExtraKeys.COPIED_FROM] = sid.to_dict()
except exceptions.NotFound:
logger.info(
"Object with id '%s' is being saved but information about the "
"object it was copied from will not be in the record because "
"the original object has not been saved yet and therefore has "
"no id.",
obj_id,
)
return extras
def _create_builder(self, helper, **additional) -> records.DataRecordBuilder:
"""Create a record builder for a new object object"""
additional = additional or {}
builder = records.DataRecord.new_builder(
type_id=helper.TYPE_ID,
obj_id=self.get_archive().create_archive_id(),
version=0,
)
builder.update(additional)
return builder
@contextlib.contextmanager
def _cycle_protection(self, obj: object):
"""This context manager is used as a means of circular-reference identification.
Naturally, such cyclic saving should never happen however if there is a bug, at least this method
allows us to catch it early and see the source.
"""
obj_id = id(obj)
if obj_id in self._saving_set:
raise RuntimeError(
"The object is already being saved, this cannot be called twice and suggests "
"a circular reference is being made"
)
self._saving_set.add(obj_id)
try:
yield
finally:
self._saving_set.remove(obj_id)
@staticmethod
def _create_record(entry_dict: dict) -> records.DataRecord:
return records.DataRecord(**entry_dict)
[docs]class SnapshotLoader(Loader):
"""Responsible for loading snapshots. This object should not be reused and only
one external call to `load` should be made. This is because it keeps an internal
cache."""
def __init__(self, historian):
super().__init__(historian)
self._snapshots = {} # type: Dict[records.SnapshotId, object]
[docs] def load(self, snapshot_id: records.SnapshotId) -> object:
"""Load an object from its snapshot id"""
if not isinstance(snapshot_id, records.SnapshotId):
raise TypeError(snapshot_id)
try:
snapshot = self._snapshots[snapshot_id]
except KeyError:
record = self.get_archive().load(snapshot_id)
if record.is_deleted_record():
snapshot = None
else:
snapshot = self.load_from_record(record)
# Cache it
self._snapshots[snapshot_id] = snapshot
return snapshot
def load_from_record(self, record: records.DataRecord) -> Any:
with self._historian.in_transaction() as trans: # type: transactions.Transaction
updates = {}
obj = self.decode(record.state, record.get_state_schema(), updates=updates)
trans.insert_snapshot(obj, record.snapshot_id)
if updates:
logger.warning(
"Object snapshot '%s' is at an older version that your current codebase. It "
"can be migrated by using `mince migrate` from the command line.",
record.snapshot_id,
)
return obj
[docs]class Migrator(Saver, SnapshotLoader):
"""A migrating depositor used to make migrations to database records"""
[docs] def get_snapshot_id(self, obj) -> records.SnapshotId:
try:
return self.get_historian().get_snapshot_id(obj)
except exceptions.NotFound:
pass
# Ok, try the current transaction
trans = self.get_historian().current_transaction()
if trans is not None:
for sid, snapshot in trans.snapshots.items():
if obj is snapshot:
return sid
# Ok, it's a brand new object that's never been saved, so save it
self._historian.save_one(obj)
return self.get_historian().get_snapshot_id(obj)
[docs] def migrate_records(
self, to_migrate: Iterable[records.DataRecord]
) -> Sequence[records.DataRecord]:
"""Migrate multiple records. This call will return an iterable of those that were migrated"""
migrated = []
with self._historian.in_transaction() as trans: # type: transactions.Transaction
for record in to_migrate:
updates = {}
obj = self.decode(
record.state, record.get_state_schema(), updates=updates
)
if updates:
self._migrate_record(record, obj, trans)
migrated.append(record)
return migrated
def _migrate_record(self, record, new_obj, trans):
"""Given the current record and the corresponding instance this will save an updated state
to the dictionary by re-saving the object. The current transaction must be supplied."""
new_schema = []
new_state = self.encode(new_obj, new_schema)
trans.stage(
operations.Update(
record.snapshot_id,
{records.STATE: new_state, records.STATE_TYPES: new_schema},
)
)
logger.info(
"Snapshot %s has been migrated to the latest version", record.snapshot_id
)