dibbler/sqlalchemy/sql/visitors.py

329 lines
10 KiB
Python
Raw Normal View History

2017-04-15 18:27:12 +02:00
# sql/visitors.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
2010-05-07 19:33:49 +02:00
"""Visitor/traversal interface and library functions.
SQLAlchemy schema and expression constructs rely on a Python-centric
version of the classic "visitor" pattern as the primary way in which
2017-04-15 18:27:12 +02:00
they apply functionality. The most common use of this pattern
is statement compilation, where individual expression classes match
up to rendering methods that produce a string result. Beyond this,
the visitor system is also used to inspect expressions for various
information and patterns, as well as for usage in
2010-05-07 19:33:49 +02:00
some kinds of expression transformation. Other kinds of transformation
use a non-visitor traversal system.
2017-04-15 18:27:12 +02:00
For many examples of how the visit system is used, see the
2010-05-07 19:33:49 +02:00
sqlalchemy.sql.util and the sqlalchemy.sql.compiler modules.
For an introduction to clause adaption, see
2017-04-15 18:27:12 +02:00
http://techspot.zzzeek.org/2008/01/23/expression-transformations/
2010-05-07 19:33:49 +02:00
"""
from collections import deque
2017-04-15 18:27:12 +02:00
from .. import util
2010-05-07 19:33:49 +02:00
import operator
2017-04-15 18:27:12 +02:00
from .. import exc
__all__ = ['VisitableType', 'Visitable', 'ClauseVisitor',
'CloningVisitor', 'ReplacingCloningVisitor', 'iterate',
'iterate_depthfirst', 'traverse_using', 'traverse',
'traverse_depthfirst',
'cloned_traverse', 'replacement_traverse']
2010-05-07 19:33:49 +02:00
class VisitableType(type):
2017-04-15 18:27:12 +02:00
"""Metaclass which assigns a `_compiler_dispatch` method to classes
having a `__visit_name__` attribute.
The _compiler_dispatch attribute becomes an instance method which
looks approximately like the following::
def _compiler_dispatch (self, visitor, **kw):
'''Look for an attribute named "visit_" + self.__visit_name__
on the visitor, and call it with the same kw params.'''
visit_attr = 'visit_%s' % self.__visit_name__
return getattr(visitor, visit_attr)(self, **kw)
Classes having no __visit_name__ attribute will remain unaffected.
2010-05-07 19:33:49 +02:00
"""
2017-04-15 18:27:12 +02:00
2010-05-07 19:33:49 +02:00
def __init__(cls, clsname, bases, clsdict):
2017-04-15 18:27:12 +02:00
if clsname != 'Visitable' and \
hasattr(cls, '__visit_name__'):
_generate_dispatch(cls)
super(VisitableType, cls).__init__(clsname, bases, clsdict)
def _generate_dispatch(cls):
"""Return an optimized visit dispatch function for the cls
for use by the compiler.
"""
if '__visit_name__' in cls.__dict__:
2010-05-07 19:33:49 +02:00
visit_name = cls.__visit_name__
if isinstance(visit_name, str):
2017-04-15 18:27:12 +02:00
# There is an optimization opportunity here because the
# the string name of the class's __visit_name__ is known at
# this early stage (import time) so it can be pre-constructed.
2010-05-07 19:33:49 +02:00
getter = operator.attrgetter("visit_%s" % visit_name)
2017-04-15 18:27:12 +02:00
2010-05-07 19:33:49 +02:00
def _compiler_dispatch(self, visitor, **kw):
2017-04-15 18:27:12 +02:00
try:
meth = getter(visitor)
except AttributeError:
raise exc.UnsupportedCompilationError(visitor, cls)
else:
return meth(self, **kw)
2010-05-07 19:33:49 +02:00
else:
2017-04-15 18:27:12 +02:00
# The optimization opportunity is lost for this case because the
# __visit_name__ is not yet a string. As a result, the visit
# string has to be recalculated with each compilation.
2010-05-07 19:33:49 +02:00
def _compiler_dispatch(self, visitor, **kw):
2017-04-15 18:27:12 +02:00
visit_attr = 'visit_%s' % self.__visit_name__
try:
meth = getattr(visitor, visit_attr)
except AttributeError:
raise exc.UnsupportedCompilationError(visitor, cls)
else:
return meth(self, **kw)
_compiler_dispatch.__doc__ = \
"""Look for an attribute named "visit_" + self.__visit_name__
on the visitor, and call it with the same kw params.
"""
2010-05-07 19:33:49 +02:00
cls._compiler_dispatch = _compiler_dispatch
2017-04-15 18:27:12 +02:00
class Visitable(util.with_metaclass(VisitableType, object)):
2010-05-07 19:33:49 +02:00
"""Base class for visitable objects, applies the
``VisitableType`` metaclass.
2017-04-15 18:27:12 +02:00
2010-05-07 19:33:49 +02:00
"""
class ClauseVisitor(object):
2017-04-15 18:27:12 +02:00
"""Base class for visitor objects which can traverse using
2010-05-07 19:33:49 +02:00
the traverse() function.
2017-04-15 18:27:12 +02:00
2010-05-07 19:33:49 +02:00
"""
2017-04-15 18:27:12 +02:00
2010-05-07 19:33:49 +02:00
__traverse_options__ = {}
2017-04-15 18:27:12 +02:00
def traverse_single(self, obj, **kw):
2010-05-07 19:33:49 +02:00
for v in self._visitor_iterator:
meth = getattr(v, "visit_%s" % obj.__visit_name__, None)
if meth:
2017-04-15 18:27:12 +02:00
return meth(obj, **kw)
2010-05-07 19:33:49 +02:00
def iterate(self, obj):
2017-04-15 18:27:12 +02:00
"""traverse the given expression structure, returning an iterator
of all elements.
2010-05-07 19:33:49 +02:00
2017-04-15 18:27:12 +02:00
"""
2010-05-07 19:33:49 +02:00
return iterate(obj, self.__traverse_options__)
2017-04-15 18:27:12 +02:00
2010-05-07 19:33:49 +02:00
def traverse(self, obj):
"""traverse and visit the given expression structure."""
return traverse(obj, self.__traverse_options__, self._visitor_dict)
2017-04-15 18:27:12 +02:00
2010-05-07 19:33:49 +02:00
@util.memoized_property
def _visitor_dict(self):
visitors = {}
for name in dir(self):
if name.startswith('visit_'):
visitors[name[6:]] = getattr(self, name)
return visitors
2017-04-15 18:27:12 +02:00
2010-05-07 19:33:49 +02:00
@property
def _visitor_iterator(self):
"""iterate through this visitor and each 'chained' visitor."""
2017-04-15 18:27:12 +02:00
2010-05-07 19:33:49 +02:00
v = self
while v:
yield v
v = getattr(v, '_next', None)
def chain(self, visitor):
"""'chain' an additional ClauseVisitor onto this ClauseVisitor.
2017-04-15 18:27:12 +02:00
2010-05-07 19:33:49 +02:00
the chained visitor will receive all visit events after this one.
2017-04-15 18:27:12 +02:00
2010-05-07 19:33:49 +02:00
"""
tail = list(self._visitor_iterator)[-1]
tail._next = visitor
return self
2017-04-15 18:27:12 +02:00
2010-05-07 19:33:49 +02:00
class CloningVisitor(ClauseVisitor):
2017-04-15 18:27:12 +02:00
"""Base class for visitor objects which can traverse using
2010-05-07 19:33:49 +02:00
the cloned_traverse() function.
2017-04-15 18:27:12 +02:00
2010-05-07 19:33:49 +02:00
"""
def copy_and_process(self, list_):
2017-04-15 18:27:12 +02:00
"""Apply cloned traversal to the given list of elements, and return
the new list.
2010-05-07 19:33:49 +02:00
2017-04-15 18:27:12 +02:00
"""
2010-05-07 19:33:49 +02:00
return [self.traverse(x) for x in list_]
def traverse(self, obj):
"""traverse and visit the given expression structure."""
2017-04-15 18:27:12 +02:00
return cloned_traverse(
obj, self.__traverse_options__, self._visitor_dict)
2010-05-07 19:33:49 +02:00
class ReplacingCloningVisitor(CloningVisitor):
2017-04-15 18:27:12 +02:00
"""Base class for visitor objects which can traverse using
2010-05-07 19:33:49 +02:00
the replacement_traverse() function.
2017-04-15 18:27:12 +02:00
2010-05-07 19:33:49 +02:00
"""
def replace(self, elem):
"""receive pre-copied elements during a cloning traversal.
2017-04-15 18:27:12 +02:00
If the method returns a new element, the element is used
instead of creating a simple copy of the element. Traversal
2010-05-07 19:33:49 +02:00
will halt on the newly returned element if it is re-encountered.
"""
return None
def traverse(self, obj):
"""traverse and visit the given expression structure."""
def replace(elem):
for v in self._visitor_iterator:
e = v.replace(elem)
if e is not None:
return e
return replacement_traverse(obj, self.__traverse_options__, replace)
2017-04-15 18:27:12 +02:00
2010-05-07 19:33:49 +02:00
def iterate(obj, opts):
"""traverse the given expression structure, returning an iterator.
2017-04-15 18:27:12 +02:00
2010-05-07 19:33:49 +02:00
traversal is configured to be breadth-first.
2017-04-15 18:27:12 +02:00
2010-05-07 19:33:49 +02:00
"""
2017-04-15 18:27:12 +02:00
# fasttrack for atomic elements like columns
children = obj.get_children(**opts)
if not children:
return [obj]
traversal = deque()
2010-05-07 19:33:49 +02:00
stack = deque([obj])
while stack:
t = stack.popleft()
2017-04-15 18:27:12 +02:00
traversal.append(t)
2010-05-07 19:33:49 +02:00
for c in t.get_children(**opts):
stack.append(c)
2017-04-15 18:27:12 +02:00
return iter(traversal)
2010-05-07 19:33:49 +02:00
def iterate_depthfirst(obj, opts):
"""traverse the given expression structure, returning an iterator.
2017-04-15 18:27:12 +02:00
2010-05-07 19:33:49 +02:00
traversal is configured to be depth-first.
2017-04-15 18:27:12 +02:00
2010-05-07 19:33:49 +02:00
"""
2017-04-15 18:27:12 +02:00
# fasttrack for atomic elements like columns
children = obj.get_children(**opts)
if not children:
return [obj]
2010-05-07 19:33:49 +02:00
stack = deque([obj])
traversal = deque()
while stack:
t = stack.pop()
traversal.appendleft(t)
for c in t.get_children(**opts):
stack.append(c)
return iter(traversal)
2017-04-15 18:27:12 +02:00
2010-05-07 19:33:49 +02:00
def traverse_using(iterator, obj, visitors):
2017-04-15 18:27:12 +02:00
"""visit the given expression structure using the given iterator of
objects.
2010-05-07 19:33:49 +02:00
2017-04-15 18:27:12 +02:00
"""
2010-05-07 19:33:49 +02:00
for target in iterator:
meth = visitors.get(target.__visit_name__, None)
if meth:
meth(target)
return obj
2017-04-15 18:27:12 +02:00
2010-05-07 19:33:49 +02:00
def traverse(obj, opts, visitors):
2017-04-15 18:27:12 +02:00
"""traverse and visit the given expression structure using the default
iterator.
2010-05-07 19:33:49 +02:00
2017-04-15 18:27:12 +02:00
"""
2010-05-07 19:33:49 +02:00
return traverse_using(iterate(obj, opts), obj, visitors)
2017-04-15 18:27:12 +02:00
2010-05-07 19:33:49 +02:00
def traverse_depthfirst(obj, opts, visitors):
2017-04-15 18:27:12 +02:00
"""traverse and visit the given expression structure using the
depth-first iterator.
2010-05-07 19:33:49 +02:00
2017-04-15 18:27:12 +02:00
"""
2010-05-07 19:33:49 +02:00
return traverse_using(iterate_depthfirst(obj, opts), obj, visitors)
2017-04-15 18:27:12 +02:00
2010-05-07 19:33:49 +02:00
def cloned_traverse(obj, opts, visitors):
2017-04-15 18:27:12 +02:00
"""clone the given expression structure, allowing
modifications by visitors."""
2010-05-07 19:33:49 +02:00
2017-04-15 18:27:12 +02:00
cloned = {}
stop_on = set(opts.get('stop_on', []))
2010-05-07 19:33:49 +02:00
2017-04-15 18:27:12 +02:00
def clone(elem):
if elem in stop_on:
return elem
else:
if id(elem) not in cloned:
cloned[id(elem)] = newelem = elem._clone()
newelem._copy_internals(clone=clone)
meth = visitors.get(newelem.__visit_name__, None)
if meth:
meth(newelem)
return cloned[id(elem)]
if obj is not None:
obj = clone(obj)
return obj
2010-05-07 19:33:49 +02:00
2017-04-15 18:27:12 +02:00
def replacement_traverse(obj, opts, replace):
"""clone the given expression structure, allowing element
replacement by a given replacement function."""
2010-05-07 19:33:49 +02:00
2017-04-15 18:27:12 +02:00
cloned = {}
stop_on = set([id(x) for x in opts.get('stop_on', [])])
2010-05-07 19:33:49 +02:00
2017-04-15 18:27:12 +02:00
def clone(elem, **kw):
if id(elem) in stop_on or \
'no_replacement_traverse' in elem._annotations:
return elem
else:
newelem = replace(elem)
if newelem is not None:
stop_on.add(id(newelem))
return newelem
else:
if elem not in cloned:
cloned[elem] = newelem = elem._clone()
newelem._copy_internals(clone=clone, **kw)
return cloned[elem]
if obj is not None:
obj = clone(obj, **opts)
2010-05-07 19:33:49 +02:00
return obj