This commit is contained in:
2010-05-07 17:33:49 +00:00
parent c7c7498f19
commit cae7e001e3
127 changed files with 57530 additions and 0 deletions

1176
sqlalchemy/orm/__init__.py Normal file

File diff suppressed because it is too large Load Diff

1708
sqlalchemy/orm/attributes.py Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,575 @@
# orm/dependency.py
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Relationship dependencies.
Bridges the ``PropertyLoader`` (i.e. a ``relationship()``) and the
``UOWTransaction`` together to allow processing of relationship()-based
dependencies at flush time.
"""
from sqlalchemy import sql, util
import sqlalchemy.exceptions as sa_exc
from sqlalchemy.orm import attributes, exc, sync
from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE, MANYTOMANY
def create_dependency_processor(prop):
types = {
ONETOMANY : OneToManyDP,
MANYTOONE: ManyToOneDP,
MANYTOMANY : ManyToManyDP,
}
return types[prop.direction](prop)
class DependencyProcessor(object):
has_dependencies = True
def __init__(self, prop):
self.prop = prop
self.cascade = prop.cascade
self.mapper = prop.mapper
self.parent = prop.parent
self.secondary = prop.secondary
self.direction = prop.direction
self.post_update = prop.post_update
self.passive_deletes = prop.passive_deletes
self.passive_updates = prop.passive_updates
self.enable_typechecks = prop.enable_typechecks
self.key = prop.key
self.dependency_marker = MapperStub(self.parent, self.mapper, self.key)
if not self.prop.synchronize_pairs:
raise sa_exc.ArgumentError("Can't build a DependencyProcessor for relationship %s. "
"No target attributes to populate between parent and child are present" % self.prop)
def _get_instrumented_attribute(self):
"""Return the ``InstrumentedAttribute`` handled by this
``DependencyProecssor``.
"""
return self.parent.class_manager.get_impl(self.key)
def hasparent(self, state):
"""return True if the given object instance has a parent,
according to the ``InstrumentedAttribute`` handled by this ``DependencyProcessor``.
"""
# TODO: use correct API for this
return self._get_instrumented_attribute().hasparent(state)
def register_dependencies(self, uowcommit):
"""Tell a ``UOWTransaction`` what mappers are dependent on
which, with regards to the two or three mappers handled by
this ``DependencyProcessor``.
"""
raise NotImplementedError()
def register_processors(self, uowcommit):
"""Tell a ``UOWTransaction`` about this object as a processor,
which will be executed after that mapper's objects have been
saved or before they've been deleted. The process operation
manages attributes and dependent operations between two mappers.
"""
raise NotImplementedError()
def whose_dependent_on_who(self, state1, state2):
"""Given an object pair assuming `obj2` is a child of `obj1`,
return a tuple with the dependent object second, or None if
there is no dependency.
"""
if state1 is state2:
return None
elif self.direction == ONETOMANY:
return (state1, state2)
else:
return (state2, state1)
def process_dependencies(self, task, deplist, uowcommit, delete = False):
"""This method is called during a flush operation to
synchronize data between a parent and child object.
It is called within the context of the various mappers and
sometimes individual objects sorted according to their
insert/update/delete order (topological sort).
"""
raise NotImplementedError()
def preprocess_dependencies(self, task, deplist, uowcommit, delete = False):
"""Used before the flushes' topological sort to traverse
through related objects and ensure every instance which will
require save/update/delete is properly added to the
UOWTransaction.
"""
raise NotImplementedError()
def _verify_canload(self, state):
if state is not None and \
not self.mapper._canload(state, allow_subtypes=not self.enable_typechecks):
if self.mapper._canload(state, allow_subtypes=True):
raise exc.FlushError(
"Attempting to flush an item of type %s on collection '%s', "
"which is not the expected type %s. Configure mapper '%s' to "
"load this subtype polymorphically, or set "
"enable_typechecks=False to allow subtypes. "
"Mismatched typeloading may cause bi-directional relationships "
"(backrefs) to not function properly." %
(state.class_, self.prop, self.mapper.class_, self.mapper))
else:
raise exc.FlushError(
"Attempting to flush an item of type %s on collection '%s', "
"whose mapper does not inherit from that of %s." %
(state.class_, self.prop, self.mapper.class_))
def _synchronize(self, state, child, associationrow, clearkeys, uowcommit):
"""Called during a flush to synchronize primary key identifier
values between a parent/child object, as well as to an
associationrow in the case of many-to-many.
"""
raise NotImplementedError()
def _check_reverse_action(self, uowcommit, parent, child, action):
"""Determine if an action has been performed by the 'reverse' property of this property.
this is used to ensure that only one side of a bidirectional relationship
issues a certain operation for a parent/child pair.
"""
for r in self.prop._reverse_property:
if not r.viewonly and (r._dependency_processor, action, parent, child) in uowcommit.attributes:
return True
return False
def _performed_action(self, uowcommit, parent, child, action):
"""Establish that an action has been performed for a certain parent/child pair.
Used only for actions that are sensitive to bidirectional double-action,
i.e. manytomany, post_update.
"""
uowcommit.attributes[(self, action, parent, child)] = True
def _conditional_post_update(self, state, uowcommit, related):
"""Execute a post_update call.
For relationships that contain the post_update flag, an additional
``UPDATE`` statement may be associated after an ``INSERT`` or
before a ``DELETE`` in order to resolve circular row
dependencies.
This method will check for the post_update flag being set on a
particular relationship, and given a target object and list of
one or more related objects, and execute the ``UPDATE`` if the
given related object list contains ``INSERT``s or ``DELETE``s.
"""
if state is not None and self.post_update:
for x in related:
if x is not None and not self._check_reverse_action(uowcommit, x, state, "postupdate"):
uowcommit.register_object(state, postupdate=True, post_update_cols=[r for l, r in self.prop.synchronize_pairs])
self._performed_action(uowcommit, x, state, "postupdate")
break
def _pks_changed(self, uowcommit, state):
raise NotImplementedError()
def __repr__(self):
return "%s(%s)" % (self.__class__.__name__, self.prop)
class OneToManyDP(DependencyProcessor):
def register_dependencies(self, uowcommit):
if self.post_update:
uowcommit.register_dependency(self.mapper, self.dependency_marker)
uowcommit.register_dependency(self.parent, self.dependency_marker)
else:
uowcommit.register_dependency(self.parent, self.mapper)
def register_processors(self, uowcommit):
if self.post_update:
uowcommit.register_processor(self.dependency_marker, self, self.parent)
else:
uowcommit.register_processor(self.parent, self, self.parent)
def process_dependencies(self, task, deplist, uowcommit, delete = False):
if delete:
# head object is being deleted, and we manage its list of child objects
# the child objects have to have their foreign key to the parent set to NULL
# this phase can be called safely for any cascade but is unnecessary if delete cascade
# is on.
if self.post_update or not self.passive_deletes == 'all':
for state in deplist:
history = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes)
if history:
for child in history.deleted:
if child is not None and self.hasparent(child) is False:
self._synchronize(state, child, None, True, uowcommit)
self._conditional_post_update(child, uowcommit, [state])
if self.post_update or not self.cascade.delete:
for child in history.unchanged:
if child is not None:
self._synchronize(state, child, None, True, uowcommit)
self._conditional_post_update(child, uowcommit, [state])
else:
for state in deplist:
history = uowcommit.get_attribute_history(state, self.key, passive=True)
if history:
for child in history.added:
self._synchronize(state, child, None, False, uowcommit)
if child is not None:
self._conditional_post_update(child, uowcommit, [state])
for child in history.deleted:
if not self.cascade.delete_orphan and not self.hasparent(child):
self._synchronize(state, child, None, True, uowcommit)
if self._pks_changed(uowcommit, state):
for child in history.unchanged:
self._synchronize(state, child, None, False, uowcommit)
def preprocess_dependencies(self, task, deplist, uowcommit, delete = False):
if delete:
# head object is being deleted, and we manage its list of child objects
# the child objects have to have their foreign key to the parent set to NULL
if not self.post_update:
should_null_fks = not self.cascade.delete and not self.passive_deletes == 'all'
for state in deplist:
history = uowcommit.get_attribute_history(
state, self.key, passive=self.passive_deletes)
if history:
for child in history.deleted:
if child is not None and self.hasparent(child) is False:
if self.cascade.delete_orphan:
uowcommit.register_object(child, isdelete=True)
else:
uowcommit.register_object(child)
if should_null_fks:
for child in history.unchanged:
if child is not None:
uowcommit.register_object(child)
else:
for state in deplist:
history = uowcommit.get_attribute_history(state, self.key, passive=True)
if history:
for child in history.added:
if child is not None:
uowcommit.register_object(child)
for child in history.deleted:
if not self.cascade.delete_orphan:
uowcommit.register_object(child, isdelete=False)
elif self.hasparent(child) is False:
uowcommit.register_object(child, isdelete=True)
for c, m in self.mapper.cascade_iterator('delete', child):
uowcommit.register_object(
attributes.instance_state(c),
isdelete=True)
if self._pks_changed(uowcommit, state):
if not history:
history = uowcommit.get_attribute_history(
state, self.key, passive=self.passive_updates)
if history:
for child in history.unchanged:
if child is not None:
uowcommit.register_object(child)
def _synchronize(self, state, child, associationrow, clearkeys, uowcommit):
source = state
dest = child
if dest is None or (not self.post_update and uowcommit.is_deleted(dest)):
return
self._verify_canload(child)
if clearkeys:
sync.clear(dest, self.mapper, self.prop.synchronize_pairs)
else:
sync.populate(source, self.parent, dest, self.mapper,
self.prop.synchronize_pairs, uowcommit,
self.passive_updates)
def _pks_changed(self, uowcommit, state):
return sync.source_modified(uowcommit, state, self.parent, self.prop.synchronize_pairs)
class DetectKeySwitch(DependencyProcessor):
"""a special DP that works for many-to-one relationships, fires off for
child items who have changed their referenced key."""
has_dependencies = False
def register_dependencies(self, uowcommit):
pass
def register_processors(self, uowcommit):
uowcommit.register_processor(self.parent, self, self.mapper)
def preprocess_dependencies(self, task, deplist, uowcommit, delete=False):
# for non-passive updates, register in the preprocess stage
# so that mapper save_obj() gets a hold of changes
if not delete and not self.passive_updates:
self._process_key_switches(deplist, uowcommit)
def process_dependencies(self, task, deplist, uowcommit, delete=False):
# for passive updates, register objects in the process stage
# so that we avoid ManyToOneDP's registering the object without
# the listonly flag in its own preprocess stage (results in UPDATE)
# statements being emitted
if not delete and self.passive_updates:
self._process_key_switches(deplist, uowcommit)
def _process_key_switches(self, deplist, uowcommit):
switchers = set(s for s in deplist if self._pks_changed(uowcommit, s))
if switchers:
# yes, we're doing a linear search right now through the UOW. only
# takes effect when primary key values have actually changed.
# a possible optimization might be to enhance the "hasparents" capability of
# attributes to actually store all parent references, but this introduces
# more complicated attribute accounting.
for s in [elem for elem in uowcommit.session.identity_map.all_states()
if issubclass(elem.class_, self.parent.class_) and
self.key in elem.dict and
elem.dict[self.key] is not None and
attributes.instance_state(elem.dict[self.key]) in switchers
]:
uowcommit.register_object(s)
sync.populate(
attributes.instance_state(s.dict[self.key]),
self.mapper, s, self.parent, self.prop.synchronize_pairs,
uowcommit, self.passive_updates)
def _pks_changed(self, uowcommit, state):
return sync.source_modified(uowcommit, state, self.mapper, self.prop.synchronize_pairs)
class ManyToOneDP(DependencyProcessor):
def __init__(self, prop):
DependencyProcessor.__init__(self, prop)
self.mapper._dependency_processors.append(DetectKeySwitch(prop))
def register_dependencies(self, uowcommit):
if self.post_update:
uowcommit.register_dependency(self.mapper, self.dependency_marker)
uowcommit.register_dependency(self.parent, self.dependency_marker)
else:
uowcommit.register_dependency(self.mapper, self.parent)
def register_processors(self, uowcommit):
if self.post_update:
uowcommit.register_processor(self.dependency_marker, self, self.parent)
else:
uowcommit.register_processor(self.mapper, self, self.parent)
def process_dependencies(self, task, deplist, uowcommit, delete=False):
if delete:
if self.post_update and not self.cascade.delete_orphan and not self.passive_deletes == 'all':
# post_update means we have to update our row to not reference the child object
# before we can DELETE the row
for state in deplist:
self._synchronize(state, None, None, True, uowcommit)
history = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes)
if history:
self._conditional_post_update(state, uowcommit, history.sum())
else:
for state in deplist:
history = uowcommit.get_attribute_history(state, self.key, passive=True)
if history:
for child in history.added:
self._synchronize(state, child, None, False, uowcommit)
self._conditional_post_update(state, uowcommit, history.sum())
def preprocess_dependencies(self, task, deplist, uowcommit, delete=False):
if self.post_update:
return
if delete:
if self.cascade.delete or self.cascade.delete_orphan:
for state in deplist:
history = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes)
if history:
if self.cascade.delete_orphan:
todelete = history.sum()
else:
todelete = history.non_deleted()
for child in todelete:
if child is None:
continue
uowcommit.register_object(child, isdelete=True)
for c, m in self.mapper.cascade_iterator('delete', child):
uowcommit.register_object(
attributes.instance_state(c), isdelete=True)
else:
for state in deplist:
uowcommit.register_object(state)
if self.cascade.delete_orphan:
history = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes)
if history:
for child in history.deleted:
if self.hasparent(child) is False:
uowcommit.register_object(child, isdelete=True)
for c, m in self.mapper.cascade_iterator('delete', child):
uowcommit.register_object(
attributes.instance_state(c),
isdelete=True)
def _synchronize(self, state, child, associationrow, clearkeys, uowcommit):
if state is None or (not self.post_update and uowcommit.is_deleted(state)):
return
if clearkeys or child is None:
sync.clear(state, self.parent, self.prop.synchronize_pairs)
else:
self._verify_canload(child)
sync.populate(child, self.mapper, state,
self.parent, self.prop.synchronize_pairs, uowcommit,
self.passive_updates
)
class ManyToManyDP(DependencyProcessor):
def register_dependencies(self, uowcommit):
# many-to-many. create a "Stub" mapper to represent the
# "middle table" in the relationship. This stub mapper doesnt save
# or delete any objects, but just marks a dependency on the two
# related mappers. its dependency processor then populates the
# association table.
uowcommit.register_dependency(self.parent, self.dependency_marker)
uowcommit.register_dependency(self.mapper, self.dependency_marker)
def register_processors(self, uowcommit):
uowcommit.register_processor(self.dependency_marker, self, self.parent)
def process_dependencies(self, task, deplist, uowcommit, delete = False):
connection = uowcommit.transaction.connection(self.mapper)
secondary_delete = []
secondary_insert = []
secondary_update = []
if delete:
for state in deplist:
history = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes)
if history:
for child in history.non_added():
if child is None or self._check_reverse_action(uowcommit, child, state, "manytomany"):
continue
associationrow = {}
self._synchronize(state, child, associationrow, False, uowcommit)
secondary_delete.append(associationrow)
self._performed_action(uowcommit, state, child, "manytomany")
else:
for state in deplist:
history = uowcommit.get_attribute_history(state, self.key)
if history:
for child in history.added:
if child is None or self._check_reverse_action(uowcommit, child, state, "manytomany"):
continue
associationrow = {}
self._synchronize(state, child, associationrow, False, uowcommit)
self._performed_action(uowcommit, state, child, "manytomany")
secondary_insert.append(associationrow)
for child in history.deleted:
if child is None or self._check_reverse_action(uowcommit, child, state, "manytomany"):
continue
associationrow = {}
self._synchronize(state, child, associationrow, False, uowcommit)
self._performed_action(uowcommit, state, child, "manytomany")
secondary_delete.append(associationrow)
if not self.passive_updates and self._pks_changed(uowcommit, state):
if not history:
history = uowcommit.get_attribute_history(state, self.key, passive=False)
for child in history.unchanged:
associationrow = {}
sync.update(state, self.parent, associationrow, "old_", self.prop.synchronize_pairs)
sync.update(child, self.mapper, associationrow, "old_", self.prop.secondary_synchronize_pairs)
#self.syncrules.update(associationrow, state, child, "old_")
secondary_update.append(associationrow)
if secondary_delete:
statement = self.secondary.delete(sql.and_(*[
c == sql.bindparam(c.key, type_=c.type) for c in self.secondary.c if c.key in associationrow
]))
result = connection.execute(statement, secondary_delete)
if result.supports_sane_multi_rowcount() and result.rowcount != len(secondary_delete):
raise exc.ConcurrentModificationError("Deleted rowcount %d does not match number of "
"secondary table rows deleted from table '%s': %d" %
(result.rowcount, self.secondary.description, len(secondary_delete)))
if secondary_update:
statement = self.secondary.update(sql.and_(*[
c == sql.bindparam("old_" + c.key, type_=c.type) for c in self.secondary.c if c.key in associationrow
]))
result = connection.execute(statement, secondary_update)
if result.supports_sane_multi_rowcount() and result.rowcount != len(secondary_update):
raise exc.ConcurrentModificationError("Updated rowcount %d does not match number of "
"secondary table rows updated from table '%s': %d" %
(result.rowcount, self.secondary.description, len(secondary_update)))
if secondary_insert:
statement = self.secondary.insert()
connection.execute(statement, secondary_insert)
def preprocess_dependencies(self, task, deplist, uowcommit, delete = False):
if not delete:
for state in deplist:
history = uowcommit.get_attribute_history(state, self.key, passive=True)
if history:
for child in history.deleted:
if self.cascade.delete_orphan and self.hasparent(child) is False:
uowcommit.register_object(child, isdelete=True)
for c, m in self.mapper.cascade_iterator('delete', child):
uowcommit.register_object(
attributes.instance_state(c), isdelete=True)
def _synchronize(self, state, child, associationrow, clearkeys, uowcommit):
if associationrow is None:
return
self._verify_canload(child)
sync.populate_dict(state, self.parent, associationrow,
self.prop.synchronize_pairs)
sync.populate_dict(child, self.mapper, associationrow,
self.prop.secondary_synchronize_pairs)
def _pks_changed(self, uowcommit, state):
return sync.source_modified(uowcommit, state, self.parent, self.prop.synchronize_pairs)
class MapperStub(object):
"""Represent a many-to-many dependency within a flush
context.
The UOWTransaction corresponds dependencies to mappers.
MapperStub takes the place of the "association table"
so that a depedendency can be corresponded to it.
"""
def __init__(self, parent, mapper, key):
self.mapper = mapper
self.base_mapper = self
self.class_ = mapper.class_
self._inheriting_mappers = []
def polymorphic_iterator(self):
return iter((self,))
def _register_dependencies(self, uowcommit):
pass
def _register_procesors(self, uowcommit):
pass
def _save_obj(self, *args, **kwargs):
pass
def _delete_obj(self, *args, **kwargs):
pass
def primary_mapper(self):
return self

293
sqlalchemy/orm/dynamic.py Normal file
View File

@@ -0,0 +1,293 @@
# dynamic.py
# Copyright (C) the SQLAlchemy authors and contributors
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Dynamic collection API.
Dynamic collections act like Query() objects for read operations and support
basic add/delete mutation.
"""
from sqlalchemy import log, util
from sqlalchemy import exc as sa_exc
from sqlalchemy.orm import exc as sa_exc
from sqlalchemy.sql import operators
from sqlalchemy.orm import (
attributes, object_session, util as mapperutil, strategies, object_mapper
)
from sqlalchemy.orm.query import Query
from sqlalchemy.orm.util import _state_has_identity, has_identity
from sqlalchemy.orm import attributes, collections
class DynaLoader(strategies.AbstractRelationshipLoader):
def init_class_attribute(self, mapper):
self.is_class_level = True
strategies._register_attribute(self,
mapper,
useobject=True,
impl_class=DynamicAttributeImpl,
target_mapper=self.parent_property.mapper,
order_by=self.parent_property.order_by,
query_class=self.parent_property.query_class
)
def create_row_processor(self, selectcontext, path, mapper, row, adapter):
return (None, None)
log.class_logger(DynaLoader)
class DynamicAttributeImpl(attributes.AttributeImpl):
uses_objects = True
accepts_scalar_loader = False
def __init__(self, class_, key, typecallable,
target_mapper, order_by, query_class=None, **kwargs):
super(DynamicAttributeImpl, self).__init__(class_, key, typecallable, **kwargs)
self.target_mapper = target_mapper
self.order_by = order_by
if not query_class:
self.query_class = AppenderQuery
elif AppenderMixin in query_class.mro():
self.query_class = query_class
else:
self.query_class = mixin_user_query(query_class)
def get(self, state, dict_, passive=False):
if passive:
return self._get_collection_history(state, passive=True).added_items
else:
return self.query_class(self, state)
def get_collection(self, state, dict_, user_data=None, passive=True):
if passive:
return self._get_collection_history(state, passive=passive).added_items
else:
history = self._get_collection_history(state, passive=passive)
return history.added_items + history.unchanged_items
def fire_append_event(self, state, dict_, value, initiator):
collection_history = self._modified_event(state, dict_)
collection_history.added_items.append(value)
for ext in self.extensions:
ext.append(state, value, initiator or self)
if self.trackparent and value is not None:
self.sethasparent(attributes.instance_state(value), True)
def fire_remove_event(self, state, dict_, value, initiator):
collection_history = self._modified_event(state, dict_)
collection_history.deleted_items.append(value)
if self.trackparent and value is not None:
self.sethasparent(attributes.instance_state(value), False)
for ext in self.extensions:
ext.remove(state, value, initiator or self)
def _modified_event(self, state, dict_):
if self.key not in state.committed_state:
state.committed_state[self.key] = CollectionHistory(self, state)
state.modified_event(dict_,
self,
False,
attributes.NEVER_SET,
passive=attributes.PASSIVE_NO_INITIALIZE)
# this is a hack to allow the _base.ComparableEntity fixture
# to work
dict_[self.key] = True
return state.committed_state[self.key]
def set(self, state, dict_, value, initiator, passive=attributes.PASSIVE_OFF):
if initiator is self:
return
self._set_iterable(state, dict_, value)
def _set_iterable(self, state, dict_, iterable, adapter=None):
collection_history = self._modified_event(state, dict_)
new_values = list(iterable)
if _state_has_identity(state):
old_collection = list(self.get(state, dict_))
else:
old_collection = []
collections.bulk_replace(new_values, DynCollectionAdapter(self, state, old_collection), DynCollectionAdapter(self, state, new_values))
def delete(self, *args, **kwargs):
raise NotImplementedError()
def get_history(self, state, dict_, passive=False):
c = self._get_collection_history(state, passive)
return attributes.History(c.added_items, c.unchanged_items, c.deleted_items)
def _get_collection_history(self, state, passive=False):
if self.key in state.committed_state:
c = state.committed_state[self.key]
else:
c = CollectionHistory(self, state)
if not passive:
return CollectionHistory(self, state, apply_to=c)
else:
return c
def append(self, state, dict_, value, initiator, passive=False):
if initiator is not self:
self.fire_append_event(state, dict_, value, initiator)
def remove(self, state, dict_, value, initiator, passive=False):
if initiator is not self:
self.fire_remove_event(state, dict_, value, initiator)
class DynCollectionAdapter(object):
"""the dynamic analogue to orm.collections.CollectionAdapter"""
def __init__(self, attr, owner_state, data):
self.attr = attr
self.state = owner_state
self.data = data
def __iter__(self):
return iter(self.data)
def append_with_event(self, item, initiator=None):
self.attr.append(self.state, self.state.dict, item, initiator)
def remove_with_event(self, item, initiator=None):
self.attr.remove(self.state, self.state.dict, item, initiator)
def append_without_event(self, item):
pass
def remove_without_event(self, item):
pass
class AppenderMixin(object):
query_class = None
def __init__(self, attr, state):
Query.__init__(self, attr.target_mapper, None)
self.instance = instance = state.obj()
self.attr = attr
mapper = object_mapper(instance)
prop = mapper.get_property(self.attr.key, resolve_synonyms=True)
self._criterion = prop.compare(
operators.eq,
instance,
value_is_parent=True,
alias_secondary=False)
if self.attr.order_by:
self._order_by = self.attr.order_by
def __session(self):
sess = object_session(self.instance)
if sess is not None and self.autoflush and sess.autoflush and self.instance in sess:
sess.flush()
if not has_identity(self.instance):
return None
else:
return sess
def session(self):
return self.__session()
session = property(session, lambda s, x:None)
def __iter__(self):
sess = self.__session()
if sess is None:
return iter(self.attr._get_collection_history(
attributes.instance_state(self.instance),
passive=True).added_items)
else:
return iter(self._clone(sess))
def __getitem__(self, index):
sess = self.__session()
if sess is None:
return self.attr._get_collection_history(
attributes.instance_state(self.instance),
passive=True).added_items.__getitem__(index)
else:
return self._clone(sess).__getitem__(index)
def count(self):
sess = self.__session()
if sess is None:
return len(self.attr._get_collection_history(
attributes.instance_state(self.instance),
passive=True).added_items)
else:
return self._clone(sess).count()
def _clone(self, sess=None):
# note we're returning an entirely new Query class instance
# here without any assignment capabilities; the class of this
# query is determined by the session.
instance = self.instance
if sess is None:
sess = object_session(instance)
if sess is None:
raise orm_exc.DetachedInstanceError(
"Parent instance %s is not bound to a Session, and no "
"contextual session is established; lazy load operation "
"of attribute '%s' cannot proceed" % (
mapperutil.instance_str(instance), self.attr.key))
if self.query_class:
query = self.query_class(self.attr.target_mapper, session=sess)
else:
query = sess.query(self.attr.target_mapper)
query._criterion = self._criterion
query._order_by = self._order_by
return query
def append(self, item):
self.attr.append(
attributes.instance_state(self.instance),
attributes.instance_dict(self.instance), item, None)
def remove(self, item):
self.attr.remove(
attributes.instance_state(self.instance),
attributes.instance_dict(self.instance), item, None)
class AppenderQuery(AppenderMixin, Query):
"""A dynamic query that supports basic collection storage operations."""
def mixin_user_query(cls):
"""Return a new class with AppenderQuery functionality layered over."""
name = 'Appender' + cls.__name__
return type(name, (AppenderMixin, cls), {'query_class': cls})
class CollectionHistory(object):
"""Overrides AttributeHistory to receive append/remove events directly."""
def __init__(self, attr, state, apply_to=None):
if apply_to:
deleted = util.IdentitySet(apply_to.deleted_items)
added = apply_to.added_items
coll = AppenderQuery(attr, state).autoflush(False)
self.unchanged_items = [o for o in util.IdentitySet(coll) if o not in deleted]
self.added_items = apply_to.added_items
self.deleted_items = apply_to.deleted_items
else:
self.deleted_items = []
self.added_items = []
self.unchanged_items = []

104
sqlalchemy/orm/evaluator.py Normal file
View File

@@ -0,0 +1,104 @@
import operator
from sqlalchemy.sql import operators, functions
from sqlalchemy.sql import expression as sql
class UnevaluatableError(Exception):
pass
_straight_ops = set(getattr(operators, op)
for op in ('add', 'mul', 'sub',
# Py2K
'div',
# end Py2K
'mod', 'truediv',
'lt', 'le', 'ne', 'gt', 'ge', 'eq'))
_notimplemented_ops = set(getattr(operators, op)
for op in ('like_op', 'notlike_op', 'ilike_op',
'notilike_op', 'between_op', 'in_op',
'notin_op', 'endswith_op', 'concat_op'))
class EvaluatorCompiler(object):
def process(self, clause):
meth = getattr(self, "visit_%s" % clause.__visit_name__, None)
if not meth:
raise UnevaluatableError("Cannot evaluate %s" % type(clause).__name__)
return meth(clause)
def visit_grouping(self, clause):
return self.process(clause.element)
def visit_null(self, clause):
return lambda obj: None
def visit_column(self, clause):
if 'parentmapper' in clause._annotations:
key = clause._annotations['parentmapper']._get_col_to_prop(clause).key
else:
key = clause.key
get_corresponding_attr = operator.attrgetter(key)
return lambda obj: get_corresponding_attr(obj)
def visit_clauselist(self, clause):
evaluators = map(self.process, clause.clauses)
if clause.operator is operators.or_:
def evaluate(obj):
has_null = False
for sub_evaluate in evaluators:
value = sub_evaluate(obj)
if value:
return True
has_null = has_null or value is None
if has_null:
return None
return False
elif clause.operator is operators.and_:
def evaluate(obj):
for sub_evaluate in evaluators:
value = sub_evaluate(obj)
if not value:
if value is None:
return None
return False
return True
else:
raise UnevaluatableError("Cannot evaluate clauselist with operator %s" % clause.operator)
return evaluate
def visit_binary(self, clause):
eval_left,eval_right = map(self.process, [clause.left, clause.right])
operator = clause.operator
if operator is operators.is_:
def evaluate(obj):
return eval_left(obj) == eval_right(obj)
elif operator is operators.isnot:
def evaluate(obj):
return eval_left(obj) != eval_right(obj)
elif operator in _straight_ops:
def evaluate(obj):
left_val = eval_left(obj)
right_val = eval_right(obj)
if left_val is None or right_val is None:
return None
return operator(eval_left(obj), eval_right(obj))
else:
raise UnevaluatableError("Cannot evaluate %s with operator %s" % (type(clause).__name__, clause.operator))
return evaluate
def visit_unary(self, clause):
eval_inner = self.process(clause.element)
if clause.operator is operators.inv:
def evaluate(obj):
value = eval_inner(obj)
if value is None:
return None
return not value
return evaluate
raise UnevaluatableError("Cannot evaluate %s with operator %s" % (type(clause).__name__, clause.operator))
def visit_bindparam(self, clause):
val = clause.value
return lambda obj: val

98
sqlalchemy/orm/exc.py Normal file
View File

@@ -0,0 +1,98 @@
# exc.py - ORM exceptions
# Copyright (C) the SQLAlchemy authors and contributors
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""SQLAlchemy ORM exceptions."""
import sqlalchemy as sa
NO_STATE = (AttributeError, KeyError)
"""Exception types that may be raised by instrumentation implementations."""
class ConcurrentModificationError(sa.exc.SQLAlchemyError):
"""Rows have been modified outside of the unit of work."""
class FlushError(sa.exc.SQLAlchemyError):
"""A invalid condition was detected during flush()."""
class UnmappedError(sa.exc.InvalidRequestError):
"""TODO"""
class DetachedInstanceError(sa.exc.SQLAlchemyError):
"""An attempt to access unloaded attributes on a mapped instance that is detached."""
class UnmappedInstanceError(UnmappedError):
"""An mapping operation was requested for an unknown instance."""
def __init__(self, obj, msg=None):
if not msg:
try:
mapper = sa.orm.class_mapper(type(obj))
name = _safe_cls_name(type(obj))
msg = ("Class %r is mapped, but this instance lacks "
"instrumentation. This occurs when the instance is created "
"before sqlalchemy.orm.mapper(%s) was called." % (name, name))
except UnmappedClassError:
msg = _default_unmapped(type(obj))
if isinstance(obj, type):
msg += (
'; was a class (%s) supplied where an instance was '
'required?' % _safe_cls_name(obj))
UnmappedError.__init__(self, msg)
class UnmappedClassError(UnmappedError):
"""An mapping operation was requested for an unknown class."""
def __init__(self, cls, msg=None):
if not msg:
msg = _default_unmapped(cls)
UnmappedError.__init__(self, msg)
class ObjectDeletedError(sa.exc.InvalidRequestError):
"""An refresh() operation failed to re-retrieve an object's row."""
class UnmappedColumnError(sa.exc.InvalidRequestError):
"""Mapping operation was requested on an unknown column."""
class NoResultFound(sa.exc.InvalidRequestError):
"""A database result was required but none was found."""
class MultipleResultsFound(sa.exc.InvalidRequestError):
"""A single database result was required but more than one were found."""
# Legacy compat until 0.6.
sa.exc.ConcurrentModificationError = ConcurrentModificationError
sa.exc.FlushError = FlushError
sa.exc.UnmappedColumnError
def _safe_cls_name(cls):
try:
cls_name = '.'.join((cls.__module__, cls.__name__))
except AttributeError:
cls_name = getattr(cls, '__name__', None)
if cls_name is None:
cls_name = repr(cls)
return cls_name
def _default_unmapped(cls):
try:
mappers = sa.orm.attributes.manager_of_class(cls).mappers
except NO_STATE:
mappers = {}
except TypeError:
mappers = {}
name = _safe_cls_name(cls)
if not mappers:
return "Class '%s' is not mapped" % name

251
sqlalchemy/orm/identity.py Normal file
View File

@@ -0,0 +1,251 @@
# identity.py
# Copyright (C) the SQLAlchemy authors and contributors
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
import weakref
from sqlalchemy import util as base_util
from sqlalchemy.orm import attributes
class IdentityMap(dict):
def __init__(self):
self._mutable_attrs = set()
self._modified = set()
self._wr = weakref.ref(self)
def replace(self, state):
raise NotImplementedError()
def add(self, state):
raise NotImplementedError()
def remove(self, state):
raise NotImplementedError()
def update(self, dict):
raise NotImplementedError("IdentityMap uses add() to insert data")
def clear(self):
raise NotImplementedError("IdentityMap uses remove() to remove data")
def _manage_incoming_state(self, state):
state._instance_dict = self._wr
if state.modified:
self._modified.add(state)
if state.manager.mutable_attributes:
self._mutable_attrs.add(state)
def _manage_removed_state(self, state):
del state._instance_dict
self._mutable_attrs.discard(state)
self._modified.discard(state)
def _dirty_states(self):
return self._modified.union(s for s in self._mutable_attrs.copy()
if s.modified)
def check_modified(self):
"""return True if any InstanceStates present have been marked as 'modified'."""
if self._modified:
return True
else:
for state in self._mutable_attrs.copy():
if state.modified:
return True
return False
def has_key(self, key):
return key in self
def popitem(self):
raise NotImplementedError("IdentityMap uses remove() to remove data")
def pop(self, key, *args):
raise NotImplementedError("IdentityMap uses remove() to remove data")
def setdefault(self, key, default=None):
raise NotImplementedError("IdentityMap uses add() to insert data")
def copy(self):
raise NotImplementedError()
def __setitem__(self, key, value):
raise NotImplementedError("IdentityMap uses add() to insert data")
def __delitem__(self, key):
raise NotImplementedError("IdentityMap uses remove() to remove data")
class WeakInstanceDict(IdentityMap):
def __getitem__(self, key):
state = dict.__getitem__(self, key)
o = state.obj()
if o is None:
o = state._is_really_none()
if o is None:
raise KeyError, key
return o
def __contains__(self, key):
try:
if dict.__contains__(self, key):
state = dict.__getitem__(self, key)
o = state.obj()
if o is None:
o = state._is_really_none()
else:
return False
except KeyError:
return False
else:
return o is not None
def contains_state(self, state):
return dict.get(self, state.key) is state
def replace(self, state):
if dict.__contains__(self, state.key):
existing = dict.__getitem__(self, state.key)
if existing is not state:
self._manage_removed_state(existing)
else:
return
dict.__setitem__(self, state.key, state)
self._manage_incoming_state(state)
def add(self, state):
if state.key in self:
if dict.__getitem__(self, state.key) is not state:
raise AssertionError("A conflicting state is already "
"present in the identity map for key %r"
% (state.key, ))
else:
dict.__setitem__(self, state.key, state)
self._manage_incoming_state(state)
def remove_key(self, key):
state = dict.__getitem__(self, key)
self.remove(state)
def remove(self, state):
if dict.pop(self, state.key) is not state:
raise AssertionError("State %s is not present in this identity map" % state)
self._manage_removed_state(state)
def discard(self, state):
if self.contains_state(state):
dict.__delitem__(self, state.key)
self._manage_removed_state(state)
def get(self, key, default=None):
state = dict.get(self, key, default)
if state is default:
return default
o = state.obj()
if o is None:
o = state._is_really_none()
if o is None:
return default
return o
# Py2K
def items(self):
return list(self.iteritems())
def iteritems(self):
for state in dict.itervalues(self):
# end Py2K
# Py3K
#def items(self):
# for state in dict.values(self):
value = state.obj()
if value is not None:
yield state.key, value
# Py2K
def values(self):
return list(self.itervalues())
def itervalues(self):
for state in dict.itervalues(self):
# end Py2K
# Py3K
#def values(self):
# for state in dict.values(self):
instance = state.obj()
if instance is not None:
yield instance
def all_states(self):
# Py3K
# return list(dict.values(self))
# Py2K
return dict.values(self)
# end Py2K
def prune(self):
return 0
class StrongInstanceDict(IdentityMap):
def all_states(self):
return [attributes.instance_state(o) for o in self.itervalues()]
def contains_state(self, state):
return state.key in self and attributes.instance_state(self[state.key]) is state
def replace(self, state):
if dict.__contains__(self, state.key):
existing = dict.__getitem__(self, state.key)
existing = attributes.instance_state(existing)
if existing is not state:
self._manage_removed_state(existing)
else:
return
dict.__setitem__(self, state.key, state.obj())
self._manage_incoming_state(state)
def add(self, state):
if state.key in self:
if attributes.instance_state(dict.__getitem__(self, state.key)) is not state:
raise AssertionError("A conflicting state is already present in the identity map for key %r" % (state.key, ))
else:
dict.__setitem__(self, state.key, state.obj())
self._manage_incoming_state(state)
def remove(self, state):
if attributes.instance_state(dict.pop(self, state.key)) is not state:
raise AssertionError("State %s is not present in this identity map" % state)
self._manage_removed_state(state)
def discard(self, state):
if self.contains_state(state):
dict.__delitem__(self, state.key)
self._manage_removed_state(state)
def remove_key(self, key):
state = attributes.instance_state(dict.__getitem__(self, key))
self.remove(state)
def prune(self):
"""prune unreferenced, non-dirty states."""
ref_count = len(self)
dirty = [s.obj() for s in self.all_states() if s.modified]
# work around http://bugs.python.org/issue6149
keepers = weakref.WeakValueDictionary()
keepers.update(self)
dict.clear(self)
dict.update(self, keepers)
self.modified = bool(dirty)
return ref_count - len(self)

1098
sqlalchemy/orm/interfaces.py Normal file

File diff suppressed because it is too large Load Diff

1958
sqlalchemy/orm/mapper.py Normal file

File diff suppressed because it is too large Load Diff

1205
sqlalchemy/orm/properties.py Normal file

File diff suppressed because it is too large Load Diff

2469
sqlalchemy/orm/query.py Normal file

File diff suppressed because it is too large Load Diff

205
sqlalchemy/orm/scoping.py Normal file
View File

@@ -0,0 +1,205 @@
# scoping.py
# Copyright (C) the SQLAlchemy authors and contributors
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
import sqlalchemy.exceptions as sa_exc
from sqlalchemy.util import ScopedRegistry, ThreadLocalRegistry, \
to_list, get_cls_kwargs, deprecated
from sqlalchemy.orm import (
EXT_CONTINUE, MapperExtension, class_mapper, object_session
)
from sqlalchemy.orm import exc as orm_exc
from sqlalchemy.orm.session import Session
__all__ = ['ScopedSession']
class ScopedSession(object):
"""Provides thread-local management of Sessions.
Usage::
Session = scoped_session(sessionmaker(autoflush=True))
... use session normally.
"""
def __init__(self, session_factory, scopefunc=None):
self.session_factory = session_factory
if scopefunc:
self.registry = ScopedRegistry(session_factory, scopefunc)
else:
self.registry = ThreadLocalRegistry(session_factory)
self.extension = _ScopedExt(self)
def __call__(self, **kwargs):
if kwargs:
scope = kwargs.pop('scope', False)
if scope is not None:
if self.registry.has():
raise sa_exc.InvalidRequestError("Scoped session is already present; no new arguments may be specified.")
else:
sess = self.session_factory(**kwargs)
self.registry.set(sess)
return sess
else:
return self.session_factory(**kwargs)
else:
return self.registry()
def remove(self):
"""Dispose of the current contextual session."""
if self.registry.has():
self.registry().close()
self.registry.clear()
@deprecated("Session.mapper is deprecated. "
"Please see http://www.sqlalchemy.org/trac/wiki/UsageRecipes/SessionAwareMapper "
"for information on how to replicate its behavior.")
def mapper(self, *args, **kwargs):
"""return a mapper() function which associates this ScopedSession with the Mapper.
DEPRECATED.
"""
from sqlalchemy.orm import mapper
extension_args = dict((arg, kwargs.pop(arg))
for arg in get_cls_kwargs(_ScopedExt)
if arg in kwargs)
kwargs['extension'] = extension = to_list(kwargs.get('extension', []))
if extension_args:
extension.append(self.extension.configure(**extension_args))
else:
extension.append(self.extension)
return mapper(*args, **kwargs)
def configure(self, **kwargs):
"""reconfigure the sessionmaker used by this ScopedSession."""
self.session_factory.configure(**kwargs)
def query_property(self, query_cls=None):
"""return a class property which produces a `Query` object against the
class when called.
e.g.::
Session = scoped_session(sessionmaker())
class MyClass(object):
query = Session.query_property()
# after mappers are defined
result = MyClass.query.filter(MyClass.name=='foo').all()
Produces instances of the session's configured query class by
default. To override and use a custom implementation, provide
a ``query_cls`` callable. The callable will be invoked with
the class's mapper as a positional argument and a session
keyword argument.
There is no limit to the number of query properties placed on
a class.
"""
class query(object):
def __get__(s, instance, owner):
try:
mapper = class_mapper(owner)
if mapper:
if query_cls:
# custom query class
return query_cls(mapper, session=self.registry())
else:
# session's configured query class
return self.registry().query(mapper)
except orm_exc.UnmappedClassError:
return None
return query()
def instrument(name):
def do(self, *args, **kwargs):
return getattr(self.registry(), name)(*args, **kwargs)
return do
for meth in Session.public_methods:
setattr(ScopedSession, meth, instrument(meth))
def makeprop(name):
def set(self, attr):
setattr(self.registry(), name, attr)
def get(self):
return getattr(self.registry(), name)
return property(get, set)
for prop in ('bind', 'dirty', 'deleted', 'new', 'identity_map', 'is_active', 'autoflush'):
setattr(ScopedSession, prop, makeprop(prop))
def clslevel(name):
def do(cls, *args, **kwargs):
return getattr(Session, name)(*args, **kwargs)
return classmethod(do)
for prop in ('close_all', 'object_session', 'identity_key'):
setattr(ScopedSession, prop, clslevel(prop))
class _ScopedExt(MapperExtension):
def __init__(self, context, validate=False, save_on_init=True):
self.context = context
self.validate = validate
self.save_on_init = save_on_init
self.set_kwargs_on_init = True
def validating(self):
return _ScopedExt(self.context, validate=True)
def configure(self, **kwargs):
return _ScopedExt(self.context, **kwargs)
def instrument_class(self, mapper, class_):
class query(object):
def __getattr__(s, key):
return getattr(self.context.registry().query(class_), key)
def __call__(s):
return self.context.registry().query(class_)
def __get__(self, instance, cls):
return self
if not 'query' in class_.__dict__:
class_.query = query()
if self.set_kwargs_on_init and class_.__init__ is object.__init__:
class_.__init__ = self._default__init__(mapper)
def _default__init__(ext, mapper):
def __init__(self, **kwargs):
for key, value in kwargs.iteritems():
if ext.validate:
if not mapper.get_property(key, resolve_synonyms=False,
raiseerr=False):
raise sa_exc.ArgumentError(
"Invalid __init__ argument: '%s'" % key)
setattr(self, key, value)
return __init__
def init_instance(self, mapper, class_, oldinit, instance, args, kwargs):
if self.save_on_init:
session = kwargs.pop('_sa_session', None)
if session is None:
session = self.context.registry()
session._save_without_cascade(instance)
return EXT_CONTINUE
def init_failed(self, mapper, class_, oldinit, instance, args, kwargs):
sess = object_session(instance)
if sess:
sess.expunge(instance)
return EXT_CONTINUE
def dispose_class(self, mapper, class_):
if hasattr(class_, 'query'):
delattr(class_, 'query')

1604
sqlalchemy/orm/session.py Normal file

File diff suppressed because it is too large Load Diff

15
sqlalchemy/orm/shard.py Normal file
View File

@@ -0,0 +1,15 @@
# shard.py
# Copyright (C) the SQLAlchemy authors and contributors
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from sqlalchemy import util
util.warn_deprecated(
"Horizontal sharding is now importable via "
"'import sqlalchemy.ext.horizontal_shard"
)
from sqlalchemy.ext.horizontal_shard import *

527
sqlalchemy/orm/state.py Normal file
View File

@@ -0,0 +1,527 @@
from sqlalchemy.util import EMPTY_SET
import weakref
from sqlalchemy import util
from sqlalchemy.orm.attributes import PASSIVE_NO_RESULT, PASSIVE_OFF, \
NEVER_SET, NO_VALUE, manager_of_class, \
ATTR_WAS_SET
from sqlalchemy.orm import attributes, exc as orm_exc, interfaces
import sys
attributes.state = sys.modules['sqlalchemy.orm.state']
class InstanceState(object):
"""tracks state information at the instance level."""
session_id = None
key = None
runid = None
load_options = EMPTY_SET
load_path = ()
insert_order = None
mutable_dict = None
_strong_obj = None
modified = False
expired = False
def __init__(self, obj, manager):
self.class_ = obj.__class__
self.manager = manager
self.obj = weakref.ref(obj, self._cleanup)
@util.memoized_property
def committed_state(self):
return {}
@util.memoized_property
def parents(self):
return {}
@util.memoized_property
def pending(self):
return {}
@util.memoized_property
def callables(self):
return {}
def detach(self):
if self.session_id:
try:
del self.session_id
except AttributeError:
pass
def dispose(self):
self.detach()
del self.obj
def _cleanup(self, ref):
instance_dict = self._instance_dict()
if instance_dict:
try:
instance_dict.remove(self)
except AssertionError:
pass
# remove possible cycles
self.__dict__.pop('callables', None)
self.dispose()
def obj(self):
return None
@property
def dict(self):
o = self.obj()
if o is not None:
return attributes.instance_dict(o)
else:
return {}
@property
def sort_key(self):
return self.key and self.key[1] or (self.insert_order, )
def initialize_instance(*mixed, **kwargs):
self, instance, args = mixed[0], mixed[1], mixed[2:]
manager = self.manager
for fn in manager.events.on_init:
fn(self, instance, args, kwargs)
# LESSTHANIDEAL:
# adjust for the case where the InstanceState was created before
# mapper compilation, and this actually needs to be a MutableAttrInstanceState
if manager.mutable_attributes and self.__class__ is not MutableAttrInstanceState:
self.__class__ = MutableAttrInstanceState
self.obj = weakref.ref(self.obj(), self._cleanup)
self.mutable_dict = {}
try:
return manager.events.original_init(*mixed[1:], **kwargs)
except:
for fn in manager.events.on_init_failure:
fn(self, instance, args, kwargs)
raise
def get_history(self, key, **kwargs):
return self.manager.get_impl(key).get_history(self, self.dict, **kwargs)
def get_impl(self, key):
return self.manager.get_impl(key)
def get_pending(self, key):
if key not in self.pending:
self.pending[key] = PendingCollection()
return self.pending[key]
def value_as_iterable(self, key, passive=PASSIVE_OFF):
"""return an InstanceState attribute as a list,
regardless of it being a scalar or collection-based
attribute.
returns None if passive is not PASSIVE_OFF and the getter returns
PASSIVE_NO_RESULT.
"""
impl = self.get_impl(key)
dict_ = self.dict
x = impl.get(self, dict_, passive=passive)
if x is PASSIVE_NO_RESULT:
return None
elif hasattr(impl, 'get_collection'):
return impl.get_collection(self, dict_, x, passive=passive)
else:
return [x]
def _run_on_load(self, instance):
self.manager.events.run('on_load', instance)
def __getstate__(self):
d = {'instance':self.obj()}
d.update(
(k, self.__dict__[k]) for k in (
'committed_state', 'pending', 'parents', 'modified', 'expired',
'callables', 'key', 'load_options', 'mutable_dict'
) if k in self.__dict__
)
if self.load_path:
d['load_path'] = interfaces.serialize_path(self.load_path)
return d
def __setstate__(self, state):
self.obj = weakref.ref(state['instance'], self._cleanup)
self.class_ = state['instance'].__class__
self.manager = manager = manager_of_class(self.class_)
if manager is None:
raise orm_exc.UnmappedInstanceError(
state['instance'],
"Cannot deserialize object of type %r - no mapper() has"
" been configured for this class within the current Python process!" %
self.class_)
elif manager.mapper and not manager.mapper.compiled:
manager.mapper.compile()
self.committed_state = state.get('committed_state', {})
self.pending = state.get('pending', {})
self.parents = state.get('parents', {})
self.modified = state.get('modified', False)
self.expired = state.get('expired', False)
self.callables = state.get('callables', {})
if self.modified:
self._strong_obj = state['instance']
self.__dict__.update([
(k, state[k]) for k in (
'key', 'load_options', 'mutable_dict'
) if k in state
])
if 'load_path' in state:
self.load_path = interfaces.deserialize_path(state['load_path'])
def initialize(self, key):
"""Set this attribute to an empty value or collection,
based on the AttributeImpl in use."""
self.manager.get_impl(key).initialize(self, self.dict)
def reset(self, dict_, key):
"""Remove the given attribute and any
callables associated with it."""
dict_.pop(key, None)
self.callables.pop(key, None)
def expire_attribute_pre_commit(self, dict_, key):
"""a fast expire that can be called by column loaders during a load.
The additional bookkeeping is finished up in commit_all().
This method is actually called a lot with joined-table
loading, when the second table isn't present in the result.
"""
dict_.pop(key, None)
self.callables[key] = self
def set_callable(self, dict_, key, callable_):
"""Remove the given attribute and set the given callable
as a loader."""
dict_.pop(key, None)
self.callables[key] = callable_
def expire_attributes(self, dict_, attribute_names, instance_dict=None):
"""Expire all or a group of attributes.
If all attributes are expired, the "expired" flag is set to True.
"""
if attribute_names is None:
attribute_names = self.manager.keys()
self.expired = True
if self.modified:
if not instance_dict:
instance_dict = self._instance_dict()
if instance_dict:
instance_dict._modified.discard(self)
else:
instance_dict._modified.discard(self)
self.modified = False
filter_deferred = True
else:
filter_deferred = False
to_clear = (
self.__dict__.get('pending', None),
self.__dict__.get('committed_state', None),
self.mutable_dict
)
for key in attribute_names:
impl = self.manager[key].impl
if impl.accepts_scalar_loader and \
(not filter_deferred or impl.expire_missing or key in dict_):
self.callables[key] = self
dict_.pop(key, None)
for d in to_clear:
if d is not None:
d.pop(key, None)
def __call__(self, **kw):
"""__call__ allows the InstanceState to act as a deferred
callable for loading expired attributes, which is also
serializable (picklable).
"""
if kw.get('passive') is attributes.PASSIVE_NO_FETCH:
return attributes.PASSIVE_NO_RESULT
toload = self.expired_attributes.\
intersection(self.unmodified)
self.manager.deferred_scalar_loader(self, toload)
# if the loader failed, or this
# instance state didn't have an identity,
# the attributes still might be in the callables
# dict. ensure they are removed.
for k in toload.intersection(self.callables):
del self.callables[k]
return ATTR_WAS_SET
@property
def unmodified(self):
"""Return the set of keys which have no uncommitted changes"""
return set(self.manager).difference(self.committed_state)
@property
def unloaded(self):
"""Return the set of keys which do not have a loaded value.
This includes expired attributes and any other attribute that
was never populated or modified.
"""
return set(self.manager).\
difference(self.committed_state).\
difference(self.dict)
@property
def expired_attributes(self):
"""Return the set of keys which are 'expired' to be loaded by
the manager's deferred scalar loader, assuming no pending
changes.
see also the ``unmodified`` collection which is intersected
against this set when a refresh operation occurs.
"""
return set([k for k, v in self.callables.items() if v is self])
def _instance_dict(self):
return None
def _is_really_none(self):
return self.obj()
def modified_event(self, dict_, attr, should_copy, previous, passive=PASSIVE_OFF):
needs_committed = attr.key not in self.committed_state
if needs_committed:
if previous is NEVER_SET:
if passive:
if attr.key in dict_:
previous = dict_[attr.key]
else:
previous = attr.get(self, dict_)
if should_copy and previous not in (None, NO_VALUE, NEVER_SET):
previous = attr.copy(previous)
if needs_committed:
self.committed_state[attr.key] = previous
if not self.modified:
instance_dict = self._instance_dict()
if instance_dict:
instance_dict._modified.add(self)
self.modified = True
if self._strong_obj is None:
self._strong_obj = self.obj()
def commit(self, dict_, keys):
"""Commit attributes.
This is used by a partial-attribute load operation to mark committed
those attributes which were refreshed from the database.
Attributes marked as "expired" can potentially remain "expired" after
this step if a value was not populated in state.dict.
"""
class_manager = self.manager
for key in keys:
if key in dict_ and key in class_manager.mutable_attributes:
self.committed_state[key] = self.manager[key].impl.copy(dict_[key])
else:
self.committed_state.pop(key, None)
self.expired = False
for key in set(self.callables).\
intersection(keys).\
intersection(dict_):
del self.callables[key]
def commit_all(self, dict_, instance_dict=None):
"""commit all attributes unconditionally.
This is used after a flush() or a full load/refresh
to remove all pending state from the instance.
- all attributes are marked as "committed"
- the "strong dirty reference" is removed
- the "modified" flag is set to False
- any "expired" markers/callables for attributes loaded are removed.
Attributes marked as "expired" can potentially remain "expired" after this step
if a value was not populated in state.dict.
"""
self.__dict__.pop('committed_state', None)
self.__dict__.pop('pending', None)
if 'callables' in self.__dict__:
callables = self.callables
for key in list(callables):
if key in dict_ and callables[key] is self:
del callables[key]
for key in self.manager.mutable_attributes:
if key in dict_:
self.committed_state[key] = self.manager[key].impl.copy(dict_[key])
if instance_dict and self.modified:
instance_dict._modified.discard(self)
self.modified = self.expired = False
self._strong_obj = None
class MutableAttrInstanceState(InstanceState):
"""InstanceState implementation for objects that reference 'mutable'
attributes.
Has a more involved "cleanup" handler that checks mutable attributes
for changes upon dereference, resurrecting if needed.
"""
@util.memoized_property
def mutable_dict(self):
return {}
def _get_modified(self, dict_=None):
if self.__dict__.get('modified', False):
return True
else:
if dict_ is None:
dict_ = self.dict
for key in self.manager.mutable_attributes:
if self.manager[key].impl.check_mutable_modified(self, dict_):
return True
else:
return False
def _set_modified(self, value):
self.__dict__['modified'] = value
modified = property(_get_modified, _set_modified)
@property
def unmodified(self):
"""a set of keys which have no uncommitted changes"""
dict_ = self.dict
return set([
key for key in self.manager
if (key not in self.committed_state or
(key in self.manager.mutable_attributes and
not self.manager[key].impl.check_mutable_modified(self, dict_)))])
def _is_really_none(self):
"""do a check modified/resurrect.
This would be called in the extremely rare
race condition that the weakref returned None but
the cleanup handler had not yet established the
__resurrect callable as its replacement.
"""
if self.modified:
self.obj = self.__resurrect
return self.obj()
else:
return None
def reset(self, dict_, key):
self.mutable_dict.pop(key, None)
InstanceState.reset(self, dict_, key)
def _cleanup(self, ref):
"""weakref callback.
This method may be called by an asynchronous
gc.
If the state shows pending changes, the weakref
is replaced by the __resurrect callable which will
re-establish an object reference on next access,
else removes this InstanceState from the owning
identity map, if any.
"""
if self._get_modified(self.mutable_dict):
self.obj = self.__resurrect
else:
instance_dict = self._instance_dict()
if instance_dict:
try:
instance_dict.remove(self)
except AssertionError:
pass
self.dispose()
def __resurrect(self):
"""A substitute for the obj() weakref function which resurrects."""
# store strong ref'ed version of the object; will revert
# to weakref when changes are persisted
obj = self.manager.new_instance(state=self)
self.obj = weakref.ref(obj, self._cleanup)
self._strong_obj = obj
obj.__dict__.update(self.mutable_dict)
# re-establishes identity attributes from the key
self.manager.events.run('on_resurrect', self, obj)
# TODO: don't really think we should run this here.
# resurrect is only meant to preserve the minimal state needed to
# do an UPDATE, not to produce a fully usable object
self._run_on_load(obj)
return obj
class PendingCollection(object):
"""A writable placeholder for an unloaded collection.
Stores items appended to and removed from a collection that has not yet
been loaded. When the collection is loaded, the changes stored in
PendingCollection are applied to it to produce the final result.
"""
def __init__(self):
self.deleted_items = util.IdentitySet()
self.added_items = util.OrderedIdentitySet()
def append(self, value):
if value in self.deleted_items:
self.deleted_items.remove(value)
self.added_items.add(value)
def remove(self, value):
if value in self.added_items:
self.added_items.remove(value)
self.deleted_items.add(value)

1229
sqlalchemy/orm/strategies.py Normal file

File diff suppressed because it is too large Load Diff

98
sqlalchemy/orm/sync.py Normal file
View File

@@ -0,0 +1,98 @@
# mapper/sync.py
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""private module containing functions used for copying data
between instances based on join conditions.
"""
from sqlalchemy.orm import exc, util as mapperutil
def populate(source, source_mapper, dest, dest_mapper,
synchronize_pairs, uowcommit, passive_updates):
for l, r in synchronize_pairs:
try:
value = source_mapper._get_state_attr_by_column(source, l)
except exc.UnmappedColumnError:
_raise_col_to_prop(False, source_mapper, l, dest_mapper, r)
try:
dest_mapper._set_state_attr_by_column(dest, r, value)
except exc.UnmappedColumnError:
_raise_col_to_prop(True, source_mapper, l, dest_mapper, r)
# techically the "r.primary_key" check isn't
# needed here, but we check for this condition to limit
# how often this logic is invoked for memory/performance
# reasons, since we only need this info for a primary key
# destination.
if l.primary_key and r.primary_key and \
r.references(l) and passive_updates:
uowcommit.attributes[("pk_cascaded", dest, r)] = True
def clear(dest, dest_mapper, synchronize_pairs):
for l, r in synchronize_pairs:
if r.primary_key:
raise AssertionError(
"Dependency rule tried to blank-out primary key "
"column '%s' on instance '%s'" %
(r, mapperutil.state_str(dest))
)
try:
dest_mapper._set_state_attr_by_column(dest, r, None)
except exc.UnmappedColumnError:
_raise_col_to_prop(True, None, l, dest_mapper, r)
def update(source, source_mapper, dest, old_prefix, synchronize_pairs):
for l, r in synchronize_pairs:
try:
oldvalue = source_mapper._get_committed_attr_by_column(source.obj(), l)
value = source_mapper._get_state_attr_by_column(source, l)
except exc.UnmappedColumnError:
_raise_col_to_prop(False, source_mapper, l, None, r)
dest[r.key] = value
dest[old_prefix + r.key] = oldvalue
def populate_dict(source, source_mapper, dict_, synchronize_pairs):
for l, r in synchronize_pairs:
try:
value = source_mapper._get_state_attr_by_column(source, l)
except exc.UnmappedColumnError:
_raise_col_to_prop(False, source_mapper, l, None, r)
dict_[r.key] = value
def source_modified(uowcommit, source, source_mapper, synchronize_pairs):
"""return true if the source object has changes from an old to a
new value on the given synchronize pairs
"""
for l, r in synchronize_pairs:
try:
prop = source_mapper._get_col_to_prop(l)
except exc.UnmappedColumnError:
_raise_col_to_prop(False, source_mapper, l, None, r)
history = uowcommit.get_attribute_history(source, prop.key, passive=True)
if len(history.deleted):
return True
else:
return False
def _raise_col_to_prop(isdest, source_mapper, source_column, dest_mapper, dest_column):
if isdest:
raise exc.UnmappedColumnError(
"Can't execute sync rule for destination column '%s'; "
"mapper '%s' does not map this column. Try using an explicit"
" `foreign_keys` collection which does not include this column "
"(or use a viewonly=True relation)." % (dest_column, source_mapper)
)
else:
raise exc.UnmappedColumnError(
"Can't execute sync rule for source column '%s'; mapper '%s' "
"does not map this column. Try using an explicit `foreign_keys`"
" collection which does not include destination column '%s' (or "
"use a viewonly=True relation)." %
(source_column, source_mapper, dest_column)
)

View File

@@ -0,0 +1,781 @@
# orm/unitofwork.py
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""The internals for the Unit Of Work system.
Includes hooks into the attributes package enabling the routing of
change events to Unit Of Work objects, as well as the flush()
mechanism which creates a dependency structure that executes change
operations.
A Unit of Work is essentially a system of maintaining a graph of
in-memory objects and their modified state. Objects are maintained as
unique against their primary key identity using an *identity map*
pattern. The Unit of Work then maintains lists of objects that are
new, dirty, or deleted and provides the capability to flush all those
changes at once.
"""
from sqlalchemy import util, log, topological
from sqlalchemy.orm import attributes, interfaces
from sqlalchemy.orm import util as mapperutil
from sqlalchemy.orm.mapper import _state_mapper
# Load lazily
object_session = None
_state_session = None
class UOWEventHandler(interfaces.AttributeExtension):
"""An event handler added to all relationship attributes which handles
session cascade operations.
"""
active_history = False
def __init__(self, key):
self.key = key
def append(self, state, item, initiator):
# process "save_update" cascade rules for when an instance is appended to the list of another instance
sess = _state_session(state)
if sess:
prop = _state_mapper(state).get_property(self.key)
if prop.cascade.save_update and item not in sess:
sess.add(item)
return item
def remove(self, state, item, initiator):
sess = _state_session(state)
if sess:
prop = _state_mapper(state).get_property(self.key)
# expunge pending orphans
if prop.cascade.delete_orphan and \
item in sess.new and \
prop.mapper._is_orphan(attributes.instance_state(item)):
sess.expunge(item)
def set(self, state, newvalue, oldvalue, initiator):
# process "save_update" cascade rules for when an instance is attached to another instance
if oldvalue is newvalue:
return newvalue
sess = _state_session(state)
if sess:
prop = _state_mapper(state).get_property(self.key)
if newvalue is not None and prop.cascade.save_update and newvalue not in sess:
sess.add(newvalue)
if prop.cascade.delete_orphan and oldvalue in sess.new and \
prop.mapper._is_orphan(attributes.instance_state(oldvalue)):
sess.expunge(oldvalue)
return newvalue
class UOWTransaction(object):
"""Handles the details of organizing and executing transaction
tasks during a UnitOfWork object's flush() operation.
The central operation is to form a graph of nodes represented by the
``UOWTask`` class, which is then traversed by a ``UOWExecutor`` object
that issues SQL and instance-synchronizing operations via the related
packages.
"""
def __init__(self, session):
self.session = session
self.mapper_flush_opts = session._mapper_flush_opts
# stores tuples of mapper/dependent mapper pairs,
# representing a partial ordering fed into topological sort
self.dependencies = set()
# dictionary of mappers to UOWTasks
self.tasks = {}
# dictionary used by external actors to store arbitrary state
# information.
self.attributes = {}
self.processors = set()
def get_attribute_history(self, state, key, passive=True):
hashkey = ("history", state, key)
# cache the objects, not the states; the strong reference here
# prevents newly loaded objects from being dereferenced during the
# flush process
if hashkey in self.attributes:
(history, cached_passive) = self.attributes[hashkey]
# if the cached lookup was "passive" and now we want non-passive, do a non-passive
# lookup and re-cache
if cached_passive and not passive:
history = attributes.get_state_history(state, key, passive=False)
self.attributes[hashkey] = (history, passive)
else:
history = attributes.get_state_history(state, key, passive=passive)
self.attributes[hashkey] = (history, passive)
if not history or not state.get_impl(key).uses_objects:
return history
else:
return history.as_state()
def register_object(self, state, isdelete=False,
listonly=False, postupdate=False, post_update_cols=None):
# if object is not in the overall session, do nothing
if not self.session._contains_state(state):
return
mapper = _state_mapper(state)
task = self.get_task_by_mapper(mapper)
if postupdate:
task.append_postupdate(state, post_update_cols)
else:
task.append(state, listonly=listonly, isdelete=isdelete)
# ensure the mapper for this object has had its
# DependencyProcessors added.
if mapper not in self.processors:
mapper._register_processors(self)
self.processors.add(mapper)
if mapper.base_mapper not in self.processors:
mapper.base_mapper._register_processors(self)
self.processors.add(mapper.base_mapper)
def set_row_switch(self, state):
"""mark a deleted object as a 'row switch'.
this indicates that an INSERT statement elsewhere corresponds to this DELETE;
the INSERT is converted to an UPDATE and the DELETE does not occur.
"""
mapper = _state_mapper(state)
task = self.get_task_by_mapper(mapper)
taskelement = task._objects[state]
taskelement.isdelete = "rowswitch"
def is_deleted(self, state):
"""return true if the given state is marked as deleted within this UOWTransaction."""
mapper = _state_mapper(state)
task = self.get_task_by_mapper(mapper)
return task.is_deleted(state)
def get_task_by_mapper(self, mapper, dontcreate=False):
"""return UOWTask element corresponding to the given mapper.
Will create a new UOWTask, including a UOWTask corresponding to the
"base" inherited mapper, if needed, unless the dontcreate flag is True.
"""
try:
return self.tasks[mapper]
except KeyError:
if dontcreate:
return None
base_mapper = mapper.base_mapper
if base_mapper in self.tasks:
base_task = self.tasks[base_mapper]
else:
self.tasks[base_mapper] = base_task = UOWTask(self, base_mapper)
base_mapper._register_dependencies(self)
if mapper not in self.tasks:
self.tasks[mapper] = task = UOWTask(self, mapper, base_task=base_task)
mapper._register_dependencies(self)
else:
task = self.tasks[mapper]
return task
def register_dependency(self, mapper, dependency):
"""register a dependency between two mappers.
Called by ``mapper.PropertyLoader`` to register the objects
handled by one mapper being dependent on the objects handled
by another.
"""
# correct for primary mapper
# also convert to the "base mapper", the parentmost task at the top of an inheritance chain
# dependency sorting is done via non-inheriting mappers only, dependencies between mappers
# in the same inheritance chain is done at the per-object level
mapper = mapper.primary_mapper().base_mapper
dependency = dependency.primary_mapper().base_mapper
self.dependencies.add((mapper, dependency))
def register_processor(self, mapper, processor, mapperfrom):
"""register a dependency processor, corresponding to
operations which occur between two mappers.
"""
# correct for primary mapper
mapper = mapper.primary_mapper()
mapperfrom = mapperfrom.primary_mapper()
task = self.get_task_by_mapper(mapper)
targettask = self.get_task_by_mapper(mapperfrom)
up = UOWDependencyProcessor(processor, targettask)
task.dependencies.add(up)
def execute(self):
"""Execute this UOWTransaction.
This will organize all collected UOWTasks into a dependency-sorted
list which is then traversed using the traversal scheme
encoded in the UOWExecutor class. Operations to mappers and dependency
processors are fired off in order to issue SQL to the database and
synchronize instance attributes with database values and related
foreign key values."""
# pre-execute dependency processors. this process may
# result in new tasks, objects and/or dependency processors being added,
# particularly with 'delete-orphan' cascade rules.
# keep running through the full list of tasks until all
# objects have been processed.
while True:
ret = False
for task in self.tasks.values():
for up in list(task.dependencies):
if up.preexecute(self):
ret = True
if not ret:
break
tasks = self._sort_dependencies()
if self._should_log_info():
self.logger.info("Task dump:\n%s", self._dump(tasks))
UOWExecutor().execute(self, tasks)
self.logger.info("Execute Complete")
def _dump(self, tasks):
from uowdumper import UOWDumper
return UOWDumper.dump(tasks)
@property
def elements(self):
"""Iterate UOWTaskElements."""
for task in self.tasks.itervalues():
for elem in task.elements:
yield elem
def finalize_flush_changes(self):
"""mark processed objects as clean / deleted after a successful flush().
this method is called within the flush() method after the
execute() method has succeeded and the transaction has been committed.
"""
for elem in self.elements:
if elem.isdelete:
self.session._remove_newly_deleted(elem.state)
elif not elem.listonly:
self.session._register_newly_persistent(elem.state)
def _sort_dependencies(self):
nodes = topological.sort_with_cycles(self.dependencies,
[t.mapper for t in self.tasks.itervalues() if t.base_task is t]
)
ret = []
for item, cycles in nodes:
task = self.get_task_by_mapper(item)
if cycles:
for t in task._sort_circular_dependencies(
self,
[self.get_task_by_mapper(i) for i in cycles]
):
ret.append(t)
else:
ret.append(task)
return ret
log.class_logger(UOWTransaction)
class UOWTask(object):
"""A collection of mapped states corresponding to a particular mapper."""
def __init__(self, uowtransaction, mapper, base_task=None):
self.uowtransaction = uowtransaction
# base_task is the UOWTask which represents the "base mapper"
# in our mapper's inheritance chain. if the mapper does not
# inherit from any other mapper, the base_task is self.
# the _inheriting_tasks dictionary is a dictionary present only
# on the "base_task"-holding UOWTask, which maps all mappers within
# an inheritance hierarchy to their corresponding UOWTask instances.
if base_task is None:
self.base_task = self
self._inheriting_tasks = {mapper:self}
else:
self.base_task = base_task
base_task._inheriting_tasks[mapper] = self
# the Mapper which this UOWTask corresponds to
self.mapper = mapper
# mapping of InstanceState -> UOWTaskElement
self._objects = {}
self.dependent_tasks = []
self.dependencies = set()
self.cyclical_dependencies = set()
@util.memoized_property
def inheriting_mappers(self):
return list(self.mapper.polymorphic_iterator())
@property
def polymorphic_tasks(self):
"""Return an iterator of UOWTask objects corresponding to the
inheritance sequence of this UOWTask's mapper.
e.g. if mapper B and mapper C inherit from mapper A, and
mapper D inherits from B:
mapperA -> mapperB -> mapperD
-> mapperC
the inheritance sequence starting at mapper A is a depth-first
traversal:
[mapperA, mapperB, mapperD, mapperC]
this method will therefore return
[UOWTask(mapperA), UOWTask(mapperB), UOWTask(mapperD),
UOWTask(mapperC)]
The concept of "polymporphic iteration" is adapted into
several property-based iterators which return object
instances, UOWTaskElements and UOWDependencyProcessors in an
order corresponding to this sequence of parent UOWTasks. This
is used to issue operations related to inheritance-chains of
mappers in the proper order based on dependencies between
those mappers.
"""
for mapper in self.inheriting_mappers:
t = self.base_task._inheriting_tasks.get(mapper, None)
if t is not None:
yield t
def is_empty(self):
"""return True if this UOWTask is 'empty', meaning it has no child items.
used only for debugging output.
"""
return not self._objects and not self.dependencies
def append(self, state, listonly=False, isdelete=False):
if state not in self._objects:
self._objects[state] = rec = UOWTaskElement(state)
else:
rec = self._objects[state]
rec.update(listonly, isdelete)
def append_postupdate(self, state, post_update_cols):
"""issue a 'post update' UPDATE statement via this object's mapper immediately.
this operation is used only with relationships that specify the `post_update=True`
flag.
"""
# postupdates are UPDATED immeditely (for now)
# convert post_update_cols list to a Set so that __hash__() is used to compare columns
# instead of __eq__()
self.mapper._save_obj([state], self.uowtransaction, postupdate=True, post_update_cols=set(post_update_cols))
def __contains__(self, state):
"""return True if the given object is contained within this UOWTask or inheriting tasks."""
for task in self.polymorphic_tasks:
if state in task._objects:
return True
else:
return False
def is_deleted(self, state):
"""return True if the given object is marked as to be deleted within this UOWTask."""
try:
return self._objects[state].isdelete
except KeyError:
return False
def _polymorphic_collection(fn):
"""return a property that will adapt the collection returned by the
given callable into a polymorphic traversal."""
@property
def collection(self):
for task in self.polymorphic_tasks:
for rec in fn(task):
yield rec
return collection
def _polymorphic_collection_filtered(fn):
def collection(self, mappers):
for task in self.polymorphic_tasks:
if task.mapper in mappers:
for rec in fn(task):
yield rec
return collection
@property
def elements(self):
return self._objects.values()
@_polymorphic_collection
def polymorphic_elements(self):
return self.elements
@_polymorphic_collection_filtered
def filter_polymorphic_elements(self):
return self.elements
@property
def polymorphic_tosave_elements(self):
return [rec for rec in self.polymorphic_elements if not rec.isdelete]
@property
def polymorphic_todelete_elements(self):
return [rec for rec in self.polymorphic_elements if rec.isdelete]
@property
def polymorphic_tosave_objects(self):
return [
rec.state for rec in self.polymorphic_elements
if rec.state is not None and not rec.listonly and rec.isdelete is False
]
@property
def polymorphic_todelete_objects(self):
return [
rec.state for rec in self.polymorphic_elements
if rec.state is not None and not rec.listonly and rec.isdelete is True
]
@_polymorphic_collection
def polymorphic_dependencies(self):
return self.dependencies
@_polymorphic_collection
def polymorphic_cyclical_dependencies(self):
return self.cyclical_dependencies
def _sort_circular_dependencies(self, trans, cycles):
"""Topologically sort individual entities with row-level dependencies.
Builds a modified UOWTask structure, and is invoked when the
per-mapper topological structure is found to have cycles.
"""
dependencies = {}
def set_processor_for_state(state, depprocessor, target_state, isdelete):
if state not in dependencies:
dependencies[state] = {}
tasks = dependencies[state]
if depprocessor not in tasks:
tasks[depprocessor] = UOWDependencyProcessor(
depprocessor.processor,
UOWTask(self.uowtransaction, depprocessor.targettask.mapper)
)
tasks[depprocessor].targettask.append(target_state, isdelete=isdelete)
cycles = set(cycles)
def dependency_in_cycles(dep):
proctask = trans.get_task_by_mapper(dep.processor.mapper.base_mapper, True)
targettask = trans.get_task_by_mapper(dep.targettask.mapper.base_mapper, True)
return targettask in cycles and (proctask is not None and proctask in cycles)
deps_by_targettask = {}
extradeplist = []
for task in cycles:
for dep in task.polymorphic_dependencies:
if not dependency_in_cycles(dep):
extradeplist.append(dep)
for t in dep.targettask.polymorphic_tasks:
l = deps_by_targettask.setdefault(t, [])
l.append(dep)
object_to_original_task = {}
tuples = []
for task in cycles:
for subtask in task.polymorphic_tasks:
for taskelement in subtask.elements:
state = taskelement.state
object_to_original_task[state] = subtask
if subtask not in deps_by_targettask:
continue
for dep in deps_by_targettask[subtask]:
if not dep.processor.has_dependencies or not dependency_in_cycles(dep):
continue
(processor, targettask) = (dep.processor, dep.targettask)
isdelete = taskelement.isdelete
# list of dependent objects from this object
(added, unchanged, deleted) = dep.get_object_dependencies(state, trans, passive=True)
if not added and not unchanged and not deleted:
continue
# the task corresponding to saving/deleting of those dependent objects
childtask = trans.get_task_by_mapper(processor.mapper)
childlist = added + unchanged + deleted
for o in childlist:
if o is None:
continue
if o not in childtask:
childtask.append(o, listonly=True)
object_to_original_task[o] = childtask
whosdep = dep.whose_dependent_on_who(state, o)
if whosdep is not None:
tuples.append(whosdep)
if whosdep[0] is state:
set_processor_for_state(whosdep[0], dep, whosdep[0], isdelete=isdelete)
else:
set_processor_for_state(whosdep[0], dep, whosdep[1], isdelete=isdelete)
else:
# TODO: no test coverage here
set_processor_for_state(state, dep, state, isdelete=isdelete)
t = UOWTask(self.uowtransaction, self.mapper)
t.dependencies.update(extradeplist)
used_tasks = set()
# rationale for "tree" sort as opposed to a straight
# dependency - keep non-dependent objects
# grouped together, so that insert ordering as determined
# by session.add() is maintained.
# An alternative might be to represent the "insert order"
# as part of the topological sort itself, which would
# eliminate the need for this step (but may make the original
# topological sort more expensive)
head = topological.sort_as_tree(tuples, object_to_original_task.iterkeys())
if head is not None:
original_to_tasks = {}
stack = [(head, t)]
while stack:
((state, cycles, children), parenttask) = stack.pop()
originating_task = object_to_original_task[state]
used_tasks.add(originating_task)
if (parenttask, originating_task) not in original_to_tasks:
task = UOWTask(self.uowtransaction, originating_task.mapper)
original_to_tasks[(parenttask, originating_task)] = task
parenttask.dependent_tasks.append(task)
else:
task = original_to_tasks[(parenttask, originating_task)]
task.append(state, originating_task._objects[state].listonly, isdelete=originating_task._objects[state].isdelete)
if state in dependencies:
task.cyclical_dependencies.update(dependencies[state].itervalues())
stack += [(n, task) for n in children]
ret = [t]
# add tasks that were in the cycle, but didnt get assembled
# into the cyclical tree, to the start of the list
for t2 in cycles:
if t2 not in used_tasks and t2 is not self:
localtask = UOWTask(self.uowtransaction, t2.mapper)
for state in t2.elements:
localtask.append(state, t2.listonly, isdelete=t2._objects[state].isdelete)
for dep in t2.dependencies:
localtask.dependencies.add(dep)
ret.insert(0, localtask)
return ret
def __repr__(self):
return ("UOWTask(%s) Mapper: '%r'" % (hex(id(self)), self.mapper))
class UOWTaskElement(object):
"""Corresponds to a single InstanceState to be saved, deleted,
or otherwise marked as having dependencies. A collection of
UOWTaskElements are held by a UOWTask.
"""
def __init__(self, state):
self.state = state
self.listonly = True
self.isdelete = False
self.preprocessed = set()
def update(self, listonly, isdelete):
if not listonly and self.listonly:
self.listonly = False
self.preprocessed.clear()
if isdelete and not self.isdelete:
self.isdelete = True
self.preprocessed.clear()
def __repr__(self):
return "UOWTaskElement/%d: %s/%d %s" % (
id(self),
self.state.class_.__name__,
id(self.state.obj()),
(self.listonly and 'listonly' or (self.isdelete and 'delete' or 'save'))
)
class UOWDependencyProcessor(object):
"""In between the saving and deleting of objects, process
dependent data, such as filling in a foreign key on a child item
from a new primary key, or deleting association rows before a
delete. This object acts as a proxy to a DependencyProcessor.
"""
def __init__(self, processor, targettask):
self.processor = processor
self.targettask = targettask
prop = processor.prop
# define a set of mappers which
# will filter the lists of entities
# this UOWDP processes. this allows
# MapperProperties to be overridden
# at least for concrete mappers.
self._mappers = set([
m
for m in self.processor.parent.polymorphic_iterator()
if m._props[prop.key] is prop
]).union(self.processor.mapper.polymorphic_iterator())
def __repr__(self):
return "UOWDependencyProcessor(%s, %s)" % (str(self.processor), str(self.targettask))
def __eq__(self, other):
return other.processor is self.processor and other.targettask is self.targettask
def __hash__(self):
return hash((self.processor, self.targettask))
def preexecute(self, trans):
"""preprocess all objects contained within this ``UOWDependencyProcessor``s target task.
This may locate additional objects which should be part of the
transaction, such as those affected deletes, orphans to be
deleted, etc.
Once an object is preprocessed, its ``UOWTaskElement`` is marked as processed. If subsequent
changes occur to the ``UOWTaskElement``, its processed flag is reset, and will require processing
again.
Return True if any objects were preprocessed, or False if no
objects were preprocessed. If True is returned, the parent ``UOWTransaction`` will
ultimately call ``preexecute()`` again on all processors until no new objects are processed.
"""
def getobj(elem):
elem.preprocessed.add(self)
return elem.state
ret = False
elements = [getobj(elem) for elem in
self.targettask.filter_polymorphic_elements(self._mappers)
if self not in elem.preprocessed and not elem.isdelete]
if elements:
ret = True
self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=False)
elements = [getobj(elem) for elem in
self.targettask.filter_polymorphic_elements(self._mappers)
if self not in elem.preprocessed and elem.isdelete]
if elements:
ret = True
self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=True)
return ret
def execute(self, trans, delete):
"""process all objects contained within this ``UOWDependencyProcessor``s target task."""
elements = [e for e in
self.targettask.filter_polymorphic_elements(self._mappers)
if bool(e.isdelete)==delete]
self.processor.process_dependencies(
self.targettask,
[elem.state for elem in elements],
trans,
delete=delete)
def get_object_dependencies(self, state, trans, passive):
return trans.get_attribute_history(state, self.processor.key, passive=passive)
def whose_dependent_on_who(self, state1, state2):
"""establish which object is operationally dependent amongst a parent/child
using the semantics stated by the dependency processor.
This method is used to establish a partial ordering (set of dependency tuples)
when toplogically sorting on a per-instance basis.
"""
return self.processor.whose_dependent_on_who(state1, state2)
class UOWExecutor(object):
"""Encapsulates the execution traversal of a UOWTransaction structure."""
def execute(self, trans, tasks, isdelete=None):
if isdelete is not True:
for task in tasks:
self.execute_save_steps(trans, task)
if isdelete is not False:
for task in reversed(tasks):
self.execute_delete_steps(trans, task)
def save_objects(self, trans, task):
task.mapper._save_obj(task.polymorphic_tosave_objects, trans)
def delete_objects(self, trans, task):
task.mapper._delete_obj(task.polymorphic_todelete_objects, trans)
def execute_dependency(self, trans, dep, isdelete):
dep.execute(trans, isdelete)
def execute_save_steps(self, trans, task):
self.save_objects(trans, task)
for dep in task.polymorphic_cyclical_dependencies:
self.execute_dependency(trans, dep, False)
for dep in task.polymorphic_cyclical_dependencies:
self.execute_dependency(trans, dep, True)
self.execute_cyclical_dependencies(trans, task, False)
self.execute_dependencies(trans, task)
def execute_delete_steps(self, trans, task):
self.execute_cyclical_dependencies(trans, task, True)
self.delete_objects(trans, task)
def execute_dependencies(self, trans, task):
polymorphic_dependencies = list(task.polymorphic_dependencies)
for dep in polymorphic_dependencies:
self.execute_dependency(trans, dep, False)
for dep in reversed(polymorphic_dependencies):
self.execute_dependency(trans, dep, True)
def execute_cyclical_dependencies(self, trans, task, isdelete):
for t in task.dependent_tasks:
self.execute(trans, [t], isdelete)

101
sqlalchemy/orm/uowdumper.py Normal file
View File

@@ -0,0 +1,101 @@
# orm/uowdumper.py
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Dumps out a string representation of a UOWTask structure"""
from sqlalchemy.orm import unitofwork
from sqlalchemy.orm import util as mapperutil
import StringIO
class UOWDumper(unitofwork.UOWExecutor):
def __init__(self, tasks, buf):
self.indent = 0
self.tasks = tasks
self.buf = buf
self.execute(None, tasks)
@classmethod
def dump(cls, tasks):
buf = StringIO.StringIO()
UOWDumper(tasks, buf)
return buf.getvalue()
def execute(self, trans, tasks, isdelete=None):
if isdelete is not True:
for task in tasks:
self._execute(trans, task, False)
if isdelete is not False:
for task in reversed(tasks):
self._execute(trans, task, True)
def _execute(self, trans, task, isdelete):
try:
i = self._indent()
if i:
i = i[:-1] + "+-"
self.buf.write(i + " " + self._repr_task(task))
self.buf.write(" (" + (isdelete and "delete " or "save/update ") + "phase) \n")
self.indent += 1
super(UOWDumper, self).execute(trans, [task], isdelete)
finally:
self.indent -= 1
def save_objects(self, trans, task):
for rec in sorted(task.polymorphic_tosave_elements, key=lambda a: a.state.sort_key):
if rec.listonly:
continue
self.buf.write(self._indent()[:-1] + "+-" + self._repr_task_element(rec) + "\n")
def delete_objects(self, trans, task):
for rec in task.polymorphic_todelete_elements:
if rec.listonly:
continue
self.buf.write(self._indent() + "- " + self._repr_task_element(rec) + "\n")
def execute_dependency(self, transaction, dep, isdelete):
self._dump_processor(dep, isdelete)
def _dump_processor(self, proc, deletes):
if deletes:
val = proc.targettask.polymorphic_todelete_elements
else:
val = proc.targettask.polymorphic_tosave_elements
for v in val:
self.buf.write(self._indent() + " +- " + self._repr_task_element(v, proc.processor.key, process=True) + "\n")
def _repr_task_element(self, te, attribute=None, process=False):
if getattr(te, 'state', None) is None:
objid = "(placeholder)"
else:
if attribute is not None:
objid = "%s.%s" % (mapperutil.state_str(te.state), attribute)
else:
objid = mapperutil.state_str(te.state)
if process:
return "Process %s" % (objid)
else:
return "%s %s" % ((te.isdelete and "Delete" or "Save"), objid)
def _repr_task(self, task):
if task.mapper is not None:
if task.mapper.__class__.__name__ == 'Mapper':
name = task.mapper.class_.__name__ + "/" + task.mapper.local_table.description
else:
name = repr(task.mapper)
else:
name = '(none)'
return ("UOWTask(%s, %s)" % (hex(id(task)), name))
def _repr_task_class(self, task):
if task.mapper is not None and task.mapper.__class__.__name__ == 'Mapper':
return task.mapper.class_.__name__
else:
return '(none)'
def _indent(self):
return " |" * self.indent

668
sqlalchemy/orm/util.py Normal file
View File

@@ -0,0 +1,668 @@
# mapper/util.py
# Copyright (C) 2005, 2006, 2007, 2008, 2009, 2010 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
import sqlalchemy.exceptions as sa_exc
from sqlalchemy import sql, util
from sqlalchemy.sql import expression, util as sql_util, operators
from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE, PropComparator, \
MapperProperty, AttributeExtension
from sqlalchemy.orm import attributes, exc
mapperlib = None
all_cascades = frozenset(("delete", "delete-orphan", "all", "merge",
"expunge", "save-update", "refresh-expire",
"none"))
_INSTRUMENTOR = ('mapper', 'instrumentor')
class CascadeOptions(object):
"""Keeps track of the options sent to relationship().cascade"""
def __init__(self, arg=""):
if not arg:
values = set()
else:
values = set(c.strip() for c in arg.split(','))
self.delete_orphan = "delete-orphan" in values
self.delete = "delete" in values or "all" in values
self.save_update = "save-update" in values or "all" in values
self.merge = "merge" in values or "all" in values
self.expunge = "expunge" in values or "all" in values
self.refresh_expire = "refresh-expire" in values or "all" in values
if self.delete_orphan and not self.delete:
util.warn("The 'delete-orphan' cascade option requires "
"'delete'. This will raise an error in 0.6.")
for x in values:
if x not in all_cascades:
raise sa_exc.ArgumentError("Invalid cascade option '%s'" % x)
def __contains__(self, item):
return getattr(self, item.replace("-", "_"), False)
def __repr__(self):
return "CascadeOptions(%s)" % repr(",".join(
[x for x in ['delete', 'save_update', 'merge', 'expunge',
'delete_orphan', 'refresh-expire']
if getattr(self, x, False) is True]))
class Validator(AttributeExtension):
"""Runs a validation method on an attribute value to be set or appended.
The Validator class is used by the :func:`~sqlalchemy.orm.validates`
decorator, and direct access is usually not needed.
"""
def __init__(self, key, validator):
"""Construct a new Validator.
key - name of the attribute to be validated;
will be passed as the second argument to
the validation method (the first is the object instance itself).
validator - an function or instance method which accepts
three arguments; an instance (usually just 'self' for a method),
the key name of the attribute, and the value. The function should
return the same value given, unless it wishes to modify it.
"""
self.key = key
self.validator = validator
def append(self, state, value, initiator):
return self.validator(state.obj(), self.key, value)
def set(self, state, value, oldvalue, initiator):
return self.validator(state.obj(), self.key, value)
def polymorphic_union(table_map, typecolname, aliasname='p_union'):
"""Create a ``UNION`` statement used by a polymorphic mapper.
See :ref:`concrete_inheritance` for an example of how
this is used.
"""
colnames = set()
colnamemaps = {}
types = {}
for key in table_map.keys():
table = table_map[key]
# mysql doesnt like selecting from a select; make it an alias of the select
if isinstance(table, sql.Select):
table = table.alias()
table_map[key] = table
m = {}
for c in table.c:
colnames.add(c.key)
m[c.key] = c
types[c.key] = c.type
colnamemaps[table] = m
def col(name, table):
try:
return colnamemaps[table][name]
except KeyError:
return sql.cast(sql.null(), types[name]).label(name)
result = []
for type, table in table_map.iteritems():
if typecolname is not None:
result.append(sql.select([col(name, table) for name in colnames] +
[sql.literal_column("'%s'" % type).label(typecolname)],
from_obj=[table]))
else:
result.append(sql.select([col(name, table) for name in colnames],
from_obj=[table]))
return sql.union_all(*result).alias(aliasname)
def identity_key(*args, **kwargs):
"""Get an identity key.
Valid call signatures:
* ``identity_key(class, ident)``
class
mapped class (must be a positional argument)
ident
primary key, if the key is composite this is a tuple
* ``identity_key(instance=instance)``
instance
object instance (must be given as a keyword arg)
* ``identity_key(class, row=row)``
class
mapped class (must be a positional argument)
row
result proxy row (must be given as a keyword arg)
"""
if args:
if len(args) == 1:
class_ = args[0]
try:
row = kwargs.pop("row")
except KeyError:
ident = kwargs.pop("ident")
elif len(args) == 2:
class_, ident = args
elif len(args) == 3:
class_, ident = args
else:
raise sa_exc.ArgumentError("expected up to three "
"positional arguments, got %s" % len(args))
if kwargs:
raise sa_exc.ArgumentError("unknown keyword arguments: %s"
% ", ".join(kwargs.keys()))
mapper = class_mapper(class_)
if "ident" in locals():
return mapper.identity_key_from_primary_key(ident)
return mapper.identity_key_from_row(row)
instance = kwargs.pop("instance")
if kwargs:
raise sa_exc.ArgumentError("unknown keyword arguments: %s"
% ", ".join(kwargs.keys()))
mapper = object_mapper(instance)
return mapper.identity_key_from_instance(instance)
class ExtensionCarrier(dict):
"""Fronts an ordered collection of MapperExtension objects.
Bundles multiple MapperExtensions into a unified callable unit,
encapsulating ordering, looping and EXT_CONTINUE logic. The
ExtensionCarrier implements the MapperExtension interface, e.g.::
carrier.after_insert(...args...)
The dictionary interface provides containment for implemented
method names mapped to a callable which executes that method
for participating extensions.
"""
interface = set(method for method in dir(MapperExtension)
if not method.startswith('_'))
def __init__(self, extensions=None):
self._extensions = []
for ext in extensions or ():
self.append(ext)
def copy(self):
return ExtensionCarrier(self._extensions)
def push(self, extension):
"""Insert a MapperExtension at the beginning of the collection."""
self._register(extension)
self._extensions.insert(0, extension)
def append(self, extension):
"""Append a MapperExtension at the end of the collection."""
self._register(extension)
self._extensions.append(extension)
def __iter__(self):
"""Iterate over MapperExtensions in the collection."""
return iter(self._extensions)
def _register(self, extension):
"""Register callable fronts for overridden interface methods."""
for method in self.interface.difference(self):
impl = getattr(extension, method, None)
if impl and impl is not getattr(MapperExtension, method):
self[method] = self._create_do(method)
def _create_do(self, method):
"""Return a closure that loops over impls of the named method."""
def _do(*args, **kwargs):
for ext in self._extensions:
ret = getattr(ext, method)(*args, **kwargs)
if ret is not EXT_CONTINUE:
return ret
else:
return EXT_CONTINUE
_do.__name__ = method
return _do
@staticmethod
def _pass(*args, **kwargs):
return EXT_CONTINUE
def __getattr__(self, key):
"""Delegate MapperExtension methods to bundled fronts."""
if key not in self.interface:
raise AttributeError(key)
return self.get(key, self._pass)
class ORMAdapter(sql_util.ColumnAdapter):
"""Extends ColumnAdapter to accept ORM entities.
The selectable is extracted from the given entity,
and the AliasedClass if any is referenced.
"""
def __init__(self, entity, equivalents=None, chain_to=None, adapt_required=False):
self.mapper, selectable, is_aliased_class = _entity_info(entity)
if is_aliased_class:
self.aliased_class = entity
else:
self.aliased_class = None
sql_util.ColumnAdapter.__init__(self, selectable, equivalents, chain_to, adapt_required=adapt_required)
def replace(self, elem):
entity = elem._annotations.get('parentmapper', None)
if not entity or entity.isa(self.mapper):
return sql_util.ColumnAdapter.replace(self, elem)
else:
return None
class AliasedClass(object):
"""Represents an "aliased" form of a mapped class for usage with Query.
The ORM equivalent of a :func:`sqlalchemy.sql.expression.alias`
construct, this object mimics the mapped class using a
__getattr__ scheme and maintains a reference to a
real :class:`~sqlalchemy.sql.expression.Alias` object.
Usage is via the :class:`~sqlalchemy.orm.aliased()` synonym::
# find all pairs of users with the same name
user_alias = aliased(User)
session.query(User, user_alias).\\
join((user_alias, User.id > user_alias.id)).\\
filter(User.name==user_alias.name)
"""
def __init__(self, cls, alias=None, name=None):
self.__mapper = _class_to_mapper(cls)
self.__target = self.__mapper.class_
if alias is None:
alias = self.__mapper._with_polymorphic_selectable.alias()
self.__adapter = sql_util.ClauseAdapter(alias, equivalents=self.__mapper._equivalent_columns)
self.__alias = alias
# used to assign a name to the RowTuple object
# returned by Query.
self._sa_label_name = name
self.__name__ = 'AliasedClass_' + str(self.__target)
def __getstate__(self):
return {'mapper':self.__mapper, 'alias':self.__alias, 'name':self._sa_label_name}
def __setstate__(self, state):
self.__mapper = state['mapper']
self.__target = self.__mapper.class_
alias = state['alias']
self.__adapter = sql_util.ClauseAdapter(alias, equivalents=self.__mapper._equivalent_columns)
self.__alias = alias
name = state['name']
self._sa_label_name = name
self.__name__ = 'AliasedClass_' + str(self.__target)
def __adapt_element(self, elem):
return self.__adapter.traverse(elem)._annotate({'parententity': self, 'parentmapper':self.__mapper})
def __adapt_prop(self, prop):
existing = getattr(self.__target, prop.key)
comparator = existing.comparator.adapted(self.__adapt_element)
queryattr = attributes.QueryableAttribute(prop.key,
impl=existing.impl, parententity=self, comparator=comparator)
setattr(self, prop.key, queryattr)
return queryattr
def __getattr__(self, key):
prop = self.__mapper._get_property(key, raiseerr=False)
if prop:
return self.__adapt_prop(prop)
for base in self.__target.__mro__:
try:
attr = object.__getattribute__(base, key)
except AttributeError:
continue
else:
break
else:
raise AttributeError(key)
if hasattr(attr, 'func_code'):
is_method = getattr(self.__target, key, None)
if is_method and is_method.im_self is not None:
return util.types.MethodType(attr.im_func, self, self)
else:
return None
elif hasattr(attr, '__get__'):
return attr.__get__(None, self)
else:
return attr
def __repr__(self):
return '<AliasedClass at 0x%x; %s>' % (
id(self), self.__target.__name__)
def _orm_annotate(element, exclude=None):
"""Deep copy the given ClauseElement, annotating each element with the "_orm_adapt" flag.
Elements within the exclude collection will be cloned but not annotated.
"""
return sql_util._deep_annotate(element, {'_orm_adapt':True}, exclude)
_orm_deannotate = sql_util._deep_deannotate
class _ORMJoin(expression.Join):
"""Extend Join to support ORM constructs as input."""
__visit_name__ = expression.Join.__visit_name__
def __init__(self, left, right, onclause=None, isouter=False, join_to_left=True):
adapt_from = None
if hasattr(left, '_orm_mappers'):
left_mapper = left._orm_mappers[1]
if join_to_left:
adapt_from = left.right
else:
left_mapper, left, left_is_aliased = _entity_info(left)
if join_to_left and (left_is_aliased or not left_mapper):
adapt_from = left
right_mapper, right, right_is_aliased = _entity_info(right)
if right_is_aliased:
adapt_to = right
else:
adapt_to = None
if left_mapper or right_mapper:
self._orm_mappers = (left_mapper, right_mapper)
if isinstance(onclause, basestring):
prop = left_mapper.get_property(onclause)
elif isinstance(onclause, attributes.QueryableAttribute):
if adapt_from is None:
adapt_from = onclause.__clause_element__()
prop = onclause.property
elif isinstance(onclause, MapperProperty):
prop = onclause
else:
prop = None
if prop:
pj, sj, source, dest, secondary, target_adapter = prop._create_joins(
source_selectable=adapt_from,
dest_selectable=adapt_to,
source_polymorphic=True,
dest_polymorphic=True,
of_type=right_mapper)
if sj is not None:
left = sql.join(left, secondary, pj, isouter)
onclause = sj
else:
onclause = pj
self._target_adapter = target_adapter
expression.Join.__init__(self, left, right, onclause, isouter)
def join(self, right, onclause=None, isouter=False, join_to_left=True):
return _ORMJoin(self, right, onclause, isouter, join_to_left)
def outerjoin(self, right, onclause=None, join_to_left=True):
return _ORMJoin(self, right, onclause, True, join_to_left)
def join(left, right, onclause=None, isouter=False, join_to_left=True):
"""Produce an inner join between left and right clauses.
In addition to the interface provided by
:func:`~sqlalchemy.sql.expression.join()`, left and right may be mapped
classes or AliasedClass instances. The onclause may be a
string name of a relationship(), or a class-bound descriptor
representing a relationship.
join_to_left indicates to attempt aliasing the ON clause,
in whatever form it is passed, to the selectable
passed as the left side. If False, the onclause
is used as is.
"""
return _ORMJoin(left, right, onclause, isouter, join_to_left)
def outerjoin(left, right, onclause=None, join_to_left=True):
"""Produce a left outer join between left and right clauses.
In addition to the interface provided by
:func:`~sqlalchemy.sql.expression.outerjoin()`, left and right may be mapped
classes or AliasedClass instances. The onclause may be a
string name of a relationship(), or a class-bound descriptor
representing a relationship.
"""
return _ORMJoin(left, right, onclause, True, join_to_left)
def with_parent(instance, prop):
"""Return criterion which selects instances with a given parent.
instance
a parent instance, which should be persistent or detached.
property
a class-attached descriptor, MapperProperty or string property name
attached to the parent instance.
\**kwargs
all extra keyword arguments are propagated to the constructor of
Query.
"""
if isinstance(prop, basestring):
mapper = object_mapper(instance)
prop = mapper.get_property(prop, resolve_synonyms=True)
elif isinstance(prop, attributes.QueryableAttribute):
prop = prop.property
return prop.compare(operators.eq, instance, value_is_parent=True)
def _entity_info(entity, compile=True):
"""Return mapping information given a class, mapper, or AliasedClass.
Returns 3-tuple of: mapper, mapped selectable, boolean indicating if this
is an aliased() construct.
If the given entity is not a mapper, mapped class, or aliased construct,
returns None, the entity, False. This is typically used to allow
unmapped selectables through.
"""
if isinstance(entity, AliasedClass):
return entity._AliasedClass__mapper, entity._AliasedClass__alias, True
global mapperlib
if mapperlib is None:
from sqlalchemy.orm import mapperlib
if isinstance(entity, mapperlib.Mapper):
mapper = entity
elif isinstance(entity, type):
class_manager = attributes.manager_of_class(entity)
if class_manager is None:
return None, entity, False
mapper = class_manager.mapper
else:
return None, entity, False
if compile:
mapper = mapper.compile()
return mapper, mapper._with_polymorphic_selectable, False
def _entity_descriptor(entity, key):
"""Return attribute/property information given an entity and string name.
Returns a 2-tuple representing InstrumentedAttribute/MapperProperty.
"""
if isinstance(entity, AliasedClass):
try:
desc = getattr(entity, key)
return desc, desc.property
except AttributeError:
raise sa_exc.InvalidRequestError("Entity '%s' has no property '%s'" % (entity, key))
elif isinstance(entity, type):
try:
desc = attributes.manager_of_class(entity)[key]
return desc, desc.property
except KeyError:
raise sa_exc.InvalidRequestError("Entity '%s' has no property '%s'" % (entity, key))
else:
try:
desc = entity.class_manager[key]
return desc, desc.property
except KeyError:
raise sa_exc.InvalidRequestError("Entity '%s' has no property '%s'" % (entity, key))
def _orm_columns(entity):
mapper, selectable, is_aliased_class = _entity_info(entity)
if isinstance(selectable, expression.Selectable):
return [c for c in selectable.c]
else:
return [selectable]
def _orm_selectable(entity):
mapper, selectable, is_aliased_class = _entity_info(entity)
return selectable
def _is_aliased_class(entity):
return isinstance(entity, AliasedClass)
def _state_mapper(state):
return state.manager.mapper
def object_mapper(instance):
"""Given an object, return the primary Mapper associated with the object instance.
Raises UnmappedInstanceError if no mapping is configured.
"""
try:
state = attributes.instance_state(instance)
if not state.manager.mapper:
raise exc.UnmappedInstanceError(instance)
return state.manager.mapper
except exc.NO_STATE:
raise exc.UnmappedInstanceError(instance)
def class_mapper(class_, compile=True):
"""Given a class, return the primary Mapper associated with the key.
Raises UnmappedClassError if no mapping is configured.
"""
try:
class_manager = attributes.manager_of_class(class_)
mapper = class_manager.mapper
# HACK until [ticket:1142] is complete
if mapper is None:
raise AttributeError
except exc.NO_STATE:
raise exc.UnmappedClassError(class_)
if compile:
mapper = mapper.compile()
return mapper
def _class_to_mapper(class_or_mapper, compile=True):
if _is_aliased_class(class_or_mapper):
return class_or_mapper._AliasedClass__mapper
elif isinstance(class_or_mapper, type):
return class_mapper(class_or_mapper, compile=compile)
elif hasattr(class_or_mapper, 'compile'):
if compile:
return class_or_mapper.compile()
else:
return class_or_mapper
else:
raise exc.UnmappedClassError(class_or_mapper)
def has_identity(object):
state = attributes.instance_state(object)
return _state_has_identity(state)
def _state_has_identity(state):
return bool(state.key)
def _is_mapped_class(cls):
global mapperlib
if mapperlib is None:
from sqlalchemy.orm import mapperlib
if isinstance(cls, (AliasedClass, mapperlib.Mapper)):
return True
if isinstance(cls, expression.ClauseElement):
return False
if isinstance(cls, type):
manager = attributes.manager_of_class(cls)
return manager and _INSTRUMENTOR in manager.info
return False
def instance_str(instance):
"""Return a string describing an instance."""
return state_str(attributes.instance_state(instance))
def state_str(state):
"""Return a string describing an instance via its InstanceState."""
if state is None:
return "None"
else:
return '<%s at 0x%x>' % (state.class_.__name__, id(state.obj()))
def attribute_str(instance, attribute):
return instance_str(instance) + "." + attribute
def state_attribute_str(state, attribute):
return state_str(state) + "." + attribute
def identity_equal(a, b):
if a is b:
return True
if a is None or b is None:
return False
try:
state_a = attributes.instance_state(a)
state_b = attributes.instance_state(b)
except exc.NO_STATE:
return False
if state_a.key is None or state_b.key is None:
return False
return state_a.key == state_b.key
# TODO: Avoid circular import.
attributes.identity_equal = identity_equal
attributes._is_aliased_class = _is_aliased_class
attributes._entity_info = _entity_info