dibbler/sqlalchemy/sql/compiler.py

3035 lines
105 KiB
Python

# sql/compiler.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Base SQL and DDL compiler implementations.
Classes provided include:
:class:`.compiler.SQLCompiler` - renders SQL
strings
:class:`.compiler.DDLCompiler` - renders DDL
(data definition language) strings
:class:`.compiler.GenericTypeCompiler` - renders
type specification strings.
To generate user-defined SQL strings, see
:doc:`/ext/compiler`.
"""
import contextlib
import re
from . import schema, sqltypes, operators, functions, visitors, \
elements, selectable, crud
from .. import util, exc
import itertools
RESERVED_WORDS = set([
'all', 'analyse', 'analyze', 'and', 'any', 'array',
'as', 'asc', 'asymmetric', 'authorization', 'between',
'binary', 'both', 'case', 'cast', 'check', 'collate',
'column', 'constraint', 'create', 'cross', 'current_date',
'current_role', 'current_time', 'current_timestamp',
'current_user', 'default', 'deferrable', 'desc',
'distinct', 'do', 'else', 'end', 'except', 'false',
'for', 'foreign', 'freeze', 'from', 'full', 'grant',
'group', 'having', 'ilike', 'in', 'initially', 'inner',
'intersect', 'into', 'is', 'isnull', 'join', 'leading',
'left', 'like', 'limit', 'localtime', 'localtimestamp',
'natural', 'new', 'not', 'notnull', 'null', 'off', 'offset',
'old', 'on', 'only', 'or', 'order', 'outer', 'overlaps',
'placing', 'primary', 'references', 'right', 'select',
'session_user', 'set', 'similar', 'some', 'symmetric', 'table',
'then', 'to', 'trailing', 'true', 'union', 'unique', 'user',
'using', 'verbose', 'when', 'where'])
LEGAL_CHARACTERS = re.compile(r'^[A-Z0-9_$]+$', re.I)
ILLEGAL_INITIAL_CHARACTERS = set([str(x) for x in range(0, 10)]).union(['$'])
BIND_PARAMS = re.compile(r'(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])', re.UNICODE)
BIND_PARAMS_ESC = re.compile(r'\x5c(:[\w\$]*)(?![:\w\$])', re.UNICODE)
BIND_TEMPLATES = {
'pyformat': "%%(%(name)s)s",
'qmark': "?",
'format': "%%s",
'numeric': ":[_POSITION]",
'named': ":%(name)s"
}
OPERATORS = {
# binary
operators.and_: ' AND ',
operators.or_: ' OR ',
operators.add: ' + ',
operators.mul: ' * ',
operators.sub: ' - ',
operators.div: ' / ',
operators.mod: ' % ',
operators.truediv: ' / ',
operators.neg: '-',
operators.lt: ' < ',
operators.le: ' <= ',
operators.ne: ' != ',
operators.gt: ' > ',
operators.ge: ' >= ',
operators.eq: ' = ',
operators.is_distinct_from: ' IS DISTINCT FROM ',
operators.isnot_distinct_from: ' IS NOT DISTINCT FROM ',
operators.concat_op: ' || ',
operators.match_op: ' MATCH ',
operators.notmatch_op: ' NOT MATCH ',
operators.in_op: ' IN ',
operators.notin_op: ' NOT IN ',
operators.comma_op: ', ',
operators.from_: ' FROM ',
operators.as_: ' AS ',
operators.is_: ' IS ',
operators.isnot: ' IS NOT ',
operators.collate: ' COLLATE ',
# unary
operators.exists: 'EXISTS ',
operators.distinct_op: 'DISTINCT ',
operators.inv: 'NOT ',
operators.any_op: 'ANY ',
operators.all_op: 'ALL ',
# modifiers
operators.desc_op: ' DESC',
operators.asc_op: ' ASC',
operators.nullsfirst_op: ' NULLS FIRST',
operators.nullslast_op: ' NULLS LAST',
}
FUNCTIONS = {
functions.coalesce: 'coalesce%(expr)s',
functions.current_date: 'CURRENT_DATE',
functions.current_time: 'CURRENT_TIME',
functions.current_timestamp: 'CURRENT_TIMESTAMP',
functions.current_user: 'CURRENT_USER',
functions.localtime: 'LOCALTIME',
functions.localtimestamp: 'LOCALTIMESTAMP',
functions.random: 'random%(expr)s',
functions.sysdate: 'sysdate',
functions.session_user: 'SESSION_USER',
functions.user: 'USER'
}
EXTRACT_MAP = {
'month': 'month',
'day': 'day',
'year': 'year',
'second': 'second',
'hour': 'hour',
'doy': 'doy',
'minute': 'minute',
'quarter': 'quarter',
'dow': 'dow',
'week': 'week',
'epoch': 'epoch',
'milliseconds': 'milliseconds',
'microseconds': 'microseconds',
'timezone_hour': 'timezone_hour',
'timezone_minute': 'timezone_minute'
}
COMPOUND_KEYWORDS = {
selectable.CompoundSelect.UNION: 'UNION',
selectable.CompoundSelect.UNION_ALL: 'UNION ALL',
selectable.CompoundSelect.EXCEPT: 'EXCEPT',
selectable.CompoundSelect.EXCEPT_ALL: 'EXCEPT ALL',
selectable.CompoundSelect.INTERSECT: 'INTERSECT',
selectable.CompoundSelect.INTERSECT_ALL: 'INTERSECT ALL'
}
class Compiled(object):
"""Represent a compiled SQL or DDL expression.
The ``__str__`` method of the ``Compiled`` object should produce
the actual text of the statement. ``Compiled`` objects are
specific to their underlying database dialect, and also may
or may not be specific to the columns referenced within a
particular set of bind parameters. In no case should the
``Compiled`` object be dependent on the actual values of those
bind parameters, even though it may reference those values as
defaults.
"""
_cached_metadata = None
execution_options = util.immutabledict()
"""
Execution options propagated from the statement. In some cases,
sub-elements of the statement can modify these.
"""
def __init__(self, dialect, statement, bind=None,
schema_translate_map=None,
compile_kwargs=util.immutabledict()):
"""Construct a new :class:`.Compiled` object.
:param dialect: :class:`.Dialect` to compile against.
:param statement: :class:`.ClauseElement` to be compiled.
:param bind: Optional Engine or Connection to compile this
statement against.
:param schema_translate_map: dictionary of schema names to be
translated when forming the resultant SQL
.. versionadded:: 1.1
.. seealso::
:ref:`schema_translating`
:param compile_kwargs: additional kwargs that will be
passed to the initial call to :meth:`.Compiled.process`.
"""
self.dialect = dialect
self.bind = bind
self.preparer = self.dialect.identifier_preparer
if schema_translate_map:
self.preparer = self.preparer._with_schema_translate(
schema_translate_map)
if statement is not None:
self.statement = statement
self.can_execute = statement.supports_execution
if self.can_execute:
self.execution_options = statement._execution_options
self.string = self.process(self.statement, **compile_kwargs)
@util.deprecated("0.7", ":class:`.Compiled` objects now compile "
"within the constructor.")
def compile(self):
"""Produce the internal string representation of this element.
"""
pass
def _execute_on_connection(self, connection, multiparams, params):
if self.can_execute:
return connection._execute_compiled(self, multiparams, params)
else:
raise exc.ObjectNotExecutableError(self.statement)
@property
def sql_compiler(self):
"""Return a Compiled that is capable of processing SQL expressions.
If this compiler is one, it would likely just return 'self'.
"""
raise NotImplementedError()
def process(self, obj, **kwargs):
return obj._compiler_dispatch(self, **kwargs)
def __str__(self):
"""Return the string text of the generated SQL or DDL."""
return self.string or ''
def construct_params(self, params=None):
"""Return the bind params for this compiled object.
:param params: a dict of string/object pairs whose values will
override bind values compiled in to the
statement.
"""
raise NotImplementedError()
@property
def params(self):
"""Return the bind params for this compiled object."""
return self.construct_params()
def execute(self, *multiparams, **params):
"""Execute this compiled object."""
e = self.bind
if e is None:
raise exc.UnboundExecutionError(
"This Compiled object is not bound to any Engine "
"or Connection.")
return e._execute_compiled(self, multiparams, params)
def scalar(self, *multiparams, **params):
"""Execute this compiled object and return the result's
scalar value."""
return self.execute(*multiparams, **params).scalar()
class TypeCompiler(util.with_metaclass(util.EnsureKWArgType, object)):
"""Produces DDL specification for TypeEngine objects."""
ensure_kwarg = r'visit_\w+'
def __init__(self, dialect):
self.dialect = dialect
def process(self, type_, **kw):
return type_._compiler_dispatch(self, **kw)
class _CompileLabel(visitors.Visitable):
"""lightweight label object which acts as an expression.Label."""
__visit_name__ = 'label'
__slots__ = 'element', 'name'
def __init__(self, col, name, alt_names=()):
self.element = col
self.name = name
self._alt_names = (col,) + alt_names
@property
def proxy_set(self):
return self.element.proxy_set
@property
def type(self):
return self.element.type
def self_group(self, **kw):
return self
class SQLCompiler(Compiled):
"""Default implementation of :class:`.Compiled`.
Compiles :class:`.ClauseElement` objects into SQL strings.
"""
extract_map = EXTRACT_MAP
compound_keywords = COMPOUND_KEYWORDS
isdelete = isinsert = isupdate = False
"""class-level defaults which can be set at the instance
level to define if this Compiled instance represents
INSERT/UPDATE/DELETE
"""
isplaintext = False
returning = None
"""holds the "returning" collection of columns if
the statement is CRUD and defines returning columns
either implicitly or explicitly
"""
returning_precedes_values = False
"""set to True classwide to generate RETURNING
clauses before the VALUES or WHERE clause (i.e. MSSQL)
"""
render_table_with_column_in_update_from = False
"""set to True classwide to indicate the SET clause
in a multi-table UPDATE statement should qualify
columns with the table name (i.e. MySQL only)
"""
ansi_bind_rules = False
"""SQL 92 doesn't allow bind parameters to be used
in the columns clause of a SELECT, nor does it allow
ambiguous expressions like "? = ?". A compiler
subclass can set this flag to False if the target
driver/DB enforces this
"""
_textual_ordered_columns = False
"""tell the result object that the column names as rendered are important,
but they are also "ordered" vs. what is in the compiled object here.
"""
_ordered_columns = True
"""
if False, means we can't be sure the list of entries
in _result_columns is actually the rendered order. Usually
True unless using an unordered TextAsFrom.
"""
insert_prefetch = update_prefetch = ()
def __init__(self, dialect, statement, column_keys=None,
inline=False, **kwargs):
"""Construct a new :class:`.SQLCompiler` object.
:param dialect: :class:`.Dialect` to be used
:param statement: :class:`.ClauseElement` to be compiled
:param column_keys: a list of column names to be compiled into an
INSERT or UPDATE statement.
:param inline: whether to generate INSERT statements as "inline", e.g.
not formatted to return any generated defaults
:param kwargs: additional keyword arguments to be consumed by the
superclass.
"""
self.column_keys = column_keys
# compile INSERT/UPDATE defaults/sequences inlined (no pre-
# execute)
self.inline = inline or getattr(statement, 'inline', False)
# a dictionary of bind parameter keys to BindParameter
# instances.
self.binds = {}
# a dictionary of BindParameter instances to "compiled" names
# that are actually present in the generated SQL
self.bind_names = util.column_dict()
# stack which keeps track of nested SELECT statements
self.stack = []
# relates label names in the final SQL to a tuple of local
# column/label name, ColumnElement object (if any) and
# TypeEngine. ResultProxy uses this for type processing and
# column targeting
self._result_columns = []
# true if the paramstyle is positional
self.positional = dialect.positional
if self.positional:
self.positiontup = []
self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
self.ctes = None
self.label_length = dialect.label_length \
or dialect.max_identifier_length
# a map which tracks "anonymous" identifiers that are created on
# the fly here
self.anon_map = util.PopulateDict(self._process_anon)
# a map which tracks "truncated" names based on
# dialect.label_length or dialect.max_identifier_length
self.truncated_names = {}
Compiled.__init__(self, dialect, statement, **kwargs)
if (
self.isinsert or self.isupdate or self.isdelete
) and statement._returning:
self.returning = statement._returning
if self.positional and dialect.paramstyle == 'numeric':
self._apply_numbered_params()
@property
def prefetch(self):
return list(self.insert_prefetch + self.update_prefetch)
@util.memoized_instancemethod
def _init_cte_state(self):
"""Initialize collections related to CTEs only if
a CTE is located, to save on the overhead of
these collections otherwise.
"""
# collect CTEs to tack on top of a SELECT
self.ctes = util.OrderedDict()
self.ctes_by_name = {}
self.ctes_recursive = False
if self.positional:
self.cte_positional = {}
@contextlib.contextmanager
def _nested_result(self):
"""special API to support the use case of 'nested result sets'"""
result_columns, ordered_columns = (
self._result_columns, self._ordered_columns)
self._result_columns, self._ordered_columns = [], False
try:
if self.stack:
entry = self.stack[-1]
entry['need_result_map_for_nested'] = True
else:
entry = None
yield self._result_columns, self._ordered_columns
finally:
if entry:
entry.pop('need_result_map_for_nested')
self._result_columns, self._ordered_columns = (
result_columns, ordered_columns)
def _apply_numbered_params(self):
poscount = itertools.count(1)
self.string = re.sub(
r'\[_POSITION\]',
lambda m: str(util.next(poscount)),
self.string)
@util.memoized_property
def _bind_processors(self):
return dict(
(key, value) for key, value in
((self.bind_names[bindparam],
bindparam.type._cached_bind_processor(self.dialect))
for bindparam in self.bind_names)
if value is not None
)
def is_subquery(self):
return len(self.stack) > 1
@property
def sql_compiler(self):
return self
def construct_params(self, params=None, _group_number=None, _check=True):
"""return a dictionary of bind parameter keys and values"""
if params:
pd = {}
for bindparam in self.bind_names:
name = self.bind_names[bindparam]
if bindparam.key in params:
pd[name] = params[bindparam.key]
elif name in params:
pd[name] = params[name]
elif _check and bindparam.required:
if _group_number:
raise exc.InvalidRequestError(
"A value is required for bind parameter %r, "
"in parameter group %d" %
(bindparam.key, _group_number))
else:
raise exc.InvalidRequestError(
"A value is required for bind parameter %r"
% bindparam.key)
elif bindparam.callable:
pd[name] = bindparam.effective_value
else:
pd[name] = bindparam.value
return pd
else:
pd = {}
for bindparam in self.bind_names:
if _check and bindparam.required:
if _group_number:
raise exc.InvalidRequestError(
"A value is required for bind parameter %r, "
"in parameter group %d" %
(bindparam.key, _group_number))
else:
raise exc.InvalidRequestError(
"A value is required for bind parameter %r"
% bindparam.key)
if bindparam.callable:
pd[self.bind_names[bindparam]] = bindparam.effective_value
else:
pd[self.bind_names[bindparam]] = bindparam.value
return pd
@property
def params(self):
"""Return the bind param dictionary embedded into this
compiled object, for those values that are present."""
return self.construct_params(_check=False)
@util.dependencies("sqlalchemy.engine.result")
def _create_result_map(self, result):
"""utility method used for unit tests only."""
return result.ResultMetaData._create_result_map(self._result_columns)
def default_from(self):
"""Called when a SELECT statement has no froms, and no FROM clause is
to be appended.
Gives Oracle a chance to tack on a ``FROM DUAL`` to the string output.
"""
return ""
def visit_grouping(self, grouping, asfrom=False, **kwargs):
return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
def visit_label_reference(
self, element, within_columns_clause=False, **kwargs):
if self.stack and self.dialect.supports_simple_order_by_label:
selectable = self.stack[-1]['selectable']
with_cols, only_froms, only_cols = selectable._label_resolve_dict
if within_columns_clause:
resolve_dict = only_froms
else:
resolve_dict = only_cols
# this can be None in the case that a _label_reference()
# were subject to a replacement operation, in which case
# the replacement of the Label element may have changed
# to something else like a ColumnClause expression.
order_by_elem = element.element._order_by_label_element
if order_by_elem is not None and order_by_elem.name in \
resolve_dict and \
order_by_elem.shares_lineage(
resolve_dict[order_by_elem.name]):
kwargs['render_label_as_label'] = \
element.element._order_by_label_element
return self.process(
element.element, within_columns_clause=within_columns_clause,
**kwargs)
def visit_textual_label_reference(
self, element, within_columns_clause=False, **kwargs):
if not self.stack:
# compiling the element outside of the context of a SELECT
return self.process(
element._text_clause
)
selectable = self.stack[-1]['selectable']
with_cols, only_froms, only_cols = selectable._label_resolve_dict
try:
if within_columns_clause:
col = only_froms[element.element]
else:
col = with_cols[element.element]
except KeyError:
# treat it like text()
util.warn_limited(
"Can't resolve label reference %r; converting to text()",
util.ellipses_string(element.element))
return self.process(
element._text_clause
)
else:
kwargs['render_label_as_label'] = col
return self.process(
col, within_columns_clause=within_columns_clause, **kwargs)
def visit_label(self, label,
add_to_result_map=None,
within_label_clause=False,
within_columns_clause=False,
render_label_as_label=None,
**kw):
# only render labels within the columns clause
# or ORDER BY clause of a select. dialect-specific compilers
# can modify this behavior.
render_label_with_as = (within_columns_clause and not
within_label_clause)
render_label_only = render_label_as_label is label
if render_label_only or render_label_with_as:
if isinstance(label.name, elements._truncated_label):
labelname = self._truncated_identifier("colident", label.name)
else:
labelname = label.name
if render_label_with_as:
if add_to_result_map is not None:
add_to_result_map(
labelname,
label.name,
(label, labelname, ) + label._alt_names,
label.type
)
return label.element._compiler_dispatch(
self, within_columns_clause=True,
within_label_clause=True, **kw) + \
OPERATORS[operators.as_] + \
self.preparer.format_label(label, labelname)
elif render_label_only:
return self.preparer.format_label(label, labelname)
else:
return label.element._compiler_dispatch(
self, within_columns_clause=False, **kw)
def _fallback_column_name(self, column):
raise exc.CompileError("Cannot compile Column object until "
"its 'name' is assigned.")
def visit_column(self, column, add_to_result_map=None,
include_table=True, **kwargs):
name = orig_name = column.name
if name is None:
name = self._fallback_column_name(column)
is_literal = column.is_literal
if not is_literal and isinstance(name, elements._truncated_label):
name = self._truncated_identifier("colident", name)
if add_to_result_map is not None:
add_to_result_map(
name,
orig_name,
(column, name, column.key),
column.type
)
if is_literal:
name = self.escape_literal_column(name)
else:
name = self.preparer.quote(name)
table = column.table
if table is None or not include_table or not table.named_with_column:
return name
else:
effective_schema = self.preparer.schema_for_object(table)
if effective_schema:
schema_prefix = self.preparer.quote_schema(
effective_schema) + '.'
else:
schema_prefix = ''
tablename = table.name
if isinstance(tablename, elements._truncated_label):
tablename = self._truncated_identifier("alias", tablename)
return schema_prefix + \
self.preparer.quote(tablename) + \
"." + name
def escape_literal_column(self, text):
"""provide escaping for the literal_column() construct."""
# TODO: some dialects might need different behavior here
return text.replace('%', '%%')
def visit_fromclause(self, fromclause, **kwargs):
return fromclause.name
def visit_index(self, index, **kwargs):
return index.name
def visit_typeclause(self, typeclause, **kw):
kw['type_expression'] = typeclause
return self.dialect.type_compiler.process(typeclause.type, **kw)
def post_process_text(self, text):
return text
def visit_textclause(self, textclause, **kw):
def do_bindparam(m):
name = m.group(1)
if name in textclause._bindparams:
return self.process(textclause._bindparams[name], **kw)
else:
return self.bindparam_string(name, **kw)
if not self.stack:
self.isplaintext = True
# un-escape any \:params
return BIND_PARAMS_ESC.sub(
lambda m: m.group(1),
BIND_PARAMS.sub(
do_bindparam,
self.post_process_text(textclause.text))
)
def visit_text_as_from(self, taf,
compound_index=None,
asfrom=False,
parens=True, **kw):
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
populate_result_map = toplevel or \
(
compound_index == 0 and entry.get(
'need_result_map_for_compound', False)
) or entry.get('need_result_map_for_nested', False)
if populate_result_map:
self._ordered_columns = \
self._textual_ordered_columns = taf.positional
for c in taf.column_args:
self.process(c, within_columns_clause=True,
add_to_result_map=self._add_to_result_map)
text = self.process(taf.element, **kw)
if asfrom and parens:
text = "(%s)" % text
return text
def visit_null(self, expr, **kw):
return 'NULL'
def visit_true(self, expr, **kw):
if self.dialect.supports_native_boolean:
return 'true'
else:
return "1"
def visit_false(self, expr, **kw):
if self.dialect.supports_native_boolean:
return 'false'
else:
return "0"
def visit_clauselist(self, clauselist, **kw):
sep = clauselist.operator
if sep is None:
sep = " "
else:
sep = OPERATORS[clauselist.operator]
return sep.join(
s for s in
(
c._compiler_dispatch(self, **kw)
for c in clauselist.clauses)
if s)
def visit_case(self, clause, **kwargs):
x = "CASE "
if clause.value is not None:
x += clause.value._compiler_dispatch(self, **kwargs) + " "
for cond, result in clause.whens:
x += "WHEN " + cond._compiler_dispatch(
self, **kwargs
) + " THEN " + result._compiler_dispatch(
self, **kwargs) + " "
if clause.else_ is not None:
x += "ELSE " + clause.else_._compiler_dispatch(
self, **kwargs
) + " "
x += "END"
return x
def visit_type_coerce(self, type_coerce, **kw):
return type_coerce.typed_expression._compiler_dispatch(self, **kw)
def visit_cast(self, cast, **kwargs):
return "CAST(%s AS %s)" % \
(cast.clause._compiler_dispatch(self, **kwargs),
cast.typeclause._compiler_dispatch(self, **kwargs))
def _format_frame_clause(self, range_, **kw):
return '%s AND %s' % (
"UNBOUNDED PRECEDING"
if range_[0] is elements.RANGE_UNBOUNDED
else "CURRENT ROW" if range_[0] is elements.RANGE_CURRENT
else "%s PRECEDING" % (self.process(range_[0], **kw), ),
"UNBOUNDED FOLLOWING"
if range_[1] is elements.RANGE_UNBOUNDED
else "CURRENT ROW" if range_[1] is elements.RANGE_CURRENT
else "%s FOLLOWING" % (self.process(range_[1], **kw), )
)
def visit_over(self, over, **kwargs):
if over.range_:
range_ = "RANGE BETWEEN %s" % self._format_frame_clause(
over.range_, **kwargs)
elif over.rows:
range_ = "ROWS BETWEEN %s" % self._format_frame_clause(
over.rows, **kwargs)
else:
range_ = None
return "%s OVER (%s)" % (
over.element._compiler_dispatch(self, **kwargs),
' '.join([
'%s BY %s' % (
word, clause._compiler_dispatch(self, **kwargs)
)
for word, clause in (
('PARTITION', over.partition_by),
('ORDER', over.order_by)
)
if clause is not None and len(clause)
] + ([range_] if range_ else [])
)
)
def visit_withingroup(self, withingroup, **kwargs):
return "%s WITHIN GROUP (ORDER BY %s)" % (
withingroup.element._compiler_dispatch(self, **kwargs),
withingroup.order_by._compiler_dispatch(self, **kwargs)
)
def visit_funcfilter(self, funcfilter, **kwargs):
return "%s FILTER (WHERE %s)" % (
funcfilter.func._compiler_dispatch(self, **kwargs),
funcfilter.criterion._compiler_dispatch(self, **kwargs)
)
def visit_extract(self, extract, **kwargs):
field = self.extract_map.get(extract.field, extract.field)
return "EXTRACT(%s FROM %s)" % (
field, extract.expr._compiler_dispatch(self, **kwargs))
def visit_function(self, func, add_to_result_map=None, **kwargs):
if add_to_result_map is not None:
add_to_result_map(
func.name, func.name, (), func.type
)
disp = getattr(self, "visit_%s_func" % func.name.lower(), None)
if disp:
return disp(func, **kwargs)
else:
name = FUNCTIONS.get(func.__class__, func.name + "%(expr)s")
return ".".join(list(func.packagenames) + [name]) % \
{'expr': self.function_argspec(func, **kwargs)}
def visit_next_value_func(self, next_value, **kw):
return self.visit_sequence(next_value.sequence)
def visit_sequence(self, sequence):
raise NotImplementedError(
"Dialect '%s' does not support sequence increments." %
self.dialect.name
)
def function_argspec(self, func, **kwargs):
return func.clause_expr._compiler_dispatch(self, **kwargs)
def visit_compound_select(self, cs, asfrom=False,
parens=True, compound_index=0, **kwargs):
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
need_result_map = toplevel or \
(compound_index == 0
and entry.get('need_result_map_for_compound', False))
self.stack.append(
{
'correlate_froms': entry['correlate_froms'],
'asfrom_froms': entry['asfrom_froms'],
'selectable': cs,
'need_result_map_for_compound': need_result_map
})
keyword = self.compound_keywords.get(cs.keyword)
text = (" " + keyword + " ").join(
(c._compiler_dispatch(self,
asfrom=asfrom, parens=False,
compound_index=i, **kwargs)
for i, c in enumerate(cs.selects))
)
group_by = cs._group_by_clause._compiler_dispatch(
self, asfrom=asfrom, **kwargs)
if group_by:
text += " GROUP BY " + group_by
text += self.order_by_clause(cs, **kwargs)
text += (cs._limit_clause is not None
or cs._offset_clause is not None) and \
self.limit_clause(cs, **kwargs) or ""
if self.ctes and toplevel:
text = self._render_cte_clause() + text
self.stack.pop(-1)
if asfrom and parens:
return "(" + text + ")"
else:
return text
def _get_operator_dispatch(self, operator_, qualifier1, qualifier2):
attrname = "visit_%s_%s%s" % (
operator_.__name__, qualifier1,
"_" + qualifier2 if qualifier2 else "")
return getattr(self, attrname, None)
def visit_unary(self, unary, **kw):
if unary.operator:
if unary.modifier:
raise exc.CompileError(
"Unary expression does not support operator "
"and modifier simultaneously")
disp = self._get_operator_dispatch(
unary.operator, "unary", "operator")
if disp:
return disp(unary, unary.operator, **kw)
else:
return self._generate_generic_unary_operator(
unary, OPERATORS[unary.operator], **kw)
elif unary.modifier:
disp = self._get_operator_dispatch(
unary.modifier, "unary", "modifier")
if disp:
return disp(unary, unary.modifier, **kw)
else:
return self._generate_generic_unary_modifier(
unary, OPERATORS[unary.modifier], **kw)
else:
raise exc.CompileError(
"Unary expression has no operator or modifier")
def visit_istrue_unary_operator(self, element, operator, **kw):
if self.dialect.supports_native_boolean:
return self.process(element.element, **kw)
else:
return "%s = 1" % self.process(element.element, **kw)
def visit_isfalse_unary_operator(self, element, operator, **kw):
if self.dialect.supports_native_boolean:
return "NOT %s" % self.process(element.element, **kw)
else:
return "%s = 0" % self.process(element.element, **kw)
def visit_notmatch_op_binary(self, binary, operator, **kw):
return "NOT %s" % self.visit_binary(
binary, override_operator=operators.match_op)
def visit_binary(self, binary, override_operator=None,
eager_grouping=False, **kw):
# don't allow "? = ?" to render
if self.ansi_bind_rules and \
isinstance(binary.left, elements.BindParameter) and \
isinstance(binary.right, elements.BindParameter):
kw['literal_binds'] = True
operator_ = override_operator or binary.operator
disp = self._get_operator_dispatch(operator_, "binary", None)
if disp:
return disp(binary, operator_, **kw)
else:
try:
opstring = OPERATORS[operator_]
except KeyError:
raise exc.UnsupportedCompilationError(self, operator_)
else:
return self._generate_generic_binary(binary, opstring, **kw)
def visit_custom_op_binary(self, element, operator, **kw):
kw['eager_grouping'] = operator.eager_grouping
return self._generate_generic_binary(
element, " " + operator.opstring + " ", **kw)
def visit_custom_op_unary_operator(self, element, operator, **kw):
return self._generate_generic_unary_operator(
element, operator.opstring + " ", **kw)
def visit_custom_op_unary_modifier(self, element, operator, **kw):
return self._generate_generic_unary_modifier(
element, " " + operator.opstring, **kw)
def _generate_generic_binary(
self, binary, opstring, eager_grouping=False, **kw):
_in_binary = kw.get('_in_binary', False)
kw['_in_binary'] = True
text = binary.left._compiler_dispatch(
self, eager_grouping=eager_grouping, **kw) + \
opstring + \
binary.right._compiler_dispatch(
self, eager_grouping=eager_grouping, **kw)
if _in_binary and eager_grouping:
text = "(%s)" % text
return text
def _generate_generic_unary_operator(self, unary, opstring, **kw):
return opstring + unary.element._compiler_dispatch(self, **kw)
def _generate_generic_unary_modifier(self, unary, opstring, **kw):
return unary.element._compiler_dispatch(self, **kw) + opstring
@util.memoized_property
def _like_percent_literal(self):
return elements.literal_column("'%'", type_=sqltypes.STRINGTYPE)
def visit_contains_op_binary(self, binary, operator, **kw):
binary = binary._clone()
percent = self._like_percent_literal
binary.right = percent.__add__(binary.right).__add__(percent)
return self.visit_like_op_binary(binary, operator, **kw)
def visit_notcontains_op_binary(self, binary, operator, **kw):
binary = binary._clone()
percent = self._like_percent_literal
binary.right = percent.__add__(binary.right).__add__(percent)
return self.visit_notlike_op_binary(binary, operator, **kw)
def visit_startswith_op_binary(self, binary, operator, **kw):
binary = binary._clone()
percent = self._like_percent_literal
binary.right = percent.__radd__(
binary.right
)
return self.visit_like_op_binary(binary, operator, **kw)
def visit_notstartswith_op_binary(self, binary, operator, **kw):
binary = binary._clone()
percent = self._like_percent_literal
binary.right = percent.__radd__(
binary.right
)
return self.visit_notlike_op_binary(binary, operator, **kw)
def visit_endswith_op_binary(self, binary, operator, **kw):
binary = binary._clone()
percent = self._like_percent_literal
binary.right = percent.__add__(binary.right)
return self.visit_like_op_binary(binary, operator, **kw)
def visit_notendswith_op_binary(self, binary, operator, **kw):
binary = binary._clone()
percent = self._like_percent_literal
binary.right = percent.__add__(binary.right)
return self.visit_notlike_op_binary(binary, operator, **kw)
def visit_like_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
# TODO: use ternary here, not "and"/ "or"
return '%s LIKE %s' % (
binary.left._compiler_dispatch(self, **kw),
binary.right._compiler_dispatch(self, **kw)) \
+ (
' ESCAPE ' +
self.render_literal_value(escape, sqltypes.STRINGTYPE)
if escape else ''
)
def visit_notlike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
return '%s NOT LIKE %s' % (
binary.left._compiler_dispatch(self, **kw),
binary.right._compiler_dispatch(self, **kw)) \
+ (
' ESCAPE ' +
self.render_literal_value(escape, sqltypes.STRINGTYPE)
if escape else ''
)
def visit_ilike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
return 'lower(%s) LIKE lower(%s)' % (
binary.left._compiler_dispatch(self, **kw),
binary.right._compiler_dispatch(self, **kw)) \
+ (
' ESCAPE ' +
self.render_literal_value(escape, sqltypes.STRINGTYPE)
if escape else ''
)
def visit_notilike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
return 'lower(%s) NOT LIKE lower(%s)' % (
binary.left._compiler_dispatch(self, **kw),
binary.right._compiler_dispatch(self, **kw)) \
+ (
' ESCAPE ' +
self.render_literal_value(escape, sqltypes.STRINGTYPE)
if escape else ''
)
def visit_between_op_binary(self, binary, operator, **kw):
symmetric = binary.modifiers.get("symmetric", False)
return self._generate_generic_binary(
binary, " BETWEEN SYMMETRIC "
if symmetric else " BETWEEN ", **kw)
def visit_notbetween_op_binary(self, binary, operator, **kw):
symmetric = binary.modifiers.get("symmetric", False)
return self._generate_generic_binary(
binary, " NOT BETWEEN SYMMETRIC "
if symmetric else " NOT BETWEEN ", **kw)
def visit_bindparam(self, bindparam, within_columns_clause=False,
literal_binds=False,
skip_bind_expression=False,
**kwargs):
if not skip_bind_expression and bindparam.type._has_bind_expression:
bind_expression = bindparam.type.bind_expression(bindparam)
return self.process(bind_expression,
skip_bind_expression=True)
if literal_binds or \
(within_columns_clause and
self.ansi_bind_rules):
if bindparam.value is None and bindparam.callable is None:
raise exc.CompileError("Bind parameter '%s' without a "
"renderable value not allowed here."
% bindparam.key)
return self.render_literal_bindparam(
bindparam, within_columns_clause=True, **kwargs)
name = self._truncate_bindparam(bindparam)
if name in self.binds:
existing = self.binds[name]
if existing is not bindparam:
if (existing.unique or bindparam.unique) and \
not existing.proxy_set.intersection(
bindparam.proxy_set):
raise exc.CompileError(
"Bind parameter '%s' conflicts with "
"unique bind parameter of the same name" %
bindparam.key
)
elif existing._is_crud or bindparam._is_crud:
raise exc.CompileError(
"bindparam() name '%s' is reserved "
"for automatic usage in the VALUES or SET "
"clause of this "
"insert/update statement. Please use a "
"name other than column name when using bindparam() "
"with insert() or update() (for example, 'b_%s')." %
(bindparam.key, bindparam.key)
)
self.binds[bindparam.key] = self.binds[name] = bindparam
return self.bindparam_string(name, **kwargs)
def render_literal_bindparam(self, bindparam, **kw):
value = bindparam.effective_value
return self.render_literal_value(value, bindparam.type)
def render_literal_value(self, value, type_):
"""Render the value of a bind parameter as a quoted literal.
This is used for statement sections that do not accept bind parameters
on the target driver/database.
This should be implemented by subclasses using the quoting services
of the DBAPI.
"""
processor = type_._cached_literal_processor(self.dialect)
if processor:
return processor(value)
else:
raise NotImplementedError(
"Don't know how to literal-quote value %r" % value)
def _truncate_bindparam(self, bindparam):
if bindparam in self.bind_names:
return self.bind_names[bindparam]
bind_name = bindparam.key
if isinstance(bind_name, elements._truncated_label):
bind_name = self._truncated_identifier("bindparam", bind_name)
# add to bind_names for translation
self.bind_names[bindparam] = bind_name
return bind_name
def _truncated_identifier(self, ident_class, name):
if (ident_class, name) in self.truncated_names:
return self.truncated_names[(ident_class, name)]
anonname = name.apply_map(self.anon_map)
if len(anonname) > self.label_length - 6:
counter = self.truncated_names.get(ident_class, 1)
truncname = anonname[0:max(self.label_length - 6, 0)] + \
"_" + hex(counter)[2:]
self.truncated_names[ident_class] = counter + 1
else:
truncname = anonname
self.truncated_names[(ident_class, name)] = truncname
return truncname
def _anonymize(self, name):
return name % self.anon_map
def _process_anon(self, key):
(ident, derived) = key.split(' ', 1)
anonymous_counter = self.anon_map.get(derived, 1)
self.anon_map[derived] = anonymous_counter + 1
return derived + "_" + str(anonymous_counter)
def bindparam_string(self, name, positional_names=None, **kw):
if self.positional:
if positional_names is not None:
positional_names.append(name)
else:
self.positiontup.append(name)
return self.bindtemplate % {'name': name}
def visit_cte(self, cte, asfrom=False, ashint=False,
fromhints=None,
**kwargs):
self._init_cte_state()
if isinstance(cte.name, elements._truncated_label):
cte_name = self._truncated_identifier("alias", cte.name)
else:
cte_name = cte.name
if cte_name in self.ctes_by_name:
existing_cte = self.ctes_by_name[cte_name]
# we've generated a same-named CTE that we are enclosed in,
# or this is the same CTE. just return the name.
if cte in existing_cte._restates or cte is existing_cte:
return self.preparer.format_alias(cte, cte_name)
elif existing_cte in cte._restates:
# we've generated a same-named CTE that is
# enclosed in us - we take precedence, so
# discard the text for the "inner".
del self.ctes[existing_cte]
else:
raise exc.CompileError(
"Multiple, unrelated CTEs found with "
"the same name: %r" %
cte_name)
self.ctes_by_name[cte_name] = cte
# look for embedded DML ctes and propagate autocommit
if 'autocommit' in cte.element._execution_options and \
'autocommit' not in self.execution_options:
self.execution_options = self.execution_options.union(
{"autocommit": cte.element._execution_options['autocommit']})
if cte._cte_alias is not None:
orig_cte = cte._cte_alias
if orig_cte not in self.ctes:
self.visit_cte(orig_cte, **kwargs)
cte_alias_name = cte._cte_alias.name
if isinstance(cte_alias_name, elements._truncated_label):
cte_alias_name = self._truncated_identifier(
"alias", cte_alias_name)
else:
orig_cte = cte
cte_alias_name = None
if not cte_alias_name and cte not in self.ctes:
if cte.recursive:
self.ctes_recursive = True
text = self.preparer.format_alias(cte, cte_name)
if cte.recursive:
if isinstance(cte.original, selectable.Select):
col_source = cte.original
elif isinstance(cte.original, selectable.CompoundSelect):
col_source = cte.original.selects[0]
else:
assert False
recur_cols = [c for c in
util.unique_list(col_source.inner_columns)
if c is not None]
text += "(%s)" % (", ".join(
self.preparer.format_column(ident)
for ident in recur_cols))
if self.positional:
kwargs['positional_names'] = self.cte_positional[cte] = []
text += " AS \n" + \
cte.original._compiler_dispatch(
self, asfrom=True, **kwargs
)
if cte._suffixes:
text += " " + self._generate_prefixes(
cte, cte._suffixes, **kwargs)
self.ctes[cte] = text
if asfrom:
if cte_alias_name:
text = self.preparer.format_alias(cte, cte_alias_name)
text += self.get_render_as_alias_suffix(cte_name)
else:
return self.preparer.format_alias(cte, cte_name)
return text
def visit_alias(self, alias, asfrom=False, ashint=False,
iscrud=False,
fromhints=None, **kwargs):
if asfrom or ashint:
if isinstance(alias.name, elements._truncated_label):
alias_name = self._truncated_identifier("alias", alias.name)
else:
alias_name = alias.name
if ashint:
return self.preparer.format_alias(alias, alias_name)
elif asfrom:
ret = alias.original._compiler_dispatch(self,
asfrom=True, **kwargs) + \
self.get_render_as_alias_suffix(
self.preparer.format_alias(alias, alias_name))
if fromhints and alias in fromhints:
ret = self.format_from_hint_text(ret, alias,
fromhints[alias], iscrud)
return ret
else:
return alias.original._compiler_dispatch(self, **kwargs)
def visit_lateral(self, lateral, **kw):
kw['lateral'] = True
return "LATERAL %s" % self.visit_alias(lateral, **kw)
def visit_tablesample(self, tablesample, asfrom=False, **kw):
text = "%s TABLESAMPLE %s" % (
self.visit_alias(tablesample, asfrom=True, **kw),
tablesample._get_method()._compiler_dispatch(self, **kw))
if tablesample.seed is not None:
text += " REPEATABLE (%s)" % (
tablesample.seed._compiler_dispatch(self, **kw))
return text
def get_render_as_alias_suffix(self, alias_name_text):
return " AS " + alias_name_text
def _add_to_result_map(self, keyname, name, objects, type_):
self._result_columns.append((keyname, name, objects, type_))
def _label_select_column(self, select, column,
populate_result_map,
asfrom, column_clause_args,
name=None,
within_columns_clause=True):
"""produce labeled columns present in a select()."""
if column.type._has_column_expression and \
populate_result_map:
col_expr = column.type.column_expression(column)
add_to_result_map = lambda keyname, name, objects, type_: \
self._add_to_result_map(
keyname, name,
(column,) + objects, type_)
else:
col_expr = column
if populate_result_map:
add_to_result_map = self._add_to_result_map
else:
add_to_result_map = None
if not within_columns_clause:
result_expr = col_expr
elif isinstance(column, elements.Label):
if col_expr is not column:
result_expr = _CompileLabel(
col_expr,
column.name,
alt_names=(column.element,)
)
else:
result_expr = col_expr
elif select is not None and name:
result_expr = _CompileLabel(
col_expr,
name,
alt_names=(column._key_label,)
)
elif \
asfrom and \
isinstance(column, elements.ColumnClause) and \
not column.is_literal and \
column.table is not None and \
not isinstance(column.table, selectable.Select):
result_expr = _CompileLabel(col_expr,
elements._as_truncated(column.name),
alt_names=(column.key,))
elif (
not isinstance(column, elements.TextClause) and
(
not isinstance(column, elements.UnaryExpression) or
column.wraps_column_expression
) and
(
not hasattr(column, 'name') or
isinstance(column, functions.Function)
)
):
result_expr = _CompileLabel(col_expr, column.anon_label)
elif col_expr is not column:
# TODO: are we sure "column" has a .name and .key here ?
# assert isinstance(column, elements.ColumnClause)
result_expr = _CompileLabel(col_expr,
elements._as_truncated(column.name),
alt_names=(column.key,))
else:
result_expr = col_expr
column_clause_args.update(
within_columns_clause=within_columns_clause,
add_to_result_map=add_to_result_map
)
return result_expr._compiler_dispatch(
self,
**column_clause_args
)
def format_from_hint_text(self, sqltext, table, hint, iscrud):
hinttext = self.get_from_hint_text(table, hint)
if hinttext:
sqltext += " " + hinttext
return sqltext
def get_select_hint_text(self, byfroms):
return None
def get_from_hint_text(self, table, text):
return None
def get_crud_hint_text(self, table, text):
return None
def get_statement_hint_text(self, hint_texts):
return " ".join(hint_texts)
def _transform_select_for_nested_joins(self, select):
"""Rewrite any "a JOIN (b JOIN c)" expression as
"a JOIN (select * from b JOIN c) AS anon", to support
databases that can't parse a parenthesized join correctly
(i.e. sqlite < 3.7.16).
"""
cloned = {}
column_translate = [{}]
def visit(element, **kw):
if element in column_translate[-1]:
return column_translate[-1][element]
elif element in cloned:
return cloned[element]
newelem = cloned[element] = element._clone()
if newelem.is_selectable and newelem._is_join and \
isinstance(newelem.right, selectable.FromGrouping):
newelem._reset_exported()
newelem.left = visit(newelem.left, **kw)
right = visit(newelem.right, **kw)
selectable_ = selectable.Select(
[right.element],
use_labels=True).alias()
for c in selectable_.c:
c._key_label = c.key
c._label = c.name
translate_dict = dict(
zip(newelem.right.element.c, selectable_.c)
)
# translating from both the old and the new
# because different select() structures will lead us
# to traverse differently
translate_dict[right.element.left] = selectable_
translate_dict[right.element.right] = selectable_
translate_dict[newelem.right.element.left] = selectable_
translate_dict[newelem.right.element.right] = selectable_
# propagate translations that we've gained
# from nested visit(newelem.right) outwards
# to the enclosing select here. this happens
# only when we have more than one level of right
# join nesting, i.e. "a JOIN (b JOIN (c JOIN d))"
for k, v in list(column_translate[-1].items()):
if v in translate_dict:
# remarkably, no current ORM tests (May 2013)
# hit this condition, only test_join_rewriting
# does.
column_translate[-1][k] = translate_dict[v]
column_translate[-1].update(translate_dict)
newelem.right = selectable_
newelem.onclause = visit(newelem.onclause, **kw)
elif newelem._is_from_container:
# if we hit an Alias, CompoundSelect or ScalarSelect, put a
# marker in the stack.
kw['transform_clue'] = 'select_container'
newelem._copy_internals(clone=visit, **kw)
elif newelem.is_selectable and newelem._is_select:
barrier_select = kw.get('transform_clue', None) == \
'select_container'
# if we're still descended from an
# Alias/CompoundSelect/ScalarSelect, we're
# in a FROM clause, so start with a new translate collection
if barrier_select:
column_translate.append({})
kw['transform_clue'] = 'inside_select'
newelem._copy_internals(clone=visit, **kw)
if barrier_select:
del column_translate[-1]
else:
newelem._copy_internals(clone=visit, **kw)
return newelem
return visit(select)
def _transform_result_map_for_nested_joins(
self, select, transformed_select):
inner_col = dict((c._key_label, c) for
c in transformed_select.inner_columns)
d = dict(
(inner_col[c._key_label], c)
for c in select.inner_columns
)
self._result_columns = [
(key, name, tuple([d.get(col, col) for col in objs]), typ)
for key, name, objs, typ in self._result_columns
]
_default_stack_entry = util.immutabledict([
('correlate_froms', frozenset()),
('asfrom_froms', frozenset())
])
def _display_froms_for_select(self, select, asfrom, lateral=False):
# utility method to help external dialects
# get the correct from list for a select.
# specifically the oracle dialect needs this feature
# right now.
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
correlate_froms = entry['correlate_froms']
asfrom_froms = entry['asfrom_froms']
if asfrom and not lateral:
froms = select._get_display_froms(
explicit_correlate_froms=correlate_froms.difference(
asfrom_froms),
implicit_correlate_froms=())
else:
froms = select._get_display_froms(
explicit_correlate_froms=correlate_froms,
implicit_correlate_froms=asfrom_froms)
return froms
def visit_select(self, select, asfrom=False, parens=True,
fromhints=None,
compound_index=0,
nested_join_translation=False,
select_wraps_for=None,
lateral=False,
**kwargs):
needs_nested_translation = \
select.use_labels and \
not nested_join_translation and \
not self.stack and \
not self.dialect.supports_right_nested_joins
if needs_nested_translation:
transformed_select = self._transform_select_for_nested_joins(
select)
text = self.visit_select(
transformed_select, asfrom=asfrom, parens=parens,
fromhints=fromhints,
compound_index=compound_index,
nested_join_translation=True, **kwargs
)
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
populate_result_map = toplevel or \
(
compound_index == 0 and entry.get(
'need_result_map_for_compound', False)
) or entry.get('need_result_map_for_nested', False)
# this was first proposed as part of #3372; however, it is not
# reached in current tests and could possibly be an assertion
# instead.
if not populate_result_map and 'add_to_result_map' in kwargs:
del kwargs['add_to_result_map']
if needs_nested_translation:
if populate_result_map:
self._transform_result_map_for_nested_joins(
select, transformed_select)
return text
froms = self._setup_select_stack(select, entry, asfrom, lateral)
column_clause_args = kwargs.copy()
column_clause_args.update({
'within_label_clause': False,
'within_columns_clause': False
})
text = "SELECT " # we're off to a good start !
if select._hints:
hint_text, byfrom = self._setup_select_hints(select)
if hint_text:
text += hint_text + " "
else:
byfrom = None
if select._prefixes:
text += self._generate_prefixes(
select, select._prefixes, **kwargs)
text += self.get_select_precolumns(select, **kwargs)
# the actual list of columns to print in the SELECT column list.
inner_columns = [
c for c in [
self._label_select_column(
select,
column,
populate_result_map, asfrom,
column_clause_args,
name=name)
for name, column in select._columns_plus_names
]
if c is not None
]
if populate_result_map and select_wraps_for is not None:
# if this select is a compiler-generated wrapper,
# rewrite the targeted columns in the result map
translate = dict(
zip(
[name for (key, name) in select._columns_plus_names],
[name for (key, name) in
select_wraps_for._columns_plus_names])
)
self._result_columns = [
(key, name, tuple(translate.get(o, o) for o in obj), type_)
for key, name, obj, type_ in self._result_columns
]
text = self._compose_select_body(
text, select, inner_columns, froms, byfrom, kwargs)
if select._statement_hints:
per_dialect = [
ht for (dialect_name, ht)
in select._statement_hints
if dialect_name in ('*', self.dialect.name)
]
if per_dialect:
text += " " + self.get_statement_hint_text(per_dialect)
if self.ctes and toplevel:
text = self._render_cte_clause() + text
if select._suffixes:
text += " " + self._generate_prefixes(
select, select._suffixes, **kwargs)
self.stack.pop(-1)
if (asfrom or lateral) and parens:
return "(" + text + ")"
else:
return text
def _setup_select_hints(self, select):
byfrom = dict([
(from_, hinttext % {
'name': from_._compiler_dispatch(
self, ashint=True)
})
for (from_, dialect), hinttext in
select._hints.items()
if dialect in ('*', self.dialect.name)
])
hint_text = self.get_select_hint_text(byfrom)
return hint_text, byfrom
def _setup_select_stack(self, select, entry, asfrom, lateral):
correlate_froms = entry['correlate_froms']
asfrom_froms = entry['asfrom_froms']
if asfrom and not lateral:
froms = select._get_display_froms(
explicit_correlate_froms=correlate_froms.difference(
asfrom_froms),
implicit_correlate_froms=())
else:
froms = select._get_display_froms(
explicit_correlate_froms=correlate_froms,
implicit_correlate_froms=asfrom_froms)
new_correlate_froms = set(selectable._from_objects(*froms))
all_correlate_froms = new_correlate_froms.union(correlate_froms)
new_entry = {
'asfrom_froms': new_correlate_froms,
'correlate_froms': all_correlate_froms,
'selectable': select,
}
self.stack.append(new_entry)
return froms
def _compose_select_body(
self, text, select, inner_columns, froms, byfrom, kwargs):
text += ', '.join(inner_columns)
if froms:
text += " \nFROM "
if select._hints:
text += ', '.join(
[f._compiler_dispatch(self, asfrom=True,
fromhints=byfrom, **kwargs)
for f in froms])
else:
text += ', '.join(
[f._compiler_dispatch(self, asfrom=True, **kwargs)
for f in froms])
else:
text += self.default_from()
if select._whereclause is not None:
t = select._whereclause._compiler_dispatch(self, **kwargs)
if t:
text += " \nWHERE " + t
if select._group_by_clause.clauses:
group_by = select._group_by_clause._compiler_dispatch(
self, **kwargs)
if group_by:
text += " GROUP BY " + group_by
if select._having is not None:
t = select._having._compiler_dispatch(self, **kwargs)
if t:
text += " \nHAVING " + t
if select._order_by_clause.clauses:
text += self.order_by_clause(select, **kwargs)
if (select._limit_clause is not None or
select._offset_clause is not None):
text += self.limit_clause(select, **kwargs)
if select._for_update_arg is not None:
text += self.for_update_clause(select, **kwargs)
return text
def _generate_prefixes(self, stmt, prefixes, **kw):
clause = " ".join(
prefix._compiler_dispatch(self, **kw)
for prefix, dialect_name in prefixes
if dialect_name is None or
dialect_name == self.dialect.name
)
if clause:
clause += " "
return clause
def _render_cte_clause(self):
if self.positional:
self.positiontup = sum([
self.cte_positional[cte]
for cte in self.ctes], []) + \
self.positiontup
cte_text = self.get_cte_preamble(self.ctes_recursive) + " "
cte_text += ", \n".join(
[txt for txt in self.ctes.values()]
)
cte_text += "\n "
return cte_text
def get_cte_preamble(self, recursive):
if recursive:
return "WITH RECURSIVE"
else:
return "WITH"
def get_select_precolumns(self, select, **kw):
"""Called when building a ``SELECT`` statement, position is just
before column list.
"""
return select._distinct and "DISTINCT " or ""
def order_by_clause(self, select, **kw):
order_by = select._order_by_clause._compiler_dispatch(self, **kw)
if order_by:
return " ORDER BY " + order_by
else:
return ""
def for_update_clause(self, select, **kw):
return " FOR UPDATE"
def returning_clause(self, stmt, returning_cols):
raise exc.CompileError(
"RETURNING is not supported by this "
"dialect's statement compiler.")
def limit_clause(self, select, **kw):
text = ""
if select._limit_clause is not None:
text += "\n LIMIT " + self.process(select._limit_clause, **kw)
if select._offset_clause is not None:
if select._limit_clause is None:
text += "\n LIMIT -1"
text += " OFFSET " + self.process(select._offset_clause, **kw)
return text
def visit_table(self, table, asfrom=False, iscrud=False, ashint=False,
fromhints=None, use_schema=True, **kwargs):
if asfrom or ashint:
effective_schema = self.preparer.schema_for_object(table)
if use_schema and effective_schema:
ret = self.preparer.quote_schema(effective_schema) + \
"." + self.preparer.quote(table.name)
else:
ret = self.preparer.quote(table.name)
if fromhints and table in fromhints:
ret = self.format_from_hint_text(ret, table,
fromhints[table], iscrud)
return ret
else:
return ""
def visit_join(self, join, asfrom=False, **kwargs):
if join.full:
join_type = " FULL OUTER JOIN "
elif join.isouter:
join_type = " LEFT OUTER JOIN "
else:
join_type = " JOIN "
return (
join.left._compiler_dispatch(self, asfrom=True, **kwargs) +
join_type +
join.right._compiler_dispatch(self, asfrom=True, **kwargs) +
" ON " +
join.onclause._compiler_dispatch(self, **kwargs)
)
def _setup_crud_hints(self, stmt, table_text):
dialect_hints = dict([
(table, hint_text)
for (table, dialect), hint_text in
stmt._hints.items()
if dialect in ('*', self.dialect.name)
])
if stmt.table in dialect_hints:
table_text = self.format_from_hint_text(
table_text,
stmt.table,
dialect_hints[stmt.table],
True
)
return dialect_hints, table_text
def visit_insert(self, insert_stmt, asfrom=False, **kw):
toplevel = not self.stack
self.stack.append(
{'correlate_froms': set(),
"asfrom_froms": set(),
"selectable": insert_stmt})
crud_params = crud._setup_crud_params(
self, insert_stmt, crud.ISINSERT, **kw)
if not crud_params and \
not self.dialect.supports_default_values and \
not self.dialect.supports_empty_insert:
raise exc.CompileError("The '%s' dialect with current database "
"version settings does not support empty "
"inserts." %
self.dialect.name)
if insert_stmt._has_multi_parameters:
if not self.dialect.supports_multivalues_insert:
raise exc.CompileError(
"The '%s' dialect with current database "
"version settings does not support "
"in-place multirow inserts." %
self.dialect.name)
crud_params_single = crud_params[0]
else:
crud_params_single = crud_params
preparer = self.preparer
supports_default_values = self.dialect.supports_default_values
text = "INSERT "
if insert_stmt._prefixes:
text += self._generate_prefixes(insert_stmt,
insert_stmt._prefixes, **kw)
text += "INTO "
table_text = preparer.format_table(insert_stmt.table)
if insert_stmt._hints:
dialect_hints, table_text = self._setup_crud_hints(
insert_stmt, table_text)
else:
dialect_hints = None
text += table_text
if crud_params_single or not supports_default_values:
text += " (%s)" % ', '.join([preparer.format_column(c[0])
for c in crud_params_single])
if self.returning or insert_stmt._returning:
returning_clause = self.returning_clause(
insert_stmt, self.returning or insert_stmt._returning)
if self.returning_precedes_values:
text += " " + returning_clause
else:
returning_clause = None
if insert_stmt.select is not None:
text += " %s" % self.process(self._insert_from_select, **kw)
elif not crud_params and supports_default_values:
text += " DEFAULT VALUES"
elif insert_stmt._has_multi_parameters:
text += " VALUES %s" % (
", ".join(
"(%s)" % (
', '.join(c[1] for c in crud_param_set)
)
for crud_param_set in crud_params
)
)
else:
text += " VALUES (%s)" % \
', '.join([c[1] for c in crud_params])
if insert_stmt._post_values_clause is not None:
post_values_clause = self.process(
insert_stmt._post_values_clause, **kw)
if post_values_clause:
text += " " + post_values_clause
if returning_clause and not self.returning_precedes_values:
text += " " + returning_clause
if self.ctes and toplevel:
text = self._render_cte_clause() + text
self.stack.pop(-1)
if asfrom:
return "(" + text + ")"
else:
return text
def update_limit_clause(self, update_stmt):
"""Provide a hook for MySQL to add LIMIT to the UPDATE"""
return None
def update_tables_clause(self, update_stmt, from_table,
extra_froms, **kw):
"""Provide a hook to override the initial table clause
in an UPDATE statement.
MySQL overrides this.
"""
kw['asfrom'] = True
return from_table._compiler_dispatch(self, iscrud=True, **kw)
def update_from_clause(self, update_stmt,
from_table, extra_froms,
from_hints,
**kw):
"""Provide a hook to override the generation of an
UPDATE..FROM clause.
MySQL and MSSQL override this.
"""
return "FROM " + ', '.join(
t._compiler_dispatch(self, asfrom=True,
fromhints=from_hints, **kw)
for t in extra_froms)
def visit_update(self, update_stmt, asfrom=False, **kw):
toplevel = not self.stack
self.stack.append(
{'correlate_froms': set([update_stmt.table]),
"asfrom_froms": set([update_stmt.table]),
"selectable": update_stmt})
extra_froms = update_stmt._extra_froms
text = "UPDATE "
if update_stmt._prefixes:
text += self._generate_prefixes(update_stmt,
update_stmt._prefixes, **kw)
table_text = self.update_tables_clause(update_stmt, update_stmt.table,
extra_froms, **kw)
crud_params = crud._setup_crud_params(
self, update_stmt, crud.ISUPDATE, **kw)
if update_stmt._hints:
dialect_hints, table_text = self._setup_crud_hints(
update_stmt, table_text)
else:
dialect_hints = None
text += table_text
text += ' SET '
include_table = extra_froms and \
self.render_table_with_column_in_update_from
text += ', '.join(
c[0]._compiler_dispatch(self,
include_table=include_table) +
'=' + c[1] for c in crud_params
)
if self.returning or update_stmt._returning:
if self.returning_precedes_values:
text += " " + self.returning_clause(
update_stmt, self.returning or update_stmt._returning)
if extra_froms:
extra_from_text = self.update_from_clause(
update_stmt,
update_stmt.table,
extra_froms,
dialect_hints, **kw)
if extra_from_text:
text += " " + extra_from_text
if update_stmt._whereclause is not None:
t = self.process(update_stmt._whereclause, **kw)
if t:
text += " WHERE " + t
limit_clause = self.update_limit_clause(update_stmt)
if limit_clause:
text += " " + limit_clause
if (self.returning or update_stmt._returning) and \
not self.returning_precedes_values:
text += " " + self.returning_clause(
update_stmt, self.returning or update_stmt._returning)
if self.ctes and toplevel:
text = self._render_cte_clause() + text
self.stack.pop(-1)
if asfrom:
return "(" + text + ")"
else:
return text
@util.memoized_property
def _key_getters_for_crud_column(self):
return crud._key_getters_for_crud_column(self, self.statement)
def visit_delete(self, delete_stmt, asfrom=False, **kw):
toplevel = not self.stack
self.stack.append({'correlate_froms': set([delete_stmt.table]),
"asfrom_froms": set([delete_stmt.table]),
"selectable": delete_stmt})
crud._setup_crud_params(self, delete_stmt, crud.ISDELETE, **kw)
text = "DELETE "
if delete_stmt._prefixes:
text += self._generate_prefixes(delete_stmt,
delete_stmt._prefixes, **kw)
text += "FROM "
table_text = delete_stmt.table._compiler_dispatch(
self, asfrom=True, iscrud=True)
if delete_stmt._hints:
dialect_hints, table_text = self._setup_crud_hints(
delete_stmt, table_text)
text += table_text
if delete_stmt._returning:
if self.returning_precedes_values:
text += " " + self.returning_clause(
delete_stmt, delete_stmt._returning)
if delete_stmt._whereclause is not None:
t = delete_stmt._whereclause._compiler_dispatch(self, **kw)
if t:
text += " WHERE " + t
if delete_stmt._returning and not self.returning_precedes_values:
text += " " + self.returning_clause(
delete_stmt, delete_stmt._returning)
if self.ctes and toplevel:
text = self._render_cte_clause() + text
self.stack.pop(-1)
if asfrom:
return "(" + text + ")"
else:
return text
def visit_savepoint(self, savepoint_stmt):
return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
def visit_rollback_to_savepoint(self, savepoint_stmt):
return "ROLLBACK TO SAVEPOINT %s" % \
self.preparer.format_savepoint(savepoint_stmt)
def visit_release_savepoint(self, savepoint_stmt):
return "RELEASE SAVEPOINT %s" % \
self.preparer.format_savepoint(savepoint_stmt)
class StrSQLCompiler(SQLCompiler):
""""a compiler subclass with a few non-standard SQL features allowed.
Used for stringification of SQL statements when a real dialect is not
available.
"""
def _fallback_column_name(self, column):
return "<name unknown>"
def visit_getitem_binary(self, binary, operator, **kw):
return "%s[%s]" % (
self.process(binary.left, **kw),
self.process(binary.right, **kw)
)
def visit_json_getitem_op_binary(self, binary, operator, **kw):
return self.visit_getitem_binary(binary, operator, **kw)
def visit_json_path_getitem_op_binary(self, binary, operator, **kw):
return self.visit_getitem_binary(binary, operator, **kw)
def returning_clause(self, stmt, returning_cols):
columns = [
self._label_select_column(None, c, True, False, {})
for c in elements._select_iterables(returning_cols)
]
return 'RETURNING ' + ', '.join(columns)
class DDLCompiler(Compiled):
@util.memoized_property
def sql_compiler(self):
return self.dialect.statement_compiler(self.dialect, None)
@util.memoized_property
def type_compiler(self):
return self.dialect.type_compiler
def construct_params(self, params=None):
return None
def visit_ddl(self, ddl, **kwargs):
# table events can substitute table and schema name
context = ddl.context
if isinstance(ddl.target, schema.Table):
context = context.copy()
preparer = self.preparer
path = preparer.format_table_seq(ddl.target)
if len(path) == 1:
table, sch = path[0], ''
else:
table, sch = path[-1], path[0]
context.setdefault('table', table)
context.setdefault('schema', sch)
context.setdefault('fullname', preparer.format_table(ddl.target))
return self.sql_compiler.post_process_text(ddl.statement % context)
def visit_create_schema(self, create):
schema = self.preparer.format_schema(create.element)
return "CREATE SCHEMA " + schema
def visit_drop_schema(self, drop):
schema = self.preparer.format_schema(drop.element)
text = "DROP SCHEMA " + schema
if drop.cascade:
text += " CASCADE"
return text
def visit_create_table(self, create):
table = create.element
preparer = self.preparer
text = "\nCREATE "
if table._prefixes:
text += " ".join(table._prefixes) + " "
text += "TABLE " + preparer.format_table(table) + " "
create_table_suffix = self.create_table_suffix(table)
if create_table_suffix:
text += create_table_suffix + " "
text += "("
separator = "\n"
# if only one primary key, specify it along with the column
first_pk = False
for create_column in create.columns:
column = create_column.element
try:
processed = self.process(create_column,
first_pk=column.primary_key
and not first_pk)
if processed is not None:
text += separator
separator = ", \n"
text += "\t" + processed
if column.primary_key:
first_pk = True
except exc.CompileError as ce:
util.raise_from_cause(
exc.CompileError(
util.u("(in table '%s', column '%s'): %s") %
(table.description, column.name, ce.args[0])
))
const = self.create_table_constraints(
table, _include_foreign_key_constraints= # noqa
create.include_foreign_key_constraints)
if const:
text += separator + "\t" + const
text += "\n)%s\n\n" % self.post_create_table(table)
return text
def visit_create_column(self, create, first_pk=False):
column = create.element
if column.system:
return None
text = self.get_column_specification(
column,
first_pk=first_pk
)
const = " ".join(self.process(constraint)
for constraint in column.constraints)
if const:
text += " " + const
return text
def create_table_constraints(
self, table,
_include_foreign_key_constraints=None):
# On some DB order is significant: visit PK first, then the
# other constraints (engine.ReflectionTest.testbasic failed on FB2)
constraints = []
if table.primary_key:
constraints.append(table.primary_key)
all_fkcs = table.foreign_key_constraints
if _include_foreign_key_constraints is not None:
omit_fkcs = all_fkcs.difference(_include_foreign_key_constraints)
else:
omit_fkcs = set()
constraints.extend([c for c in table._sorted_constraints
if c is not table.primary_key and
c not in omit_fkcs])
return ", \n\t".join(
p for p in
(self.process(constraint)
for constraint in constraints
if (
constraint._create_rule is None or
constraint._create_rule(self))
and (
not self.dialect.supports_alter or
not getattr(constraint, 'use_alter', False)
)) if p is not None
)
def visit_drop_table(self, drop):
return "\nDROP TABLE " + self.preparer.format_table(drop.element)
def visit_drop_view(self, drop):
return "\nDROP VIEW " + self.preparer.format_table(drop.element)
def _verify_index_table(self, index):
if index.table is None:
raise exc.CompileError("Index '%s' is not associated "
"with any table." % index.name)
def visit_create_index(self, create, include_schema=False,
include_table_schema=True):
index = create.element
self._verify_index_table(index)
preparer = self.preparer
text = "CREATE "
if index.unique:
text += "UNIQUE "
text += "INDEX %s ON %s (%s)" \
% (
self._prepared_index_name(index,
include_schema=include_schema),
preparer.format_table(index.table,
use_schema=include_table_schema),
', '.join(
self.sql_compiler.process(
expr, include_table=False, literal_binds=True) for
expr in index.expressions)
)
return text
def visit_drop_index(self, drop):
index = drop.element
return "\nDROP INDEX " + self._prepared_index_name(
index, include_schema=True)
def _prepared_index_name(self, index, include_schema=False):
if index.table is not None:
effective_schema = self.preparer.schema_for_object(index.table)
else:
effective_schema = None
if include_schema and effective_schema:
schema_name = self.preparer.quote_schema(effective_schema)
else:
schema_name = None
ident = index.name
if isinstance(ident, elements._truncated_label):
max_ = self.dialect.max_index_name_length or \
self.dialect.max_identifier_length
if len(ident) > max_:
ident = ident[0:max_ - 8] + \
"_" + util.md5_hex(ident)[-4:]
else:
self.dialect.validate_identifier(ident)
index_name = self.preparer.quote(ident)
if schema_name:
index_name = schema_name + "." + index_name
return index_name
def visit_add_constraint(self, create):
return "ALTER TABLE %s ADD %s" % (
self.preparer.format_table(create.element.table),
self.process(create.element)
)
def visit_create_sequence(self, create):
text = "CREATE SEQUENCE %s" % \
self.preparer.format_sequence(create.element)
if create.element.increment is not None:
text += " INCREMENT BY %d" % create.element.increment
if create.element.start is not None:
text += " START WITH %d" % create.element.start
if create.element.minvalue is not None:
text += " MINVALUE %d" % create.element.minvalue
if create.element.maxvalue is not None:
text += " MAXVALUE %d" % create.element.maxvalue
if create.element.nominvalue is not None:
text += " NO MINVALUE"
if create.element.nomaxvalue is not None:
text += " NO MAXVALUE"
if create.element.cycle is not None:
text += " CYCLE"
return text
def visit_drop_sequence(self, drop):
return "DROP SEQUENCE %s" % \
self.preparer.format_sequence(drop.element)
def visit_drop_constraint(self, drop):
constraint = drop.element
if constraint.name is not None:
formatted_name = self.preparer.format_constraint(constraint)
else:
formatted_name = None
if formatted_name is None:
raise exc.CompileError(
"Can't emit DROP CONSTRAINT for constraint %r; "
"it has no name" % drop.element)
return "ALTER TABLE %s DROP CONSTRAINT %s%s" % (
self.preparer.format_table(drop.element.table),
formatted_name,
drop.cascade and " CASCADE" or ""
)
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column) + " " + \
self.dialect.type_compiler.process(
column.type, type_expression=column)
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
if not column.nullable:
colspec += " NOT NULL"
return colspec
def create_table_suffix(self, table):
return ''
def post_create_table(self, table):
return ''
def get_column_default_string(self, column):
if isinstance(column.server_default, schema.DefaultClause):
if isinstance(column.server_default.arg, util.string_types):
return self.sql_compiler.render_literal_value(
column.server_default.arg, sqltypes.STRINGTYPE)
else:
return self.sql_compiler.process(
column.server_default.arg, literal_binds=True)
else:
return None
def visit_check_constraint(self, constraint):
text = ""
if constraint.name is not None:
formatted_name = self.preparer.format_constraint(constraint)
if formatted_name is not None:
text += "CONSTRAINT %s " % formatted_name
text += "CHECK (%s)" % self.sql_compiler.process(constraint.sqltext,
include_table=False,
literal_binds=True)
text += self.define_constraint_deferrability(constraint)
return text
def visit_column_check_constraint(self, constraint):
text = ""
if constraint.name is not None:
formatted_name = self.preparer.format_constraint(constraint)
if formatted_name is not None:
text += "CONSTRAINT %s " % formatted_name
text += "CHECK (%s)" % constraint.sqltext
text += self.define_constraint_deferrability(constraint)
return text
def visit_primary_key_constraint(self, constraint):
if len(constraint) == 0:
return ''
text = ""
if constraint.name is not None:
formatted_name = self.preparer.format_constraint(constraint)
if formatted_name is not None:
text += "CONSTRAINT %s " % formatted_name
text += "PRIMARY KEY "
text += "(%s)" % ', '.join(self.preparer.quote(c.name)
for c in (constraint.columns_autoinc_first
if constraint._implicit_generated
else constraint.columns))
text += self.define_constraint_deferrability(constraint)
return text
def visit_foreign_key_constraint(self, constraint):
preparer = self.preparer
text = ""
if constraint.name is not None:
formatted_name = self.preparer.format_constraint(constraint)
if formatted_name is not None:
text += "CONSTRAINT %s " % formatted_name
remote_table = list(constraint.elements)[0].column.table
text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % (
', '.join(preparer.quote(f.parent.name)
for f in constraint.elements),
self.define_constraint_remote_table(
constraint, remote_table, preparer),
', '.join(preparer.quote(f.column.name)
for f in constraint.elements)
)
text += self.define_constraint_match(constraint)
text += self.define_constraint_cascades(constraint)
text += self.define_constraint_deferrability(constraint)
return text
def define_constraint_remote_table(self, constraint, table, preparer):
"""Format the remote table clause of a CREATE CONSTRAINT clause."""
return preparer.format_table(table)
def visit_unique_constraint(self, constraint):
if len(constraint) == 0:
return ''
text = ""
if constraint.name is not None:
formatted_name = self.preparer.format_constraint(constraint)
text += "CONSTRAINT %s " % formatted_name
text += "UNIQUE (%s)" % (
', '.join(self.preparer.quote(c.name)
for c in constraint))
text += self.define_constraint_deferrability(constraint)
return text
def define_constraint_cascades(self, constraint):
text = ""
if constraint.ondelete is not None:
text += " ON DELETE %s" % constraint.ondelete
if constraint.onupdate is not None:
text += " ON UPDATE %s" % constraint.onupdate
return text
def define_constraint_deferrability(self, constraint):
text = ""
if constraint.deferrable is not None:
if constraint.deferrable:
text += " DEFERRABLE"
else:
text += " NOT DEFERRABLE"
if constraint.initially is not None:
text += " INITIALLY %s" % constraint.initially
return text
def define_constraint_match(self, constraint):
text = ""
if constraint.match is not None:
text += " MATCH %s" % constraint.match
return text
class GenericTypeCompiler(TypeCompiler):
def visit_FLOAT(self, type_, **kw):
return "FLOAT"
def visit_REAL(self, type_, **kw):
return "REAL"
def visit_NUMERIC(self, type_, **kw):
if type_.precision is None:
return "NUMERIC"
elif type_.scale is None:
return "NUMERIC(%(precision)s)" % \
{'precision': type_.precision}
else:
return "NUMERIC(%(precision)s, %(scale)s)" % \
{'precision': type_.precision,
'scale': type_.scale}
def visit_DECIMAL(self, type_, **kw):
if type_.precision is None:
return "DECIMAL"
elif type_.scale is None:
return "DECIMAL(%(precision)s)" % \
{'precision': type_.precision}
else:
return "DECIMAL(%(precision)s, %(scale)s)" % \
{'precision': type_.precision,
'scale': type_.scale}
def visit_INTEGER(self, type_, **kw):
return "INTEGER"
def visit_SMALLINT(self, type_, **kw):
return "SMALLINT"
def visit_BIGINT(self, type_, **kw):
return "BIGINT"
def visit_TIMESTAMP(self, type_, **kw):
return 'TIMESTAMP'
def visit_DATETIME(self, type_, **kw):
return "DATETIME"
def visit_DATE(self, type_, **kw):
return "DATE"
def visit_TIME(self, type_, **kw):
return "TIME"
def visit_CLOB(self, type_, **kw):
return "CLOB"
def visit_NCLOB(self, type_, **kw):
return "NCLOB"
def _render_string_type(self, type_, name):
text = name
if type_.length:
text += "(%d)" % type_.length
if type_.collation:
text += ' COLLATE "%s"' % type_.collation
return text
def visit_CHAR(self, type_, **kw):
return self._render_string_type(type_, "CHAR")
def visit_NCHAR(self, type_, **kw):
return self._render_string_type(type_, "NCHAR")
def visit_VARCHAR(self, type_, **kw):
return self._render_string_type(type_, "VARCHAR")
def visit_NVARCHAR(self, type_, **kw):
return self._render_string_type(type_, "NVARCHAR")
def visit_TEXT(self, type_, **kw):
return self._render_string_type(type_, "TEXT")
def visit_BLOB(self, type_, **kw):
return "BLOB"
def visit_BINARY(self, type_, **kw):
return "BINARY" + (type_.length and "(%d)" % type_.length or "")
def visit_VARBINARY(self, type_, **kw):
return "VARBINARY" + (type_.length and "(%d)" % type_.length or "")
def visit_BOOLEAN(self, type_, **kw):
return "BOOLEAN"
def visit_large_binary(self, type_, **kw):
return self.visit_BLOB(type_, **kw)
def visit_boolean(self, type_, **kw):
return self.visit_BOOLEAN(type_, **kw)
def visit_time(self, type_, **kw):
return self.visit_TIME(type_, **kw)
def visit_datetime(self, type_, **kw):
return self.visit_DATETIME(type_, **kw)
def visit_date(self, type_, **kw):
return self.visit_DATE(type_, **kw)
def visit_big_integer(self, type_, **kw):
return self.visit_BIGINT(type_, **kw)
def visit_small_integer(self, type_, **kw):
return self.visit_SMALLINT(type_, **kw)
def visit_integer(self, type_, **kw):
return self.visit_INTEGER(type_, **kw)
def visit_real(self, type_, **kw):
return self.visit_REAL(type_, **kw)
def visit_float(self, type_, **kw):
return self.visit_FLOAT(type_, **kw)
def visit_numeric(self, type_, **kw):
return self.visit_NUMERIC(type_, **kw)
def visit_string(self, type_, **kw):
return self.visit_VARCHAR(type_, **kw)
def visit_unicode(self, type_, **kw):
return self.visit_VARCHAR(type_, **kw)
def visit_text(self, type_, **kw):
return self.visit_TEXT(type_, **kw)
def visit_unicode_text(self, type_, **kw):
return self.visit_TEXT(type_, **kw)
def visit_enum(self, type_, **kw):
return self.visit_VARCHAR(type_, **kw)
def visit_null(self, type_, **kw):
raise exc.CompileError("Can't generate DDL for %r; "
"did you forget to specify a "
"type on this Column?" % type_)
def visit_type_decorator(self, type_, **kw):
return self.process(type_.type_engine(self.dialect), **kw)
def visit_user_defined(self, type_, **kw):
return type_.get_col_spec(**kw)
class StrSQLTypeCompiler(GenericTypeCompiler):
def __getattr__(self, key):
if key.startswith("visit_"):
return self._visit_unknown
else:
raise AttributeError(key)
def _visit_unknown(self, type_, **kw):
return "%s" % type_.__class__.__name__
class IdentifierPreparer(object):
"""Handle quoting and case-folding of identifiers based on options."""
reserved_words = RESERVED_WORDS
legal_characters = LEGAL_CHARACTERS
illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS
schema_for_object = schema._schema_getter(None)
def __init__(self, dialect, initial_quote='"',
final_quote=None, escape_quote='"', omit_schema=False):
"""Construct a new ``IdentifierPreparer`` object.
initial_quote
Character that begins a delimited identifier.
final_quote
Character that ends a delimited identifier. Defaults to
`initial_quote`.
omit_schema
Prevent prepending schema name. Useful for databases that do
not support schemae.
"""
self.dialect = dialect
self.initial_quote = initial_quote
self.final_quote = final_quote or self.initial_quote
self.escape_quote = escape_quote
self.escape_to_quote = self.escape_quote * 2
self.omit_schema = omit_schema
self._strings = {}
def _with_schema_translate(self, schema_translate_map):
prep = self.__class__.__new__(self.__class__)
prep.__dict__.update(self.__dict__)
prep.schema_for_object = schema._schema_getter(schema_translate_map)
return prep
def _escape_identifier(self, value):
"""Escape an identifier.
Subclasses should override this to provide database-dependent
escaping behavior.
"""
return value.replace(self.escape_quote, self.escape_to_quote)
def _unescape_identifier(self, value):
"""Canonicalize an escaped identifier.
Subclasses should override this to provide database-dependent
unescaping behavior that reverses _escape_identifier.
"""
return value.replace(self.escape_to_quote, self.escape_quote)
def quote_identifier(self, value):
"""Quote an identifier.
Subclasses should override this to provide database-dependent
quoting behavior.
"""
return self.initial_quote + \
self._escape_identifier(value) + \
self.final_quote
def _requires_quotes(self, value):
"""Return True if the given identifier requires quoting."""
lc_value = value.lower()
return (lc_value in self.reserved_words
or value[0] in self.illegal_initial_characters
or not self.legal_characters.match(util.text_type(value))
or (lc_value != value))
def quote_schema(self, schema, force=None):
"""Conditionally quote a schema.
Subclasses can override this to provide database-dependent
quoting behavior for schema names.
the 'force' flag should be considered deprecated.
"""
return self.quote(schema, force)
def quote(self, ident, force=None):
"""Conditionally quote an identifier.
the 'force' flag should be considered deprecated.
"""
force = getattr(ident, "quote", None)
if force is None:
if ident in self._strings:
return self._strings[ident]
else:
if self._requires_quotes(ident):
self._strings[ident] = self.quote_identifier(ident)
else:
self._strings[ident] = ident
return self._strings[ident]
elif force:
return self.quote_identifier(ident)
else:
return ident
def format_sequence(self, sequence, use_schema=True):
name = self.quote(sequence.name)
effective_schema = self.schema_for_object(sequence)
if (not self.omit_schema and use_schema and
effective_schema is not None):
name = self.quote_schema(effective_schema) + "." + name
return name
def format_label(self, label, name=None):
return self.quote(name or label.name)
def format_alias(self, alias, name=None):
return self.quote(name or alias.name)
def format_savepoint(self, savepoint, name=None):
# Running the savepoint name through quoting is unnecessary
# for all known dialects. This is here to support potential
# third party use cases
ident = name or savepoint.ident
if self._requires_quotes(ident):
ident = self.quote_identifier(ident)
return ident
@util.dependencies("sqlalchemy.sql.naming")
def format_constraint(self, naming, constraint):
if isinstance(constraint.name, elements._defer_name):
name = naming._constraint_name_for_table(
constraint, constraint.table)
if name:
return self.quote(name)
elif isinstance(constraint.name, elements._defer_none_name):
return None
return self.quote(constraint.name)
def format_table(self, table, use_schema=True, name=None):
"""Prepare a quoted table and schema name."""
if name is None:
name = table.name
result = self.quote(name)
effective_schema = self.schema_for_object(table)
if not self.omit_schema and use_schema \
and effective_schema:
result = self.quote_schema(effective_schema) + "." + result
return result
def format_schema(self, name, quote=None):
"""Prepare a quoted schema name."""
return self.quote(name, quote)
def format_column(self, column, use_table=False,
name=None, table_name=None):
"""Prepare a quoted column name."""
if name is None:
name = column.name
if not getattr(column, 'is_literal', False):
if use_table:
return self.format_table(
column.table, use_schema=False,
name=table_name) + "." + self.quote(name)
else:
return self.quote(name)
else:
# literal textual elements get stuck into ColumnClause a lot,
# which shouldn't get quoted
if use_table:
return self.format_table(
column.table, use_schema=False,
name=table_name) + '.' + name
else:
return name
def format_table_seq(self, table, use_schema=True):
"""Format table name and schema as a tuple."""
# Dialects with more levels in their fully qualified references
# ('database', 'owner', etc.) could override this and return
# a longer sequence.
effective_schema = self.schema_for_object(table)
if not self.omit_schema and use_schema and \
effective_schema:
return (self.quote_schema(effective_schema),
self.format_table(table, use_schema=False))
else:
return (self.format_table(table, use_schema=False), )
@util.memoized_property
def _r_identifiers(self):
initial, final, escaped_final = \
[re.escape(s) for s in
(self.initial_quote, self.final_quote,
self._escape_identifier(self.final_quote))]
r = re.compile(
r'(?:'
r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s'
r'|([^\.]+))(?=\.|$))+' %
{'initial': initial,
'final': final,
'escaped': escaped_final})
return r
def unformat_identifiers(self, identifiers):
"""Unpack 'schema.table.column'-like strings into components."""
r = self._r_identifiers
return [self._unescape_identifier(i)
for i in [a or b for a, b in r.findall(identifiers)]]