257 lines
7.8 KiB
Python
257 lines
7.8 KiB
Python
"""Visitor/traversal interface and library functions.
|
|
|
|
SQLAlchemy schema and expression constructs rely on a Python-centric
|
|
version of the classic "visitor" pattern as the primary way in which
|
|
they apply functionality. The most common use of this pattern
|
|
is statement compilation, where individual expression classes match
|
|
up to rendering methods that produce a string result. Beyond this,
|
|
the visitor system is also used to inspect expressions for various
|
|
information and patterns, as well as for usage in
|
|
some kinds of expression transformation. Other kinds of transformation
|
|
use a non-visitor traversal system.
|
|
|
|
For many examples of how the visit system is used, see the
|
|
sqlalchemy.sql.util and the sqlalchemy.sql.compiler modules.
|
|
For an introduction to clause adaption, see
|
|
http://techspot.zzzeek.org/?p=19 .
|
|
|
|
"""
|
|
|
|
from collections import deque
|
|
import re
|
|
from sqlalchemy import util
|
|
import operator
|
|
|
|
__all__ = ['VisitableType', 'Visitable', 'ClauseVisitor',
|
|
'CloningVisitor', 'ReplacingCloningVisitor', 'iterate',
|
|
'iterate_depthfirst', 'traverse_using', 'traverse',
|
|
'cloned_traverse', 'replacement_traverse']
|
|
|
|
class VisitableType(type):
|
|
"""Metaclass which checks for a `__visit_name__` attribute and
|
|
applies `_compiler_dispatch` method to classes.
|
|
|
|
"""
|
|
|
|
def __init__(cls, clsname, bases, clsdict):
|
|
if cls.__name__ == 'Visitable' or not hasattr(cls, '__visit_name__'):
|
|
super(VisitableType, cls).__init__(clsname, bases, clsdict)
|
|
return
|
|
|
|
# set up an optimized visit dispatch function
|
|
# for use by the compiler
|
|
visit_name = cls.__visit_name__
|
|
if isinstance(visit_name, str):
|
|
getter = operator.attrgetter("visit_%s" % visit_name)
|
|
def _compiler_dispatch(self, visitor, **kw):
|
|
return getter(visitor)(self, **kw)
|
|
else:
|
|
def _compiler_dispatch(self, visitor, **kw):
|
|
return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw)
|
|
|
|
cls._compiler_dispatch = _compiler_dispatch
|
|
|
|
super(VisitableType, cls).__init__(clsname, bases, clsdict)
|
|
|
|
class Visitable(object):
|
|
"""Base class for visitable objects, applies the
|
|
``VisitableType`` metaclass.
|
|
|
|
"""
|
|
|
|
__metaclass__ = VisitableType
|
|
|
|
class ClauseVisitor(object):
|
|
"""Base class for visitor objects which can traverse using
|
|
the traverse() function.
|
|
|
|
"""
|
|
|
|
__traverse_options__ = {}
|
|
|
|
def traverse_single(self, obj):
|
|
for v in self._visitor_iterator:
|
|
meth = getattr(v, "visit_%s" % obj.__visit_name__, None)
|
|
if meth:
|
|
return meth(obj)
|
|
|
|
def iterate(self, obj):
|
|
"""traverse the given expression structure, returning an iterator of all elements."""
|
|
|
|
return iterate(obj, self.__traverse_options__)
|
|
|
|
def traverse(self, obj):
|
|
"""traverse and visit the given expression structure."""
|
|
|
|
return traverse(obj, self.__traverse_options__, self._visitor_dict)
|
|
|
|
@util.memoized_property
|
|
def _visitor_dict(self):
|
|
visitors = {}
|
|
|
|
for name in dir(self):
|
|
if name.startswith('visit_'):
|
|
visitors[name[6:]] = getattr(self, name)
|
|
return visitors
|
|
|
|
@property
|
|
def _visitor_iterator(self):
|
|
"""iterate through this visitor and each 'chained' visitor."""
|
|
|
|
v = self
|
|
while v:
|
|
yield v
|
|
v = getattr(v, '_next', None)
|
|
|
|
def chain(self, visitor):
|
|
"""'chain' an additional ClauseVisitor onto this ClauseVisitor.
|
|
|
|
the chained visitor will receive all visit events after this one.
|
|
|
|
"""
|
|
tail = list(self._visitor_iterator)[-1]
|
|
tail._next = visitor
|
|
return self
|
|
|
|
class CloningVisitor(ClauseVisitor):
|
|
"""Base class for visitor objects which can traverse using
|
|
the cloned_traverse() function.
|
|
|
|
"""
|
|
|
|
def copy_and_process(self, list_):
|
|
"""Apply cloned traversal to the given list of elements, and return the new list."""
|
|
|
|
return [self.traverse(x) for x in list_]
|
|
|
|
def traverse(self, obj):
|
|
"""traverse and visit the given expression structure."""
|
|
|
|
return cloned_traverse(obj, self.__traverse_options__, self._visitor_dict)
|
|
|
|
class ReplacingCloningVisitor(CloningVisitor):
|
|
"""Base class for visitor objects which can traverse using
|
|
the replacement_traverse() function.
|
|
|
|
"""
|
|
|
|
def replace(self, elem):
|
|
"""receive pre-copied elements during a cloning traversal.
|
|
|
|
If the method returns a new element, the element is used
|
|
instead of creating a simple copy of the element. Traversal
|
|
will halt on the newly returned element if it is re-encountered.
|
|
"""
|
|
return None
|
|
|
|
def traverse(self, obj):
|
|
"""traverse and visit the given expression structure."""
|
|
|
|
def replace(elem):
|
|
for v in self._visitor_iterator:
|
|
e = v.replace(elem)
|
|
if e is not None:
|
|
return e
|
|
return replacement_traverse(obj, self.__traverse_options__, replace)
|
|
|
|
def iterate(obj, opts):
|
|
"""traverse the given expression structure, returning an iterator.
|
|
|
|
traversal is configured to be breadth-first.
|
|
|
|
"""
|
|
stack = deque([obj])
|
|
while stack:
|
|
t = stack.popleft()
|
|
yield t
|
|
for c in t.get_children(**opts):
|
|
stack.append(c)
|
|
|
|
def iterate_depthfirst(obj, opts):
|
|
"""traverse the given expression structure, returning an iterator.
|
|
|
|
traversal is configured to be depth-first.
|
|
|
|
"""
|
|
stack = deque([obj])
|
|
traversal = deque()
|
|
while stack:
|
|
t = stack.pop()
|
|
traversal.appendleft(t)
|
|
for c in t.get_children(**opts):
|
|
stack.append(c)
|
|
return iter(traversal)
|
|
|
|
def traverse_using(iterator, obj, visitors):
|
|
"""visit the given expression structure using the given iterator of objects."""
|
|
|
|
for target in iterator:
|
|
meth = visitors.get(target.__visit_name__, None)
|
|
if meth:
|
|
meth(target)
|
|
return obj
|
|
|
|
def traverse(obj, opts, visitors):
|
|
"""traverse and visit the given expression structure using the default iterator."""
|
|
|
|
return traverse_using(iterate(obj, opts), obj, visitors)
|
|
|
|
def traverse_depthfirst(obj, opts, visitors):
|
|
"""traverse and visit the given expression structure using the depth-first iterator."""
|
|
|
|
return traverse_using(iterate_depthfirst(obj, opts), obj, visitors)
|
|
|
|
def cloned_traverse(obj, opts, visitors):
|
|
"""clone the given expression structure, allowing modifications by visitors."""
|
|
|
|
cloned = util.column_dict()
|
|
|
|
def clone(element):
|
|
if element not in cloned:
|
|
cloned[element] = element._clone()
|
|
return cloned[element]
|
|
|
|
obj = clone(obj)
|
|
stack = [obj]
|
|
|
|
while stack:
|
|
t = stack.pop()
|
|
if t in cloned:
|
|
continue
|
|
t._copy_internals(clone=clone)
|
|
|
|
meth = visitors.get(t.__visit_name__, None)
|
|
if meth:
|
|
meth(t)
|
|
|
|
for c in t.get_children(**opts):
|
|
stack.append(c)
|
|
return obj
|
|
|
|
def replacement_traverse(obj, opts, replace):
|
|
"""clone the given expression structure, allowing element replacement by a given replacement function."""
|
|
|
|
cloned = util.column_dict()
|
|
stop_on = util.column_set(opts.get('stop_on', []))
|
|
|
|
def clone(element):
|
|
newelem = replace(element)
|
|
if newelem is not None:
|
|
stop_on.add(newelem)
|
|
return newelem
|
|
|
|
if element not in cloned:
|
|
cloned[element] = element._clone()
|
|
return cloned[element]
|
|
|
|
obj = clone(obj)
|
|
stack = [obj]
|
|
while stack:
|
|
t = stack.pop()
|
|
if t in stop_on:
|
|
continue
|
|
t._copy_internals(clone=clone)
|
|
for c in t.get_children(**opts):
|
|
stack.append(c)
|
|
return obj
|