# -*- coding: utf-8 -*-
# pylint: disable=too-many-lines
import collections
import contextlib
try:
from contextlib import nullcontext
except ImportError:
from contextlib2 import nullcontext
import getpass
import logging
import socket
from typing import (
MutableMapping,
Any,
Optional,
Iterable,
Union,
Iterator,
Type,
Dict,
Callable,
Sequence,
)
import weakref
import deprecation
import networkx
from . import archives
from . import builtins
from . import frontend
from . import defaults
from . import depositors
from . import refs
from . import exceptions
from . import expr
from . import files
from . import helpers
from . import hist
from . import migrate
from . import operations
from . import qops
from . import records as recordsm # The records module
from . import result_types
from . import staging
from . import tracking
from . import types
from . import type_registry
from . import utils
from . import version as version_mod
from .transactions import RollbackTransaction, Transaction, LiveObjects
__all__ = "Historian", "ObjectEntry"
logger = logging.getLogger(__name__)
ObjectEntry = collections.namedtuple("ObjectEntry", "ref obj")
HistorianType = Union[helpers.TypeHelper, Type[types.SavableObject]]
[docs]class Historian: # pylint: disable=too-many-public-methods, too-many-instance-attributes
"""The historian acts as a go-between between your python objects and the archive which is
a persistent store of the records. It will keep track of all live objects (i.e. those that
have active references to them) that have been loaded and/or saved as well as enabling the
user to lookup objects in the archive."""
[docs] @deprecation.deprecated(
deprecated_in="0.14.5",
removed_in="0.16.0",
current_version=version_mod.__version__,
details="Use mincepy.copy() instead",
)
def copy(self, obj):
"""Create a shallow copy of the object. Using this method allows the historian to inject
information about where the object was copied from into the record if saved."""
return tracking.copy(obj)
[docs] @deprecation.deprecated(
deprecated_in="0.15.10",
removed_in="0.17.0",
current_version=version_mod.__version__,
details="Use mincepy.records.find() instead",
)
def find_records(self, *args, **kwargs) -> Iterator[recordsm.DataRecord]:
"""Find records
Has same signature as py:meth:`mincepy.Records.find`.
"""
yield from self.records.find(*args, **kwargs)
[docs] @deprecation.deprecated(
deprecated_in="0.15.10",
removed_in="0.17.0",
current_version=version_mod.__version__,
details="Use mincepy.records.distinct() instead",
)
def find_distinct(self, *args, **kwargs):
"""Get distinct values of the given record key
Has same signature as py:meth:`mincepy.Records.distinct`.
"""
yield from self.records.distinct(*args, **kwargs)
def __init__(self, archive: archives.Archive, equators=()):
self._archive = archive
self._equator = types.Equator(defaults.get_default_equators() + equators)
# Register default types
self._type_registry = type_registry.TypeRegistry()
self.register_type(refs.ObjRef)
self.register_types(builtins.HISTORIAN_TYPES)
self.register_type(builtins.SnapshotIdHelper())
self.register_types(archive.get_types())
# Snapshot objects -> reference. Objects that were loaded from historical snapshots
self._snapshots_objects = (
utils.WeakObjectIdDict()
) # type: MutableMapping[Any, recordsm.SnapshotId]
self._live_objects = LiveObjects()
self._transactions = None
self._user = getpass.getuser()
self._hostname = socket.gethostname()
self._live_depositor = depositors.LiveDepositor(self)
self._meta = hist.Meta(self, self._archive)
self._migrate = migrate.Migrations(self)
self._references = hist.References(self)
self._snapshots = hist.SnapshotsCollection(self, self._archive.snapshots)
self._objects = hist.LiveObjectsCollection(self, self._archive.objects)
@property
def archive(self):
return self._archive
@property
def meta(self) -> hist.Meta:
"""Access to functions that operate on the metadata"""
return self._meta
@property
def primitives(self) -> tuple:
"""A tuple of all the primitive types"""
return types.PRIMITIVE_TYPES + (self._archive.get_id_type(),)
@property
def migrations(self) -> migrate.Migrations:
"""Access the migration possibilities"""
return self._migrate
@property
def records(self) -> frontend.EntriesCollection[recordsm.DataRecord]:
"""Access methods and properties that act on and return data records"""
return self._objects.records
@property
def objects(self) -> hist.LiveObjectsCollection:
"""Access the snapshots"""
return self._objects
@property
def references(self) -> hist.References:
"""Access the references collection"""
return self._references
@property
def snapshots(self) -> hist.SnapshotsCollection:
"""Access the snapshots"""
return self._snapshots
[docs] def create_file(
self, filename: str = None, encoding: str = None
) -> builtins.BaseFile:
"""Create a new file. The historian will supply file type compatible with the archive in
use."""
return files.File(self._archive.file_store, filename, encoding)
[docs] def save(self, *objs: object):
"""Save multiple objects producing corresponding object identifiers. This returns a
sequence of ids that is in the same order as the passed objects.
:param objs: the object(s) to save. Can also be a tuple of (obj, meta) to optionally
include metadata to be saved with the object(s)
"""
to_save = []
# Convert everything to tuples
for entry in objs:
if isinstance(entry, tuple):
if len(entry) > 2:
raise ValueError(
f"Supplied tuples can only contain (object, meta), got '{entry}'"
)
else:
entry = (entry,)
to_save.append(entry)
ids = []
with self.in_transaction():
for entry in to_save:
ids.append(self.save_one(*entry))
if len(objs) == 1:
return ids[0]
return ids
[docs] def save_one(self, obj: object, meta: dict = None):
"""Save the object returning an object id. If metadata is supplied it will be set on the
object.
Developer note: this is the front end point-of-entry for a user/client code saving an object
however subsequent objects being saved in this transaction will only go through _save_object
and therefore any code common to all objects being saved should possibly go there.
"""
if obj in self._snapshots_objects:
raise exceptions.ModificationError(
"Cannot save a snapshot object, that would rewrite history!"
)
if meta and not isinstance(meta, dict):
raise TypeError(f"Metadata must be a dictionary, got type '{type(meta)}'")
# Save the object and metadata
with self.in_transaction():
record = self._live_depositor._save_object(
obj
) # pylint: disable=protected-access
if meta:
self.meta.update(record.obj_id, meta)
return record.obj_id
[docs] def is_known(self, obj: object) -> bool:
"""Check if an object has ever been saved and is therefore known to the historian
:return: True if ever saved, False otherwise
"""
return self.get_obj_id(obj) is not None
[docs] def replace(self, old: object, new: object):
"""Replace a live object with a new version.
This is especially useful if you have made a copy of an object and modified it but you want
to continue the history of the object as the original rather than a brand new object. Then
just replace the old object with the new one by calling this function.
"""
if self.current_transaction() is not None:
raise RuntimeError("Can't replace during a transaction for the time being")
if not isinstance(new, type(old)):
raise TypeError(f"Can't replace type '{type(old)} with type '{type(new)}!")
# Get the current record and replace the object with the new one
record = self._live_objects.get_record(old)
self._live_objects.remove(record.obj_id)
self._live_objects.insert(new, record)
# Make sure creators is correct as well
staging.replace(old, new)
def load_snapshot(self, snapshot_id: recordsm.SnapshotId) -> object:
return self._new_snapshot_depositor().load(snapshot_id)
def load_snapshot_from_record(self, record: recordsm.DataRecord) -> object:
return self._new_snapshot_depositor().load_from_record(record)
[docs] def load(self, *obj_id_or_snapshot_id):
"""Load object(s) or snapshot(s)."""
loaded = []
for entry in obj_id_or_snapshot_id:
loaded.append(self.load_one(entry))
if len(obj_id_or_snapshot_id) == 1:
return loaded[0]
return loaded
[docs] def load_one(self, obj_id_or_snapshot_id) -> object:
"""Load one object or snapshot from the database"""
if isinstance(obj_id_or_snapshot_id, recordsm.SnapshotId):
return self.load_snapshot(obj_id_or_snapshot_id)
# OK, assume we're dealing with an object id
obj_id = self._ensure_obj_id(obj_id_or_snapshot_id)
# Try getting the object from our dict of up-to-date ones
try:
return self.get_obj(obj_id)
except exceptions.NotFound:
# Going to have to load from the database
return self._live_depositor._load_object(
obj_id
) # pylint: disable=protected-access
[docs] def get(self, obj_id) -> object:
"""Get a live object using the object id"""
return self._objects.get(obj_id)
[docs] def sync(self, obj: object) -> bool:
"""Update an object with the latest state in the database.
If there is no new version in the archive then the current version remains
unchanged including any modifications.
:return: True if the object was updated, False otherwise
"""
obj_id = self.get_obj_id(obj)
if obj_id is None:
# Never saved so the object is as up to date as can be!
return False
record = self._objects.records.get(obj_id)
if record.is_deleted_record():
raise exceptions.ObjectDeleted(
f"Object with id '{obj_id}' has been deleted"
)
if record.version == self.get_snapshot_id(obj).version:
# Nothing has changed
return False
# The one in the archive is newer, so use that
return self._live_depositor.update_from_record(obj, record)
[docs] def delete(self, *obj_or_identifier, imperative=True) -> result_types.DeleteResult:
"""Delete objects.
:param imperative: if True, this means that the caller explicitly expects this call to delete the passed objects
and it should therefore raise if an object cannot be found or has been deleted already. If False, the
function will ignore these cases and continue.
:raises mincepy.NotFound: if the object cannot be found (potentially because it was
already deleted)
"""
# We need the current records to be able to build the delete records
obj_ids = list(map(self._ensure_obj_id, obj_or_identifier))
# Find the current records (i.e. from our cache)
records = {} # type: Dict[Any, recordsm.DataRecord]
left_to_find = set()
for obj_id in obj_ids:
try:
records[obj_id] = self.get_current_record(self.get_obj(obj_id))
except exceptions.ObjectDeleted:
if imperative:
# Object deleted already so reraise
raise
except exceptions.NotFound:
left_to_find.add(obj_id)
# Those that we don't have cached records for and need to look up
if left_to_find:
# Have a look in the archive
for record in self._objects.records.find(
recordsm.DataRecord.obj_id.in_(*left_to_find)
):
records[record.obj_id] = record
left_to_find.remove(record.obj_id)
if left_to_find and imperative:
# Still couldn't find them so raise
raise exceptions.NotFound(left_to_find)
deleted = []
with self.in_transaction() as trans:
# Mark each object as deleted in the transaction and stage the 'delete record' for insertion
# in the order that they were passed to us, in case this makes a difference to the caller
for obj_id in obj_ids:
record = records.get(obj_id, None)
if record is None:
continue
builder = recordsm.make_deleted_builder(record)
deleted_record = self._record_builder_created(builder).build()
trans.delete(record.obj_id)
trans.stage(operations.Insert(deleted_record))
deleted.append(record.obj_id)
return result_types.DeleteResult(deleted, left_to_find)
[docs] def history(
self, obj_or_obj_id, idx_or_slice="*", as_objects=True
) -> [Sequence[ObjectEntry], Sequence[recordsm.DataRecord]]:
"""Get a sequence of object ids and instances from the history of the given object.
:param obj_or_obj_id: The instance or id of the object to get the history for
:param idx_or_slice: The particular index or a slice of which historical versions to get
:param as_objects: if True return the object instances, otherwise returns the DataRecords
Example:
>>> import mincepy, mincepy.testing
>>> historian = mincepy.get_historian()
>>> car = mincepy.testing.Car('ferrari', 'white')
>>> car_id = historian.save(car)
>>> car.colour = 'red'
>>> historian.save(car)
>>> history = historian.history(car_id)
>>> len(history)
2
>>> history[0].obj.colour == 'white'
True
>>> history[1].obj.colour == 'red'
True
>>> history[1].obj is car
"""
obj_id = self._ensure_obj_id(obj_or_obj_id)
snapshot_ids = self._archive.get_snapshot_ids(obj_id)
indices = utils.to_slice(idx_or_slice)
to_get = snapshot_ids[indices]
if as_objects:
return [ObjectEntry(sid, self.load(sid)) for sid in to_get]
return [self._archive.load(ref) for ref in to_get]
[docs] def get_current_record(self, obj: object) -> recordsm.DataRecord:
"""Get the current record that the historian has cached for the passed object"""
trans = self.current_transaction()
# Try the transaction first
if trans:
try:
return trans.get_record_for_live_object(obj)
except exceptions.ObjectDeleted:
# ObjectDeleted is a specialisation of the NotFound error but it means that we
# should consider the object as being gone so reraise it
raise
except exceptions.NotFound:
pass
return self._live_objects.get_record(obj)
[docs] def get_obj_id(self, obj: object) -> Any:
"""Get the object ID for a live object.
:return: the object id or None if the object is not known to the historian
"""
trans = self.current_transaction()
if trans is not None:
try:
return trans.get_snapshot_id_for_live_object(obj).obj_id
except exceptions.NotFound:
pass
try:
obj_id = self._live_objects.get_record(obj).obj_id
if trans is not None and trans.is_deleted(obj_id):
# The object has been deleted in the transaction, so it is not known
return None
return obj_id
except exceptions.NotFound:
return None
[docs] def get_obj(self, obj_id) -> object:
"""Get a currently live object"""
trans = self.current_transaction()
if trans:
try:
return trans.get_live_object(obj_id)
except exceptions.ObjectDeleted:
# ObjectDeleted is a specialisation of the NotFound error but it means that we
# should consider the object as being gone so reraise it
raise
except exceptions.NotFound:
pass
return self._live_objects.get_object(obj_id)
[docs] def is_saved(self, obj: object) -> bool:
"""Test if an object is saved with this historian. This is equivalent to
`historian.get_obj_id(obj) is not None`."""
return self.get_obj_id(obj) is not None
[docs] def to_obj_id(self, obj_or_identifier):
"""
This call will try and get an object id from the passed parameter. The possibilities are:
1. Passed an object ID in which case it will be returned unchanged
2. Passed a snapshot ID, in which case the corresponding object ID will be returned
2. Passed a live object instance, in which case the id of that object will be returned
3. Passed a type that can be understood by the archive as an object id e.g. a string of
version, in which case the archive will attempt to convert it
Returns None if neither of these cases were True.
"""
if self.is_obj_id(obj_or_identifier):
return obj_or_identifier
if isinstance(obj_or_identifier, recordsm.SnapshotId):
return obj_or_identifier.obj_id
try:
# Try creating it for the user by calling the constructor with the argument passed.
# This helps for common obj id types which can be constructed from a string
return self._archive.construct_archive_id(obj_or_identifier)
except (ValueError, TypeError):
# Maybe we've been passed an object
pass
return self.get_obj_id(obj_or_identifier)
[docs] def get_snapshot_id(self, obj: object) -> recordsm.SnapshotId:
"""Get the current snapshot id for a live object. Will return the id or raise
:class:`mincepy.NotFound` exception"""
trans = self.current_transaction()
if trans:
try:
return trans.get_snapshot_id_for_live_object(obj)
except exceptions.NotFound:
pass
return self._live_objects.get_record(obj).snapshot_id
def hash(self, obj: object):
return self._equator.hash(obj)
def eq(self, one: object, other: object): # pylint: disable=invalid-name
return self._equator.eq(one, other)
# region Types
[docs] @classmethod
def is_trackable(cls, obj):
"""Determine if an object is trackable i.e. we can treat these as live objects and
automatically keep track of their history when saving. Ultimately this is determined by
whether the type is weak referencable or not.
"""
try:
weakref.ref(obj)
return True
except TypeError:
return False
@property
def type_registry(self) -> type_registry.TypeRegistry:
return self._type_registry
[docs] def is_primitive(self, obj) -> bool:
"""Check if the object is one of the primitives and should be saved by value in the
archive"""
return obj.__class__ in self.primitives
[docs] def is_obj_id(self, obj_id) -> bool:
"""Check if an object is of the object id type"""
return isinstance(obj_id, self._archive.get_id_type())
def register_type(
self, obj_class_or_helper: "type_registry.RegisterableType", replace=True
) -> helpers.TypeHelper:
helper = self._type_registry.register_type(obj_class_or_helper, replace=replace)
self._equator.add_equator(helper)
return helper
def register_types(
self, obj_clases_or_helpers: Iterable[HistorianType], replace=True
):
for item in obj_clases_or_helpers:
self.register_type(item, replace=replace)
def get_obj_type_id(self, obj_type):
return self._type_registry.get_type_id(obj_type)
def get_obj_type(self, type_id):
return self.get_helper(type_id).TYPE
def get_helper(self, type_id_or_type, auto_register=False) -> helpers.TypeHelper:
if (
auto_register
and issubclass(type_id_or_type, types.SavableObject)
and type_id_or_type not in self._type_registry
):
self.register_type(type_id_or_type)
return self._type_registry.get_helper(type_id_or_type)
# endregion
# pylint: disable=redefined-builtin
[docs] def find(
self,
*filter: expr.FilterSpec,
obj_type=None,
obj_id=None,
version: int = -1,
state=None,
meta: dict = None,
sort=None,
limit=0,
skip=0,
) -> frontend.ResultSet[object]:
"""
.. _MongoDB: https://docs.mongofrontend.com/manual/tutorial/query-documents/
Find objects. This call will search the archive for objects matching the given criteria.
In many cases the main arguments of interest will be `state` and `meta` which allow you to
apply filters on the stored state of the object and metadata respectively. To understand
how the state is stored in the database (and therefore how to apply filters to it) it may
be necessary to look at the details of the `save_instance_state()` method for that type.
Metadata is always a dictionary containing primitives (strings, dicts, lists, etc).
For the most part, the filter syntax of `mincePy` conforms to that of `MongoDB`_ with
convenience functions locate in :py:mod:`mincepy.qops` that can make it easier to
to build a query.
Examples:
Find all :py:class:`~mincepy.testing.Car`s that are brown or red:
>>> import mincepy as mpy
>>> historian = mpy.get_historian()
>>> historian.find(mpy.testing.Car.colour.in_('brown', 'red'))
Find all people that are older than 34 and live in Edinburgh:
>>> historian.find(mpy.testing.Person.age > 34, meta=dict(city='Edinburgh'))
:param obj_type: the object type to look for
:param obj_id: an object or multiple object ids to look for
:param version: the version of the object to retrieve, -1 means latest
:param state: the criteria on the state of the object to apply
:type state: must be subclass of historian.primitive
:param meta: the search criteria to apply on the metadata of the object
:param sort: the sort criteria
:param limit: the maximum number of results to return, 0 means unlimited
:param skip: the page to get results from
"""
return self._objects.find(
*filter,
obj_type=obj_type,
obj_id=obj_id,
version=version,
state=state,
meta=meta,
sort=sort,
limit=limit,
skip=skip,
)
[docs] def get_creator(self, obj_or_identifier) -> object:
"""Get the object that created the passed object"""
if not self.is_obj_id(obj_or_identifier):
# Object instance, try the staging area
info = staging.get_info(obj_or_identifier, create=False) or {}
created_by = info.get(recordsm.ExtraKeys.CREATED_BY, None)
if created_by is not None:
return created_by
creator_id = self.created_by(obj_or_identifier)
return self.load_one(creator_id)
[docs] def created_by(self, obj_or_identifier):
"""Return the id of the object that created the passed object"""
try:
record = self.get_current_record(obj_or_identifier)
except exceptions.NotFound as exc:
if not self.is_obj_id(obj_or_identifier):
raise
try:
record = self._objects.records.find(obj_id=obj_or_identifier).one()
except exceptions.NotOneError:
raise exc from None
return record.created_by
[docs] def get_user_info(self) -> dict:
"""Get information about the current user and host"""
user_info = {}
if self._user:
user_info[recordsm.ExtraKeys.USER] = self._user
if self._hostname:
user_info[recordsm.ExtraKeys.HOSTNAME] = self._hostname
return user_info
[docs] def merge(
self,
result_set: frontend.ResultSet[object],
*,
meta=None, # pylint: disable=unused-argument
batch_size=1024,
progress_callback: Callable[
[utils.Progress, Optional[result_types.MergeResult]], None
] = None,
) -> result_types.MergeResult:
"""Merge a set of objects into this database.
Given a set of results from another archive this will attempt to merge the corresponding records
into this historian's archive.
:param result_set: the set of records to merge from the source historian
:param meta: option for merging metadata, allowed values:
None - Don't merge metadata
'update' - Perform dictionary update with existing metadata
'overwrite' - In the case of an existing metadata dictionary, overwrite it
"""
# REMOTE
remote = result_set.historian # type: Historian
# Get information about the records that we've been asked to merge
# pylint: disable=protected-access
remote_partial_records = result_set._project(recordsm.OBJ_ID, recordsm.VERSION)
remote_snapshot_ids = set(
map(recordsm.SnapshotId.from_dict, remote_partial_records)
) # DB HIT
progress = utils.Progress(len(remote_snapshot_ids))
if progress_callback is not None:
progress_callback(progress, result_types.MergeResult())
result = result_types.MergeResult()
# get the outgoing snapshot ref. graph
while remote_snapshot_ids:
# Get a batch
batch = set()
try:
for _ in range(batch_size):
batch.add(remote_snapshot_ids.pop())
except KeyError:
pass
graph = remote.references.get_snapshot_ref_graph(*batch)
# The graph may contain nodes that are still in our list of remote snapshots to transfer
# so check and remove these because they will be done in this batch
extras = set(graph.nodes) - batch
remote_snapshot_ids.difference_update(extras)
partial_result = self._merge_batch(remote, graph)
result.update(partial_result)
progress.done = progress.total - len(remote_snapshot_ids)
if progress_callback is not None:
progress_callback(progress, partial_result)
return result
[docs] def purge(
self, deleted=True, unreferenced=True, dry_run=True
) -> result_types.PurgeResult:
"""Purge the archive of unused snapshots"""
snapshot_purge = self.snapshots.purge(deleted=deleted, dry_run=dry_run)
unreferenced_deleted = set()
if unreferenced:
# Let's get snapshot ids for all live object
live_snapshot_ids = list(
map(
recordsm.SnapshotId.from_dict,
# pylint: disable=protected-access
self.objects.records.find()._project(
recordsm.OBJ_ID, recordsm.VERSION
),
)
)
# Now, find all the snapshots that they refer to, these will be the ones we DON'T delete
reg_graph = self.references.get_snapshot_ref_graph(
*live_snapshot_ids, direction=archives.OUTGOING, max_dist=-1
) # DB HIT
to_keep = set(reg_graph.nodes)
res = self.snapshots.records.find(
expr.Comparison("_id", expr.Nin(list(map(str, to_keep))))
)
unreferenced_deleted = set(
map(
recordsm.SnapshotId.from_dict,
# pylint: disable=protected-access
res._project(recordsm.OBJ_ID, recordsm.VERSION),
)
)
if unreferenced_deleted and not dry_run:
self._archive.bulk_write(
list(map(operations.Delete, unreferenced_deleted))
)
return result_types.PurgeResult(
snapshot_purge.deleted_purged, unreferenced_deleted
)
def _merge_batch(
self, remote: "Historian", remote_ref_graph: networkx.DiGraph
) -> result_types.MergeResult:
sid_strings = list(map(str, remote_ref_graph.nodes))
# REMOTE
# Get the partial records for all these snapshots indexed by the SID
remote_partial_records = {}
for entry in remote.archive.snapshots.find(
{"_id": qops.in_(*sid_strings)},
projection={
recordsm.OBJ_ID: 1,
recordsm.VERSION: 1,
recordsm.SNAPSHOT_HASH: 1,
},
): # DB HIT
remote_partial_records[recordsm.SnapshotId.from_dict(entry)] = entry
# LOCAL
# Find the local snapshots along with their hashes
local_partial_records = {}
for entry in self.archive.snapshots.find(
{"_id": qops.in_(*sid_strings)},
projection={
recordsm.OBJ_ID: 1,
recordsm.VERSION: 1,
recordsm.SNAPSHOT_HASH: 1,
},
): # DB HIT
local_partial_records[recordsm.SnapshotId.from_dict(entry)] = entry
# Remove all those that match and log any that have conflicting hashes
conflicting = []
for sid, local_partial in local_partial_records.items():
remote_record = remote_partial_records.pop(sid)
if (
remote_record[recordsm.SNAPSHOT_HASH]
!= local_partial[recordsm.SNAPSHOT_HASH]
):
conflicting.append(sid)
if conflicting:
raise exceptions.MergeError(
f"Cannot merge, the following snapshots have conflicting hashes: {conflicting}"
)
# Finally, get all the records to merge and create merge operations
ops = []
files_to_transfer = []
for remote_record in remote.archive.snapshots.find(
{"_id": qops.in_(*map(str, remote_partial_records.keys()))}
): # DB HIT
record = recordsm.DataRecord(**remote_record)
ops.append(operations.Merge(record))
files_in_record = record.get_files()
if files_in_record:
# Extract the second entry in the tuple as this contains the actual state dictionary
files_to_transfer.extend(entry[1] for entry in files_in_record)
# and write the new records into our archive
if ops:
# Copy the files first. This way if the user cancels prematurely the files are there but no the objects
# that refer to them. The other way around would result in the objects being there but failing when
# someone tries to load the files
file_store = self.archive.file_store
for file_dict in files_to_transfer:
file_id = file_dict[expr.field_name(files.File.file_id)]
filename = file_dict[expr.field_name(files.File.filename)] or ""
with remote.archive.file_store.open_download_stream(
file_id
) as down_stream:
file_store.upload_from_stream_with_id(
file_id, filename, down_stream
)
self._archive.bulk_write(ops) # DB HIT
return result_types.MergeResult(
all_snapshots=remote_ref_graph.nodes,
merged_snapshots=remote_partial_records.keys(),
)
[docs] @contextlib.contextmanager
def in_transaction(self) -> Iterator[Transaction]:
"""This context will either re-use an existing transaction, if one is currently taking place
or create a new one if not."""
current = self.current_transaction()
if current is None:
ctx = self.transaction()
else:
ctx = nullcontext(current)
with ctx as trans:
yield trans
[docs] @contextlib.contextmanager
def transaction(self) -> Iterator[Transaction]:
"""Start a new transaction. Will be nested if there is already one underway"""
if self._transactions:
# Start a nested one
with self._transactions[-1].nested() as nested:
self._transactions.append(nested)
try:
yield nested
except Exception: # Need this so we can have 'else' pylint: disable=try-except-raise
raise
else:
self._closing_transaction(nested)
finally:
popped = self._transactions.pop()
assert popped is nested
else:
# New transaction
trans = Transaction()
self._transactions = [trans]
try:
yield trans
except RollbackTransaction:
pass
else:
self._closing_transaction(trans)
self._commit_transaction(trans)
finally:
assert len(self._transactions) == 1
assert self._transactions[0] is trans
self._transactions = None
[docs] def current_transaction(self) -> Optional[Transaction]:
"""Get the current transaction if there is one, otherwise returns None"""
if not self._transactions:
return None
return self._transactions[-1]
def _closing_transaction(self, trans: Transaction):
conflicting = set()
# Filter out the deleted records
del_ops = filter(
lambda _: isinstance(_, operations.Insert) and _.record.is_deleted_record(),
trans.staged,
)
obj_ids = set(operation.obj_id for operation in del_ops)
ref_graph = self.references.get_obj_ref_graph(
*obj_ids, direction=archives.INCOMING
)
for obj_id in obj_ids:
for edge in ref_graph.in_edges(obj_id):
conflicting.add(edge[1])
if conflicting:
raise exceptions.ReferenceError("Cannot perform delete", conflicting)
def _commit_transaction(self, trans: Transaction):
"""Commit a transaction that is finishing"""
# Perform the database operations first because if these fail we shouldn't update ourselves
# Save any records that were staged for archiving
if trans.staged:
self._archive.bulk_write(trans.staged)
# Now all is good we can update
# Live objects
self._live_objects.update(trans.live_objects)
# Deleted objects
for deleted in trans.deleted:
try:
self._live_objects.remove(deleted)
except exceptions.NotFound:
pass
# Snapshots
for ref, obj in trans.snapshots.items():
self._snapshots_objects[obj] = ref
# Finally update the metadata as this is least important
# Metas
if trans.metas:
self._archive.meta_set_many(trans.metas)
def _load_object_from_record(self, record: recordsm.DataRecord):
depositor = self._live_depositor
# Try getting the object from the our dict of up to date ones
obj_id = record.obj_id
try:
return self.get_obj(obj_id)
except exceptions.NotFound:
pass
with self.in_transaction() as trans:
if trans.is_deleted(obj_id):
raise exceptions.ObjectDeleted(obj_id)
if record.is_deleted_record():
raise exceptions.ObjectDeleted(obj_id)
logger.debug("Loading object from record: %s", record.snapshot_id)
# Ok, just use the one from the archive
return depositor.load_from_record(record)
def _ensure_obj_id(self, obj_or_identifier):
"""
This call will try and get an object id from the passed parameter. Uses .to_obj_id() and raises NotFound if it
is not possible to get the object id.
"""
obj_id = self.to_obj_id(obj_or_identifier)
if obj_id is None:
raise exceptions.NotFound(
f"Could not get an object id from '{obj_or_identifier}'"
)
return obj_id
def _prepare_obj_id(self, obj_id):
if obj_id is None:
return None
# Convert object ids to the expected type before passing to archive
try:
return self._ensure_obj_id(obj_id)
except exceptions.NotFound as exc:
# Maybe it is multiple object ids
if not isinstance(
obj_id, Iterable
): # pylint: disable=isinstance-second-argument-not-valid-type
raise TypeError(f"Cannot get object id(s) from '{obj_id}'") from exc
return list(map(self._ensure_obj_id, obj_id))
def _prepare_type_id(self, obj_type):
if obj_type is None:
return None
try:
return self.get_obj_type_id(obj_type)
except TypeError as exc:
# Maybe it is multiple type ids
if not isinstance(
obj_type, Iterable
): # pylint: disable=isinstance-second-argument-not-valid-type
raise TypeError(f"Cannot get type id(s) from '{obj_type}'") from exc
return list(map(self.get_obj_type_id, obj_type))
def _record_builder_created(
self, builder: recordsm.DataRecordBuilder
) -> recordsm.DataRecordBuilder:
"""Update a data record builder with standard information."""
builder.extras.update(self.get_user_info())
return builder
def _new_snapshot_depositor(self):
return depositors.SnapshotLoader(self)