fixup! WIP

This commit is contained in:
2025-06-13 22:19:40 +02:00
parent f5c4959e51
commit 3e4c3a44d2
19 changed files with 1034 additions and 576 deletions

View File

@@ -27,7 +27,7 @@ class Product(Base):
EAN-13 code. EAN-13 code.
""" """
name: Mapped[str] = mapped_column(String(45)) name: Mapped[str] = mapped_column(String(45), unique=True)
""" """
The name of the product. The name of the product.

View File

@@ -9,27 +9,32 @@ from sqlalchemy import (
ForeignKey, ForeignKey,
Integer, Integer,
Text, Text,
) and_,
from sqlalchemy import ( column,
Enum as SQLEnum, or_,
) )
from sqlalchemy.orm import ( from sqlalchemy.orm import (
Mapped, Mapped,
mapped_column, mapped_column,
relationship, relationship,
) )
from sqlalchemy.orm.collections import (
InstrumentedDict,
InstrumentedList,
InstrumentedSet,
)
from sqlalchemy.sql.schema import Index from sqlalchemy.sql.schema import Index
from .Base import Base from .Base import Base
from .TransactionType import TransactionType from .TransactionType import TransactionType, TransactionTypeSQL
if TYPE_CHECKING: if TYPE_CHECKING:
from .Product import Product from .Product import Product
from .User import User from .User import User
# TODO: rename to *_PERCENT
# NOTE: these only matter when there are no adjustments made in the database. # NOTE: these only matter when there are no adjustments made in the database.
DEFAULT_INTEREST_RATE_PERCENTAGE = 100 DEFAULT_INTEREST_RATE_PERCENTAGE = 100
DEFAULT_PENALTY_THRESHOLD = -100 DEFAULT_PENALTY_THRESHOLD = -100
DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE = 200 DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE = 200
@@ -64,7 +69,6 @@ assert all(x <= _DYNAMIC_FIELDS for x in _EXPECTED_FIELDS.values()), (
"All expected fields must be part of _DYNAMIC_FIELDS." "All expected fields must be part of _DYNAMIC_FIELDS."
) )
# TODO: ensure that the transaction types are not prefixed with 'TransactionType.' in the database
def _transaction_type_field_constraints( def _transaction_type_field_constraints(
transaction_type: TransactionType, transaction_type: TransactionType,
@@ -72,14 +76,14 @@ def _transaction_type_field_constraints(
) -> CheckConstraint: ) -> CheckConstraint:
unexpected_fields = _DYNAMIC_FIELDS - expected_fields unexpected_fields = _DYNAMIC_FIELDS - expected_fields
expected_constraints = ["{} IS NOT NULL".format(field) for field in expected_fields]
unexpected_constraints = ["{} IS NULL".format(field) for field in unexpected_fields]
constraints = expected_constraints + unexpected_constraints
# TODO: use sqlalchemy's `and_` and `or_` to build the constraints
return CheckConstraint( return CheckConstraint(
f"type <> '{transaction_type}' OR ({' AND '.join(constraints)})", or_(
column("type") != transaction_type.value,
and_(
*[column(field) != None for field in expected_fields],
*[column(field) == None for field in unexpected_fields],
),
),
name=f"trx_type_{transaction_type.value}_expected_fields", name=f"trx_type_{transaction_type.value}_expected_fields",
) )
@@ -91,7 +95,10 @@ class Transaction(Base):
for transaction_type, expected_fields in _EXPECTED_FIELDS.items() for transaction_type, expected_fields in _EXPECTED_FIELDS.items()
], ],
CheckConstraint( CheckConstraint(
f"type <> '{TransactionType.TRANSFER}' OR user_id <> transfer_user_id", or_(
column("type") != TransactionType.TRANSFER.value,
column("user_id") != column("transfer_user_id"),
),
name="trx_type_transfer_no_self_transfers", name="trx_type_transfer_no_self_transfers",
), ),
# Speed up product count calculation # Speed up product count calculation
@@ -125,7 +132,7 @@ class Transaction(Base):
This is not used for any calculations, but can be useful for debugging. This is not used for any calculations, but can be useful for debugging.
""" """
type_: Mapped[TransactionType] = mapped_column(SQLEnum(TransactionType), name="type") type_: Mapped[TransactionType] = mapped_column(TransactionTypeSQL, name="type")
""" """
Which type of transaction this is. Which type of transaction this is.
@@ -292,6 +299,39 @@ class Transaction(Base):
"The real amount of the transaction must be less than the total value of the products." "The real amount of the transaction must be less than the total value of the products."
) )
# TODO: improve printing further
def __repr__(self) -> str:
sort_order = [
"id",
"time",
]
columns = ", ".join(
f"{k}={repr(v)}"
for k, v in sorted(
self.__dict__.items(),
key=lambda item: chr(sort_order.index(item[0]))
if item[0] in sort_order
else item[0],
)
if not any(
[
k == "type_",
(k == "message" and v is None),
k.startswith("_"),
# Ensure that we don't try to print out the entire list of
# relationships, which could create an infinite loop
isinstance(v, Base),
isinstance(v, InstrumentedList),
isinstance(v, InstrumentedSet),
isinstance(v, InstrumentedDict),
*[k in (_DYNAMIC_FIELDS - _EXPECTED_FIELDS[self.type_])],
]
)
)
return f"{self.type_.upper()}({columns})"
################### ###################
# FACTORY METHODS # # FACTORY METHODS #
################### ###################

View File

@@ -1,15 +1,26 @@
from enum import Enum from enum import StrEnum, auto
from sqlalchemy import Enum as SQLEnum
class TransactionType(Enum): class TransactionType(StrEnum):
""" """
Enum for transaction types. Enum for transaction types.
""" """
ADD_PRODUCT = "add_product" ADD_PRODUCT = auto()
ADJUST_BALANCE = "adjust_balance" ADJUST_BALANCE = auto()
ADJUST_INTEREST = "adjust_interest" ADJUST_INTEREST = auto()
ADJUST_PENALTY = "adjust_penalty" ADJUST_PENALTY = auto()
ADJUST_STOCK = "adjust_stock" ADJUST_STOCK = auto()
BUY_PRODUCT = "buy_product" BUY_PRODUCT = auto()
TRANSFER = "transfer" TRANSFER = auto()
TransactionTypeSQL = SQLEnum(
TransactionType,
native_enum=True,
create_constraint=True,
validate_strings=True,
values_callable=lambda x: [i.value for i in x],
)

View File

@@ -45,20 +45,3 @@ class User(Base):
# def is_anonymous(self): # def is_anonymous(self):
# return self.card == "11122233" # return self.card == "11122233"
# TODO: move to 'queries'
# TODO: allow filtering out 'special transactions' like 'ADJUST_INTEREST' and 'ADJUST_PENALTY'
def transactions(self, sql_session: Session) -> list[Transaction]:
"""
Returns the transactions of the user in chronological order.
"""
from .Transaction import Transaction # Import here to avoid circular import
return list(
sql_session.scalars(
select(Transaction)
.where(Transaction.user_id == self.id)
.order_by(Transaction.time.asc())
).all()
)

View File

@@ -16,6 +16,4 @@ def current_interest(sql_session: Session) -> int:
if result is None: if result is None:
return DEFAULT_INTEREST_RATE_PERCENTAGE return DEFAULT_INTEREST_RATE_PERCENTAGE
assert result.interest_rate_percent is not None, "Interest rate percent must be set"
return result.interest_rate_percent return result.interest_rate_percent

View File

@@ -1,7 +1,11 @@
import math
from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from sqlalchemy import ( from sqlalchemy import (
ColumnElement,
Integer, Integer,
SQLColumnExpression,
asc, asc,
case, case,
cast, cast,
@@ -16,13 +20,14 @@ from dibbler.models import (
Transaction, Transaction,
TransactionType, TransactionType,
) )
from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENTAGE
# TODO: include the transaction id in the log for easier debugging
def _product_price_query( def _product_price_query(
product_id: int, product_id: int | ColumnElement[int],
use_cache: bool = True, use_cache: bool = True,
until: datetime | None = None, until: datetime | SQLColumnExpression[datetime] | None = None,
until_including: bool = True,
cte_name: str = "rec_cte", cte_name: str = "rec_cte",
): ):
""" """
@@ -35,6 +40,7 @@ def _product_price_query(
initial_element = select( initial_element = select(
literal(0).label("i"), literal(0).label("i"),
literal(0).label("time"), literal(0).label("time"),
literal(None).label("transaction_id"),
literal(0).label("price"), literal(0).label("price"),
literal(0).label("product_count"), literal(0).label("product_count"),
) )
@@ -45,6 +51,7 @@ def _product_price_query(
trx_subset = ( trx_subset = (
select( select(
func.row_number().over(order_by=asc(Transaction.time)).label("i"), func.row_number().over(order_by=asc(Transaction.time)).label("i"),
Transaction.id,
Transaction.time, Transaction.time,
Transaction.type_, Transaction.type_,
Transaction.product_count, Transaction.product_count,
@@ -59,7 +66,12 @@ def _product_price_query(
] ]
), ),
Transaction.product_id == product_id, Transaction.product_id == product_id,
Transaction.time <= until if until is not None else 1 == 1, case(
(literal(until_including), Transaction.time <= until),
else_=Transaction.time < until,
)
if until is not None
else literal(True),
) )
.order_by(Transaction.time.asc()) .order_by(Transaction.time.asc())
.alias("trx_subset") .alias("trx_subset")
@@ -69,6 +81,7 @@ def _product_price_query(
select( select(
trx_subset.c.i, trx_subset.c.i,
trx_subset.c.time, trx_subset.c.time,
trx_subset.c.id.label("transaction_id"),
case( case(
# Someone buys the product -> price remains the same. # Someone buys the product -> price remains the same.
(trx_subset.c.type_ == TransactionType.BUY_PRODUCT, recursive_cte.c.price), (trx_subset.c.type_ == TransactionType.BUY_PRODUCT, recursive_cte.c.price),
@@ -78,7 +91,10 @@ def _product_price_query(
trx_subset.c.type_ == TransactionType.ADD_PRODUCT, trx_subset.c.type_ == TransactionType.ADD_PRODUCT,
cast( cast(
func.ceil( func.ceil(
(trx_subset.c.per_product * trx_subset.c.product_count) (
recursive_cte.c.price * func.max(recursive_cte.c.product_count, 0)
+ trx_subset.c.per_product * trx_subset.c.product_count
)
/ ( / (
# The running product count can be negative if the accounting is bad. # The running product count can be negative if the accounting is bad.
# This ensures that we never end up with negative prices or zero divisions # This ensures that we never end up with negative prices or zero divisions
@@ -122,19 +138,23 @@ def _product_price_query(
return recursive_cte.union_all(recursive_elements) return recursive_cte.union_all(recursive_elements)
# TODO: create a function for the log that pretty prints the log entries # TODO: create a function for the log that pretty prints the log entries
# for debugging purposes # for debugging purposes
# TODO: wrap the log entries in a dataclass, the don't cost that much @dataclass
class ProductPriceLogEntry:
transaction: Transaction
price: int
product_count: int
def product_price_log( def product_price_log(
sql_session: Session, sql_session: Session,
product: Product, product: Product,
use_cache: bool = True, use_cache: bool = True,
until: Transaction | None = None, until: Transaction | None = None,
) -> list[tuple[int, datetime, int, int]]: ) -> list[ProductPriceLogEntry]:
""" """
Calculates the price of a product and returns a log of the price changes. Calculates the price of a product and returns a log of the price changes.
""" """
@@ -147,20 +167,32 @@ def product_price_log(
result = sql_session.execute( result = sql_session.execute(
select( select(
recursive_cte.c.i, Transaction,
recursive_cte.c.time,
recursive_cte.c.price, recursive_cte.c.price,
recursive_cte.c.product_count, recursive_cte.c.product_count,
).order_by(recursive_cte.c.i.asc()) )
.select_from(recursive_cte)
.join(
Transaction,
onclause=Transaction.id == recursive_cte.c.transaction_id,
)
.order_by(recursive_cte.c.i.asc())
).all() ).all()
if not result: if result is None:
# If there are no transactions for this product, the query should return an empty list, not None. # If there are no transactions for this product, the query should return an empty list, not None.
raise RuntimeError( raise RuntimeError(
f"Something went wrong while calculating the price log for product {product.name} (ID: {product.id})." f"Something went wrong while calculating the price log for product {product.name} (ID: {product.id})."
) )
return [(row.i, row.time, row.price, row.product_count) for row in result] return [
ProductPriceLogEntry(
transaction=row[0],
price=row.price,
product_count=row.product_count,
)
for row in result
]
@staticmethod @staticmethod
@@ -169,6 +201,7 @@ def product_price(
product: Product, product: Product,
use_cache: bool = True, use_cache: bool = True,
until: Transaction | None = None, until: Transaction | None = None,
include_interest: bool = False,
) -> int: ) -> int:
""" """
Calculates the price of a product. Calculates the price of a product.
@@ -184,9 +217,9 @@ def product_price(
# - product_count should never be negative (but this happens sometimes, so just a warning) # - product_count should never be negative (but this happens sometimes, so just a warning)
# - price should never be negative # - price should never be negative
result = sql_session.scalar( result = sql_session.scalars(
select(recursive_cte.c.price).order_by(recursive_cte.c.i.desc()).limit(1) select(recursive_cte.c.price).order_by(recursive_cte.c.i.desc()).limit(1)
) ).one_or_none()
if result is None: if result is None:
# If there are no transactions for this product, the query should return 0, not None. # If there are no transactions for this product, the query should return 0, not None.
@@ -194,4 +227,19 @@ def product_price(
f"Something went wrong while calculating the price for product {product.name} (ID: {product.id})." f"Something went wrong while calculating the price for product {product.name} (ID: {product.id})."
) )
if include_interest:
interest_rate = (
sql_session.scalar(
select(Transaction.interest_rate_percent)
.where(
Transaction.type_ == TransactionType.ADJUST_INTEREST,
literal(True) if until is None else Transaction.time <= until.time,
)
.order_by(Transaction.time.desc())
.limit(1)
)
or DEFAULT_INTEREST_RATE_PERCENTAGE
)
result = math.ceil(result * interest_rate / 100)
return result return result

View File

@@ -1,6 +1,6 @@
from datetime import datetime from datetime import datetime
from sqlalchemy import case, func, select from sqlalchemy import case, func, literal, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from dibbler.models import ( from dibbler.models import (
@@ -10,6 +10,51 @@ from dibbler.models import (
) )
def _product_stock_query(
product_id: int,
use_cache: bool = True,
until: datetime | None = None,
):
"""
The inner query for calculating the product stock.
"""
if use_cache:
print("WARNING: Using cache for product stock query is not implemented yet.")
query = select(
func.sum(
case(
(
Transaction.type_ == TransactionType.ADD_PRODUCT,
Transaction.product_count,
),
(
Transaction.type_ == TransactionType.BUY_PRODUCT,
-Transaction.product_count,
),
(
Transaction.type_ == TransactionType.ADJUST_STOCK,
Transaction.product_count,
),
else_=0,
)
)
).where(
Transaction.type_.in_(
[
TransactionType.BUY_PRODUCT,
TransactionType.ADD_PRODUCT,
TransactionType.ADJUST_STOCK,
]
),
Transaction.product_id == product_id,
Transaction.time <= until if until is not None else literal(True),
)
return query
def product_stock( def product_stock(
sql_session: Session, sql_session: Session,
product: Product, product: Product,
@@ -20,39 +65,12 @@ def product_stock(
Returns the number of products in stock. Returns the number of products in stock.
""" """
if use_cache: query = _product_stock_query(
print("WARNING: Using cache for product stock query is not implemented yet.") product_id=product.id,
use_cache=use_cache,
until=until,
)
result = sql_session.scalars( result = sql_session.scalars(query).one_or_none()
select(
func.sum(
case(
(
Transaction.type_ == TransactionType.ADD_PRODUCT,
Transaction.product_count,
),
(
Transaction.type_ == TransactionType.BUY_PRODUCT,
-Transaction.product_count,
),
(
Transaction.type_ == TransactionType.ADJUST_STOCK,
Transaction.product_count,
),
else_=0,
)
)
).where(
Transaction.type_.in_(
[
TransactionType.BUY_PRODUCT,
TransactionType.ADD_PRODUCT,
TransactionType.ADJUST_STOCK,
]
),
Transaction.product_id == product.id,
Transaction.time <= until if until is not None else 1 == 1,
)
).one_or_none()
return result or 0 return result or 0

View File

@@ -1,54 +1,39 @@
from sqlalchemy import and_, or_ from sqlalchemy import and_, literal, or_, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from dibbler.models import Product from dibbler.models import Product
# TODO: modernize queries to use SQLAlchemy 2.0 style
def search_product( def search_product(
string: str, string: str,
session: Session, sql_session: Session,
find_hidden_products=True, find_hidden_products=True,
) -> Product | list[Product]: ) -> Product | list[Product]:
if find_hidden_products: exact_match = sql_session.scalars(
exact_match = ( select(Product).where(
session.query(Product) or_(
.filter(or_(Product.bar_code == string, Product.name == string)) Product.bar_code == string,
.first() and_(
) Product.name == string,
else: literal(True) if find_hidden_products else not Product.hidden,
exact_match = ( ),
session.query(Product)
.filter(
or_(
Product.bar_code == string,
and_(Product.name == string, not Product.hidden),
)
) )
.first()
) )
).first()
if exact_match: if exact_match:
return exact_match return exact_match
if find_hidden_products:
product_list = ( product_list = sql_session.scalars(
session.query(Product) select(Product).where(
.filter( or_(
or_( Product.bar_code.ilike(f"%{string}%"),
Product.bar_code.ilike(f"%{string}%"), and_(
Product.name.ilike(f"%{string}%"), Product.name.ilike(f"%{string}%"),
) literal(True) if find_hidden_products else not Product.hidden,
),
) )
.all()
) )
else: ).all()
product_list = (
session.query(Product) return list(product_list)
.filter(
or_(
Product.bar_code.ilike(f"%{string}%"),
and_(Product.name.ilike(f"%{string}%"), not Product.hidden),
)
)
.all()
)
return product_list

View File

@@ -1,28 +1,37 @@
from sqlalchemy import or_ from sqlalchemy import or_, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from dibbler.models import User from dibbler.models import User
# TODO: modernize queries to use SQLAlchemy 2.0 style def search_user(
def search_user(string: str, session: Session, ignorethisflag=None) -> User | list[User]: string: str,
sql_session: Session,
ignorethisflag=None,
) -> User | list[User]:
string = string.lower() string = string.lower()
exact_match = (
session.query(User) exact_match = sql_session.scalars(
.filter(or_(User.name == string, User.card == string, User.rfid == string)) select(User).where(
.first() or_(
) User.name == string,
User.card == string,
User.rfid == string,
)
)
).first()
if exact_match: if exact_match:
return exact_match return exact_match
user_list = (
session.query(User) user_list = sql_session.scalars(
.filter( select(User).where(
or_( or_(
User.name.ilike(f"%{string}%"), User.name.ilike(f"%{string}%"),
User.card.ilike(f"%{string}%"), User.card.ilike(f"%{string}%"),
User.rfid.ilike(f"%{string}%"), User.rfid.ilike(f"%{string}%"),
) )
) )
.all() ).all()
)
return user_list return list(user_list)

View File

@@ -1,6 +1,8 @@
from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from sqlalchemy import ( from sqlalchemy import (
Float,
Integer, Integer,
and_, and_,
asc, asc,
@@ -26,12 +28,12 @@ from dibbler.models.Transaction import (
) )
from dibbler.queries.product_price import _product_price_query from dibbler.queries.product_price import _product_price_query
# TODO: include the transaction id in the log for easier debugging
def _user_balance_query( def _user_balance_query(
user: User, user_id: int,
use_cache: bool = True, use_cache: bool = True,
until: datetime | None = None, until: datetime | None = None,
until_including: bool = True,
cte_name: str = "rec_cte", cte_name: str = "rec_cte",
): ):
""" """
@@ -44,6 +46,7 @@ def _user_balance_query(
initial_element = select( initial_element = select(
literal(0).label("i"), literal(0).label("i"),
literal(0).label("time"), literal(0).label("time"),
literal(None).label("transaction_id"),
literal(0).label("balance"), literal(0).label("balance"),
literal(DEFAULT_INTEREST_RATE_PERCENTAGE).label("interest_rate_percent"), literal(DEFAULT_INTEREST_RATE_PERCENTAGE).label("interest_rate_percent"),
literal(DEFAULT_PENALTY_THRESHOLD).label("penalty_threshold"), literal(DEFAULT_PENALTY_THRESHOLD).label("penalty_threshold"),
@@ -56,20 +59,21 @@ def _user_balance_query(
trx_subset = ( trx_subset = (
select( select(
func.row_number().over(order_by=asc(Transaction.time)).label("i"), func.row_number().over(order_by=asc(Transaction.time)).label("i"),
Transaction.time,
Transaction.type_,
Transaction.amount, Transaction.amount,
Transaction.product_count, Transaction.id,
Transaction.product_id,
Transaction.transfer_user_id,
Transaction.interest_rate_percent, Transaction.interest_rate_percent,
Transaction.penalty_multiplier_percent, Transaction.penalty_multiplier_percent,
Transaction.penalty_threshold, Transaction.penalty_threshold,
Transaction.product_count,
Transaction.product_id,
Transaction.time,
Transaction.transfer_user_id,
Transaction.type_,
) )
.where( .where(
or_( or_(
and_( and_(
Transaction.user_id == user.id, Transaction.user_id == user_id,
Transaction.type_.in_( Transaction.type_.in_(
[ [
TransactionType.ADD_PRODUCT, TransactionType.ADD_PRODUCT,
@@ -81,7 +85,7 @@ def _user_balance_query(
), ),
and_( and_(
Transaction.type_ == TransactionType.TRANSFER, Transaction.type_ == TransactionType.TRANSFER,
Transaction.transfer_user_id == user.id, Transaction.transfer_user_id == user_id,
), ),
Transaction.type_.in_( Transaction.type_.in_(
[ [
@@ -90,7 +94,12 @@ def _user_balance_query(
] ]
), ),
), ),
Transaction.time <= until if until is not None else 1 == 1, case(
(literal(until_including), Transaction.time <= until),
else_=Transaction.time < until,
)
if until is not None
else literal(True),
) )
.order_by(Transaction.time.asc()) .order_by(Transaction.time.asc())
.alias("trx_subset") .alias("trx_subset")
@@ -100,6 +109,7 @@ def _user_balance_query(
select( select(
trx_subset.c.i, trx_subset.c.i,
trx_subset.c.time, trx_subset.c.time,
trx_subset.c.id.label("transaction_id"),
case( case(
# Adjusts balance -> balance gets adjusted # Adjusts balance -> balance gets adjusted
( (
@@ -124,12 +134,14 @@ def _user_balance_query(
# product prices somehow. # product prices somehow.
# Base price # Base price
( (
select(column("price")) # FIXME: this always returns 0 for some reason...
select(cast(column("price"), Float))
.select_from( .select_from(
_product_price_query( _product_price_query(
trx_subset.c.product_id, trx_subset.c.product_id,
use_cache=use_cache, use_cache=use_cache,
until=trx_subset.c.time, until=trx_subset.c.time,
until_including=False,
cte_name="product_price_cte", cte_name="product_price_cte",
) )
) )
@@ -139,12 +151,16 @@ def _user_balance_query(
# TODO: should interest be applied before or after the penalty multiplier? # TODO: should interest be applied before or after the penalty multiplier?
# at the moment of writing, after sound right, but maybe ask someone? # at the moment of writing, after sound right, but maybe ask someone?
# Interest # Interest
* (recursive_cte.c.interest_rate_percent / 100) * (cast(recursive_cte.c.interest_rate_percent, Float) / 100)
# Penalty # Penalty
* case( * case(
( (
# TODO: should this be <= or <?
recursive_cte.c.balance < recursive_cte.c.penalty_threshold, recursive_cte.c.balance < recursive_cte.c.penalty_threshold,
(recursive_cte.c.penalty_multiplier_percent / 100), (
cast(recursive_cte.c.penalty_multiplier_percent, Float)
/ 100
),
), ),
else_=1.0, else_=1.0,
) )
@@ -155,14 +171,18 @@ def _user_balance_query(
), ),
# Transfers money to self -> balance increases # Transfers money to self -> balance increases
( (
trx_subset.c.type_ == TransactionType.TRANSFER and_(
and trx_subset.c.transfer_user_id == user.id, trx_subset.c.type_ == TransactionType.TRANSFER,
trx_subset.c.transfer_user_id == user_id,
),
recursive_cte.c.balance + trx_subset.c.amount, recursive_cte.c.balance + trx_subset.c.amount,
), ),
# Transfers money from self -> balance decreases # Transfers money from self -> balance decreases
( (
trx_subset.c.type_ == TransactionType.TRANSFER and_(
and trx_subset.c.transfer_user_id != user.id, trx_subset.c.type_ == TransactionType.TRANSFER,
trx_subset.c.transfer_user_id != user_id,
),
recursive_cte.c.balance - trx_subset.c.amount, recursive_cte.c.balance - trx_subset.c.amount,
), ),
# Interest adjustment -> balance stays the same # Interest adjustment -> balance stays the same
@@ -201,32 +221,55 @@ def _user_balance_query(
# TODO: create a function for the log that pretty prints the log entries # TODO: create a function for the log that pretty prints the log entries
# for debugging purposes # for debugging purposes
# TODO: wrap the log entries in a dataclass, the don't cost that much
# TODO: add a method on the dataclass, using the running penalization data @dataclass
# to figure out if the current row was penalized or not. class UserBalanceLogEntry:
transaction: Transaction
balance: int
interest_rate_percent: int
penalty_threshold: int
penalty_multiplier_percent: int
def is_penalized(self) -> bool:
"""
Returns whether this exact transaction is penalized.
"""
return False
# return self.transaction.type_ == TransactionType.BUY_PRODUCT and prev?
def user_balance_log( def user_balance_log(
sql_session: Session, sql_session: Session,
user: User, user: User,
use_cache: bool = True, use_cache: bool = True,
until: Transaction | None = None, until: Transaction | None = None,
) -> list[tuple[int, datetime, int, int, int, int]]: ) -> list[UserBalanceLogEntry]:
"""
Returns a log of the user's balance over time, including interest and penalty adjustments.
"""
recursive_cte = _user_balance_query( recursive_cte = _user_balance_query(
user, user.id,
use_cache=use_cache, use_cache=use_cache,
until=until.time if until else None, until=until.time if until else None,
) )
result = sql_session.execute( result = sql_session.execute(
select( select(
recursive_cte.c.i, Transaction,
recursive_cte.c.time,
recursive_cte.c.balance, recursive_cte.c.balance,
recursive_cte.c.interest_rate_percent, recursive_cte.c.interest_rate_percent,
recursive_cte.c.penalty_threshold, recursive_cte.c.penalty_threshold,
recursive_cte.c.penalty_multiplier_percent, recursive_cte.c.penalty_multiplier_percent,
).order_by(recursive_cte.c.i.asc()) )
.select_from(recursive_cte)
.join(
Transaction,
onclause=Transaction.id == recursive_cte.c.transaction_id,
)
.order_by(recursive_cte.c.i.asc())
).all() ).all()
if result is None: if result is None:
@@ -235,7 +278,16 @@ def user_balance_log(
f"Something went wrong while calculating the balance for user {user.name} (ID: {user.id})." f"Something went wrong while calculating the balance for user {user.name} (ID: {user.id})."
) )
return result return [
UserBalanceLogEntry(
transaction=row[0],
balance=row.balance,
interest_rate_percent=row.interest_rate_percent,
penalty_threshold=row.penalty_threshold,
penalty_multiplier_percent=row.penalty_multiplier_percent,
)
for row in result
]
def user_balance( def user_balance(
@@ -249,7 +301,7 @@ def user_balance(
""" """
recursive_cte = _user_balance_query( recursive_cte = _user_balance_query(
user, user.id,
use_cache=use_cache, use_cache=use_cache,
until=until.time if until else None, until=until.time if until else None,
) )

View File

@@ -0,0 +1,20 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from dibbler.models import Transaction, User
# TODO: allow filtering out 'special transactions' like 'ADJUST_INTEREST' and 'ADJUST_PENALTY'
def user_transactions(sql_session: Session, user: User) -> list[Transaction]:
"""
Returns the transactions of the user in chronological order.
"""
return list(
sql_session.scalars(
select(Transaction)
.where(Transaction.user_id == user.id)
.order_by(Transaction.time.asc())
).all()
)

View File

@@ -2,7 +2,7 @@ from datetime import datetime
from pathlib import Path from pathlib import Path
from dibbler.db import Session from dibbler.db import Session
from dibbler.models import Product, Transaction, TransactionType, User from dibbler.models import Product, Transaction, User
JSON_FILE = Path(__file__).parent.parent.parent / "mock_data.json" JSON_FILE = Path(__file__).parent.parent.parent / "mock_data.json"
@@ -11,6 +11,7 @@ JSON_FILE = Path(__file__).parent.parent.parent / "mock_data.json"
# whether to seed test data, or by using command line arguments for # whether to seed test data, or by using command line arguments for
# automatating the answer. # automatating the answer.
def clear_db(sql_session): def clear_db(sql_session):
sql_session.query(Product).delete() sql_session.query(Product).delete()
sql_session.query(User).delete() sql_session.query(User).delete()
@@ -41,37 +42,31 @@ def main():
# Add transactions # Add transactions
transactions = [ transactions = [
Transaction( Transaction.adjust_balance(
time=datetime(2023, 10, 1, 10, 0, 0), time=datetime(2023, 10, 1, 10, 0, 0),
type_=TransactionType.ADJUST_BALANCE,
amount=100, amount=100,
user_id=user1.id, user_id=user1.id,
), ),
Transaction( Transaction.adjust_balance(
time=datetime(2023, 10, 1, 10, 0, 1), time=datetime(2023, 10, 1, 10, 0, 1),
type_=TransactionType.ADJUST_BALANCE,
amount=50, amount=50,
user_id=user2.id, user_id=user2.id,
), ),
Transaction( Transaction.adjust_balance(
time=datetime(2023, 10, 1, 10, 0, 2), time=datetime(2023, 10, 1, 10, 0, 2),
type_=TransactionType.ADJUST_BALANCE,
amount=-50, amount=-50,
user_id=user1.id, user_id=user1.id,
), ),
Transaction( Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 0), time=datetime(2023, 10, 1, 12, 0, 0),
type_=TransactionType.ADD_PRODUCT,
amount=27 * 2, amount=27 * 2,
per_product=27, per_product=27,
product_count=2, product_count=2,
user_id=user1.id, user_id=user1.id,
product_id=product1.id, product_id=product1.id,
), ),
Transaction( Transaction.buy_product(
time=datetime(2023, 10, 1, 12, 0, 1), time=datetime(2023, 10, 1, 12, 0, 1),
type_=TransactionType.BUY_PRODUCT,
amount=27,
product_count=1, product_count=1,
user_id=user2.id, user_id=user2.id,
product_id=product1.id, product_id=product1.id,

View File

@@ -1,6 +1,6 @@
import pytest import pytest
from sqlalchemy import create_engine from sqlalchemy import create_engine, event
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from dibbler.models import Base from dibbler.models import Base
@@ -24,6 +24,13 @@ def sql_session(request):
"sqlite:///:memory:", "sqlite:///:memory:",
echo=echo, echo=echo,
) )
@event.listens_for(engine, "connect")
def set_sqlite_pragma(dbapi_connection, _connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
Base.metadata.create_all(engine) Base.metadata.create_all(engine)
with Session(engine) as sql_session: with Session(engine) as sql_session:
yield sql_session yield sql_session

View File

@@ -5,6 +5,7 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User from dibbler.models import Product, Transaction, User
from dibbler.queries.product_stock import product_stock
def insert_test_data(sql_session: Session) -> tuple[User, Product]: def insert_test_data(sql_session: Session) -> tuple[User, Product]:
@@ -118,3 +119,81 @@ def test_user_foreign_key_constraint(sql_session: Session) -> None:
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
sql_session.commit() sql_session.commit()
def test_transaction_buy_product_more_than_stock(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 0),
user_id=user.id,
product_id=product.id,
amount=27,
per_product=27,
product_count=1,
),
Transaction.buy_product(
time=datetime(2023, 10, 1, 13, 0, 0),
product_count=10,
user_id=user.id,
product_id=product.id,
),
]
sql_session.add_all(transactions)
sql_session.commit()
assert product_stock(sql_session, product) == 1 - 10
def test_transaction_buy_product_dont_allow_no_add_product_transactions(
sql_session: Session,
) -> None:
user, product = insert_test_data(sql_session)
transaction = Transaction.buy_product(
time=datetime(2023, 10, 1, 12, 0, 0),
product_count=1,
user_id=user.id,
product_id=product.id,
)
sql_session.add(transaction)
with pytest.raises(ValueError):
sql_session.commit()
def test_transaction_add_product_deny_amount_over_per_product_times_product_count(
sql_session: Session,
) -> None:
user, product = insert_test_data(sql_session)
with pytest.raises(ValueError):
_transaction = Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 0),
user_id=user.id,
product_id=product.id,
amount=27 * 2 + 1, # Invalid amount
per_product=27,
product_count=2,
)
def test_transaction_add_product_allow_amount_under_per_product_times_product_count(
sql_session: Session,
) -> None:
user, product = insert_test_data(sql_session)
transaction = Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 0),
user_id=user.id,
product_id=product.id,
amount=27 * 2 - 1, # Valid amount
per_product=27,
product_count=2,
)
sql_session.add(transaction)
sql_session.commit()

View File

@@ -23,49 +23,3 @@ def test_ensure_no_duplicate_user_names(sql_session: Session):
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
sql_session.commit() sql_session.commit()
def test_user_transactions(sql_session: Session):
user = insert_test_data(sql_session)
product = Product("1234567890123", "Test Product")
user2 = User("Test User 2")
sql_session.add_all([product, user2])
sql_session.commit()
transactions = [
Transaction.adjust_balance(
time=datetime(2023, 10, 1, 10, 0, 0),
amount=100,
user_id=user.id,
),
Transaction.adjust_balance(
time=datetime(2023, 10, 1, 10, 0, 1),
amount=50,
user_id=user2.id,
),
Transaction.adjust_balance(
time=datetime(2023, 10, 1, 10, 0, 2),
amount=-50,
user_id=user.id,
),
Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 0),
amount=27 * 2,
per_product=27,
product_count=2,
user_id=user.id,
product_id=product.id,
),
Transaction.buy_product(
time=datetime(2023, 10, 1, 12, 0, 1),
product_count=1,
user_id=user2.id,
product_id=product.id,
),
]
sql_session.add_all(transactions)
assert len(user.transactions(sql_session)) == 3
assert len(user2.transactions(sql_session)) == 2

View File

@@ -1,177 +0,0 @@
import math
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
from dibbler.queries.product_stock import product_stock
from dibbler.queries.user_balance import user_balance
def insert_test_data(sql_session: Session) -> tuple[User, Product]:
user = User("Test User")
product = Product("1234567890123", "Test Product")
sql_session.add(user)
sql_session.add(product)
sql_session.commit()
transactions = [
Transaction.adjust_penalty(
time=datetime(2023, 10, 1, 10, 0, 0),
user_id=user.id,
penalty_multiplier_percent=200,
penalty_threshold=-100,
),
Transaction.adjust_balance(
time=datetime(2023, 10, 1, 10, 0, 1),
user_id=user.id,
amount=100,
),
Transaction.add_product(
time=datetime(2023, 10, 1, 10, 0, 2),
user_id=user.id,
product_id=product.id,
amount=27,
per_product=27,
product_count=1,
),
]
sql_session.add_all(transactions)
sql_session.commit()
return user, product
def test_buy_product_basic(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transaction = Transaction.buy_product(
time=datetime(2023, 10, 1, 12, 0, 0),
user_id=user.id,
product_id=product.id,
product_count=1,
)
sql_session.add(transaction)
sql_session.commit()
def test_buy_product_with_penalty(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.adjust_balance(
time=datetime(2023, 10, 1, 11, 0, 0),
user_id=user.id,
amount=-200,
)
]
sql_session.add_all(transactions)
sql_session.commit()
transaction = Transaction.buy_product(
time=datetime(2023, 10, 1, 12, 0, 0),
user_id=user.id,
product_id=product.id,
product_count=1,
)
sql_session.add(transaction)
sql_session.commit()
assert user_balance(sql_session, user) == 100 + 27 - 200 - (27 * 2)
def test_buy_product_with_interest(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.adjust_interest(
time=datetime(2023, 10, 1, 11, 0, 0),
user_id=user.id,
interest_rate_percent=110,
)
]
sql_session.add_all(transactions)
sql_session.commit()
transaction = Transaction.buy_product(
time=datetime(2023, 10, 1, 12, 0, 0),
user_id=user.id,
product_id=product.id,
product_count=1,
)
sql_session.add(transaction)
sql_session.commit()
assert user_balance(sql_session, user) == 100 + 27 - math.ceil(27 * 1.1)
def test_buy_product_with_changing_penalty(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.adjust_balance(
time=datetime(2023, 10, 1, 11, 0, 0),
user_id=user.id,
amount=-200,
)
]
sql_session.add_all(transactions)
sql_session.commit()
transaction = Transaction.buy_product(
time=datetime(2023, 10, 1, 12, 0, 0),
user_id=user.id,
product_id=product.id,
product_count=1,
)
sql_session.add(transaction)
sql_session.commit()
assert user_balance(sql_session, user) == 100 + 27 - 200 - (27 * 2)
adjust_penalty = Transaction.adjust_penalty(
time=datetime(2023, 10, 1, 13, 0, 0),
user_id=user.id,
penalty_multiplier_percent=300,
penalty_threshold=-100,
)
sql_session.add(adjust_penalty)
sql_session.commit()
transaction = Transaction.buy_product(
time=datetime(2023, 10, 1, 14, 0, 0),
user_id=user.id,
product_id=product.id,
product_count=1,
)
sql_session.add(transaction)
sql_session.commit()
assert user_balance(sql_session, user) == 100 + 27 - 200 - (27 * 2) - (27 * 3)
def test_buy_product_with_changing_interest(sql_session: Session) -> None:
raise NotImplementedError("This test is not implemented yet.")
def test_buy_product_with_penalty_interest_combined(sql_session: Session) -> None:
raise NotImplementedError("This test is not implemented yet.")
def test_buy_product_more_than_stock(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transaction = Transaction.buy_product(
time=datetime(2023, 10, 1, 13, 0, 0),
product_count=10,
user_id=user.id,
product_id=product.id,
)
sql_session.add(transaction)
sql_session.commit()
assert product_stock(sql_session, product) == 1 - 10

View File

@@ -1,170 +1,342 @@
import math
from datetime import datetime from datetime import datetime
from pprint import pprint
from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User from dibbler.models import Product, Transaction, User
from dibbler.queries.product_price import product_price from dibbler.queries.product_price import product_price, product_price_log
# TODO: see if we can use pytest_runtest_makereport to print the "product_price_log"s
# only on failures instead of inlining it in every test function
def insert_test_data(sql_session: Session) -> None: def insert_test_data(sql_session: Session) -> tuple[User, Product]:
# Add users user = User("Test User")
user1 = User("Test User 1") product = Product("1234567890123", "Test Product")
user2 = User("Test User 2")
sql_session.add_all([user1, user2]) sql_session.add(user)
sql_session.add(product)
sql_session.commit() sql_session.commit()
# Add products return user, product
product1 = Product("1234567890123", "Test Product 1")
product2 = Product("9876543210987", "Test Product 2")
product3 = Product("1111111111111", "Test Product 3") def test_product_price_no_transactions(sql_session: Session) -> None:
sql_session.add_all([product1, product2, product3]) _, product = insert_test_data(sql_session)
sql_session.commit()
pprint(product_price_log(sql_session, product))
assert product_price(sql_session, product) == 0
def test_product_price_basic_history(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
# Add transactions
transactions = [ transactions = [
Transaction.adjust_balance(
time=datetime(2023, 10, 1, 10, 0, 0),
amount=100,
user_id=user1.id,
),
Transaction.adjust_balance(
time=datetime(2023, 10, 1, 10, 0, 1),
amount=50,
user_id=user2.id,
),
Transaction.adjust_balance(
time=datetime(2023, 10, 1, 10, 0, 2),
amount=-50,
user_id=user1.id,
),
Transaction.add_product( Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 0), time=datetime(2023, 10, 1, 12, 0, 0),
amount=27 * 2, amount=27 * 2 - 1,
per_product=27, per_product=27,
product_count=2, product_count=2,
user_id=user1.id, user_id=user.id,
product_id=product1.id, product_id=product.id,
),
Transaction.buy_product(
time=datetime(2023, 10, 1, 12, 0, 1),
product_count=1,
user_id=user2.id,
product_id=product1.id,
),
Transaction.adjust_stock(
time=datetime(2023, 10, 1, 12, 0, 2),
product_count=3,
user_id=user1.id,
product_id=product1.id,
),
Transaction.adjust_stock(
time=datetime(2023, 10, 1, 12, 0, 3),
product_count=-2,
user_id=user1.id,
product_id=product1.id,
),
Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 4),
amount=50,
per_product=50,
product_count=1,
user_id=user1.id,
product_id=product3.id,
),
Transaction.buy_product(
time=datetime(2023, 10, 1, 12, 0, 5),
product_count=1,
user_id=user1.id,
product_id=product3.id,
),
Transaction.adjust_balance(
time=datetime(2023, 10, 1, 12, 0, 6),
amount=1000,
user_id=user1.id,
), ),
] ]
sql_session.add_all(transactions) sql_session.add_all(transactions)
sql_session.commit() sql_session.commit()
pprint(product_price_log(sql_session, product))
def test_product_price(sql_session: Session) -> None: assert product_price(sql_session, product) == 27
insert_test_data(sql_session)
product1 = sql_session.scalars(select(Product).where(Product.name == "Test Product 1")).one()
assert product_price(sql_session, product1) == 27
def test_product_price_no_transactions(sql_session: Session) -> None:
insert_test_data(sql_session)
product2 = sql_session.scalars(select(Product).where(Product.name == "Test Product 2")).one()
assert product_price(sql_session, product2) == 0
def test_product_price_sold_out(sql_session: Session) -> None: def test_product_price_sold_out(sql_session: Session) -> None:
insert_test_data(sql_session) user, product = insert_test_data(sql_session)
product3 = sql_session.scalars(select(Product).where(Product.name == "Test Product 3")).one() transactions = [
assert product_price(sql_session, product3) == 50 Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 0),
amount=27 * 2 - 1,
per_product=27,
product_count=2,
user_id=user.id,
product_id=product.id,
),
Transaction.buy_product(
time=datetime(2023, 10, 1, 12, 0, 1),
product_count=2,
user_id=user.id,
product_id=product.id,
),
]
sql_session.add_all(transactions)
sql_session.commit()
pprint(product_price_log(sql_session, product))
assert product_price(sql_session, product) == 27
def test_product_price_interest(sql_session: Session) -> None: def test_product_price_interest(sql_session: Session) -> None:
raise NotImplementedError("This test is not implemented yet.") user, product = insert_test_data(sql_session)
transactions = [
Transaction.adjust_interest(
time=datetime(2023, 10, 1, 12, 0, 0),
interest_rate_percent=110,
user_id=user.id,
),
Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 1),
amount=27 * 2 - 1,
per_product=27,
product_count=2,
user_id=user.id,
product_id=product.id,
),
]
sql_session.add_all(transactions)
sql_session.commit()
pprint(product_price_log(sql_session, product))
product_price_ = product_price(sql_session, product)
product_price_interest = product_price(sql_session, product, include_interest=True)
assert product_price_ == 27
assert product_price_interest == math.ceil(27 * 1.1)
def test_product_price_changing_interest(sql_session: Session) -> None: def test_product_price_changing_interest(sql_session: Session) -> None:
raise NotImplementedError("This test is not implemented yet.") user, product = insert_test_data(sql_session)
transactions = [
Transaction.adjust_interest(
time=datetime(2023, 10, 1, 12, 0, 0),
interest_rate_percent=110,
user_id=user.id,
),
Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 1),
amount=27 * 2 - 1,
per_product=27,
product_count=2,
user_id=user.id,
product_id=product.id,
),
Transaction.adjust_interest(
time=datetime(2023, 10, 1, 12, 0, 2),
interest_rate_percent=120,
user_id=user.id,
),
]
sql_session.add_all(transactions)
sql_session.commit()
pprint(product_price_log(sql_session, product))
product_price_interest = product_price(sql_session, product, include_interest=True)
assert product_price_interest == math.ceil(27 * 1.2)
def test_product_price_old_transaction(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 1),
amount=27 * 2,
per_product=27,
product_count=2,
user_id=user.id,
product_id=product.id,
),
# Price should be 27
Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 2),
amount=38 * 3,
per_product=38,
product_count=3,
user_id=user.id,
product_id=product.id,
),
# price should be averaged upwards
]
sql_session.add_all(transactions)
sql_session.commit()
until_transaction = transactions[0]
pprint(
product_price_log(
sql_session,
product,
until=until_transaction,
)
)
product_price_ = product_price(
sql_session,
product,
until=until_transaction,
)
assert product_price_ == 27
# Price goes up and gets rounded up to the next integer # Price goes up and gets rounded up to the next integer
def test_product_price_round_up_from_below(sql_session: Session) -> None: def test_product_price_round_up_from_below(sql_session: Session) -> None:
raise NotImplementedError("This test is not implemented yet.") user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 1),
amount=27 * 2,
per_product=27,
product_count=2,
user_id=user.id,
product_id=product.id,
),
# Price should be 27
Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 2),
amount=38 * 3,
per_product=38,
product_count=3,
user_id=user.id,
product_id=product.id,
),
# price should be averaged upwards
]
sql_session.add_all(transactions)
sql_session.commit()
pprint(product_price_log(sql_session, product))
product_price_ = product_price(sql_session, product)
assert product_price_ == math.ceil((27 * 2 + 38 * 3) / (2 + 3))
# Price goes down and gets rounded up to the next integer # Price goes down and gets rounded up to the next integer
def test_product_price_round_up_from_above(sql_session: Session) -> None: def test_product_price_round_up_from_above(sql_session: Session) -> None:
raise NotImplementedError("This test is not implemented yet.") user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 1),
amount=27 * 2,
per_product=27,
product_count=2,
user_id=user.id,
product_id=product.id,
),
# Price should be 27
Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 2),
amount=20 * 3,
per_product=20,
product_count=3,
user_id=user.id,
product_id=product.id,
),
# price should be averaged downwards
]
sql_session.add_all(transactions)
sql_session.commit()
pprint(product_price_log(sql_session, product))
product_price_ = product_price(sql_session, product)
assert product_price_ == math.ceil((27 * 2 + 20 * 3) / (2 + 3))
def test_product_price_with_negative_stock_single_addition(sql_session: Session) -> None: def test_product_price_with_negative_stock_single_addition(sql_session: Session) -> None:
insert_test_data(sql_session) user, product = insert_test_data(sql_session)
product1 = sql_session.scalars(select(Product).where(Product.name == "Test Product 1")).one() transactions = [
user1 = sql_session.scalars(select(User).where(User.name == "Test User 1")).one() Transaction.add_product(
time=datetime(2023, 10, 1, 13, 0, 0),
amount=1,
per_product=10,
product_count=1,
user_id=user.id,
product_id=product.id,
),
Transaction.buy_product(
time=datetime(2023, 10, 1, 13, 0, 1),
product_count=10,
user_id=user.id,
product_id=product.id,
),
Transaction.add_product(
time=datetime(2023, 10, 1, 13, 0, 2),
amount=22,
per_product=22,
product_count=1,
user_id=user.id,
product_id=product.id,
),
]
transaction = Transaction.buy_product( sql_session.add_all(transactions)
time=datetime(2023, 10, 1, 13, 0, 0),
product_count=10,
user_id=user1.id,
product_id=product1.id,
)
sql_session.add(transaction)
sql_session.commit() sql_session.commit()
product1_price = product_price(sql_session, product1) pprint(product_price_log(sql_session, product))
assert product1_price == 27
transaction = Transaction.add_product(
time=datetime(2023, 10, 1, 13, 0, 1),
amount=22,
per_product=22,
product_count=1,
user_id=user1.id,
product_id=product1.id,
)
sql_session.add(transaction)
sql_session.commit()
# Stock went subzero, price should be the last added product price # Stock went subzero, price should be the last added product price
product1_price = product_price(sql_session, product1) product1_price = product_price(sql_session, product)
assert product1_price == 22 assert product1_price == 22
# TODO: what happens when stock is still negative and yet new products are added? # TODO: what happens when stock is still negative and yet new products are added?
def test_product_price_with_negative_stock_multiple_additions(sql_session: Session) -> None: def test_product_price_with_negative_stock_multiple_additions(sql_session: Session) -> None:
raise NotImplementedError("This test is not implemented yet.") user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
time=datetime(2023, 10, 1, 13, 0, 0),
amount=1,
per_product=10,
product_count=1,
user_id=user.id,
product_id=product.id,
),
Transaction.buy_product(
time=datetime(2023, 10, 1, 13, 0, 1),
product_count=10,
user_id=user.id,
product_id=product.id,
),
Transaction.add_product(
time=datetime(2023, 10, 1, 13, 0, 2),
amount=22,
per_product=22,
product_count=1,
user_id=user.id,
product_id=product.id,
),
Transaction.add_product(
time=datetime(2023, 10, 1, 13, 0, 3),
amount=29,
per_product=29,
product_count=2,
user_id=user.id,
product_id=product.id,
),
]
sql_session.add_all(transactions)
sql_session.commit()
pprint(product_price_log(sql_session, product))
# Stock went subzero, price should be the ceiled average of the last added products
product1_price = product_price(sql_session, product)
assert product1_price == math.ceil((22 + 29 * 2) / (1 + 2))

View File

@@ -1,102 +1,306 @@
import math
from datetime import datetime from datetime import datetime
from pprint import pprint
from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User from dibbler.models import Product, Transaction, User
from dibbler.queries.user_balance import user_balance, user_balance_log from dibbler.queries.user_balance import user_balance, user_balance_log
# TODO: see if we can use pytest_runtest_makereport to print the "user_balance_log"s
# only on failures instead of inlining it in every test function
def insert_test_data(sql_session: Session) -> None:
# Add users
user1 = User("Test User 1")
user2 = User("Test User 2")
sql_session.add(user1) def insert_test_data(sql_session: Session) -> tuple[User, Product]:
sql_session.add(user2) user = User("Test User")
product = Product("1234567890123", "Test Product")
sql_session.add(user)
sql_session.add(product)
sql_session.commit() sql_session.commit()
# Add products return user, product
product1 = Product("1234567890123", "Test Product 1")
product2 = Product("9876543210987", "Test Product 2")
sql_session.add(product1) def test_user_balance_no_transactions(sql_session: Session) -> None:
sql_session.add(product2) user, _ = insert_test_data(sql_session)
sql_session.commit()
pprint(user_balance_log(sql_session, user))
balance = user_balance(sql_session, user)
assert balance == 0
def test_user_balance_basic_history(sql_session: Session) -> None:
user, product = insert_test_data(sql_session)
# Add transactions
transactions = [ transactions = [
Transaction.adjust_balance( Transaction.adjust_balance(
time=datetime(2023, 10, 1, 10, 0, 0), time=datetime(2023, 10, 1, 10, 0, 0),
user_id=user.id,
amount=100, amount=100,
user_id=user1.id,
),
Transaction.adjust_balance(
time=datetime(2023, 10, 1, 10, 0, 1),
amount=50,
user_id=user2.id,
),
Transaction.adjust_balance(
time=datetime(2023, 10, 1, 10, 0, 2),
amount=-50,
user_id=user1.id,
), ),
Transaction.add_product( Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 0), time=datetime(2023, 10, 1, 10, 0, 1),
amount=27 * 2, user_id=user.id,
product_id=product.id,
amount=27,
per_product=27, per_product=27,
product_count=2,
user_id=user1.id,
product_id=product1.id,
),
Transaction.buy_product(
time=datetime(2023, 10, 1, 12, 0, 1),
product_count=1, product_count=1,
user_id=user2.id,
product_id=product1.id,
), ),
] ]
sql_session.add_all(transactions) sql_session.add_all(transactions)
sql_session.commit() sql_session.commit()
pprint(user_balance_log(sql_session, user))
def test_user_balance_basic_history(sql_session: Session) -> None: balance = user_balance(sql_session, user)
insert_test_data(sql_session)
user1 = sql_session.scalars(select(User).where(User.name == "Test User 1")).one() assert balance == 100 + 27
user2 = sql_session.scalars(select(User).where(User.name == "Test User 2")).one()
assert user_balance(sql_session, user1) == 100 - 50 + 27 * 2
assert user_balance(sql_session, user2) == 50 - 27
def test_user_balance_no_transactions(sql_session: Session) -> None: def test_user_balance_with_transfers(sql_session: Session) -> None:
raise NotImplementedError("This test is not implemented yet.") user1, product = insert_test_data(sql_session)
user2 = User("Test User 2")
sql_session.add(user2)
sql_session.commit()
transactions = [
Transaction.adjust_balance(
time=datetime(2023, 10, 1, 10, 0, 0),
user_id=user1.id,
amount=100,
),
Transaction.transfer(
time=datetime(2023, 10, 1, 10, 0, 1),
user_id=user1.id,
transfer_user_id=user2.id,
amount=50,
),
Transaction.transfer(
time=datetime(2023, 10, 1, 10, 0, 2),
user_id=user2.id,
transfer_user_id=user1.id,
amount=30,
),
]
sql_session.add_all(transactions)
sql_session.commit()
pprint(user_balance_log(sql_session, user1))
user1_balance = user_balance(sql_session, user1)
assert user1_balance == 100 - 50 + 30
pprint(user_balance_log(sql_session, user2))
user2_balance = user_balance(sql_session, user2)
assert user2_balance == 50 - 30
def test_user_balance_complex_history(sql_session: Session) -> None: def test_user_balance_complex_history(sql_session: Session) -> None:
raise NotImplementedError("This test is not implemented yet.") raise NotImplementedError("This test is not implemented yet.")
def test_user_balance_with_tranfers(sql_session: Session) -> None:
raise NotImplementedError("This test is not implemented yet.")
def test_user_balance_penalty(sql_session: Session) -> None: def test_user_balance_penalty(sql_session: Session) -> None:
raise NotImplementedError("This test is not implemented yet.") user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
time=datetime(2023, 10, 1, 10, 0, 0),
user_id=user.id,
product_id=product.id,
amount=27,
per_product=27,
product_count=1,
),
Transaction.adjust_balance(
time=datetime(2023, 10, 1, 11, 0, 0),
user_id=user.id,
amount=-200,
),
# Penalized, pays 2x the price (default penalty)
Transaction.buy_product(
time=datetime(2023, 10, 1, 12, 0, 0),
user_id=user.id,
product_id=product.id,
product_count=1,
),
]
sql_session.add_all(transactions)
sql_session.commit()
pprint(user_balance_log(sql_session, user))
assert user_balance(sql_session, user) == 27 - 200 - (27 * 2)
def test_user_balance_changing_penalty(sql_session: Session) -> None: def test_user_balance_changing_penalty(sql_session: Session) -> None:
raise NotImplementedError("This test is not implemented yet.") user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
time=datetime(2023, 10, 1, 10, 0, 0),
user_id=user.id,
product_id=product.id,
amount=27,
per_product=27,
product_count=1,
),
Transaction.adjust_balance(
time=datetime(2023, 10, 1, 11, 0, 0),
user_id=user.id,
amount=-200,
),
# Penalized, pays 2x the price (default penalty)
Transaction.buy_product(
time=datetime(2023, 10, 1, 12, 0, 0),
user_id=user.id,
product_id=product.id,
product_count=1,
),
Transaction.adjust_penalty(
time=datetime(2023, 10, 1, 13, 0, 0),
user_id=user.id,
penalty_multiplier_percent=300,
penalty_threshold=-100,
),
# Penalized, pays 3x the price
Transaction.buy_product(
time=datetime(2023, 10, 1, 14, 0, 0),
user_id=user.id,
product_id=product.id,
product_count=1,
),
]
sql_session.add_all(transactions)
sql_session.commit()
pprint(user_balance_log(sql_session, user))
assert user_balance(sql_session, user) == 27 - 200 - (27 * 2) - (27 * 3)
def test_user_balance_interest(sql_session: Session) -> None: def test_user_balance_interest(sql_session: Session) -> None:
raise NotImplementedError("This test is not implemented yet.") user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
time=datetime(2023, 10, 1, 10, 0, 0),
user_id=user.id,
product_id=product.id,
amount=27,
per_product=27,
product_count=1,
),
Transaction.adjust_interest(
time=datetime(2023, 10, 1, 11, 0, 0),
user_id=user.id,
interest_rate_percent=110,
),
Transaction.buy_product(
time=datetime(2023, 10, 1, 12, 0, 0),
user_id=user.id,
product_id=product.id,
product_count=1,
),
]
sql_session.add_all(transactions)
sql_session.commit()
pprint(user_balance_log(sql_session, user))
assert user_balance(sql_session, user) == 27 - math.ceil(27 * 1.1)
def test_user_balance_changing_interest(sql_session: Session) -> None: def test_user_balance_changing_interest(sql_session: Session) -> None:
raise NotImplementedError("This test is not implemented yet.") user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
time=datetime(2023, 10, 1, 10, 0, 0),
user_id=user.id,
product_id=product.id,
amount=27 * 3,
per_product=27,
product_count=3,
),
Transaction.adjust_interest(
time=datetime(2023, 10, 1, 11, 0, 0),
user_id=user.id,
interest_rate_percent=110,
),
# Pays 1.1x the price
Transaction.buy_product(
time=datetime(2023, 10, 1, 12, 0, 0),
user_id=user.id,
product_id=product.id,
product_count=1,
),
Transaction.adjust_interest(
time=datetime(2023, 10, 1, 13, 0, 0),
user_id=user.id,
interest_rate_percent=120,
),
# Pays 1.2x the price
Transaction.buy_product(
time=datetime(2023, 10, 1, 14, 0, 0),
user_id=user.id,
product_id=product.id,
product_count=1,
),
]
sql_session.add_all(transactions)
sql_session.commit()
pprint(user_balance_log(sql_session, user))
assert user_balance(sql_session, user) == 27 * 3 - math.ceil(27 * 1.1) - math.ceil(27 * 1.2)
def test_user_balance_penalty_interest_combined(sql_session: Session) -> None: def test_user_balance_penalty_interest_combined(sql_session: Session) -> None:
raise NotImplementedError("This test is not implemented yet.") user, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
time=datetime(2023, 10, 1, 10, 0, 0),
user_id=user.id,
product_id=product.id,
amount=27,
per_product=27,
product_count=1,
),
Transaction.adjust_interest(
time=datetime(2023, 10, 1, 11, 0, 0),
user_id=user.id,
interest_rate_percent=110,
),
Transaction.adjust_balance(
time=datetime(2023, 10, 1, 12, 0, 0),
user_id=user.id,
amount=-200,
),
# Penalized, pays 2x the price (default penalty)
# Pays 1.1x the price
Transaction.buy_product(
time=datetime(2023, 10, 1, 13, 0, 0),
user_id=user.id,
product_id=product.id,
product_count=1,
),
]
sql_session.add_all(transactions)
sql_session.commit()
pprint(user_balance_log(sql_session, user))
assert user_balance(sql_session, user) == (
27
- 200
- math.ceil(27 * 2 * 1.1)
)

View File

@@ -0,0 +1,60 @@
from datetime import datetime
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
from dibbler.queries.user_transactions import user_transactions
def insert_test_data(sql_session: Session) -> User:
user = User("Test User")
sql_session.add(user)
sql_session.commit()
return user
def test_user_transactions(sql_session: Session):
user = insert_test_data(sql_session)
product = Product("1234567890123", "Test Product")
user2 = User("Test User 2")
sql_session.add_all([product, user2])
sql_session.commit()
transactions = [
Transaction.adjust_balance(
time=datetime(2023, 10, 1, 10, 0, 0),
amount=100,
user_id=user.id,
),
Transaction.adjust_balance(
time=datetime(2023, 10, 1, 10, 0, 1),
amount=50,
user_id=user2.id,
),
Transaction.adjust_balance(
time=datetime(2023, 10, 1, 10, 0, 2),
amount=-50,
user_id=user.id,
),
Transaction.add_product(
time=datetime(2023, 10, 1, 12, 0, 0),
amount=27 * 2,
per_product=27,
product_count=2,
user_id=user.id,
product_id=product.id,
),
Transaction.buy_product(
time=datetime(2023, 10, 1, 12, 0, 1),
product_count=1,
user_id=user2.id,
product_id=product.id,
),
]
sql_session.add_all(transactions)
assert len(user_transactions(sql_session, user)) == 3
assert len(user_transactions(sql_session, user2)) == 2