1059 lines
		
	
	
		
			39 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			1059 lines
		
	
	
		
			39 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# maxdb.py
 | 
						|
#
 | 
						|
# This module is part of SQLAlchemy and is released under
 | 
						|
# the MIT License: http://www.opensource.org/licenses/mit-license.php
 | 
						|
 | 
						|
"""Support for the MaxDB database.
 | 
						|
 | 
						|
This dialect is *not* ported to SQLAlchemy 0.6.
 | 
						|
 | 
						|
This dialect is *not* tested on SQLAlchemy 0.6.
 | 
						|
 | 
						|
Overview
 | 
						|
--------
 | 
						|
 | 
						|
The ``maxdb`` dialect is **experimental** and has only been tested on 7.6.03.007
 | 
						|
and 7.6.00.037.  Of these, **only 7.6.03.007 will work** with SQLAlchemy's ORM.
 | 
						|
The earlier version has severe ``LEFT JOIN`` limitations and will return
 | 
						|
incorrect results from even very simple ORM queries.
 | 
						|
 | 
						|
Only the native Python DB-API is currently supported.  ODBC driver support
 | 
						|
is a future enhancement.
 | 
						|
 | 
						|
Connecting
 | 
						|
----------
 | 
						|
 | 
						|
The username is case-sensitive.  If you usually connect to the
 | 
						|
database with sqlcli and other tools in lower case, you likely need to
 | 
						|
use upper case for DB-API.
 | 
						|
 | 
						|
Implementation Notes
 | 
						|
--------------------
 | 
						|
 | 
						|
Also check the DatabaseNotes page on the wiki for detailed information.
 | 
						|
 | 
						|
With the 7.6.00.37 driver and Python 2.5, it seems that all DB-API
 | 
						|
generated exceptions are broken and can cause Python to crash.
 | 
						|
 | 
						|
For 'somecol.in_([])' to work, the IN operator's generation must be changed
 | 
						|
to cast 'NULL' to a numeric, i.e. NUM(NULL).  The DB-API doesn't accept a
 | 
						|
bind parameter there, so that particular generation must inline the NULL value,
 | 
						|
which depends on [ticket:807].
 | 
						|
 | 
						|
The DB-API is very picky about where bind params may be used in queries.
 | 
						|
 | 
						|
Bind params for some functions (e.g. MOD) need type information supplied.
 | 
						|
The dialect does not yet do this automatically.
 | 
						|
 | 
						|
Max will occasionally throw up 'bad sql, compile again' exceptions for
 | 
						|
perfectly valid SQL.  The dialect does not currently handle these, more
 | 
						|
research is needed.
 | 
						|
 | 
						|
MaxDB 7.5 and Sap DB <= 7.4 reportedly do not support schemas.  A very
 | 
						|
slightly different version of this dialect would be required to support
 | 
						|
those versions, and can easily be added if there is demand.  Some other
 | 
						|
required components such as an Max-aware 'old oracle style' join compiler
 | 
						|
(thetas with (+) outer indicators) are already done and available for
 | 
						|
integration- email the devel list if you're interested in working on
 | 
						|
this.
 | 
						|
 | 
						|
"""
 | 
						|
import datetime, itertools, re
 | 
						|
 | 
						|
from sqlalchemy import exc, schema, sql, util, processors
 | 
						|
from sqlalchemy.sql import operators as sql_operators, expression as sql_expr
 | 
						|
from sqlalchemy.sql import compiler, visitors
 | 
						|
from sqlalchemy.engine import base as engine_base, default, reflection
 | 
						|
from sqlalchemy import types as sqltypes
 | 
						|
 | 
						|
 | 
						|
class _StringType(sqltypes.String):
 | 
						|
    _type = None
 | 
						|
 | 
						|
    def __init__(self, length=None, encoding=None, **kw):
 | 
						|
        super(_StringType, self).__init__(length=length, **kw)
 | 
						|
        self.encoding = encoding
 | 
						|
 | 
						|
    def bind_processor(self, dialect):
 | 
						|
        if self.encoding == 'unicode':
 | 
						|
            return None
 | 
						|
        else:
 | 
						|
            def process(value):
 | 
						|
                if isinstance(value, unicode):
 | 
						|
                    return value.encode(dialect.encoding)
 | 
						|
                else:
 | 
						|
                    return value
 | 
						|
            return process
 | 
						|
 | 
						|
    def result_processor(self, dialect, coltype):
 | 
						|
        #XXX: this code is probably very slow and one should try (if at all
 | 
						|
        # possible) to determine the correct code path on a per-connection
 | 
						|
        # basis (ie, here in result_processor, instead of inside the processor
 | 
						|
        # function itself) and probably also use a few generic
 | 
						|
        # processors, or possibly per query (though there is no mechanism
 | 
						|
        # for that yet).
 | 
						|
        def process(value):
 | 
						|
            while True:
 | 
						|
                if value is None:
 | 
						|
                    return None
 | 
						|
                elif isinstance(value, unicode):
 | 
						|
                    return value
 | 
						|
                elif isinstance(value, str):
 | 
						|
                    if self.convert_unicode or dialect.convert_unicode:
 | 
						|
                        return value.decode(dialect.encoding)
 | 
						|
                    else:
 | 
						|
                        return value
 | 
						|
                elif hasattr(value, 'read'):
 | 
						|
                    # some sort of LONG, snarf and retry
 | 
						|
                    value = value.read(value.remainingLength())
 | 
						|
                    continue
 | 
						|
                else:
 | 
						|
                    # unexpected type, return as-is
 | 
						|
                    return value
 | 
						|
        return process
 | 
						|
 | 
						|
 | 
						|
class MaxString(_StringType):
 | 
						|
    _type = 'VARCHAR'
 | 
						|
 | 
						|
    def __init__(self, *a, **kw):
 | 
						|
        super(MaxString, self).__init__(*a, **kw)
 | 
						|
 | 
						|
 | 
						|
class MaxUnicode(_StringType):
 | 
						|
    _type = 'VARCHAR'
 | 
						|
 | 
						|
    def __init__(self, length=None, **kw):
 | 
						|
        super(MaxUnicode, self).__init__(length=length, encoding='unicode')
 | 
						|
 | 
						|
 | 
						|
class MaxChar(_StringType):
 | 
						|
    _type = 'CHAR'
 | 
						|
 | 
						|
 | 
						|
class MaxText(_StringType):
 | 
						|
    _type = 'LONG'
 | 
						|
 | 
						|
    def __init__(self, *a, **kw):
 | 
						|
        super(MaxText, self).__init__(*a, **kw)
 | 
						|
 | 
						|
    def get_col_spec(self):
 | 
						|
        spec = 'LONG'
 | 
						|
        if self.encoding is not None:
 | 
						|
            spec = ' '.join((spec, self.encoding))
 | 
						|
        elif self.convert_unicode:
 | 
						|
            spec = ' '.join((spec, 'UNICODE'))
 | 
						|
 | 
						|
        return spec
 | 
						|
 | 
						|
 | 
						|
class MaxNumeric(sqltypes.Numeric):
 | 
						|
    """The FIXED (also NUMERIC, DECIMAL) data type."""
 | 
						|
 | 
						|
    def __init__(self, precision=None, scale=None, **kw):
 | 
						|
        kw.setdefault('asdecimal', True)
 | 
						|
        super(MaxNumeric, self).__init__(scale=scale, precision=precision,
 | 
						|
                                         **kw)
 | 
						|
 | 
						|
    def bind_processor(self, dialect):
 | 
						|
        return None
 | 
						|
 | 
						|
 | 
						|
class MaxTimestamp(sqltypes.DateTime):
 | 
						|
    def bind_processor(self, dialect):
 | 
						|
        def process(value):
 | 
						|
            if value is None:
 | 
						|
                return None
 | 
						|
            elif isinstance(value, basestring):
 | 
						|
                return value
 | 
						|
            elif dialect.datetimeformat == 'internal':
 | 
						|
                ms = getattr(value, 'microsecond', 0)
 | 
						|
                return value.strftime("%Y%m%d%H%M%S" + ("%06u" % ms))
 | 
						|
            elif dialect.datetimeformat == 'iso':
 | 
						|
                ms = getattr(value, 'microsecond', 0)
 | 
						|
                return value.strftime("%Y-%m-%d %H:%M:%S." + ("%06u" % ms))
 | 
						|
            else:
 | 
						|
                raise exc.InvalidRequestError(
 | 
						|
                    "datetimeformat '%s' is not supported." % (
 | 
						|
                    dialect.datetimeformat,))
 | 
						|
        return process
 | 
						|
 | 
						|
    def result_processor(self, dialect, coltype):
 | 
						|
        if dialect.datetimeformat == 'internal':
 | 
						|
            def process(value):
 | 
						|
                if value is None:
 | 
						|
                    return None
 | 
						|
                else:
 | 
						|
                    return datetime.datetime(
 | 
						|
                        *[int(v)
 | 
						|
                          for v in (value[0:4], value[4:6], value[6:8],
 | 
						|
                                    value[8:10], value[10:12], value[12:14],
 | 
						|
                                    value[14:])])
 | 
						|
        elif dialect.datetimeformat == 'iso':
 | 
						|
            def process(value):
 | 
						|
                if value is None:
 | 
						|
                    return None
 | 
						|
                else:
 | 
						|
                    return datetime.datetime(
 | 
						|
                        *[int(v)
 | 
						|
                          for v in (value[0:4], value[5:7], value[8:10],
 | 
						|
                                    value[11:13], value[14:16], value[17:19],
 | 
						|
                                    value[20:])])
 | 
						|
        else:
 | 
						|
            raise exc.InvalidRequestError(
 | 
						|
                "datetimeformat '%s' is not supported." % 
 | 
						|
                dialect.datetimeformat)
 | 
						|
        return process
 | 
						|
 | 
						|
 | 
						|
class MaxDate(sqltypes.Date):
 | 
						|
    def bind_processor(self, dialect):
 | 
						|
        def process(value):
 | 
						|
            if value is None:
 | 
						|
                return None
 | 
						|
            elif isinstance(value, basestring):
 | 
						|
                return value
 | 
						|
            elif dialect.datetimeformat == 'internal':
 | 
						|
                return value.strftime("%Y%m%d")
 | 
						|
            elif dialect.datetimeformat == 'iso':
 | 
						|
                return value.strftime("%Y-%m-%d")
 | 
						|
            else:
 | 
						|
                raise exc.InvalidRequestError(
 | 
						|
                    "datetimeformat '%s' is not supported." % (
 | 
						|
                    dialect.datetimeformat,))
 | 
						|
        return process
 | 
						|
 | 
						|
    def result_processor(self, dialect, coltype):
 | 
						|
        if dialect.datetimeformat == 'internal':
 | 
						|
            def process(value):
 | 
						|
                if value is None:
 | 
						|
                    return None
 | 
						|
                else:
 | 
						|
                    return datetime.date(int(value[0:4]), int(value[4:6]), 
 | 
						|
                                         int(value[6:8]))
 | 
						|
        elif dialect.datetimeformat == 'iso':
 | 
						|
            def process(value):
 | 
						|
                if value is None:
 | 
						|
                    return None
 | 
						|
                else:
 | 
						|
                    return datetime.date(int(value[0:4]), int(value[5:7]), 
 | 
						|
                                         int(value[8:10]))
 | 
						|
        else:
 | 
						|
            raise exc.InvalidRequestError(
 | 
						|
                "datetimeformat '%s' is not supported." % 
 | 
						|
                dialect.datetimeformat)
 | 
						|
        return process
 | 
						|
 | 
						|
 | 
						|
class MaxTime(sqltypes.Time):
 | 
						|
    def bind_processor(self, dialect):
 | 
						|
        def process(value):
 | 
						|
            if value is None:
 | 
						|
                return None
 | 
						|
            elif isinstance(value, basestring):
 | 
						|
                return value
 | 
						|
            elif dialect.datetimeformat == 'internal':
 | 
						|
                return value.strftime("%H%M%S")
 | 
						|
            elif dialect.datetimeformat == 'iso':
 | 
						|
                return value.strftime("%H-%M-%S")
 | 
						|
            else:
 | 
						|
                raise exc.InvalidRequestError(
 | 
						|
                    "datetimeformat '%s' is not supported." % (
 | 
						|
                    dialect.datetimeformat,))
 | 
						|
        return process
 | 
						|
 | 
						|
    def result_processor(self, dialect, coltype):
 | 
						|
        if dialect.datetimeformat == 'internal':
 | 
						|
            def process(value):
 | 
						|
                if value is None:
 | 
						|
                    return None
 | 
						|
                else:
 | 
						|
                    return datetime.time(int(value[0:4]), int(value[4:6]), 
 | 
						|
                                         int(value[6:8]))
 | 
						|
        elif dialect.datetimeformat == 'iso':
 | 
						|
            def process(value):
 | 
						|
                if value is None:
 | 
						|
                    return None
 | 
						|
                else:
 | 
						|
                    return datetime.time(int(value[0:4]), int(value[5:7]),
 | 
						|
                                         int(value[8:10]))
 | 
						|
        else:
 | 
						|
            raise exc.InvalidRequestError(
 | 
						|
                "datetimeformat '%s' is not supported." % 
 | 
						|
                dialect.datetimeformat)
 | 
						|
        return process
 | 
						|
 | 
						|
 | 
						|
class MaxBlob(sqltypes.LargeBinary):
 | 
						|
    def bind_processor(self, dialect):
 | 
						|
        return processors.to_str
 | 
						|
 | 
						|
    def result_processor(self, dialect, coltype):
 | 
						|
        def process(value):
 | 
						|
            if value is None:
 | 
						|
                return None
 | 
						|
            else:
 | 
						|
                return value.read(value.remainingLength())
 | 
						|
        return process
 | 
						|
 | 
						|
class MaxDBTypeCompiler(compiler.GenericTypeCompiler):
 | 
						|
    def _string_spec(self, string_spec, type_):
 | 
						|
        if type_.length is None:
 | 
						|
            spec = 'LONG'
 | 
						|
        else:
 | 
						|
            spec = '%s(%s)' % (string_spec, type_.length)
 | 
						|
 | 
						|
        if getattr(type_, 'encoding'):
 | 
						|
            spec = ' '.join([spec, getattr(type_, 'encoding').upper()])
 | 
						|
        return spec
 | 
						|
 | 
						|
    def visit_text(self, type_):
 | 
						|
        spec = 'LONG'
 | 
						|
        if getattr(type_, 'encoding', None):
 | 
						|
            spec = ' '.join((spec, type_.encoding))
 | 
						|
        elif type_.convert_unicode:
 | 
						|
            spec = ' '.join((spec, 'UNICODE'))
 | 
						|
 | 
						|
        return spec
 | 
						|
 | 
						|
    def visit_char(self, type_):
 | 
						|
        return self._string_spec("CHAR", type_)
 | 
						|
 | 
						|
    def visit_string(self, type_):
 | 
						|
        return self._string_spec("VARCHAR", type_)
 | 
						|
 | 
						|
    def visit_large_binary(self, type_):
 | 
						|
        return "LONG BYTE"
 | 
						|
    
 | 
						|
    def visit_numeric(self, type_):
 | 
						|
        if type_.scale and type_.precision:
 | 
						|
            return 'FIXED(%s, %s)' % (type_.precision, type_.scale)
 | 
						|
        elif type_.precision:
 | 
						|
            return 'FIXED(%s)' % type_.precision
 | 
						|
        else:
 | 
						|
            return 'INTEGER'
 | 
						|
    
 | 
						|
    def visit_BOOLEAN(self, type_):
 | 
						|
        return "BOOLEAN"
 | 
						|
        
 | 
						|
colspecs = {
 | 
						|
    sqltypes.Numeric: MaxNumeric,
 | 
						|
    sqltypes.DateTime: MaxTimestamp,
 | 
						|
    sqltypes.Date: MaxDate,
 | 
						|
    sqltypes.Time: MaxTime,
 | 
						|
    sqltypes.String: MaxString,
 | 
						|
    sqltypes.Unicode:MaxUnicode,
 | 
						|
    sqltypes.LargeBinary: MaxBlob,
 | 
						|
    sqltypes.Text: MaxText,
 | 
						|
    sqltypes.CHAR: MaxChar,
 | 
						|
    sqltypes.TIMESTAMP: MaxTimestamp,
 | 
						|
    sqltypes.BLOB: MaxBlob,
 | 
						|
    sqltypes.Unicode: MaxUnicode,
 | 
						|
    }
 | 
						|
 | 
						|
ischema_names = {
 | 
						|
    'boolean': sqltypes.BOOLEAN,
 | 
						|
    'char': sqltypes.CHAR,
 | 
						|
    'character': sqltypes.CHAR,
 | 
						|
    'date': sqltypes.DATE,
 | 
						|
    'fixed': sqltypes.Numeric,
 | 
						|
    'float': sqltypes.FLOAT,
 | 
						|
    'int': sqltypes.INT,
 | 
						|
    'integer': sqltypes.INT,
 | 
						|
    'long binary': sqltypes.BLOB,
 | 
						|
    'long unicode': sqltypes.Text,
 | 
						|
    'long': sqltypes.Text,
 | 
						|
    'long': sqltypes.Text,
 | 
						|
    'smallint': sqltypes.SmallInteger,
 | 
						|
    'time': sqltypes.Time,
 | 
						|
    'timestamp': sqltypes.TIMESTAMP,
 | 
						|
    'varchar': sqltypes.VARCHAR,
 | 
						|
    }
 | 
						|
 | 
						|
# TODO: migrate this to sapdb.py
 | 
						|
class MaxDBExecutionContext(default.DefaultExecutionContext):
 | 
						|
    def post_exec(self):
 | 
						|
        # DB-API bug: if there were any functions as values,
 | 
						|
        # then do another select and pull CURRVAL from the
 | 
						|
        # autoincrement column's implicit sequence... ugh
 | 
						|
        if self.compiled.isinsert and not self.executemany:
 | 
						|
            table = self.compiled.statement.table
 | 
						|
            index, serial_col = _autoserial_column(table)
 | 
						|
 | 
						|
            if serial_col and (not self.compiled._safeserial or
 | 
						|
                               not(self._last_inserted_ids) or
 | 
						|
                               self._last_inserted_ids[index] in (None, 0)):
 | 
						|
                if table.schema:
 | 
						|
                    sql = "SELECT %s.CURRVAL FROM DUAL" % (
 | 
						|
                        self.compiled.preparer.format_table(table))
 | 
						|
                else:
 | 
						|
                    sql = "SELECT CURRENT_SCHEMA.%s.CURRVAL FROM DUAL" % (
 | 
						|
                        self.compiled.preparer.format_table(table))
 | 
						|
 | 
						|
                rs = self.cursor.execute(sql)
 | 
						|
                id = rs.fetchone()[0]
 | 
						|
 | 
						|
                if not self._last_inserted_ids:
 | 
						|
                    # This shouldn't ever be > 1?  Right?
 | 
						|
                    self._last_inserted_ids = \
 | 
						|
                      [None] * len(table.primary_key.columns)
 | 
						|
                self._last_inserted_ids[index] = id
 | 
						|
 | 
						|
        super(MaxDBExecutionContext, self).post_exec()
 | 
						|
 | 
						|
    def get_result_proxy(self):
 | 
						|
        if self.cursor.description is not None:
 | 
						|
            for column in self.cursor.description:
 | 
						|
                if column[1] in ('Long Binary', 'Long', 'Long Unicode'):
 | 
						|
                    return MaxDBResultProxy(self)
 | 
						|
        return engine_base.ResultProxy(self)
 | 
						|
 | 
						|
    @property
 | 
						|
    def rowcount(self):
 | 
						|
        if hasattr(self, '_rowcount'):
 | 
						|
            return self._rowcount
 | 
						|
        else:
 | 
						|
            return self.cursor.rowcount
 | 
						|
 | 
						|
    def fire_sequence(self, seq):
 | 
						|
        if seq.optional:
 | 
						|
            return None
 | 
						|
        return self._execute_scalar("SELECT %s.NEXTVAL FROM DUAL" % (
 | 
						|
            self.dialect.identifier_preparer.format_sequence(seq)))
 | 
						|
 | 
						|
class MaxDBCachedColumnRow(engine_base.RowProxy):
 | 
						|
    """A RowProxy that only runs result_processors once per column."""
 | 
						|
 | 
						|
    def __init__(self, parent, row):
 | 
						|
        super(MaxDBCachedColumnRow, self).__init__(parent, row)
 | 
						|
        self.columns = {}
 | 
						|
        self._row = row
 | 
						|
        self._parent = parent
 | 
						|
 | 
						|
    def _get_col(self, key):
 | 
						|
        if key not in self.columns:
 | 
						|
            self.columns[key] = self._parent._get_col(self._row, key)
 | 
						|
        return self.columns[key]
 | 
						|
 | 
						|
    def __iter__(self):
 | 
						|
        for i in xrange(len(self._row)):
 | 
						|
            yield self._get_col(i)
 | 
						|
 | 
						|
    def __repr__(self):
 | 
						|
        return repr(list(self))
 | 
						|
 | 
						|
    def __eq__(self, other):
 | 
						|
        return ((other is self) or
 | 
						|
                (other == tuple([self._get_col(key)
 | 
						|
                                 for key in xrange(len(self._row))])))
 | 
						|
    def __getitem__(self, key):
 | 
						|
        if isinstance(key, slice):
 | 
						|
            indices = key.indices(len(self._row))
 | 
						|
            return tuple([self._get_col(i) for i in xrange(*indices)])
 | 
						|
        else:
 | 
						|
            return self._get_col(key)
 | 
						|
 | 
						|
    def __getattr__(self, name):
 | 
						|
        try:
 | 
						|
            return self._get_col(name)
 | 
						|
        except KeyError:
 | 
						|
            raise AttributeError(name)
 | 
						|
 | 
						|
 | 
						|
class MaxDBResultProxy(engine_base.ResultProxy):
 | 
						|
    _process_row = MaxDBCachedColumnRow
 | 
						|
 | 
						|
class MaxDBCompiler(compiler.SQLCompiler):
 | 
						|
 | 
						|
    function_conversion = {
 | 
						|
        'CURRENT_DATE': 'DATE',
 | 
						|
        'CURRENT_TIME': 'TIME',
 | 
						|
        'CURRENT_TIMESTAMP': 'TIMESTAMP',
 | 
						|
        }
 | 
						|
 | 
						|
    # These functions must be written without parens when called with no
 | 
						|
    # parameters.  e.g. 'SELECT DATE FROM DUAL' not 'SELECT DATE() FROM DUAL'
 | 
						|
    bare_functions = set([
 | 
						|
        'CURRENT_SCHEMA', 'DATE', 'FALSE', 'SYSDBA', 'TIME', 'TIMESTAMP',
 | 
						|
        'TIMEZONE', 'TRANSACTION', 'TRUE', 'USER', 'UID', 'USERGROUP',
 | 
						|
        'UTCDATE', 'UTCDIFF'])
 | 
						|
 | 
						|
    def visit_mod(self, binary, **kw):
 | 
						|
        return "mod(%s, %s)" % (self.process(binary.left), self.process(binary.right))
 | 
						|
        
 | 
						|
    def default_from(self):
 | 
						|
        return ' FROM DUAL'
 | 
						|
 | 
						|
    def for_update_clause(self, select):
 | 
						|
        clause = select.for_update
 | 
						|
        if clause is True:
 | 
						|
            return " WITH LOCK EXCLUSIVE"
 | 
						|
        elif clause is None:
 | 
						|
            return ""
 | 
						|
        elif clause == "read":
 | 
						|
            return " WITH LOCK"
 | 
						|
        elif clause == "ignore":
 | 
						|
            return " WITH LOCK (IGNORE) EXCLUSIVE"
 | 
						|
        elif clause == "nowait":
 | 
						|
            return " WITH LOCK (NOWAIT) EXCLUSIVE"
 | 
						|
        elif isinstance(clause, basestring):
 | 
						|
            return " WITH LOCK %s" % clause.upper()
 | 
						|
        elif not clause:
 | 
						|
            return ""
 | 
						|
        else:
 | 
						|
            return " WITH LOCK EXCLUSIVE"
 | 
						|
 | 
						|
    def function_argspec(self, fn, **kw):
 | 
						|
        if fn.name.upper() in self.bare_functions:
 | 
						|
            return ""
 | 
						|
        elif len(fn.clauses) > 0:
 | 
						|
            return compiler.SQLCompiler.function_argspec(self, fn, **kw)
 | 
						|
        else:
 | 
						|
            return ""
 | 
						|
 | 
						|
    def visit_function(self, fn, **kw):
 | 
						|
        transform = self.function_conversion.get(fn.name.upper(), None)
 | 
						|
        if transform:
 | 
						|
            fn = fn._clone()
 | 
						|
            fn.name = transform
 | 
						|
        return super(MaxDBCompiler, self).visit_function(fn, **kw)
 | 
						|
 | 
						|
    def visit_cast(self, cast, **kwargs):
 | 
						|
        # MaxDB only supports casts * to NUMERIC, * to VARCHAR or
 | 
						|
        # date/time to VARCHAR.  Casts of LONGs will fail.
 | 
						|
        if isinstance(cast.type, (sqltypes.Integer, sqltypes.Numeric)):
 | 
						|
            return "NUM(%s)" % self.process(cast.clause)
 | 
						|
        elif isinstance(cast.type, sqltypes.String):
 | 
						|
            return "CHR(%s)" % self.process(cast.clause)
 | 
						|
        else:
 | 
						|
            return self.process(cast.clause)
 | 
						|
 | 
						|
    def visit_sequence(self, sequence):
 | 
						|
        if sequence.optional:
 | 
						|
            return None
 | 
						|
        else:
 | 
						|
            return (self.dialect.identifier_preparer.format_sequence(sequence) +
 | 
						|
                    ".NEXTVAL")
 | 
						|
 | 
						|
    class ColumnSnagger(visitors.ClauseVisitor):
 | 
						|
        def __init__(self):
 | 
						|
            self.count = 0
 | 
						|
            self.column = None
 | 
						|
        def visit_column(self, column):
 | 
						|
            self.column = column
 | 
						|
            self.count += 1
 | 
						|
 | 
						|
    def _find_labeled_columns(self, columns, use_labels=False):
 | 
						|
        labels = {}
 | 
						|
        for column in columns:
 | 
						|
            if isinstance(column, basestring):
 | 
						|
                continue
 | 
						|
            snagger = self.ColumnSnagger()
 | 
						|
            snagger.traverse(column)
 | 
						|
            if snagger.count == 1:
 | 
						|
                if isinstance(column, sql_expr._Label):
 | 
						|
                    labels[unicode(snagger.column)] = column.name
 | 
						|
                elif use_labels:
 | 
						|
                    labels[unicode(snagger.column)] = column._label
 | 
						|
 | 
						|
        return labels
 | 
						|
 | 
						|
    def order_by_clause(self, select, **kw):
 | 
						|
        order_by = self.process(select._order_by_clause, **kw)
 | 
						|
 | 
						|
        # ORDER BY clauses in DISTINCT queries must reference aliased
 | 
						|
        # inner columns by alias name, not true column name.
 | 
						|
        if order_by and getattr(select, '_distinct', False):
 | 
						|
            labels = self._find_labeled_columns(select.inner_columns,
 | 
						|
                                                select.use_labels)
 | 
						|
            if labels:
 | 
						|
                for needs_alias in labels.keys():
 | 
						|
                    r = re.compile(r'(^| )(%s)(,| |$)' %
 | 
						|
                                   re.escape(needs_alias))
 | 
						|
                    order_by = r.sub((r'\1%s\3' % labels[needs_alias]),
 | 
						|
                                     order_by)
 | 
						|
 | 
						|
        # No ORDER BY in subqueries.
 | 
						|
        if order_by:
 | 
						|
            if self.is_subquery():
 | 
						|
                # It's safe to simply drop the ORDER BY if there is no
 | 
						|
                # LIMIT.  Right?  Other dialects seem to get away with
 | 
						|
                # dropping order.
 | 
						|
                if select._limit:
 | 
						|
                    raise exc.InvalidRequestError(
 | 
						|
                        "MaxDB does not support ORDER BY in subqueries")
 | 
						|
                else:
 | 
						|
                    return ""
 | 
						|
            return " ORDER BY " + order_by
 | 
						|
        else:
 | 
						|
            return ""
 | 
						|
 | 
						|
    def get_select_precolumns(self, select):
 | 
						|
        # Convert a subquery's LIMIT to TOP
 | 
						|
        sql = select._distinct and 'DISTINCT ' or ''
 | 
						|
        if self.is_subquery() and select._limit:
 | 
						|
            if select._offset:
 | 
						|
                raise exc.InvalidRequestError(
 | 
						|
                    'MaxDB does not support LIMIT with an offset.')
 | 
						|
            sql += 'TOP %s ' % select._limit
 | 
						|
        return sql
 | 
						|
 | 
						|
    def limit_clause(self, select):
 | 
						|
        # The docs say offsets are supported with LIMIT.  But they're not.
 | 
						|
        # TODO: maybe emulate by adding a ROWNO/ROWNUM predicate?
 | 
						|
        if self.is_subquery():
 | 
						|
            # sub queries need TOP
 | 
						|
            return ''
 | 
						|
        elif select._offset:
 | 
						|
            raise exc.InvalidRequestError(
 | 
						|
                'MaxDB does not support LIMIT with an offset.')
 | 
						|
        else:
 | 
						|
            return ' \n LIMIT %s' % (select._limit,)
 | 
						|
 | 
						|
    def visit_insert(self, insert):
 | 
						|
        self.isinsert = True
 | 
						|
        self._safeserial = True
 | 
						|
 | 
						|
        colparams = self._get_colparams(insert)
 | 
						|
        for value in (insert.parameters or {}).itervalues():
 | 
						|
            if isinstance(value, sql_expr.Function):
 | 
						|
                self._safeserial = False
 | 
						|
                break
 | 
						|
 | 
						|
        return ''.join(('INSERT INTO ',
 | 
						|
                         self.preparer.format_table(insert.table),
 | 
						|
                         ' (',
 | 
						|
                         ', '.join([self.preparer.format_column(c[0])
 | 
						|
                                    for c in colparams]),
 | 
						|
                         ') VALUES (',
 | 
						|
                         ', '.join([c[1] for c in colparams]),
 | 
						|
                         ')'))
 | 
						|
 | 
						|
 | 
						|
class MaxDBIdentifierPreparer(compiler.IdentifierPreparer):
 | 
						|
    reserved_words = set([
 | 
						|
        'abs', 'absolute', 'acos', 'adddate', 'addtime', 'all', 'alpha',
 | 
						|
        'alter', 'any', 'ascii', 'asin', 'atan', 'atan2', 'avg', 'binary',
 | 
						|
        'bit', 'boolean', 'byte', 'case', 'ceil', 'ceiling', 'char',
 | 
						|
        'character', 'check', 'chr', 'column', 'concat', 'constraint', 'cos',
 | 
						|
        'cosh', 'cot', 'count', 'cross', 'curdate', 'current', 'curtime',
 | 
						|
        'database', 'date', 'datediff', 'day', 'dayname', 'dayofmonth',
 | 
						|
        'dayofweek', 'dayofyear', 'dec', 'decimal', 'decode', 'default',
 | 
						|
        'degrees', 'delete', 'digits', 'distinct', 'double', 'except',
 | 
						|
        'exists', 'exp', 'expand', 'first', 'fixed', 'float', 'floor', 'for',
 | 
						|
        'from', 'full', 'get_objectname', 'get_schema', 'graphic', 'greatest',
 | 
						|
        'group', 'having', 'hex', 'hextoraw', 'hour', 'ifnull', 'ignore',
 | 
						|
        'index', 'initcap', 'inner', 'insert', 'int', 'integer', 'internal',
 | 
						|
        'intersect', 'into', 'join', 'key', 'last', 'lcase', 'least', 'left',
 | 
						|
        'length', 'lfill', 'list', 'ln', 'locate', 'log', 'log10', 'long',
 | 
						|
        'longfile', 'lower', 'lpad', 'ltrim', 'makedate', 'maketime',
 | 
						|
        'mapchar', 'max', 'mbcs', 'microsecond', 'min', 'minute', 'mod',
 | 
						|
        'month', 'monthname', 'natural', 'nchar', 'next', 'no', 'noround',
 | 
						|
        'not', 'now', 'null', 'num', 'numeric', 'object', 'of', 'on',
 | 
						|
        'order', 'packed', 'pi', 'power', 'prev', 'primary', 'radians',
 | 
						|
        'real', 'reject', 'relative', 'replace', 'rfill', 'right', 'round',
 | 
						|
        'rowid', 'rowno', 'rpad', 'rtrim', 'second', 'select', 'selupd',
 | 
						|
        'serial', 'set', 'show', 'sign', 'sin', 'sinh', 'smallint', 'some',
 | 
						|
        'soundex', 'space', 'sqrt', 'stamp', 'statistics', 'stddev',
 | 
						|
        'subdate', 'substr', 'substring', 'subtime', 'sum', 'sysdba',
 | 
						|
        'table', 'tan', 'tanh', 'time', 'timediff', 'timestamp', 'timezone',
 | 
						|
        'to', 'toidentifier', 'transaction', 'translate', 'trim', 'trunc',
 | 
						|
        'truncate', 'ucase', 'uid', 'unicode', 'union', 'update', 'upper',
 | 
						|
        'user', 'usergroup', 'using', 'utcdate', 'utcdiff', 'value', 'values',
 | 
						|
        'varchar', 'vargraphic', 'variance', 'week', 'weekofyear', 'when',
 | 
						|
        'where', 'with', 'year', 'zoned' ])
 | 
						|
 | 
						|
    def _normalize_name(self, name):
 | 
						|
        if name is None:
 | 
						|
            return None
 | 
						|
        if name.isupper():
 | 
						|
            lc_name = name.lower()
 | 
						|
            if not self._requires_quotes(lc_name):
 | 
						|
                return lc_name
 | 
						|
        return name
 | 
						|
 | 
						|
    def _denormalize_name(self, name):
 | 
						|
        if name is None:
 | 
						|
            return None
 | 
						|
        elif (name.islower() and
 | 
						|
              not self._requires_quotes(name)):
 | 
						|
            return name.upper()
 | 
						|
        else:
 | 
						|
            return name
 | 
						|
 | 
						|
    def _maybe_quote_identifier(self, name):
 | 
						|
        if self._requires_quotes(name):
 | 
						|
            return self.quote_identifier(name)
 | 
						|
        else:
 | 
						|
            return name
 | 
						|
 | 
						|
 | 
						|
class MaxDBDDLCompiler(compiler.DDLCompiler):
 | 
						|
    def get_column_specification(self, column, **kw):
 | 
						|
        colspec = [self.preparer.format_column(column),
 | 
						|
                   self.dialect.type_compiler.process(column.type)]
 | 
						|
 | 
						|
        if not column.nullable:
 | 
						|
            colspec.append('NOT NULL')
 | 
						|
 | 
						|
        default = column.default
 | 
						|
        default_str = self.get_column_default_string(column)
 | 
						|
 | 
						|
        # No DDL default for columns specified with non-optional sequence-
 | 
						|
        # this defaulting behavior is entirely client-side. (And as a
 | 
						|
        # consequence, non-reflectable.)
 | 
						|
        if (default and isinstance(default, schema.Sequence) and
 | 
						|
            not default.optional):
 | 
						|
            pass
 | 
						|
        # Regular default
 | 
						|
        elif default_str is not None:
 | 
						|
            colspec.append('DEFAULT %s' % default_str)
 | 
						|
        # Assign DEFAULT SERIAL heuristically
 | 
						|
        elif column.primary_key and column.autoincrement:
 | 
						|
            # For SERIAL on a non-primary key member, use
 | 
						|
            # DefaultClause(text('SERIAL'))
 | 
						|
            try:
 | 
						|
                first = [c for c in column.table.primary_key.columns
 | 
						|
                         if (c.autoincrement and
 | 
						|
                             (isinstance(c.type, sqltypes.Integer) or
 | 
						|
                              (isinstance(c.type, MaxNumeric) and
 | 
						|
                               c.type.precision)) and
 | 
						|
                             not c.foreign_keys)].pop(0)
 | 
						|
                if column is first:
 | 
						|
                    colspec.append('DEFAULT SERIAL')
 | 
						|
            except IndexError:
 | 
						|
                pass
 | 
						|
        return ' '.join(colspec)
 | 
						|
 | 
						|
    def get_column_default_string(self, column):
 | 
						|
        if isinstance(column.server_default, schema.DefaultClause):
 | 
						|
            if isinstance(column.default.arg, basestring):
 | 
						|
                if isinstance(column.type, sqltypes.Integer):
 | 
						|
                    return str(column.default.arg)
 | 
						|
                else:
 | 
						|
                    return "'%s'" % column.default.arg
 | 
						|
            else:
 | 
						|
                return unicode(self._compile(column.default.arg, None))
 | 
						|
        else:
 | 
						|
            return None
 | 
						|
 | 
						|
    def visit_create_sequence(self, create):
 | 
						|
        """Creates a SEQUENCE.
 | 
						|
 | 
						|
        TODO: move to module doc?
 | 
						|
 | 
						|
        start
 | 
						|
          With an integer value, set the START WITH option.
 | 
						|
 | 
						|
        increment
 | 
						|
          An integer value to increment by.  Default is the database default.
 | 
						|
 | 
						|
        maxdb_minvalue
 | 
						|
        maxdb_maxvalue
 | 
						|
          With an integer value, sets the corresponding sequence option.
 | 
						|
 | 
						|
        maxdb_no_minvalue
 | 
						|
        maxdb_no_maxvalue
 | 
						|
          Defaults to False.  If true, sets the corresponding sequence option.
 | 
						|
 | 
						|
        maxdb_cycle
 | 
						|
          Defaults to False.  If true, sets the CYCLE option.
 | 
						|
 | 
						|
        maxdb_cache
 | 
						|
          With an integer value, sets the CACHE option.
 | 
						|
 | 
						|
        maxdb_no_cache
 | 
						|
          Defaults to False.  If true, sets NOCACHE.
 | 
						|
        """
 | 
						|
        sequence = create.element
 | 
						|
        
 | 
						|
        if (not sequence.optional and
 | 
						|
            (not self.checkfirst or
 | 
						|
             not self.dialect.has_sequence(self.connection, sequence.name))):
 | 
						|
 | 
						|
            ddl = ['CREATE SEQUENCE',
 | 
						|
                   self.preparer.format_sequence(sequence)]
 | 
						|
 | 
						|
            sequence.increment = 1
 | 
						|
 | 
						|
            if sequence.increment is not None:
 | 
						|
                ddl.extend(('INCREMENT BY', str(sequence.increment)))
 | 
						|
 | 
						|
            if sequence.start is not None:
 | 
						|
                ddl.extend(('START WITH', str(sequence.start)))
 | 
						|
 | 
						|
            opts = dict([(pair[0][6:].lower(), pair[1])
 | 
						|
                         for pair in sequence.kwargs.items()
 | 
						|
                         if pair[0].startswith('maxdb_')])
 | 
						|
 | 
						|
            if 'maxvalue' in opts:
 | 
						|
                ddl.extend(('MAXVALUE', str(opts['maxvalue'])))
 | 
						|
            elif opts.get('no_maxvalue', False):
 | 
						|
                ddl.append('NOMAXVALUE')
 | 
						|
            if 'minvalue' in opts:
 | 
						|
                ddl.extend(('MINVALUE', str(opts['minvalue'])))
 | 
						|
            elif opts.get('no_minvalue', False):
 | 
						|
                ddl.append('NOMINVALUE')
 | 
						|
 | 
						|
            if opts.get('cycle', False):
 | 
						|
                ddl.append('CYCLE')
 | 
						|
 | 
						|
            if 'cache' in opts:
 | 
						|
                ddl.extend(('CACHE', str(opts['cache'])))
 | 
						|
            elif opts.get('no_cache', False):
 | 
						|
                ddl.append('NOCACHE')
 | 
						|
 | 
						|
            return ' '.join(ddl)
 | 
						|
 | 
						|
 | 
						|
class MaxDBDialect(default.DefaultDialect):
 | 
						|
    name = 'maxdb'
 | 
						|
    supports_alter = True
 | 
						|
    supports_unicode_statements = True
 | 
						|
    max_identifier_length = 32
 | 
						|
    supports_sane_rowcount = True
 | 
						|
    supports_sane_multi_rowcount = False
 | 
						|
 | 
						|
    preparer = MaxDBIdentifierPreparer
 | 
						|
    statement_compiler = MaxDBCompiler
 | 
						|
    ddl_compiler = MaxDBDDLCompiler
 | 
						|
    execution_ctx_cls = MaxDBExecutionContext
 | 
						|
 | 
						|
    ported_sqla_06 = False
 | 
						|
 | 
						|
    colspecs = colspecs
 | 
						|
    ischema_names = ischema_names
 | 
						|
    
 | 
						|
    # MaxDB-specific
 | 
						|
    datetimeformat = 'internal'
 | 
						|
 | 
						|
    def __init__(self, _raise_known_sql_errors=False, **kw):
 | 
						|
        super(MaxDBDialect, self).__init__(**kw)
 | 
						|
        self._raise_known = _raise_known_sql_errors
 | 
						|
 | 
						|
        if self.dbapi is None:
 | 
						|
            self.dbapi_type_map = {}
 | 
						|
        else:
 | 
						|
            self.dbapi_type_map = {
 | 
						|
                'Long Binary': MaxBlob(),
 | 
						|
                'Long byte_t': MaxBlob(),
 | 
						|
                'Long Unicode': MaxText(),
 | 
						|
                'Timestamp': MaxTimestamp(),
 | 
						|
                'Date': MaxDate(),
 | 
						|
                'Time': MaxTime(),
 | 
						|
                datetime.datetime: MaxTimestamp(),
 | 
						|
                datetime.date: MaxDate(),
 | 
						|
                datetime.time: MaxTime(),
 | 
						|
            }
 | 
						|
 | 
						|
    def do_execute(self, cursor, statement, parameters, context=None):
 | 
						|
        res = cursor.execute(statement, parameters)
 | 
						|
        if isinstance(res, int) and context is not None:
 | 
						|
            context._rowcount = res
 | 
						|
 | 
						|
    def do_release_savepoint(self, connection, name):
 | 
						|
        # Does MaxDB truly support RELEASE SAVEPOINT <id>?  All my attempts
 | 
						|
        # produce "SUBTRANS COMMIT/ROLLBACK not allowed without SUBTRANS
 | 
						|
        # BEGIN SQLSTATE: I7065"
 | 
						|
        # Note that ROLLBACK TO works fine.  In theory, a RELEASE should
 | 
						|
        # just free up some transactional resources early, before the overall
 | 
						|
        # COMMIT/ROLLBACK so omitting it should be relatively ok.
 | 
						|
        pass
 | 
						|
 | 
						|
    def _get_default_schema_name(self, connection):
 | 
						|
        return self.identifier_preparer._normalize_name(
 | 
						|
                connection.execute('SELECT CURRENT_SCHEMA FROM DUAL').scalar())
 | 
						|
 | 
						|
    def has_table(self, connection, table_name, schema=None):
 | 
						|
        denormalize = self.identifier_preparer._denormalize_name
 | 
						|
        bind = [denormalize(table_name)]
 | 
						|
        if schema is None:
 | 
						|
            sql = ("SELECT tablename FROM TABLES "
 | 
						|
                   "WHERE TABLES.TABLENAME=? AND"
 | 
						|
                   "  TABLES.SCHEMANAME=CURRENT_SCHEMA ")
 | 
						|
        else:
 | 
						|
            sql = ("SELECT tablename FROM TABLES "
 | 
						|
                   "WHERE TABLES.TABLENAME = ? AND"
 | 
						|
                   "  TABLES.SCHEMANAME=? ")
 | 
						|
            bind.append(denormalize(schema))
 | 
						|
 | 
						|
        rp = connection.execute(sql, bind)
 | 
						|
        return bool(rp.first())
 | 
						|
 | 
						|
    @reflection.cache
 | 
						|
    def get_table_names(self, connection, schema=None, **kw):
 | 
						|
        if schema is None:
 | 
						|
            sql = (" SELECT TABLENAME FROM TABLES WHERE "
 | 
						|
                   " SCHEMANAME=CURRENT_SCHEMA ")
 | 
						|
            rs = connection.execute(sql)
 | 
						|
        else:
 | 
						|
            sql = (" SELECT TABLENAME FROM TABLES WHERE "
 | 
						|
                   " SCHEMANAME=? ")
 | 
						|
            matchname = self.identifier_preparer._denormalize_name(schema)
 | 
						|
            rs = connection.execute(sql, matchname)
 | 
						|
        normalize = self.identifier_preparer._normalize_name
 | 
						|
        return [normalize(row[0]) for row in rs]
 | 
						|
 | 
						|
    def reflecttable(self, connection, table, include_columns):
 | 
						|
        denormalize = self.identifier_preparer._denormalize_name
 | 
						|
        normalize = self.identifier_preparer._normalize_name
 | 
						|
 | 
						|
        st = ('SELECT COLUMNNAME, MODE, DATATYPE, CODETYPE, LEN, DEC, '
 | 
						|
              '  NULLABLE, "DEFAULT", DEFAULTFUNCTION '
 | 
						|
              'FROM COLUMNS '
 | 
						|
              'WHERE TABLENAME=? AND SCHEMANAME=%s '
 | 
						|
              'ORDER BY POS')
 | 
						|
 | 
						|
        fk = ('SELECT COLUMNNAME, FKEYNAME, '
 | 
						|
              '  REFSCHEMANAME, REFTABLENAME, REFCOLUMNNAME, RULE, '
 | 
						|
              '  (CASE WHEN REFSCHEMANAME = CURRENT_SCHEMA '
 | 
						|
              '   THEN 1 ELSE 0 END) AS in_schema '
 | 
						|
              'FROM FOREIGNKEYCOLUMNS '
 | 
						|
              'WHERE TABLENAME=? AND SCHEMANAME=%s '
 | 
						|
              'ORDER BY FKEYNAME ')
 | 
						|
 | 
						|
        params = [denormalize(table.name)]
 | 
						|
        if not table.schema:
 | 
						|
            st = st % 'CURRENT_SCHEMA'
 | 
						|
            fk = fk % 'CURRENT_SCHEMA'
 | 
						|
        else:
 | 
						|
            st = st % '?'
 | 
						|
            fk = fk % '?'
 | 
						|
            params.append(denormalize(table.schema))
 | 
						|
 | 
						|
        rows = connection.execute(st, params).fetchall()
 | 
						|
        if not rows:
 | 
						|
            raise exc.NoSuchTableError(table.fullname)
 | 
						|
 | 
						|
        include_columns = set(include_columns or [])
 | 
						|
 | 
						|
        for row in rows:
 | 
						|
            (name, mode, col_type, encoding, length, scale,
 | 
						|
             nullable, constant_def, func_def) = row
 | 
						|
 | 
						|
            name = normalize(name)
 | 
						|
 | 
						|
            if include_columns and name not in include_columns:
 | 
						|
                continue
 | 
						|
 | 
						|
            type_args, type_kw = [], {}
 | 
						|
            if col_type == 'FIXED':
 | 
						|
                type_args = length, scale
 | 
						|
                # Convert FIXED(10) DEFAULT SERIAL to our Integer
 | 
						|
                if (scale == 0 and
 | 
						|
                    func_def is not None and func_def.startswith('SERIAL')):
 | 
						|
                    col_type = 'INTEGER'
 | 
						|
                    type_args = length,
 | 
						|
            elif col_type in 'FLOAT':
 | 
						|
                type_args = length,
 | 
						|
            elif col_type in ('CHAR', 'VARCHAR'):
 | 
						|
                type_args = length,
 | 
						|
                type_kw['encoding'] = encoding
 | 
						|
            elif col_type == 'LONG':
 | 
						|
                type_kw['encoding'] = encoding
 | 
						|
 | 
						|
            try:
 | 
						|
                type_cls = ischema_names[col_type.lower()]
 | 
						|
                type_instance = type_cls(*type_args, **type_kw)
 | 
						|
            except KeyError:
 | 
						|
                util.warn("Did not recognize type '%s' of column '%s'" %
 | 
						|
                          (col_type, name))
 | 
						|
                type_instance = sqltypes.NullType
 | 
						|
 | 
						|
            col_kw = {'autoincrement': False}
 | 
						|
            col_kw['nullable'] = (nullable == 'YES')
 | 
						|
            col_kw['primary_key'] = (mode == 'KEY')
 | 
						|
 | 
						|
            if func_def is not None:
 | 
						|
                if func_def.startswith('SERIAL'):
 | 
						|
                    if col_kw['primary_key']:
 | 
						|
                        # No special default- let the standard autoincrement
 | 
						|
                        # support handle SERIAL pk columns.
 | 
						|
                        col_kw['autoincrement'] = True
 | 
						|
                    else:
 | 
						|
                        # strip current numbering
 | 
						|
                        col_kw['server_default'] = schema.DefaultClause(
 | 
						|
                            sql.text('SERIAL'))
 | 
						|
                        col_kw['autoincrement'] = True
 | 
						|
                else:
 | 
						|
                    col_kw['server_default'] = schema.DefaultClause(
 | 
						|
                        sql.text(func_def))
 | 
						|
            elif constant_def is not None:
 | 
						|
                col_kw['server_default'] = schema.DefaultClause(sql.text(
 | 
						|
                    "'%s'" % constant_def.replace("'", "''")))
 | 
						|
 | 
						|
            table.append_column(schema.Column(name, type_instance, **col_kw))
 | 
						|
 | 
						|
        fk_sets = itertools.groupby(connection.execute(fk, params),
 | 
						|
                                    lambda row: row.FKEYNAME)
 | 
						|
        for fkeyname, fkey in fk_sets:
 | 
						|
            fkey = list(fkey)
 | 
						|
            if include_columns:
 | 
						|
                key_cols = set([r.COLUMNNAME for r in fkey])
 | 
						|
                if key_cols != include_columns:
 | 
						|
                    continue
 | 
						|
 | 
						|
            columns, referants = [], []
 | 
						|
            quote = self.identifier_preparer._maybe_quote_identifier
 | 
						|
 | 
						|
            for row in fkey:
 | 
						|
                columns.append(normalize(row.COLUMNNAME))
 | 
						|
                if table.schema or not row.in_schema:
 | 
						|
                    referants.append('.'.join(
 | 
						|
                        [quote(normalize(row[c]))
 | 
						|
                         for c in ('REFSCHEMANAME', 'REFTABLENAME',
 | 
						|
                                   'REFCOLUMNNAME')]))
 | 
						|
                else:
 | 
						|
                    referants.append('.'.join(
 | 
						|
                        [quote(normalize(row[c]))
 | 
						|
                         for c in ('REFTABLENAME', 'REFCOLUMNNAME')]))
 | 
						|
 | 
						|
            constraint_kw = {'name': fkeyname.lower()}
 | 
						|
            if fkey[0].RULE is not None:
 | 
						|
                rule = fkey[0].RULE
 | 
						|
                if rule.startswith('DELETE '):
 | 
						|
                    rule = rule[7:]
 | 
						|
                constraint_kw['ondelete'] = rule
 | 
						|
 | 
						|
            table_kw = {}
 | 
						|
            if table.schema or not row.in_schema:
 | 
						|
                table_kw['schema'] = normalize(fkey[0].REFSCHEMANAME)
 | 
						|
 | 
						|
            ref_key = schema._get_table_key(normalize(fkey[0].REFTABLENAME),
 | 
						|
                                            table_kw.get('schema'))
 | 
						|
            if ref_key not in table.metadata.tables:
 | 
						|
                schema.Table(normalize(fkey[0].REFTABLENAME),
 | 
						|
                             table.metadata,
 | 
						|
                             autoload=True, autoload_with=connection,
 | 
						|
                             **table_kw)
 | 
						|
 | 
						|
            constraint = schema.ForeignKeyConstraint(columns, referants, link_to_name=True,
 | 
						|
                                                     **constraint_kw)
 | 
						|
            table.append_constraint(constraint)
 | 
						|
 | 
						|
    def has_sequence(self, connection, name):
 | 
						|
        # [ticket:726] makes this schema-aware.
 | 
						|
        denormalize = self.identifier_preparer._denormalize_name
 | 
						|
        sql = ("SELECT sequence_name FROM SEQUENCES "
 | 
						|
               "WHERE SEQUENCE_NAME=? ")
 | 
						|
 | 
						|
        rp = connection.execute(sql, denormalize(name))
 | 
						|
        return bool(rp.first())
 | 
						|
 | 
						|
 | 
						|
def _autoserial_column(table):
 | 
						|
    """Finds the effective DEFAULT SERIAL column of a Table, if any."""
 | 
						|
 | 
						|
    for index, col in enumerate(table.primary_key.columns):
 | 
						|
        if (isinstance(col.type, (sqltypes.Integer, sqltypes.Numeric)) and
 | 
						|
            col.autoincrement):
 | 
						|
            if isinstance(col.default, schema.Sequence):
 | 
						|
                if col.default.optional:
 | 
						|
                    return index, col
 | 
						|
            elif (col.default is None or
 | 
						|
                  (not isinstance(col.server_default, schema.DefaultClause))):
 | 
						|
                return index, col
 | 
						|
 | 
						|
    return None, None
 | 
						|
 |