fixup! WIP
All checks were successful
Run tests / run-tests (push) Successful in 1m30s

This commit is contained in:
2025-12-10 13:32:54 +09:00
parent 7f4a980eef
commit fa7ad3a258
8 changed files with 216 additions and 114 deletions

View File

@@ -0,0 +1,20 @@
from typing import TypeVar
from sqlalchemy import BindParameter, literal
T = TypeVar("T")
def const(value: T) -> BindParameter[T]:
"""
Create a constant SQL literal bind parameter.
This is useful to avoid too many `?` bind parameters in SQL queries,
when the input value is known to be safe.
"""
return literal(value, literal_execute=True)
CONST_ZERO: BindParameter[int] = const(0)
CONST_ONE: BindParameter[int] = const(1)
CONST_TRUE: BindParameter[bool] = const(True)
CONST_FALSE: BindParameter[bool] = const(False)
CONST_NONE: BindParameter[None] = const(None)

View File

@@ -19,6 +19,17 @@ class TransactionType(StrEnum):
THROW_PRODUCT = auto() THROW_PRODUCT = auto()
TRANSFER = auto() TRANSFER = auto()
def as_literal_column(self):
"""
Return the transaction type as a SQL literal column.
This is useful to avoid too many `?` bind parameters in SQL queries,
when the input value is known to be safe.
"""
from sqlalchemy import literal_column
return literal_column(f"'{self.value}'")
TransactionTypeSQL = SQLEnum( TransactionTypeSQL = SQLEnum(
TransactionType, TransactionType,

View File

@@ -1,10 +1,11 @@
from datetime import datetime
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime
from sqlalchemy import ( from sqlalchemy import (
CTE, CTE,
BindParameter,
and_, and_,
asc, bindparam,
case, case,
func, func,
literal, literal,
@@ -12,6 +13,7 @@ from sqlalchemy import (
) )
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from dibbler.lib.query_helpers import CONST_NONE, CONST_ONE, CONST_TRUE, CONST_ZERO, const
from dibbler.models import ( from dibbler.models import (
Product, Product,
Transaction, Transaction,
@@ -22,9 +24,9 @@ from dibbler.queries.product_stock import _product_stock_query
def _product_owners_query( def _product_owners_query(
product_id: int, product_id: BindParameter[int] | int,
use_cache: bool = True, use_cache: bool = True,
until: datetime | None = None, until: BindParameter[datetime] | datetime | None = None,
cte_name: str = "rec_cte", cte_name: str = "rec_cte",
) -> CTE: ) -> CTE:
""" """
@@ -34,6 +36,12 @@ def _product_owners_query(
if use_cache: if use_cache:
print("WARNING: Using cache for users owning product query is not implemented yet.") print("WARNING: Using cache for users owning product query is not implemented yet.")
if isinstance(product_id, int):
product_id = bindparam("product_id", value=product_id)
if isinstance(until, datetime):
until = BindParameter("until", value=until)
product_stock = _product_stock_query( product_stock = _product_stock_query(
product_id=product_id, product_id=product_id,
use_cache=use_cache, use_cache=use_cache,
@@ -54,26 +62,26 @@ def _product_owners_query(
.where( .where(
Transaction.type_.in_( Transaction.type_.in_(
[ [
TransactionType.ADD_PRODUCT, TransactionType.ADD_PRODUCT.as_literal_column(),
# TransactionType.BUY_PRODUCT, # TransactionType.BUY_PRODUCT,
TransactionType.ADJUST_STOCK, TransactionType.ADJUST_STOCK.as_literal_column(),
# TransactionType.JOINT, # TransactionType.JOINT,
# TransactionType.THROW_PRODUCT, # TransactionType.THROW_PRODUCT,
] ]
), ),
Transaction.product_id == product_id, Transaction.product_id == product_id,
literal(True) if until is None else Transaction.time <= until, CONST_TRUE if until is None else Transaction.time <= until,
) )
.order_by(Transaction.time.desc()) .order_by(Transaction.time.desc())
.subquery() .subquery()
) )
initial_element = select( initial_element = select(
literal(0).label("i"), CONST_ZERO.label("i"),
literal(0).label("time"), CONST_ZERO.label("time"),
literal(None).label("transaction_id"), CONST_NONE.label("transaction_id"),
literal(None).label("user_id"), CONST_NONE.label("user_id"),
literal(0).label("product_count"), CONST_ZERO.label("product_count"),
product_stock.scalar_subquery().label("products_left_to_account_for"), product_stock.scalar_subquery().label("products_left_to_account_for"),
) )
@@ -88,28 +96,31 @@ def _product_owners_query(
case( case(
# Someone adds the product -> they own it # Someone adds the product -> they own it
( (
trx_subset.c.type_ == TransactionType.ADD_PRODUCT, trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
trx_subset.c.user_id, trx_subset.c.user_id,
), ),
else_=None, else_=CONST_NONE,
).label("user_id"), ).label("user_id"),
# How many products did they add (if any) # How many products did they add (if any)
case( case(
# Someone adds the product -> they added a certain amount of products # Someone adds the product -> they added a certain amount of products
(trx_subset.c.type_ == TransactionType.ADD_PRODUCT, trx_subset.c.product_count),
# Stock got adjusted upwards -> consider those products as added by nobody
( (
(trx_subset.c.type_ == TransactionType.ADJUST_STOCK) trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
& (trx_subset.c.product_count > 0),
trx_subset.c.product_count, trx_subset.c.product_count,
), ),
else_=0, # Stock got adjusted upwards -> consider those products as added by nobody
(
(trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column())
and (trx_subset.c.product_count > CONST_ZERO),
trx_subset.c.product_count,
),
else_=CONST_ZERO,
).label("product_count"), ).label("product_count"),
# How many products left to account for # How many products left to account for
case( case(
# Someone adds the product -> increase the number of products left to account for # Someone adds the product -> increase the number of products left to account for
( (
trx_subset.c.type_ == TransactionType.ADD_PRODUCT, trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count, recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count,
), ),
# Someone buys/joins/throws the product -> decrease the number of products left to account for # Someone buys/joins/throws the product -> decrease the number of products left to account for
@@ -127,8 +138,8 @@ def _product_owners_query(
# If adjusted upwards -> products owned by nobody, decrease products left to account for # If adjusted upwards -> products owned by nobody, decrease products left to account for
# If adjusted downwards -> products taken away from owners, decrease products left to account for # If adjusted downwards -> products taken away from owners, decrease products left to account for
( (
(trx_subset.c.type_ == TransactionType.ADJUST_STOCK) (trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column())
and (trx_subset.c.product_count > 0), and (trx_subset.c.product_count > CONST_ZERO),
recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count, recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count,
), ),
# ( # (
@@ -142,8 +153,8 @@ def _product_owners_query(
.select_from(trx_subset) .select_from(trx_subset)
.where( .where(
and_( and_(
trx_subset.c.i == recursive_cte.c.i + 1, trx_subset.c.i == recursive_cte.c.i + CONST_ONE,
recursive_cte.c.products_left_to_account_for > 0, recursive_cte.c.products_left_to_account_for > CONST_ZERO,
) )
) )
) )

View File

@@ -3,9 +3,9 @@ from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from sqlalchemy import ( from sqlalchemy import (
BindParameter,
ColumnElement, ColumnElement,
Integer, Integer,
SQLColumnExpression,
asc, asc,
case, case,
cast, cast,
@@ -15,6 +15,7 @@ from sqlalchemy import (
) )
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from dibbler.lib.query_helpers import CONST_NONE, CONST_ONE, CONST_TRUE, CONST_ZERO, const
from dibbler.models import ( from dibbler.models import (
Product, Product,
Transaction, Transaction,
@@ -26,8 +27,8 @@ from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENTAGE
def _product_price_query( def _product_price_query(
product_id: int | ColumnElement[int], product_id: int | ColumnElement[int],
use_cache: bool = True, use_cache: bool = True,
until: datetime | SQLColumnExpression[datetime] | None = None, until: BindParameter[datetime] | datetime | None = None,
until_including: bool = True, until_including: BindParameter[bool] | bool = True,
cte_name: str = "rec_cte", cte_name: str = "rec_cte",
): ):
""" """
@@ -37,12 +38,21 @@ def _product_price_query(
if use_cache: if use_cache:
print("WARNING: Using cache for product price query is not implemented yet.") print("WARNING: Using cache for product price query is not implemented yet.")
if isinstance(product_id, int):
product_id = BindParameter("product_id", value=product_id)
if isinstance(until, datetime):
until = BindParameter("until", value=until)
if isinstance(until_including, bool):
until_including = BindParameter("until_including", value=until_including)
initial_element = select( initial_element = select(
literal(0).label("i"), CONST_ZERO.label("i"),
literal(0).label("time"), CONST_ZERO.label("time"),
literal(None).label("transaction_id"), CONST_NONE.label("transaction_id"),
literal(0).label("price"), CONST_ZERO.label("price"),
literal(0).label("product_count"), CONST_ZERO.label("product_count"),
) )
recursive_cte = initial_element.cte(name=cte_name, recursive=True) recursive_cte = initial_element.cte(name=cte_name, recursive=True)
@@ -60,19 +70,19 @@ def _product_price_query(
.where( .where(
Transaction.type_.in_( Transaction.type_.in_(
[ [
TransactionType.BUY_PRODUCT, TransactionType.BUY_PRODUCT.as_literal_column(),
TransactionType.ADD_PRODUCT, TransactionType.ADD_PRODUCT.as_literal_column(),
TransactionType.ADJUST_STOCK, TransactionType.ADJUST_STOCK.as_literal_column(),
TransactionType.JOINT, TransactionType.JOINT.as_literal_column(),
] ]
), ),
Transaction.product_id == product_id, Transaction.product_id == product_id,
case( case(
(literal(until_including), Transaction.time <= until), (until_including, Transaction.time <= until),
else_=Transaction.time < until, else_=Transaction.time < until,
) )
if until is not None if until is not None
else literal(True), else CONST_TRUE,
) )
.order_by(Transaction.time.asc()) .order_by(Transaction.time.asc())
.alias("trx_subset") .alias("trx_subset")
@@ -85,22 +95,26 @@ def _product_price_query(
trx_subset.c.id.label("transaction_id"), trx_subset.c.id.label("transaction_id"),
case( case(
# Someone buys the product -> price remains the same. # Someone buys the product -> price remains the same.
(trx_subset.c.type_ == TransactionType.BUY_PRODUCT, recursive_cte.c.price), (
trx_subset.c.type_ == TransactionType.BUY_PRODUCT.as_literal_column(),
recursive_cte.c.price,
),
# Someone adds the product -> price is recalculated based on # Someone adds the product -> price is recalculated based on
# product count, previous price, and new price. # product count, previous price, and new price.
( (
trx_subset.c.type_ == TransactionType.ADD_PRODUCT, trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
cast( cast(
func.ceil( func.ceil(
( (
recursive_cte.c.price * func.max(recursive_cte.c.product_count, 0) recursive_cte.c.price
* func.max(recursive_cte.c.product_count, CONST_ZERO)
+ trx_subset.c.per_product * trx_subset.c.product_count + trx_subset.c.per_product * trx_subset.c.product_count
) )
/ ( / (
# The running product count can be negative if the accounting is bad. # The running product count can be negative if the accounting is bad.
# This ensures that we never end up with negative prices or zero divisions # This ensures that we never end up with negative prices or zero divisions
# and other disastrous phenomena. # and other disastrous phenomena.
func.max(recursive_cte.c.product_count, 0) func.max(recursive_cte.c.product_count, CONST_ZERO)
+ trx_subset.c.product_count + trx_subset.c.product_count
) )
), ),
@@ -108,28 +122,31 @@ def _product_price_query(
), ),
), ),
# Someone adjusts the stock -> price remains the same. # Someone adjusts the stock -> price remains the same.
(trx_subset.c.type_ == TransactionType.ADJUST_STOCK, recursive_cte.c.price), (
trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column(),
recursive_cte.c.price,
),
# Should never happen # Should never happen
else_=recursive_cte.c.price, else_=recursive_cte.c.price,
).label("price"), ).label("price"),
case( case(
# Someone buys the product -> product count is reduced. # Someone buys the product -> product count is reduced.
( (
trx_subset.c.type_ == TransactionType.BUY_PRODUCT, trx_subset.c.type_ == TransactionType.BUY_PRODUCT.as_literal_column(),
recursive_cte.c.product_count - trx_subset.c.product_count, recursive_cte.c.product_count - trx_subset.c.product_count,
), ),
( (
trx_subset.c.type_ == TransactionType.JOINT, trx_subset.c.type_ == TransactionType.JOINT.as_literal_column(),
recursive_cte.c.product_count - trx_subset.c.product_count, recursive_cte.c.product_count - trx_subset.c.product_count,
), ),
# Someone adds the product -> product count is increased. # Someone adds the product -> product count is increased.
( (
trx_subset.c.type_ == TransactionType.ADD_PRODUCT, trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
recursive_cte.c.product_count + trx_subset.c.product_count, recursive_cte.c.product_count + trx_subset.c.product_count,
), ),
# Someone adjusts the stock -> product count is adjusted. # Someone adjusts the stock -> product count is adjusted.
( (
trx_subset.c.type_ == TransactionType.ADJUST_STOCK, trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column(),
recursive_cte.c.product_count + trx_subset.c.product_count, recursive_cte.c.product_count + trx_subset.c.product_count,
), ),
# Should never happen # Should never happen
@@ -137,7 +154,7 @@ def _product_price_query(
).label("product_count"), ).label("product_count"),
) )
.select_from(trx_subset) .select_from(trx_subset)
.where(trx_subset.c.i == recursive_cte.c.i + 1) .where(trx_subset.c.i == recursive_cte.c.i + CONST_ONE)
) )
return recursive_cte.union_all(recursive_elements) return recursive_cte.union_all(recursive_elements)
@@ -222,7 +239,10 @@ def product_price(
# - price should never be negative # - price should never be negative
result = sql_session.scalars( result = sql_session.scalars(
select(recursive_cte.c.price).order_by(recursive_cte.c.i.desc()).limit(1) select(recursive_cte.c.price)
.order_by(recursive_cte.c.i.desc())
.limit(CONST_ONE)
.offset(CONST_ZERO)
).one_or_none() ).one_or_none()
if result is None: if result is None:
@@ -237,10 +257,10 @@ def product_price(
select(Transaction.interest_rate_percent) select(Transaction.interest_rate_percent)
.where( .where(
Transaction.type_ == TransactionType.ADJUST_INTEREST, Transaction.type_ == TransactionType.ADJUST_INTEREST,
literal(True) if until is None else Transaction.time <= until.time, CONST_TRUE if until is None else Transaction.time <= until.time,
) )
.order_by(Transaction.time.desc()) .order_by(Transaction.time.desc())
.limit(1) .limit(CONST_ONE)
) )
or DEFAULT_INTEREST_RATE_PERCENTAGE or DEFAULT_INTEREST_RATE_PERCENTAGE
) )

View File

@@ -1,14 +1,15 @@
from datetime import datetime from datetime import datetime
from sqlalchemy import ( from sqlalchemy import (
BindParameter,
Select, Select,
case, case,
func, func,
literal,
select, select,
) )
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from dibbler.lib.query_helpers import CONST_TRUE
from dibbler.models import ( from dibbler.models import (
Product, Product,
Transaction, Transaction,
@@ -17,9 +18,9 @@ from dibbler.models import (
def _product_stock_query( def _product_stock_query(
product_id: int, product_id: BindParameter[int] | int,
use_cache: bool = True, use_cache: bool = True,
until: datetime | None = None, until: BindParameter[datetime] | datetime | None = None,
) -> Select: ) -> Select:
""" """
The inner query for calculating the product stock. The inner query for calculating the product stock.
@@ -28,27 +29,33 @@ def _product_stock_query(
if use_cache: if use_cache:
print("WARNING: Using cache for product stock query is not implemented yet.") print("WARNING: Using cache for product stock query is not implemented yet.")
if isinstance(product_id, int):
product_id = BindParameter("product_id", value=product_id)
if isinstance(until, datetime):
until = BindParameter("until", value=until)
query = select( query = select(
func.sum( func.sum(
case( case(
( (
Transaction.type_ == TransactionType.ADD_PRODUCT, Transaction.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
Transaction.product_count, Transaction.product_count,
), ),
( (
Transaction.type_ == TransactionType.ADJUST_STOCK, Transaction.type_ == TransactionType.ADJUST_STOCK.as_literal_column(),
Transaction.product_count, Transaction.product_count,
), ),
( (
Transaction.type_ == TransactionType.BUY_PRODUCT, Transaction.type_ == TransactionType.BUY_PRODUCT.as_literal_column(),
-Transaction.product_count, -Transaction.product_count,
), ),
( (
Transaction.type_ == TransactionType.JOINT, Transaction.type_ == TransactionType.JOINT.as_literal_column(),
-Transaction.product_count, -Transaction.product_count,
), ),
( (
Transaction.type_ == TransactionType.THROW_PRODUCT, Transaction.type_ == TransactionType.THROW_PRODUCT.as_literal_column(),
-Transaction.product_count, -Transaction.product_count,
), ),
else_=0, else_=0,
@@ -57,15 +64,15 @@ def _product_stock_query(
).where( ).where(
Transaction.type_.in_( Transaction.type_.in_(
[ [
TransactionType.ADD_PRODUCT, TransactionType.ADD_PRODUCT.as_literal_column(),
TransactionType.ADJUST_STOCK, TransactionType.ADJUST_STOCK.as_literal_column(),
TransactionType.BUY_PRODUCT, TransactionType.BUY_PRODUCT.as_literal_column(),
TransactionType.JOINT, TransactionType.JOINT.as_literal_column(),
TransactionType.THROW_PRODUCT, TransactionType.THROW_PRODUCT.as_literal_column(),
] ]
), ),
Transaction.product_id == product_id, Transaction.product_id == product_id,
Transaction.time <= until if until is not None else literal(True), Transaction.time <= until if until is not None else CONST_TRUE,
) )
return query return query

View File

@@ -3,10 +3,10 @@ from datetime import datetime
from sqlalchemy import ( from sqlalchemy import (
CTE, CTE,
BindParameter,
Float, Float,
Integer, Integer,
and_, and_,
asc,
case, case,
cast, cast,
column, column,
@@ -17,6 +17,7 @@ from sqlalchemy import (
) )
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from dibbler.lib.query_helpers import CONST_NONE, CONST_ONE, CONST_TRUE, CONST_ZERO, const
from dibbler.models import ( from dibbler.models import (
Transaction, Transaction,
TransactionType, TransactionType,
@@ -31,10 +32,10 @@ from dibbler.queries.product_price import _product_price_query
def _user_balance_query( def _user_balance_query(
user_id: int, user_id: BindParameter[int] | int,
use_cache: bool = True, use_cache: bool = True,
until: datetime | None = None, until: BindParameter[datetime] | BindParameter[None] | datetime | None = None,
until_including: bool = True, until_including: BindParameter[bool] | bool = True,
cte_name: str = "rec_cte", cte_name: str = "rec_cte",
) -> CTE: ) -> CTE:
""" """
@@ -44,14 +45,23 @@ def _user_balance_query(
if use_cache: if use_cache:
print("WARNING: Using cache for user balance query is not implemented yet.") print("WARNING: Using cache for user balance query is not implemented yet.")
if isinstance(user_id, int):
user_id = BindParameter("user_id", value=user_id)
if isinstance(until, datetime):
until = BindParameter("until", value=until, type_=datetime)
if isinstance(until_including, bool):
until_including = BindParameter("until_including", value=until_including, type_=bool)
initial_element = select( initial_element = select(
literal(0).label("i"), CONST_ZERO.label("i"),
literal(0).label("time"), CONST_ZERO.label("time"),
literal(None).label("transaction_id"), CONST_NONE.label("transaction_id"),
literal(0).label("balance"), CONST_ZERO.label("balance"),
literal(DEFAULT_INTEREST_RATE_PERCENTAGE).label("interest_rate_percent"), const(DEFAULT_INTEREST_RATE_PERCENTAGE).label("interest_rate_percent"),
literal(DEFAULT_PENALTY_THRESHOLD).label("penalty_threshold"), const(DEFAULT_PENALTY_THRESHOLD).label("penalty_threshold"),
literal(DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE).label("penalty_multiplier_percent"), const(DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE).label("penalty_multiplier_percent"),
) )
recursive_cte = initial_element.cte(name=cte_name, recursive=True) recursive_cte = initial_element.cte(name=cte_name, recursive=True)
@@ -59,7 +69,7 @@ def _user_balance_query(
# Subset of transactions that we'll want to iterate over. # Subset of transactions that we'll want to iterate over.
trx_subset = ( trx_subset = (
select( select(
func.row_number().over(order_by=asc(Transaction.time)).label("i"), func.row_number().over(order_by=Transaction.time.asc()).label("i"),
Transaction.amount, Transaction.amount,
Transaction.id, Transaction.id,
Transaction.interest_rate_percent, Transaction.interest_rate_percent,
@@ -77,34 +87,34 @@ def _user_balance_query(
Transaction.user_id == user_id, Transaction.user_id == user_id,
Transaction.type_.in_( Transaction.type_.in_(
[ [
TransactionType.ADD_PRODUCT, TransactionType.ADD_PRODUCT.as_literal_column(),
TransactionType.ADJUST_BALANCE, TransactionType.ADJUST_BALANCE.as_literal_column(),
TransactionType.BUY_PRODUCT, TransactionType.BUY_PRODUCT.as_literal_column(),
TransactionType.TRANSFER, TransactionType.TRANSFER.as_literal_column(),
# TODO: join this with the JOINT transactions, and determine # TODO: join this with the JOINT transactions, and determine
# how much the current user paid for the product. # how much the current user paid for the product.
TransactionType.JOINT_BUY_PRODUCT, TransactionType.JOINT_BUY_PRODUCT.as_literal_column(),
] ]
), ),
), ),
and_( and_(
Transaction.type_ == TransactionType.TRANSFER, Transaction.type_ == TransactionType.TRANSFER.as_literal_column(),
Transaction.transfer_user_id == user_id, Transaction.transfer_user_id == user_id,
), ),
Transaction.type_.in_( Transaction.type_.in_(
[ [
TransactionType.THROW_PRODUCT, TransactionType.THROW_PRODUCT.as_literal_column(),
TransactionType.ADJUST_INTEREST, TransactionType.ADJUST_INTEREST.as_literal_column(),
TransactionType.ADJUST_PENALTY, TransactionType.ADJUST_PENALTY.as_literal_column(),
] ]
), ),
), ),
case( case(
(literal(until_including), Transaction.time <= until), (until_including, Transaction.time <= until),
else_=Transaction.time < until, else_=Transaction.time < until,
) )
if until is not None if until is not None
else literal(True), else CONST_TRUE,
) )
.order_by(Transaction.time.asc()) .order_by(Transaction.time.asc())
.alias("trx_subset") .alias("trx_subset")
@@ -118,17 +128,17 @@ def _user_balance_query(
case( case(
# Adjusts balance -> balance gets adjusted # Adjusts balance -> balance gets adjusted
( (
trx_subset.c.type_ == TransactionType.ADJUST_BALANCE, trx_subset.c.type_ == TransactionType.ADJUST_BALANCE.as_literal_column(),
recursive_cte.c.balance + trx_subset.c.amount, recursive_cte.c.balance + trx_subset.c.amount,
), ),
# Adds a product -> balance increases # Adds a product -> balance increases
( (
trx_subset.c.type_ == TransactionType.ADD_PRODUCT, trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
recursive_cte.c.balance + trx_subset.c.amount, recursive_cte.c.balance + trx_subset.c.amount,
), ),
# Buys a product -> balance decreases # Buys a product -> balance decreases
( (
trx_subset.c.type_ == TransactionType.BUY_PRODUCT, trx_subset.c.type_ == TransactionType.BUY_PRODUCT.as_literal_column(),
recursive_cte.c.balance recursive_cte.c.balance
- ( - (
trx_subset.c.product_count trx_subset.c.product_count
@@ -151,12 +161,12 @@ def _user_balance_query(
) )
) )
.order_by(column("i").desc()) .order_by(column("i").desc())
.limit(1) .limit(CONST_ONE)
).scalar_subquery() ).scalar_subquery()
# TODO: should interest be applied before or after the penalty multiplier? # TODO: should interest be applied before or after the penalty multiplier?
# at the moment of writing, after sound right, but maybe ask someone? # at the moment of writing, after sound right, but maybe ask someone?
# Interest # Interest
* (cast(recursive_cte.c.interest_rate_percent, Float) / 100) * (cast(recursive_cte.c.interest_rate_percent, Float) / const(100))
# TODO: these should be added together, not multiplied, see specification # TODO: these should be added together, not multiplied, see specification
# Penalty # Penalty
* case( * case(
@@ -164,10 +174,10 @@ def _user_balance_query(
recursive_cte.c.balance < recursive_cte.c.penalty_threshold, recursive_cte.c.balance < recursive_cte.c.penalty_threshold,
( (
cast(recursive_cte.c.penalty_multiplier_percent, Float) cast(recursive_cte.c.penalty_multiplier_percent, Float)
/ 100 / const(100)
), ),
), ),
else_=1.0, else_=const(1.0),
) )
), ),
Integer, Integer,
@@ -177,7 +187,7 @@ def _user_balance_query(
# Transfers money to self -> balance increases # Transfers money to self -> balance increases
( (
and_( and_(
trx_subset.c.type_ == TransactionType.TRANSFER, trx_subset.c.type_ == TransactionType.TRANSFER.as_literal_column(),
trx_subset.c.transfer_user_id == user_id, trx_subset.c.transfer_user_id == user_id,
), ),
recursive_cte.c.balance + trx_subset.c.amount, recursive_cte.c.balance + trx_subset.c.amount,
@@ -185,7 +195,7 @@ def _user_balance_query(
# Transfers money from self -> balance decreases # Transfers money from self -> balance decreases
( (
and_( and_(
trx_subset.c.type_ == TransactionType.TRANSFER, trx_subset.c.type_ == TransactionType.TRANSFER.as_literal_column(),
trx_subset.c.transfer_user_id != user_id, trx_subset.c.transfer_user_id != user_id,
), ),
recursive_cte.c.balance - trx_subset.c.amount, recursive_cte.c.balance - trx_subset.c.amount,
@@ -202,28 +212,28 @@ def _user_balance_query(
).label("balance"), ).label("balance"),
case( case(
( (
trx_subset.c.type_ == TransactionType.ADJUST_INTEREST, trx_subset.c.type_ == TransactionType.ADJUST_INTEREST.as_literal_column(),
trx_subset.c.interest_rate_percent, trx_subset.c.interest_rate_percent,
), ),
else_=recursive_cte.c.interest_rate_percent, else_=recursive_cte.c.interest_rate_percent,
).label("interest_rate_percent"), ).label("interest_rate_percent"),
case( case(
( (
trx_subset.c.type_ == TransactionType.ADJUST_PENALTY, trx_subset.c.type_ == TransactionType.ADJUST_PENALTY.as_literal_column(),
trx_subset.c.penalty_threshold, trx_subset.c.penalty_threshold,
), ),
else_=recursive_cte.c.penalty_threshold, else_=recursive_cte.c.penalty_threshold,
).label("penalty_threshold"), ).label("penalty_threshold"),
case( case(
( (
trx_subset.c.type_ == TransactionType.ADJUST_PENALTY, trx_subset.c.type_ == TransactionType.ADJUST_PENALTY.as_literal_column(),
trx_subset.c.penalty_multiplier_percent, trx_subset.c.penalty_multiplier_percent,
), ),
else_=recursive_cte.c.penalty_multiplier_percent, else_=recursive_cte.c.penalty_multiplier_percent,
).label("penalty_multiplier_percent"), ).label("penalty_multiplier_percent"),
) )
.select_from(trx_subset) .select_from(trx_subset)
.where(trx_subset.c.i == recursive_cte.c.i + 1) .where(trx_subset.c.i == recursive_cte.c.i + CONST_ONE)
) )
return recursive_cte.union_all(recursive_elements) return recursive_cte.union_all(recursive_elements)
@@ -322,7 +332,10 @@ def user_balance(
) )
result = sql_session.scalar( result = sql_session.scalar(
select(recursive_cte.c.balance).order_by(recursive_cte.c.i.desc()).limit(1) select(recursive_cte.c.balance)
.order_by(recursive_cte.c.i.desc())
.limit(CONST_ONE)
.offset(CONST_ZERO)
) )
if result is None: if result is None:

View File

@@ -3,6 +3,7 @@ import logging
import pytest import pytest
import sqlparse import sqlparse
from sqlalchemy import create_engine, event from sqlalchemy import create_engine, event
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from dibbler.models import Base from dibbler.models import Base
@@ -35,21 +36,27 @@ class SqlParseFormatter(logging.Formatter):
return super().format(record) return super().format(record)
@pytest.fixture(scope="function") def pytest_configure(config):
def sql_session(request): """Setup pretty SQL logging if --echo is enabled."""
"""Create a new SQLAlchemy session for testing."""
logging.basicConfig()
logger = logging.getLogger("sqlalchemy.engine") logger = logging.getLogger("sqlalchemy.engine")
# TODO: it would be nice not to duplicate these logs.
# logging.NullHandler() does not seem to work here
handler = logging.StreamHandler() handler = logging.StreamHandler()
handler.setFormatter(SqlParseFormatter()) handler.setFormatter(SqlParseFormatter())
logger.addHandler(handler) logger.addHandler(handler)
echo = request.config.getoption("--echo") echo = config.getoption("--echo")
engine = create_engine( if echo:
"sqlite:///:memory:", logger.setLevel(logging.INFO)
echo=echo,
)
@pytest.fixture(scope="function")
def sql_session(request):
"""Create a new SQLAlchemy session for testing."""
engine = create_engine("sqlite:///:memory:")
@event.listens_for(engine, "connect") @event.listens_for(engine, "connect")
def set_sqlite_pragma(dbapi_connection, _connection_record): def set_sqlite_pragma(dbapi_connection, _connection_record):
@@ -60,3 +67,17 @@ def sql_session(request):
Base.metadata.create_all(engine) Base.metadata.create_all(engine)
with Session(engine) as sql_session: with Session(engine) as sql_session:
yield sql_session yield sql_session
# FIXME: Declaring this hook seems to have a side effect where the database does not
# get reset between tests.
# @pytest.hookimpl(trylast=True)
# def pytest_runtest_call(item: pytest.Item):
# """Hook to format SQL statements in OperationalError exceptions."""
# try:
# item.runtest()
# except OperationalError as e:
# if e.statement is not None:
# formatted_sql = sqlparse.format(e.statement, reindent=True, keyword_case="upper")
# e.statement = "\n" + formatted_sql + "\n"
# raise e

View File

@@ -3,7 +3,6 @@ from datetime import datetime
from pprint import pprint from pprint import pprint
import pytest import pytest
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User from dibbler.models import Product, Transaction, User
@@ -104,8 +103,8 @@ def test_user_balance_with_transfers(sql_session: Session) -> None:
assert user2_balance == 50 - 30 assert user2_balance == 50 - 30
def test_user_balance_complex_history(sql_session: Session) -> None: @pytest.mark.skip(reason="Not yet implemented")
raise NotImplementedError("This test is not implemented yet.") def test_user_balance_complex_history(sql_session: Session) -> None: ...
def test_user_balance_penalty(sql_session: Session) -> None: def test_user_balance_penalty(sql_session: Session) -> None: