From fa7ad3a258d7bd541a9b0605ea83b64c6e0804b5 Mon Sep 17 00:00:00 2001 From: h7x4 Date: Wed, 10 Dec 2025 13:32:54 +0900 Subject: [PATCH] fixup! WIP --- dibbler/lib/query_helpers.py | 20 +++++++ dibbler/models/TransactionType.py | 11 ++++ dibbler/queries/product_owners.py | 59 +++++++++++--------- dibbler/queries/product_price.py | 74 +++++++++++++++---------- dibbler/queries/product_stock.py | 35 +++++++----- dibbler/queries/user_balance.py | 87 +++++++++++++++++------------- tests/conftest.py | 39 ++++++++++---- tests/queries/test_user_balance.py | 5 +- 8 files changed, 216 insertions(+), 114 deletions(-) create mode 100644 dibbler/lib/query_helpers.py diff --git a/dibbler/lib/query_helpers.py b/dibbler/lib/query_helpers.py new file mode 100644 index 0000000..bd9b7fa --- /dev/null +++ b/dibbler/lib/query_helpers.py @@ -0,0 +1,20 @@ +from typing import TypeVar +from sqlalchemy import BindParameter, literal + +T = TypeVar("T") + +def const(value: T) -> BindParameter[T]: + """ + Create a constant SQL literal bind parameter. + + This is useful to avoid too many `?` bind parameters in SQL queries, + when the input value is known to be safe. + """ + + return literal(value, literal_execute=True) + +CONST_ZERO: BindParameter[int] = const(0) +CONST_ONE: BindParameter[int] = const(1) +CONST_TRUE: BindParameter[bool] = const(True) +CONST_FALSE: BindParameter[bool] = const(False) +CONST_NONE: BindParameter[None] = const(None) diff --git a/dibbler/models/TransactionType.py b/dibbler/models/TransactionType.py index d823dbe..3b8f0f1 100644 --- a/dibbler/models/TransactionType.py +++ b/dibbler/models/TransactionType.py @@ -19,6 +19,17 @@ class TransactionType(StrEnum): THROW_PRODUCT = auto() TRANSFER = auto() + def as_literal_column(self): + """ + Return the transaction type as a SQL literal column. + + This is useful to avoid too many `?` bind parameters in SQL queries, + when the input value is known to be safe. + """ + from sqlalchemy import literal_column + + return literal_column(f"'{self.value}'") + TransactionTypeSQL = SQLEnum( TransactionType, diff --git a/dibbler/queries/product_owners.py b/dibbler/queries/product_owners.py index ee32df8..36199b2 100644 --- a/dibbler/queries/product_owners.py +++ b/dibbler/queries/product_owners.py @@ -1,10 +1,11 @@ -from datetime import datetime from dataclasses import dataclass +from datetime import datetime from sqlalchemy import ( CTE, + BindParameter, and_, - asc, + bindparam, case, func, literal, @@ -12,6 +13,7 @@ from sqlalchemy import ( ) from sqlalchemy.orm import Session +from dibbler.lib.query_helpers import CONST_NONE, CONST_ONE, CONST_TRUE, CONST_ZERO, const from dibbler.models import ( Product, Transaction, @@ -22,9 +24,9 @@ from dibbler.queries.product_stock import _product_stock_query def _product_owners_query( - product_id: int, + product_id: BindParameter[int] | int, use_cache: bool = True, - until: datetime | None = None, + until: BindParameter[datetime] | datetime | None = None, cte_name: str = "rec_cte", ) -> CTE: """ @@ -34,6 +36,12 @@ def _product_owners_query( if use_cache: print("WARNING: Using cache for users owning product query is not implemented yet.") + if isinstance(product_id, int): + product_id = bindparam("product_id", value=product_id) + + if isinstance(until, datetime): + until = BindParameter("until", value=until) + product_stock = _product_stock_query( product_id=product_id, use_cache=use_cache, @@ -54,26 +62,26 @@ def _product_owners_query( .where( Transaction.type_.in_( [ - TransactionType.ADD_PRODUCT, + TransactionType.ADD_PRODUCT.as_literal_column(), # TransactionType.BUY_PRODUCT, - TransactionType.ADJUST_STOCK, + TransactionType.ADJUST_STOCK.as_literal_column(), # TransactionType.JOINT, # TransactionType.THROW_PRODUCT, ] ), Transaction.product_id == product_id, - literal(True) if until is None else Transaction.time <= until, + CONST_TRUE if until is None else Transaction.time <= until, ) .order_by(Transaction.time.desc()) .subquery() ) initial_element = select( - literal(0).label("i"), - literal(0).label("time"), - literal(None).label("transaction_id"), - literal(None).label("user_id"), - literal(0).label("product_count"), + CONST_ZERO.label("i"), + CONST_ZERO.label("time"), + CONST_NONE.label("transaction_id"), + CONST_NONE.label("user_id"), + CONST_ZERO.label("product_count"), product_stock.scalar_subquery().label("products_left_to_account_for"), ) @@ -88,28 +96,31 @@ def _product_owners_query( case( # Someone adds the product -> they own it ( - trx_subset.c.type_ == TransactionType.ADD_PRODUCT, + trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(), trx_subset.c.user_id, ), - else_=None, + else_=CONST_NONE, ).label("user_id"), # How many products did they add (if any) case( # Someone adds the product -> they added a certain amount of products - (trx_subset.c.type_ == TransactionType.ADD_PRODUCT, trx_subset.c.product_count), - # Stock got adjusted upwards -> consider those products as added by nobody ( - (trx_subset.c.type_ == TransactionType.ADJUST_STOCK) - & (trx_subset.c.product_count > 0), + trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(), trx_subset.c.product_count, ), - else_=0, + # Stock got adjusted upwards -> consider those products as added by nobody + ( + (trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column()) + and (trx_subset.c.product_count > CONST_ZERO), + trx_subset.c.product_count, + ), + else_=CONST_ZERO, ).label("product_count"), # How many products left to account for case( # Someone adds the product -> increase the number of products left to account for ( - trx_subset.c.type_ == TransactionType.ADD_PRODUCT, + trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(), recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count, ), # Someone buys/joins/throws the product -> decrease the number of products left to account for @@ -127,8 +138,8 @@ def _product_owners_query( # If adjusted upwards -> products owned by nobody, decrease products left to account for # If adjusted downwards -> products taken away from owners, decrease products left to account for ( - (trx_subset.c.type_ == TransactionType.ADJUST_STOCK) - and (trx_subset.c.product_count > 0), + (trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column()) + and (trx_subset.c.product_count > CONST_ZERO), recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count, ), # ( @@ -142,8 +153,8 @@ def _product_owners_query( .select_from(trx_subset) .where( and_( - trx_subset.c.i == recursive_cte.c.i + 1, - recursive_cte.c.products_left_to_account_for > 0, + trx_subset.c.i == recursive_cte.c.i + CONST_ONE, + recursive_cte.c.products_left_to_account_for > CONST_ZERO, ) ) ) diff --git a/dibbler/queries/product_price.py b/dibbler/queries/product_price.py index 0b3d8bc..76a0172 100644 --- a/dibbler/queries/product_price.py +++ b/dibbler/queries/product_price.py @@ -3,9 +3,9 @@ from dataclasses import dataclass from datetime import datetime from sqlalchemy import ( + BindParameter, ColumnElement, Integer, - SQLColumnExpression, asc, case, cast, @@ -15,6 +15,7 @@ from sqlalchemy import ( ) from sqlalchemy.orm import Session +from dibbler.lib.query_helpers import CONST_NONE, CONST_ONE, CONST_TRUE, CONST_ZERO, const from dibbler.models import ( Product, Transaction, @@ -26,8 +27,8 @@ from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENTAGE def _product_price_query( product_id: int | ColumnElement[int], use_cache: bool = True, - until: datetime | SQLColumnExpression[datetime] | None = None, - until_including: bool = True, + until: BindParameter[datetime] | datetime | None = None, + until_including: BindParameter[bool] | bool = True, cte_name: str = "rec_cte", ): """ @@ -37,12 +38,21 @@ def _product_price_query( if use_cache: print("WARNING: Using cache for product price query is not implemented yet.") + if isinstance(product_id, int): + product_id = BindParameter("product_id", value=product_id) + + if isinstance(until, datetime): + until = BindParameter("until", value=until) + + if isinstance(until_including, bool): + until_including = BindParameter("until_including", value=until_including) + initial_element = select( - literal(0).label("i"), - literal(0).label("time"), - literal(None).label("transaction_id"), - literal(0).label("price"), - literal(0).label("product_count"), + CONST_ZERO.label("i"), + CONST_ZERO.label("time"), + CONST_NONE.label("transaction_id"), + CONST_ZERO.label("price"), + CONST_ZERO.label("product_count"), ) recursive_cte = initial_element.cte(name=cte_name, recursive=True) @@ -60,19 +70,19 @@ def _product_price_query( .where( Transaction.type_.in_( [ - TransactionType.BUY_PRODUCT, - TransactionType.ADD_PRODUCT, - TransactionType.ADJUST_STOCK, - TransactionType.JOINT, + TransactionType.BUY_PRODUCT.as_literal_column(), + TransactionType.ADD_PRODUCT.as_literal_column(), + TransactionType.ADJUST_STOCK.as_literal_column(), + TransactionType.JOINT.as_literal_column(), ] ), Transaction.product_id == product_id, case( - (literal(until_including), Transaction.time <= until), + (until_including, Transaction.time <= until), else_=Transaction.time < until, ) if until is not None - else literal(True), + else CONST_TRUE, ) .order_by(Transaction.time.asc()) .alias("trx_subset") @@ -85,22 +95,26 @@ def _product_price_query( trx_subset.c.id.label("transaction_id"), case( # Someone buys the product -> price remains the same. - (trx_subset.c.type_ == TransactionType.BUY_PRODUCT, recursive_cte.c.price), + ( + trx_subset.c.type_ == TransactionType.BUY_PRODUCT.as_literal_column(), + recursive_cte.c.price, + ), # Someone adds the product -> price is recalculated based on # product count, previous price, and new price. ( - trx_subset.c.type_ == TransactionType.ADD_PRODUCT, + trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(), cast( func.ceil( ( - recursive_cte.c.price * func.max(recursive_cte.c.product_count, 0) + recursive_cte.c.price + * func.max(recursive_cte.c.product_count, CONST_ZERO) + trx_subset.c.per_product * trx_subset.c.product_count ) / ( # The running product count can be negative if the accounting is bad. # This ensures that we never end up with negative prices or zero divisions # and other disastrous phenomena. - func.max(recursive_cte.c.product_count, 0) + func.max(recursive_cte.c.product_count, CONST_ZERO) + trx_subset.c.product_count ) ), @@ -108,28 +122,31 @@ def _product_price_query( ), ), # Someone adjusts the stock -> price remains the same. - (trx_subset.c.type_ == TransactionType.ADJUST_STOCK, recursive_cte.c.price), + ( + trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column(), + recursive_cte.c.price, + ), # Should never happen else_=recursive_cte.c.price, ).label("price"), case( # Someone buys the product -> product count is reduced. ( - trx_subset.c.type_ == TransactionType.BUY_PRODUCT, + trx_subset.c.type_ == TransactionType.BUY_PRODUCT.as_literal_column(), recursive_cte.c.product_count - trx_subset.c.product_count, ), ( - trx_subset.c.type_ == TransactionType.JOINT, + trx_subset.c.type_ == TransactionType.JOINT.as_literal_column(), recursive_cte.c.product_count - trx_subset.c.product_count, ), # Someone adds the product -> product count is increased. ( - trx_subset.c.type_ == TransactionType.ADD_PRODUCT, + trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(), recursive_cte.c.product_count + trx_subset.c.product_count, ), # Someone adjusts the stock -> product count is adjusted. ( - trx_subset.c.type_ == TransactionType.ADJUST_STOCK, + trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column(), recursive_cte.c.product_count + trx_subset.c.product_count, ), # Should never happen @@ -137,7 +154,7 @@ def _product_price_query( ).label("product_count"), ) .select_from(trx_subset) - .where(trx_subset.c.i == recursive_cte.c.i + 1) + .where(trx_subset.c.i == recursive_cte.c.i + CONST_ONE) ) return recursive_cte.union_all(recursive_elements) @@ -222,7 +239,10 @@ def product_price( # - price should never be negative result = sql_session.scalars( - select(recursive_cte.c.price).order_by(recursive_cte.c.i.desc()).limit(1) + select(recursive_cte.c.price) + .order_by(recursive_cte.c.i.desc()) + .limit(CONST_ONE) + .offset(CONST_ZERO) ).one_or_none() if result is None: @@ -237,10 +257,10 @@ def product_price( select(Transaction.interest_rate_percent) .where( Transaction.type_ == TransactionType.ADJUST_INTEREST, - literal(True) if until is None else Transaction.time <= until.time, + CONST_TRUE if until is None else Transaction.time <= until.time, ) .order_by(Transaction.time.desc()) - .limit(1) + .limit(CONST_ONE) ) or DEFAULT_INTEREST_RATE_PERCENTAGE ) diff --git a/dibbler/queries/product_stock.py b/dibbler/queries/product_stock.py index 976d691..1732bdf 100644 --- a/dibbler/queries/product_stock.py +++ b/dibbler/queries/product_stock.py @@ -1,14 +1,15 @@ from datetime import datetime from sqlalchemy import ( + BindParameter, Select, case, func, - literal, select, ) from sqlalchemy.orm import Session +from dibbler.lib.query_helpers import CONST_TRUE from dibbler.models import ( Product, Transaction, @@ -17,9 +18,9 @@ from dibbler.models import ( def _product_stock_query( - product_id: int, + product_id: BindParameter[int] | int, use_cache: bool = True, - until: datetime | None = None, + until: BindParameter[datetime] | datetime | None = None, ) -> Select: """ The inner query for calculating the product stock. @@ -28,27 +29,33 @@ def _product_stock_query( if use_cache: print("WARNING: Using cache for product stock query is not implemented yet.") + if isinstance(product_id, int): + product_id = BindParameter("product_id", value=product_id) + + if isinstance(until, datetime): + until = BindParameter("until", value=until) + query = select( func.sum( case( ( - Transaction.type_ == TransactionType.ADD_PRODUCT, + Transaction.type_ == TransactionType.ADD_PRODUCT.as_literal_column(), Transaction.product_count, ), ( - Transaction.type_ == TransactionType.ADJUST_STOCK, + Transaction.type_ == TransactionType.ADJUST_STOCK.as_literal_column(), Transaction.product_count, ), ( - Transaction.type_ == TransactionType.BUY_PRODUCT, + Transaction.type_ == TransactionType.BUY_PRODUCT.as_literal_column(), -Transaction.product_count, ), ( - Transaction.type_ == TransactionType.JOINT, + Transaction.type_ == TransactionType.JOINT.as_literal_column(), -Transaction.product_count, ), ( - Transaction.type_ == TransactionType.THROW_PRODUCT, + Transaction.type_ == TransactionType.THROW_PRODUCT.as_literal_column(), -Transaction.product_count, ), else_=0, @@ -57,15 +64,15 @@ def _product_stock_query( ).where( Transaction.type_.in_( [ - TransactionType.ADD_PRODUCT, - TransactionType.ADJUST_STOCK, - TransactionType.BUY_PRODUCT, - TransactionType.JOINT, - TransactionType.THROW_PRODUCT, + TransactionType.ADD_PRODUCT.as_literal_column(), + TransactionType.ADJUST_STOCK.as_literal_column(), + TransactionType.BUY_PRODUCT.as_literal_column(), + TransactionType.JOINT.as_literal_column(), + TransactionType.THROW_PRODUCT.as_literal_column(), ] ), Transaction.product_id == product_id, - Transaction.time <= until if until is not None else literal(True), + Transaction.time <= until if until is not None else CONST_TRUE, ) return query diff --git a/dibbler/queries/user_balance.py b/dibbler/queries/user_balance.py index b9dbf5c..cc1a31e 100644 --- a/dibbler/queries/user_balance.py +++ b/dibbler/queries/user_balance.py @@ -3,10 +3,10 @@ from datetime import datetime from sqlalchemy import ( CTE, + BindParameter, Float, Integer, and_, - asc, case, cast, column, @@ -17,6 +17,7 @@ from sqlalchemy import ( ) from sqlalchemy.orm import Session +from dibbler.lib.query_helpers import CONST_NONE, CONST_ONE, CONST_TRUE, CONST_ZERO, const from dibbler.models import ( Transaction, TransactionType, @@ -31,10 +32,10 @@ from dibbler.queries.product_price import _product_price_query def _user_balance_query( - user_id: int, + user_id: BindParameter[int] | int, use_cache: bool = True, - until: datetime | None = None, - until_including: bool = True, + until: BindParameter[datetime] | BindParameter[None] | datetime | None = None, + until_including: BindParameter[bool] | bool = True, cte_name: str = "rec_cte", ) -> CTE: """ @@ -44,14 +45,23 @@ def _user_balance_query( if use_cache: print("WARNING: Using cache for user balance query is not implemented yet.") + if isinstance(user_id, int): + user_id = BindParameter("user_id", value=user_id) + + if isinstance(until, datetime): + until = BindParameter("until", value=until, type_=datetime) + + if isinstance(until_including, bool): + until_including = BindParameter("until_including", value=until_including, type_=bool) + initial_element = select( - literal(0).label("i"), - literal(0).label("time"), - literal(None).label("transaction_id"), - literal(0).label("balance"), - literal(DEFAULT_INTEREST_RATE_PERCENTAGE).label("interest_rate_percent"), - literal(DEFAULT_PENALTY_THRESHOLD).label("penalty_threshold"), - literal(DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE).label("penalty_multiplier_percent"), + CONST_ZERO.label("i"), + CONST_ZERO.label("time"), + CONST_NONE.label("transaction_id"), + CONST_ZERO.label("balance"), + const(DEFAULT_INTEREST_RATE_PERCENTAGE).label("interest_rate_percent"), + const(DEFAULT_PENALTY_THRESHOLD).label("penalty_threshold"), + const(DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE).label("penalty_multiplier_percent"), ) recursive_cte = initial_element.cte(name=cte_name, recursive=True) @@ -59,7 +69,7 @@ def _user_balance_query( # Subset of transactions that we'll want to iterate over. trx_subset = ( select( - func.row_number().over(order_by=asc(Transaction.time)).label("i"), + func.row_number().over(order_by=Transaction.time.asc()).label("i"), Transaction.amount, Transaction.id, Transaction.interest_rate_percent, @@ -77,34 +87,34 @@ def _user_balance_query( Transaction.user_id == user_id, Transaction.type_.in_( [ - TransactionType.ADD_PRODUCT, - TransactionType.ADJUST_BALANCE, - TransactionType.BUY_PRODUCT, - TransactionType.TRANSFER, + TransactionType.ADD_PRODUCT.as_literal_column(), + TransactionType.ADJUST_BALANCE.as_literal_column(), + TransactionType.BUY_PRODUCT.as_literal_column(), + TransactionType.TRANSFER.as_literal_column(), # TODO: join this with the JOINT transactions, and determine # how much the current user paid for the product. - TransactionType.JOINT_BUY_PRODUCT, + TransactionType.JOINT_BUY_PRODUCT.as_literal_column(), ] ), ), and_( - Transaction.type_ == TransactionType.TRANSFER, + Transaction.type_ == TransactionType.TRANSFER.as_literal_column(), Transaction.transfer_user_id == user_id, ), Transaction.type_.in_( [ - TransactionType.THROW_PRODUCT, - TransactionType.ADJUST_INTEREST, - TransactionType.ADJUST_PENALTY, + TransactionType.THROW_PRODUCT.as_literal_column(), + TransactionType.ADJUST_INTEREST.as_literal_column(), + TransactionType.ADJUST_PENALTY.as_literal_column(), ] ), ), case( - (literal(until_including), Transaction.time <= until), + (until_including, Transaction.time <= until), else_=Transaction.time < until, ) if until is not None - else literal(True), + else CONST_TRUE, ) .order_by(Transaction.time.asc()) .alias("trx_subset") @@ -118,17 +128,17 @@ def _user_balance_query( case( # Adjusts balance -> balance gets adjusted ( - trx_subset.c.type_ == TransactionType.ADJUST_BALANCE, + trx_subset.c.type_ == TransactionType.ADJUST_BALANCE.as_literal_column(), recursive_cte.c.balance + trx_subset.c.amount, ), # Adds a product -> balance increases ( - trx_subset.c.type_ == TransactionType.ADD_PRODUCT, + trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(), recursive_cte.c.balance + trx_subset.c.amount, ), # Buys a product -> balance decreases ( - trx_subset.c.type_ == TransactionType.BUY_PRODUCT, + trx_subset.c.type_ == TransactionType.BUY_PRODUCT.as_literal_column(), recursive_cte.c.balance - ( trx_subset.c.product_count @@ -151,12 +161,12 @@ def _user_balance_query( ) ) .order_by(column("i").desc()) - .limit(1) + .limit(CONST_ONE) ).scalar_subquery() # TODO: should interest be applied before or after the penalty multiplier? # at the moment of writing, after sound right, but maybe ask someone? # Interest - * (cast(recursive_cte.c.interest_rate_percent, Float) / 100) + * (cast(recursive_cte.c.interest_rate_percent, Float) / const(100)) # TODO: these should be added together, not multiplied, see specification # Penalty * case( @@ -164,10 +174,10 @@ def _user_balance_query( recursive_cte.c.balance < recursive_cte.c.penalty_threshold, ( cast(recursive_cte.c.penalty_multiplier_percent, Float) - / 100 + / const(100) ), ), - else_=1.0, + else_=const(1.0), ) ), Integer, @@ -177,7 +187,7 @@ def _user_balance_query( # Transfers money to self -> balance increases ( and_( - trx_subset.c.type_ == TransactionType.TRANSFER, + trx_subset.c.type_ == TransactionType.TRANSFER.as_literal_column(), trx_subset.c.transfer_user_id == user_id, ), recursive_cte.c.balance + trx_subset.c.amount, @@ -185,7 +195,7 @@ def _user_balance_query( # Transfers money from self -> balance decreases ( and_( - trx_subset.c.type_ == TransactionType.TRANSFER, + trx_subset.c.type_ == TransactionType.TRANSFER.as_literal_column(), trx_subset.c.transfer_user_id != user_id, ), recursive_cte.c.balance - trx_subset.c.amount, @@ -202,28 +212,28 @@ def _user_balance_query( ).label("balance"), case( ( - trx_subset.c.type_ == TransactionType.ADJUST_INTEREST, + trx_subset.c.type_ == TransactionType.ADJUST_INTEREST.as_literal_column(), trx_subset.c.interest_rate_percent, ), else_=recursive_cte.c.interest_rate_percent, ).label("interest_rate_percent"), case( ( - trx_subset.c.type_ == TransactionType.ADJUST_PENALTY, + trx_subset.c.type_ == TransactionType.ADJUST_PENALTY.as_literal_column(), trx_subset.c.penalty_threshold, ), else_=recursive_cte.c.penalty_threshold, ).label("penalty_threshold"), case( ( - trx_subset.c.type_ == TransactionType.ADJUST_PENALTY, + trx_subset.c.type_ == TransactionType.ADJUST_PENALTY.as_literal_column(), trx_subset.c.penalty_multiplier_percent, ), else_=recursive_cte.c.penalty_multiplier_percent, ).label("penalty_multiplier_percent"), ) .select_from(trx_subset) - .where(trx_subset.c.i == recursive_cte.c.i + 1) + .where(trx_subset.c.i == recursive_cte.c.i + CONST_ONE) ) return recursive_cte.union_all(recursive_elements) @@ -322,7 +332,10 @@ def user_balance( ) result = sql_session.scalar( - select(recursive_cte.c.balance).order_by(recursive_cte.c.i.desc()).limit(1) + select(recursive_cte.c.balance) + .order_by(recursive_cte.c.i.desc()) + .limit(CONST_ONE) + .offset(CONST_ZERO) ) if result is None: diff --git a/tests/conftest.py b/tests/conftest.py index a5170ce..90a447b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ import logging import pytest import sqlparse from sqlalchemy import create_engine, event +from sqlalchemy.exc import OperationalError from sqlalchemy.orm import Session from dibbler.models import Base @@ -35,21 +36,27 @@ class SqlParseFormatter(logging.Formatter): return super().format(record) -@pytest.fixture(scope="function") -def sql_session(request): - """Create a new SQLAlchemy session for testing.""" +def pytest_configure(config): + """Setup pretty SQL logging if --echo is enabled.""" - logging.basicConfig() logger = logging.getLogger("sqlalchemy.engine") + + # TODO: it would be nice not to duplicate these logs. + # logging.NullHandler() does not seem to work here handler = logging.StreamHandler() handler.setFormatter(SqlParseFormatter()) logger.addHandler(handler) - echo = request.config.getoption("--echo") - engine = create_engine( - "sqlite:///:memory:", - echo=echo, - ) + echo = config.getoption("--echo") + if echo: + logger.setLevel(logging.INFO) + + +@pytest.fixture(scope="function") +def sql_session(request): + """Create a new SQLAlchemy session for testing.""" + + engine = create_engine("sqlite:///:memory:") @event.listens_for(engine, "connect") def set_sqlite_pragma(dbapi_connection, _connection_record): @@ -60,3 +67,17 @@ def sql_session(request): Base.metadata.create_all(engine) with Session(engine) as sql_session: yield sql_session + + +# FIXME: Declaring this hook seems to have a side effect where the database does not +# get reset between tests. +# @pytest.hookimpl(trylast=True) +# def pytest_runtest_call(item: pytest.Item): +# """Hook to format SQL statements in OperationalError exceptions.""" +# try: +# item.runtest() +# except OperationalError as e: +# if e.statement is not None: +# formatted_sql = sqlparse.format(e.statement, reindent=True, keyword_case="upper") +# e.statement = "\n" + formatted_sql + "\n" +# raise e diff --git a/tests/queries/test_user_balance.py b/tests/queries/test_user_balance.py index 4f3643c..e62eebf 100644 --- a/tests/queries/test_user_balance.py +++ b/tests/queries/test_user_balance.py @@ -3,7 +3,6 @@ from datetime import datetime from pprint import pprint import pytest - from sqlalchemy.orm import Session from dibbler.models import Product, Transaction, User @@ -104,8 +103,8 @@ def test_user_balance_with_transfers(sql_session: Session) -> None: assert user2_balance == 50 - 30 -def test_user_balance_complex_history(sql_session: Session) -> None: - raise NotImplementedError("This test is not implemented yet.") +@pytest.mark.skip(reason="Not yet implemented") +def test_user_balance_complex_history(sql_session: Session) -> None: ... def test_user_balance_penalty(sql_session: Session) -> None: