fixup! WIP
All checks were successful
Run tests / run-tests (push) Successful in 1m30s

This commit is contained in:
2025-12-10 13:32:54 +09:00
parent 7f4a980eef
commit fa7ad3a258
8 changed files with 216 additions and 114 deletions

View File

@@ -0,0 +1,20 @@
from typing import TypeVar
from sqlalchemy import BindParameter, literal
T = TypeVar("T")
def const(value: T) -> BindParameter[T]:
"""
Create a constant SQL literal bind parameter.
This is useful to avoid too many `?` bind parameters in SQL queries,
when the input value is known to be safe.
"""
return literal(value, literal_execute=True)
CONST_ZERO: BindParameter[int] = const(0)
CONST_ONE: BindParameter[int] = const(1)
CONST_TRUE: BindParameter[bool] = const(True)
CONST_FALSE: BindParameter[bool] = const(False)
CONST_NONE: BindParameter[None] = const(None)

View File

@@ -19,6 +19,17 @@ class TransactionType(StrEnum):
THROW_PRODUCT = auto()
TRANSFER = auto()
def as_literal_column(self):
"""
Return the transaction type as a SQL literal column.
This is useful to avoid too many `?` bind parameters in SQL queries,
when the input value is known to be safe.
"""
from sqlalchemy import literal_column
return literal_column(f"'{self.value}'")
TransactionTypeSQL = SQLEnum(
TransactionType,

View File

@@ -1,10 +1,11 @@
from datetime import datetime
from dataclasses import dataclass
from datetime import datetime
from sqlalchemy import (
CTE,
BindParameter,
and_,
asc,
bindparam,
case,
func,
literal,
@@ -12,6 +13,7 @@ from sqlalchemy import (
)
from sqlalchemy.orm import Session
from dibbler.lib.query_helpers import CONST_NONE, CONST_ONE, CONST_TRUE, CONST_ZERO, const
from dibbler.models import (
Product,
Transaction,
@@ -22,9 +24,9 @@ from dibbler.queries.product_stock import _product_stock_query
def _product_owners_query(
product_id: int,
product_id: BindParameter[int] | int,
use_cache: bool = True,
until: datetime | None = None,
until: BindParameter[datetime] | datetime | None = None,
cte_name: str = "rec_cte",
) -> CTE:
"""
@@ -34,6 +36,12 @@ def _product_owners_query(
if use_cache:
print("WARNING: Using cache for users owning product query is not implemented yet.")
if isinstance(product_id, int):
product_id = bindparam("product_id", value=product_id)
if isinstance(until, datetime):
until = BindParameter("until", value=until)
product_stock = _product_stock_query(
product_id=product_id,
use_cache=use_cache,
@@ -54,26 +62,26 @@ def _product_owners_query(
.where(
Transaction.type_.in_(
[
TransactionType.ADD_PRODUCT,
TransactionType.ADD_PRODUCT.as_literal_column(),
# TransactionType.BUY_PRODUCT,
TransactionType.ADJUST_STOCK,
TransactionType.ADJUST_STOCK.as_literal_column(),
# TransactionType.JOINT,
# TransactionType.THROW_PRODUCT,
]
),
Transaction.product_id == product_id,
literal(True) if until is None else Transaction.time <= until,
CONST_TRUE if until is None else Transaction.time <= until,
)
.order_by(Transaction.time.desc())
.subquery()
)
initial_element = select(
literal(0).label("i"),
literal(0).label("time"),
literal(None).label("transaction_id"),
literal(None).label("user_id"),
literal(0).label("product_count"),
CONST_ZERO.label("i"),
CONST_ZERO.label("time"),
CONST_NONE.label("transaction_id"),
CONST_NONE.label("user_id"),
CONST_ZERO.label("product_count"),
product_stock.scalar_subquery().label("products_left_to_account_for"),
)
@@ -88,28 +96,31 @@ def _product_owners_query(
case(
# Someone adds the product -> they own it
(
trx_subset.c.type_ == TransactionType.ADD_PRODUCT,
trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
trx_subset.c.user_id,
),
else_=None,
else_=CONST_NONE,
).label("user_id"),
# How many products did they add (if any)
case(
# Someone adds the product -> they added a certain amount of products
(trx_subset.c.type_ == TransactionType.ADD_PRODUCT, trx_subset.c.product_count),
# Stock got adjusted upwards -> consider those products as added by nobody
(
(trx_subset.c.type_ == TransactionType.ADJUST_STOCK)
& (trx_subset.c.product_count > 0),
trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
trx_subset.c.product_count,
),
else_=0,
# Stock got adjusted upwards -> consider those products as added by nobody
(
(trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column())
and (trx_subset.c.product_count > CONST_ZERO),
trx_subset.c.product_count,
),
else_=CONST_ZERO,
).label("product_count"),
# How many products left to account for
case(
# Someone adds the product -> increase the number of products left to account for
(
trx_subset.c.type_ == TransactionType.ADD_PRODUCT,
trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count,
),
# Someone buys/joins/throws the product -> decrease the number of products left to account for
@@ -127,8 +138,8 @@ def _product_owners_query(
# If adjusted upwards -> products owned by nobody, decrease products left to account for
# If adjusted downwards -> products taken away from owners, decrease products left to account for
(
(trx_subset.c.type_ == TransactionType.ADJUST_STOCK)
and (trx_subset.c.product_count > 0),
(trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column())
and (trx_subset.c.product_count > CONST_ZERO),
recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count,
),
# (
@@ -142,8 +153,8 @@ def _product_owners_query(
.select_from(trx_subset)
.where(
and_(
trx_subset.c.i == recursive_cte.c.i + 1,
recursive_cte.c.products_left_to_account_for > 0,
trx_subset.c.i == recursive_cte.c.i + CONST_ONE,
recursive_cte.c.products_left_to_account_for > CONST_ZERO,
)
)
)

View File

@@ -3,9 +3,9 @@ from dataclasses import dataclass
from datetime import datetime
from sqlalchemy import (
BindParameter,
ColumnElement,
Integer,
SQLColumnExpression,
asc,
case,
cast,
@@ -15,6 +15,7 @@ from sqlalchemy import (
)
from sqlalchemy.orm import Session
from dibbler.lib.query_helpers import CONST_NONE, CONST_ONE, CONST_TRUE, CONST_ZERO, const
from dibbler.models import (
Product,
Transaction,
@@ -26,8 +27,8 @@ from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENTAGE
def _product_price_query(
product_id: int | ColumnElement[int],
use_cache: bool = True,
until: datetime | SQLColumnExpression[datetime] | None = None,
until_including: bool = True,
until: BindParameter[datetime] | datetime | None = None,
until_including: BindParameter[bool] | bool = True,
cte_name: str = "rec_cte",
):
"""
@@ -37,12 +38,21 @@ def _product_price_query(
if use_cache:
print("WARNING: Using cache for product price query is not implemented yet.")
if isinstance(product_id, int):
product_id = BindParameter("product_id", value=product_id)
if isinstance(until, datetime):
until = BindParameter("until", value=until)
if isinstance(until_including, bool):
until_including = BindParameter("until_including", value=until_including)
initial_element = select(
literal(0).label("i"),
literal(0).label("time"),
literal(None).label("transaction_id"),
literal(0).label("price"),
literal(0).label("product_count"),
CONST_ZERO.label("i"),
CONST_ZERO.label("time"),
CONST_NONE.label("transaction_id"),
CONST_ZERO.label("price"),
CONST_ZERO.label("product_count"),
)
recursive_cte = initial_element.cte(name=cte_name, recursive=True)
@@ -60,19 +70,19 @@ def _product_price_query(
.where(
Transaction.type_.in_(
[
TransactionType.BUY_PRODUCT,
TransactionType.ADD_PRODUCT,
TransactionType.ADJUST_STOCK,
TransactionType.JOINT,
TransactionType.BUY_PRODUCT.as_literal_column(),
TransactionType.ADD_PRODUCT.as_literal_column(),
TransactionType.ADJUST_STOCK.as_literal_column(),
TransactionType.JOINT.as_literal_column(),
]
),
Transaction.product_id == product_id,
case(
(literal(until_including), Transaction.time <= until),
(until_including, Transaction.time <= until),
else_=Transaction.time < until,
)
if until is not None
else literal(True),
else CONST_TRUE,
)
.order_by(Transaction.time.asc())
.alias("trx_subset")
@@ -85,22 +95,26 @@ def _product_price_query(
trx_subset.c.id.label("transaction_id"),
case(
# Someone buys the product -> price remains the same.
(trx_subset.c.type_ == TransactionType.BUY_PRODUCT, recursive_cte.c.price),
(
trx_subset.c.type_ == TransactionType.BUY_PRODUCT.as_literal_column(),
recursive_cte.c.price,
),
# Someone adds the product -> price is recalculated based on
# product count, previous price, and new price.
(
trx_subset.c.type_ == TransactionType.ADD_PRODUCT,
trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
cast(
func.ceil(
(
recursive_cte.c.price * func.max(recursive_cte.c.product_count, 0)
recursive_cte.c.price
* func.max(recursive_cte.c.product_count, CONST_ZERO)
+ trx_subset.c.per_product * trx_subset.c.product_count
)
/ (
# The running product count can be negative if the accounting is bad.
# This ensures that we never end up with negative prices or zero divisions
# and other disastrous phenomena.
func.max(recursive_cte.c.product_count, 0)
func.max(recursive_cte.c.product_count, CONST_ZERO)
+ trx_subset.c.product_count
)
),
@@ -108,28 +122,31 @@ def _product_price_query(
),
),
# Someone adjusts the stock -> price remains the same.
(trx_subset.c.type_ == TransactionType.ADJUST_STOCK, recursive_cte.c.price),
(
trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column(),
recursive_cte.c.price,
),
# Should never happen
else_=recursive_cte.c.price,
).label("price"),
case(
# Someone buys the product -> product count is reduced.
(
trx_subset.c.type_ == TransactionType.BUY_PRODUCT,
trx_subset.c.type_ == TransactionType.BUY_PRODUCT.as_literal_column(),
recursive_cte.c.product_count - trx_subset.c.product_count,
),
(
trx_subset.c.type_ == TransactionType.JOINT,
trx_subset.c.type_ == TransactionType.JOINT.as_literal_column(),
recursive_cte.c.product_count - trx_subset.c.product_count,
),
# Someone adds the product -> product count is increased.
(
trx_subset.c.type_ == TransactionType.ADD_PRODUCT,
trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
recursive_cte.c.product_count + trx_subset.c.product_count,
),
# Someone adjusts the stock -> product count is adjusted.
(
trx_subset.c.type_ == TransactionType.ADJUST_STOCK,
trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column(),
recursive_cte.c.product_count + trx_subset.c.product_count,
),
# Should never happen
@@ -137,7 +154,7 @@ def _product_price_query(
).label("product_count"),
)
.select_from(trx_subset)
.where(trx_subset.c.i == recursive_cte.c.i + 1)
.where(trx_subset.c.i == recursive_cte.c.i + CONST_ONE)
)
return recursive_cte.union_all(recursive_elements)
@@ -222,7 +239,10 @@ def product_price(
# - price should never be negative
result = sql_session.scalars(
select(recursive_cte.c.price).order_by(recursive_cte.c.i.desc()).limit(1)
select(recursive_cte.c.price)
.order_by(recursive_cte.c.i.desc())
.limit(CONST_ONE)
.offset(CONST_ZERO)
).one_or_none()
if result is None:
@@ -237,10 +257,10 @@ def product_price(
select(Transaction.interest_rate_percent)
.where(
Transaction.type_ == TransactionType.ADJUST_INTEREST,
literal(True) if until is None else Transaction.time <= until.time,
CONST_TRUE if until is None else Transaction.time <= until.time,
)
.order_by(Transaction.time.desc())
.limit(1)
.limit(CONST_ONE)
)
or DEFAULT_INTEREST_RATE_PERCENTAGE
)

View File

@@ -1,14 +1,15 @@
from datetime import datetime
from sqlalchemy import (
BindParameter,
Select,
case,
func,
literal,
select,
)
from sqlalchemy.orm import Session
from dibbler.lib.query_helpers import CONST_TRUE
from dibbler.models import (
Product,
Transaction,
@@ -17,9 +18,9 @@ from dibbler.models import (
def _product_stock_query(
product_id: int,
product_id: BindParameter[int] | int,
use_cache: bool = True,
until: datetime | None = None,
until: BindParameter[datetime] | datetime | None = None,
) -> Select:
"""
The inner query for calculating the product stock.
@@ -28,27 +29,33 @@ def _product_stock_query(
if use_cache:
print("WARNING: Using cache for product stock query is not implemented yet.")
if isinstance(product_id, int):
product_id = BindParameter("product_id", value=product_id)
if isinstance(until, datetime):
until = BindParameter("until", value=until)
query = select(
func.sum(
case(
(
Transaction.type_ == TransactionType.ADD_PRODUCT,
Transaction.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
Transaction.product_count,
),
(
Transaction.type_ == TransactionType.ADJUST_STOCK,
Transaction.type_ == TransactionType.ADJUST_STOCK.as_literal_column(),
Transaction.product_count,
),
(
Transaction.type_ == TransactionType.BUY_PRODUCT,
Transaction.type_ == TransactionType.BUY_PRODUCT.as_literal_column(),
-Transaction.product_count,
),
(
Transaction.type_ == TransactionType.JOINT,
Transaction.type_ == TransactionType.JOINT.as_literal_column(),
-Transaction.product_count,
),
(
Transaction.type_ == TransactionType.THROW_PRODUCT,
Transaction.type_ == TransactionType.THROW_PRODUCT.as_literal_column(),
-Transaction.product_count,
),
else_=0,
@@ -57,15 +64,15 @@ def _product_stock_query(
).where(
Transaction.type_.in_(
[
TransactionType.ADD_PRODUCT,
TransactionType.ADJUST_STOCK,
TransactionType.BUY_PRODUCT,
TransactionType.JOINT,
TransactionType.THROW_PRODUCT,
TransactionType.ADD_PRODUCT.as_literal_column(),
TransactionType.ADJUST_STOCK.as_literal_column(),
TransactionType.BUY_PRODUCT.as_literal_column(),
TransactionType.JOINT.as_literal_column(),
TransactionType.THROW_PRODUCT.as_literal_column(),
]
),
Transaction.product_id == product_id,
Transaction.time <= until if until is not None else literal(True),
Transaction.time <= until if until is not None else CONST_TRUE,
)
return query

View File

@@ -3,10 +3,10 @@ from datetime import datetime
from sqlalchemy import (
CTE,
BindParameter,
Float,
Integer,
and_,
asc,
case,
cast,
column,
@@ -17,6 +17,7 @@ from sqlalchemy import (
)
from sqlalchemy.orm import Session
from dibbler.lib.query_helpers import CONST_NONE, CONST_ONE, CONST_TRUE, CONST_ZERO, const
from dibbler.models import (
Transaction,
TransactionType,
@@ -31,10 +32,10 @@ from dibbler.queries.product_price import _product_price_query
def _user_balance_query(
user_id: int,
user_id: BindParameter[int] | int,
use_cache: bool = True,
until: datetime | None = None,
until_including: bool = True,
until: BindParameter[datetime] | BindParameter[None] | datetime | None = None,
until_including: BindParameter[bool] | bool = True,
cte_name: str = "rec_cte",
) -> CTE:
"""
@@ -44,14 +45,23 @@ def _user_balance_query(
if use_cache:
print("WARNING: Using cache for user balance query is not implemented yet.")
if isinstance(user_id, int):
user_id = BindParameter("user_id", value=user_id)
if isinstance(until, datetime):
until = BindParameter("until", value=until, type_=datetime)
if isinstance(until_including, bool):
until_including = BindParameter("until_including", value=until_including, type_=bool)
initial_element = select(
literal(0).label("i"),
literal(0).label("time"),
literal(None).label("transaction_id"),
literal(0).label("balance"),
literal(DEFAULT_INTEREST_RATE_PERCENTAGE).label("interest_rate_percent"),
literal(DEFAULT_PENALTY_THRESHOLD).label("penalty_threshold"),
literal(DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE).label("penalty_multiplier_percent"),
CONST_ZERO.label("i"),
CONST_ZERO.label("time"),
CONST_NONE.label("transaction_id"),
CONST_ZERO.label("balance"),
const(DEFAULT_INTEREST_RATE_PERCENTAGE).label("interest_rate_percent"),
const(DEFAULT_PENALTY_THRESHOLD).label("penalty_threshold"),
const(DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE).label("penalty_multiplier_percent"),
)
recursive_cte = initial_element.cte(name=cte_name, recursive=True)
@@ -59,7 +69,7 @@ def _user_balance_query(
# Subset of transactions that we'll want to iterate over.
trx_subset = (
select(
func.row_number().over(order_by=asc(Transaction.time)).label("i"),
func.row_number().over(order_by=Transaction.time.asc()).label("i"),
Transaction.amount,
Transaction.id,
Transaction.interest_rate_percent,
@@ -77,34 +87,34 @@ def _user_balance_query(
Transaction.user_id == user_id,
Transaction.type_.in_(
[
TransactionType.ADD_PRODUCT,
TransactionType.ADJUST_BALANCE,
TransactionType.BUY_PRODUCT,
TransactionType.TRANSFER,
TransactionType.ADD_PRODUCT.as_literal_column(),
TransactionType.ADJUST_BALANCE.as_literal_column(),
TransactionType.BUY_PRODUCT.as_literal_column(),
TransactionType.TRANSFER.as_literal_column(),
# TODO: join this with the JOINT transactions, and determine
# how much the current user paid for the product.
TransactionType.JOINT_BUY_PRODUCT,
TransactionType.JOINT_BUY_PRODUCT.as_literal_column(),
]
),
),
and_(
Transaction.type_ == TransactionType.TRANSFER,
Transaction.type_ == TransactionType.TRANSFER.as_literal_column(),
Transaction.transfer_user_id == user_id,
),
Transaction.type_.in_(
[
TransactionType.THROW_PRODUCT,
TransactionType.ADJUST_INTEREST,
TransactionType.ADJUST_PENALTY,
TransactionType.THROW_PRODUCT.as_literal_column(),
TransactionType.ADJUST_INTEREST.as_literal_column(),
TransactionType.ADJUST_PENALTY.as_literal_column(),
]
),
),
case(
(literal(until_including), Transaction.time <= until),
(until_including, Transaction.time <= until),
else_=Transaction.time < until,
)
if until is not None
else literal(True),
else CONST_TRUE,
)
.order_by(Transaction.time.asc())
.alias("trx_subset")
@@ -118,17 +128,17 @@ def _user_balance_query(
case(
# Adjusts balance -> balance gets adjusted
(
trx_subset.c.type_ == TransactionType.ADJUST_BALANCE,
trx_subset.c.type_ == TransactionType.ADJUST_BALANCE.as_literal_column(),
recursive_cte.c.balance + trx_subset.c.amount,
),
# Adds a product -> balance increases
(
trx_subset.c.type_ == TransactionType.ADD_PRODUCT,
trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(),
recursive_cte.c.balance + trx_subset.c.amount,
),
# Buys a product -> balance decreases
(
trx_subset.c.type_ == TransactionType.BUY_PRODUCT,
trx_subset.c.type_ == TransactionType.BUY_PRODUCT.as_literal_column(),
recursive_cte.c.balance
- (
trx_subset.c.product_count
@@ -151,12 +161,12 @@ def _user_balance_query(
)
)
.order_by(column("i").desc())
.limit(1)
.limit(CONST_ONE)
).scalar_subquery()
# TODO: should interest be applied before or after the penalty multiplier?
# at the moment of writing, after sound right, but maybe ask someone?
# Interest
* (cast(recursive_cte.c.interest_rate_percent, Float) / 100)
* (cast(recursive_cte.c.interest_rate_percent, Float) / const(100))
# TODO: these should be added together, not multiplied, see specification
# Penalty
* case(
@@ -164,10 +174,10 @@ def _user_balance_query(
recursive_cte.c.balance < recursive_cte.c.penalty_threshold,
(
cast(recursive_cte.c.penalty_multiplier_percent, Float)
/ 100
/ const(100)
),
),
else_=1.0,
else_=const(1.0),
)
),
Integer,
@@ -177,7 +187,7 @@ def _user_balance_query(
# Transfers money to self -> balance increases
(
and_(
trx_subset.c.type_ == TransactionType.TRANSFER,
trx_subset.c.type_ == TransactionType.TRANSFER.as_literal_column(),
trx_subset.c.transfer_user_id == user_id,
),
recursive_cte.c.balance + trx_subset.c.amount,
@@ -185,7 +195,7 @@ def _user_balance_query(
# Transfers money from self -> balance decreases
(
and_(
trx_subset.c.type_ == TransactionType.TRANSFER,
trx_subset.c.type_ == TransactionType.TRANSFER.as_literal_column(),
trx_subset.c.transfer_user_id != user_id,
),
recursive_cte.c.balance - trx_subset.c.amount,
@@ -202,28 +212,28 @@ def _user_balance_query(
).label("balance"),
case(
(
trx_subset.c.type_ == TransactionType.ADJUST_INTEREST,
trx_subset.c.type_ == TransactionType.ADJUST_INTEREST.as_literal_column(),
trx_subset.c.interest_rate_percent,
),
else_=recursive_cte.c.interest_rate_percent,
).label("interest_rate_percent"),
case(
(
trx_subset.c.type_ == TransactionType.ADJUST_PENALTY,
trx_subset.c.type_ == TransactionType.ADJUST_PENALTY.as_literal_column(),
trx_subset.c.penalty_threshold,
),
else_=recursive_cte.c.penalty_threshold,
).label("penalty_threshold"),
case(
(
trx_subset.c.type_ == TransactionType.ADJUST_PENALTY,
trx_subset.c.type_ == TransactionType.ADJUST_PENALTY.as_literal_column(),
trx_subset.c.penalty_multiplier_percent,
),
else_=recursive_cte.c.penalty_multiplier_percent,
).label("penalty_multiplier_percent"),
)
.select_from(trx_subset)
.where(trx_subset.c.i == recursive_cte.c.i + 1)
.where(trx_subset.c.i == recursive_cte.c.i + CONST_ONE)
)
return recursive_cte.union_all(recursive_elements)
@@ -322,7 +332,10 @@ def user_balance(
)
result = sql_session.scalar(
select(recursive_cte.c.balance).order_by(recursive_cte.c.i.desc()).limit(1)
select(recursive_cte.c.balance)
.order_by(recursive_cte.c.i.desc())
.limit(CONST_ONE)
.offset(CONST_ZERO)
)
if result is None:

View File

@@ -3,6 +3,7 @@ import logging
import pytest
import sqlparse
from sqlalchemy import create_engine, event
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import Session
from dibbler.models import Base
@@ -35,21 +36,27 @@ class SqlParseFormatter(logging.Formatter):
return super().format(record)
@pytest.fixture(scope="function")
def sql_session(request):
"""Create a new SQLAlchemy session for testing."""
def pytest_configure(config):
"""Setup pretty SQL logging if --echo is enabled."""
logging.basicConfig()
logger = logging.getLogger("sqlalchemy.engine")
# TODO: it would be nice not to duplicate these logs.
# logging.NullHandler() does not seem to work here
handler = logging.StreamHandler()
handler.setFormatter(SqlParseFormatter())
logger.addHandler(handler)
echo = request.config.getoption("--echo")
engine = create_engine(
"sqlite:///:memory:",
echo=echo,
)
echo = config.getoption("--echo")
if echo:
logger.setLevel(logging.INFO)
@pytest.fixture(scope="function")
def sql_session(request):
"""Create a new SQLAlchemy session for testing."""
engine = create_engine("sqlite:///:memory:")
@event.listens_for(engine, "connect")
def set_sqlite_pragma(dbapi_connection, _connection_record):
@@ -60,3 +67,17 @@ def sql_session(request):
Base.metadata.create_all(engine)
with Session(engine) as sql_session:
yield sql_session
# FIXME: Declaring this hook seems to have a side effect where the database does not
# get reset between tests.
# @pytest.hookimpl(trylast=True)
# def pytest_runtest_call(item: pytest.Item):
# """Hook to format SQL statements in OperationalError exceptions."""
# try:
# item.runtest()
# except OperationalError as e:
# if e.statement is not None:
# formatted_sql = sqlparse.format(e.statement, reindent=True, keyword_case="upper")
# e.statement = "\n" + formatted_sql + "\n"
# raise e

View File

@@ -3,7 +3,6 @@ from datetime import datetime
from pprint import pprint
import pytest
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
@@ -104,8 +103,8 @@ def test_user_balance_with_transfers(sql_session: Session) -> None:
assert user2_balance == 50 - 30
def test_user_balance_complex_history(sql_session: Session) -> None:
raise NotImplementedError("This test is not implemented yet.")
@pytest.mark.skip(reason="Not yet implemented")
def test_user_balance_complex_history(sql_session: Session) -> None: ...
def test_user_balance_penalty(sql_session: Session) -> None: