dibbler/sqlalchemy/dialects/postgresql/pygresql.py

244 lines
7.4 KiB
Python

# postgresql/pygresql.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: postgresql+pygresql
:name: pygresql
:dbapi: pgdb
:connectstring: postgresql+pygresql://user:password@host:port/dbname\
[?key=value&key=value...]
:url: http://www.pygresql.org/
"""
import decimal
import re
from ... import exc, processors, util
from ...types import Numeric, JSON as Json
from ...sql.elements import Null
from .base import PGDialect, PGCompiler, PGIdentifierPreparer, \
_DECIMAL_TYPES, _FLOAT_TYPES, _INT_TYPES, UUID
from .hstore import HSTORE
from .json import JSON, JSONB
class _PGNumeric(Numeric):
def bind_processor(self, dialect):
return None
def result_processor(self, dialect, coltype):
if not isinstance(coltype, int):
coltype = coltype.oid
if self.asdecimal:
if coltype in _FLOAT_TYPES:
return processors.to_decimal_processor_factory(
decimal.Decimal,
self._effective_decimal_return_scale)
elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
# PyGreSQL returns Decimal natively for 1700 (numeric)
return None
else:
raise exc.InvalidRequestError(
"Unknown PG numeric type: %d" % coltype)
else:
if coltype in _FLOAT_TYPES:
# PyGreSQL returns float natively for 701 (float8)
return None
elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
return processors.to_float
else:
raise exc.InvalidRequestError(
"Unknown PG numeric type: %d" % coltype)
class _PGHStore(HSTORE):
def bind_processor(self, dialect):
if not dialect.has_native_hstore:
return super(_PGHStore, self).bind_processor(dialect)
hstore = dialect.dbapi.Hstore
def process(value):
if isinstance(value, dict):
return hstore(value)
return value
return process
def result_processor(self, dialect, coltype):
if not dialect.has_native_hstore:
return super(_PGHStore, self).result_processor(dialect, coltype)
class _PGJSON(JSON):
def bind_processor(self, dialect):
if not dialect.has_native_json:
return super(_PGJSON, self).bind_processor(dialect)
json = dialect.dbapi.Json
def process(value):
if value is self.NULL:
value = None
elif isinstance(value, Null) or (
value is None and self.none_as_null):
return None
if value is None or isinstance(value, (dict, list)):
return json(value)
return value
return process
def result_processor(self, dialect, coltype):
if not dialect.has_native_json:
return super(_PGJSON, self).result_processor(dialect, coltype)
class _PGJSONB(JSONB):
def bind_processor(self, dialect):
if not dialect.has_native_json:
return super(_PGJSONB, self).bind_processor(dialect)
json = dialect.dbapi.Json
def process(value):
if value is self.NULL:
value = None
elif isinstance(value, Null) or (
value is None and self.none_as_null):
return None
if value is None or isinstance(value, (dict, list)):
return json(value)
return value
return process
def result_processor(self, dialect, coltype):
if not dialect.has_native_json:
return super(_PGJSONB, self).result_processor(dialect, coltype)
class _PGUUID(UUID):
def bind_processor(self, dialect):
if not dialect.has_native_uuid:
return super(_PGUUID, self).bind_processor(dialect)
uuid = dialect.dbapi.Uuid
def process(value):
if value is None:
return None
if isinstance(value, (str, bytes)):
if len(value) == 16:
return uuid(bytes=value)
return uuid(value)
if isinstance(value, int):
return uuid(int=value)
return value
return process
def result_processor(self, dialect, coltype):
if not dialect.has_native_uuid:
return super(_PGUUID, self).result_processor(dialect, coltype)
if not self.as_uuid:
def process(value):
if value is not None:
return str(value)
return process
class _PGCompiler(PGCompiler):
def visit_mod_binary(self, binary, operator, **kw):
return self.process(binary.left, **kw) + " %% " + \
self.process(binary.right, **kw)
def post_process_text(self, text):
return text.replace('%', '%%')
class _PGIdentifierPreparer(PGIdentifierPreparer):
def _escape_identifier(self, value):
value = value.replace(self.escape_quote, self.escape_to_quote)
return value.replace('%', '%%')
class PGDialect_pygresql(PGDialect):
driver = 'pygresql'
statement_compiler = _PGCompiler
preparer = _PGIdentifierPreparer
@classmethod
def dbapi(cls):
import pgdb
return pgdb
colspecs = util.update_copy(
PGDialect.colspecs,
{
Numeric: _PGNumeric,
HSTORE: _PGHStore,
Json: _PGJSON,
JSON: _PGJSON,
JSONB: _PGJSONB,
UUID: _PGUUID,
}
)
def __init__(self, **kwargs):
super(PGDialect_pygresql, self).__init__(**kwargs)
try:
version = self.dbapi.version
m = re.match(r'(\d+)\.(\d+)', version)
version = (int(m.group(1)), int(m.group(2)))
except (AttributeError, ValueError, TypeError):
version = (0, 0)
self.dbapi_version = version
if version < (5, 0):
has_native_hstore = has_native_json = has_native_uuid = False
if version != (0, 0):
util.warn("PyGreSQL is only fully supported by SQLAlchemy"
" since version 5.0.")
else:
self.supports_unicode_statements = True
self.supports_unicode_binds = True
has_native_hstore = has_native_json = has_native_uuid = True
self.has_native_hstore = has_native_hstore
self.has_native_json = has_native_json
self.has_native_uuid = has_native_uuid
def create_connect_args(self, url):
opts = url.translate_connect_args(username='user')
if 'port' in opts:
opts['host'] = '%s:%s' % (
opts.get('host', '').rsplit(':', 1)[0], opts.pop('port'))
opts.update(url.query)
return [], opts
def is_disconnect(self, e, connection, cursor):
if isinstance(e, self.dbapi.Error):
if not connection:
return False
try:
connection = connection.connection
except AttributeError:
pass
else:
if not connection:
return False
try:
return connection.closed
except AttributeError: # PyGreSQL < 5.0
return connection._cnx is None
return False
dialect = PGDialect_pygresql