dibbler/sqlalchemy/orm/evaluator.py

138 lines
4.7 KiB
Python
Raw Normal View History

2017-04-15 18:27:12 +02:00
# orm/evaluator.py
# Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
2010-05-07 19:33:49 +02:00
import operator
2017-04-15 18:27:12 +02:00
from ..sql import operators
2010-05-07 19:33:49 +02:00
class UnevaluatableError(Exception):
pass
_straight_ops = set(getattr(operators, op)
2017-04-15 18:27:12 +02:00
for op in ('add', 'mul', 'sub',
'div',
'mod', 'truediv',
2010-05-07 19:33:49 +02:00
'lt', 'le', 'ne', 'gt', 'ge', 'eq'))
_notimplemented_ops = set(getattr(operators, op)
for op in ('like_op', 'notlike_op', 'ilike_op',
'notilike_op', 'between_op', 'in_op',
'notin_op', 'endswith_op', 'concat_op'))
2017-04-15 18:27:12 +02:00
2010-05-07 19:33:49 +02:00
class EvaluatorCompiler(object):
2017-04-15 18:27:12 +02:00
def __init__(self, target_cls=None):
self.target_cls = target_cls
2010-05-07 19:33:49 +02:00
def process(self, clause):
meth = getattr(self, "visit_%s" % clause.__visit_name__, None)
if not meth:
2017-04-15 18:27:12 +02:00
raise UnevaluatableError(
"Cannot evaluate %s" % type(clause).__name__)
2010-05-07 19:33:49 +02:00
return meth(clause)
def visit_grouping(self, clause):
return self.process(clause.element)
def visit_null(self, clause):
return lambda obj: None
2017-04-15 18:27:12 +02:00
def visit_false(self, clause):
return lambda obj: False
def visit_true(self, clause):
return lambda obj: True
2010-05-07 19:33:49 +02:00
def visit_column(self, clause):
if 'parentmapper' in clause._annotations:
2017-04-15 18:27:12 +02:00
parentmapper = clause._annotations['parentmapper']
if self.target_cls and not issubclass(
self.target_cls, parentmapper.class_):
raise UnevaluatableError(
"Can't evaluate criteria against alternate class %s" %
parentmapper.class_
)
key = parentmapper._columntoproperty[clause].key
2010-05-07 19:33:49 +02:00
else:
key = clause.key
2017-04-15 18:27:12 +02:00
2010-05-07 19:33:49 +02:00
get_corresponding_attr = operator.attrgetter(key)
return lambda obj: get_corresponding_attr(obj)
def visit_clauselist(self, clause):
2017-04-15 18:27:12 +02:00
evaluators = list(map(self.process, clause.clauses))
2010-05-07 19:33:49 +02:00
if clause.operator is operators.or_:
def evaluate(obj):
has_null = False
for sub_evaluate in evaluators:
value = sub_evaluate(obj)
if value:
return True
has_null = has_null or value is None
if has_null:
return None
return False
elif clause.operator is operators.and_:
def evaluate(obj):
for sub_evaluate in evaluators:
value = sub_evaluate(obj)
if not value:
if value is None:
return None
return False
return True
else:
2017-04-15 18:27:12 +02:00
raise UnevaluatableError(
"Cannot evaluate clauselist with operator %s" %
clause.operator)
2010-05-07 19:33:49 +02:00
return evaluate
def visit_binary(self, clause):
2017-04-15 18:27:12 +02:00
eval_left, eval_right = list(map(self.process,
[clause.left, clause.right]))
2010-05-07 19:33:49 +02:00
operator = clause.operator
if operator is operators.is_:
def evaluate(obj):
return eval_left(obj) == eval_right(obj)
elif operator is operators.isnot:
def evaluate(obj):
return eval_left(obj) != eval_right(obj)
elif operator in _straight_ops:
def evaluate(obj):
left_val = eval_left(obj)
right_val = eval_right(obj)
if left_val is None or right_val is None:
return None
return operator(eval_left(obj), eval_right(obj))
else:
2017-04-15 18:27:12 +02:00
raise UnevaluatableError(
"Cannot evaluate %s with operator %s" %
(type(clause).__name__, clause.operator))
2010-05-07 19:33:49 +02:00
return evaluate
def visit_unary(self, clause):
eval_inner = self.process(clause.element)
if clause.operator is operators.inv:
def evaluate(obj):
value = eval_inner(obj)
if value is None:
return None
return not value
return evaluate
2017-04-15 18:27:12 +02:00
raise UnevaluatableError(
"Cannot evaluate %s with operator %s" %
(type(clause).__name__, clause.operator))
2010-05-07 19:33:49 +02:00
def visit_bindparam(self, clause):
2017-04-15 18:27:12 +02:00
if clause.callable:
val = clause.callable()
else:
val = clause.value
2010-05-07 19:33:49 +02:00
return lambda obj: val