From f54d8b2945bc530ef874c877fa4513c8a78dc4c3 Mon Sep 17 00:00:00 2001 From: h7x4 Date: Wed, 10 Dec 2025 16:00:36 +0900 Subject: [PATCH] Write a set of queries to go along with the event sourcing model --- dibbler/lib/__init__.py | 0 dibbler/lib/helpers.py | 76 +- dibbler/queries/__init__.py | 46 + dibbler/queries/add_product.py | 51 ++ dibbler/queries/adjust_balance.py | 33 + dibbler/queries/adjust_interest.py | 36 + dibbler/queries/adjust_penalty.py | 49 ++ dibbler/queries/adjust_stock.py | 40 + dibbler/queries/buy_product.py | 38 + dibbler/queries/create_product.py | 25 + dibbler/queries/create_user.py | 21 + dibbler/queries/current_interest.py | 55 ++ dibbler/queries/current_penalty.py | 59 ++ dibbler/queries/joint_buy_product.py | 68 ++ dibbler/queries/product_owners.py | 309 +++++++ dibbler/queries/product_price.py | 309 +++++++ dibbler/queries/product_stock.py | 126 +++ dibbler/queries/query_helpers.py | 80 ++ dibbler/queries/search_product.py | 42 + dibbler/queries/search_user.py | 39 + dibbler/queries/throw_product.py | 42 + dibbler/queries/transaction_log.py | 142 ++++ dibbler/queries/transfer.py | 38 + dibbler/queries/user_balance.py | 567 +++++++++++++ dibbler/queries/user_products.py | 48 ++ tests/helpers.py | 28 + tests/queries/__init__.py | 0 tests/queries/test_adjust_interest.py | 85 ++ tests/queries/test_adjust_penalty.py | 179 ++++ tests/queries/test_current_interest.py | 38 + tests/queries/test_current_penalty.py | 46 + tests/queries/test_joint_buy_product.py | 199 +++++ tests/queries/test_product_owners.py | 335 ++++++++ tests/queries/test_product_price.py | 447 ++++++++++ tests/queries/test_product_stock.py | 285 +++++++ tests/queries/test_search_product.py | 96 +++ tests/queries/test_search_user.py | 86 ++ tests/queries/test_transaction_log.py | 687 +++++++++++++++ tests/queries/test_user_balance.py | 1016 +++++++++++++++++++++++ 39 files changed, 5792 insertions(+), 74 deletions(-) create mode 100644 dibbler/lib/__init__.py create mode 100644 dibbler/queries/__init__.py create mode 100644 dibbler/queries/add_product.py create mode 100644 dibbler/queries/adjust_balance.py create mode 100644 dibbler/queries/adjust_interest.py create mode 100644 dibbler/queries/adjust_penalty.py create mode 100644 dibbler/queries/adjust_stock.py create mode 100644 dibbler/queries/buy_product.py create mode 100644 dibbler/queries/create_product.py create mode 100644 dibbler/queries/create_user.py create mode 100644 dibbler/queries/current_interest.py create mode 100644 dibbler/queries/current_penalty.py create mode 100644 dibbler/queries/joint_buy_product.py create mode 100644 dibbler/queries/product_owners.py create mode 100644 dibbler/queries/product_price.py create mode 100644 dibbler/queries/product_stock.py create mode 100644 dibbler/queries/query_helpers.py create mode 100644 dibbler/queries/search_product.py create mode 100644 dibbler/queries/search_user.py create mode 100644 dibbler/queries/throw_product.py create mode 100644 dibbler/queries/transaction_log.py create mode 100644 dibbler/queries/transfer.py create mode 100644 dibbler/queries/user_balance.py create mode 100644 dibbler/queries/user_products.py create mode 100644 tests/helpers.py create mode 100644 tests/queries/__init__.py create mode 100644 tests/queries/test_adjust_interest.py create mode 100644 tests/queries/test_adjust_penalty.py create mode 100644 tests/queries/test_current_interest.py create mode 100644 tests/queries/test_current_penalty.py create mode 100644 tests/queries/test_joint_buy_product.py create mode 100644 tests/queries/test_product_owners.py create mode 100644 tests/queries/test_product_price.py create mode 100644 tests/queries/test_product_stock.py create mode 100644 tests/queries/test_search_product.py create mode 100644 tests/queries/test_search_user.py create mode 100644 tests/queries/test_transaction_log.py create mode 100644 tests/queries/test_user_balance.py diff --git a/dibbler/lib/__init__.py b/dibbler/lib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dibbler/lib/helpers.py b/dibbler/lib/helpers.py index 30926a3..0aab88e 100644 --- a/dibbler/lib/helpers.py +++ b/dibbler/lib/helpers.py @@ -1,79 +1,7 @@ -import pwd -import subprocess import os +import pwd import signal - -from sqlalchemy import or_, and_ - -from ..models import User, Product - - -def search_user(string, session, ignorethisflag=None): - string = string.lower() - exact_match = ( - session.query(User) - .filter(or_(User.name == string, User.card == string, User.rfid == string)) - .first() - ) - if exact_match: - return exact_match - user_list = ( - session.query(User) - .filter( - or_( - User.name.ilike(f"%{string}%"), - User.card.ilike(f"%{string}%"), - User.rfid.ilike(f"%{string}%"), - ) - ) - .all() - ) - return user_list - - -def search_product(string, session, find_hidden_products=True): - if find_hidden_products: - exact_match = ( - session.query(Product) - .filter(or_(Product.bar_code == string, Product.name == string)) - .first() - ) - else: - exact_match = ( - session.query(Product) - .filter( - or_( - Product.bar_code == string, - and_(Product.name == string, Product.hidden is False), - ) - ) - .first() - ) - if exact_match: - return exact_match - if find_hidden_products: - product_list = ( - session.query(Product) - .filter( - or_( - Product.bar_code.ilike(f"%{string}%"), - Product.name.ilike(f"%{string}%"), - ) - ) - .all() - ) - else: - product_list = ( - session.query(Product) - .filter( - or_( - Product.bar_code.ilike(f"%{string}%"), - and_(Product.name.ilike(f"%{string}%"), Product.hidden is False), - ) - ) - .all() - ) - return product_list +import subprocess def system_user_exists(username): diff --git a/dibbler/queries/__init__.py b/dibbler/queries/__init__.py new file mode 100644 index 0000000..f1b1854 --- /dev/null +++ b/dibbler/queries/__init__.py @@ -0,0 +1,46 @@ +__all__ = [ + "add_product", + "adjust_balance", + "adjust_interest", + "adjust_penalty", + "adjust_stock", + "create_product", + "create_user", + "current_interest", + "current_penalty", + "joint_buy_product", + "product_owners", + "product_owners_log", + "product_price", + "product_price_log", + "product_stock", + "search_product", + "search_user", + "throw_product", + "transaction_log", + "transfer", + "user_balance", + "user_balance_log", + "user_products", +] + +from .add_product import add_product +from .adjust_balance import adjust_balance +from .adjust_interest import adjust_interest +from .adjust_penalty import adjust_penalty +from .adjust_stock import adjust_stock +from .create_product import create_product +from .create_user import create_user +from .current_interest import current_interest +from .current_penalty import current_penalty +from .joint_buy_product import joint_buy_product +from .product_owners import product_owners, product_owners_log +from .product_price import product_price, product_price_log +from .product_stock import product_stock +from .search_product import search_product +from .search_user import search_user +from .throw_product import throw_product +from .transaction_log import transaction_log +from .transfer import transfer +from .user_balance import user_balance, user_balance_log +from .user_products import user_products diff --git a/dibbler/queries/add_product.py b/dibbler/queries/add_product.py new file mode 100644 index 0000000..5c46191 --- /dev/null +++ b/dibbler/queries/add_product.py @@ -0,0 +1,51 @@ +from datetime import datetime + +from sqlalchemy.orm import Session + +from dibbler.models import Product, Transaction, User + + +def add_product( + sql_session: Session, + user: User, + product: Product, + amount: int, + per_product: int, + product_count: int, + time: datetime | None = None, + message: str | None = None, +) -> Transaction: + if user.id is None: + raise ValueError("User must be persisted in the database.") + + if product.id is None: + raise ValueError("Product must be persisted in the database.") + + if amount <= 0: + raise ValueError("Amount must be positive.") + + if per_product <= 0: + raise ValueError("Per product price must be positive.") + + if product_count <= 0: + raise ValueError("Product count must be positive.") + + if per_product * product_count < amount: + raise ValueError("Total per product price must be at least equal to amount.") + + # TODO: verify time is not behind last transaction's time + + transaction = Transaction.add_product( + user_id=user.id, + product_id=product.id, + amount=amount, + per_product=per_product, + product_count=product_count, + time=time, + message=message, + ) + + sql_session.add(transaction) + sql_session.commit() + + return transaction diff --git a/dibbler/queries/adjust_balance.py b/dibbler/queries/adjust_balance.py new file mode 100644 index 0000000..583a240 --- /dev/null +++ b/dibbler/queries/adjust_balance.py @@ -0,0 +1,33 @@ +from datetime import datetime + +from sqlalchemy.orm import Session + +from dibbler.models import Transaction, User + + +def adjust_balance( + sql_session: Session, + user: User, + amount: int, + time: datetime | None = None, + message: str | None = None, +) -> Transaction: + if user.id is None: + raise ValueError("User must be persisted in the database.") + + if amount == 0: + raise ValueError("Amount must be non-zero.") + + # TODO: verify time is not behind last transaction's time + + transaction = Transaction.adjust_balance( + user_id=user.id, + amount=amount, + time=time, + message=message, + ) + + sql_session.add(transaction) + sql_session.commit() + + return transaction diff --git a/dibbler/queries/adjust_interest.py b/dibbler/queries/adjust_interest.py new file mode 100644 index 0000000..8d5e865 --- /dev/null +++ b/dibbler/queries/adjust_interest.py @@ -0,0 +1,36 @@ +from datetime import datetime + +from sqlalchemy.orm import Session + +from dibbler.models import Transaction, User + +# TODO: this type of transaction should be password protected. +# the password can be set as a string literal in the config file. + + +def adjust_interest( + sql_session: Session, + user: User, + new_interest: int, + time: datetime | None = None, + message: str | None = None, +) -> Transaction: + if new_interest < 0: + raise ValueError("Interest rate cannot be negative") + + if user.id is None: + raise ValueError("User must be persisted in the database.") + + # TODO: verify time is not behind last transaction's time + + transaction = Transaction.adjust_interest( + user_id=user.id, + interest_rate_percent=new_interest, + time=time, + message=message, + ) + + sql_session.add(transaction) + sql_session.commit() + + return transaction diff --git a/dibbler/queries/adjust_penalty.py b/dibbler/queries/adjust_penalty.py new file mode 100644 index 0000000..7bb04f4 --- /dev/null +++ b/dibbler/queries/adjust_penalty.py @@ -0,0 +1,49 @@ +from datetime import datetime + +from sqlalchemy.orm import Session + +from dibbler.models import Transaction, User +from dibbler.queries.current_penalty import current_penalty + +# TODO: this type of transaction should be password protected. +# the password can be set as a string literal in the config file. + + +def adjust_penalty( + sql_session: Session, + user: User, + new_penalty: int | None = None, + new_penalty_multiplier: int | None = None, + time: datetime | None = None, + message: str | None = None, +) -> Transaction: + if new_penalty is None and new_penalty_multiplier is None: + raise ValueError("At least one of new_penalty or new_penalty_multiplier must be provided") + + if new_penalty_multiplier is not None and new_penalty_multiplier < 100: + raise ValueError("Penalty multiplier cannot be less than 100%") + + if user.id is None: + raise ValueError("User must be persisted in the database.") + + if new_penalty is None or new_penalty_multiplier is None: + existing_penalty, existing_penalty_multiplier = current_penalty(sql_session) + if new_penalty is None: + new_penalty = existing_penalty + if new_penalty_multiplier is None: + new_penalty_multiplier = existing_penalty_multiplier + + # TODO: verify time is not behind last transaction's time + + transaction = Transaction.adjust_penalty( + user_id=user.id, + penalty_threshold=new_penalty, + penalty_multiplier_percent=new_penalty_multiplier, + time=time, + message=message, + ) + + sql_session.add(transaction) + sql_session.commit() + + return transaction diff --git a/dibbler/queries/adjust_stock.py b/dibbler/queries/adjust_stock.py new file mode 100644 index 0000000..9fb182a --- /dev/null +++ b/dibbler/queries/adjust_stock.py @@ -0,0 +1,40 @@ +from datetime import datetime + +from sqlalchemy.orm import Session + +from dibbler.models import Product, Transaction, User + + +def adjust_stock( + sql_session: Session, + user: User, + product: Product, + product_count: int, + time: datetime | None = None, + message: str | None = None, +) -> Transaction: + if user.id is None: + raise ValueError("User must be persisted in the database.") + + if product.id is None: + raise ValueError("Product must be persisted in the database.") + + if product_count == 0: + raise ValueError("Product count must be non-zero.") + + # TODO: it should not be possible to reduce stock below zero. + # + # TODO: verify time is not behind last transaction's time + + transaction = Transaction.adjust_stock( + user_id=user.id, + product_id=product.id, + product_count=product_count, + time=time, + message=message, + ) + + sql_session.add(transaction) + sql_session.commit() + + return transaction diff --git a/dibbler/queries/buy_product.py b/dibbler/queries/buy_product.py new file mode 100644 index 0000000..16b85ed --- /dev/null +++ b/dibbler/queries/buy_product.py @@ -0,0 +1,38 @@ +from datetime import datetime + +from sqlalchemy.orm import Session + +from dibbler.models import Product, Transaction, User + + +def buy_product( + sql_session: Session, + user: User, + product: Product, + product_count: int, + time: datetime | None = None, + message: str | None = None, +) -> Transaction: + if user.id is None: + raise ValueError("User must be persisted in the database.") + + if product.id is None: + raise ValueError("Product must be persisted in the database.") + + if product_count <= 0: + raise ValueError("Product count must be positive.") + + # TODO: verify time is not behind last transaction's time + + transaction = Transaction.buy_product( + user_id=user.id, + product_id=product.id, + product_count=product_count, + time=time, + message=message, + ) + + sql_session.add(transaction) + sql_session.commit() + + return transaction diff --git a/dibbler/queries/create_product.py b/dibbler/queries/create_product.py new file mode 100644 index 0000000..d0e619d --- /dev/null +++ b/dibbler/queries/create_product.py @@ -0,0 +1,25 @@ +from sqlalchemy.orm import Session + +from dibbler.models import Product + + +def create_product( + sql_session: Session, + name: str, + barcode: str, +) -> Product: + if not name: + raise ValueError("Name cannot be empty.") + + if not barcode: + raise ValueError("Barcode cannot be empty.") + + # TODO: check for duplicate names, barcodes + + # TODO: add more validation for barcode + + product = Product(barcode, name) + sql_session.add(product) + sql_session.commit() + + return product diff --git a/dibbler/queries/create_user.py b/dibbler/queries/create_user.py new file mode 100644 index 0000000..eb89ae8 --- /dev/null +++ b/dibbler/queries/create_user.py @@ -0,0 +1,21 @@ +from sqlalchemy.orm import Session + +from dibbler.models import User + + +def create_user( + sql_session: Session, + name: str, + card: str | None, + rfid: str | None, +) -> User: + if not name: + raise ValueError("Name cannot be empty.") + + # TODO: check for duplicate names, cards, rfids + + user = User(name=name, card=card, rfid=rfid) + sql_session.add(user) + sql_session.commit() + + return user diff --git a/dibbler/queries/current_interest.py b/dibbler/queries/current_interest.py new file mode 100644 index 0000000..58fe2cd --- /dev/null +++ b/dibbler/queries/current_interest.py @@ -0,0 +1,55 @@ +from datetime import datetime + +from sqlalchemy import BindParameter, bindparam, select +from sqlalchemy.orm import Session + +from dibbler.models import Transaction, TransactionType +from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENT +from dibbler.queries.query_helpers import until_filter + + +def current_interest( + sql_session: Session, + until_time: BindParameter[datetime] | datetime | None = None, + until_transaction: BindParameter[Transaction] | Transaction | None = None, + until_inclusive: bool = True, +) -> int: + """ + Get the current interest rate percentage as of a given time or transaction. + + Returns the interest rate percentage as an integer. + """ + + if not (until_time is None or until_transaction is None): + raise ValueError("Cannot filter by both until_time and until_transaction.") + + if isinstance(until_time, datetime): + until_time = BindParameter("until_time", value=until_time) + + if isinstance(until_transaction, Transaction): + if until_transaction.id is None: + raise ValueError("until_transaction must be persisted in the database.") + until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id) + else: + until_transaction_id = None + + result = sql_session.scalars( + select(Transaction) + .where( + Transaction.type_ == TransactionType.ADJUST_INTEREST, + until_filter( + until_time=until_time, + until_transaction_id=until_transaction_id, + until_inclusive=until_inclusive, + ), + ) + .order_by(Transaction.time.desc()) + .limit(1) + ).one_or_none() + + if result is None: + return DEFAULT_INTEREST_RATE_PERCENT + elif result.interest_rate_percent is None: + return DEFAULT_INTEREST_RATE_PERCENT + else: + return result.interest_rate_percent diff --git a/dibbler/queries/current_penalty.py b/dibbler/queries/current_penalty.py new file mode 100644 index 0000000..2aef237 --- /dev/null +++ b/dibbler/queries/current_penalty.py @@ -0,0 +1,59 @@ +from datetime import datetime + +from sqlalchemy import BindParameter, bindparam, select +from sqlalchemy.orm import Session + +from dibbler.models import Transaction, TransactionType +from dibbler.models.Transaction import ( + DEFAULT_PENALTY_MULTIPLIER_PERCENT, + DEFAULT_PENALTY_THRESHOLD, +) +from dibbler.queries.query_helpers import until_filter + + +def current_penalty( + sql_session: Session, + until_time: BindParameter[datetime] | datetime | None = None, + until_transaction: BindParameter[Transaction] | Transaction | None = None, + until_inclusive: bool = True, +) -> tuple[int, int]: + """ + Get the current penalty settings (threshold and multiplier percentage) as of a given time or transaction. + + Returns a tuple of `(penalty_threshold, penalty_multiplier_percentage)`. + """ + + if not (until_time is None or until_transaction is None): + raise ValueError("Cannot filter by both until_time and until_transaction.") + + if isinstance(until_time, datetime): + until_time = BindParameter("until_time", value=until_time) + + if isinstance(until_transaction, Transaction): + if until_transaction.id is None: + raise ValueError("until_transaction must be persisted in the database.") + until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id) + else: + until_transaction_id = None + + result = sql_session.scalars( + select(Transaction) + .where( + Transaction.type_ == TransactionType.ADJUST_PENALTY, + until_filter( + until_time=until_time, + until_transaction_id=until_transaction_id, + until_inclusive=until_inclusive, + ), + ) + .order_by(Transaction.time.desc()) + .limit(1) + ).one_or_none() + + if result is None: + return DEFAULT_PENALTY_THRESHOLD, DEFAULT_PENALTY_MULTIPLIER_PERCENT + + assert result.penalty_threshold is not None, "Penalty threshold must be set" + assert result.penalty_multiplier_percent is not None, "Penalty multiplier percent must be set" + + return result.penalty_threshold, result.penalty_multiplier_percent diff --git a/dibbler/queries/joint_buy_product.py b/dibbler/queries/joint_buy_product.py new file mode 100644 index 0000000..51bbc49 --- /dev/null +++ b/dibbler/queries/joint_buy_product.py @@ -0,0 +1,68 @@ +from datetime import datetime + +from sqlalchemy.orm import Session + +from dibbler.models import ( + Product, + Transaction, + User, +) + + +def joint_buy_product( + sql_session: Session, + product: Product, + product_count: int, + instigator: User, + users: list[User], + time: datetime | None = None, + message: str | None = None, +) -> list[Transaction]: + """ + Create buy product transactions for multiple users at once. + """ + + if product.id is None: + raise ValueError("Product must be persisted in the database.") + + if instigator.id is None: + raise ValueError("Instigator must be persisted in the database.") + + if len(users) == 0: + raise ValueError("At least bying one user must be specified.") + + if any(user.id is None for user in users): + raise ValueError("All users must be persisted in the database.") + + if instigator not in users: + raise ValueError("Instigator must be in the list of users buying the product.") + + if product_count <= 0: + raise ValueError("Product count must be positive.") + + # TODO: verify time is not behind last transaction's time + + joint_transaction = Transaction.joint( + user_id=instigator.id, + product_id=product.id, + product_count=product_count, + time=time, + message=message, + ) + sql_session.add(joint_transaction) + sql_session.flush() # Ensure joint_transaction gets an ID + + transactions = [joint_transaction] + + for user in users: + buy_transaction = Transaction.joint_buy_product( + user_id=user.id, + joint_transaction_id=joint_transaction.id, + time=time, + message=message, + ) + sql_session.add(buy_transaction) + transactions.append(buy_transaction) + + sql_session.commit() + return transactions diff --git a/dibbler/queries/product_owners.py b/dibbler/queries/product_owners.py new file mode 100644 index 0000000..d9a82c0 --- /dev/null +++ b/dibbler/queries/product_owners.py @@ -0,0 +1,309 @@ +from dataclasses import dataclass +from datetime import datetime + +from sqlalchemy import ( + CTE, + BindParameter, + and_, + bindparam, + case, + func, + or_, + select, +) +from sqlalchemy.orm import Session + +from dibbler.models import ( + Product, + Transaction, + TransactionType, + User, +) +from dibbler.queries.product_stock import _product_stock_query +from dibbler.queries.query_helpers import ( + CONST_NONE, + CONST_ONE, + CONST_ZERO, + until_filter, +) + + +def _product_owners_query( + product_id: BindParameter[int] | int, + use_cache: bool = True, + until_time: BindParameter[datetime] | datetime | None = None, + until_transaction: Transaction | None = None, + until_inclusive: bool = True, + cte_name: str = "rec_cte", + trx_subset_name: str = "trx_subset", +) -> CTE: + """ + The inner query for inferring the owners of a given product. + """ + + if use_cache: + print("WARNING: Using cache for users owning product query is not implemented yet.") + + if isinstance(product_id, int): + product_id = bindparam("product_id", value=product_id) + + if until_time is not None and until_transaction is not None: + raise ValueError("Cannot filter by both until_time and until_transaction.") + + if isinstance(until_time, datetime): + until_time = bindparam("until_time", value=until_time) + + if isinstance(until_transaction, Transaction): + if until_transaction.id is None: + raise ValueError("until_transaction must be persisted in the database.") + until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id) + else: + until_transaction_id = None + + product_stock = _product_stock_query( + product_id=product_id, + use_cache=use_cache, + until_time=until_time, + until_transaction=until_transaction, + until_inclusive=until_inclusive, + ) + + # Subset of transactions that we'll want to iterate over. + trx_subset = ( + select( + func.row_number().over(order_by=Transaction.time.desc()).label("i"), + Transaction.time, + Transaction.id, + Transaction.type_, + Transaction.user_id, + Transaction.product_count, + ) + .where( + or_( + Transaction.type_ == TransactionType.ADD_PRODUCT.as_literal_column(), + and_( + Transaction.type_ == TransactionType.ADJUST_STOCK.as_literal_column(), + Transaction.product_count > CONST_ZERO, + ), + ), + Transaction.product_id == product_id, + until_filter( + until_time=until_time, + until_transaction_id=until_transaction_id, + until_inclusive=until_inclusive, + ), + ) + .order_by(Transaction.time.desc()) + .subquery(trx_subset_name) + ) + + initial_element = select( + CONST_ZERO.label("i"), + CONST_ZERO.label("time"), + CONST_NONE.label("transaction_id"), + CONST_NONE.label("user_id"), + CONST_ZERO.label("product_count"), + product_stock.scalar_subquery().label("products_left_to_account_for"), + ) + + recursive_cte = initial_element.cte(name=cte_name, recursive=True) + + recursive_elements = ( + select( + trx_subset.c.i, + trx_subset.c.time, + trx_subset.c.id.label("transaction_id"), + # Who added the product (if any) + case( + # Someone adds the product -> they own it + ( + trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(), + trx_subset.c.user_id, + ), + else_=CONST_NONE, + ).label("user_id"), + # How many products did they add (if any) + case( + # Someone adds the product -> they added a certain amount of products + ( + trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(), + trx_subset.c.product_count, + ), + # Stock got adjusted upwards -> consider those products as added by nobody + ( + (trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column()) + and (trx_subset.c.product_count > CONST_ZERO), + trx_subset.c.product_count, + ), + else_=CONST_ZERO, + ).label("product_count"), + # How many products left to account for + case( + # Someone adds the product -> known owner, decrease the number of products left to account for + ( + trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(), + recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count, + ), + # Stock got adjusted upwards -> none owner, decrease the number of products left to account for + ( + and_( + trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column(), + trx_subset.c.product_count > CONST_ZERO, + ), + recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count, + ), + else_=recursive_cte.c.products_left_to_account_for, + ).label("products_left_to_account_for"), + ) + .select_from(trx_subset) + .where( + and_( + trx_subset.c.i == recursive_cte.c.i + CONST_ONE, + # Base case: stop if we've accounted for all products + recursive_cte.c.products_left_to_account_for > CONST_ZERO, + ) + ) + ) + + return recursive_cte.union_all(recursive_elements) + + +@dataclass +class ProductOwnersLogEntry: + transaction: Transaction + user: User | None + products_left_to_account_for: int + + +def product_owners_log( + sql_session: Session, + product: Product, + use_cache: bool = True, + until_time: BindParameter[datetime] | datetime | None = None, + until_transaction: Transaction | None = None, + until_inclusive: bool = True, +) -> list[ProductOwnersLogEntry]: + """ + Returns a log of the product ownership calculation for the given product. + + If 'until' is given, only transactions up to that time are considered. + """ + + if product.id is None: + raise ValueError("Product must be persisted in the database.") + + recursive_cte = _product_owners_query( + product_id=product.id, + use_cache=use_cache, + until_time=until_time, + until_transaction=until_transaction, + until_inclusive=until_inclusive, + ) + + result = sql_session.execute( + select( + Transaction, + User, + recursive_cte.c.products_left_to_account_for, + ) + .select_from(recursive_cte) + .join( + Transaction, + onclause=Transaction.id == recursive_cte.c.transaction_id, + ) + .join( + User, + onclause=User.id == recursive_cte.c.user_id, + isouter=True, + ) + .order_by(recursive_cte.c.time.desc()) + ).all() + + if result is None: + # If there are no transactions for this product, the query should return an empty list, not None. + raise RuntimeError( + f"Something went wrong while calculating the owner log for product {product.name} (ID: {product.id})." + ) + + return [ + ProductOwnersLogEntry( + transaction=row[0], + user=row[1], + products_left_to_account_for=row[2], + ) + for row in result + ] + + +def product_owners( + sql_session: Session, + product: Product, + use_cache: bool = True, + until_time: BindParameter[datetime] | datetime | None = None, + until_transaction: Transaction | None = None, + until_inclusive: bool = True, +) -> list[User | None]: + """ + Returns an ordered list of users owning the given product. + + If 'until' is given, only transactions up to that time are considered. + """ + + if product.id is None: + raise ValueError("Product must be persisted in the database.") + + recursive_cte = _product_owners_query( + product_id=product.id, + use_cache=use_cache, + until_time=until_time, + until_transaction=until_transaction, + until_inclusive=until_inclusive, + ) + + db_result = sql_session.execute( + select( + recursive_cte.c.products_left_to_account_for, + recursive_cte.c.product_count, + User, + ) + .join(User, User.id == recursive_cte.c.user_id, isouter=True) + .order_by(recursive_cte.c.time.desc()) + ).all() + + print(db_result) + + result: list[User | None] = [] + none_count = 0 + + # We are moving backwards through history, but this is the order we want to return the list + # There are 3 cases: + # User is not none -> add user product_count times + # User is none, and product_count is not 0 -> add None product_count times + # User is none, and product_count is 0 -> check how much products are left to account for, + + # TODO: embed this into the query itself? + for products_left_to_account_for, product_count, user in db_result: + if user is not None: + if products_left_to_account_for < 0: + result.extend([user] * (product_count + products_left_to_account_for)) + else: + result.extend([user] * product_count) + elif product_count != 0: + if products_left_to_account_for < 0: + none_count += product_count + products_left_to_account_for + else: + none_count += product_count + else: + pass + + # none_count += user_count + # else: + + result.extend([None] * none_count) + + # # NOTE: if the last line exeeds the product count, we need to truncate it + # result.extend([user] * min(user_count, products_left_to_account_for)) + + # redistribute the user counts to a list of users + + return list(result) diff --git a/dibbler/queries/product_price.py b/dibbler/queries/product_price.py new file mode 100644 index 0000000..a488123 --- /dev/null +++ b/dibbler/queries/product_price.py @@ -0,0 +1,309 @@ +import math +from dataclasses import dataclass +from datetime import datetime + +from sqlalchemy import ( + BindParameter, + ColumnElement, + Integer, + bindparam, + case, + cast, + func, + select, +) +from sqlalchemy.orm import Session + +from dibbler.models import ( + Product, + Transaction, + TransactionType, +) +from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENT +from dibbler.queries.query_helpers import ( + CONST_NONE, + CONST_ONE, + CONST_ZERO, + until_filter, +) + + +def _product_price_query( + product_id: int | ColumnElement[int], + use_cache: bool = True, + until_time: BindParameter[datetime] | datetime | None = None, + until_transaction: Transaction | None = None, + until_inclusive: bool = True, + cte_name: str = "rec_cte", + trx_subset_name: str = "trx_subset", +): + """ + The inner query for calculating the product price. + """ + + if use_cache: + print("WARNING: Using cache for product price query is not implemented yet.") + + if isinstance(product_id, int): + product_id = BindParameter("product_id", value=product_id) + + if not (until_time is None or until_transaction is None): + raise ValueError("Cannot filter by both until_time and until_transaction.") + + if isinstance(until_time, datetime): + until_time = BindParameter("until_time", value=until_time) + + if isinstance(until_transaction, Transaction): + if until_transaction.id is None: + raise ValueError("until_transaction must be persisted in the database.") + until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id) + else: + until_transaction_id = None + + initial_element = select( + CONST_ZERO.label("i"), + CONST_ZERO.label("time"), + CONST_NONE.label("transaction_id"), + CONST_ZERO.label("price"), + CONST_ZERO.label("product_count"), + ) + + recursive_cte = initial_element.cte(name=cte_name, recursive=True) + + # Subset of transactions that we'll want to iterate over. + trx_subset = ( + select( + func.row_number().over(order_by=Transaction.time.asc()).label("i"), + Transaction.id, + Transaction.time, + Transaction.type_, + Transaction.product_count, + Transaction.per_product, + ) + .where( + Transaction.type_.in_( + [ + TransactionType.BUY_PRODUCT.as_literal_column(), + TransactionType.ADD_PRODUCT.as_literal_column(), + TransactionType.ADJUST_STOCK.as_literal_column(), + TransactionType.JOINT.as_literal_column(), + ] + ), + Transaction.product_id == product_id, + until_filter( + until_time=until_time, + until_transaction_id=until_transaction_id, + until_inclusive=until_inclusive, + ), + ) + .order_by(Transaction.time.asc()) + .subquery(trx_subset_name) + ) + + recursive_elements = ( + select( + trx_subset.c.i, + trx_subset.c.time, + trx_subset.c.id.label("transaction_id"), + case( + # Someone buys the product -> price remains the same. + ( + trx_subset.c.type_ == TransactionType.BUY_PRODUCT.as_literal_column(), + recursive_cte.c.price, + ), + # Someone adds the product -> price is recalculated based on + # product count, previous price, and new price. + ( + trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(), + cast( + func.ceil( + ( + recursive_cte.c.price + * func.max(recursive_cte.c.product_count, CONST_ZERO) + + trx_subset.c.per_product * trx_subset.c.product_count + ) + / ( + # The running product count can be negative if the accounting is bad. + # This ensures that we never end up with negative prices or zero divisions + # and other disastrous phenomena. + func.max(recursive_cte.c.product_count, CONST_ZERO) + + trx_subset.c.product_count + ) + ), + Integer, + ), + ), + # Someone adjusts the stock -> price remains the same. + ( + trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column(), + recursive_cte.c.price, + ), + # Should never happen + else_=recursive_cte.c.price, + ).label("price"), + case( + # Someone buys the product -> product count is reduced. + ( + trx_subset.c.type_ == TransactionType.BUY_PRODUCT.as_literal_column(), + recursive_cte.c.product_count - trx_subset.c.product_count, + ), + ( + trx_subset.c.type_ == TransactionType.JOINT.as_literal_column(), + recursive_cte.c.product_count - trx_subset.c.product_count, + ), + # Someone adds the product -> product count is increased. + ( + trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(), + recursive_cte.c.product_count + trx_subset.c.product_count, + ), + # Someone adjusts the stock -> product count is adjusted. + ( + trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column(), + recursive_cte.c.product_count + trx_subset.c.product_count, + ), + # Should never happen + else_=recursive_cte.c.product_count, + ).label("product_count"), + ) + .select_from(trx_subset) + .where(trx_subset.c.i == recursive_cte.c.i + CONST_ONE) + ) + + return recursive_cte.union_all(recursive_elements) + + +# TODO: create a function for the log that pretty prints the log entries +# for debugging purposes + + +@dataclass +class ProductPriceLogEntry: + transaction: Transaction + price: int + product_count: int + + +def product_price_log( + sql_session: Session, + product: Product, + use_cache: bool = True, + until_time: BindParameter[datetime] | datetime | None = None, + until_transaction: Transaction | None = None, + until_inclusive: bool = True, +) -> list[ProductPriceLogEntry]: + """ + Calculates the price of a product and returns a log of the price changes. + """ + + if product.id is None: + raise ValueError("Product must be persisted in the database.") + + recursive_cte = _product_price_query( + product.id, + use_cache=use_cache, + until_time=until_time, + until_transaction=until_transaction, + until_inclusive=until_inclusive, + ) + + result = sql_session.execute( + select( + Transaction, + recursive_cte.c.price, + recursive_cte.c.product_count, + ) + .select_from(recursive_cte) + .join( + Transaction, + onclause=Transaction.id == recursive_cte.c.transaction_id, + ) + .order_by(recursive_cte.c.i.asc()) + ).all() + + if result is None: + # If there are no transactions for this product, the query should return an empty list, not None. + raise RuntimeError( + f"Something went wrong while calculating the price log for product {product.name} (ID: {product.id})." + ) + + return [ + ProductPriceLogEntry( + transaction=row[0], + price=row.price, + product_count=row.product_count, + ) + for row in result + ] + + +def product_price( + sql_session: Session, + product: Product, + use_cache: bool = True, + until_time: BindParameter[datetime] | datetime | None = None, + until_transaction: Transaction | None = None, + until_inclusive: bool = True, + include_interest: bool = False, +) -> int: + """ + Calculates the price of a product. + """ + + if product.id is None: + raise ValueError("Product must be persisted in the database.") + + if isinstance(until_time, datetime): + until_time = BindParameter("until_time", value=until_time) + + if isinstance(until_transaction, Transaction): + if until_transaction.id is None: + raise ValueError("until_transaction must be persisted in the database.") + until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id) + else: + until_transaction_id = None + + recursive_cte = _product_price_query( + product.id, + use_cache=use_cache, + until_time=until_time, + until_transaction=until_transaction, + until_inclusive=until_inclusive, + ) + + # TODO: optionally verify subresults: + # - product_count should never be negative (but this happens sometimes, so just a warning) + # - price should never be negative + + result = sql_session.scalars( + select(recursive_cte.c.price) + .order_by(recursive_cte.c.i.desc()) + .limit(CONST_ONE) + .offset(CONST_ZERO) + ).one_or_none() + + if result is None: + # If there are no transactions for this product, the query should return 0, not None. + raise RuntimeError( + f"Something went wrong while calculating the price for product {product.name} (ID: {product.id})." + ) + + if include_interest: + interest_rate = ( + sql_session.scalar( + select(Transaction.interest_rate_percent) + .where( + Transaction.type_ == TransactionType.ADJUST_INTEREST, + until_filter( + until_time=until_time, + until_transaction_id=until_transaction_id, + until_inclusive=until_inclusive, + ), + ) + .order_by(Transaction.time.desc()) + .limit(CONST_ONE) + ) + or DEFAULT_INTEREST_RATE_PERCENT + ) + result = math.ceil(result * interest_rate / 100) + + return result diff --git a/dibbler/queries/product_stock.py b/dibbler/queries/product_stock.py new file mode 100644 index 0000000..870b8cf --- /dev/null +++ b/dibbler/queries/product_stock.py @@ -0,0 +1,126 @@ +from datetime import datetime +from typing import Tuple + +from sqlalchemy import ( + BindParameter, + Select, + bindparam, + case, + func, + select, +) +from sqlalchemy.orm import Session + +from dibbler.models import ( + Product, + Transaction, + TransactionType, +) +from dibbler.queries.query_helpers import until_filter + + +def _product_stock_query( + product_id: BindParameter[int] | int, + use_cache: bool = True, + until_time: BindParameter[datetime] | datetime | None = None, + until_transaction: Transaction | None = None, + until_inclusive: bool = True, +) -> Select[Tuple[int]]: + """ + The inner query for calculating the product stock. + """ + + if use_cache: + print("WARNING: Using cache for product stock query is not implemented yet.") + + if isinstance(product_id, int): + product_id = BindParameter("product_id", value=product_id) + + if not (until_time is None or until_transaction is None): + raise ValueError("Cannot filter by both until_time and until_transaction.") + + if isinstance(until_time, datetime): + until_time = BindParameter("until_time", value=until_time) + + if isinstance(until_transaction, Transaction): + if until_transaction.id is None: + raise ValueError("until_transaction must be persisted in the database.") + until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id) + else: + until_transaction_id = None + + query = select( + func.sum( + case( + ( + Transaction.type_ == TransactionType.ADD_PRODUCT.as_literal_column(), + Transaction.product_count, + ), + ( + Transaction.type_ == TransactionType.ADJUST_STOCK.as_literal_column(), + Transaction.product_count, + ), + ( + Transaction.type_ == TransactionType.BUY_PRODUCT.as_literal_column(), + -Transaction.product_count, + ), + ( + Transaction.type_ == TransactionType.JOINT.as_literal_column(), + -Transaction.product_count, + ), + ( + Transaction.type_ == TransactionType.THROW_PRODUCT.as_literal_column(), + -Transaction.product_count, + ), + else_=0, + ) + ).label("stock") + ).where( + Transaction.type_.in_( + [ + TransactionType.ADD_PRODUCT.as_literal_column(), + TransactionType.ADJUST_STOCK.as_literal_column(), + TransactionType.BUY_PRODUCT.as_literal_column(), + TransactionType.JOINT.as_literal_column(), + TransactionType.THROW_PRODUCT.as_literal_column(), + ] + ), + Transaction.product_id == product_id, + until_filter( + until_time=until_time, + until_transaction_id=until_transaction_id, + until_inclusive=until_inclusive, + ), + ) + + return query + + +def product_stock( + sql_session: Session, + product: Product, + use_cache: bool = True, + until_time: BindParameter[datetime] | datetime | None = None, + until_transaction: Transaction | None = None, + until_inclusive: bool = True, +) -> int: + """ + Returns the number of products in stock. + + If 'until' is given, only transactions up to that time are considered. + """ + + if product.id is None: + raise ValueError("Product must be persisted in the database.") + + query = _product_stock_query( + product_id=product.id, + use_cache=use_cache, + until_time=until_time, + until_transaction=until_transaction, + until_inclusive=until_inclusive, + ) + + result = sql_session.scalars(query).one_or_none() + + return result or 0 diff --git a/dibbler/queries/query_helpers.py b/dibbler/queries/query_helpers.py new file mode 100644 index 0000000..1feb69a --- /dev/null +++ b/dibbler/queries/query_helpers.py @@ -0,0 +1,80 @@ +from datetime import datetime +from typing import TypeVar + +from sqlalchemy import ( + BindParameter, + ColumnExpressionArgument, + literal, + select, +) +from sqlalchemy.orm import QueryableAttribute + +from dibbler.models import Transaction + +T = TypeVar("T") + + +def const(value: T) -> BindParameter[T]: + """ + Create a constant SQL literal bind parameter. + + This is useful to avoid too many `?` bind parameters in SQL queries, + when the input value is known to be safe. + """ + + return literal(value, literal_execute=True) + + +CONST_ZERO: BindParameter[int] = const(0) +"""A constant SQL expression `0`. This will render as a literal `0` in SQL queries.""" + +CONST_ONE: BindParameter[int] = const(1) +"""A constant SQL expression `1`. This will render as a literal `1` in SQL queries.""" + +CONST_TRUE: BindParameter[bool] = const(True) +"""A constant SQL expression `TRUE`. This will render as a literal `TRUE` in SQL queries.""" + +CONST_FALSE: BindParameter[bool] = const(False) +"""A constant SQL expression `FALSE`. This will render as a literal `FALSE` in SQL queries.""" + +CONST_NONE: BindParameter[None] = const(None) +"""A constant SQL expression `NULL`. This will render as a literal `NULL` in SQL queries.""" + + +def until_filter( + until_time: BindParameter[datetime] | None = None, + until_transaction_id: BindParameter[int] | None = None, + until_inclusive: bool = True, + transaction_time: QueryableAttribute = Transaction.time, +) -> ColumnExpressionArgument[bool]: + """ + Create a filter condition for transactions up to a given time or transaction. + + Only one of `until_time` or `until_transaction_id` may be specified. + """ + + assert not (until_time is not None and until_transaction_id is not None), ( + "Cannot filter by both until_time and until_transaction_id." + ) + + match (until_time, until_transaction_id, until_inclusive): + case (BindParameter(), None, True): + return transaction_time <= until_time + case (BindParameter(), None, False): + return transaction_time < until_time + case (None, BindParameter(), True): + return ( + transaction_time + <= select(Transaction.time) + .where(Transaction.id == until_transaction_id) + .scalar_subquery() + ) + case (None, BindParameter(), False): + return ( + transaction_time + < select(Transaction.time) + .where(Transaction.id == until_transaction_id) + .scalar_subquery() + ) + + return CONST_TRUE diff --git a/dibbler/queries/search_product.py b/dibbler/queries/search_product.py new file mode 100644 index 0000000..d29b363 --- /dev/null +++ b/dibbler/queries/search_product.py @@ -0,0 +1,42 @@ +from sqlalchemy import and_, literal, not_, or_, select +from sqlalchemy.orm import Session + +from dibbler.models import Product + + +def search_product( + string: str, + sql_session: Session, + find_hidden_products=False, +) -> Product | list[Product]: + if not string: + raise ValueError("Search string cannot be empty.") + + exact_match = sql_session.scalars( + select(Product).where( + or_( + Product.bar_code == string, + and_( + Product.name == string, + literal(True) if find_hidden_products else not_(Product.hidden), + ), + ) + ) + ).first() + + if exact_match: + return exact_match + + product_list = sql_session.scalars( + select(Product).where( + or_( + Product.bar_code.ilike(f"%{string}%"), + and_( + Product.name.ilike(f"%{string}%"), + literal(True) if find_hidden_products else not_(Product.hidden), + ), + ) + ) + ).all() + + return list(product_list) diff --git a/dibbler/queries/search_user.py b/dibbler/queries/search_user.py new file mode 100644 index 0000000..78219c2 --- /dev/null +++ b/dibbler/queries/search_user.py @@ -0,0 +1,39 @@ +from sqlalchemy import or_, select +from sqlalchemy.orm import Session + +from dibbler.models import User + + +def search_user( + string: str, + sql_session: Session, +) -> User | list[User]: + if not string: + raise ValueError("Search string cannot be empty.") + + string = string.lower() + + exact_match = sql_session.scalars( + select(User).where( + or_( + User.name == string, + User.card == string, + User.rfid == string, + ) + ) + ).first() + + if exact_match: + return exact_match + + user_list = sql_session.scalars( + select(User).where( + or_( + User.name.ilike(f"%{string}%"), + User.card.ilike(f"%{string}%"), + User.rfid.ilike(f"%{string}%"), + ) + ) + ).all() + + return list(user_list) diff --git a/dibbler/queries/throw_product.py b/dibbler/queries/throw_product.py new file mode 100644 index 0000000..18cd167 --- /dev/null +++ b/dibbler/queries/throw_product.py @@ -0,0 +1,42 @@ +from datetime import datetime + +from sqlalchemy.orm import Session + +from dibbler.models import Product, Transaction, User + + +def throw_product( + sql_session: Session, + user: User, + product: Product, + product_count: int, + time: datetime | None = None, + message: str | None = None, +) -> Transaction: + if user.id is None: + raise ValueError("User must be persisted in the database.") + + if product.id is None: + raise ValueError("Product must be persisted in the database.") + + if product_count <= 0: + raise ValueError("Product count must be positive.") + + # TODO: verify time is not behind last transaction's time + + raise NotImplementedError( + "Please don't use this function until relevant calculations have been added to user_balance." + ) + + transaction = Transaction.throw_product( + user_id=user.id, + product_id=product.id, + product_count=product_count, + time=time, + message=message, + ) + + sql_session.add(transaction) + sql_session.commit() + + return transaction diff --git a/dibbler/queries/transaction_log.py b/dibbler/queries/transaction_log.py new file mode 100644 index 0000000..a2619c1 --- /dev/null +++ b/dibbler/queries/transaction_log.py @@ -0,0 +1,142 @@ +from datetime import datetime + +from sqlalchemy import BindParameter, select +from sqlalchemy.orm import Session + +from dibbler.models import ( + Product, + Transaction, + TransactionType, + User, +) + + +# TODO: should this include full joint transactions that involve a user? +# TODO: should this involve throw-away transactions that affects a user? +def transaction_log( + sql_session: Session, + user: User | None = None, + product: Product | None = None, + until_time: BindParameter[datetime] | datetime | None = None, + until_transaction: Transaction | None = None, + until_inclusive: bool = True, + after_time: BindParameter[datetime] | datetime | None = None, + after_transaction: Transaction | None = None, + after_inlcusive: bool = True, + transaction_type: list[TransactionType] | None = None, + negate_transaction_type_filter: bool = False, + limit: int | None = None, +) -> list[Transaction]: + """ + Retrieve the transaction log, optionally filtered. + + Only one of `user` or `product` may be specified. + Only one of `until_time` or `until_transaction_id` may be specified. + Only one of `after_time` or `after_transaction_id` may be specified. + + The after and after filters are inclusive by default. + """ + + if not (user is None or product is None): + raise ValueError("Cannot filter by both user and product.") + + if isinstance(user, User): + if user.id is None: + raise ValueError("User must be persisted in the database.") + user_id = BindParameter("user_id", value=user.id) + else: + user_id = None + + if isinstance(product, Product): + if product.id is None: + raise ValueError("Product must be persisted in the database.") + product_id = BindParameter("product_id", value=product.id) + else: + product_id = None + + if not (until_time is None or until_transaction is None): + raise ValueError("Cannot filter by both after_time and after_transaction_id.") + + if isinstance(until_time, datetime): + until_time = BindParameter("until_time", value=until_time) + + if isinstance(until_transaction, Transaction): + if until_transaction.id is None: + raise ValueError("until_transaction must be persisted in the database.") + until_transaction_id = BindParameter("until_transaction_id", value=until_transaction.id) + else: + until_transaction_id = None + + if not (after_time is None or after_transaction is None): + raise ValueError("Cannot filter by both after_time and after_transaction_id.") + + if isinstance(after_time, datetime): + after_time = BindParameter("after_time", value=after_time) + + if isinstance(after_transaction, Transaction): + if after_transaction.id is None: + raise ValueError("after_transaction must be persisted in the database.") + after_transaction_id = BindParameter("after_transaction_id", value=after_transaction.id) + else: + after_transaction_id = None + + if after_time is not None and until_time is not None: + assert isinstance(after_time.value, datetime) + assert isinstance(until_time.value, datetime) + + if after_time.value > until_time.value: + raise ValueError("after_time cannot be after until_time.") + + if after_transaction is not None and until_transaction is not None: + assert after_transaction.time is not None + assert until_transaction.time is not None + + if after_transaction.time > until_transaction.time: + raise ValueError("after_transaction cannot be after until_transaction.") + + if limit is not None and limit <= 0: + raise ValueError("Limit must be positive.") + + query = select(Transaction) + if user is not None: + query = query.where(Transaction.user_id == user_id) + if product is not None: + query = query.where(Transaction.product_id == product_id) + + match (until_time, until_transaction_id, until_inclusive): + case (BindParameter(), None, True): + query = query.where(Transaction.time <= until_time) + case (BindParameter(), None, False): + query = query.where(Transaction.time < until_time) + case (None, BindParameter(), True): + query = query.where(Transaction.id <= until_transaction_id) + case (None, BindParameter(), False): + query = query.where(Transaction.id < until_transaction_id) + case _: + pass + + match (after_time, after_transaction_id, after_inlcusive): + case (BindParameter(), None, True): + query = query.where(Transaction.time >= after_time) + case (BindParameter(), None, False): + query = query.where(Transaction.time > after_time) + case (None, BindParameter(), True): + query = query.where(Transaction.id >= after_transaction_id) + case (None, BindParameter(), False): + query = query.where(Transaction.id > after_transaction_id) + case _: + pass + + if transaction_type is not None: + if negate_transaction_type_filter: + query = query.where(~Transaction.type_.in_(transaction_type)) + else: + query = query.where(Transaction.type_.in_(transaction_type)) + + if limit is not None: + query = query.limit(limit) + + query = query.order_by(Transaction.time.asc(), Transaction.id.asc()) + result = sql_session.scalars(query).all() + + return list(result) diff --git a/dibbler/queries/transfer.py b/dibbler/queries/transfer.py new file mode 100644 index 0000000..12e952b --- /dev/null +++ b/dibbler/queries/transfer.py @@ -0,0 +1,38 @@ +from datetime import datetime + +from sqlalchemy.orm import Session + +from dibbler.models import Transaction, User + + +def transfer( + sql_session: Session, + from_user: User, + to_user: User, + amount: int, + time: datetime | None = None, + message: str | None = None, +) -> Transaction: + if from_user.id is None: + raise ValueError("From user must be persisted in the database.") + + if to_user.id is None: + raise ValueError("To user must be persisted in the database.") + + if amount <= 0: + raise ValueError("Amount must be positive.") + + # TODO: verify time is not behind last transaction's time + + transaction = Transaction.transfer( + user_id=from_user.id, + transfer_user_id=to_user.id, + amount=amount, + time=time, + message=message, + ) + + sql_session.add(transaction) + sql_session.commit() + + return transaction diff --git a/dibbler/queries/user_balance.py b/dibbler/queries/user_balance.py new file mode 100644 index 0000000..67b7171 --- /dev/null +++ b/dibbler/queries/user_balance.py @@ -0,0 +1,567 @@ +from dataclasses import dataclass +from datetime import datetime +from typing import Tuple + +from sqlalchemy import ( + CTE, + BindParameter, + Float, + Integer, + Select, + and_, + bindparam, + case, + cast, + column, + func, + or_, + select, +) +from sqlalchemy.orm import Session, aliased +from sqlalchemy.sql.elements import KeyedColumnElement + +from dibbler.models import ( + Transaction, + TransactionType, + User, +) +from dibbler.models.Transaction import ( + DEFAULT_INTEREST_RATE_PERCENT, + DEFAULT_PENALTY_MULTIPLIER_PERCENT, + DEFAULT_PENALTY_THRESHOLD, +) +from dibbler.queries.product_price import _product_price_query +from dibbler.queries.query_helpers import ( + CONST_NONE, + CONST_ONE, + CONST_ZERO, + const, + until_filter, +) + + +def _joint_transaction_query( + user_id: BindParameter[int] | int, + use_cache: bool = True, + until_time: BindParameter[datetime] | None = None, + until_transaction: Transaction | None = None, + until_inclusive: bool = True, +) -> Select[Tuple[int, int, int]]: + """ + The inner query for getting joint transactions relevant to a user. + + This scans for JOINT_BUY_PRODUCT transactions made by the user, + then finds the corresponding JOINT transactions, and counts how many "shares" + of the joint transaction the user has, as well as the total number of shares. + """ + + if isinstance(until_transaction, Transaction): + if until_transaction.id is None: + raise ValueError("until_transaction must be persisted in the database.") + until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id) + else: + until_transaction_id = None + + # First, select all joint buy product transactions for the given user + # sub_joint_transaction = aliased(Transaction, name="right_trx") + sub_joint_transaction = ( + select(Transaction.joint_transaction_id.distinct().label("joint_transaction_id")) + .where( + Transaction.type_ == TransactionType.JOINT_BUY_PRODUCT.as_literal_column(), + Transaction.user_id == user_id, + until_filter( + until_time=until_time, + until_transaction_id=until_transaction_id, + until_inclusive=until_inclusive, + transaction_time=Transaction.time, + ), + ) + .subquery("sub_joint_transaction") + ) + + # Join those with their main joint transaction + # (just use Transaction) + + # Then, count how many users are involved in each joint transaction + joint_transaction_count = aliased(Transaction, name="count_trx") + + joint_transaction = ( + select( + Transaction.id, + # Shares the user has in the transaction, + func.sum( + case( + (joint_transaction_count.user_id == user_id, CONST_ONE), + else_=CONST_ZERO, + ) + ).label("user_shares"), + # The total number of shares in the transaction, + func.count(joint_transaction_count.id).label("user_count"), + ) + .select_from(sub_joint_transaction) + .join( + Transaction, + onclause=Transaction.id == sub_joint_transaction.c.joint_transaction_id, + ) + .join( + joint_transaction_count, + onclause=joint_transaction_count.joint_transaction_id == Transaction.id, + ) + .group_by(joint_transaction_count.joint_transaction_id) + ) + + return joint_transaction + + +def _non_joint_transaction_query( + user_id: BindParameter[int] | int, + use_cache: bool = True, + until_time: BindParameter[datetime] | None = None, + until_transaction: Transaction | None = None, + until_inclusive: bool = True, +) -> Select[Tuple[int, None, None]]: + """ + The inner query for getting non-joint transactions relevant to a user. + """ + + if isinstance(until_transaction, Transaction): + if until_transaction.id is None: + raise ValueError("until_transaction must be persisted in the database.") + until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id) + else: + until_transaction_id = None + + query = select( + Transaction.id, + CONST_NONE.label("user_shares"), + CONST_NONE.label("user_count"), + ).where( + or_( + and_( + Transaction.user_id == user_id, + Transaction.type_.in_( + [ + TransactionType.ADD_PRODUCT.as_literal_column(), + TransactionType.ADJUST_BALANCE.as_literal_column(), + TransactionType.BUY_PRODUCT.as_literal_column(), + TransactionType.TRANSFER.as_literal_column(), + ] + ), + ), + and_( + Transaction.type_ == TransactionType.TRANSFER.as_literal_column(), + Transaction.transfer_user_id == user_id, + ), + Transaction.type_.in_( + [ + TransactionType.THROW_PRODUCT.as_literal_column(), + TransactionType.ADJUST_INTEREST.as_literal_column(), + TransactionType.ADJUST_PENALTY.as_literal_column(), + ] + ), + ), + until_filter( + until_time=until_time, + until_transaction_id=until_transaction_id, + until_inclusive=until_inclusive, + ), + ) + + return query + + +def _product_cost_expression( + product_count_column: KeyedColumnElement[int], + product_id_column: KeyedColumnElement[int], + interest_rate_percent_column: KeyedColumnElement[int], + user_balance_column: KeyedColumnElement[int], + penalty_threshold_column: KeyedColumnElement[int], + penalty_multiplier_percent_column: KeyedColumnElement[int], + joint_user_shares_column: KeyedColumnElement[int], + joint_user_count_column: KeyedColumnElement[int], + use_cache: bool = True, + until_time: BindParameter[datetime] | None = None, + until_transaction: Transaction | None = None, + until_inclusive: bool = True, + cte_name: str = "product_price_cte", + trx_subset_name: str = "product_price_trx_subset", +): + # TODO: This can get quite expensive real quick, so we should do some caching of the + # product prices somehow. + expression = ( + select( + cast( + func.ceil( + # Base price + ( + cast( + column("price") * product_count_column * joint_user_shares_column, + Float, + ) + / joint_user_count_column + ) + # Interest + + ( + cast( + column("price") * product_count_column * joint_user_shares_column, + Float, + ) + / joint_user_count_column + * cast(interest_rate_percent_column - const(100), Float) + / const(100.0) + ) + # Penalty + + ( + ( + cast( + column("price") * product_count_column * joint_user_shares_column, + Float, + ) + / joint_user_count_column + ) + * cast(penalty_multiplier_percent_column - const(100), Float) + / const(100.0) + * cast(user_balance_column < penalty_threshold_column, Integer) + ) + ), + Integer, + ) + ) + .select_from( + _product_price_query( + product_id_column, + use_cache=use_cache, + until_time=until_time, + until_transaction=until_transaction, + until_inclusive=until_inclusive, + cte_name=cte_name, + trx_subset_name=trx_subset_name, + ) + ) + .order_by(column("i").desc()) + .limit(CONST_ONE) + .scalar_subquery() + ) + + return expression + + +def _user_balance_query( + user_id: BindParameter[int] | int, + use_cache: bool = True, + until_time: BindParameter[datetime] | None = None, + until_transaction: Transaction | None = None, + until_inclusive: bool = True, + cte_name: str = "rec_cte", + trx_subset_name: str = "trx_subset", +) -> CTE: + """ + The inner query for calculating the user's balance. + """ + + if use_cache: + print("WARNING: Using cache for user balance query is not implemented yet.") + + if isinstance(user_id, int): + user_id = BindParameter("user_id", value=user_id) + + initial_element = select( + CONST_ZERO.label("i"), + CONST_ZERO.label("time"), + CONST_NONE.label("transaction_id"), + CONST_ZERO.label("balance"), + const(DEFAULT_INTEREST_RATE_PERCENT).label("interest_rate_percent"), + const(DEFAULT_PENALTY_THRESHOLD).label("penalty_threshold"), + const(DEFAULT_PENALTY_MULTIPLIER_PERCENT).label("penalty_multiplier_percent"), + ) + + recursive_cte = initial_element.cte(name=cte_name, recursive=True) + + trx_subset_subset = ( + _non_joint_transaction_query( + user_id=user_id, + use_cache=use_cache, + until_time=until_time, + until_transaction=until_transaction, + until_inclusive=until_inclusive, + ) + .union_all( + _joint_transaction_query( + user_id=user_id, + use_cache=use_cache, + until_time=until_time, + until_transaction=until_transaction, + until_inclusive=until_inclusive, + ) + ) + .subquery(f"{trx_subset_name}_subset") + ) + + # Subset of transactions that we'll want to iterate over. + trx_subset = ( + select( + func.row_number().over(order_by=Transaction.time.asc()).label("i"), + Transaction.id, + Transaction.amount, + Transaction.interest_rate_percent, + Transaction.penalty_multiplier_percent, + Transaction.penalty_threshold, + Transaction.product_count, + Transaction.product_id, + Transaction.time, + Transaction.transfer_user_id, + Transaction.type_, + trx_subset_subset.c.user_shares, + trx_subset_subset.c.user_count, + ) + .select_from(trx_subset_subset) + .join( + Transaction, + onclause=Transaction.id == trx_subset_subset.c.id, + ) + .order_by(Transaction.time.asc()) + .subquery(trx_subset_name) + ) + + recursive_elements = ( + select( + trx_subset.c.i, + trx_subset.c.time, + trx_subset.c.id.label("transaction_id"), + case( + # Adjusts balance -> balance gets adjusted + ( + trx_subset.c.type_ == TransactionType.ADJUST_BALANCE.as_literal_column(), + recursive_cte.c.balance + trx_subset.c.amount, + ), + # Adds a product -> balance increases + ( + trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(), + recursive_cte.c.balance + trx_subset.c.amount, + ), + # Buys a product -> balance decreases + ( + trx_subset.c.type_ == TransactionType.BUY_PRODUCT.as_literal_column(), + recursive_cte.c.balance + - _product_cost_expression( + product_count_column=trx_subset.c.product_count, + product_id_column=trx_subset.c.product_id, + interest_rate_percent_column=recursive_cte.c.interest_rate_percent, + user_balance_column=recursive_cte.c.balance, + penalty_threshold_column=recursive_cte.c.penalty_threshold, + penalty_multiplier_percent_column=recursive_cte.c.penalty_multiplier_percent, + joint_user_shares_column=CONST_ONE, + joint_user_count_column=CONST_ONE, + use_cache=use_cache, + until_time=until_time, + until_transaction=until_transaction, + until_inclusive=until_inclusive, + cte_name=f"{cte_name}_price", + trx_subset_name=f"{trx_subset_name}_price", + ).label("product_cost"), + ), + # Joint transaction -> balance decreases proportionally + ( + trx_subset.c.type_ == TransactionType.JOINT.as_literal_column(), + recursive_cte.c.balance + - _product_cost_expression( + product_count_column=trx_subset.c.product_count, + product_id_column=trx_subset.c.product_id, + interest_rate_percent_column=recursive_cte.c.interest_rate_percent, + user_balance_column=recursive_cte.c.balance, + penalty_threshold_column=recursive_cte.c.penalty_threshold, + penalty_multiplier_percent_column=recursive_cte.c.penalty_multiplier_percent, + joint_user_shares_column=trx_subset.c.user_shares, + joint_user_count_column=trx_subset.c.user_count, + use_cache=use_cache, + until_time=until_time, + until_transaction=until_transaction, + until_inclusive=until_inclusive, + cte_name=f"{cte_name}_joint_price", + trx_subset_name=f"{trx_subset_name}_joint_price", + ).label("joint_product_cost"), + ), + # Transfers money to self -> balance increases + ( + and_( + trx_subset.c.type_ == TransactionType.TRANSFER.as_literal_column(), + trx_subset.c.transfer_user_id == user_id, + ), + recursive_cte.c.balance + trx_subset.c.amount, + ), + # Transfers money from self -> balance decreases + ( + and_( + trx_subset.c.type_ == TransactionType.TRANSFER.as_literal_column(), + trx_subset.c.transfer_user_id != user_id, + ), + recursive_cte.c.balance - trx_subset.c.amount, + ), + # Throws a product -> if the user is considered to have bought it, balance increases + # TODO: # ( + # trx_subset.c.type_ == TransactionType.THROW_PRODUCT, + # recursive_cte.c.balance + trx_subset.c.amount, + # ), + # Interest adjustment -> balance stays the same + # Penalty adjustment -> balance stays the same + else_=recursive_cte.c.balance, + ).label("balance"), + case( + ( + trx_subset.c.type_ == TransactionType.ADJUST_INTEREST.as_literal_column(), + trx_subset.c.interest_rate_percent, + ), + else_=recursive_cte.c.interest_rate_percent, + ).label("interest_rate_percent"), + case( + ( + trx_subset.c.type_ == TransactionType.ADJUST_PENALTY.as_literal_column(), + trx_subset.c.penalty_threshold, + ), + else_=recursive_cte.c.penalty_threshold, + ).label("penalty_threshold"), + case( + ( + trx_subset.c.type_ == TransactionType.ADJUST_PENALTY.as_literal_column(), + trx_subset.c.penalty_multiplier_percent, + ), + else_=recursive_cte.c.penalty_multiplier_percent, + ).label("penalty_multiplier_percent"), + ) + .select_from(trx_subset) + .where(trx_subset.c.i == recursive_cte.c.i + CONST_ONE) + ) + + return recursive_cte.union_all(recursive_elements) + + +# TODO: create a function for the log that pretty prints the log entries +# for debugging purposes + + +@dataclass +class UserBalanceLogEntry: + transaction: Transaction + balance: int + interest_rate_percent: int + penalty_threshold: int + penalty_multiplier_percent: int + + def is_penalized(self) -> bool: + """ + Returns whether this exact transaction is penalized. + """ + + raise NotImplementedError("is_penalized is not implemented yet.") + + +def user_balance_log( + sql_session: Session, + user: User, + use_cache: bool = True, + until_time: BindParameter[datetime] | datetime | None = None, + until_transaction: Transaction | None = None, + until_inclusive: bool = True, +) -> list[UserBalanceLogEntry]: + """ + Returns a log of the user's balance over time, including interest and penalty adjustments. + + If 'until' is given, only transactions up to that time are considered. + """ + + if user.id is None: + raise ValueError("User must be persisted in the database.") + + if not (until_time is None or until_transaction is None): + raise ValueError("Cannot filter by both until_time and until_transaction.") + + if isinstance(until_time, datetime): + until_time = BindParameter("until_time", value=until_time) + + recursive_cte = _user_balance_query( + user.id, + use_cache=use_cache, + until_time=until_time, + until_transaction=until_transaction, + until_inclusive=until_inclusive, + ) + + result = sql_session.execute( + select( + Transaction, + recursive_cte.c.balance, + recursive_cte.c.interest_rate_percent, + recursive_cte.c.penalty_threshold, + recursive_cte.c.penalty_multiplier_percent, + ) + .select_from(recursive_cte) + .join( + Transaction, + onclause=Transaction.id == recursive_cte.c.transaction_id, + ) + .order_by(recursive_cte.c.i.asc()) + ).all() + + if result is None: + # If there are no transactions for this user, the query should return 0, not None. + raise RuntimeError( + f"Something went wrong while calculating the balance for user {user.name} (ID: {user.id})." + ) + + return [ + UserBalanceLogEntry( + transaction=row[0], + balance=row.balance, + interest_rate_percent=row.interest_rate_percent, + penalty_threshold=row.penalty_threshold, + penalty_multiplier_percent=row.penalty_multiplier_percent, + ) + for row in result + ] + + +def user_balance( + sql_session: Session, + user: User, + use_cache: bool = True, + until_time: BindParameter[datetime] | datetime | None = None, + until_transaction: Transaction | None = None, + until_inclusive: bool = True, +) -> int: + """ + Calculates the balance of a user. + + If 'until' is given, only transactions up to that time are considered. + """ + + if user.id is None: + raise ValueError("User must be persisted in the database.") + + if not (until_time is None or until_transaction is None): + raise ValueError("Cannot filter by both until_time and until_transaction.") + + if isinstance(until_time, datetime): + until_time = BindParameter("until_time", value=until_time) + + recursive_cte = _user_balance_query( + user.id, + use_cache=use_cache, + until_time=until_time, + until_transaction=until_transaction, + until_inclusive=until_inclusive, + ) + + result = sql_session.scalar( + select(recursive_cte.c.balance) + .order_by(recursive_cte.c.i.desc()) + .limit(CONST_ONE) + .offset(CONST_ZERO) + ) + + if result is None: + # If there are no transactions for this user, the query should return 0, not None. + raise RuntimeError( + f"Something went wrong while calculating the balance for user {user.name} (ID: {user.id})." + ) + + return result diff --git a/dibbler/queries/user_products.py b/dibbler/queries/user_products.py new file mode 100644 index 0000000..03b200c --- /dev/null +++ b/dibbler/queries/user_products.py @@ -0,0 +1,48 @@ +from datetime import datetime + +from sqlalchemy import BindParameter, bindparam +from sqlalchemy.orm import Session + +from dibbler.models import Product, Transaction, User + +# NOTE: This absoulutely needs a cache, else we can't stop recursing until we know all owners for all products... +# +# Since we know that the non-owned products will not get renowned by the user by other means, +# we can just check for ownership on the products that have an ADD_PRODUCT transaction for the user. +# between now and the cached time. +# +# However, the opposite way is more difficult. The cache will store which products are owned by which users, +# but we still need to check if the user passes out of ownership for the item, without needing to check past +# the cache time. Maybe we also need to store the queue number(s) per user/product combo in the cache? What if +# a user has products multiple places in the queue, interleaved with other users? + + +def user_products( + sql_session: Session, + user: User, + use_cache: bool = True, + until_time: BindParameter[datetime] | datetime | None = None, + until_transaction: Transaction | None = None, + until_inclusive: bool = True, +) -> list[tuple[Product, int]]: + """ + Returns the list of products owned by the user, along with how many of each product they own. + """ + + if user.id is None: + raise ValueError("User must be persisted in the database.") + + if not (until_time is None or until_transaction is None): + raise ValueError("Cannot filter by both until_time and until_transaction.") + + if isinstance(until_time, datetime): + until_time = BindParameter("until_time", value=until_time) + + if isinstance(until_transaction, Transaction): + if until_transaction.id is None: + raise ValueError("until_transaction must be persisted in the database.") + until_transaction_id = bindparam("until_transaction_id", value=until_transaction.id) + else: + until_transaction_id = None + + raise NotImplementedError("Not implemented yet, needs caching system first.") diff --git a/tests/helpers.py b/tests/helpers.py new file mode 100644 index 0000000..c07ddcd --- /dev/null +++ b/tests/helpers.py @@ -0,0 +1,28 @@ +from datetime import datetime, timedelta + +from dibbler.models import Transaction + + +def assign_times( + transactions: list[Transaction], + start_time: datetime = datetime(2024, 1, 1, 0, 0, 0), + delta: timedelta = timedelta(minutes=1), +) -> None: + """Assigns datetimes to a list of transactions starting from start_time and incrementing by delta.""" + current_time = start_time + for transaction in transactions: + transaction.time = current_time + current_time += delta + + +def assert_id_order_similar_to_time_order(transactions: list[Transaction]) -> None: + """Asserts that the order of transaction IDs is similar to the order of their timestamps.""" + sorted_by_time = sorted(transactions, key=lambda t: t.time) + sorted_by_id = sorted(transactions, key=lambda t: t.id) + + for t1, t2 in zip(sorted_by_time, sorted_by_id): + assert t1.id == t2.id or t1.time == t2.time, ( + f"Transaction ID order does not match time order:\n" + f"ID {t1.id} at time {t1.time}\n" + f"ID {t2.id} at time {t2.time}" + ) diff --git a/tests/queries/__init__.py b/tests/queries/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/queries/test_adjust_interest.py b/tests/queries/test_adjust_interest.py new file mode 100644 index 0000000..0f9242c --- /dev/null +++ b/tests/queries/test_adjust_interest.py @@ -0,0 +1,85 @@ +from datetime import datetime, timedelta + +import pytest +from sqlalchemy.orm import Session + +from dibbler.models import Transaction, User +from dibbler.queries import adjust_interest, current_interest + + +def insert_test_data(sql_session: Session) -> User: + user = User("Test User") + sql_session.add(user) + sql_session.commit() + + return user + + +def test_adjust_interest_unitialized_user(sql_session: Session) -> None: + user = User("Uninitialized User") + + with pytest.raises(ValueError, match="User must be persisted in the database."): + adjust_interest( + sql_session, + user=user, + new_interest=4, + message="Attempting to adjust interest for uninitialized user", + ) + + +def test_adjust_interest_no_history(sql_session: Session) -> None: + user = insert_test_data(sql_session) + + adjust_interest( + sql_session, + user=user, + new_interest=3, + message="Setting initial interest rate", + ) + sql_session.commit() + + current_interest_rate = current_interest(sql_session) + + assert current_interest_rate == 3 + + +def test_adjust_interest_existing_history(sql_session: Session) -> None: + user = insert_test_data(sql_session) + + transactions = [ + Transaction.adjust_interest( + time=datetime(2023, 10, 1, 9, 0, 0), + user_id=user.id, + interest_rate_percent=5, + message="Initial interest rate", + ), + ] + sql_session.add_all(transactions) + sql_session.commit() + + current_interest_rate = current_interest(sql_session) + assert current_interest_rate == 5 + + adjust_interest( + sql_session, + user=user, + new_interest=2, + message="Adjusting interest rate", + time=transactions[-1].time + timedelta(days=1), + ) + sql_session.commit() + + current_interest_rate = current_interest(sql_session) + assert current_interest_rate == 2 + + +def test_adjust_interest_negative_failure(sql_session: Session) -> None: + user = insert_test_data(sql_session) + + with pytest.raises(ValueError, match="Interest rate cannot be negative"): + adjust_interest( + sql_session, + user=user, + new_interest=-1, + message="Attempting to set negative interest rate", + ) diff --git a/tests/queries/test_adjust_penalty.py b/tests/queries/test_adjust_penalty.py new file mode 100644 index 0000000..7575a50 --- /dev/null +++ b/tests/queries/test_adjust_penalty.py @@ -0,0 +1,179 @@ +from datetime import datetime, timedelta + +import pytest +from sqlalchemy.orm import Session + +from dibbler.models import Transaction, User +from dibbler.models.Transaction import ( + DEFAULT_PENALTY_MULTIPLIER_PERCENT, + DEFAULT_PENALTY_THRESHOLD, +) +from dibbler.queries import adjust_penalty, current_penalty + + +def insert_test_data(sql_session: Session) -> User: + user = User("Test User") + sql_session.add(user) + sql_session.commit() + + return user + + +def test_adjust_penalty_empty_not_allowed(sql_session: Session) -> None: + user = insert_test_data(sql_session) + + with pytest.raises(ValueError): + adjust_penalty( + sql_session, + user=user, + message="No penalty or multiplier provided", + ) + + +def test_adjust_penalty_unitialized_user(sql_session: Session) -> None: + user = User("Uninitialized User") + + with pytest.raises(ValueError): + adjust_penalty( + sql_session, + user=user, + new_penalty=-100, + new_penalty_multiplier=110, + message="Attempting to adjust penalty for uninitialized user", + ) + + +def test_adjust_penalty_no_history(sql_session: Session) -> None: + user = insert_test_data(sql_session) + + adjust_penalty( + sql_session, + user=user, + new_penalty=-200, + message="Setting initial interest rate", + ) + sql_session.commit() + + (penalty, multiplier) = current_penalty(sql_session) + + assert penalty == -200 + assert multiplier == DEFAULT_PENALTY_MULTIPLIER_PERCENT + + +def test_adjust_penalty_multiplier_no_history(sql_session: Session) -> None: + user = insert_test_data(sql_session) + + adjust_penalty( + sql_session, + user=user, + new_penalty_multiplier=125, + message="Setting initial interest rate", + ) + sql_session.commit() + + (penalty, multiplier) = current_penalty(sql_session) + + assert penalty == DEFAULT_PENALTY_THRESHOLD + assert multiplier == 125 + + +def test_adjust_penalty_multiplier_less_than_100_fail(sql_session: Session) -> None: + user = insert_test_data(sql_session) + + adjust_penalty( + sql_session, + user=user, + new_penalty_multiplier=100, + message="Setting initial interest rate", + ) + sql_session.commit() + + (_, multiplier) = current_penalty(sql_session) + + assert multiplier == 100 + + with pytest.raises(ValueError, match="Penalty multiplier cannot be less than 100%"): + adjust_penalty( + sql_session, + user=user, + new_penalty_multiplier=99, + message="Setting initial interest rate", + ) + + +def test_adjust_penalty_existing_history(sql_session: Session) -> None: + user = insert_test_data(sql_session) + + transactions = [ + Transaction.adjust_penalty( + time=datetime(2024, 1, 1, 10, 0, 0), + user_id=user.id, + penalty_threshold=-150, + penalty_multiplier_percent=110, + message="Initial penalty settings", + ), + ] + sql_session.add_all(transactions) + sql_session.commit() + + (penalty, _) = current_penalty(sql_session) + assert penalty == -150 + + adjust_penalty( + sql_session, + user=user, + new_penalty=-250, + message="Adjusting penalty threshold", + time=transactions[-1].time + timedelta(days=1), + ) + sql_session.commit() + + (penalty, _) = current_penalty(sql_session) + assert penalty == -250 + + +def test_adjust_penalty_multiplier_existing_history(sql_session: Session) -> None: + user = insert_test_data(sql_session) + + transactions = [ + Transaction.adjust_penalty( + time=datetime(2024, 1, 1, 10, 0, 0), + user_id=user.id, + penalty_threshold=-150, + penalty_multiplier_percent=110, + message="Initial penalty settings", + ), + ] + sql_session.add_all(transactions) + sql_session.commit() + + (_, multiplier) = current_penalty(sql_session) + assert multiplier == 110 + + adjust_penalty( + sql_session, + user=user, + new_penalty_multiplier=130, + message="Adjusting penalty multiplier", + time=transactions[-1].time + timedelta(days=1), + ) + sql_session.commit() + (_, multiplier) = current_penalty(sql_session) + assert multiplier == 130 + + +def test_adjust_penalty_and_multiplier(sql_session: Session) -> None: + user = insert_test_data(sql_session) + + adjust_penalty( + sql_session, + user=user, + new_penalty=-300, + new_penalty_multiplier=150, + message="Setting both penalty and multiplier", + ) + sql_session.commit() + + (penalty, multiplier) = current_penalty(sql_session) + assert penalty == -300 + assert multiplier == 150 diff --git a/tests/queries/test_current_interest.py b/tests/queries/test_current_interest.py new file mode 100644 index 0000000..f35b8d9 --- /dev/null +++ b/tests/queries/test_current_interest.py @@ -0,0 +1,38 @@ +from datetime import datetime + +from sqlalchemy.orm import Session + +from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENT +from dibbler.models import Transaction, User +from dibbler.queries import current_interest +from tests.helpers import assert_id_order_similar_to_time_order, assign_times + + +def test_current_interest_no_history(sql_session: Session) -> None: + assert current_interest(sql_session) == DEFAULT_INTEREST_RATE_PERCENT + + +def test_current_interest_with_history(sql_session: Session) -> None: + user = User("Admin User") + sql_session.add(user) + sql_session.commit() + + transactions = [ + Transaction.adjust_interest( + interest_rate_percent=5, + user_id=user.id, + ), + Transaction.adjust_interest( + interest_rate_percent=7, + user_id=user.id, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + assert_id_order_similar_to_time_order(transactions) + + assert current_interest(sql_session) == 7 diff --git a/tests/queries/test_current_penalty.py b/tests/queries/test_current_penalty.py new file mode 100644 index 0000000..1b92fd3 --- /dev/null +++ b/tests/queries/test_current_penalty.py @@ -0,0 +1,46 @@ +from datetime import datetime + +from sqlalchemy.orm import Session + +from dibbler.models import Transaction, User +from dibbler.models.Transaction import ( + DEFAULT_PENALTY_MULTIPLIER_PERCENT, + DEFAULT_PENALTY_THRESHOLD, +) +from dibbler.queries import current_penalty +from tests.helpers import assign_times, assert_id_order_similar_to_time_order + + +def test_current_penalty_no_history(sql_session: Session) -> None: + assert current_penalty(sql_session) == ( + DEFAULT_PENALTY_THRESHOLD, + DEFAULT_PENALTY_MULTIPLIER_PERCENT, + ) + + +def test_current_penalty_with_history(sql_session: Session) -> None: + user = User("Admin User") + sql_session.add(user) + sql_session.commit() + + transactions = [ + Transaction.adjust_penalty( + penalty_threshold=-200, + penalty_multiplier_percent=150, + user_id=user.id, + ), + Transaction.adjust_penalty( + penalty_threshold=-300, + penalty_multiplier_percent=200, + user_id=user.id, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + assert_id_order_similar_to_time_order(transactions) + + assert current_penalty(sql_session) == (-300, 200) diff --git a/tests/queries/test_joint_buy_product.py b/tests/queries/test_joint_buy_product.py new file mode 100644 index 0000000..b47b40e --- /dev/null +++ b/tests/queries/test_joint_buy_product.py @@ -0,0 +1,199 @@ +from datetime import datetime, timedelta + +import pytest +from sqlalchemy.orm import Session + +from dibbler.models import Product, Transaction, User +from dibbler.queries import joint_buy_product + + +def insert_test_data(sql_session: Session) -> tuple[User, User, User, Product]: + user1 = User("Test User 1") + user2 = User("Test User 2") + user3 = User("Test User 3") + product = Product("1234567890123", "Test Product") + + sql_session.add_all([user1, user2, user3, product]) + sql_session.commit() + + transactions = [ + Transaction.add_product( + user_id=user1.id, + product_id=product.id, + amount=30, + per_product=10, + product_count=3, + time=datetime(2024, 1, 1, 10, 0, 0), + ) + ] + + sql_session.add_all(transactions) + sql_session.commit() + + return user1, user2, user3, product + + +def test_joint_buy_product_uninitialized_product(sql_session: Session) -> None: + user = User("Test User 1") + sql_session.add(user) + sql_session.commit() + + product = Product("1234567890123", "Uninitialized Product") + + with pytest.raises(ValueError): + joint_buy_product( + sql_session, + instigator=user, + users=[user], + product=product, + product_count=1, + ) + + +def test_joint_buy_product_no_users(sql_session: Session) -> None: + user, _, _, product = insert_test_data(sql_session) + + with pytest.raises(ValueError): + joint_buy_product( + sql_session, + instigator=user, + users=[], + product=product, + product_count=1, + ) + + +def test_joint_buy_product_uninitialized_instigator(sql_session: Session) -> None: + user, user2, _, product = insert_test_data(sql_session) + + uninitialized_user = User("Uninitialized User") + with pytest.raises(ValueError): + joint_buy_product( + sql_session, + instigator=uninitialized_user, + users=[user, user2], + product=product, + product_count=1, + ) + + +def test_joint_buy_product_uninitialized_user_in_list(sql_session: Session) -> None: + user, _, _, product = insert_test_data(sql_session) + + uninitialized_user = User("Uninitialized User") + with pytest.raises(ValueError): + joint_buy_product( + sql_session, + instigator=user, + users=[user, uninitialized_user], + product=product, + product_count=1, + ) + + +def test_joint_buy_product_invalid_product_count(sql_session: Session) -> None: + user, _, _, product = insert_test_data(sql_session) + + with pytest.raises(ValueError): + joint_buy_product( + sql_session, + instigator=user, + users=[user], + product=product, + product_count=0, + ) + + with pytest.raises(ValueError): + joint_buy_product( + sql_session, + instigator=user, + users=[user], + product=product, + product_count=-1, + ) + + +def test_joint_single_user(sql_session: Session) -> None: + user, _, _, product = insert_test_data(sql_session) + + joint_buy_product( + sql_session, + instigator=user, + users=[user], + product=product, + product_count=1, + ) + + +def test_joint_buy_product(sql_session: Session) -> None: + user, user2, user3, product = insert_test_data(sql_session) + + joint_buy_product( + sql_session, + instigator=user, + users=[user, user2, user3], + product=product, + product_count=1, + ) + + +def test_joint_buy_product_more_than_in_stock(sql_session: Session) -> None: + user, user2, user3, product = insert_test_data(sql_session) + + joint_buy_product( + sql_session, + instigator=user, + users=[user, user2, user3], + product=product, + product_count=5, + ) + + +def test_joint_buy_product_out_of_stock(sql_session: Session) -> None: + user, user2, user3, product = insert_test_data(sql_session) + + transactions = [ + Transaction.buy_product( + user_id=user.id, + product_id=product.id, + product_count=3, + time=datetime(2024, 1, 2, 10, 0, 0), + ) + ] + + sql_session.add_all(transactions) + sql_session.commit() + + joint_buy_product( + sql_session, + instigator=user, + users=[user, user2, user3], + product=product, + product_count=10, + time=transactions[-1].time + timedelta(days=1), + ) + + +def test_joint_buy_product_duplicate_user(sql_session: Session) -> None: + user, user2, _, product = insert_test_data(sql_session) + + joint_buy_product( + sql_session, + instigator=user, + users=[user, user, user2], + product=product, + product_count=1, + ) + + +def test_joint_buy_product_non_involved_instigator(sql_session: Session) -> None: + user, user2, user3, product = insert_test_data(sql_session) + + with pytest.raises(ValueError): + joint_buy_product( + sql_session, + instigator=user, + users=[user2, user3], + product=product, + product_count=1, + ) diff --git a/tests/queries/test_product_owners.py b/tests/queries/test_product_owners.py new file mode 100644 index 0000000..b8bebac --- /dev/null +++ b/tests/queries/test_product_owners.py @@ -0,0 +1,335 @@ +from datetime import datetime +from pprint import pprint + +import pytest +from sqlalchemy.orm import Session + +from dibbler.models import Product, User +from dibbler.models.Transaction import Transaction +from dibbler.queries import product_owners, product_owners_log, product_stock +from tests.helpers import assign_times, assert_id_order_similar_to_time_order + + +def insert_test_data(sql_session: Session) -> tuple[Product, User]: + user = User("testuser") + product = Product("1234567890123", "Test Product") + + sql_session.add(user) + sql_session.add(product) + + sql_session.commit() + + return product, user + + +def test_product_owners_unitilialized_product(sql_session: Session) -> None: + user = User("testuser") + sql_session.add(user) + sql_session.commit() + + product = Product("1234567890123", "Uninitialized Product") + + with pytest.raises(ValueError): + product_owners(sql_session, product) + + +def test_product_owners_no_transactions(sql_session: Session) -> None: + product, _ = insert_test_data(sql_session) + + pprint(product_owners_log(sql_session, product)) + + owners = product_owners(sql_session, product) + assert owners == [] + + +def test_product_owners_add_products(sql_session: Session) -> None: + product, user = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + user_id=user.id, + product_id=product.id, + amount=30, + per_product=10, + product_count=3, + ) + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + assert_id_order_similar_to_time_order(transactions) + + pprint(product_owners_log(sql_session, product)) + + owners = product_owners(sql_session, product) + assert owners == [user, user, user] + + +def test_product_owners_add_and_buy_products(sql_session: Session) -> None: + product, user = insert_test_data(sql_session) + transactions = [ + Transaction.add_product( + user_id=user.id, + product_id=product.id, + amount=30, + per_product=10, + product_count=3, + ), + Transaction.buy_product( + user_id=user.id, + product_id=product.id, + product_count=1, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + assert_id_order_similar_to_time_order(transactions) + + pprint(product_owners_log(sql_session, product)) + + owners = product_owners(sql_session, product) + assert owners == [user, user] + + +def test_product_owners_add_and_throw_products(sql_session: Session) -> None: + product, user = insert_test_data(sql_session) + transactions = [ + Transaction.add_product( + user_id=user.id, + product_id=product.id, + amount=40, + per_product=10, + product_count=4, + ), + Transaction.throw_product( + user_id=user.id, + product_id=product.id, + product_count=2, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + assert_id_order_similar_to_time_order(transactions) + + pprint(product_owners_log(sql_session, product)) + + owners = product_owners(sql_session, product) + assert owners == [user, user] + + +def test_product_owners_multiple_users(sql_session: Session) -> None: + product, user1 = insert_test_data(sql_session) + user2 = User("testuser2") + sql_session.add(user2) + sql_session.commit() + transactions = [ + Transaction.add_product( + user_id=user1.id, + product_id=product.id, + amount=20, + per_product=10, + product_count=2, + ), + Transaction.add_product( + user_id=user2.id, + product_id=product.id, + amount=30, + per_product=10, + product_count=3, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + assert_id_order_similar_to_time_order(transactions) + + pprint(product_owners_log(sql_session, product)) + + owners = product_owners(sql_session, product) + assert owners == [user2, user2, user2, user1, user1] + + +def test_product_owners_adjust_stock_down(sql_session: Session) -> None: + product, user = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + user_id=user.id, + product_id=product.id, + amount=50, + per_product=10, + product_count=5, + ), + Transaction.adjust_stock( + user_id=user.id, + product_id=product.id, + product_count=-2, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + assert_id_order_similar_to_time_order(transactions) + + pprint(product_owners_log(sql_session, product)) + + assert product_stock(sql_session, product) == 3 + + owners = product_owners(sql_session, product) + assert owners == [user, user, user] + + +def test_product_owners_adjust_stock_up(sql_session: Session) -> None: + product, user = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + user_id=user.id, + product_id=product.id, + amount=20, + per_product=10, + product_count=2, + ), + Transaction.adjust_stock( + user_id=user.id, + product_id=product.id, + product_count=3, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + assert_id_order_similar_to_time_order(transactions) + + pprint(product_owners_log(sql_session, product)) + + owners = product_owners(sql_session, product) + assert owners == [user, user, None, None, None] + + +def test_product_owners_negative_stock(sql_session: Session) -> None: + product, user = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + user_id=user.id, + product_id=product.id, + amount=10, + per_product=10, + product_count=1, + ), + Transaction.buy_product( + user_id=user.id, + product_id=product.id, + product_count=2, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + assert_id_order_similar_to_time_order(transactions) + + owners = product_owners(sql_session, product) + assert owners == [] + + +def test_product_owners_add_products_from_negative_stock(sql_session: Session) -> None: + product, user = insert_test_data(sql_session) + + transactions = [ + Transaction.buy_product( + user_id=user.id, + product_id=product.id, + product_count=2, + ), + Transaction.add_product( + user_id=user.id, + product_id=product.id, + amount=30, + per_product=10, + product_count=3, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + assert_id_order_similar_to_time_order(transactions) + + pprint(product_owners_log(sql_session, product)) + + owners = product_owners(sql_session, product) + assert owners == [user] + + +def test_product_owners_interleaved_users(sql_session: Session) -> None: + product, user1 = insert_test_data(sql_session) + user2 = User("testuser2") + sql_session.add(user2) + sql_session.commit() + + transactions = [ + Transaction.add_product( + user_id=user1.id, + product_id=product.id, + amount=20, + per_product=10, + product_count=2, + ), + Transaction.add_product( + user_id=user2.id, + product_id=product.id, + amount=30, + per_product=10, + product_count=3, + ), + Transaction.buy_product( + user_id=user1.id, + product_id=product.id, + product_count=1, + ), + Transaction.add_product( + user_id=user1.id, + product_id=product.id, + amount=10, + per_product=10, + product_count=1, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + assert_id_order_similar_to_time_order(transactions) + + pprint(product_owners_log(sql_session, product)) + + owners = product_owners(sql_session, product) + assert owners == [user1, user2, user2, user2, user1] diff --git a/tests/queries/test_product_price.py b/tests/queries/test_product_price.py new file mode 100644 index 0000000..e1ad434 --- /dev/null +++ b/tests/queries/test_product_price.py @@ -0,0 +1,447 @@ +import math +from datetime import datetime, timedelta +from pprint import pprint + +import pytest +from sqlalchemy.orm import Session + +from dibbler.models import Product, Transaction, User +from dibbler.queries import product_price, product_price_log, joint_buy_product +from tests.helpers import assert_id_order_similar_to_time_order, assign_times + +# TODO: see if we can use pytest_runtest_makereport to print the "product_price_log"s +# only on failures instead of inlining it in every test function + + +def insert_test_data(sql_session: Session) -> tuple[User, Product]: + user = User("Test User") + product = Product("1234567890123", "Test Product") + + sql_session.add(user) + sql_session.add(product) + sql_session.commit() + + return user, product + + +def test_product_price_uninitialized_product(sql_session: Session) -> None: + user = User("Test User") + sql_session.add(user) + sql_session.commit() + + product = Product("1234567890123", "Uninitialized Product") + + with pytest.raises(ValueError, match="Product must be persisted in the database."): + product_price(sql_session, product) + + +def test_product_price_no_transactions(sql_session: Session) -> None: + _, product = insert_test_data(sql_session) + + pprint(product_price_log(sql_session, product)) + + assert product_price(sql_session, product) == 0 + + +def test_product_price_basic_history(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 0), + amount=27 * 2 - 1, + per_product=27, + product_count=2, + user_id=user.id, + product_id=product.id, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + pprint(product_price_log(sql_session, product)) + + assert product_price(sql_session, product) == 27 + + +def test_product_price_sold_out(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + amount=27 * 2 - 1, + per_product=27, + product_count=2, + user_id=user.id, + product_id=product.id, + ), + Transaction.buy_product( + product_count=2, + user_id=user.id, + product_id=product.id, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + assert_id_order_similar_to_time_order(transactions) + + pprint(product_price_log(sql_session, product)) + + assert product_price(sql_session, product) == 27 + + +def test_product_price_interest(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.adjust_interest( + interest_rate_percent=110, + user_id=user.id, + ), + Transaction.add_product( + amount=27 * 2 - 1, + per_product=27, + product_count=2, + user_id=user.id, + product_id=product.id, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + assert_id_order_similar_to_time_order(transactions) + + pprint(product_price_log(sql_session, product)) + + product_price_ = product_price(sql_session, product) + product_price_interest = product_price(sql_session, product, include_interest=True) + + assert product_price_ == 27 + assert product_price_interest == math.ceil(27 * 1.1) + + +def test_product_price_changing_interest(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.adjust_interest( + interest_rate_percent=110, + user_id=user.id, + ), + Transaction.add_product( + amount=27 * 2 - 1, + per_product=27, + product_count=2, + user_id=user.id, + product_id=product.id, + ), + Transaction.adjust_interest( + interest_rate_percent=120, + user_id=user.id, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + assert_id_order_similar_to_time_order(transactions) + + pprint(product_price_log(sql_session, product)) + + product_price_interest = product_price(sql_session, product, include_interest=True) + assert product_price_interest == math.ceil(27 * 1.2) + + +def test_product_price_old_transaction(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + amount=27 * 2, + per_product=27, + product_count=2, + user_id=user.id, + product_id=product.id, + ), + # Price should be 27 + Transaction.add_product( + amount=38 * 3, + per_product=38, + product_count=3, + user_id=user.id, + product_id=product.id, + ), + # price should be averaged upwards + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + assert_id_order_similar_to_time_order(transactions) + + until_transaction = transactions[0] + + pprint( + product_price_log( + sql_session, + product, + until_transaction=until_transaction, + ) + ) + + product_price_ = product_price( + sql_session, + product, + until_transaction=until_transaction, + ) + assert product_price_ == 27 + + +# Price goes up and gets rounded up to the next integer +def test_product_price_round_up_from_below(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + amount=27 * 2, + per_product=27, + product_count=2, + user_id=user.id, + product_id=product.id, + ), + # Price should be 27 + Transaction.add_product( + amount=38 * 3, + per_product=38, + product_count=3, + user_id=user.id, + product_id=product.id, + ), + # price should be averaged upwards + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + assert_id_order_similar_to_time_order(transactions) + + pprint(product_price_log(sql_session, product)) + + product_price_ = product_price(sql_session, product) + assert product_price_ == math.ceil((27 * 2 + 38 * 3) / (2 + 3)) + + +# Price goes down and gets rounded up to the next integer +def test_product_price_round_up_from_above(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + amount=27 * 2, + per_product=27, + product_count=2, + user_id=user.id, + product_id=product.id, + ), + # Price should be 27 + Transaction.add_product( + amount=20 * 3, + per_product=20, + product_count=3, + user_id=user.id, + product_id=product.id, + ), + # price should be averaged downwards + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + assert_id_order_similar_to_time_order(transactions) + + pprint(product_price_log(sql_session, product)) + + product_price_ = product_price(sql_session, product) + assert product_price_ == math.ceil((27 * 2 + 20 * 3) / (2 + 3)) + + +def test_product_price_with_negative_stock_single_addition(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + amount=1, + per_product=10, + product_count=1, + user_id=user.id, + product_id=product.id, + ), + Transaction.buy_product( + product_count=10, + user_id=user.id, + product_id=product.id, + ), + Transaction.add_product( + amount=22, + per_product=22, + product_count=1, + user_id=user.id, + product_id=product.id, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + assert_id_order_similar_to_time_order(transactions) + + pprint(product_price_log(sql_session, product)) + + # Stock went subzero, price should be the last added product price + product1_price = product_price(sql_session, product) + assert product1_price == 22 + + +def test_product_price_with_negative_stock_multiple_additions(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + amount=1, + per_product=10, + product_count=1, + user_id=user.id, + product_id=product.id, + ), + Transaction.buy_product( + product_count=10, + user_id=user.id, + product_id=product.id, + ), + Transaction.add_product( + amount=22, + per_product=22, + product_count=1, + user_id=user.id, + product_id=product.id, + ), + Transaction.add_product( + amount=29, + per_product=29, + product_count=2, + user_id=user.id, + product_id=product.id, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + assert_id_order_similar_to_time_order(transactions) + + pprint(product_price_log(sql_session, product)) + + # Stock went subzero, price should be the last added product price + product1_price = product_price(sql_session, product) + assert product1_price == math.ceil(29) + + +def test_product_price_joint_transactions(sql_session: Session) -> None: + user1, product = insert_test_data(sql_session) + user2 = User("Test User 2") + sql_session.add(user2) + sql_session.commit() + + transactions = [ + Transaction.add_product( + amount=30 * 3, + per_product=30, + product_count=3, + user_id=user1.id, + product_id=product.id, + ), + Transaction.add_product( + amount=20 * 2, + per_product=20, + product_count=2, + user_id=user2.id, + product_id=product.id, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + assert_id_order_similar_to_time_order(transactions) + + product_price_ = product_price(sql_session, product) + assert product_price_ == math.ceil((30 * 3 + 20 * 2) / (3 + 2)) + + transactions += joint_buy_product( + sql_session, + instigator=user1, + users=[user1, user2], + product=product, + product_count=2, + time=transactions[-1].time + timedelta(seconds=1), + ) + + pprint(product_price_log(sql_session, product)) + + old_product_price = product_price_ + product_price_ = product_price(sql_session, product) + assert product_price_ == old_product_price, ( + "Joint buy transactions should not affect product price" + ) + + transactions = [ + Transaction.add_product( + amount=25 * 4, + per_product=25, + product_count=4, + user_id=user1.id, + product_id=product.id, + time=transactions[-1].time + timedelta(seconds=1), + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + pprint(product_price_log(sql_session, product)) + product_price_ = product_price(sql_session, product) + + # Expected state: + # Added products: + # Count: 3 + 2 = 5, Price: (30 * 3 + 20 * 2) / 5 = 26 + # Joint bought products: + # Count: 5 - 2 = 3, Price: n/a (should not affect price) + # Added products: + # Count: 3 + 4 = 7, Price: (26 * 3 + 25 * 4) / (3 + 4) = 25.57 -> 26 + + assert product_price_ == math.ceil((26 * 3 + 25 * 4) / (3 + 4)) + + +def test_product_price_until(sql_session: Session) -> None: ... diff --git a/tests/queries/test_product_stock.py b/tests/queries/test_product_stock.py new file mode 100644 index 0000000..e06b24b --- /dev/null +++ b/tests/queries/test_product_stock.py @@ -0,0 +1,285 @@ +from datetime import datetime, timedelta + +import pytest +from sqlalchemy.orm import Session + +from dibbler.models import Product, Transaction, User +from dibbler.queries import joint_buy_product, product_stock +from tests.helpers import assert_id_order_similar_to_time_order, assign_times + + +def insert_test_data(sql_session: Session) -> tuple[User, Product]: + user = User("Test User 1") + product = Product("1234567890123", "Test Product") + sql_session.add(user) + sql_session.add(product) + sql_session.commit() + return user, product + + +def test_product_stock_uninitialized_product(sql_session: Session) -> None: + user = User("Test User 1") + sql_session.add(user) + sql_session.commit() + + product = Product("1234567890123", "Uninitialized Product") + + with pytest.raises(ValueError): + product_stock(sql_session, product) + + +def test_product_stock_until_datetime_and_transaction_id_not_allowed(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transaction = Transaction.add_product( + amount=10, + per_product=10, + user_id=user.id, + product_id=product.id, + product_count=1, + ) + + with pytest.raises(ValueError): + product_stock( + sql_session, + product, + until_time=datetime.now(), + until_transaction=transaction, + ) + + +def test_product_stock_basic_history(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + sql_session.commit() + + transactions = [ + Transaction.add_product( + amount=10, + per_product=10, + user_id=user.id, + product_id=product.id, + product_count=1, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + assert product_stock(sql_session, product) == 1 + + +def test_product_stock_adjust_stock_up(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + user_id=user.id, + product_id=product.id, + amount=50, + per_product=10, + product_count=5, + ), + Transaction.adjust_stock( + user_id=user.id, + product_id=product.id, + product_count=2, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + assert_id_order_similar_to_time_order(transactions) + + assert product_stock(sql_session, product) == 5 + 2 + + +def test_product_stock_adjust_stock_down(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + user_id=user.id, + product_id=product.id, + amount=50, + per_product=10, + product_count=5, + ), + Transaction.adjust_stock( + user_id=user.id, + product_id=product.id, + product_count=-2, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + assert_id_order_similar_to_time_order(transactions) + + assert product_stock(sql_session, product) == 5 - 2 + + +def test_product_stock_complex_history(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + amount=27 * 2, + per_product=27, + user_id=user.id, + product_id=product.id, + product_count=2, + ), + Transaction.buy_product( + user_id=user.id, + product_id=product.id, + product_count=3, + ), + Transaction.add_product( + amount=50 * 4, + per_product=50, + user_id=user.id, + product_id=product.id, + product_count=4, + ), + Transaction.adjust_stock( + user_id=user.id, + product_id=product.id, + product_count=3, + ), + Transaction.adjust_stock( + user_id=user.id, + product_id=product.id, + product_count=-2, + ), + Transaction.throw_product( + user_id=user.id, + product_id=product.id, + product_count=1, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + assert_id_order_similar_to_time_order(transactions) + + assert product_stock(sql_session, product) == 2 - 3 + 4 + 3 - 2 - 1 + + +def test_product_stock_no_transactions(sql_session: Session) -> None: + _, product = insert_test_data(sql_session) + + assert product_stock(sql_session, product) == 0 + + +def test_negative_product_stock(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + amount=50, + per_product=50, + user_id=user.id, + product_id=product.id, + product_count=1, + ), + Transaction.buy_product( + user_id=user.id, + product_id=product.id, + product_count=2, + ), + Transaction.adjust_stock( + user_id=user.id, + product_id=product.id, + product_count=-1, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + assert_id_order_similar_to_time_order(transactions) + + # The stock should be negative because we added and bought the product + assert product_stock(sql_session, product) == 1 - 2 - 1 + + +def test_product_stock_joint_transaction(sql_session: Session) -> None: + user1, product = insert_test_data(sql_session) + + user2 = User("Test User 2") + sql_session.add(user2) + sql_session.commit() + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 17, 0, 0), + amount=100, + per_product=100, + user_id=user1.id, + product_id=product.id, + product_count=5, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + joint_buy_product( + sql_session, + time=transactions[0].time + timedelta(seconds=1), + instigator=user1, + users=[user1, user2], + product=product, + product_count=3, + ) + + assert product_stock(sql_session, product) == 5 - 3 + + +def test_product_stock_until_time(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + amount=10, + per_product=10, + user_id=user.id, + product_id=product.id, + product_count=1, + ), + Transaction.add_product( + amount=20, + per_product=10, + user_id=user.id, + product_id=product.id, + product_count=2, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + assert_id_order_similar_to_time_order(transactions) + + assert ( + product_stock( + sql_session, + product, + until_time=transactions[-1].time - timedelta(seconds=1), + ) + == 1 + ) diff --git a/tests/queries/test_search_product.py b/tests/queries/test_search_product.py new file mode 100644 index 0000000..2e33cc5 --- /dev/null +++ b/tests/queries/test_search_product.py @@ -0,0 +1,96 @@ +import pytest +from sqlalchemy.orm import Session + +from dibbler.models import Product +from dibbler.queries import search_product + + +def insert_test_data(sql_session: Session) -> list[Product]: + products = [ + Product("1234567890123", "Test Product A"), + Product("2345678901234", "Test Product B"), + Product("3456789012345", "Another Product"), + Product("4567890123456", "Hidden Product", hidden=True), + ] + + sql_session.add_all(products) + sql_session.commit() + + return products + + +def test_search_product_empty_not_allowed(sql_session: Session) -> None: + insert_test_data(sql_session) + + with pytest.raises(ValueError): + search_product("", sql_session) + + +def test_search_product_no_products(sql_session: Session) -> None: + result = search_product("Nonexistent Product", sql_session) + + assert isinstance(result, list) + + assert len(result) == 0 + + +def test_search_product_name_exact_match(sql_session: Session) -> None: + insert_test_data(sql_session) + + result = search_product("Test Product A", sql_session) + assert isinstance(result, Product) + assert result.bar_code == "1234567890123" + + +def test_search_product_name_partial_match(sql_session: Session) -> None: + insert_test_data(sql_session) + + result = search_product("Test Product", sql_session) + assert isinstance(result, list) + assert len(result) == 2 + names = {product.name for product in result} + assert names == {"Test Product A", "Test Product B"} + + +def test_search_product_name_no_match(sql_session: Session) -> None: + insert_test_data(sql_session) + + result = search_product("Nonexistent", sql_session) + assert isinstance(result, list) + assert len(result) == 0 + + +def test_search_product_barcode_exact_match(sql_session: Session) -> None: + products = insert_test_data(sql_session) + + product = products[1] # Test Product B + + result = search_product(product.bar_code, sql_session) + assert isinstance(result, Product) + assert result.name == product.name + + +# Should not be able to find hidden products +def test_search_product_hidden_products(sql_session: Session) -> None: + insert_test_data(sql_session) + result = search_product("Hidden Product", sql_session) + assert isinstance(result, list) + assert len(result) == 0 + + +# Should be able to find hidden products if specified +def test_search_product_find_hidden_products(sql_session: Session) -> None: + insert_test_data(sql_session) + result = search_product("Hidden Product", sql_session, find_hidden_products=True) + assert isinstance(result, Product) + assert result.name == "Hidden Product" + + +# Should be able to find hidden products by barcode despite not specified +def test_search_product_hidden_products_by_barcode(sql_session: Session) -> None: + products = insert_test_data(sql_session) + hidden_product = products[3] # Hidden Product + + result = search_product(hidden_product.bar_code, sql_session) + assert isinstance(result, Product) + assert result.name == "Hidden Product" diff --git a/tests/queries/test_search_user.py b/tests/queries/test_search_user.py new file mode 100644 index 0000000..110bd2a --- /dev/null +++ b/tests/queries/test_search_user.py @@ -0,0 +1,86 @@ +from sqlalchemy.orm import Session +import pytest + +from dibbler.models import User +from dibbler.queries import search_user + +USER = [ + ("alice", 123), + ("bob", 125), + ("charlie", 126), + ("david", 127), + ("eve", 128), + ("evey", 129), + ("evy", 130), + ("-symbol-man", 131), + ("user_123", 132), +] + + +def setup_users(sql_session: Session) -> None: + for username, rfid in USER: + user = User(name=username, rfid=str(rfid)) + sql_session.add(user) + sql_session.commit() + + +def test_search_user_empty_not_allowed(sql_session: Session) -> None: + setup_users(sql_session) + + with pytest.raises(ValueError): + search_user("", sql_session) + + +def test_search_user_exact_match(sql_session: Session) -> None: + setup_users(sql_session) + + user = search_user("alice", sql_session) + assert user is not None + assert isinstance(user, User) + assert user.name == "alice" + + user = search_user("125", sql_session) + assert user is not None + assert isinstance(user, User) + assert user.name == "bob" + + +def test_search_user_partial_match(sql_session: Session) -> None: + setup_users(sql_session) + + users = search_user("ev", sql_session) + assert isinstance(users, list) + assert len(users) == 3 + names = {user.name for user in users} + assert names == {"eve", "evey", "evy"} + + users = search_user("user", sql_session) + assert isinstance(users, list) + assert len(users) == 1 + assert users[0].name == "user_123" + + +def test_search_user_no_match(sql_session: Session) -> None: + setup_users(sql_session) + + result = search_user("nonexistent", sql_session) + assert isinstance(result, list) + assert len(result) == 0 + + +def test_search_user_special_characters(sql_session: Session) -> None: + setup_users(sql_session) + + user = search_user("-symbol-man", sql_session) + assert user is not None + assert isinstance(user, User) + assert user.name == "-symbol-man" + + +def test_search_by_rfid(sql_session: Session) -> None: + setup_users(sql_session) + + user = search_user("130", sql_session) + assert user is not None + assert isinstance(user, User) + assert user.name == "evy" diff --git a/tests/queries/test_transaction_log.py b/tests/queries/test_transaction_log.py new file mode 100644 index 0000000..e919726 --- /dev/null +++ b/tests/queries/test_transaction_log.py @@ -0,0 +1,687 @@ +from datetime import datetime, timedelta + +import pytest +from sqlalchemy.orm import Session + +from dibbler.models import ( + Product, + Transaction, + TransactionType, + User, +) +from dibbler.queries import transaction_log +from tests.helpers import assert_id_order_similar_to_time_order, assign_times + + +def insert_test_data(sql_session: Session) -> tuple[User, User, Product, Product]: + user1 = User("Test User 1") + user2 = User("Test User 2") + + product1 = Product("1234567890123", "Test Product 1") + product2 = Product("9876543210987", "Test Product 2") + sql_session.add_all([user1, user2, product1, product2]) + sql_session.commit() + + return user1, user2, product1, product2 + + +def insert_default_test_transactions( + sql_session: Session, + user1: User, + user2: User, + product1: Product, + product2: Product, +) -> list[Transaction]: + transactions = [ + Transaction.adjust_balance( + amount=100, + user_id=user1.id, + ), + Transaction.adjust_balance( + amount=50, + user_id=user2.id, + ), + Transaction.adjust_balance( + amount=-50, + user_id=user1.id, + ), + Transaction.add_product( + amount=27 * 2, + per_product=27, + product_count=2, + user_id=user1.id, + product_id=product1.id, + ), + Transaction.buy_product( + product_count=1, + user_id=user2.id, + product_id=product2.id, + ), + Transaction.add_product( + amount=15 * 1, + per_product=15, + product_count=1, + user_id=user2.id, + product_id=product2.id, + ), + Transaction.transfer( + amount=30, + user_id=user1.id, + transfer_user_id=user2.id, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + assert_id_order_similar_to_time_order(transactions) + + return transactions + + +def test_transaction_log_invalid_limit(sql_session: Session) -> None: + with pytest.raises(ValueError): + transaction_log(sql_session, limit=0) + + with pytest.raises(ValueError): + transaction_log(sql_session, limit=-1) + + +def test_transaction_log_uninitialized_user(sql_session: Session) -> None: + user = User("Uninitialized User") + + with pytest.raises(ValueError): + transaction_log(sql_session, user=user) + + +def test_transaction_log_uninitialized_product(sql_session: Session) -> None: + product = Product("1234567890123", "Uninitialized Product") + + with pytest.raises(ValueError): + transaction_log(sql_session, product=product) + + +def test_transaction_log_uninitialized_after_until_transaction(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + insert_default_test_transactions(sql_session, user, user2, product, product2) + + uninitialized_transaction = Transaction.adjust_balance( + time=datetime(2023, 10, 1, 10, 0, 0), + amount=100, + user_id=user.id, + ) + + with pytest.raises(ValueError): + transaction_log( + sql_session, + user=user, + after_transaction=uninitialized_transaction, + ) + + with pytest.raises(ValueError): + transaction_log( + sql_session, + user=user, + until_transaction=uninitialized_transaction, + ) + + +def test_transaction_log_product_and_user_not_allowed(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + insert_default_test_transactions(sql_session, user, user2, product, product2) + + with pytest.raises(ValueError): + transaction_log( + sql_session, + user=user, + product=product, + ) + + +def test_transaction_log_until_datetime_and_transaction_id_not_allowed( + sql_session: Session, +) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + insert_default_test_transactions(sql_session, user, user2, product, product2) + + trx = Transaction.adjust_balance( + time=datetime(2023, 10, 1, 10, 0, 0), + amount=100, + user_id=user.id, + ) + sql_session.add(trx) + sql_session.commit() + + with pytest.raises(ValueError): + transaction_log( + sql_session, + user=user, + until_time=datetime(2023, 10, 1, 11, 0, 0), + until_transaction=trx, + ) + + +def test_transaction_log_after_datetime_and_transaction_id_not_allowed( + sql_session: Session, +) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + insert_default_test_transactions(sql_session, user, user2, product, product2) + + trx = Transaction.adjust_balance( + time=datetime(2023, 10, 1, 10, 0, 0), + amount=100, + user_id=user.id, + ) + sql_session.add(trx) + sql_session.commit() + + with pytest.raises(ValueError): + transaction_log( + sql_session, + user=user, + after_time=datetime(2023, 10, 1, 15, 0, 0), + after_transaction=trx, + ) + + +def test_user_transactions_no_transactions(sql_session: Session) -> None: + insert_test_data(sql_session) + + transactions = transaction_log(sql_session) + + assert len(transactions) == 0 + + +def test_transaction_log_basic(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + insert_default_test_transactions(sql_session, user, user2, product, product2) + + assert len(transaction_log(sql_session)) == 7 + + +def test_transaction_log_filtered_by_user(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + insert_default_test_transactions(sql_session, user, user2, product, product2) + + assert len(transaction_log(sql_session, user=user)) == 4 + assert len(transaction_log(sql_session, user=user2)) == 3 + + +def test_transaction_log_filtered_by_product(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + insert_default_test_transactions(sql_session, user, user2, product, product2) + + assert len(transaction_log(sql_session, product=product)) == 1 + assert len(transaction_log(sql_session, product=product2)) == 2 + + +def test_transaction_log_after_datetime(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + assert ( + len( + transaction_log( + sql_session, + after_time=transactions[2].time, + ) + ) + == len(transactions) - 2 + ) + + +def test_transaction_log_after_datetime_no_transactions(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + assert ( + len( + transaction_log( + sql_session, + after_time=transactions[-1].time + timedelta(seconds=1), + ) + ) + == 0 + ) + + +def test_transaction_log_after_datetime_exclusive(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + assert ( + len( + transaction_log( + sql_session, + after_time=transactions[2].time, + after_inlcusive=False, + ) + ) + == len(transactions) - 3 + ) + + +def test_transaction_log_after_transaction_id(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + first_transaction = transactions[0] + + assert len( + transaction_log( + sql_session, + after_transaction=first_transaction, + ) + ) == len(transactions) + + +def test_transaction_log_after_transaction_id_one_transaction(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + last_transaction = transactions[-1] + + assert ( + len( + transaction_log( + sql_session, + after_transaction=last_transaction, + ) + ) + == 1 + ) + + +def test_transaction_log_after_transaction_id_exclusive(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + third_transaction = transactions[2] + + assert ( + len( + transaction_log( + sql_session, + after_transaction=third_transaction, + after_inlcusive=False, + ) + ) + == len(transactions) - 3 + ) + + +def test_transaction_log_until_datetime(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + assert ( + len( + transaction_log( + sql_session, + until_time=transactions[-3].time, + ) + ) + == len(transactions) - 2 + ) + + +def test_transaction_log_until_datetime_no_transactions(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + assert ( + len( + transaction_log( + sql_session, + until_time=transactions[0].time - timedelta(seconds=1), + ) + ) + == 0 + ) + + +def test_transaction_log_until_datetime_exclusive(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + assert ( + len( + transaction_log( + sql_session, + until_time=transactions[-3].time, + until_inclusive=False, + ) + ) + == len(transactions) - 3 + ) + + +def test_transaction_log_until_transaction(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + last_transaction = transactions[-3] + + assert ( + len( + transaction_log( + sql_session, + until_transaction=last_transaction, + ) + ) + == len(transactions) - 2 + ) + + +def test_transaction_log_until_transaction_one_transaction(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + first_transaction = transactions[0] + + assert ( + len( + transaction_log( + sql_session, + until_transaction=first_transaction, + ) + ) + == 1 + ) + + +def test_transaction_log_until_transaction_exclusive(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + last_transaction = transactions[-3] + + assert ( + len( + transaction_log( + sql_session, + until_transaction=last_transaction, + until_inclusive=False, + ) + ) + == len(transactions) - 3 + ) + + +def test_transaction_log_after_until_datetime_illegal_order(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + second_transaction = transactions[1] + fifth_transaction = transactions[4] + + with pytest.raises(ValueError): + transaction_log( + sql_session, + after_time=fifth_transaction.time, + until_time=second_transaction.time, + ) + + +def test_transaction_log_after_until_datetime_combined(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + second_transaction = transactions[1] + fifth_transaction = transactions[4] + + assert ( + len( + transaction_log( + sql_session, + after_time=second_transaction.time, + until_time=fifth_transaction.time, + ) + ) + == 4 + ) + + +def test_transaction_log_after_until_transaction_illegal_order(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + second_transaction = transactions[1] + fifth_transaction = transactions[4] + + with pytest.raises(ValueError): + transaction_log( + sql_session, + after_transaction=fifth_transaction, + until_transaction=second_transaction, + ) + + +def test_transaction_log_after_until_transaction_combined(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + second_transaction = transactions[1] + fifth_transaction = transactions[4] + + assert ( + len( + transaction_log( + sql_session, + after_transaction=second_transaction, + until_transaction=fifth_transaction, + ) + ) + == 4 + ) + + +def test_transaction_log_after_date_until_transaction(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + second_transaction = transactions[1] + fifth_transaction = transactions[4] + + assert ( + len( + transaction_log( + sql_session, + after_time=second_transaction.time, + until_transaction=fifth_transaction, + ) + ) + == 4 + ) + + +def test_transaction_log_after_transaction_until_date(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + second_transaction = transactions[1] + fifth_transaction = transactions[4] + + assert ( + len( + transaction_log( + sql_session, + after_transaction=second_transaction, + until_time=fifth_transaction.time, + ) + ) + == 4 + ) + + +def test_transaction_log_limit(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + assert len(transaction_log(sql_session, limit=3)) == 3 + assert len(transaction_log(sql_session, limit=len(transactions) + 3)) == len(transactions) + + +def test_transaction_log_filtered_by_transaction_type(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + insert_default_test_transactions(sql_session, user, user2, product, product2) + + assert ( + len( + transaction_log( + sql_session, + transaction_type=[TransactionType.ADJUST_BALANCE], + ) + ) + == 3 + ) + assert ( + len( + transaction_log( + sql_session, + transaction_type=[TransactionType.ADD_PRODUCT], + ) + ) + == 2 + ) + assert ( + len( + transaction_log( + sql_session, + transaction_type=[TransactionType.BUY_PRODUCT, TransactionType.ADD_PRODUCT], + ) + ) + == 3 + ) + + +def test_transaction_log_filtered_by_transaction_type_negated(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + assert ( + len( + transaction_log( + sql_session, + transaction_type=[TransactionType.ADJUST_BALANCE], + negate_transaction_type_filter=True, + ) + ) + == len(transactions) - 3 + ) + assert ( + len( + transaction_log( + sql_session, + transaction_type=[TransactionType.ADD_PRODUCT], + negate_transaction_type_filter=True, + ) + ) + == len(transactions) - 2 + ) + assert ( + len( + transaction_log( + sql_session, + transaction_type=[TransactionType.BUY_PRODUCT, TransactionType.ADD_PRODUCT], + negate_transaction_type_filter=True, + ) + ) + == len(transactions) - 3 + ) + + +def test_transaction_log_combined_filter_user_datetime_transaction_type_limit( + sql_session: Session, +) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + second_transaction = transactions[1] + sixth_transaction = transactions[5] + + result = transaction_log( + sql_session, + user=user, + after_time=second_transaction.time, + until_time=sixth_transaction.time, + transaction_type=[TransactionType.ADJUST_BALANCE, TransactionType.ADD_PRODUCT], + limit=2, + ) + + assert len(result) == 2 + + +def test_transaction_log_combined_filter_user_transaction_transaction_type_limit( + sql_session: Session, +) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + second_transaction = transactions[1] + sixth_transaction = transactions[5] + + result = transaction_log( + sql_session, + user=user, + after_transaction=second_transaction, + until_transaction=sixth_transaction, + transaction_type=[TransactionType.ADJUST_BALANCE, TransactionType.ADD_PRODUCT], + limit=2, + ) + + assert len(result) == 2 + + +def test_transaction_log_combined_filter_product_datetime_transaction_type_limit( + sql_session: Session, +) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + second_transaction = transactions[1] + sixth_transaction = transactions[5] + + result = transaction_log( + sql_session, + product=product2, + after_time=second_transaction.time, + until_time=sixth_transaction.time, + transaction_type=[TransactionType.BUY_PRODUCT, TransactionType.ADD_PRODUCT], + limit=2, + ) + + assert len(result) == 2 + + +def test_transaction_log_combined_filter_product_transaction_transaction_type_limit( + sql_session: Session, +) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + second_transaction = transactions[1] + sixth_transaction = transactions[5] + + result = transaction_log( + sql_session, + product=product2, + after_transaction=second_transaction, + until_transaction=sixth_transaction, + transaction_type=[TransactionType.BUY_PRODUCT, TransactionType.ADD_PRODUCT], + limit=2, + ) + + assert len(result) == 2 + + +# NOTE: see the corresponding TODO's above the function definition + + +@pytest.mark.skip(reason="Not yet implemented") +def test_transaction_log_filtered_by_user_joint_transactions(sql_session: Session) -> None: ... + + +@pytest.mark.skip(reason="Not yet implemented") +def test_transaction_log_filtered_by_user_throw_away_transactions(sql_session: Session) -> None: ... diff --git a/tests/queries/test_user_balance.py b/tests/queries/test_user_balance.py new file mode 100644 index 0000000..50219e3 --- /dev/null +++ b/tests/queries/test_user_balance.py @@ -0,0 +1,1016 @@ +import math +from datetime import datetime, timedelta +from pprint import pprint + +import pytest +from sqlalchemy.orm import Session + +from dibbler.models import Product, Transaction, User +from dibbler.models.Transaction import ( + DEFAULT_INTEREST_RATE_PERCENT, + DEFAULT_PENALTY_MULTIPLIER_PERCENT, +) +from dibbler.queries import joint_buy_product, user_balance, user_balance_log +from dibbler.queries.user_balance import _joint_transaction_query, _non_joint_transaction_query +from tests.helpers import assert_id_order_similar_to_time_order, assign_times + +# TODO: see if we can use pytest_runtest_makereport to print the "user_balance_log"s +# only on failures instead of inlining it in every test function + + +def insert_test_data(sql_session: Session) -> tuple[User, User, User, Product]: + user = User("Test User") + user2 = User("Test User 2") + user3 = User("Test User 3") + product = Product("1234567890123", "Test Product") + + sql_session.add_all([user, user2, user3, product]) + sql_session.commit() + + return user, user2, user3, product + + +# NOTE: see economics spec +def _product_cost( + per_product: int, + product_count: int, + interest_rate_percent: int = DEFAULT_INTEREST_RATE_PERCENT, + apply_penalty: bool = False, + penalty_multiplier_percent: int = DEFAULT_PENALTY_MULTIPLIER_PERCENT, + joint_shares: int = 1, + joint_total_shares: int = 1, +) -> int: + base_cost: float = per_product * product_count * joint_shares / joint_total_shares + added_interest: float = base_cost * ((interest_rate_percent - 100) / 100) + + penalty: float = 0.0 + if apply_penalty: + penalty: float = base_cost * ((penalty_multiplier_percent - 100) / 100) + + total_cost: int = math.ceil(base_cost + added_interest + penalty) + + return total_cost + + +def test_non_joint_transaction_query(sql_session) -> None: + user1, user2, user3, product = insert_test_data(sql_session) + + transactions = [ + Transaction.adjust_balance( + user_id=user1.id, + amount=100, + ), + Transaction.adjust_balance( + user_id=user2.id, + amount=50, + ), + Transaction.add_product( + user_id=user2.id, + amount=70, + product_id=product.id, + product_count=3, + per_product=30, + ), + Transaction.transfer( + user_id=user1.id, + transfer_user_id=user2.id, + amount=50, + ), + Transaction.transfer( + user_id=user2.id, + transfer_user_id=user3.id, + amount=30, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + assert_id_order_similar_to_time_order(transactions) + + t = transactions + + result = { + row[0] + for row in sql_session.execute( + _non_joint_transaction_query( + user_id=user1.id, + use_cache=False, + ), + ).all() + } + assert result == {t[0].id, t[3].id} + + result = { + row[0] + for row in sql_session.execute( + _non_joint_transaction_query( + user_id=user2.id, + use_cache=False, + ), + ).all() + } + assert result == { + t[1].id, + t[2].id, + t[3].id, + t[4].id, + } + + result = { + row[0] + for row in sql_session.execute( + _non_joint_transaction_query( + user_id=user3.id, + use_cache=False, + ), + ).all() + } + assert result == {t[4].id} + + +def test_joint_transaction_query(sql_session: Session) -> None: + user1, user2, user3, product = insert_test_data(sql_session) + + j1 = joint_buy_product( + sql_session, + product=product, + product_count=3, + instigator=user1, + users=[user1, user2], + ) + + j2 = joint_buy_product( + sql_session, + product=product, + product_count=2, + instigator=user1, + users=[user1, user1, user2], + time=j1[-1].time + timedelta(minutes=1), + ) + + j3 = joint_buy_product( + sql_session, + product=product, + product_count=2, + instigator=user1, + users=[user1, user3, user3], + time=j2[-1].time + timedelta(minutes=1), + ) + + j4 = joint_buy_product( + sql_session, + product=product, + product_count=2, + instigator=user2, + users=[user2, user3, user3], + time=j3[-1].time + timedelta(minutes=1), + ) + + assert_id_order_similar_to_time_order(j1 + j2 + j3 + j4) + + result = set( + sql_session.execute( + _joint_transaction_query( + user_id=user1.id, + use_cache=False, + ), + ).all(), + ) + assert result == { + (j1[0].id, 1, 2), + (j2[0].id, 2, 3), + (j3[0].id, 1, 3), + } + + result = set( + sql_session.execute( + _joint_transaction_query( + user_id=user2.id, + use_cache=False, + ), + ).all(), + ) + assert result == { + (j1[0].id, 1, 2), + (j2[0].id, 1, 3), + (j4[0].id, 1, 3), + } + + result = set( + sql_session.execute( + _joint_transaction_query( + user_id=user3.id, + use_cache=False, + ), + ).all(), + ) + assert result == { + (j3[0].id, 2, 3), + (j4[0].id, 2, 3), + } + + +def test_user_balance_no_transactions(sql_session: Session) -> None: + user, *_ = insert_test_data(sql_session) + + pprint(user_balance_log(sql_session, user)) + + balance = user_balance(sql_session, user) + + assert balance == 0 + + +def test_user_balance_basic_history(sql_session: Session) -> None: + user, _, _, product = insert_test_data(sql_session) + + transactions = [ + Transaction.adjust_balance( + user_id=user.id, + amount=100, + ), + Transaction.add_product( + user_id=user.id, + product_id=product.id, + amount=27, + per_product=27, + product_count=1, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + assert_id_order_similar_to_time_order(transactions) + + pprint(user_balance_log(sql_session, user)) + + balance = user_balance(sql_session, user) + + assert balance == 100 + 27 + + +def test_user_balance_with_transfers(sql_session: Session) -> None: + user1, user2, _, _ = insert_test_data(sql_session) + + transactions = [ + Transaction.adjust_balance( + user_id=user1.id, + amount=100, + ), + Transaction.transfer( + user_id=user1.id, + transfer_user_id=user2.id, + amount=50, + ), + Transaction.transfer( + user_id=user2.id, + transfer_user_id=user1.id, + amount=30, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + assert_id_order_similar_to_time_order(transactions) + + pprint(user_balance_log(sql_session, user1)) + + user1_balance = user_balance(sql_session, user1) + assert user1_balance == 100 - 50 + 30 + + pprint(user_balance_log(sql_session, user2)) + + user2_balance = user_balance(sql_session, user2) + assert user2_balance == 50 - 30 + + +def test_user_balance_penalty(sql_session: Session) -> None: + user, _, _, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + user_id=user.id, + product_id=product.id, + amount=27, + per_product=27, + product_count=1, + ), + Transaction.adjust_balance( + user_id=user.id, + amount=-200, + ), + # Penalized, pays 2x the price (default penalty) + Transaction.buy_product( + user_id=user.id, + product_id=product.id, + product_count=1, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + assert_id_order_similar_to_time_order(transactions) + + pprint(user_balance_log(sql_session, user)) + + assert user_balance(sql_session, user) == 27 - 200 - _product_cost(27, 1, apply_penalty=True) + + +def test_user_balance_changing_penalty(sql_session: Session) -> None: + user, _, _, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + user_id=user.id, + product_id=product.id, + amount=27, + per_product=27, + product_count=1, + ), + Transaction.adjust_balance( + user_id=user.id, + amount=-200, + ), + # Penalized, pays 2x the price (default penalty) + Transaction.buy_product( + user_id=user.id, + product_id=product.id, + product_count=1, + ), + Transaction.adjust_penalty( + user_id=user.id, + penalty_multiplier_percent=300, + penalty_threshold=-100, + ), + # Penalized, pays 3x the price + Transaction.buy_product( + user_id=user.id, + product_id=product.id, + product_count=1, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + assert_id_order_similar_to_time_order(transactions) + + pprint(user_balance_log(sql_session, user)) + + assert user_balance(sql_session, user) == ( + 27 + - 200 + - _product_cost(27, 1, apply_penalty=True) + - _product_cost(27, 1, apply_penalty=True, penalty_multiplier_percent=300) + ) + + +def test_user_balance_interest(sql_session: Session) -> None: + user, _, _, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + user_id=user.id, + product_id=product.id, + amount=27, + per_product=27, + product_count=1, + ), + Transaction.adjust_interest( + user_id=user.id, + interest_rate_percent=110, + ), + Transaction.buy_product( + user_id=user.id, + product_id=product.id, + product_count=1, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + assert_id_order_similar_to_time_order(transactions) + + pprint(user_balance_log(sql_session, user)) + + assert user_balance(sql_session, user) == 27 - _product_cost(27, 1, interest_rate_percent=110) + + +def test_user_balance_changing_interest(sql_session: Session) -> None: + user, _, _, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + user_id=user.id, + product_id=product.id, + amount=27 * 3, + per_product=27, + product_count=3, + ), + Transaction.adjust_interest( + user_id=user.id, + interest_rate_percent=110, + ), + # Pays 1.1x the price + Transaction.buy_product( + user_id=user.id, + product_id=product.id, + product_count=1, + ), + Transaction.adjust_interest( + user_id=user.id, + interest_rate_percent=120, + ), + # Pays 1.2x the price + Transaction.buy_product( + user_id=user.id, + product_id=product.id, + product_count=1, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + assert_id_order_similar_to_time_order(transactions) + + pprint(user_balance_log(sql_session, user)) + + assert user_balance(sql_session, user) == ( + 27 * 3 + - _product_cost(27, 1, interest_rate_percent=110) + - _product_cost(27, 1, interest_rate_percent=120) + ) + + +def test_user_balance_penalty_interest_combined(sql_session: Session) -> None: + user, _, _, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + user_id=user.id, + product_id=product.id, + amount=27, + per_product=27, + product_count=1, + ), + Transaction.adjust_interest( + user_id=user.id, + interest_rate_percent=110, + ), + Transaction.adjust_balance( + user_id=user.id, + amount=-200, + ), + # Penalized, pays 2x the price (default penalty) + # Pays 1.1x the price + Transaction.buy_product( + user_id=user.id, + product_id=product.id, + product_count=1, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + assert_id_order_similar_to_time_order(transactions) + + pprint(user_balance_log(sql_session, user)) + + assert user_balance(sql_session, user) == ( + 27 + - 200 + - _product_cost( + 27, + 1, + interest_rate_percent=110, + apply_penalty=True, + ) + ) + + +def test_user_balance_joint_transaction_single_user(sql_session: Session) -> None: + user, _, _, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 10, 0, 0), + user_id=user.id, + product_id=product.id, + amount=27 * 3, + per_product=27, + product_count=3, + ), + ] + sql_session.add_all(transactions) + sql_session.commit() + + joint_buy_product( + sql_session, + instigator=user, + users=[user], + product=product, + product_count=2, + time=transactions[-1].time + timedelta(minutes=1), + ) + + pprint(user_balance_log(sql_session, user)) + + assert user_balance(sql_session, user) == ( + (27 * 3) + - _product_cost( + 27, + 2, + joint_shares=1, + joint_total_shares=1, + ) + ) + + +def test_user_balance_joint_transactions_multiple_users(sql_session: Session) -> None: + user, user2, user3, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 10, 0, 0), + user_id=user.id, + product_id=product.id, + amount=27 * 3, + per_product=27, + product_count=3, + ), + ] + sql_session.add_all(transactions) + sql_session.commit() + + joint_buy_product( + sql_session, + instigator=user, + users=[user, user2, user3], + product=product, + product_count=2, + time=transactions[-1].time + timedelta(minutes=1), + ) + + pprint(user_balance_log(sql_session, user)) + + assert user_balance(sql_session, user) == ( + (27 * 3) + - _product_cost( + 27, + 2, + joint_shares=1, + joint_total_shares=3, + ) + ) + + +def test_user_balance_joint_transactions_multiple_times_self(sql_session: Session) -> None: + user, user2, _, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 10, 0, 0), + user_id=user.id, + product_id=product.id, + amount=27 * 3, + per_product=27, + product_count=3, + ), + ] + sql_session.add_all(transactions) + sql_session.commit() + + joint_buy_product( + sql_session, + instigator=user, + users=[user, user, user2], + product=product, + product_count=2, + time=transactions[-1].time + timedelta(minutes=1), + ) + + pprint(user_balance_log(sql_session, user)) + + assert user_balance(sql_session, user) == ( + (27 * 3) + - _product_cost( + 27, + 2, + joint_shares=2, + joint_total_shares=3, + ) + ) + + +def test_user_balance_joint_transactions_interest(sql_session: Session) -> None: + user, user2, _, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + user_id=user.id, + product_id=product.id, + amount=27 * 3, + per_product=27, + product_count=3, + ), + Transaction.adjust_interest( + user_id=user.id, + interest_rate_percent=110, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + transactions += joint_buy_product( + sql_session, + instigator=user, + users=[user, user2], + product=product, + product_count=2, + time=transactions[-1].time + timedelta(minutes=1), + ) + + assert_id_order_similar_to_time_order(transactions) + + pprint(user_balance_log(sql_session, user)) + + assert user_balance(sql_session, user) == ( + (27 * 3) + - _product_cost( + 27, + 2, + joint_shares=1, + joint_total_shares=2, + interest_rate_percent=110, + ) + ) + + +def test_user_balance_joint_transactions_changing_interest(sql_session: Session) -> None: + user, user2, _, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + user_id=user.id, + product_id=product.id, + amount=27 * 4, + per_product=27, + product_count=4, + ), + # Pays 1.1x the price + Transaction.adjust_interest( + user_id=user.id, + interest_rate_percent=110, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + transactions += joint_buy_product( + sql_session, + instigator=user, + users=[user, user2], + product=product, + product_count=2, + time=transactions[-1].time + timedelta(minutes=15), + ) + + transactions += [ + # Pays 1.2x the price + Transaction.adjust_interest( + time=transactions[-1].time + timedelta(minutes=15), + user_id=user.id, + interest_rate_percent=120, + ) + ] + sql_session.add_all(transactions) + sql_session.commit() + + transactions += joint_buy_product( + sql_session, + instigator=user, + users=[user, user2], + product=product, + product_count=1, + time=transactions[-1].time + timedelta(minutes=15), + ) + + assert_id_order_similar_to_time_order(transactions) + + pprint(user_balance_log(sql_session, user)) + + assert user_balance(sql_session, user) == ( + (27 * 4) + - _product_cost( + 27, + 2, + joint_shares=1, + joint_total_shares=2, + interest_rate_percent=110, + ) + - _product_cost( + 27, + 1, + joint_shares=1, + joint_total_shares=2, + interest_rate_percent=120, + ) + ) + + +def test_user_balance_joint_transactions_penalty(sql_session: Session) -> None: + user, user2, _, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + user_id=user.id, + product_id=product.id, + amount=27 * 3, + per_product=27, + product_count=3, + ), + Transaction.adjust_balance( + user_id=user.id, + amount=-200, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + transactions += joint_buy_product( + sql_session, + instigator=user, + users=[user, user2], + product=product, + product_count=2, + time=transactions[-1].time + timedelta(minutes=15), + ) + + assert_id_order_similar_to_time_order(transactions) + + pprint(user_balance_log(sql_session, user)) + + assert user_balance(sql_session, user) == ( + (27 * 3) + - 200 + - _product_cost( + 27, + 2, + joint_shares=1, + joint_total_shares=2, + apply_penalty=True, + ) + ) + + +def test_user_balance_joint_transactions_changing_penalty(sql_session: Session) -> None: + user, user2, _, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + user_id=user.id, + product_id=product.id, + amount=27 * 3, + per_product=27, + product_count=3, + ), + Transaction.adjust_balance( + user_id=user.id, + amount=-200, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + transactions += joint_buy_product( + sql_session, + instigator=user, + users=[user, user2], + product=product, + product_count=2, + time=transactions[-1].time + timedelta(minutes=15), + ) + + transactions += [ + Transaction.adjust_penalty( + time=transactions[-1].time + timedelta(minutes=30), + user_id=user.id, + penalty_multiplier_percent=300, + penalty_threshold=-100, + ) + ] + + sql_session.add_all(transactions) + sql_session.commit() + + transactions += joint_buy_product( + sql_session, + instigator=user, + users=[user, user2], + product=product, + product_count=1, + time=transactions[-1].time + timedelta(minutes=45), + ) + + assert_id_order_similar_to_time_order(transactions) + + pprint(user_balance_log(sql_session, user)) + + assert user_balance(sql_session, user) == ( + (27 * 3) + - 200 + - _product_cost( + 27, + 2, + joint_shares=1, + joint_total_shares=2, + apply_penalty=True, + ) + - _product_cost( + 27, + 1, + joint_shares=1, + joint_total_shares=2, + apply_penalty=True, + penalty_multiplier_percent=300, + ) + ) + + +def test_user_balance_joint_transactions_penalty_interest_combined( + sql_session: Session, +) -> None: + user, user2, _, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + user_id=user.id, + product_id=product.id, + amount=27 * 3, + per_product=27, + product_count=3, + ), + Transaction.adjust_interest( + user_id=user.id, + interest_rate_percent=110, + ), + Transaction.adjust_balance( + user_id=user.id, + amount=-200, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + transactions += joint_buy_product( + sql_session, + instigator=user, + users=[user, user2], + product=product, + product_count=2, + time=transactions[-1].time + timedelta(minutes=15), + ) + + pprint(user_balance_log(sql_session, user)) + + assert user_balance(sql_session, user) == ( + (27 * 3) + - 200 + - _product_cost( + 27, + 2, + joint_shares=1, + joint_total_shares=2, + interest_rate_percent=110, + apply_penalty=True, + ) + ) + + +def test_user_balance_until_time(sql_session: Session) -> None: + user, _, _, product = insert_test_data(sql_session) + + transactions = [ + Transaction.adjust_balance( + user_id=user.id, + amount=100, + ), + Transaction.add_product( + user_id=user.id, + product_id=product.id, + amount=27, + per_product=27, + product_count=1, + ), + Transaction.adjust_balance( + user_id=user.id, + amount=50, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + pprint( + user_balance_log( + sql_session, + user, + until_time=transactions[1].time + timedelta(seconds=30), + ) + ) + + balance = user_balance( + sql_session, + user, + until_time=transactions[1].time + timedelta(seconds=30), + ) + + assert balance == 100 + 27 + + +def test_user_balance_until_transaction(sql_session: Session) -> None: + user, _, _, product = insert_test_data(sql_session) + + transactions = [ + Transaction.adjust_balance( + user_id=user.id, + amount=100, + ), + Transaction.add_product( + user_id=user.id, + product_id=product.id, + amount=27, + per_product=27, + product_count=1, + ), + Transaction.adjust_balance( + user_id=user.id, + amount=50, + ), + ] + + assign_times(transactions) + + sql_session.add_all(transactions) + sql_session.commit() + + until_transaction = transactions[1] + + pprint( + user_balance_log( + sql_session, + user, + until_transaction=until_transaction, + ) + ) + + balance = user_balance( + sql_session, + user, + until_transaction=until_transaction, + ) + + assert balance == 100 + 27 + + +@pytest.mark.skip(reason="Not yet implemented") +def test_user_balance_throw_away_products(sql_session: Session) -> None: ...