dibbler/sqlalchemy/engine/reflection.py
2010-05-07 17:33:49 +00:00

371 lines
13 KiB
Python

"""Provides an abstraction for obtaining database schema information.
Usage Notes:
Here are some general conventions when accessing the low level inspector
methods such as get_table_names, get_columns, etc.
1. Inspector methods return lists of dicts in most cases for the following
reasons:
* They're both standard types that can be serialized.
* Using a dict instead of a tuple allows easy expansion of attributes.
* Using a list for the outer structure maintains order and is easy to work
with (e.g. list comprehension [d['name'] for d in cols]).
2. Records that contain a name, such as the column name in a column record
use the key 'name'. So for most return values, each record will have a
'name' attribute..
"""
import sqlalchemy
from sqlalchemy import exc, sql
from sqlalchemy import util
from sqlalchemy.types import TypeEngine
from sqlalchemy import schema as sa_schema
@util.decorator
def cache(fn, self, con, *args, **kw):
info_cache = kw.get('info_cache', None)
if info_cache is None:
return fn(self, con, *args, **kw)
key = (
fn.__name__,
tuple(a for a in args if isinstance(a, basestring)),
tuple((k, v) for k, v in kw.iteritems() if isinstance(v, (basestring, int, float)))
)
ret = info_cache.get(key)
if ret is None:
ret = fn(self, con, *args, **kw)
info_cache[key] = ret
return ret
class Inspector(object):
"""Performs database schema inspection.
The Inspector acts as a proxy to the dialects' reflection methods and
provides higher level functions for accessing database schema information.
"""
def __init__(self, conn):
"""Initialize the instance.
:param conn: a :class:`~sqlalchemy.engine.base.Connectable`
"""
self.conn = conn
# set the engine
if hasattr(conn, 'engine'):
self.engine = conn.engine
else:
self.engine = conn
self.dialect = self.engine.dialect
self.info_cache = {}
@classmethod
def from_engine(cls, engine):
if hasattr(engine.dialect, 'inspector'):
return engine.dialect.inspector(engine)
return Inspector(engine)
@property
def default_schema_name(self):
return self.dialect.default_schema_name
def get_schema_names(self):
"""Return all schema names.
"""
if hasattr(self.dialect, 'get_schema_names'):
return self.dialect.get_schema_names(self.conn,
info_cache=self.info_cache)
return []
def get_table_names(self, schema=None, order_by=None):
"""Return all table names in `schema`.
:param schema: Optional, retrieve names from a non-default schema.
:param order_by: Optional, may be the string "foreign_key" to sort
the result on foreign key dependencies.
This should probably not return view names or maybe it should return
them with an indicator t or v.
"""
if hasattr(self.dialect, 'get_table_names'):
tnames = self.dialect.get_table_names(self.conn,
schema,
info_cache=self.info_cache)
else:
tnames = self.engine.table_names(schema)
if order_by == 'foreign_key':
ordered_tnames = tnames[:]
# Order based on foreign key dependencies.
for tname in tnames:
table_pos = tnames.index(tname)
fkeys = self.get_foreign_keys(tname, schema)
for fkey in fkeys:
rtable = fkey['referred_table']
if rtable in ordered_tnames:
ref_pos = ordered_tnames.index(rtable)
# Make sure it's lower in the list than anything it
# references.
if table_pos > ref_pos:
ordered_tnames.pop(table_pos) # rtable moves up 1
# insert just below rtable
ordered_tnames.index(ref_pos, tname)
tnames = ordered_tnames
return tnames
def get_table_options(self, table_name, schema=None, **kw):
if hasattr(self.dialect, 'get_table_options'):
return self.dialect.get_table_options(self.conn, table_name, schema,
info_cache=self.info_cache,
**kw)
return {}
def get_view_names(self, schema=None):
"""Return all view names in `schema`.
:param schema: Optional, retrieve names from a non-default schema.
"""
return self.dialect.get_view_names(self.conn, schema,
info_cache=self.info_cache)
def get_view_definition(self, view_name, schema=None):
"""Return definition for `view_name`.
:param schema: Optional, retrieve names from a non-default schema.
"""
return self.dialect.get_view_definition(
self.conn, view_name, schema, info_cache=self.info_cache)
def get_columns(self, table_name, schema=None, **kw):
"""Return information about columns in `table_name`.
Given a string `table_name` and an optional string `schema`, return
column information as a list of dicts with these keys:
name
the column's name
type
:class:`~sqlalchemy.types.TypeEngine`
nullable
boolean
default
the column's default value
attrs
dict containing optional column attributes
"""
col_defs = self.dialect.get_columns(self.conn, table_name, schema,
info_cache=self.info_cache,
**kw)
for col_def in col_defs:
# make this easy and only return instances for coltype
coltype = col_def['type']
if not isinstance(coltype, TypeEngine):
col_def['type'] = coltype()
return col_defs
def get_primary_keys(self, table_name, schema=None, **kw):
"""Return information about primary keys in `table_name`.
Given a string `table_name`, and an optional string `schema`, return
primary key information as a list of column names.
"""
pkeys = self.dialect.get_primary_keys(self.conn, table_name, schema,
info_cache=self.info_cache,
**kw)
return pkeys
def get_foreign_keys(self, table_name, schema=None, **kw):
"""Return information about foreign_keys in `table_name`.
Given a string `table_name`, and an optional string `schema`, return
foreign key information as a list of dicts with these keys:
constrained_columns
a list of column names that make up the foreign key
referred_schema
the name of the referred schema
referred_table
the name of the referred table
referred_columns
a list of column names in the referred table that correspond to
constrained_columns
\**kw
other options passed to the dialect's get_foreign_keys() method.
"""
fk_defs = self.dialect.get_foreign_keys(self.conn, table_name, schema,
info_cache=self.info_cache,
**kw)
return fk_defs
def get_indexes(self, table_name, schema=None, **kw):
"""Return information about indexes in `table_name`.
Given a string `table_name` and an optional string `schema`, return
index information as a list of dicts with these keys:
name
the index's name
column_names
list of column names in order
unique
boolean
\**kw
other options passed to the dialect's get_indexes() method.
"""
indexes = self.dialect.get_indexes(self.conn, table_name,
schema,
info_cache=self.info_cache, **kw)
return indexes
def reflecttable(self, table, include_columns):
dialect = self.conn.dialect
# MySQL dialect does this. Applicable with other dialects?
if hasattr(dialect, '_connection_charset') \
and hasattr(dialect, '_adjust_casing'):
charset = dialect._connection_charset
dialect._adjust_casing(table)
# table attributes we might need.
reflection_options = dict(
(k, table.kwargs.get(k)) for k in dialect.reflection_options if k in table.kwargs)
schema = table.schema
table_name = table.name
# apply table options
tbl_opts = self.get_table_options(table_name, schema, **table.kwargs)
if tbl_opts:
table.kwargs.update(tbl_opts)
# table.kwargs will need to be passed to each reflection method. Make
# sure keywords are strings.
tblkw = table.kwargs.copy()
for (k, v) in tblkw.items():
del tblkw[k]
tblkw[str(k)] = v
# Py2K
if isinstance(schema, str):
schema = schema.decode(dialect.encoding)
if isinstance(table_name, str):
table_name = table_name.decode(dialect.encoding)
# end Py2K
# columns
found_table = False
for col_d in self.get_columns(table_name, schema, **tblkw):
found_table = True
name = col_d['name']
if include_columns and name not in include_columns:
continue
coltype = col_d['type']
col_kw = {
'nullable':col_d['nullable'],
}
if 'autoincrement' in col_d:
col_kw['autoincrement'] = col_d['autoincrement']
if 'quote' in col_d:
col_kw['quote'] = col_d['quote']
colargs = []
if col_d.get('default') is not None:
# the "default" value is assumed to be a literal SQL expression,
# so is wrapped in text() so that no quoting occurs on re-issuance.
colargs.append(sa_schema.DefaultClause(sql.text(col_d['default'])))
if 'sequence' in col_d:
# TODO: mssql, maxdb and sybase are using this.
seq = col_d['sequence']
sequence = sa_schema.Sequence(seq['name'], 1, 1)
if 'start' in seq:
sequence.start = seq['start']
if 'increment' in seq:
sequence.increment = seq['increment']
colargs.append(sequence)
col = sa_schema.Column(name, coltype, *colargs, **col_kw)
table.append_column(col)
if not found_table:
raise exc.NoSuchTableError(table.name)
# Primary keys
primary_key_constraint = sa_schema.PrimaryKeyConstraint(*[
table.c[pk] for pk in self.get_primary_keys(table_name, schema, **tblkw)
if pk in table.c
])
table.append_constraint(primary_key_constraint)
# Foreign keys
fkeys = self.get_foreign_keys(table_name, schema, **tblkw)
for fkey_d in fkeys:
conname = fkey_d['name']
constrained_columns = fkey_d['constrained_columns']
referred_schema = fkey_d['referred_schema']
referred_table = fkey_d['referred_table']
referred_columns = fkey_d['referred_columns']
refspec = []
if referred_schema is not None:
sa_schema.Table(referred_table, table.metadata,
autoload=True, schema=referred_schema,
autoload_with=self.conn,
**reflection_options
)
for column in referred_columns:
refspec.append(".".join(
[referred_schema, referred_table, column]))
else:
sa_schema.Table(referred_table, table.metadata, autoload=True,
autoload_with=self.conn,
**reflection_options
)
for column in referred_columns:
refspec.append(".".join([referred_table, column]))
table.append_constraint(
sa_schema.ForeignKeyConstraint(constrained_columns, refspec,
conname, link_to_name=True))
# Indexes
indexes = self.get_indexes(table_name, schema)
for index_d in indexes:
name = index_d['name']
columns = index_d['column_names']
unique = index_d['unique']
flavor = index_d.get('type', 'unknown type')
if include_columns and \
not set(columns).issubset(include_columns):
util.warn(
"Omitting %s KEY for (%s), key covers omitted columns." %
(flavor, ', '.join(columns)))
continue
sa_schema.Index(name, *[table.columns[c] for c in columns],
**dict(unique=unique))