From 3e4c3a44d22713fee2cc6764429ceb22259d8af8 Mon Sep 17 00:00:00 2001 From: h7x4 Date: Fri, 13 Jun 2025 22:19:40 +0200 Subject: [PATCH] fixup! WIP --- dibbler/models/Product.py | 2 +- dibbler/models/Transaction.py | 70 ++++- dibbler/models/TransactionType.py | 29 +- dibbler/models/User.py | 17 - dibbler/queries/current_interest.py | 2 - dibbler/queries/product_price.py | 78 ++++- dibbler/queries/product_stock.py | 86 +++-- dibbler/queries/search_product.py | 61 ++-- dibbler/queries/search_user.py | 37 ++- dibbler/queries/user_balance.py | 106 +++++-- dibbler/queries/user_transactions.py | 20 ++ dibbler/subcommands/seed_test_data.py | 19 +- tests/conftest.py | 9 +- tests/models/test_transaction.py | 79 +++++ tests/models/test_user.py | 46 --- tests/queries/test_buy_product.py | 177 ----------- tests/queries/test_product_price.py | 402 +++++++++++++++++------- tests/queries/test_user_balance.py | 310 ++++++++++++++---- tests/queries/test_user_transactions.py | 60 ++++ 19 files changed, 1034 insertions(+), 576 deletions(-) create mode 100644 dibbler/queries/user_transactions.py delete mode 100644 tests/queries/test_buy_product.py create mode 100644 tests/queries/test_user_transactions.py diff --git a/dibbler/models/Product.py b/dibbler/models/Product.py index dd10fd9..c69171f 100644 --- a/dibbler/models/Product.py +++ b/dibbler/models/Product.py @@ -27,7 +27,7 @@ class Product(Base): EAN-13 code. """ - name: Mapped[str] = mapped_column(String(45)) + name: Mapped[str] = mapped_column(String(45), unique=True) """ The name of the product. diff --git a/dibbler/models/Transaction.py b/dibbler/models/Transaction.py index bb76b44..3cf2fd5 100644 --- a/dibbler/models/Transaction.py +++ b/dibbler/models/Transaction.py @@ -9,27 +9,32 @@ from sqlalchemy import ( ForeignKey, Integer, Text, -) -from sqlalchemy import ( - Enum as SQLEnum, + and_, + column, + or_, ) from sqlalchemy.orm import ( Mapped, mapped_column, relationship, ) +from sqlalchemy.orm.collections import ( + InstrumentedDict, + InstrumentedList, + InstrumentedSet, +) from sqlalchemy.sql.schema import Index from .Base import Base -from .TransactionType import TransactionType +from .TransactionType import TransactionType, TransactionTypeSQL if TYPE_CHECKING: from .Product import Product from .User import User +# TODO: rename to *_PERCENT # NOTE: these only matter when there are no adjustments made in the database. - DEFAULT_INTEREST_RATE_PERCENTAGE = 100 DEFAULT_PENALTY_THRESHOLD = -100 DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE = 200 @@ -64,7 +69,6 @@ assert all(x <= _DYNAMIC_FIELDS for x in _EXPECTED_FIELDS.values()), ( "All expected fields must be part of _DYNAMIC_FIELDS." ) -# TODO: ensure that the transaction types are not prefixed with 'TransactionType.' in the database def _transaction_type_field_constraints( transaction_type: TransactionType, @@ -72,14 +76,14 @@ def _transaction_type_field_constraints( ) -> CheckConstraint: unexpected_fields = _DYNAMIC_FIELDS - expected_fields - expected_constraints = ["{} IS NOT NULL".format(field) for field in expected_fields] - unexpected_constraints = ["{} IS NULL".format(field) for field in unexpected_fields] - - constraints = expected_constraints + unexpected_constraints - - # TODO: use sqlalchemy's `and_` and `or_` to build the constraints return CheckConstraint( - f"type <> '{transaction_type}' OR ({' AND '.join(constraints)})", + or_( + column("type") != transaction_type.value, + and_( + *[column(field) != None for field in expected_fields], + *[column(field) == None for field in unexpected_fields], + ), + ), name=f"trx_type_{transaction_type.value}_expected_fields", ) @@ -91,7 +95,10 @@ class Transaction(Base): for transaction_type, expected_fields in _EXPECTED_FIELDS.items() ], CheckConstraint( - f"type <> '{TransactionType.TRANSFER}' OR user_id <> transfer_user_id", + or_( + column("type") != TransactionType.TRANSFER.value, + column("user_id") != column("transfer_user_id"), + ), name="trx_type_transfer_no_self_transfers", ), # Speed up product count calculation @@ -125,7 +132,7 @@ class Transaction(Base): This is not used for any calculations, but can be useful for debugging. """ - type_: Mapped[TransactionType] = mapped_column(SQLEnum(TransactionType), name="type") + type_: Mapped[TransactionType] = mapped_column(TransactionTypeSQL, name="type") """ Which type of transaction this is. @@ -292,6 +299,39 @@ class Transaction(Base): "The real amount of the transaction must be less than the total value of the products." ) + # TODO: improve printing further + + def __repr__(self) -> str: + sort_order = [ + "id", + "time", + ] + + columns = ", ".join( + f"{k}={repr(v)}" + for k, v in sorted( + self.__dict__.items(), + key=lambda item: chr(sort_order.index(item[0])) + if item[0] in sort_order + else item[0], + ) + if not any( + [ + k == "type_", + (k == "message" and v is None), + k.startswith("_"), + # Ensure that we don't try to print out the entire list of + # relationships, which could create an infinite loop + isinstance(v, Base), + isinstance(v, InstrumentedList), + isinstance(v, InstrumentedSet), + isinstance(v, InstrumentedDict), + *[k in (_DYNAMIC_FIELDS - _EXPECTED_FIELDS[self.type_])], + ] + ) + ) + return f"{self.type_.upper()}({columns})" + ################### # FACTORY METHODS # ################### diff --git a/dibbler/models/TransactionType.py b/dibbler/models/TransactionType.py index 4f4fa29..ed1f7e9 100644 --- a/dibbler/models/TransactionType.py +++ b/dibbler/models/TransactionType.py @@ -1,15 +1,26 @@ -from enum import Enum +from enum import StrEnum, auto + +from sqlalchemy import Enum as SQLEnum -class TransactionType(Enum): +class TransactionType(StrEnum): """ Enum for transaction types. """ - ADD_PRODUCT = "add_product" - ADJUST_BALANCE = "adjust_balance" - ADJUST_INTEREST = "adjust_interest" - ADJUST_PENALTY = "adjust_penalty" - ADJUST_STOCK = "adjust_stock" - BUY_PRODUCT = "buy_product" - TRANSFER = "transfer" + ADD_PRODUCT = auto() + ADJUST_BALANCE = auto() + ADJUST_INTEREST = auto() + ADJUST_PENALTY = auto() + ADJUST_STOCK = auto() + BUY_PRODUCT = auto() + TRANSFER = auto() + + +TransactionTypeSQL = SQLEnum( + TransactionType, + native_enum=True, + create_constraint=True, + validate_strings=True, + values_callable=lambda x: [i.value for i in x], +) diff --git a/dibbler/models/User.py b/dibbler/models/User.py index 7f7d01c..026c46a 100644 --- a/dibbler/models/User.py +++ b/dibbler/models/User.py @@ -45,20 +45,3 @@ class User(Base): # def is_anonymous(self): # return self.card == "11122233" - - # TODO: move to 'queries' - # TODO: allow filtering out 'special transactions' like 'ADJUST_INTEREST' and 'ADJUST_PENALTY' - def transactions(self, sql_session: Session) -> list[Transaction]: - """ - Returns the transactions of the user in chronological order. - """ - - from .Transaction import Transaction # Import here to avoid circular import - - return list( - sql_session.scalars( - select(Transaction) - .where(Transaction.user_id == self.id) - .order_by(Transaction.time.asc()) - ).all() - ) diff --git a/dibbler/queries/current_interest.py b/dibbler/queries/current_interest.py index 51272bf..615d485 100644 --- a/dibbler/queries/current_interest.py +++ b/dibbler/queries/current_interest.py @@ -16,6 +16,4 @@ def current_interest(sql_session: Session) -> int: if result is None: return DEFAULT_INTEREST_RATE_PERCENTAGE - assert result.interest_rate_percent is not None, "Interest rate percent must be set" - return result.interest_rate_percent diff --git a/dibbler/queries/product_price.py b/dibbler/queries/product_price.py index 46c2358..7abd877 100644 --- a/dibbler/queries/product_price.py +++ b/dibbler/queries/product_price.py @@ -1,7 +1,11 @@ +import math +from dataclasses import dataclass from datetime import datetime from sqlalchemy import ( + ColumnElement, Integer, + SQLColumnExpression, asc, case, cast, @@ -16,13 +20,14 @@ from dibbler.models import ( Transaction, TransactionType, ) +from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENTAGE -# TODO: include the transaction id in the log for easier debugging def _product_price_query( - product_id: int, + product_id: int | ColumnElement[int], use_cache: bool = True, - until: datetime | None = None, + until: datetime | SQLColumnExpression[datetime] | None = None, + until_including: bool = True, cte_name: str = "rec_cte", ): """ @@ -35,6 +40,7 @@ def _product_price_query( initial_element = select( literal(0).label("i"), literal(0).label("time"), + literal(None).label("transaction_id"), literal(0).label("price"), literal(0).label("product_count"), ) @@ -45,6 +51,7 @@ def _product_price_query( trx_subset = ( select( func.row_number().over(order_by=asc(Transaction.time)).label("i"), + Transaction.id, Transaction.time, Transaction.type_, Transaction.product_count, @@ -59,7 +66,12 @@ def _product_price_query( ] ), Transaction.product_id == product_id, - Transaction.time <= until if until is not None else 1 == 1, + case( + (literal(until_including), Transaction.time <= until), + else_=Transaction.time < until, + ) + if until is not None + else literal(True), ) .order_by(Transaction.time.asc()) .alias("trx_subset") @@ -69,6 +81,7 @@ def _product_price_query( select( trx_subset.c.i, trx_subset.c.time, + trx_subset.c.id.label("transaction_id"), case( # Someone buys the product -> price remains the same. (trx_subset.c.type_ == TransactionType.BUY_PRODUCT, recursive_cte.c.price), @@ -78,7 +91,10 @@ def _product_price_query( trx_subset.c.type_ == TransactionType.ADD_PRODUCT, cast( func.ceil( - (trx_subset.c.per_product * trx_subset.c.product_count) + ( + recursive_cte.c.price * func.max(recursive_cte.c.product_count, 0) + + trx_subset.c.per_product * trx_subset.c.product_count + ) / ( # The running product count can be negative if the accounting is bad. # This ensures that we never end up with negative prices or zero divisions @@ -122,19 +138,23 @@ def _product_price_query( return recursive_cte.union_all(recursive_elements) - # TODO: create a function for the log that pretty prints the log entries # for debugging purposes -# TODO: wrap the log entries in a dataclass, the don't cost that much +@dataclass +class ProductPriceLogEntry: + transaction: Transaction + price: int + product_count: int + def product_price_log( sql_session: Session, product: Product, use_cache: bool = True, until: Transaction | None = None, -) -> list[tuple[int, datetime, int, int]]: +) -> list[ProductPriceLogEntry]: """ Calculates the price of a product and returns a log of the price changes. """ @@ -147,20 +167,32 @@ def product_price_log( result = sql_session.execute( select( - recursive_cte.c.i, - recursive_cte.c.time, + Transaction, recursive_cte.c.price, recursive_cte.c.product_count, - ).order_by(recursive_cte.c.i.asc()) + ) + .select_from(recursive_cte) + .join( + Transaction, + onclause=Transaction.id == recursive_cte.c.transaction_id, + ) + .order_by(recursive_cte.c.i.asc()) ).all() - if not result: + if result is None: # If there are no transactions for this product, the query should return an empty list, not None. raise RuntimeError( f"Something went wrong while calculating the price log for product {product.name} (ID: {product.id})." ) - return [(row.i, row.time, row.price, row.product_count) for row in result] + return [ + ProductPriceLogEntry( + transaction=row[0], + price=row.price, + product_count=row.product_count, + ) + for row in result + ] @staticmethod @@ -169,6 +201,7 @@ def product_price( product: Product, use_cache: bool = True, until: Transaction | None = None, + include_interest: bool = False, ) -> int: """ Calculates the price of a product. @@ -184,9 +217,9 @@ def product_price( # - product_count should never be negative (but this happens sometimes, so just a warning) # - price should never be negative - result = sql_session.scalar( + result = sql_session.scalars( select(recursive_cte.c.price).order_by(recursive_cte.c.i.desc()).limit(1) - ) + ).one_or_none() if result is None: # If there are no transactions for this product, the query should return 0, not None. @@ -194,4 +227,19 @@ def product_price( f"Something went wrong while calculating the price for product {product.name} (ID: {product.id})." ) + if include_interest: + interest_rate = ( + sql_session.scalar( + select(Transaction.interest_rate_percent) + .where( + Transaction.type_ == TransactionType.ADJUST_INTEREST, + literal(True) if until is None else Transaction.time <= until.time, + ) + .order_by(Transaction.time.desc()) + .limit(1) + ) + or DEFAULT_INTEREST_RATE_PERCENTAGE + ) + result = math.ceil(result * interest_rate / 100) + return result diff --git a/dibbler/queries/product_stock.py b/dibbler/queries/product_stock.py index 8c591e5..cf5ce6c 100644 --- a/dibbler/queries/product_stock.py +++ b/dibbler/queries/product_stock.py @@ -1,6 +1,6 @@ from datetime import datetime -from sqlalchemy import case, func, select +from sqlalchemy import case, func, literal, select from sqlalchemy.orm import Session from dibbler.models import ( @@ -10,6 +10,51 @@ from dibbler.models import ( ) +def _product_stock_query( + product_id: int, + use_cache: bool = True, + until: datetime | None = None, +): + """ + The inner query for calculating the product stock. + """ + + if use_cache: + print("WARNING: Using cache for product stock query is not implemented yet.") + + query = select( + func.sum( + case( + ( + Transaction.type_ == TransactionType.ADD_PRODUCT, + Transaction.product_count, + ), + ( + Transaction.type_ == TransactionType.BUY_PRODUCT, + -Transaction.product_count, + ), + ( + Transaction.type_ == TransactionType.ADJUST_STOCK, + Transaction.product_count, + ), + else_=0, + ) + ) + ).where( + Transaction.type_.in_( + [ + TransactionType.BUY_PRODUCT, + TransactionType.ADD_PRODUCT, + TransactionType.ADJUST_STOCK, + ] + ), + Transaction.product_id == product_id, + Transaction.time <= until if until is not None else literal(True), + ) + + return query + + def product_stock( sql_session: Session, product: Product, @@ -20,39 +65,12 @@ def product_stock( Returns the number of products in stock. """ - if use_cache: - print("WARNING: Using cache for product stock query is not implemented yet.") + query = _product_stock_query( + product_id=product.id, + use_cache=use_cache, + until=until, + ) - result = sql_session.scalars( - select( - func.sum( - case( - ( - Transaction.type_ == TransactionType.ADD_PRODUCT, - Transaction.product_count, - ), - ( - Transaction.type_ == TransactionType.BUY_PRODUCT, - -Transaction.product_count, - ), - ( - Transaction.type_ == TransactionType.ADJUST_STOCK, - Transaction.product_count, - ), - else_=0, - ) - ) - ).where( - Transaction.type_.in_( - [ - TransactionType.BUY_PRODUCT, - TransactionType.ADD_PRODUCT, - TransactionType.ADJUST_STOCK, - ] - ), - Transaction.product_id == product.id, - Transaction.time <= until if until is not None else 1 == 1, - ) - ).one_or_none() + result = sql_session.scalars(query).one_or_none() return result or 0 diff --git a/dibbler/queries/search_product.py b/dibbler/queries/search_product.py index 771784d..bfd0786 100644 --- a/dibbler/queries/search_product.py +++ b/dibbler/queries/search_product.py @@ -1,54 +1,39 @@ -from sqlalchemy import and_, or_ +from sqlalchemy import and_, literal, or_, select from sqlalchemy.orm import Session from dibbler.models import Product -# TODO: modernize queries to use SQLAlchemy 2.0 style def search_product( string: str, - session: Session, + sql_session: Session, find_hidden_products=True, ) -> Product | list[Product]: - if find_hidden_products: - exact_match = ( - session.query(Product) - .filter(or_(Product.bar_code == string, Product.name == string)) - .first() - ) - else: - exact_match = ( - session.query(Product) - .filter( - or_( - Product.bar_code == string, - and_(Product.name == string, not Product.hidden), - ) + exact_match = sql_session.scalars( + select(Product).where( + or_( + Product.bar_code == string, + and_( + Product.name == string, + literal(True) if find_hidden_products else not Product.hidden, + ), ) - .first() ) + ).first() + if exact_match: return exact_match - if find_hidden_products: - product_list = ( - session.query(Product) - .filter( - or_( - Product.bar_code.ilike(f"%{string}%"), + + product_list = sql_session.scalars( + select(Product).where( + or_( + Product.bar_code.ilike(f"%{string}%"), + and_( Product.name.ilike(f"%{string}%"), - ) + literal(True) if find_hidden_products else not Product.hidden, + ), ) - .all() ) - else: - product_list = ( - session.query(Product) - .filter( - or_( - Product.bar_code.ilike(f"%{string}%"), - and_(Product.name.ilike(f"%{string}%"), not Product.hidden), - ) - ) - .all() - ) - return product_list + ).all() + + return list(product_list) diff --git a/dibbler/queries/search_user.py b/dibbler/queries/search_user.py index 62d8c53..421b794 100644 --- a/dibbler/queries/search_user.py +++ b/dibbler/queries/search_user.py @@ -1,28 +1,37 @@ -from sqlalchemy import or_ +from sqlalchemy import or_, select from sqlalchemy.orm import Session from dibbler.models import User -# TODO: modernize queries to use SQLAlchemy 2.0 style -def search_user(string: str, session: Session, ignorethisflag=None) -> User | list[User]: +def search_user( + string: str, + sql_session: Session, + ignorethisflag=None, +) -> User | list[User]: string = string.lower() - exact_match = ( - session.query(User) - .filter(or_(User.name == string, User.card == string, User.rfid == string)) - .first() - ) + + exact_match = sql_session.scalars( + select(User).where( + or_( + User.name == string, + User.card == string, + User.rfid == string, + ) + ) + ).first() + if exact_match: return exact_match - user_list = ( - session.query(User) - .filter( + + user_list = sql_session.scalars( + select(User).where( or_( User.name.ilike(f"%{string}%"), User.card.ilike(f"%{string}%"), User.rfid.ilike(f"%{string}%"), ) ) - .all() - ) - return user_list + ).all() + + return list(user_list) diff --git a/dibbler/queries/user_balance.py b/dibbler/queries/user_balance.py index 5d2316f..3e191f3 100644 --- a/dibbler/queries/user_balance.py +++ b/dibbler/queries/user_balance.py @@ -1,6 +1,8 @@ +from dataclasses import dataclass from datetime import datetime from sqlalchemy import ( + Float, Integer, and_, asc, @@ -26,12 +28,12 @@ from dibbler.models.Transaction import ( ) from dibbler.queries.product_price import _product_price_query -# TODO: include the transaction id in the log for easier debugging def _user_balance_query( - user: User, + user_id: int, use_cache: bool = True, until: datetime | None = None, + until_including: bool = True, cte_name: str = "rec_cte", ): """ @@ -44,6 +46,7 @@ def _user_balance_query( initial_element = select( literal(0).label("i"), literal(0).label("time"), + literal(None).label("transaction_id"), literal(0).label("balance"), literal(DEFAULT_INTEREST_RATE_PERCENTAGE).label("interest_rate_percent"), literal(DEFAULT_PENALTY_THRESHOLD).label("penalty_threshold"), @@ -56,20 +59,21 @@ def _user_balance_query( trx_subset = ( select( func.row_number().over(order_by=asc(Transaction.time)).label("i"), - Transaction.time, - Transaction.type_, Transaction.amount, - Transaction.product_count, - Transaction.product_id, - Transaction.transfer_user_id, + Transaction.id, Transaction.interest_rate_percent, Transaction.penalty_multiplier_percent, Transaction.penalty_threshold, + Transaction.product_count, + Transaction.product_id, + Transaction.time, + Transaction.transfer_user_id, + Transaction.type_, ) .where( or_( and_( - Transaction.user_id == user.id, + Transaction.user_id == user_id, Transaction.type_.in_( [ TransactionType.ADD_PRODUCT, @@ -81,7 +85,7 @@ def _user_balance_query( ), and_( Transaction.type_ == TransactionType.TRANSFER, - Transaction.transfer_user_id == user.id, + Transaction.transfer_user_id == user_id, ), Transaction.type_.in_( [ @@ -90,7 +94,12 @@ def _user_balance_query( ] ), ), - Transaction.time <= until if until is not None else 1 == 1, + case( + (literal(until_including), Transaction.time <= until), + else_=Transaction.time < until, + ) + if until is not None + else literal(True), ) .order_by(Transaction.time.asc()) .alias("trx_subset") @@ -100,6 +109,7 @@ def _user_balance_query( select( trx_subset.c.i, trx_subset.c.time, + trx_subset.c.id.label("transaction_id"), case( # Adjusts balance -> balance gets adjusted ( @@ -124,12 +134,14 @@ def _user_balance_query( # product prices somehow. # Base price ( - select(column("price")) + # FIXME: this always returns 0 for some reason... + select(cast(column("price"), Float)) .select_from( _product_price_query( trx_subset.c.product_id, use_cache=use_cache, until=trx_subset.c.time, + until_including=False, cte_name="product_price_cte", ) ) @@ -139,12 +151,16 @@ def _user_balance_query( # TODO: should interest be applied before or after the penalty multiplier? # at the moment of writing, after sound right, but maybe ask someone? # Interest - * (recursive_cte.c.interest_rate_percent / 100) + * (cast(recursive_cte.c.interest_rate_percent, Float) / 100) # Penalty * case( ( + # TODO: should this be <= or balance increases ( - trx_subset.c.type_ == TransactionType.TRANSFER - and trx_subset.c.transfer_user_id == user.id, + and_( + trx_subset.c.type_ == TransactionType.TRANSFER, + trx_subset.c.transfer_user_id == user_id, + ), recursive_cte.c.balance + trx_subset.c.amount, ), # Transfers money from self -> balance decreases ( - trx_subset.c.type_ == TransactionType.TRANSFER - and trx_subset.c.transfer_user_id != user.id, + and_( + trx_subset.c.type_ == TransactionType.TRANSFER, + trx_subset.c.transfer_user_id != user_id, + ), recursive_cte.c.balance - trx_subset.c.amount, ), # Interest adjustment -> balance stays the same @@ -201,32 +221,55 @@ def _user_balance_query( # TODO: create a function for the log that pretty prints the log entries # for debugging purposes -# TODO: wrap the log entries in a dataclass, the don't cost that much -# TODO: add a method on the dataclass, using the running penalization data -# to figure out if the current row was penalized or not. +@dataclass +class UserBalanceLogEntry: + transaction: Transaction + balance: int + interest_rate_percent: int + penalty_threshold: int + penalty_multiplier_percent: int + + def is_penalized(self) -> bool: + """ + Returns whether this exact transaction is penalized. + """ + + return False + + # return self.transaction.type_ == TransactionType.BUY_PRODUCT and prev? + def user_balance_log( sql_session: Session, user: User, use_cache: bool = True, until: Transaction | None = None, -) -> list[tuple[int, datetime, int, int, int, int]]: +) -> list[UserBalanceLogEntry]: + """ + Returns a log of the user's balance over time, including interest and penalty adjustments. + """ + recursive_cte = _user_balance_query( - user, + user.id, use_cache=use_cache, until=until.time if until else None, ) result = sql_session.execute( select( - recursive_cte.c.i, - recursive_cte.c.time, + Transaction, recursive_cte.c.balance, recursive_cte.c.interest_rate_percent, recursive_cte.c.penalty_threshold, recursive_cte.c.penalty_multiplier_percent, - ).order_by(recursive_cte.c.i.asc()) + ) + .select_from(recursive_cte) + .join( + Transaction, + onclause=Transaction.id == recursive_cte.c.transaction_id, + ) + .order_by(recursive_cte.c.i.asc()) ).all() if result is None: @@ -235,7 +278,16 @@ def user_balance_log( f"Something went wrong while calculating the balance for user {user.name} (ID: {user.id})." ) - return result + return [ + UserBalanceLogEntry( + transaction=row[0], + balance=row.balance, + interest_rate_percent=row.interest_rate_percent, + penalty_threshold=row.penalty_threshold, + penalty_multiplier_percent=row.penalty_multiplier_percent, + ) + for row in result + ] def user_balance( @@ -249,7 +301,7 @@ def user_balance( """ recursive_cte = _user_balance_query( - user, + user.id, use_cache=use_cache, until=until.time if until else None, ) diff --git a/dibbler/queries/user_transactions.py b/dibbler/queries/user_transactions.py new file mode 100644 index 0000000..240e5ba --- /dev/null +++ b/dibbler/queries/user_transactions.py @@ -0,0 +1,20 @@ +from sqlalchemy import select +from sqlalchemy.orm import Session + +from dibbler.models import Transaction, User + +# TODO: allow filtering out 'special transactions' like 'ADJUST_INTEREST' and 'ADJUST_PENALTY' + + +def user_transactions(sql_session: Session, user: User) -> list[Transaction]: + """ + Returns the transactions of the user in chronological order. + """ + + return list( + sql_session.scalars( + select(Transaction) + .where(Transaction.user_id == user.id) + .order_by(Transaction.time.asc()) + ).all() + ) diff --git a/dibbler/subcommands/seed_test_data.py b/dibbler/subcommands/seed_test_data.py index 1cb9720..57b185d 100644 --- a/dibbler/subcommands/seed_test_data.py +++ b/dibbler/subcommands/seed_test_data.py @@ -2,7 +2,7 @@ from datetime import datetime from pathlib import Path from dibbler.db import Session -from dibbler.models import Product, Transaction, TransactionType, User +from dibbler.models import Product, Transaction, User JSON_FILE = Path(__file__).parent.parent.parent / "mock_data.json" @@ -11,6 +11,7 @@ JSON_FILE = Path(__file__).parent.parent.parent / "mock_data.json" # whether to seed test data, or by using command line arguments for # automatating the answer. + def clear_db(sql_session): sql_session.query(Product).delete() sql_session.query(User).delete() @@ -41,37 +42,31 @@ def main(): # Add transactions transactions = [ - Transaction( + Transaction.adjust_balance( time=datetime(2023, 10, 1, 10, 0, 0), - type_=TransactionType.ADJUST_BALANCE, amount=100, user_id=user1.id, ), - Transaction( + Transaction.adjust_balance( time=datetime(2023, 10, 1, 10, 0, 1), - type_=TransactionType.ADJUST_BALANCE, amount=50, user_id=user2.id, ), - Transaction( + Transaction.adjust_balance( time=datetime(2023, 10, 1, 10, 0, 2), - type_=TransactionType.ADJUST_BALANCE, amount=-50, user_id=user1.id, ), - Transaction( + Transaction.add_product( time=datetime(2023, 10, 1, 12, 0, 0), - type_=TransactionType.ADD_PRODUCT, amount=27 * 2, per_product=27, product_count=2, user_id=user1.id, product_id=product1.id, ), - Transaction( + Transaction.buy_product( time=datetime(2023, 10, 1, 12, 0, 1), - type_=TransactionType.BUY_PRODUCT, - amount=27, product_count=1, user_id=user2.id, product_id=product1.id, diff --git a/tests/conftest.py b/tests/conftest.py index 293439f..cc4483f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ import pytest -from sqlalchemy import create_engine +from sqlalchemy import create_engine, event from sqlalchemy.orm import Session from dibbler.models import Base @@ -24,6 +24,13 @@ def sql_session(request): "sqlite:///:memory:", echo=echo, ) + + @event.listens_for(engine, "connect") + def set_sqlite_pragma(dbapi_connection, _connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + Base.metadata.create_all(engine) with Session(engine) as sql_session: yield sql_session diff --git a/tests/models/test_transaction.py b/tests/models/test_transaction.py index 2b2c785..4c178a1 100644 --- a/tests/models/test_transaction.py +++ b/tests/models/test_transaction.py @@ -5,6 +5,7 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session from dibbler.models import Product, Transaction, User +from dibbler.queries.product_stock import product_stock def insert_test_data(sql_session: Session) -> tuple[User, Product]: @@ -118,3 +119,81 @@ def test_user_foreign_key_constraint(sql_session: Session) -> None: with pytest.raises(IntegrityError): sql_session.commit() + + +def test_transaction_buy_product_more_than_stock(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 0), + user_id=user.id, + product_id=product.id, + amount=27, + per_product=27, + product_count=1, + ), + Transaction.buy_product( + time=datetime(2023, 10, 1, 13, 0, 0), + product_count=10, + user_id=user.id, + product_id=product.id, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + assert product_stock(sql_session, product) == 1 - 10 + + +def test_transaction_buy_product_dont_allow_no_add_product_transactions( + sql_session: Session, +) -> None: + user, product = insert_test_data(sql_session) + + transaction = Transaction.buy_product( + time=datetime(2023, 10, 1, 12, 0, 0), + product_count=1, + user_id=user.id, + product_id=product.id, + ) + + sql_session.add(transaction) + + with pytest.raises(ValueError): + sql_session.commit() + + +def test_transaction_add_product_deny_amount_over_per_product_times_product_count( + sql_session: Session, +) -> None: + user, product = insert_test_data(sql_session) + + with pytest.raises(ValueError): + _transaction = Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 0), + user_id=user.id, + product_id=product.id, + amount=27 * 2 + 1, # Invalid amount + per_product=27, + product_count=2, + ) + + +def test_transaction_add_product_allow_amount_under_per_product_times_product_count( + sql_session: Session, +) -> None: + user, product = insert_test_data(sql_session) + + transaction = Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 0), + user_id=user.id, + product_id=product.id, + amount=27 * 2 - 1, # Valid amount + per_product=27, + product_count=2, + ) + + sql_session.add(transaction) + sql_session.commit() diff --git a/tests/models/test_user.py b/tests/models/test_user.py index 2b6aa9e..7e8e852 100644 --- a/tests/models/test_user.py +++ b/tests/models/test_user.py @@ -23,49 +23,3 @@ def test_ensure_no_duplicate_user_names(sql_session: Session): with pytest.raises(IntegrityError): sql_session.commit() - - -def test_user_transactions(sql_session: Session): - user = insert_test_data(sql_session) - - product = Product("1234567890123", "Test Product") - user2 = User("Test User 2") - sql_session.add_all([product, user2]) - sql_session.commit() - - transactions = [ - Transaction.adjust_balance( - time=datetime(2023, 10, 1, 10, 0, 0), - amount=100, - user_id=user.id, - ), - Transaction.adjust_balance( - time=datetime(2023, 10, 1, 10, 0, 1), - amount=50, - user_id=user2.id, - ), - Transaction.adjust_balance( - time=datetime(2023, 10, 1, 10, 0, 2), - amount=-50, - user_id=user.id, - ), - Transaction.add_product( - time=datetime(2023, 10, 1, 12, 0, 0), - amount=27 * 2, - per_product=27, - product_count=2, - user_id=user.id, - product_id=product.id, - ), - Transaction.buy_product( - time=datetime(2023, 10, 1, 12, 0, 1), - product_count=1, - user_id=user2.id, - product_id=product.id, - ), - ] - - sql_session.add_all(transactions) - - assert len(user.transactions(sql_session)) == 3 - assert len(user2.transactions(sql_session)) == 2 diff --git a/tests/queries/test_buy_product.py b/tests/queries/test_buy_product.py deleted file mode 100644 index 4791d6b..0000000 --- a/tests/queries/test_buy_product.py +++ /dev/null @@ -1,177 +0,0 @@ -import math -from datetime import datetime - -from sqlalchemy.orm import Session - -from dibbler.models import Product, Transaction, User -from dibbler.queries.product_stock import product_stock -from dibbler.queries.user_balance import user_balance - - -def insert_test_data(sql_session: Session) -> tuple[User, Product]: - user = User("Test User") - product = Product("1234567890123", "Test Product") - - sql_session.add(user) - sql_session.add(product) - sql_session.commit() - - transactions = [ - Transaction.adjust_penalty( - time=datetime(2023, 10, 1, 10, 0, 0), - user_id=user.id, - penalty_multiplier_percent=200, - penalty_threshold=-100, - ), - Transaction.adjust_balance( - time=datetime(2023, 10, 1, 10, 0, 1), - user_id=user.id, - amount=100, - ), - Transaction.add_product( - time=datetime(2023, 10, 1, 10, 0, 2), - user_id=user.id, - product_id=product.id, - amount=27, - per_product=27, - product_count=1, - ), - ] - - sql_session.add_all(transactions) - sql_session.commit() - - return user, product - - -def test_buy_product_basic(sql_session: Session) -> None: - user, product = insert_test_data(sql_session) - - transaction = Transaction.buy_product( - time=datetime(2023, 10, 1, 12, 0, 0), - user_id=user.id, - product_id=product.id, - product_count=1, - ) - - sql_session.add(transaction) - sql_session.commit() - - -def test_buy_product_with_penalty(sql_session: Session) -> None: - user, product = insert_test_data(sql_session) - - transactions = [ - Transaction.adjust_balance( - time=datetime(2023, 10, 1, 11, 0, 0), - user_id=user.id, - amount=-200, - ) - ] - sql_session.add_all(transactions) - sql_session.commit() - - transaction = Transaction.buy_product( - time=datetime(2023, 10, 1, 12, 0, 0), - user_id=user.id, - product_id=product.id, - product_count=1, - ) - sql_session.add(transaction) - sql_session.commit() - - assert user_balance(sql_session, user) == 100 + 27 - 200 - (27 * 2) - - -def test_buy_product_with_interest(sql_session: Session) -> None: - user, product = insert_test_data(sql_session) - - transactions = [ - Transaction.adjust_interest( - time=datetime(2023, 10, 1, 11, 0, 0), - user_id=user.id, - interest_rate_percent=110, - ) - ] - sql_session.add_all(transactions) - sql_session.commit() - - transaction = Transaction.buy_product( - time=datetime(2023, 10, 1, 12, 0, 0), - user_id=user.id, - product_id=product.id, - product_count=1, - ) - sql_session.add(transaction) - sql_session.commit() - - assert user_balance(sql_session, user) == 100 + 27 - math.ceil(27 * 1.1) - - -def test_buy_product_with_changing_penalty(sql_session: Session) -> None: - user, product = insert_test_data(sql_session) - - transactions = [ - Transaction.adjust_balance( - time=datetime(2023, 10, 1, 11, 0, 0), - user_id=user.id, - amount=-200, - ) - ] - sql_session.add_all(transactions) - sql_session.commit() - - transaction = Transaction.buy_product( - time=datetime(2023, 10, 1, 12, 0, 0), - user_id=user.id, - product_id=product.id, - product_count=1, - ) - sql_session.add(transaction) - sql_session.commit() - - assert user_balance(sql_session, user) == 100 + 27 - 200 - (27 * 2) - - adjust_penalty = Transaction.adjust_penalty( - time=datetime(2023, 10, 1, 13, 0, 0), - user_id=user.id, - penalty_multiplier_percent=300, - penalty_threshold=-100, - ) - sql_session.add(adjust_penalty) - sql_session.commit() - - transaction = Transaction.buy_product( - time=datetime(2023, 10, 1, 14, 0, 0), - user_id=user.id, - product_id=product.id, - product_count=1, - ) - sql_session.add(transaction) - sql_session.commit() - - assert user_balance(sql_session, user) == 100 + 27 - 200 - (27 * 2) - (27 * 3) - - -def test_buy_product_with_changing_interest(sql_session: Session) -> None: - raise NotImplementedError("This test is not implemented yet.") - - -def test_buy_product_with_penalty_interest_combined(sql_session: Session) -> None: - raise NotImplementedError("This test is not implemented yet.") - - -def test_buy_product_more_than_stock(sql_session: Session) -> None: - user, product = insert_test_data(sql_session) - - transaction = Transaction.buy_product( - time=datetime(2023, 10, 1, 13, 0, 0), - product_count=10, - user_id=user.id, - product_id=product.id, - ) - - sql_session.add(transaction) - sql_session.commit() - - assert product_stock(sql_session, product) == 1 - 10 diff --git a/tests/queries/test_product_price.py b/tests/queries/test_product_price.py index a8153e0..47c8892 100644 --- a/tests/queries/test_product_price.py +++ b/tests/queries/test_product_price.py @@ -1,170 +1,342 @@ +import math from datetime import datetime +from pprint import pprint -from sqlalchemy import select from sqlalchemy.orm import Session from dibbler.models import Product, Transaction, User -from dibbler.queries.product_price import product_price +from dibbler.queries.product_price import product_price, product_price_log + +# TODO: see if we can use pytest_runtest_makereport to print the "product_price_log"s +# only on failures instead of inlining it in every test function -def insert_test_data(sql_session: Session) -> None: - # Add users - user1 = User("Test User 1") - user2 = User("Test User 2") +def insert_test_data(sql_session: Session) -> tuple[User, Product]: + user = User("Test User") + product = Product("1234567890123", "Test Product") - sql_session.add_all([user1, user2]) + sql_session.add(user) + sql_session.add(product) sql_session.commit() - # Add products - product1 = Product("1234567890123", "Test Product 1") - product2 = Product("9876543210987", "Test Product 2") - product3 = Product("1111111111111", "Test Product 3") - sql_session.add_all([product1, product2, product3]) - sql_session.commit() + return user, product + + +def test_product_price_no_transactions(sql_session: Session) -> None: + _, product = insert_test_data(sql_session) + + pprint(product_price_log(sql_session, product)) + + assert product_price(sql_session, product) == 0 + + +def test_product_price_basic_history(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) - # Add transactions transactions = [ - Transaction.adjust_balance( - time=datetime(2023, 10, 1, 10, 0, 0), - amount=100, - user_id=user1.id, - ), - Transaction.adjust_balance( - time=datetime(2023, 10, 1, 10, 0, 1), - amount=50, - user_id=user2.id, - ), - Transaction.adjust_balance( - time=datetime(2023, 10, 1, 10, 0, 2), - amount=-50, - user_id=user1.id, - ), Transaction.add_product( time=datetime(2023, 10, 1, 12, 0, 0), - amount=27 * 2, + amount=27 * 2 - 1, per_product=27, product_count=2, - user_id=user1.id, - product_id=product1.id, - ), - Transaction.buy_product( - time=datetime(2023, 10, 1, 12, 0, 1), - product_count=1, - user_id=user2.id, - product_id=product1.id, - ), - Transaction.adjust_stock( - time=datetime(2023, 10, 1, 12, 0, 2), - product_count=3, - user_id=user1.id, - product_id=product1.id, - ), - Transaction.adjust_stock( - time=datetime(2023, 10, 1, 12, 0, 3), - product_count=-2, - user_id=user1.id, - product_id=product1.id, - ), - Transaction.add_product( - time=datetime(2023, 10, 1, 12, 0, 4), - amount=50, - per_product=50, - product_count=1, - user_id=user1.id, - product_id=product3.id, - ), - Transaction.buy_product( - time=datetime(2023, 10, 1, 12, 0, 5), - product_count=1, - user_id=user1.id, - product_id=product3.id, - ), - Transaction.adjust_balance( - time=datetime(2023, 10, 1, 12, 0, 6), - amount=1000, - user_id=user1.id, + user_id=user.id, + product_id=product.id, ), ] sql_session.add_all(transactions) sql_session.commit() + pprint(product_price_log(sql_session, product)) -def test_product_price(sql_session: Session) -> None: - insert_test_data(sql_session) - - product1 = sql_session.scalars(select(Product).where(Product.name == "Test Product 1")).one() - assert product_price(sql_session, product1) == 27 - - -def test_product_price_no_transactions(sql_session: Session) -> None: - insert_test_data(sql_session) - - product2 = sql_session.scalars(select(Product).where(Product.name == "Test Product 2")).one() - assert product_price(sql_session, product2) == 0 + assert product_price(sql_session, product) == 27 def test_product_price_sold_out(sql_session: Session) -> None: - insert_test_data(sql_session) + user, product = insert_test_data(sql_session) - product3 = sql_session.scalars(select(Product).where(Product.name == "Test Product 3")).one() - assert product_price(sql_session, product3) == 50 + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 0), + amount=27 * 2 - 1, + per_product=27, + product_count=2, + user_id=user.id, + product_id=product.id, + ), + Transaction.buy_product( + time=datetime(2023, 10, 1, 12, 0, 1), + product_count=2, + user_id=user.id, + product_id=product.id, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + pprint(product_price_log(sql_session, product)) + + assert product_price(sql_session, product) == 27 def test_product_price_interest(sql_session: Session) -> None: - raise NotImplementedError("This test is not implemented yet.") + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.adjust_interest( + time=datetime(2023, 10, 1, 12, 0, 0), + interest_rate_percent=110, + user_id=user.id, + ), + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 1), + amount=27 * 2 - 1, + per_product=27, + product_count=2, + user_id=user.id, + product_id=product.id, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + pprint(product_price_log(sql_session, product)) + + product_price_ = product_price(sql_session, product) + product_price_interest = product_price(sql_session, product, include_interest=True) + + assert product_price_ == 27 + assert product_price_interest == math.ceil(27 * 1.1) def test_product_price_changing_interest(sql_session: Session) -> None: - raise NotImplementedError("This test is not implemented yet.") + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.adjust_interest( + time=datetime(2023, 10, 1, 12, 0, 0), + interest_rate_percent=110, + user_id=user.id, + ), + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 1), + amount=27 * 2 - 1, + per_product=27, + product_count=2, + user_id=user.id, + product_id=product.id, + ), + Transaction.adjust_interest( + time=datetime(2023, 10, 1, 12, 0, 2), + interest_rate_percent=120, + user_id=user.id, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + pprint(product_price_log(sql_session, product)) + + product_price_interest = product_price(sql_session, product, include_interest=True) + assert product_price_interest == math.ceil(27 * 1.2) + + +def test_product_price_old_transaction(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 1), + amount=27 * 2, + per_product=27, + product_count=2, + user_id=user.id, + product_id=product.id, + ), + # Price should be 27 + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 2), + amount=38 * 3, + per_product=38, + product_count=3, + user_id=user.id, + product_id=product.id, + ), + # price should be averaged upwards + ] + + sql_session.add_all(transactions) + sql_session.commit() + + until_transaction = transactions[0] + + pprint( + product_price_log( + sql_session, + product, + until=until_transaction, + ) + ) + + product_price_ = product_price( + sql_session, + product, + until=until_transaction, + ) + assert product_price_ == 27 # Price goes up and gets rounded up to the next integer def test_product_price_round_up_from_below(sql_session: Session) -> None: - raise NotImplementedError("This test is not implemented yet.") + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 1), + amount=27 * 2, + per_product=27, + product_count=2, + user_id=user.id, + product_id=product.id, + ), + # Price should be 27 + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 2), + amount=38 * 3, + per_product=38, + product_count=3, + user_id=user.id, + product_id=product.id, + ), + # price should be averaged upwards + ] + + sql_session.add_all(transactions) + sql_session.commit() + + pprint(product_price_log(sql_session, product)) + + product_price_ = product_price(sql_session, product) + assert product_price_ == math.ceil((27 * 2 + 38 * 3) / (2 + 3)) # Price goes down and gets rounded up to the next integer def test_product_price_round_up_from_above(sql_session: Session) -> None: - raise NotImplementedError("This test is not implemented yet.") + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 1), + amount=27 * 2, + per_product=27, + product_count=2, + user_id=user.id, + product_id=product.id, + ), + # Price should be 27 + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 2), + amount=20 * 3, + per_product=20, + product_count=3, + user_id=user.id, + product_id=product.id, + ), + # price should be averaged downwards + ] + + sql_session.add_all(transactions) + sql_session.commit() + + pprint(product_price_log(sql_session, product)) + + product_price_ = product_price(sql_session, product) + assert product_price_ == math.ceil((27 * 2 + 20 * 3) / (2 + 3)) def test_product_price_with_negative_stock_single_addition(sql_session: Session) -> None: - insert_test_data(sql_session) + user, product = insert_test_data(sql_session) - product1 = sql_session.scalars(select(Product).where(Product.name == "Test Product 1")).one() - user1 = sql_session.scalars(select(User).where(User.name == "Test User 1")).one() + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 13, 0, 0), + amount=1, + per_product=10, + product_count=1, + user_id=user.id, + product_id=product.id, + ), + Transaction.buy_product( + time=datetime(2023, 10, 1, 13, 0, 1), + product_count=10, + user_id=user.id, + product_id=product.id, + ), + Transaction.add_product( + time=datetime(2023, 10, 1, 13, 0, 2), + amount=22, + per_product=22, + product_count=1, + user_id=user.id, + product_id=product.id, + ), + ] - transaction = Transaction.buy_product( - time=datetime(2023, 10, 1, 13, 0, 0), - product_count=10, - user_id=user1.id, - product_id=product1.id, - ) - - sql_session.add(transaction) + sql_session.add_all(transactions) sql_session.commit() - product1_price = product_price(sql_session, product1) - assert product1_price == 27 - - transaction = Transaction.add_product( - time=datetime(2023, 10, 1, 13, 0, 1), - amount=22, - per_product=22, - product_count=1, - user_id=user1.id, - product_id=product1.id, - ) - - sql_session.add(transaction) - sql_session.commit() + pprint(product_price_log(sql_session, product)) # Stock went subzero, price should be the last added product price - product1_price = product_price(sql_session, product1) + product1_price = product_price(sql_session, product) assert product1_price == 22 # TODO: what happens when stock is still negative and yet new products are added? def test_product_price_with_negative_stock_multiple_additions(sql_session: Session) -> None: - raise NotImplementedError("This test is not implemented yet.") + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 13, 0, 0), + amount=1, + per_product=10, + product_count=1, + user_id=user.id, + product_id=product.id, + ), + Transaction.buy_product( + time=datetime(2023, 10, 1, 13, 0, 1), + product_count=10, + user_id=user.id, + product_id=product.id, + ), + Transaction.add_product( + time=datetime(2023, 10, 1, 13, 0, 2), + amount=22, + per_product=22, + product_count=1, + user_id=user.id, + product_id=product.id, + ), + Transaction.add_product( + time=datetime(2023, 10, 1, 13, 0, 3), + amount=29, + per_product=29, + product_count=2, + user_id=user.id, + product_id=product.id, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + pprint(product_price_log(sql_session, product)) + + # Stock went subzero, price should be the ceiled average of the last added products + product1_price = product_price(sql_session, product) + assert product1_price == math.ceil((22 + 29 * 2) / (1 + 2)) diff --git a/tests/queries/test_user_balance.py b/tests/queries/test_user_balance.py index b0ce281..5661b37 100644 --- a/tests/queries/test_user_balance.py +++ b/tests/queries/test_user_balance.py @@ -1,102 +1,306 @@ +import math from datetime import datetime +from pprint import pprint -from sqlalchemy import select from sqlalchemy.orm import Session from dibbler.models import Product, Transaction, User from dibbler.queries.user_balance import user_balance, user_balance_log +# TODO: see if we can use pytest_runtest_makereport to print the "user_balance_log"s +# only on failures instead of inlining it in every test function -def insert_test_data(sql_session: Session) -> None: - # Add users - user1 = User("Test User 1") - user2 = User("Test User 2") - sql_session.add(user1) - sql_session.add(user2) +def insert_test_data(sql_session: Session) -> tuple[User, Product]: + user = User("Test User") + product = Product("1234567890123", "Test Product") + + sql_session.add(user) + sql_session.add(product) sql_session.commit() - # Add products - product1 = Product("1234567890123", "Test Product 1") - product2 = Product("9876543210987", "Test Product 2") - sql_session.add(product1) - sql_session.add(product2) - sql_session.commit() + return user, product + + +def test_user_balance_no_transactions(sql_session: Session) -> None: + user, _ = insert_test_data(sql_session) + + pprint(user_balance_log(sql_session, user)) + + balance = user_balance(sql_session, user) + + assert balance == 0 + + +def test_user_balance_basic_history(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) - # Add transactions transactions = [ Transaction.adjust_balance( time=datetime(2023, 10, 1, 10, 0, 0), + user_id=user.id, amount=100, - user_id=user1.id, - ), - Transaction.adjust_balance( - time=datetime(2023, 10, 1, 10, 0, 1), - amount=50, - user_id=user2.id, - ), - Transaction.adjust_balance( - time=datetime(2023, 10, 1, 10, 0, 2), - amount=-50, - user_id=user1.id, ), Transaction.add_product( - time=datetime(2023, 10, 1, 12, 0, 0), - amount=27 * 2, + time=datetime(2023, 10, 1, 10, 0, 1), + user_id=user.id, + product_id=product.id, + amount=27, per_product=27, - product_count=2, - user_id=user1.id, - product_id=product1.id, - ), - Transaction.buy_product( - time=datetime(2023, 10, 1, 12, 0, 1), product_count=1, - user_id=user2.id, - product_id=product1.id, ), ] sql_session.add_all(transactions) sql_session.commit() + pprint(user_balance_log(sql_session, user)) -def test_user_balance_basic_history(sql_session: Session) -> None: - insert_test_data(sql_session) + balance = user_balance(sql_session, user) - user1 = sql_session.scalars(select(User).where(User.name == "Test User 1")).one() - user2 = sql_session.scalars(select(User).where(User.name == "Test User 2")).one() - - assert user_balance(sql_session, user1) == 100 - 50 + 27 * 2 - assert user_balance(sql_session, user2) == 50 - 27 + assert balance == 100 + 27 -def test_user_balance_no_transactions(sql_session: Session) -> None: - raise NotImplementedError("This test is not implemented yet.") +def test_user_balance_with_transfers(sql_session: Session) -> None: + user1, product = insert_test_data(sql_session) + + user2 = User("Test User 2") + sql_session.add(user2) + sql_session.commit() + + transactions = [ + Transaction.adjust_balance( + time=datetime(2023, 10, 1, 10, 0, 0), + user_id=user1.id, + amount=100, + ), + Transaction.transfer( + time=datetime(2023, 10, 1, 10, 0, 1), + user_id=user1.id, + transfer_user_id=user2.id, + amount=50, + ), + Transaction.transfer( + time=datetime(2023, 10, 1, 10, 0, 2), + user_id=user2.id, + transfer_user_id=user1.id, + amount=30, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + pprint(user_balance_log(sql_session, user1)) + + user1_balance = user_balance(sql_session, user1) + assert user1_balance == 100 - 50 + 30 + + pprint(user_balance_log(sql_session, user2)) + + user2_balance = user_balance(sql_session, user2) + assert user2_balance == 50 - 30 def test_user_balance_complex_history(sql_session: Session) -> None: raise NotImplementedError("This test is not implemented yet.") -def test_user_balance_with_tranfers(sql_session: Session) -> None: - raise NotImplementedError("This test is not implemented yet.") - - def test_user_balance_penalty(sql_session: Session) -> None: - raise NotImplementedError("This test is not implemented yet.") + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 10, 0, 0), + user_id=user.id, + product_id=product.id, + amount=27, + per_product=27, + product_count=1, + ), + Transaction.adjust_balance( + time=datetime(2023, 10, 1, 11, 0, 0), + user_id=user.id, + amount=-200, + ), + # Penalized, pays 2x the price (default penalty) + Transaction.buy_product( + time=datetime(2023, 10, 1, 12, 0, 0), + user_id=user.id, + product_id=product.id, + product_count=1, + ), + ] + sql_session.add_all(transactions) + sql_session.commit() + + pprint(user_balance_log(sql_session, user)) + + assert user_balance(sql_session, user) == 27 - 200 - (27 * 2) def test_user_balance_changing_penalty(sql_session: Session) -> None: - raise NotImplementedError("This test is not implemented yet.") + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 10, 0, 0), + user_id=user.id, + product_id=product.id, + amount=27, + per_product=27, + product_count=1, + ), + Transaction.adjust_balance( + time=datetime(2023, 10, 1, 11, 0, 0), + user_id=user.id, + amount=-200, + ), + # Penalized, pays 2x the price (default penalty) + Transaction.buy_product( + time=datetime(2023, 10, 1, 12, 0, 0), + user_id=user.id, + product_id=product.id, + product_count=1, + ), + Transaction.adjust_penalty( + time=datetime(2023, 10, 1, 13, 0, 0), + user_id=user.id, + penalty_multiplier_percent=300, + penalty_threshold=-100, + ), + # Penalized, pays 3x the price + Transaction.buy_product( + time=datetime(2023, 10, 1, 14, 0, 0), + user_id=user.id, + product_id=product.id, + product_count=1, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + pprint(user_balance_log(sql_session, user)) + + assert user_balance(sql_session, user) == 27 - 200 - (27 * 2) - (27 * 3) def test_user_balance_interest(sql_session: Session) -> None: - raise NotImplementedError("This test is not implemented yet.") + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 10, 0, 0), + user_id=user.id, + product_id=product.id, + amount=27, + per_product=27, + product_count=1, + ), + Transaction.adjust_interest( + time=datetime(2023, 10, 1, 11, 0, 0), + user_id=user.id, + interest_rate_percent=110, + ), + Transaction.buy_product( + time=datetime(2023, 10, 1, 12, 0, 0), + user_id=user.id, + product_id=product.id, + product_count=1, + ), + ] + sql_session.add_all(transactions) + sql_session.commit() + + pprint(user_balance_log(sql_session, user)) + + assert user_balance(sql_session, user) == 27 - math.ceil(27 * 1.1) def test_user_balance_changing_interest(sql_session: Session) -> None: - raise NotImplementedError("This test is not implemented yet.") + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 10, 0, 0), + user_id=user.id, + product_id=product.id, + amount=27 * 3, + per_product=27, + product_count=3, + ), + Transaction.adjust_interest( + time=datetime(2023, 10, 1, 11, 0, 0), + user_id=user.id, + interest_rate_percent=110, + ), + # Pays 1.1x the price + Transaction.buy_product( + time=datetime(2023, 10, 1, 12, 0, 0), + user_id=user.id, + product_id=product.id, + product_count=1, + ), + Transaction.adjust_interest( + time=datetime(2023, 10, 1, 13, 0, 0), + user_id=user.id, + interest_rate_percent=120, + ), + # Pays 1.2x the price + Transaction.buy_product( + time=datetime(2023, 10, 1, 14, 0, 0), + user_id=user.id, + product_id=product.id, + product_count=1, + ), + ] + sql_session.add_all(transactions) + sql_session.commit() + + pprint(user_balance_log(sql_session, user)) + + assert user_balance(sql_session, user) == 27 * 3 - math.ceil(27 * 1.1) - math.ceil(27 * 1.2) def test_user_balance_penalty_interest_combined(sql_session: Session) -> None: - raise NotImplementedError("This test is not implemented yet.") + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 10, 0, 0), + user_id=user.id, + product_id=product.id, + amount=27, + per_product=27, + product_count=1, + ), + Transaction.adjust_interest( + time=datetime(2023, 10, 1, 11, 0, 0), + user_id=user.id, + interest_rate_percent=110, + ), + Transaction.adjust_balance( + time=datetime(2023, 10, 1, 12, 0, 0), + user_id=user.id, + amount=-200, + ), + # Penalized, pays 2x the price (default penalty) + # Pays 1.1x the price + Transaction.buy_product( + time=datetime(2023, 10, 1, 13, 0, 0), + user_id=user.id, + product_id=product.id, + product_count=1, + ), + ] + sql_session.add_all(transactions) + sql_session.commit() + + pprint(user_balance_log(sql_session, user)) + + assert user_balance(sql_session, user) == ( + 27 + - 200 + - math.ceil(27 * 2 * 1.1) + ) diff --git a/tests/queries/test_user_transactions.py b/tests/queries/test_user_transactions.py new file mode 100644 index 0000000..a129a24 --- /dev/null +++ b/tests/queries/test_user_transactions.py @@ -0,0 +1,60 @@ +from datetime import datetime + +from sqlalchemy.orm import Session + +from dibbler.models import Product, Transaction, User +from dibbler.queries.user_transactions import user_transactions + + +def insert_test_data(sql_session: Session) -> User: + user = User("Test User") + sql_session.add(user) + sql_session.commit() + + return user + + +def test_user_transactions(sql_session: Session): + user = insert_test_data(sql_session) + + product = Product("1234567890123", "Test Product") + user2 = User("Test User 2") + sql_session.add_all([product, user2]) + sql_session.commit() + + transactions = [ + Transaction.adjust_balance( + time=datetime(2023, 10, 1, 10, 0, 0), + amount=100, + user_id=user.id, + ), + Transaction.adjust_balance( + time=datetime(2023, 10, 1, 10, 0, 1), + amount=50, + user_id=user2.id, + ), + Transaction.adjust_balance( + time=datetime(2023, 10, 1, 10, 0, 2), + amount=-50, + user_id=user.id, + ), + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 0), + amount=27 * 2, + per_product=27, + product_count=2, + user_id=user.id, + product_id=product.id, + ), + Transaction.buy_product( + time=datetime(2023, 10, 1, 12, 0, 1), + product_count=1, + user_id=user2.id, + product_id=product.id, + ), + ] + + sql_session.add_all(transactions) + + assert len(user_transactions(sql_session, user)) == 3 + assert len(user_transactions(sql_session, user2)) == 2