morro
This commit is contained in:
418
sqlalchemy/dialects/access/base.py
Normal file
418
sqlalchemy/dialects/access/base.py
Normal file
@@ -0,0 +1,418 @@
|
||||
# access.py
|
||||
# Copyright (C) 2007 Paul Johnston, paj@pajhome.org.uk
|
||||
# Portions derived from jet2sql.py by Matt Keranen, mksql@yahoo.com
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""
|
||||
Support for the Microsoft Access database.
|
||||
|
||||
This dialect is *not* ported to SQLAlchemy 0.6.
|
||||
|
||||
This dialect is *not* tested on SQLAlchemy 0.6.
|
||||
|
||||
|
||||
"""
|
||||
from sqlalchemy import sql, schema, types, exc, pool
|
||||
from sqlalchemy.sql import compiler, expression
|
||||
from sqlalchemy.engine import default, base, reflection
|
||||
from sqlalchemy import processors
|
||||
|
||||
class AcNumeric(types.Numeric):
|
||||
def get_col_spec(self):
|
||||
return "NUMERIC"
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
return processors.to_str
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
return None
|
||||
|
||||
class AcFloat(types.Float):
|
||||
def get_col_spec(self):
|
||||
return "FLOAT"
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
"""By converting to string, we can use Decimal types round-trip."""
|
||||
return processors.to_str
|
||||
|
||||
class AcInteger(types.Integer):
|
||||
def get_col_spec(self):
|
||||
return "INTEGER"
|
||||
|
||||
class AcTinyInteger(types.Integer):
|
||||
def get_col_spec(self):
|
||||
return "TINYINT"
|
||||
|
||||
class AcSmallInteger(types.SmallInteger):
|
||||
def get_col_spec(self):
|
||||
return "SMALLINT"
|
||||
|
||||
class AcDateTime(types.DateTime):
|
||||
def __init__(self, *a, **kw):
|
||||
super(AcDateTime, self).__init__(False)
|
||||
|
||||
def get_col_spec(self):
|
||||
return "DATETIME"
|
||||
|
||||
class AcDate(types.Date):
|
||||
def __init__(self, *a, **kw):
|
||||
super(AcDate, self).__init__(False)
|
||||
|
||||
def get_col_spec(self):
|
||||
return "DATETIME"
|
||||
|
||||
class AcText(types.Text):
|
||||
def get_col_spec(self):
|
||||
return "MEMO"
|
||||
|
||||
class AcString(types.String):
|
||||
def get_col_spec(self):
|
||||
return "TEXT" + (self.length and ("(%d)" % self.length) or "")
|
||||
|
||||
class AcUnicode(types.Unicode):
|
||||
def get_col_spec(self):
|
||||
return "TEXT" + (self.length and ("(%d)" % self.length) or "")
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
return None
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
return None
|
||||
|
||||
class AcChar(types.CHAR):
|
||||
def get_col_spec(self):
|
||||
return "TEXT" + (self.length and ("(%d)" % self.length) or "")
|
||||
|
||||
class AcBinary(types.LargeBinary):
|
||||
def get_col_spec(self):
|
||||
return "BINARY"
|
||||
|
||||
class AcBoolean(types.Boolean):
|
||||
def get_col_spec(self):
|
||||
return "YESNO"
|
||||
|
||||
class AcTimeStamp(types.TIMESTAMP):
|
||||
def get_col_spec(self):
|
||||
return "TIMESTAMP"
|
||||
|
||||
class AccessExecutionContext(default.DefaultExecutionContext):
|
||||
def _has_implicit_sequence(self, column):
|
||||
if column.primary_key and column.autoincrement:
|
||||
if isinstance(column.type, types.Integer) and not column.foreign_keys:
|
||||
if column.default is None or (isinstance(column.default, schema.Sequence) and \
|
||||
column.default.optional):
|
||||
return True
|
||||
return False
|
||||
|
||||
def post_exec(self):
|
||||
"""If we inserted into a row with a COUNTER column, fetch the ID"""
|
||||
|
||||
if self.compiled.isinsert:
|
||||
tbl = self.compiled.statement.table
|
||||
if not hasattr(tbl, 'has_sequence'):
|
||||
tbl.has_sequence = None
|
||||
for column in tbl.c:
|
||||
if getattr(column, 'sequence', False) or self._has_implicit_sequence(column):
|
||||
tbl.has_sequence = column
|
||||
break
|
||||
|
||||
if bool(tbl.has_sequence):
|
||||
# TBD: for some reason _last_inserted_ids doesn't exist here
|
||||
# (but it does at corresponding point in mssql???)
|
||||
#if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
|
||||
self.cursor.execute("SELECT @@identity AS lastrowid")
|
||||
row = self.cursor.fetchone()
|
||||
self._last_inserted_ids = [int(row[0])] #+ self._last_inserted_ids[1:]
|
||||
# print "LAST ROW ID", self._last_inserted_ids
|
||||
|
||||
super(AccessExecutionContext, self).post_exec()
|
||||
|
||||
|
||||
const, daoEngine = None, None
|
||||
class AccessDialect(default.DefaultDialect):
|
||||
colspecs = {
|
||||
types.Unicode : AcUnicode,
|
||||
types.Integer : AcInteger,
|
||||
types.SmallInteger: AcSmallInteger,
|
||||
types.Numeric : AcNumeric,
|
||||
types.Float : AcFloat,
|
||||
types.DateTime : AcDateTime,
|
||||
types.Date : AcDate,
|
||||
types.String : AcString,
|
||||
types.LargeBinary : AcBinary,
|
||||
types.Boolean : AcBoolean,
|
||||
types.Text : AcText,
|
||||
types.CHAR: AcChar,
|
||||
types.TIMESTAMP: AcTimeStamp,
|
||||
}
|
||||
name = 'access'
|
||||
supports_sane_rowcount = False
|
||||
supports_sane_multi_rowcount = False
|
||||
|
||||
ported_sqla_06 = False
|
||||
|
||||
def type_descriptor(self, typeobj):
|
||||
newobj = types.adapt_type(typeobj, self.colspecs)
|
||||
return newobj
|
||||
|
||||
def __init__(self, **params):
|
||||
super(AccessDialect, self).__init__(**params)
|
||||
self.text_as_varchar = False
|
||||
self._dtbs = None
|
||||
|
||||
def dbapi(cls):
|
||||
import win32com.client, pythoncom
|
||||
|
||||
global const, daoEngine
|
||||
if const is None:
|
||||
const = win32com.client.constants
|
||||
for suffix in (".36", ".35", ".30"):
|
||||
try:
|
||||
daoEngine = win32com.client.gencache.EnsureDispatch("DAO.DBEngine" + suffix)
|
||||
break
|
||||
except pythoncom.com_error:
|
||||
pass
|
||||
else:
|
||||
raise exc.InvalidRequestError("Can't find a DB engine. Check http://support.microsoft.com/kb/239114 for details.")
|
||||
|
||||
import pyodbc as module
|
||||
return module
|
||||
dbapi = classmethod(dbapi)
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args()
|
||||
connectors = ["Driver={Microsoft Access Driver (*.mdb)}"]
|
||||
connectors.append("Dbq=%s" % opts["database"])
|
||||
user = opts.get("username", None)
|
||||
if user:
|
||||
connectors.append("UID=%s" % user)
|
||||
connectors.append("PWD=%s" % opts.get("password", ""))
|
||||
return [[";".join(connectors)], {}]
|
||||
|
||||
def last_inserted_ids(self):
|
||||
return self.context.last_inserted_ids
|
||||
|
||||
def do_execute(self, cursor, statement, params, **kwargs):
|
||||
if params == {}:
|
||||
params = ()
|
||||
super(AccessDialect, self).do_execute(cursor, statement, params, **kwargs)
|
||||
|
||||
def _execute(self, c, statement, parameters):
|
||||
try:
|
||||
if parameters == {}:
|
||||
parameters = ()
|
||||
c.execute(statement, parameters)
|
||||
self.context.rowcount = c.rowcount
|
||||
except Exception, e:
|
||||
raise exc.DBAPIError.instance(statement, parameters, e)
|
||||
|
||||
def has_table(self, connection, tablename, schema=None):
|
||||
# This approach seems to be more reliable that using DAO
|
||||
try:
|
||||
connection.execute('select top 1 * from [%s]' % tablename)
|
||||
return True
|
||||
except Exception, e:
|
||||
return False
|
||||
|
||||
def reflecttable(self, connection, table, include_columns):
|
||||
# This is defined in the function, as it relies on win32com constants,
|
||||
# that aren't imported until dbapi method is called
|
||||
if not hasattr(self, 'ischema_names'):
|
||||
self.ischema_names = {
|
||||
const.dbByte: AcBinary,
|
||||
const.dbInteger: AcInteger,
|
||||
const.dbLong: AcInteger,
|
||||
const.dbSingle: AcFloat,
|
||||
const.dbDouble: AcFloat,
|
||||
const.dbDate: AcDateTime,
|
||||
const.dbLongBinary: AcBinary,
|
||||
const.dbMemo: AcText,
|
||||
const.dbBoolean: AcBoolean,
|
||||
const.dbText: AcUnicode, # All Access strings are unicode
|
||||
const.dbCurrency: AcNumeric,
|
||||
}
|
||||
|
||||
# A fresh DAO connection is opened for each reflection
|
||||
# This is necessary, so we get the latest updates
|
||||
dtbs = daoEngine.OpenDatabase(connection.engine.url.database)
|
||||
|
||||
try:
|
||||
for tbl in dtbs.TableDefs:
|
||||
if tbl.Name.lower() == table.name.lower():
|
||||
break
|
||||
else:
|
||||
raise exc.NoSuchTableError(table.name)
|
||||
|
||||
for col in tbl.Fields:
|
||||
coltype = self.ischema_names[col.Type]
|
||||
if col.Type == const.dbText:
|
||||
coltype = coltype(col.Size)
|
||||
|
||||
colargs = \
|
||||
{
|
||||
'nullable': not(col.Required or col.Attributes & const.dbAutoIncrField),
|
||||
}
|
||||
default = col.DefaultValue
|
||||
|
||||
if col.Attributes & const.dbAutoIncrField:
|
||||
colargs['default'] = schema.Sequence(col.Name + '_seq')
|
||||
elif default:
|
||||
if col.Type == const.dbBoolean:
|
||||
default = default == 'Yes' and '1' or '0'
|
||||
colargs['server_default'] = schema.DefaultClause(sql.text(default))
|
||||
|
||||
table.append_column(schema.Column(col.Name, coltype, **colargs))
|
||||
|
||||
# TBD: check constraints
|
||||
|
||||
# Find primary key columns first
|
||||
for idx in tbl.Indexes:
|
||||
if idx.Primary:
|
||||
for col in idx.Fields:
|
||||
thecol = table.c[col.Name]
|
||||
table.primary_key.add(thecol)
|
||||
if isinstance(thecol.type, AcInteger) and \
|
||||
not (thecol.default and isinstance(thecol.default.arg, schema.Sequence)):
|
||||
thecol.autoincrement = False
|
||||
|
||||
# Then add other indexes
|
||||
for idx in tbl.Indexes:
|
||||
if not idx.Primary:
|
||||
if len(idx.Fields) == 1:
|
||||
col = table.c[idx.Fields[0].Name]
|
||||
if not col.primary_key:
|
||||
col.index = True
|
||||
col.unique = idx.Unique
|
||||
else:
|
||||
pass # TBD: multi-column indexes
|
||||
|
||||
|
||||
for fk in dtbs.Relations:
|
||||
if fk.ForeignTable != table.name:
|
||||
continue
|
||||
scols = [c.ForeignName for c in fk.Fields]
|
||||
rcols = ['%s.%s' % (fk.Table, c.Name) for c in fk.Fields]
|
||||
table.append_constraint(schema.ForeignKeyConstraint(scols, rcols, link_to_name=True))
|
||||
|
||||
finally:
|
||||
dtbs.Close()
|
||||
|
||||
@reflection.cache
|
||||
def get_table_names(self, connection, schema=None, **kw):
|
||||
# A fresh DAO connection is opened for each reflection
|
||||
# This is necessary, so we get the latest updates
|
||||
dtbs = daoEngine.OpenDatabase(connection.engine.url.database)
|
||||
|
||||
names = [t.Name for t in dtbs.TableDefs if t.Name[:4] != "MSys" and t.Name[:4] != "~TMP"]
|
||||
dtbs.Close()
|
||||
return names
|
||||
|
||||
|
||||
class AccessCompiler(compiler.SQLCompiler):
|
||||
extract_map = compiler.SQLCompiler.extract_map.copy()
|
||||
extract_map.update ({
|
||||
'month': 'm',
|
||||
'day': 'd',
|
||||
'year': 'yyyy',
|
||||
'second': 's',
|
||||
'hour': 'h',
|
||||
'doy': 'y',
|
||||
'minute': 'n',
|
||||
'quarter': 'q',
|
||||
'dow': 'w',
|
||||
'week': 'ww'
|
||||
})
|
||||
|
||||
def visit_select_precolumns(self, select):
|
||||
"""Access puts TOP, it's version of LIMIT here """
|
||||
s = select.distinct and "DISTINCT " or ""
|
||||
if select.limit:
|
||||
s += "TOP %s " % (select.limit)
|
||||
if select.offset:
|
||||
raise exc.InvalidRequestError('Access does not support LIMIT with an offset')
|
||||
return s
|
||||
|
||||
def limit_clause(self, select):
|
||||
"""Limit in access is after the select keyword"""
|
||||
return ""
|
||||
|
||||
def binary_operator_string(self, binary):
|
||||
"""Access uses "mod" instead of "%" """
|
||||
return binary.operator == '%' and 'mod' or binary.operator
|
||||
|
||||
def label_select_column(self, select, column, asfrom):
|
||||
if isinstance(column, expression.Function):
|
||||
return column.label()
|
||||
else:
|
||||
return super(AccessCompiler, self).label_select_column(select, column, asfrom)
|
||||
|
||||
function_rewrites = {'current_date': 'now',
|
||||
'current_timestamp': 'now',
|
||||
'length': 'len',
|
||||
}
|
||||
def visit_function(self, func):
|
||||
"""Access function names differ from the ANSI SQL names; rewrite common ones"""
|
||||
func.name = self.function_rewrites.get(func.name, func.name)
|
||||
return super(AccessCompiler, self).visit_function(func)
|
||||
|
||||
def for_update_clause(self, select):
|
||||
"""FOR UPDATE is not supported by Access; silently ignore"""
|
||||
return ''
|
||||
|
||||
# Strip schema
|
||||
def visit_table(self, table, asfrom=False, **kwargs):
|
||||
if asfrom:
|
||||
return self.preparer.quote(table.name, table.quote)
|
||||
else:
|
||||
return ""
|
||||
|
||||
def visit_join(self, join, asfrom=False, **kwargs):
|
||||
return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " INNER JOIN ") + \
|
||||
self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause))
|
||||
|
||||
def visit_extract(self, extract, **kw):
|
||||
field = self.extract_map.get(extract.field, extract.field)
|
||||
return 'DATEPART("%s", %s)' % (field, self.process(extract.expr, **kw))
|
||||
|
||||
class AccessDDLCompiler(compiler.DDLCompiler):
|
||||
def get_column_specification(self, column, **kwargs):
|
||||
colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
|
||||
|
||||
# install a sequence if we have an implicit IDENTITY column
|
||||
if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \
|
||||
column.autoincrement and isinstance(column.type, types.Integer) and not column.foreign_keys:
|
||||
if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional):
|
||||
column.sequence = schema.Sequence(column.name + '_seq')
|
||||
|
||||
if not column.nullable:
|
||||
colspec += " NOT NULL"
|
||||
|
||||
if hasattr(column, 'sequence'):
|
||||
column.table.has_sequence = column
|
||||
colspec = self.preparer.format_column(column) + " counter"
|
||||
else:
|
||||
default = self.get_column_default_string(column)
|
||||
if default is not None:
|
||||
colspec += " DEFAULT " + default
|
||||
|
||||
return colspec
|
||||
|
||||
def visit_drop_index(self, drop):
|
||||
index = drop.element
|
||||
self.append("\nDROP INDEX [%s].[%s]" % (index.table.name, self._validate_identifier(index.name, False)))
|
||||
|
||||
class AccessIdentifierPreparer(compiler.IdentifierPreparer):
|
||||
reserved_words = compiler.RESERVED_WORDS.copy()
|
||||
reserved_words.update(['value', 'text'])
|
||||
def __init__(self, dialect):
|
||||
super(AccessIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']')
|
||||
|
||||
|
||||
dialect = AccessDialect
|
||||
dialect.poolclass = pool.SingletonThreadPool
|
||||
dialect.statement_compiler = AccessCompiler
|
||||
dialect.ddlcompiler = AccessDDLCompiler
|
||||
dialect.preparer = AccessIdentifierPreparer
|
||||
dialect.execution_ctx_cls = AccessExecutionContext
|
Reference in New Issue
Block a user