From f229c5c4d0c9fc52a80c46c47eb4df5b440ffe22 Mon Sep 17 00:00:00 2001 From: h7x4 Date: Wed, 10 Dec 2025 16:00:36 +0900 Subject: [PATCH] Write a set of queries to go along with the event sourcing model --- dibbler/lib/__init__.py | 0 dibbler/lib/helpers.py | 76 +-- dibbler/lib/query_helpers.py | 22 + dibbler/queries/__init__.py | 37 ++ dibbler/queries/add_product.py | 1 + dibbler/queries/add_user.py | 1 + dibbler/queries/adjust_interest.py | 28 + dibbler/queries/adjust_penalty.py | 41 ++ dibbler/queries/current_interest.py | 25 + dibbler/queries/current_penalty.py | 29 + dibbler/queries/joint_buy_product.py | 66 +++ dibbler/queries/product_owners.py | 296 ++++++++++ dibbler/queries/product_price.py | 285 ++++++++++ dibbler/queries/product_stock.py | 113 ++++ dibbler/queries/search_product.py | 39 ++ dibbler/queries/search_user.py | 36 ++ dibbler/queries/transaction_log.py | 95 ++++ dibbler/queries/user_balance.py | 520 +++++++++++++++++ dibbler/queries/user_products.py | 10 + tests/queries/__init__.py | 0 tests/queries/test_add_product.py | 0 tests/queries/test_add_user.py | 0 tests/queries/test_adjust_interest.py | 84 +++ tests/queries/test_adjust_penalty.py | 166 ++++++ tests/queries/test_current_interest.py | 35 ++ tests/queries/test_current_penalty.py | 42 ++ tests/queries/test_joint_buy_product.py | 198 +++++++ tests/queries/test_product_owners.py | 311 +++++++++++ tests/queries/test_product_price.py | 432 +++++++++++++++ tests/queries/test_product_stock.py | 250 +++++++++ tests/queries/test_search_product.py | 88 +++ tests/queries/test_search_user.py | 78 +++ tests/queries/test_transaction_log.py | 624 +++++++++++++++++++++ tests/queries/test_user_balance.py | 709 ++++++++++++++++++++++++ 34 files changed, 4663 insertions(+), 74 deletions(-) create mode 100644 dibbler/lib/__init__.py create mode 100644 dibbler/lib/query_helpers.py create mode 100644 dibbler/queries/__init__.py create mode 100644 dibbler/queries/add_product.py create mode 100644 dibbler/queries/add_user.py create mode 100644 dibbler/queries/adjust_interest.py create mode 100644 dibbler/queries/adjust_penalty.py create mode 100644 dibbler/queries/current_interest.py create mode 100644 dibbler/queries/current_penalty.py create mode 100644 dibbler/queries/joint_buy_product.py create mode 100644 dibbler/queries/product_owners.py create mode 100644 dibbler/queries/product_price.py create mode 100644 dibbler/queries/product_stock.py create mode 100644 dibbler/queries/search_product.py create mode 100644 dibbler/queries/search_user.py create mode 100644 dibbler/queries/transaction_log.py create mode 100644 dibbler/queries/user_balance.py create mode 100644 dibbler/queries/user_products.py create mode 100644 tests/queries/__init__.py create mode 100644 tests/queries/test_add_product.py create mode 100644 tests/queries/test_add_user.py create mode 100644 tests/queries/test_adjust_interest.py create mode 100644 tests/queries/test_adjust_penalty.py create mode 100644 tests/queries/test_current_interest.py create mode 100644 tests/queries/test_current_penalty.py create mode 100644 tests/queries/test_joint_buy_product.py create mode 100644 tests/queries/test_product_owners.py create mode 100644 tests/queries/test_product_price.py create mode 100644 tests/queries/test_product_stock.py create mode 100644 tests/queries/test_search_product.py create mode 100644 tests/queries/test_search_user.py create mode 100644 tests/queries/test_transaction_log.py create mode 100644 tests/queries/test_user_balance.py diff --git a/dibbler/lib/__init__.py b/dibbler/lib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dibbler/lib/helpers.py b/dibbler/lib/helpers.py index 30926a3..0aab88e 100644 --- a/dibbler/lib/helpers.py +++ b/dibbler/lib/helpers.py @@ -1,79 +1,7 @@ -import pwd -import subprocess import os +import pwd import signal - -from sqlalchemy import or_, and_ - -from ..models import User, Product - - -def search_user(string, session, ignorethisflag=None): - string = string.lower() - exact_match = ( - session.query(User) - .filter(or_(User.name == string, User.card == string, User.rfid == string)) - .first() - ) - if exact_match: - return exact_match - user_list = ( - session.query(User) - .filter( - or_( - User.name.ilike(f"%{string}%"), - User.card.ilike(f"%{string}%"), - User.rfid.ilike(f"%{string}%"), - ) - ) - .all() - ) - return user_list - - -def search_product(string, session, find_hidden_products=True): - if find_hidden_products: - exact_match = ( - session.query(Product) - .filter(or_(Product.bar_code == string, Product.name == string)) - .first() - ) - else: - exact_match = ( - session.query(Product) - .filter( - or_( - Product.bar_code == string, - and_(Product.name == string, Product.hidden is False), - ) - ) - .first() - ) - if exact_match: - return exact_match - if find_hidden_products: - product_list = ( - session.query(Product) - .filter( - or_( - Product.bar_code.ilike(f"%{string}%"), - Product.name.ilike(f"%{string}%"), - ) - ) - .all() - ) - else: - product_list = ( - session.query(Product) - .filter( - or_( - Product.bar_code.ilike(f"%{string}%"), - and_(Product.name.ilike(f"%{string}%"), Product.hidden is False), - ) - ) - .all() - ) - return product_list +import subprocess def system_user_exists(username): diff --git a/dibbler/lib/query_helpers.py b/dibbler/lib/query_helpers.py new file mode 100644 index 0000000..9062d96 --- /dev/null +++ b/dibbler/lib/query_helpers.py @@ -0,0 +1,22 @@ +from typing import TypeVar +from sqlalchemy import BindParameter, literal + +T = TypeVar("T") + + +def const(value: T) -> BindParameter[T]: + """ + Create a constant SQL literal bind parameter. + + This is useful to avoid too many `?` bind parameters in SQL queries, + when the input value is known to be safe. + """ + + return literal(value, literal_execute=True) + + +CONST_ZERO: BindParameter[int] = const(0) +CONST_ONE: BindParameter[int] = const(1) +CONST_TRUE: BindParameter[bool] = const(True) +CONST_FALSE: BindParameter[bool] = const(False) +CONST_NONE: BindParameter[None] = const(None) diff --git a/dibbler/queries/__init__.py b/dibbler/queries/__init__.py new file mode 100644 index 0000000..dfba1e5 --- /dev/null +++ b/dibbler/queries/__init__.py @@ -0,0 +1,37 @@ +__all__ = [ + # "add_product", + # "add_user", + "adjust_interest", + "adjust_penalty", + "current_interest", + "current_penalty", + "joint_buy_product", + "product_owners", + "product_owners_log", + "product_price", + "product_price_log", + "product_stock", + # "products_owned_by_user", + "search_product", + "search_user", + "transaction_log", + "user_balance", + "user_balance_log", +] + +# from .add_product import add_product +# from .add_user import add_user +from .adjust_interest import adjust_interest +from .adjust_penalty import adjust_penalty +from .current_interest import current_interest +from .current_penalty import current_penalty +from .joint_buy_product import joint_buy_product +from .product_owners import product_owners, product_owners_log +from .product_price import product_price, product_price_log +from .product_stock import product_stock + +# from .products_owned_by_user import products_owned_by_user +from .search_product import search_product +from .search_user import search_user +from .transaction_log import transaction_log +from .user_balance import user_balance, user_balance_log diff --git a/dibbler/queries/add_product.py b/dibbler/queries/add_product.py new file mode 100644 index 0000000..5e9e5a1 --- /dev/null +++ b/dibbler/queries/add_product.py @@ -0,0 +1 @@ +# TODO: implement me diff --git a/dibbler/queries/add_user.py b/dibbler/queries/add_user.py new file mode 100644 index 0000000..5e9e5a1 --- /dev/null +++ b/dibbler/queries/add_user.py @@ -0,0 +1 @@ +# TODO: implement me diff --git a/dibbler/queries/adjust_interest.py b/dibbler/queries/adjust_interest.py new file mode 100644 index 0000000..c66c360 --- /dev/null +++ b/dibbler/queries/adjust_interest.py @@ -0,0 +1,28 @@ +from sqlalchemy.orm import Session + +from dibbler.models import Transaction, User + +# TODO: this type of transaction should be password protected. +# the password can be set as a string literal in the config file. + + +def adjust_interest( + sql_session: Session, + user: User, + new_interest: int, + message: str | None = None, +) -> None: + if new_interest < 0: + raise ValueError("Interest rate cannot be negative") + + if user.id is None: + raise ValueError("User must be persisted in the database.") + + transaction = Transaction.adjust_interest( + user_id=user.id, + interest_rate_percent=new_interest, + message=message, + ) + + sql_session.add(transaction) + sql_session.commit() diff --git a/dibbler/queries/adjust_penalty.py b/dibbler/queries/adjust_penalty.py new file mode 100644 index 0000000..31e58b0 --- /dev/null +++ b/dibbler/queries/adjust_penalty.py @@ -0,0 +1,41 @@ +from sqlalchemy.orm import Session + +from dibbler.models import Transaction, User +from dibbler.queries.current_penalty import current_penalty + +# TODO: this type of transaction should be password protected. +# the password can be set as a string literal in the config file. + + +def adjust_penalty( + sql_session: Session, + user: User, + new_penalty: int | None = None, + new_penalty_multiplier: int | None = None, + message: str | None = None, +) -> None: + if new_penalty is None and new_penalty_multiplier is None: + raise ValueError("At least one of new_penalty or new_penalty_multiplier must be provided") + + if new_penalty_multiplier is not None and new_penalty_multiplier < 100: + raise ValueError("Penalty multiplier cannot be less than 100%") + + if user.id is None: + raise ValueError("User must be persisted in the database.") + + if new_penalty is None or new_penalty_multiplier is None: + existing_penalty, existing_penalty_multiplier = current_penalty(sql_session) + if new_penalty is None: + new_penalty = existing_penalty + if new_penalty_multiplier is None: + new_penalty_multiplier = existing_penalty_multiplier + + transaction = Transaction.adjust_penalty( + user_id=user.id, + penalty_threshold=new_penalty, + penalty_multiplier_percent=new_penalty_multiplier, + message=message, + ) + + sql_session.add(transaction) + sql_session.commit() diff --git a/dibbler/queries/current_interest.py b/dibbler/queries/current_interest.py new file mode 100644 index 0000000..7810826 --- /dev/null +++ b/dibbler/queries/current_interest.py @@ -0,0 +1,25 @@ +from sqlalchemy import select +from sqlalchemy.orm import Session + +from dibbler.models import Transaction, TransactionType +from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENTAGE + + +# TODO: add until transaction parameter +# TODO: add until datetime parameter + + +def current_interest(sql_session: Session) -> int: + result = sql_session.scalars( + select(Transaction) + .where(Transaction.type_ == TransactionType.ADJUST_INTEREST) + .order_by(Transaction.time.desc()) + .limit(1) + ).one_or_none() + + if result is None: + return DEFAULT_INTEREST_RATE_PERCENTAGE + elif result.interest_rate_percent is None: + return DEFAULT_INTEREST_RATE_PERCENTAGE + else: + return result.interest_rate_percent diff --git a/dibbler/queries/current_penalty.py b/dibbler/queries/current_penalty.py new file mode 100644 index 0000000..ff79ce0 --- /dev/null +++ b/dibbler/queries/current_penalty.py @@ -0,0 +1,29 @@ +from sqlalchemy import select +from sqlalchemy.orm import Session + +from dibbler.models import Transaction, TransactionType +from dibbler.models.Transaction import ( + DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE, + DEFAULT_PENALTY_THRESHOLD, +) + + +# TODO: add until transaction parameter +# TODO: add until datetime parameter + + +def current_penalty(sql_session: Session) -> tuple[int, int]: + result = sql_session.scalars( + select(Transaction) + .where(Transaction.type_ == TransactionType.ADJUST_PENALTY) + .order_by(Transaction.time.desc()) + .limit(1) + ).one_or_none() + + if result is None: + return DEFAULT_PENALTY_THRESHOLD, DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE + + assert result.penalty_threshold is not None, "Penalty threshold must be set" + assert result.penalty_multiplier_percent is not None, "Penalty multiplier percent must be set" + + return result.penalty_threshold, result.penalty_multiplier_percent diff --git a/dibbler/queries/joint_buy_product.py b/dibbler/queries/joint_buy_product.py new file mode 100644 index 0000000..156a730 --- /dev/null +++ b/dibbler/queries/joint_buy_product.py @@ -0,0 +1,66 @@ +from datetime import datetime + +from sqlalchemy.orm import Session + +from dibbler.models import ( + Product, + Transaction, + User, +) + + +def joint_buy_product( + sql_session: Session, + product: Product, + product_count: int, + instigator: User, + users: list[User], + time: datetime | None = None, + message: str | None = None, +) -> list[Transaction]: + """ + Create buy product transactions for multiple users at once. + """ + + if product.id is None: + raise ValueError("Product must be persisted in the database.") + + if instigator.id is None: + raise ValueError("Instigator must be persisted in the database.") + + if len(users) == 0: + raise ValueError("At least bying one user must be specified.") + + if any(user.id is None for user in users): + raise ValueError("All users must be persisted in the database.") + + if instigator not in users: + raise ValueError("Instigator must be in the list of users buying the product.") + + if product_count <= 0: + raise ValueError("Product count must be positive.") + + joint_transaction = Transaction.joint( + user_id=instigator.id, + product_id=product.id, + product_count=product_count, + time=time, + message=message, + ) + sql_session.add(joint_transaction) + sql_session.flush() # Ensure joint_transaction gets an ID + + transactions = [joint_transaction] + + for user in users: + buy_transaction = Transaction.joint_buy_product( + user_id=user.id, + joint_transaction_id=joint_transaction.id, + time=time, + message=message, + ) + sql_session.add(buy_transaction) + transactions.append(buy_transaction) + + sql_session.commit() + return transactions diff --git a/dibbler/queries/product_owners.py b/dibbler/queries/product_owners.py new file mode 100644 index 0000000..43df648 --- /dev/null +++ b/dibbler/queries/product_owners.py @@ -0,0 +1,296 @@ +from dataclasses import dataclass +from datetime import datetime + +from sqlalchemy import ( + CTE, + BindParameter, + and_, + bindparam, + case, + func, + or_, + select, +) +from sqlalchemy.orm import Session + +from dibbler.lib.query_helpers import CONST_NONE, CONST_ONE, CONST_TRUE, CONST_ZERO +from dibbler.models import ( + Product, + Transaction, + TransactionType, + User, +) +from dibbler.queries.product_stock import _product_stock_query + + +def _product_owners_query( + product_id: BindParameter[int] | int, + use_cache: bool = True, + until: BindParameter[datetime] | datetime | None = None, + until_inclusive: bool = True, + cte_name: str = "rec_cte", + trx_subset_name: str = "trx_subset", +) -> CTE: + """ + The inner query for inferring the owners of a given product. + """ + + if use_cache: + print("WARNING: Using cache for users owning product query is not implemented yet.") + + if isinstance(product_id, int): + product_id = bindparam("product_id", value=product_id) + + if isinstance(until, datetime): + until = BindParameter("until", value=until) + + product_stock = _product_stock_query( + product_id=product_id, + use_cache=use_cache, + until=until, + until_inclusive=until_inclusive, + ) + + # Subset of transactions that we'll want to iterate over. + trx_subset = ( + select( + func.row_number().over(order_by=Transaction.time.desc()).label("i"), + Transaction.time, + Transaction.id, + Transaction.type_, + Transaction.user_id, + Transaction.product_count, + ) + # TODO: maybe add value constraint on ADJUST_STOCK? + .where( + or_( + Transaction.type_ == TransactionType.ADD_PRODUCT.as_literal_column(), + and_( + Transaction.type_ == TransactionType.ADJUST_STOCK.as_literal_column(), + Transaction.product_count > CONST_ZERO, + ), + ), + Transaction.product_id == product_id, + (Transaction.time <= until if until_inclusive else Transaction.time < until) + if until is not None + else CONST_TRUE, + ) + .order_by(Transaction.time.desc()) + .subquery(trx_subset_name) + ) + + initial_element = select( + CONST_ZERO.label("i"), + CONST_ZERO.label("time"), + CONST_NONE.label("transaction_id"), + CONST_NONE.label("user_id"), + CONST_ZERO.label("product_count"), + product_stock.scalar_subquery().label("products_left_to_account_for"), + ) + + recursive_cte = initial_element.cte(name=cte_name, recursive=True) + + recursive_elements = ( + select( + trx_subset.c.i, + trx_subset.c.time, + trx_subset.c.id.label("transaction_id"), + # Who added the product (if any) + case( + # Someone adds the product -> they own it + ( + trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(), + trx_subset.c.user_id, + ), + else_=CONST_NONE, + ).label("user_id"), + # How many products did they add (if any) + case( + # Someone adds the product -> they added a certain amount of products + ( + trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(), + trx_subset.c.product_count, + ), + # Stock got adjusted upwards -> consider those products as added by nobody + ( + (trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column()) + and (trx_subset.c.product_count > CONST_ZERO), + trx_subset.c.product_count, + ), + else_=CONST_ZERO, + ).label("product_count"), + # How many products left to account for + case( + # Someone adds the product -> known owner, decrease the number of products left to account for + ( + trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(), + recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count, + ), + # Stock got adjusted upwards -> none owner, decrease the number of products left to account for + ( + and_( + trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column(), + trx_subset.c.product_count > CONST_ZERO, + ), + recursive_cte.c.products_left_to_account_for - trx_subset.c.product_count, + ), + else_=recursive_cte.c.products_left_to_account_for, + ).label("products_left_to_account_for"), + ) + .select_from(trx_subset) + .where( + and_( + trx_subset.c.i == recursive_cte.c.i + CONST_ONE, + # Base case: stop if we've accounted for all products + recursive_cte.c.products_left_to_account_for > CONST_ZERO, + ) + ) + ) + + return recursive_cte.union_all(recursive_elements) + + +@dataclass +class ProductOwnersLogEntry: + transaction: Transaction + user: User | None + products_left_to_account_for: int + + +# TODO: add until datetime parameter + + +def product_owners_log( + sql_session: Session, + product: Product, + use_cache: bool = True, + until: Transaction | None = None, + until_inclusive: bool = True, +) -> list[ProductOwnersLogEntry]: + """ + Returns a log of the product ownership calculation for the given product. + + If 'until' is given, only transactions up to that time are considered. + """ + + if product.id is None: + raise ValueError("Product must be persisted in the database.") + + if until is not None and until.id is None: + raise ValueError("'until' transaction must be persisted in the database.") + + recursive_cte = _product_owners_query( + product_id=product.id, + use_cache=use_cache, + until=until.time if until else None, + until_inclusive=until_inclusive, + ) + + result = sql_session.execute( + select( + Transaction, + User, + recursive_cte.c.products_left_to_account_for, + ) + .select_from(recursive_cte) + .join( + Transaction, + onclause=Transaction.id == recursive_cte.c.transaction_id, + ) + .join( + User, + onclause=User.id == recursive_cte.c.user_id, + isouter=True, + ) + .order_by(recursive_cte.c.time.desc()) + ).all() + + if result is None: + # If there are no transactions for this product, the query should return an empty list, not None. + raise RuntimeError( + f"Something went wrong while calculating the owner log for product {product.name} (ID: {product.id})." + ) + + return [ + ProductOwnersLogEntry( + transaction=row[0], + user=row[1], + products_left_to_account_for=row[2], + ) + for row in result + ] + + +# TODO: add until transaction parameter + + +def product_owners( + sql_session: Session, + product: Product, + use_cache: bool = True, + until: datetime | None = None, + until_inclusive: bool = True, +) -> list[User | None]: + """ + Returns an ordered list of users owning the given product. + + If 'until' is given, only transactions up to that time are considered. + """ + + if product.id is None: + raise ValueError("Product must be persisted in the database.") + + recursive_cte = _product_owners_query( + product_id=product.id, + use_cache=use_cache, + until=until, + until_inclusive=until_inclusive, + ) + + db_result = sql_session.execute( + select( + recursive_cte.c.products_left_to_account_for, + recursive_cte.c.product_count, + User, + ) + .join(User, User.id == recursive_cte.c.user_id, isouter=True) + .order_by(recursive_cte.c.time.desc()) + ).all() + + print(db_result) + + result: list[User | None] = [] + none_count = 0 + + # We are moving backwards through history, but this is the order we want to return the list + # There are 3 cases: + # User is not none -> add user product_count times + # User is none, and product_count is not 0 -> add None product_count times + # User is none, and product_count is 0 -> check how much products are left to account for, + + # TODO: embed this into the query itself? + for products_left_to_account_for, product_count, user in db_result: + if user is not None: + if products_left_to_account_for < 0: + result.extend([user] * (product_count + products_left_to_account_for)) + else: + result.extend([user] * product_count) + elif product_count != 0: + if products_left_to_account_for < 0: + none_count += product_count + products_left_to_account_for + else: + none_count += product_count + else: + pass + + # none_count += user_count + # else: + + result.extend([None] * none_count) + + # # NOTE: if the last line exeeds the product count, we need to truncate it + # result.extend([user] * min(user_count, products_left_to_account_for)) + + # redistribute the user counts to a list of users + + return list(result) diff --git a/dibbler/queries/product_price.py b/dibbler/queries/product_price.py new file mode 100644 index 0000000..12028f3 --- /dev/null +++ b/dibbler/queries/product_price.py @@ -0,0 +1,285 @@ +import math +from dataclasses import dataclass +from datetime import datetime + +from sqlalchemy import ( + BindParameter, + ColumnElement, + Integer, + asc, + case, + cast, + func, + select, +) +from sqlalchemy.orm import Session + +from dibbler.lib.query_helpers import CONST_NONE, CONST_ONE, CONST_TRUE, CONST_ZERO +from dibbler.models import ( + Product, + Transaction, + TransactionType, +) +from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENTAGE + + +def _product_price_query( + product_id: int | ColumnElement[int], + use_cache: bool = True, + until: BindParameter[datetime] | datetime | None = None, + until_inclusive: bool = True, + cte_name: str = "rec_cte", + trx_subset_name: str = "trx_subset", +): + """ + The inner query for calculating the product price. + """ + + if use_cache: + print("WARNING: Using cache for product price query is not implemented yet.") + + if isinstance(product_id, int): + product_id = BindParameter("product_id", value=product_id) + + if isinstance(until, datetime): + until = BindParameter("until", value=until) + + initial_element = select( + CONST_ZERO.label("i"), + CONST_ZERO.label("time"), + CONST_NONE.label("transaction_id"), + CONST_ZERO.label("price"), + CONST_ZERO.label("product_count"), + ) + + recursive_cte = initial_element.cte(name=cte_name, recursive=True) + + # Subset of transactions that we'll want to iterate over. + trx_subset = ( + select( + func.row_number().over(order_by=asc(Transaction.time)).label("i"), + Transaction.id, + Transaction.time, + Transaction.type_, + Transaction.product_count, + Transaction.per_product, + ) + .where( + Transaction.type_.in_( + [ + TransactionType.BUY_PRODUCT.as_literal_column(), + TransactionType.ADD_PRODUCT.as_literal_column(), + TransactionType.ADJUST_STOCK.as_literal_column(), + TransactionType.JOINT.as_literal_column(), + ] + ), + Transaction.product_id == product_id, + (Transaction.time <= until if until_inclusive else Transaction.time < until) + if until is not None + else CONST_TRUE, + ) + .order_by(Transaction.time.asc()) + .subquery(trx_subset_name) + ) + + recursive_elements = ( + select( + trx_subset.c.i, + trx_subset.c.time, + trx_subset.c.id.label("transaction_id"), + case( + # Someone buys the product -> price remains the same. + ( + trx_subset.c.type_ == TransactionType.BUY_PRODUCT.as_literal_column(), + recursive_cte.c.price, + ), + # Someone adds the product -> price is recalculated based on + # product count, previous price, and new price. + ( + trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(), + cast( + func.ceil( + ( + recursive_cte.c.price + * func.max(recursive_cte.c.product_count, CONST_ZERO) + + trx_subset.c.per_product * trx_subset.c.product_count + ) + / ( + # The running product count can be negative if the accounting is bad. + # This ensures that we never end up with negative prices or zero divisions + # and other disastrous phenomena. + func.max(recursive_cte.c.product_count, CONST_ZERO) + + trx_subset.c.product_count + ) + ), + Integer, + ), + ), + # Someone adjusts the stock -> price remains the same. + ( + trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column(), + recursive_cte.c.price, + ), + # Should never happen + else_=recursive_cte.c.price, + ).label("price"), + case( + # Someone buys the product -> product count is reduced. + ( + trx_subset.c.type_ == TransactionType.BUY_PRODUCT.as_literal_column(), + recursive_cte.c.product_count - trx_subset.c.product_count, + ), + ( + trx_subset.c.type_ == TransactionType.JOINT.as_literal_column(), + recursive_cte.c.product_count - trx_subset.c.product_count, + ), + # Someone adds the product -> product count is increased. + ( + trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(), + recursive_cte.c.product_count + trx_subset.c.product_count, + ), + # Someone adjusts the stock -> product count is adjusted. + ( + trx_subset.c.type_ == TransactionType.ADJUST_STOCK.as_literal_column(), + recursive_cte.c.product_count + trx_subset.c.product_count, + ), + # Should never happen + else_=recursive_cte.c.product_count, + ).label("product_count"), + ) + .select_from(trx_subset) + .where(trx_subset.c.i == recursive_cte.c.i + CONST_ONE) + ) + + return recursive_cte.union_all(recursive_elements) + + +# TODO: create a function for the log that pretty prints the log entries +# for debugging purposes + + +@dataclass +class ProductPriceLogEntry: + transaction: Transaction + price: int + product_count: int + + +# TODO: add until datetime parameter + + +def product_price_log( + sql_session: Session, + product: Product, + use_cache: bool = True, + until: Transaction | None = None, + until_inclusive: bool = True, +) -> list[ProductPriceLogEntry]: + """ + Calculates the price of a product and returns a log of the price changes. + """ + + if product.id is None: + raise ValueError("Product must be persisted in the database.") + + if until is not None and until.id is None: + raise ValueError("'until' transaction must be persisted in the database.") + + recursive_cte = _product_price_query( + product.id, + use_cache=use_cache, + until=until.time if until else None, + until_inclusive=until_inclusive, + ) + + result = sql_session.execute( + select( + Transaction, + recursive_cte.c.price, + recursive_cte.c.product_count, + ) + .select_from(recursive_cte) + .join( + Transaction, + onclause=Transaction.id == recursive_cte.c.transaction_id, + ) + .order_by(recursive_cte.c.i.asc()) + ).all() + + if result is None: + # If there are no transactions for this product, the query should return an empty list, not None. + raise RuntimeError( + f"Something went wrong while calculating the price log for product {product.name} (ID: {product.id})." + ) + + return [ + ProductPriceLogEntry( + transaction=row[0], + price=row.price, + product_count=row.product_count, + ) + for row in result + ] + + +# TODO: add until datetime parameter + + +def product_price( + sql_session: Session, + product: Product, + use_cache: bool = True, + until: Transaction | None = None, + until_inclusive: bool = True, + include_interest: bool = False, +) -> int: + """ + Calculates the price of a product. + """ + + if product.id is None: + raise ValueError("Product must be persisted in the database.") + + if until is not None and until.id is None: + raise ValueError("'until' transaction must be persisted in the database.") + + recursive_cte = _product_price_query( + product.id, + use_cache=use_cache, + until=until.time if until else None, + until_inclusive=until_inclusive, + ) + + # TODO: optionally verify subresults: + # - product_count should never be negative (but this happens sometimes, so just a warning) + # - price should never be negative + + result = sql_session.scalars( + select(recursive_cte.c.price) + .order_by(recursive_cte.c.i.desc()) + .limit(CONST_ONE) + .offset(CONST_ZERO) + ).one_or_none() + + if result is None: + # If there are no transactions for this product, the query should return 0, not None. + raise RuntimeError( + f"Something went wrong while calculating the price for product {product.name} (ID: {product.id})." + ) + + if include_interest: + interest_rate = ( + sql_session.scalar( + select(Transaction.interest_rate_percent) + .where( + Transaction.type_ == TransactionType.ADJUST_INTEREST, + CONST_TRUE if until is None else Transaction.time <= until.time, + ) + .order_by(Transaction.time.desc()) + .limit(CONST_ONE) + ) + or DEFAULT_INTEREST_RATE_PERCENTAGE + ) + result = math.ceil(result * interest_rate / 100) + + return result diff --git a/dibbler/queries/product_stock.py b/dibbler/queries/product_stock.py new file mode 100644 index 0000000..4482acf --- /dev/null +++ b/dibbler/queries/product_stock.py @@ -0,0 +1,113 @@ +from datetime import datetime +from typing import Tuple + +from sqlalchemy import ( + BindParameter, + Select, + case, + func, + select, +) +from sqlalchemy.orm import Session + +from dibbler.lib.query_helpers import CONST_TRUE +from dibbler.models import ( + Product, + Transaction, + TransactionType, +) + + +def _product_stock_query( + product_id: BindParameter[int] | int, + use_cache: bool = True, + until: BindParameter[datetime] | datetime | None = None, + until_inclusive: bool = True, +) -> Select[Tuple[int]]: + """ + The inner query for calculating the product stock. + """ + + if use_cache: + print("WARNING: Using cache for product stock query is not implemented yet.") + + if isinstance(product_id, int): + product_id = BindParameter("product_id", value=product_id) + + if isinstance(until, datetime): + until = BindParameter("until", value=until) + + query = select( + func.sum( + case( + ( + Transaction.type_ == TransactionType.ADD_PRODUCT.as_literal_column(), + Transaction.product_count, + ), + ( + Transaction.type_ == TransactionType.ADJUST_STOCK.as_literal_column(), + Transaction.product_count, + ), + ( + Transaction.type_ == TransactionType.BUY_PRODUCT.as_literal_column(), + -Transaction.product_count, + ), + ( + Transaction.type_ == TransactionType.JOINT.as_literal_column(), + -Transaction.product_count, + ), + ( + Transaction.type_ == TransactionType.THROW_PRODUCT.as_literal_column(), + -Transaction.product_count, + ), + else_=0, + ) + ).label("stock") + ).where( + Transaction.type_.in_( + [ + TransactionType.ADD_PRODUCT.as_literal_column(), + TransactionType.ADJUST_STOCK.as_literal_column(), + TransactionType.BUY_PRODUCT.as_literal_column(), + TransactionType.JOINT.as_literal_column(), + TransactionType.THROW_PRODUCT.as_literal_column(), + ] + ), + Transaction.product_id == product_id, + (Transaction.time <= until if until_inclusive else Transaction.time < until) + if until is not None + else CONST_TRUE, + ) + + return query + + +# TODO: add until transaction parameter + + +def product_stock( + sql_session: Session, + product: Product, + use_cache: bool = True, + until: datetime | None = None, + until_inclusive: bool = True, +) -> int: + """ + Returns the number of products in stock. + + If 'until' is given, only transactions up to that time are considered. + """ + + if product.id is None: + raise ValueError("Product must be persisted in the database.") + + query = _product_stock_query( + product_id=product.id, + use_cache=use_cache, + until=until, + until_inclusive=until_inclusive, + ) + + result = sql_session.scalars(query).one_or_none() + + return result or 0 diff --git a/dibbler/queries/search_product.py b/dibbler/queries/search_product.py new file mode 100644 index 0000000..b017d0d --- /dev/null +++ b/dibbler/queries/search_product.py @@ -0,0 +1,39 @@ +from sqlalchemy import and_, literal, not_, or_, select +from sqlalchemy.orm import Session + +from dibbler.models import Product + + +def search_product( + string: str, + sql_session: Session, + find_hidden_products=False, +) -> Product | list[Product]: + exact_match = sql_session.scalars( + select(Product).where( + or_( + Product.bar_code == string, + and_( + Product.name == string, + literal(True) if find_hidden_products else not_(Product.hidden), + ), + ) + ) + ).first() + + if exact_match: + return exact_match + + product_list = sql_session.scalars( + select(Product).where( + or_( + Product.bar_code.ilike(f"%{string}%"), + and_( + Product.name.ilike(f"%{string}%"), + literal(True) if find_hidden_products else not_(Product.hidden), + ), + ) + ) + ).all() + + return list(product_list) diff --git a/dibbler/queries/search_user.py b/dibbler/queries/search_user.py new file mode 100644 index 0000000..5cc9c00 --- /dev/null +++ b/dibbler/queries/search_user.py @@ -0,0 +1,36 @@ +from sqlalchemy import or_, select +from sqlalchemy.orm import Session + +from dibbler.models import User + + +def search_user( + string: str, + sql_session: Session, +) -> User | list[User]: + string = string.lower() + + exact_match = sql_session.scalars( + select(User).where( + or_( + User.name == string, + User.card == string, + User.rfid == string, + ) + ) + ).first() + + if exact_match: + return exact_match + + user_list = sql_session.scalars( + select(User).where( + or_( + User.name.ilike(f"%{string}%"), + User.card.ilike(f"%{string}%"), + User.rfid.ilike(f"%{string}%"), + ) + ) + ).all() + + return list(user_list) diff --git a/dibbler/queries/transaction_log.py b/dibbler/queries/transaction_log.py new file mode 100644 index 0000000..2bd3557 --- /dev/null +++ b/dibbler/queries/transaction_log.py @@ -0,0 +1,95 @@ +from datetime import datetime + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from dibbler.models import ( + Product, + Transaction, + TransactionType, + User, +) + + +# TODO: should this include full joint transactions that involve a user? +# TODO: should this involve throw-away transactions that affects a user? +def transaction_log( + sql_session: Session, + user: User | None = None, + product: Product | None = None, + exclusive_after: bool = False, + after_time: datetime | None = None, + after_transaction_id: int | None = None, + exclusive_before: bool = False, + before_time: datetime | None = None, + before_transaction_id: int | None = None, + transaction_type: list[TransactionType] | None = None, + negate_transaction_type_filter: bool = False, + limit: int | None = None, +) -> list[Transaction]: + """ + Retrieve the transaction log, optionally filtered. + + Only one of `user` or `product` may be specified. + Only one of `after_time` or `after_transaction_id` may be specified. + Only one of `before_time` or `before_transaction_id` may be specified. + + The before and after filters are inclusive by default. + """ + + if not (user is None or product is None): + raise ValueError("Cannot filter by both user and product.") + + if user is not None and user.id is None: + raise ValueError("User must be persisted in the database.") + + if product is not None and product.id is None: + raise ValueError("Product must be persisted in the database.") + + if not (after_time is None or after_transaction_id is None): + raise ValueError("Cannot filter by both after_time and after_transaction_id.") + + if not (before_time is None or before_transaction_id is None): + raise ValueError("Cannot filter by both before_time and before_transaction_id.") + + query = select(Transaction) + if user is not None: + query = query.where(Transaction.user_id == user.id) + if product is not None: + query = query.where(Transaction.product_id == product.id) + + if after_time is not None: + if exclusive_after: + query = query.where(Transaction.time > after_time) + else: + query = query.where(Transaction.time >= after_time) + if after_transaction_id is not None: + if exclusive_after: + query = query.where(Transaction.id > after_transaction_id) + else: + query = query.where(Transaction.id >= after_transaction_id) + + if before_time is not None: + if exclusive_before: + query = query.where(Transaction.time < before_time) + else: + query = query.where(Transaction.time <= before_time) + if before_transaction_id is not None: + if exclusive_before: + query = query.where(Transaction.id < before_transaction_id) + else: + query = query.where(Transaction.id <= before_transaction_id) + + if transaction_type is not None: + if negate_transaction_type_filter: + query = query.where(~Transaction.type_.in_(transaction_type)) + else: + query = query.where(Transaction.type_.in_(transaction_type)) + + if limit is not None: + query = query.limit(limit) + + query = query.order_by(Transaction.time.asc(), Transaction.id.asc()) + result = sql_session.scalars(query).all() + + return list(result) diff --git a/dibbler/queries/user_balance.py b/dibbler/queries/user_balance.py new file mode 100644 index 0000000..297f8bc --- /dev/null +++ b/dibbler/queries/user_balance.py @@ -0,0 +1,520 @@ +from dataclasses import dataclass +from datetime import datetime +from typing import Tuple + +from sqlalchemy import ( + CTE, + BindParameter, + Float, + Integer, + Select, + and_, + case, + cast, + column, + func, + or_, + select, +) +from sqlalchemy.orm import Session, aliased + +from dibbler.lib.query_helpers import CONST_NONE, CONST_ONE, CONST_TRUE, CONST_ZERO, const +from dibbler.models import ( + Transaction, + TransactionType, + User, +) +from dibbler.models.Transaction import ( + DEFAULT_INTEREST_RATE_PERCENTAGE, + DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE, + DEFAULT_PENALTY_THRESHOLD, +) +from dibbler.queries.product_price import _product_price_query + + +def _joint_transaction_query( + user_id: BindParameter[int] | int, + use_cache: bool = True, + until: BindParameter[datetime] | None = None, + until_inclusive: bool = True, +) -> Select[Tuple[int, int, int]]: + """ + The inner query for getting joint transactions relevant to a user. + + This scans for JOINT_BUY_PRODUCT transactions made by the user, + then finds the corresponding JOINT transactions, and counts how many "shares" + of the joint transaction the user has, as well as the total number of shares. + """ + + # First, select all joint buy product transactions for the given user + # sub_joint_transaction = aliased(Transaction, name="right_trx") + sub_joint_transaction = ( + select(Transaction.joint_transaction_id.distinct().label("joint_transaction_id")) + .where( + Transaction.type_ == TransactionType.JOINT_BUY_PRODUCT.as_literal_column(), + Transaction.user_id == user_id, + ) + .subquery("sub_joint_transaction") + ) + + # Join those with their main joint transaction + # (just use Transaction) + + # Then, count how many users are involved in each joint transaction + joint_transaction_count = aliased(Transaction, name="count_trx") + + joint_transaction = ( + select( + Transaction.id, + # Shares the user has in the transaction, + func.sum( + case( + (joint_transaction_count.user_id == user_id, CONST_ONE), + else_=CONST_ZERO, + ) + ).label("user_shares"), + # The total number of shares in the transaction, + func.count(joint_transaction_count.id).label("user_count"), + ) + .select_from(sub_joint_transaction) + .join( + Transaction, + onclause=Transaction.id == sub_joint_transaction.c.joint_transaction_id, + ) + .join( + joint_transaction_count, + onclause=joint_transaction_count.joint_transaction_id == Transaction.id, + ) + .where( + ( + sub_joint_transaction.c.time <= until + if until_inclusive + else sub_joint_transaction.c.time < until + ) + if until is not None + else CONST_TRUE, + ) + .group_by(joint_transaction_count.joint_transaction_id) + ) + + return joint_transaction + + +def _non_joint_transaction_query( + user_id: BindParameter[int] | int, + use_cache: bool = True, + until: BindParameter[datetime] | None = None, + until_inclusive: bool = True, +) -> Select[Tuple[int, None, None]]: + """ + The inner query for getting non-joint transactions relevant to a user. + """ + + query = select( + Transaction.id, + CONST_NONE.label("user_shares"), + CONST_NONE.label("user_count"), + ).where( + or_( + and_( + Transaction.user_id == user_id, + Transaction.type_.in_( + [ + TransactionType.ADD_PRODUCT.as_literal_column(), + TransactionType.ADJUST_BALANCE.as_literal_column(), + TransactionType.BUY_PRODUCT.as_literal_column(), + TransactionType.TRANSFER.as_literal_column(), + ] + ), + ), + and_( + Transaction.type_ == TransactionType.TRANSFER.as_literal_column(), + Transaction.transfer_user_id == user_id, + ), + Transaction.type_.in_( + [ + TransactionType.THROW_PRODUCT.as_literal_column(), + TransactionType.ADJUST_INTEREST.as_literal_column(), + TransactionType.ADJUST_PENALTY.as_literal_column(), + ] + ), + ), + (Transaction.time <= until if until_inclusive else Transaction.time < until) + if until is not None + else CONST_TRUE, + ) + + return query + + +def _product_cost_expression( + product_count_column, + product_id_column, + interest_rate_percent_column, + user_balance_column, + penalty_threshold_column, + penalty_multiplier_percent_column, + use_cache: bool = True, + until: BindParameter[datetime] | None = None, + until_inclusive: bool = True, + cte_name: str = "product_price_cte", + trx_subset_name: str = "product_price_trx_subset", +): + expression = ( + product_count_column + # Price of a single product, accounted for penalties and interest. + * cast( + func.ceil( + # TODO: This can get quite expensive real quick, so we should do some caching of the + # product prices somehow. + # Base price + ( + # FIXME: this always returns 0 for some reason... + select(cast(column("price"), Float)) + .select_from( + _product_price_query( + product_id_column, + use_cache=use_cache, + until=until, + until_inclusive=until_inclusive, + cte_name=cte_name, + trx_subset_name=trx_subset_name, + ) + ) + .order_by(column("i").desc()) + .limit(CONST_ONE) + ).scalar_subquery() + # TODO: should interest be applied before or after the penalty multiplier? + # at the moment of writing, after sound right, but maybe ask someone? + # Interest + * (cast(interest_rate_percent_column, Float) / const(100)) + # TODO: these should be added together, not multiplied, see specification + # Penalty + * case( + ( + user_balance_column < penalty_threshold_column, + (cast(penalty_multiplier_percent_column, Float) / const(100)), + ), + else_=const(1.0), + ) + ), + Integer, + ) + ) + + return expression + + +def _user_balance_query( + user_id: BindParameter[int] | int, + use_cache: bool = True, + until: BindParameter[datetime] | datetime | None = None, + until_inclusive: bool = True, + cte_name: str = "rec_cte", + trx_subset_name: str = "trx_subset", +) -> CTE: + """ + The inner query for calculating the user's balance. + """ + + if use_cache: + print("WARNING: Using cache for user balance query is not implemented yet.") + + if isinstance(user_id, int): + user_id = BindParameter("user_id", value=user_id) + + if isinstance(until, datetime): + until = BindParameter("until", value=until) + + initial_element = select( + CONST_ZERO.label("i"), + CONST_ZERO.label("time"), + CONST_NONE.label("transaction_id"), + CONST_ZERO.label("balance"), + const(DEFAULT_INTEREST_RATE_PERCENTAGE).label("interest_rate_percent"), + const(DEFAULT_PENALTY_THRESHOLD).label("penalty_threshold"), + const(DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE).label("penalty_multiplier_percent"), + ) + + recursive_cte = initial_element.cte(name=cte_name, recursive=True) + + trx_subset_subset = ( + _non_joint_transaction_query( + user_id=user_id, + use_cache=use_cache, + until=until, + until_inclusive=until_inclusive, + ) + .union_all( + _joint_transaction_query( + user_id=user_id, + use_cache=use_cache, + until=until, + until_inclusive=until_inclusive, + ) + ) + .subquery(f"{trx_subset_name}_subset") + ) + + # Subset of transactions that we'll want to iterate over. + trx_subset = ( + select( + func.row_number().over(order_by=Transaction.time.asc()).label("i"), + Transaction.id, + Transaction.amount, + Transaction.interest_rate_percent, + Transaction.penalty_multiplier_percent, + Transaction.penalty_threshold, + Transaction.product_count, + Transaction.product_id, + Transaction.time, + Transaction.transfer_user_id, + Transaction.type_, + trx_subset_subset.c.user_shares, + trx_subset_subset.c.user_count, + ) + .select_from(trx_subset_subset) + .join( + Transaction, + onclause=Transaction.id == trx_subset_subset.c.id, + ) + .order_by(Transaction.time.asc()) + .subquery(trx_subset_name) + ) + + recursive_elements = ( + select( + trx_subset.c.i, + trx_subset.c.time, + trx_subset.c.id.label("transaction_id"), + case( + # Adjusts balance -> balance gets adjusted + ( + trx_subset.c.type_ == TransactionType.ADJUST_BALANCE.as_literal_column(), + recursive_cte.c.balance + trx_subset.c.amount, + ), + # Adds a product -> balance increases + ( + trx_subset.c.type_ == TransactionType.ADD_PRODUCT.as_literal_column(), + recursive_cte.c.balance + trx_subset.c.amount, + ), + # Buys a product -> balance decreases + ( + trx_subset.c.type_ == TransactionType.BUY_PRODUCT.as_literal_column(), + recursive_cte.c.balance + - _product_cost_expression( + product_count_column=trx_subset.c.product_count, + product_id_column=trx_subset.c.product_id, + interest_rate_percent_column=recursive_cte.c.interest_rate_percent, + user_balance_column=recursive_cte.c.balance, + penalty_threshold_column=recursive_cte.c.penalty_threshold, + penalty_multiplier_percent_column=recursive_cte.c.penalty_multiplier_percent, + use_cache=use_cache, + until=until, + until_inclusive=until_inclusive, + cte_name=f"{cte_name}_price", + trx_subset_name=f"{trx_subset_name}_price", + ), + ), + # Joint transaction -> balance decreases proportionally + ( + trx_subset.c.type_ == TransactionType.JOINT.as_literal_column(), + recursive_cte.c.balance + - func.ceil( + _product_cost_expression( + product_count_column=trx_subset.c.product_count, + product_id_column=trx_subset.c.product_id, + interest_rate_percent_column=recursive_cte.c.interest_rate_percent, + user_balance_column=recursive_cte.c.balance, + penalty_threshold_column=recursive_cte.c.penalty_threshold, + penalty_multiplier_percent_column=recursive_cte.c.penalty_multiplier_percent, + use_cache=use_cache, + until=until, + until_inclusive=until_inclusive, + cte_name=f"{cte_name}_joint_price", + trx_subset_name=f"{trx_subset_name}_joint_price", + ) + # TODO: move this inside of the product cost expression + * trx_subset.c.user_shares + / trx_subset.c.user_count + ), + ), + # Transfers money to self -> balance increases + ( + and_( + trx_subset.c.type_ == TransactionType.TRANSFER.as_literal_column(), + trx_subset.c.transfer_user_id == user_id, + ), + recursive_cte.c.balance + trx_subset.c.amount, + ), + # Transfers money from self -> balance decreases + ( + and_( + trx_subset.c.type_ == TransactionType.TRANSFER.as_literal_column(), + trx_subset.c.transfer_user_id != user_id, + ), + recursive_cte.c.balance - trx_subset.c.amount, + ), + # Throws a product -> if the user is considered to have bought it, balance increases + # TODO: # ( + # trx_subset.c.type_ == TransactionType.THROW_PRODUCT, + # recursive_cte.c.balance + trx_subset.c.amount, + # ), + # Interest adjustment -> balance stays the same + # Penalty adjustment -> balance stays the same + else_=recursive_cte.c.balance, + ).label("balance"), + case( + ( + trx_subset.c.type_ == TransactionType.ADJUST_INTEREST.as_literal_column(), + trx_subset.c.interest_rate_percent, + ), + else_=recursive_cte.c.interest_rate_percent, + ).label("interest_rate_percent"), + case( + ( + trx_subset.c.type_ == TransactionType.ADJUST_PENALTY.as_literal_column(), + trx_subset.c.penalty_threshold, + ), + else_=recursive_cte.c.penalty_threshold, + ).label("penalty_threshold"), + case( + ( + trx_subset.c.type_ == TransactionType.ADJUST_PENALTY.as_literal_column(), + trx_subset.c.penalty_multiplier_percent, + ), + else_=recursive_cte.c.penalty_multiplier_percent, + ).label("penalty_multiplier_percent"), + ) + .select_from(trx_subset) + .where(trx_subset.c.i == recursive_cte.c.i + CONST_ONE) + ) + + return recursive_cte.union_all(recursive_elements) + + +# TODO: create a function for the log that pretty prints the log entries +# for debugging purposes + + +@dataclass +class UserBalanceLogEntry: + transaction: Transaction + balance: int + interest_rate_percent: int + penalty_threshold: int + penalty_multiplier_percent: int + + def is_penalized(self) -> bool: + """ + Returns whether this exact transaction is penalized. + """ + + return False + + # return self.transaction.type_ == TransactionType.BUY_PRODUCT and prev? + + +# TODO: add until datetime parameter + + +def user_balance_log( + sql_session: Session, + user: User, + use_cache: bool = True, + until: Transaction | None = None, + until_inclusive: bool = True, +) -> list[UserBalanceLogEntry]: + """ + Returns a log of the user's balance over time, including interest and penalty adjustments. + + If 'until' is given, only transactions up to that time are considered. + """ + + if user.id is None: + raise ValueError("User must be persisted in the database.") + + if until is not None and until.id is None: + raise ValueError("'until' transaction must be persisted in the database.") + + recursive_cte = _user_balance_query( + user.id, + use_cache=use_cache, + until=until.time if until else None, + until_inclusive=until_inclusive, + ) + + result = sql_session.execute( + select( + Transaction, + recursive_cte.c.balance, + recursive_cte.c.interest_rate_percent, + recursive_cte.c.penalty_threshold, + recursive_cte.c.penalty_multiplier_percent, + ) + .select_from(recursive_cte) + .join( + Transaction, + onclause=Transaction.id == recursive_cte.c.transaction_id, + ) + .order_by(recursive_cte.c.i.asc()) + ).all() + + if result is None: + # If there are no transactions for this user, the query should return 0, not None. + raise RuntimeError( + f"Something went wrong while calculating the balance for user {user.name} (ID: {user.id})." + ) + + return [ + UserBalanceLogEntry( + transaction=row[0], + balance=row.balance, + interest_rate_percent=row.interest_rate_percent, + penalty_threshold=row.penalty_threshold, + penalty_multiplier_percent=row.penalty_multiplier_percent, + ) + for row in result + ] + + +# TODO: add until datetime parameter + + +def user_balance( + sql_session: Session, + user: User, + use_cache: bool = True, + until: Transaction | None = None, + until_inclusive: bool = True, +) -> int: + """ + Calculates the balance of a user. + + If 'until' is given, only transactions up to that time are considered. + """ + + if user.id is None: + raise ValueError("User must be persisted in the database.") + + recursive_cte = _user_balance_query( + user.id, + use_cache=use_cache, + until=until.time if until else None, + until_inclusive=until_inclusive, + ) + + result = sql_session.scalar( + select(recursive_cte.c.balance) + .order_by(recursive_cte.c.i.desc()) + .limit(CONST_ONE) + .offset(CONST_ZERO) + ) + + if result is None: + # If there are no transactions for this user, the query should return 0, not None. + raise RuntimeError( + f"Something went wrong while calculating the balance for user {user.name} (ID: {user.id})." + ) + + return result diff --git a/dibbler/queries/user_products.py b/dibbler/queries/user_products.py new file mode 100644 index 0000000..806f877 --- /dev/null +++ b/dibbler/queries/user_products.py @@ -0,0 +1,10 @@ +# This absoulutely needs a cache, else we can't stop recursing until we know all owners for all products... +# +# Since we know that the non-owned products will not get renowned by the user by other means, +# we can just check for ownership on the products that have an ADD_PRODUCT transaction for the user. +# between now and the cached time. +# +# However, the opposite way is more difficult. The cache will store which products are owned by which users, +# but we still need to check if the user passes out of ownership for the item, without needing to check past +# the cache time. Maybe we also need to store the queue number(s) per user/product combo in the cache? What if +# a user has products multiple places in the queue, interleaved with other users? diff --git a/tests/queries/__init__.py b/tests/queries/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/queries/test_add_product.py b/tests/queries/test_add_product.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/queries/test_add_user.py b/tests/queries/test_add_user.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/queries/test_adjust_interest.py b/tests/queries/test_adjust_interest.py new file mode 100644 index 0000000..5da393c --- /dev/null +++ b/tests/queries/test_adjust_interest.py @@ -0,0 +1,84 @@ +from datetime import datetime + +import pytest +from sqlalchemy.orm import Session + +from dibbler.models import Transaction, User +from dibbler.queries import adjust_interest, current_interest + + +def insert_test_data(sql_session: Session) -> User: + user = User("Test User") + sql_session.add(user) + sql_session.commit() + + return user + + +def test_adjust_interest_unitialized_user(sql_session: Session) -> None: + user = User("Uninitialized User") + + with pytest.raises(ValueError, match="User must be persisted in the database."): + adjust_interest( + sql_session, + user=user, + new_interest=4, + message="Attempting to adjust interest for uninitialized user", + ) + + +def test_adjust_interest_no_history(sql_session: Session) -> None: + user = insert_test_data(sql_session) + + adjust_interest( + sql_session, + user=user, + new_interest=3, + message="Setting initial interest rate", + ) + sql_session.commit() + + current_interest_rate = current_interest(sql_session) + + assert current_interest_rate == 3 + + +def test_adjust_interest_existing_history(sql_session: Session) -> None: + user = insert_test_data(sql_session) + + transactions = [ + Transaction.adjust_interest( + time=datetime(2023, 10, 1, 9, 0, 0), + user_id=user.id, + interest_rate_percent=5, + message="Initial interest rate", + ), + ] + sql_session.add_all(transactions) + sql_session.commit() + + current_interest_rate = current_interest(sql_session) + assert current_interest_rate == 5 + + adjust_interest( + sql_session, + user=user, + new_interest=2, + message="Adjusting interest rate", + ) + sql_session.commit() + + current_interest_rate = current_interest(sql_session) + assert current_interest_rate == 2 + + +def test_adjust_interest_negative_failure(sql_session: Session) -> None: + user = insert_test_data(sql_session) + + with pytest.raises(ValueError, match="Interest rate cannot be negative"): + adjust_interest( + sql_session, + user=user, + new_interest=-1, + message="Attempting to set negative interest rate", + ) diff --git a/tests/queries/test_adjust_penalty.py b/tests/queries/test_adjust_penalty.py new file mode 100644 index 0000000..3867eee --- /dev/null +++ b/tests/queries/test_adjust_penalty.py @@ -0,0 +1,166 @@ +from datetime import datetime + +import pytest +from sqlalchemy.orm import Session + +from dibbler.models import Transaction, User +from dibbler.models.Transaction import ( + DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE, + DEFAULT_PENALTY_THRESHOLD, +) +from dibbler.queries import adjust_penalty, current_penalty + + +def insert_test_data(sql_session: Session) -> User: + user = User("Test User") + sql_session.add(user) + sql_session.commit() + + return user + + +def test_adjust_penalty_unitialized_user(sql_session: Session) -> None: + user = User("Uninitialized User") + + with pytest.raises(ValueError, match="User must be persisted in the database."): + adjust_penalty( + sql_session, + user=user, + new_penalty=-100, + new_penalty_multiplier=110, + message="Attempting to adjust penalty for uninitialized user", + ) + + +def test_adjust_penalty_no_history(sql_session: Session) -> None: + user = insert_test_data(sql_session) + + adjust_penalty( + sql_session, + user=user, + new_penalty=-200, + message="Setting initial interest rate", + ) + sql_session.commit() + + (penalty, multiplier) = current_penalty(sql_session) + + assert penalty == -200 + assert multiplier == DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE + + +def test_adjust_penalty_multiplier_no_history(sql_session: Session) -> None: + user = insert_test_data(sql_session) + + adjust_penalty( + sql_session, + user=user, + new_penalty_multiplier=125, + message="Setting initial interest rate", + ) + sql_session.commit() + + (penalty, multiplier) = current_penalty(sql_session) + + assert penalty == DEFAULT_PENALTY_THRESHOLD + assert multiplier == 125 + + +def test_adjust_penalty_multiplier_less_than_100_fail(sql_session: Session) -> None: + user = insert_test_data(sql_session) + + adjust_penalty( + sql_session, + user=user, + new_penalty_multiplier=100, + message="Setting initial interest rate", + ) + sql_session.commit() + + (_, multiplier) = current_penalty(sql_session) + + assert multiplier == 100 + + with pytest.raises(ValueError, match="Penalty multiplier cannot be less than 100%"): + adjust_penalty( + sql_session, + user=user, + new_penalty_multiplier=99, + message="Setting initial interest rate", + ) + + +def test_adjust_penalty_existing_history(sql_session: Session) -> None: + user = insert_test_data(sql_session) + + transactions = [ + Transaction.adjust_penalty( + time=datetime(2024, 1, 1, 10, 0, 0), + user_id=user.id, + penalty_threshold=-150, + penalty_multiplier_percent=110, + message="Initial penalty settings", + ), + ] + sql_session.add_all(transactions) + sql_session.commit() + + (penalty, _) = current_penalty(sql_session) + assert penalty == -150 + + adjust_penalty( + sql_session, + user=user, + new_penalty=-250, + message="Adjusting penalty threshold", + ) + sql_session.commit() + + (penalty, _) = current_penalty(sql_session) + assert penalty == -250 + + +def test_adjust_penalty_multiplier_existing_history(sql_session: Session) -> None: + user = insert_test_data(sql_session) + + transactions = [ + Transaction.adjust_penalty( + time=datetime(2024, 1, 1, 10, 0, 0), + user_id=user.id, + penalty_threshold=-150, + penalty_multiplier_percent=110, + message="Initial penalty settings", + ), + ] + sql_session.add_all(transactions) + sql_session.commit() + + (_, multiplier) = current_penalty(sql_session) + assert multiplier == 110 + + adjust_penalty( + sql_session, + user=user, + new_penalty_multiplier=130, + message="Adjusting penalty multiplier", + ) + sql_session.commit() + (_, multiplier) = current_penalty(sql_session) + assert multiplier == 130 + + +def test_adjust_penalty_and_multiplier(sql_session: Session) -> None: + user = insert_test_data(sql_session) + + adjust_penalty( + sql_session, + user=user, + new_penalty=-300, + new_penalty_multiplier=150, + message="Setting both penalty and multiplier", + ) + sql_session.commit() + + (penalty, multiplier) = current_penalty(sql_session) + assert penalty == -300 + assert multiplier == 150 diff --git a/tests/queries/test_current_interest.py b/tests/queries/test_current_interest.py new file mode 100644 index 0000000..0cac74c --- /dev/null +++ b/tests/queries/test_current_interest.py @@ -0,0 +1,35 @@ +from datetime import datetime + +from sqlalchemy.orm import Session + +from dibbler.models.Transaction import DEFAULT_INTEREST_RATE_PERCENTAGE +from dibbler.models import Transaction, User +from dibbler.queries import current_interest + + +def test_current_interest_no_history(sql_session: Session) -> None: + assert current_interest(sql_session) == DEFAULT_INTEREST_RATE_PERCENTAGE + + +def test_current_interest_with_history(sql_session: Session) -> None: + user = User("Admin User") + sql_session.add(user) + sql_session.commit() + + transactions = [ + Transaction.adjust_interest( + time=datetime(2023, 10, 1, 10, 0, 0), + interest_rate_percent=5, + user_id=user.id, + ), + Transaction.adjust_interest( + time=datetime(2023, 11, 1, 10, 0, 0), + interest_rate_percent=7, + user_id=user.id, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + assert current_interest(sql_session) == 7 diff --git a/tests/queries/test_current_penalty.py b/tests/queries/test_current_penalty.py new file mode 100644 index 0000000..81fe975 --- /dev/null +++ b/tests/queries/test_current_penalty.py @@ -0,0 +1,42 @@ +from datetime import datetime + +from sqlalchemy.orm import Session + +from dibbler.models import Transaction, User +from dibbler.models.Transaction import ( + DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE, + DEFAULT_PENALTY_THRESHOLD, +) +from dibbler.queries import current_penalty + + +def test_current_penalty_no_history(sql_session: Session) -> None: + assert current_penalty(sql_session) == ( + DEFAULT_PENALTY_THRESHOLD, + DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE, + ) + + +def test_current_penalty_with_history(sql_session: Session) -> None: + user = User("Admin User") + sql_session.add(user) + sql_session.commit() + + transactions = [ + Transaction.adjust_penalty( + time=datetime(2023, 10, 1, 10, 0, 0), + penalty_threshold=-200, + penalty_multiplier_percent=150, + user_id=user.id, + ), + Transaction.adjust_penalty( + time=datetime(2023, 10, 2, 10, 0, 0), + penalty_threshold=-300, + penalty_multiplier_percent=200, + user_id=user.id, + ), + ] + sql_session.add_all(transactions) + sql_session.commit() + + assert current_penalty(sql_session) == (-300, 200) diff --git a/tests/queries/test_joint_buy_product.py b/tests/queries/test_joint_buy_product.py new file mode 100644 index 0000000..5bcff27 --- /dev/null +++ b/tests/queries/test_joint_buy_product.py @@ -0,0 +1,198 @@ +from datetime import datetime + +import pytest +from sqlalchemy.orm import Session + +from dibbler.models import Product, Transaction, User +from dibbler.queries import joint_buy_product + + +def insert_test_data(sql_session: Session) -> tuple[User, User, User, Product]: + user1 = User("Test User 1") + user2 = User("Test User 2") + user3 = User("Test User 3") + product = Product("1234567890123", "Test Product") + + sql_session.add_all([user1, user2, user3, product]) + sql_session.commit() + + transactions = [ + Transaction.add_product( + user_id=user1.id, + product_id=product.id, + amount=30, + per_product=10, + product_count=3, + time=datetime(2024, 1, 1, 10, 0, 0), + ) + ] + + sql_session.add_all(transactions) + sql_session.commit() + + return user1, user2, user3, product + + +def test_joint_buy_product_uninitialized_product(sql_session: Session) -> None: + user = User("Test User 1") + sql_session.add(user) + sql_session.commit() + + product = Product("1234567890123", "Uninitialized Product") + + with pytest.raises(ValueError): + joint_buy_product( + sql_session, + instigator=user, + users=[user], + product=product, + product_count=1, + ) + + +def test_joint_buy_product_no_users(sql_session: Session) -> None: + user, _, _, product = insert_test_data(sql_session) + + with pytest.raises(ValueError): + joint_buy_product( + sql_session, + instigator=user, + users=[], + product=product, + product_count=1, + ) + + +def test_joint_buy_product_uninitialized_instigator(sql_session: Session) -> None: + user, user2, _, product = insert_test_data(sql_session) + + uninitialized_user = User("Uninitialized User") + with pytest.raises(ValueError): + joint_buy_product( + sql_session, + instigator=uninitialized_user, + users=[user, user2], + product=product, + product_count=1, + ) + + +def test_joint_buy_product_uninitialized_user_in_list(sql_session: Session) -> None: + user, _, _, product = insert_test_data(sql_session) + + uninitialized_user = User("Uninitialized User") + with pytest.raises(ValueError): + joint_buy_product( + sql_session, + instigator=user, + users=[user, uninitialized_user], + product=product, + product_count=1, + ) + + +def test_joint_buy_product_invalid_product_count(sql_session: Session) -> None: + user, _, _, product = insert_test_data(sql_session) + + with pytest.raises(ValueError): + joint_buy_product( + sql_session, + instigator=user, + users=[user], + product=product, + product_count=0, + ) + + with pytest.raises(ValueError): + joint_buy_product( + sql_session, + instigator=user, + users=[user], + product=product, + product_count=-1, + ) + + +def test_joint_single_user(sql_session: Session) -> None: + user, _, _, product = insert_test_data(sql_session) + + joint_buy_product( + sql_session, + instigator=user, + users=[user], + product=product, + product_count=1, + ) + + +def test_joint_buy_product(sql_session: Session) -> None: + user, user2, user3, product = insert_test_data(sql_session) + + joint_buy_product( + sql_session, + instigator=user, + users=[user, user2, user3], + product=product, + product_count=1, + ) + + +def test_joint_buy_product_more_than_in_stock(sql_session: Session) -> None: + user, user2, user3, product = insert_test_data(sql_session) + + joint_buy_product( + sql_session, + instigator=user, + users=[user, user2, user3], + product=product, + product_count=5, + ) + + +def test_joint_buy_product_out_of_stock(sql_session: Session) -> None: + user, user2, user3, product = insert_test_data(sql_session) + + transactions = [ + Transaction.buy_product( + user_id=user.id, + product_id=product.id, + product_count=3, + time=datetime(2024, 1, 2, 10, 0, 0), + ) + ] + + sql_session.add_all(transactions) + sql_session.commit() + + joint_buy_product( + sql_session, + instigator=user, + users=[user, user2, user3], + product=product, + product_count=10, + ) + + +def test_joint_buy_product_duplicate_user(sql_session: Session) -> None: + user, user2, _, product = insert_test_data(sql_session) + + joint_buy_product( + sql_session, + instigator=user, + users=[user, user, user2], + product=product, + product_count=1, + ) + + +def test_joint_buy_product_non_involved_instigator(sql_session: Session) -> None: + user, user2, user3, product = insert_test_data(sql_session) + + with pytest.raises(ValueError): + joint_buy_product( + sql_session, + instigator=user, + users=[user2, user3], + product=product, + product_count=1, + ) diff --git a/tests/queries/test_product_owners.py b/tests/queries/test_product_owners.py new file mode 100644 index 0000000..185f4bf --- /dev/null +++ b/tests/queries/test_product_owners.py @@ -0,0 +1,311 @@ +from datetime import datetime +from pprint import pprint + +import pytest +from sqlalchemy.orm import Session + +from dibbler.models import Product, User +from dibbler.models.Transaction import Transaction +from dibbler.queries import product_owners, product_owners_log, product_stock + + +def insert_test_data(sql_session: Session) -> tuple[Product, User]: + user = User("testuser") + product = Product("1234567890123", "Test Product") + + sql_session.add(user) + sql_session.add(product) + + sql_session.commit() + + return product, user + + +def test_product_owners_unitilialized_product(sql_session: Session) -> None: + user = User("testuser") + sql_session.add(user) + sql_session.commit() + + product = Product("1234567890123", "Uninitialized Product") + + with pytest.raises(ValueError): + product_owners(sql_session, product) + + +def test_product_owners_no_transactions(sql_session: Session) -> None: + product, _ = insert_test_data(sql_session) + + pprint(product_owners_log(sql_session, product)) + + owners = product_owners(sql_session, product) + assert owners == [] + + +def test_product_owners_add_products(sql_session: Session) -> None: + product, user = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + user_id=user.id, + product_id=product.id, + amount=30, + per_product=10, + product_count=3, + time=datetime(2024, 1, 1, 10, 0, 0), + ) + ] + sql_session.add_all(transactions) + sql_session.commit() + + pprint(product_owners_log(sql_session, product)) + + owners = product_owners(sql_session, product) + assert owners == [user, user, user] + + +def test_product_owners_add_and_buy_products(sql_session: Session) -> None: + product, user = insert_test_data(sql_session) + transactions = [ + Transaction.add_product( + user_id=user.id, + product_id=product.id, + amount=30, + per_product=10, + product_count=3, + time=datetime(2024, 1, 1, 10, 0, 0), + ), + Transaction.buy_product( + user_id=user.id, + product_id=product.id, + product_count=1, + time=datetime(2024, 1, 2, 10, 0, 0), + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + pprint(product_owners_log(sql_session, product)) + + owners = product_owners(sql_session, product) + assert owners == [user, user] + + +def test_product_owners_add_and_throw_products(sql_session: Session) -> None: + product, user = insert_test_data(sql_session) + transactions = [ + Transaction.add_product( + user_id=user.id, + product_id=product.id, + amount=40, + per_product=10, + product_count=4, + time=datetime(2024, 1, 1, 10, 0, 0), + ), + Transaction.throw_product( + user_id=user.id, + product_id=product.id, + product_count=2, + time=datetime(2024, 1, 2, 10, 0, 0), + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + pprint(product_owners_log(sql_session, product)) + + owners = product_owners(sql_session, product) + assert owners == [user, user] + + +def test_product_owners_multiple_users(sql_session: Session) -> None: + product, user1 = insert_test_data(sql_session) + user2 = User("testuser2") + sql_session.add(user2) + sql_session.commit() + transactions = [ + Transaction.add_product( + user_id=user1.id, + product_id=product.id, + amount=20, + per_product=10, + product_count=2, + time=datetime(2024, 1, 1, 10, 0, 0), + ), + Transaction.add_product( + user_id=user2.id, + product_id=product.id, + amount=30, + per_product=10, + product_count=3, + time=datetime(2024, 1, 2, 10, 0, 0), + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + pprint(product_owners_log(sql_session, product)) + + owners = product_owners(sql_session, product) + assert owners == [user2, user2, user2, user1, user1] + + +def test_product_owners_adjust_stock_down(sql_session: Session) -> None: + product, user = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + user_id=user.id, + product_id=product.id, + amount=50, + per_product=10, + product_count=5, + time=datetime(2024, 1, 1, 10, 0, 0), + ), + Transaction.adjust_stock( + user_id=user.id, + product_id=product.id, + product_count=-2, + time=datetime(2024, 1, 2, 10, 0, 0), + ), + ] + sql_session.add_all(transactions) + sql_session.commit() + + pprint(product_owners_log(sql_session, product)) + + assert product_stock(sql_session, product) == 3 + + owners = product_owners(sql_session, product) + assert owners == [user, user, user] + + +def test_product_owners_adjust_stock_up(sql_session: Session) -> None: + product, user = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + user_id=user.id, + product_id=product.id, + amount=20, + per_product=10, + product_count=2, + time=datetime(2024, 1, 1, 10, 0, 0), + ), + Transaction.adjust_stock( + user_id=user.id, + product_id=product.id, + product_count=3, + time=datetime(2024, 1, 2, 10, 0, 0), + ), + ] + sql_session.add_all(transactions) + sql_session.commit() + + pprint(product_owners_log(sql_session, product)) + + owners = product_owners(sql_session, product) + assert owners == [user, user, None, None, None] + + +def test_product_owners_negative_stock(sql_session: Session) -> None: + product, user = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + user_id=user.id, + product_id=product.id, + amount=10, + per_product=10, + product_count=1, + time=datetime(2024, 1, 1, 10, 0, 0), + ), + Transaction.buy_product( + user_id=user.id, + product_id=product.id, + product_count=2, + time=datetime(2024, 1, 2, 10, 0, 0), + ), + ] + sql_session.add_all(transactions) + sql_session.commit() + + owners = product_owners(sql_session, product) + assert owners == [] + + +def test_product_owners_add_products_from_negative_stock(sql_session: Session) -> None: + product, user = insert_test_data(sql_session) + + transactions = [ + Transaction.buy_product( + user_id=user.id, + product_id=product.id, + product_count=2, + time=datetime(2024, 1, 1, 10, 0, 0), + ), + Transaction.add_product( + user_id=user.id, + product_id=product.id, + amount=30, + per_product=10, + product_count=3, + time=datetime(2024, 1, 2, 10, 0, 0), + ), + ] + sql_session.add_all(transactions) + sql_session.commit() + + pprint(product_owners_log(sql_session, product)) + + owners = product_owners(sql_session, product) + assert owners == [user] + + +def test_product_owners_interleaved_users(sql_session: Session) -> None: + product, user1 = insert_test_data(sql_session) + user2 = User("testuser2") + sql_session.add(user2) + sql_session.commit() + + transactions = [ + Transaction.add_product( + user_id=user1.id, + product_id=product.id, + amount=20, + per_product=10, + product_count=2, + time=datetime(2024, 1, 1, 10, 0, 0), + ), + Transaction.add_product( + user_id=user2.id, + product_id=product.id, + amount=30, + per_product=10, + product_count=3, + time=datetime(2024, 1, 2, 10, 0, 0), + ), + Transaction.buy_product( + user_id=user1.id, + product_id=product.id, + product_count=1, + time=datetime(2024, 1, 3, 10, 0, 0), + ), + Transaction.add_product( + user_id=user1.id, + product_id=product.id, + amount=10, + per_product=10, + product_count=1, + time=datetime(2024, 1, 4, 10, 0, 0), + ), + ] + sql_session.add_all(transactions) + sql_session.commit() + + pprint(product_owners_log(sql_session, product)) + + owners = product_owners(sql_session, product) + assert owners == [user1, user2, user2, user2, user1] diff --git a/tests/queries/test_product_price.py b/tests/queries/test_product_price.py new file mode 100644 index 0000000..e94389a --- /dev/null +++ b/tests/queries/test_product_price.py @@ -0,0 +1,432 @@ +import math +from datetime import datetime +from pprint import pprint + +import pytest +from sqlalchemy.orm import Session + +from dibbler.models import Product, Transaction, User +from dibbler.queries import product_price, product_price_log, joint_buy_product + +# TODO: see if we can use pytest_runtest_makereport to print the "product_price_log"s +# only on failures instead of inlining it in every test function + + +def insert_test_data(sql_session: Session) -> tuple[User, Product]: + user = User("Test User") + product = Product("1234567890123", "Test Product") + + sql_session.add(user) + sql_session.add(product) + sql_session.commit() + + return user, product + + +def test_product_price_uninitialized_product(sql_session: Session) -> None: + user = User("Test User") + sql_session.add(user) + sql_session.commit() + + product = Product("1234567890123", "Uninitialized Product") + + with pytest.raises(ValueError, match="Product must be persisted in the database."): + product_price(sql_session, product) + + +def test_product_price_no_transactions(sql_session: Session) -> None: + _, product = insert_test_data(sql_session) + + pprint(product_price_log(sql_session, product)) + + assert product_price(sql_session, product) == 0 + + +def test_product_price_basic_history(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 0), + amount=27 * 2 - 1, + per_product=27, + product_count=2, + user_id=user.id, + product_id=product.id, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + pprint(product_price_log(sql_session, product)) + + assert product_price(sql_session, product) == 27 + + +def test_product_price_sold_out(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 0), + amount=27 * 2 - 1, + per_product=27, + product_count=2, + user_id=user.id, + product_id=product.id, + ), + Transaction.buy_product( + time=datetime(2023, 10, 1, 12, 0, 1), + product_count=2, + user_id=user.id, + product_id=product.id, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + pprint(product_price_log(sql_session, product)) + + assert product_price(sql_session, product) == 27 + + +def test_product_price_interest(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.adjust_interest( + time=datetime(2023, 10, 1, 12, 0, 0), + interest_rate_percent=110, + user_id=user.id, + ), + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 1), + amount=27 * 2 - 1, + per_product=27, + product_count=2, + user_id=user.id, + product_id=product.id, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + pprint(product_price_log(sql_session, product)) + + product_price_ = product_price(sql_session, product) + product_price_interest = product_price(sql_session, product, include_interest=True) + + assert product_price_ == 27 + assert product_price_interest == math.ceil(27 * 1.1) + + +def test_product_price_changing_interest(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.adjust_interest( + time=datetime(2023, 10, 1, 12, 0, 0), + interest_rate_percent=110, + user_id=user.id, + ), + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 1), + amount=27 * 2 - 1, + per_product=27, + product_count=2, + user_id=user.id, + product_id=product.id, + ), + Transaction.adjust_interest( + time=datetime(2023, 10, 1, 12, 0, 2), + interest_rate_percent=120, + user_id=user.id, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + pprint(product_price_log(sql_session, product)) + + product_price_interest = product_price(sql_session, product, include_interest=True) + assert product_price_interest == math.ceil(27 * 1.2) + + +def test_product_price_old_transaction(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 1), + amount=27 * 2, + per_product=27, + product_count=2, + user_id=user.id, + product_id=product.id, + ), + # Price should be 27 + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 2), + amount=38 * 3, + per_product=38, + product_count=3, + user_id=user.id, + product_id=product.id, + ), + # price should be averaged upwards + ] + + sql_session.add_all(transactions) + sql_session.commit() + + until_transaction = transactions[0] + + pprint( + product_price_log( + sql_session, + product, + until=until_transaction, + ) + ) + + product_price_ = product_price( + sql_session, + product, + until=until_transaction, + ) + assert product_price_ == 27 + + +# Price goes up and gets rounded up to the next integer +def test_product_price_round_up_from_below(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 1), + amount=27 * 2, + per_product=27, + product_count=2, + user_id=user.id, + product_id=product.id, + ), + # Price should be 27 + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 2), + amount=38 * 3, + per_product=38, + product_count=3, + user_id=user.id, + product_id=product.id, + ), + # price should be averaged upwards + ] + + sql_session.add_all(transactions) + sql_session.commit() + + pprint(product_price_log(sql_session, product)) + + product_price_ = product_price(sql_session, product) + assert product_price_ == math.ceil((27 * 2 + 38 * 3) / (2 + 3)) + + +# Price goes down and gets rounded up to the next integer +def test_product_price_round_up_from_above(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 1), + amount=27 * 2, + per_product=27, + product_count=2, + user_id=user.id, + product_id=product.id, + ), + # Price should be 27 + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 2), + amount=20 * 3, + per_product=20, + product_count=3, + user_id=user.id, + product_id=product.id, + ), + # price should be averaged downwards + ] + + sql_session.add_all(transactions) + sql_session.commit() + + pprint(product_price_log(sql_session, product)) + + product_price_ = product_price(sql_session, product) + assert product_price_ == math.ceil((27 * 2 + 20 * 3) / (2 + 3)) + + +def test_product_price_with_negative_stock_single_addition(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 13, 0, 0), + amount=1, + per_product=10, + product_count=1, + user_id=user.id, + product_id=product.id, + ), + Transaction.buy_product( + time=datetime(2023, 10, 1, 13, 0, 1), + product_count=10, + user_id=user.id, + product_id=product.id, + ), + Transaction.add_product( + time=datetime(2023, 10, 1, 13, 0, 2), + amount=22, + per_product=22, + product_count=1, + user_id=user.id, + product_id=product.id, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + pprint(product_price_log(sql_session, product)) + + # Stock went subzero, price should be the last added product price + product1_price = product_price(sql_session, product) + assert product1_price == 22 + + +def test_product_price_with_negative_stock_multiple_additions(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 13, 0, 0), + amount=1, + per_product=10, + product_count=1, + user_id=user.id, + product_id=product.id, + ), + Transaction.buy_product( + time=datetime(2023, 10, 1, 13, 0, 1), + product_count=10, + user_id=user.id, + product_id=product.id, + ), + Transaction.add_product( + time=datetime(2023, 10, 1, 13, 0, 2), + amount=22, + per_product=22, + product_count=1, + user_id=user.id, + product_id=product.id, + ), + Transaction.add_product( + time=datetime(2023, 10, 1, 13, 0, 3), + amount=29, + per_product=29, + product_count=2, + user_id=user.id, + product_id=product.id, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + pprint(product_price_log(sql_session, product)) + + # Stock went subzero, price should be the last added product price + product1_price = product_price(sql_session, product) + assert product1_price == math.ceil(29) + + +def test_product_price_joint_transactions(sql_session: Session) -> None: + user1, product = insert_test_data(sql_session) + user2 = User("Test User 2") + sql_session.add(user2) + sql_session.commit() + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 0), + amount=30 * 3, + per_product=30, + product_count=3, + user_id=user1.id, + product_id=product.id, + ), + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 1), + amount=20 * 2, + per_product=20, + product_count=2, + user_id=user2.id, + product_id=product.id, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + product_price_ = product_price(sql_session, product) + assert product_price_ == math.ceil((30 * 3 + 20 * 2) / (3 + 2)) + + joint_buy_product( + sql_session, + time=datetime(2023, 10, 1, 12, 0, 2), + instigator=user1, + users=[user1, user2], + product=product, + product_count=2, + ) + + pprint(product_price_log(sql_session, product)) + + old_product_price = product_price_ + product_price_ = product_price(sql_session, product) + assert product_price_ == old_product_price, ( + "Joint buy transactions should not affect product price" + ) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 3), + amount=25 * 4, + per_product=25, + product_count=4, + user_id=user1.id, + product_id=product.id, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + pprint(product_price_log(sql_session, product)) + product_price_ = product_price(sql_session, product) + + # Expected state: + # Added products: + # Count: 3 + 2 = 5, Price: (30 * 3 + 20 * 2) / 5 = 26 + # Joint bought products: + # Count: 5 - 2 = 3, Price: n/a (should not affect price) + # Added products: + # Count: 3 + 4 = 7, Price: (26 * 3 + 25 * 4) / (3 + 4) = 25.57 -> 26 + + assert product_price_ == math.ceil((26 * 3 + 25 * 4) / (3 + 4)) + + +def test_product_price_until(sql_session: Session) -> None: ... diff --git a/tests/queries/test_product_stock.py b/tests/queries/test_product_stock.py new file mode 100644 index 0000000..bece79a --- /dev/null +++ b/tests/queries/test_product_stock.py @@ -0,0 +1,250 @@ +from datetime import datetime + +import pytest +from sqlalchemy.orm import Session + +from dibbler.models import Product, Transaction, User +from dibbler.queries import joint_buy_product, product_stock + + +def insert_test_data(sql_session: Session) -> tuple[User, Product]: + user = User("Test User 1") + product = Product("1234567890123", "Test Product") + sql_session.add(user) + sql_session.add(product) + sql_session.commit() + return user, product + + +def test_product_stock_uninitialized_product(sql_session: Session) -> None: + user = User("Test User 1") + sql_session.add(user) + sql_session.commit() + + product = Product("1234567890123", "Uninitialized Product") + + with pytest.raises(ValueError): + product_stock(sql_session, product) + + +def test_product_stock_basic_history(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + sql_session.commit() + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 0), + amount=10, + per_product=10, + user_id=user.id, + product_id=product.id, + product_count=1, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + assert product_stock(sql_session, product) == 1 + + +def test_product_stock_adjust_stock_up(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + user_id=user.id, + product_id=product.id, + amount=50, + per_product=10, + product_count=5, + time=datetime(2024, 1, 1, 10, 0, 0), + ), + Transaction.adjust_stock( + user_id=user.id, + product_id=product.id, + product_count=2, + time=datetime(2024, 1, 2, 10, 0, 0), + ), + ] + sql_session.add_all(transactions) + sql_session.commit() + + assert product_stock(sql_session, product) == 5 + 2 + + +def test_product_stock_adjust_stock_down(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + user_id=user.id, + product_id=product.id, + amount=50, + per_product=10, + product_count=5, + time=datetime(2024, 1, 1, 10, 0, 0), + ), + Transaction.adjust_stock( + user_id=user.id, + product_id=product.id, + product_count=-2, + time=datetime(2024, 1, 2, 10, 0, 0), + ), + ] + sql_session.add_all(transactions) + sql_session.commit() + + assert product_stock(sql_session, product) == 5 - 2 + + +def test_product_stock_complex_history(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 13, 0, 0), + amount=27 * 2, + per_product=27, + user_id=user.id, + product_id=product.id, + product_count=2, + ), + Transaction.buy_product( + time=datetime(2023, 10, 1, 13, 0, 1), + user_id=user.id, + product_id=product.id, + product_count=3, + ), + Transaction.add_product( + time=datetime(2023, 10, 1, 13, 0, 2), + amount=50 * 4, + per_product=50, + user_id=user.id, + product_id=product.id, + product_count=4, + ), + Transaction.adjust_stock( + time=datetime(2023, 10, 1, 15, 0, 0), + user_id=user.id, + product_id=product.id, + product_count=3, + ), + Transaction.adjust_stock( + time=datetime(2023, 10, 1, 15, 0, 1), + user_id=user.id, + product_id=product.id, + product_count=-2, + ), + Transaction.throw_product( + time=datetime(2023, 10, 1, 15, 0, 2), + user_id=user.id, + product_id=product.id, + product_count=1, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + assert product_stock(sql_session, product) == 2 - 3 + 4 + 3 - 2 - 1 + + +def test_product_stock_no_transactions(sql_session: Session) -> None: + _, product = insert_test_data(sql_session) + + assert product_stock(sql_session, product) == 0 + + +def test_negative_product_stock(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 14, 0, 0), + amount=50, + per_product=50, + user_id=user.id, + product_id=product.id, + product_count=1, + ), + Transaction.buy_product( + time=datetime(2023, 10, 1, 14, 0, 1), + user_id=user.id, + product_id=product.id, + product_count=2, + ), + Transaction.adjust_stock( + time=datetime(2023, 10, 1, 16, 0, 0), + user_id=user.id, + product_id=product.id, + product_count=-1, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + # The stock should be negative because we added and bought the product + assert product_stock(sql_session, product) == 1 - 2 - 1 + + +def test_product_stock_joint_transaction(sql_session: Session) -> None: + user1, product = insert_test_data(sql_session) + + user2 = User("Test User 2") + sql_session.add(user2) + sql_session.commit() + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 17, 0, 0), + amount=100, + per_product=100, + user_id=user1.id, + product_id=product.id, + product_count=5, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + joint_buy_product( + sql_session, + time=datetime(2023, 10, 1, 17, 0, 1), + instigator=user1, + users=[user1, user2], + product=product, + product_count=3, + ) + + assert product_stock(sql_session, product) == 5 - 3 + + +def test_product_stock_until(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 0), + amount=10, + per_product=10, + user_id=user.id, + product_id=product.id, + product_count=1, + ), + Transaction.add_product( + time=datetime(2023, 10, 2, 12, 0, 0), + amount=20, + per_product=10, + user_id=user.id, + product_id=product.id, + product_count=2, + ), + ] + sql_session.add_all(transactions) + sql_session.commit() + + assert product_stock(sql_session, product, until=datetime(2023, 10, 1, 23, 59, 59)) == 1 diff --git a/tests/queries/test_search_product.py b/tests/queries/test_search_product.py new file mode 100644 index 0000000..65e0f17 --- /dev/null +++ b/tests/queries/test_search_product.py @@ -0,0 +1,88 @@ +from sqlalchemy.orm import Session + +from dibbler.models import Product +from dibbler.queries import search_product + + +def insert_test_data(sql_session: Session) -> list[Product]: + products = [ + Product("1234567890123", "Test Product A"), + Product("2345678901234", "Test Product B"), + Product("3456789012345", "Another Product"), + Product("4567890123456", "Hidden Product", hidden=True), + ] + + sql_session.add_all(products) + sql_session.commit() + + return products + + +def test_search_product_no_products(sql_session: Session) -> None: + result = search_product("Nonexistent Product", sql_session) + + assert isinstance(result, list) + + assert len(result) == 0 + + +def test_search_product_name_exact_match(sql_session: Session) -> None: + insert_test_data(sql_session) + + result = search_product("Test Product A", sql_session) + assert isinstance(result, Product) + assert result.bar_code == "1234567890123" + + +def test_search_product_name_partial_match(sql_session: Session) -> None: + insert_test_data(sql_session) + + result = search_product("Test Product", sql_session) + assert isinstance(result, list) + assert len(result) == 2 + names = {product.name for product in result} + assert names == {"Test Product A", "Test Product B"} + + +def test_search_product_name_no_match(sql_session: Session) -> None: + insert_test_data(sql_session) + + result = search_product("Nonexistent", sql_session) + assert isinstance(result, list) + assert len(result) == 0 + + +def test_search_product_barcode_exact_match(sql_session: Session) -> None: + products = insert_test_data(sql_session) + + product = products[1] # Test Product B + + result = search_product(product.bar_code, sql_session) + assert isinstance(result, Product) + assert result.name == product.name + + +# Should not be able to find hidden products +def test_search_product_hidden_products(sql_session: Session) -> None: + insert_test_data(sql_session) + result = search_product("Hidden Product", sql_session) + assert isinstance(result, list) + assert len(result) == 0 + + +# Should be able to find hidden products if specified +def test_search_product_find_hidden_products(sql_session: Session) -> None: + insert_test_data(sql_session) + result = search_product("Hidden Product", sql_session, find_hidden_products=True) + assert isinstance(result, Product) + assert result.name == "Hidden Product" + + +# Should be able to find hidden products by barcode despite not specified +def test_search_product_hidden_products_by_barcode(sql_session: Session) -> None: + products = insert_test_data(sql_session) + hidden_product = products[3] # Hidden Product + + result = search_product(hidden_product.bar_code, sql_session) + assert isinstance(result, Product) + assert result.name == "Hidden Product" diff --git a/tests/queries/test_search_user.py b/tests/queries/test_search_user.py new file mode 100644 index 0000000..94b57c4 --- /dev/null +++ b/tests/queries/test_search_user.py @@ -0,0 +1,78 @@ +from sqlalchemy.orm import Session + +from dibbler.models import User +from dibbler.queries import search_user + +USER = [ + ("alice", 123), + ("bob", 125), + ("charlie", 126), + ("david", 127), + ("eve", 128), + ("evey", 129), + ("evy", 130), + ("-symbol-man", 131), + ("user_123", 132), +] + + +def setup_users(sql_session: Session) -> None: + for username, rfid in USER: + user = User(name=username, rfid=str(rfid)) + sql_session.add(user) + sql_session.commit() + + +def test_search_user_exact_match(sql_session: Session) -> None: + setup_users(sql_session) + + user = search_user("alice", sql_session) + assert user is not None + assert isinstance(user, User) + assert user.name == "alice" + + user = search_user("125", sql_session) + assert user is not None + assert isinstance(user, User) + assert user.name == "bob" + + +def test_search_user_partial_match(sql_session: Session) -> None: + setup_users(sql_session) + + users = search_user("ev", sql_session) + assert isinstance(users, list) + assert len(users) == 3 + names = {user.name for user in users} + assert names == {"eve", "evey", "evy"} + + users = search_user("user", sql_session) + assert isinstance(users, list) + assert len(users) == 1 + assert users[0].name == "user_123" + + +def test_search_user_no_match(sql_session: Session) -> None: + setup_users(sql_session) + + result = search_user("nonexistent", sql_session) + assert isinstance(result, list) + assert len(result) == 0 + + +def test_search_user_special_characters(sql_session: Session) -> None: + setup_users(sql_session) + + user = search_user("-symbol-man", sql_session) + assert user is not None + assert isinstance(user, User) + assert user.name == "-symbol-man" + + +def test_search_by_rfid(sql_session: Session) -> None: + setup_users(sql_session) + + user = search_user("130", sql_session) + assert user is not None + assert isinstance(user, User) + assert user.name == "evy" diff --git a/tests/queries/test_transaction_log.py b/tests/queries/test_transaction_log.py new file mode 100644 index 0000000..b8bf127 --- /dev/null +++ b/tests/queries/test_transaction_log.py @@ -0,0 +1,624 @@ +from datetime import datetime, timedelta + +import pytest +from sqlalchemy.orm import Session + +from dibbler.models import ( + Product, + Transaction, + TransactionType, + User, +) +from dibbler.queries import transaction_log + + +def insert_test_data(sql_session: Session) -> tuple[User, User, Product, Product]: + user1 = User("Test User 1") + user2 = User("Test User 2") + + product1 = Product("1234567890123", "Test Product 1") + product2 = Product("9876543210987", "Test Product 2") + sql_session.add_all([user1, user2, product1, product2]) + sql_session.commit() + + return user1, user2, product1, product2 + + +def insert_default_test_transactions( + sql_session: Session, + user1: User, + user2: User, + product1: Product, + product2: Product, +) -> list[Transaction]: + transactions = [ + Transaction.adjust_balance( + time=datetime(2023, 10, 1, 10, 0, 0), + amount=100, + user_id=user1.id, + ), + Transaction.adjust_balance( + time=datetime(2023, 10, 1, 10, 0, 1), + amount=50, + user_id=user2.id, + ), + Transaction.adjust_balance( + time=datetime(2023, 10, 1, 10, 0, 2), + amount=-50, + user_id=user1.id, + ), + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 0), + amount=27 * 2, + per_product=27, + product_count=2, + user_id=user1.id, + product_id=product1.id, + ), + Transaction.buy_product( + time=datetime(2023, 10, 1, 12, 0, 1), + product_count=1, + user_id=user2.id, + product_id=product2.id, + ), + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 2), + amount=15 * 1, + per_product=15, + product_count=1, + user_id=user2.id, + product_id=product2.id, + ), + Transaction.transfer( + time=datetime(2023, 10, 1, 14, 0, 0), + amount=30, + user_id=user1.id, + transfer_user_id=user2.id, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + return transactions + + +def test_transaction_log_uninitialized_user(sql_session: Session) -> None: + user = User("Uninitialized User") + + with pytest.raises(ValueError): + transaction_log(sql_session, user=user) + + +def test_transaction_log_uninitialized_product(sql_session: Session) -> None: + product = Product("1234567890123", "Uninitialized Product") + + with pytest.raises(ValueError): + transaction_log(sql_session, product=product) + + +def test_transaction_log_after_product_and_user_not_allowed(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + insert_default_test_transactions(sql_session, user, user2, product, product2) + + with pytest.raises(ValueError): + transaction_log( + sql_session, + user=user, + product=product, + after_time=datetime(2023, 10, 1, 11, 0, 0), + ) + + +def test_transaction_log_after_datetime_and_transaction_id_not_allowed( + sql_session: Session, +) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + insert_default_test_transactions(sql_session, user, user2, product, product2) + + with pytest.raises(ValueError): + transaction_log( + sql_session, + user=user, + after_time=datetime(2023, 10, 1, 11, 0, 0), + after_transaction_id=1, + ) + + +def test_user_transactions_before_product_and_user_not_allowed(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + insert_default_test_transactions(sql_session, user, user2, product, product2) + + with pytest.raises(ValueError): + transaction_log( + sql_session, + user=user, + product=product, + before_time=datetime(2023, 10, 1, 15, 0, 0), + ) + + +def test_transaction_log_before_datetime_and_transaction_id_not_allowed( + sql_session: Session, +) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + insert_default_test_transactions(sql_session, user, user2, product, product2) + + with pytest.raises(ValueError): + transaction_log( + sql_session, + user=user, + before_time=datetime(2023, 10, 1, 15, 0, 0), + before_transaction_id=1, + ) + + +def test_user_transactions_no_transactions(sql_session: Session) -> None: + insert_test_data(sql_session) + + transactions = transaction_log(sql_session) + + assert len(transactions) == 0 + + +def test_transaction_log_basic(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + insert_default_test_transactions(sql_session, user, user2, product, product2) + + assert len(transaction_log(sql_session)) == 7 + + +def test_transaction_log_filtered_by_user(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + insert_default_test_transactions(sql_session, user, user2, product, product2) + + assert len(transaction_log(sql_session, user=user)) == 4 + assert len(transaction_log(sql_session, user=user2)) == 3 + + +def test_transaction_log_filtered_by_product(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + insert_default_test_transactions(sql_session, user, user2, product, product2) + + assert len(transaction_log(sql_session, product=product)) == 1 + assert len(transaction_log(sql_session, product=product2)) == 2 + + +def test_transaction_log_after_datetime(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + assert ( + len( + transaction_log( + sql_session, + after_time=transactions[2].time, + ) + ) + == len(transactions) - 2 + ) + + +def test_transaction_log_after_datetime_no_transactions(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + assert ( + len( + transaction_log( + sql_session, + after_time=transactions[-1].time + timedelta(seconds=1), + ) + ) + == 0 + ) + + +def test_transaction_log_after_datetime_exclusive(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + assert ( + len( + transaction_log( + sql_session, + after_time=transactions[2].time, + exclusive_after=True, + ) + ) + == len(transactions) - 3 + ) + + +def test_transaction_log_after_transaction_id(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + first_transaction = transactions[0] + + assert len( + transaction_log( + sql_session, + after_transaction_id=first_transaction.id, + ) + ) == len(transactions) + + +def test_transaction_log_after_transaction_id_one_transaction(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + last_transaction = transactions[-1] + + assert ( + len( + transaction_log( + sql_session, + after_transaction_id=last_transaction.id, + ) + ) + == 1 + ) + + +def test_transaction_log_after_transaction_id_exclusive(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + third_transaction = transactions[2] + + assert ( + len( + transaction_log( + sql_session, + after_transaction_id=third_transaction.id, + exclusive_after=True, + ) + ) + == len(transactions) - 3 + ) + + +def test_transaction_log_before_datetime(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + assert ( + len( + transaction_log( + sql_session, + before_time=transactions[-3].time, + ) + ) + == len(transactions) - 2 + ) + + +def test_transaction_log_before_datetime_no_transactions(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + assert ( + len( + transaction_log( + sql_session, + before_time=transactions[0].time - timedelta(seconds=1), + ) + ) + == 0 + ) + + +def test_transaction_log_before_datetime_exclusive(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + assert ( + len( + transaction_log( + sql_session, + before_time=transactions[-3].time, + exclusive_before=True, + ) + ) + == len(transactions) - 3 + ) + + +def test_transaction_log_before_transaction_id(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + last_transaction = transactions[-3] + + assert ( + len( + transaction_log( + sql_session, + before_transaction_id=last_transaction.id, + ) + ) + == len(transactions) - 2 + ) + + +def test_transaction_log_before_transaction_id_one_transaction(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + first_transaction = transactions[0] + + assert ( + len( + transaction_log( + sql_session, + before_transaction_id=first_transaction.id, + ) + ) + == 1 + ) + + +def test_transaction_log_before_transaction_id_exclusive(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + last_transaction = transactions[-3] + + assert ( + len( + transaction_log( + sql_session, + before_transaction_id=last_transaction.id, + exclusive_before=True, + ) + ) + == len(transactions) - 3 + ) + + +def test_transaction_log_before_after_datetime_combined(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + second_transaction = transactions[1] + fifth_transaction = transactions[4] + + assert ( + len( + transaction_log( + sql_session, + after_time=second_transaction.time, + before_time=fifth_transaction.time, + ) + ) + == 4 + ) + + +def test_transaction_log_before_after_transaction_id_combined(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + second_transaction = transactions[1] + fifth_transaction = transactions[4] + + assert ( + len( + transaction_log( + sql_session, + after_transaction_id=second_transaction.id, + before_transaction_id=fifth_transaction.id, + ) + ) + == 4 + ) + + +def test_transaction_log_before_date_after_transaction_id(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + second_transaction = transactions[1] + fifth_transaction = transactions[4] + + assert ( + len( + transaction_log( + sql_session, + before_time=fifth_transaction.time, + after_transaction_id=second_transaction.id, + ) + ) + == 4 + ) + + +def test_transaction_log_before_transaction_id_after_date(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + second_transaction = transactions[1] + fifth_transaction = transactions[4] + + assert ( + len( + transaction_log( + sql_session, + before_transaction_id=fifth_transaction.id, + after_time=second_transaction.time, + ) + ) + == 4 + ) + + +def test_transaction_log_limit(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + assert len(transaction_log(sql_session, limit=3)) == 3 + assert len(transaction_log(sql_session, limit=len(transactions) + 3)) == len(transactions) + + +def test_transaction_log_filtered_by_transaction_type(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + insert_default_test_transactions(sql_session, user, user2, product, product2) + + assert ( + len( + transaction_log( + sql_session, + transaction_type=[TransactionType.ADJUST_BALANCE], + ) + ) + == 3 + ) + assert ( + len( + transaction_log( + sql_session, + transaction_type=[TransactionType.ADD_PRODUCT], + ) + ) + == 2 + ) + assert ( + len( + transaction_log( + sql_session, + transaction_type=[TransactionType.BUY_PRODUCT, TransactionType.ADD_PRODUCT], + ) + ) + == 3 + ) + + +def test_transaction_log_filtered_by_transaction_type_negated(sql_session: Session) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + assert ( + len( + transaction_log( + sql_session, + transaction_type=[TransactionType.ADJUST_BALANCE], + negate_transaction_type_filter=True, + ) + ) + == len(transactions) - 3 + ) + assert ( + len( + transaction_log( + sql_session, + transaction_type=[TransactionType.ADD_PRODUCT], + negate_transaction_type_filter=True, + ) + ) + == len(transactions) - 2 + ) + assert ( + len( + transaction_log( + sql_session, + transaction_type=[TransactionType.BUY_PRODUCT, TransactionType.ADD_PRODUCT], + negate_transaction_type_filter=True, + ) + ) + == len(transactions) - 3 + ) + + +def test_transaction_log_combined_filter_user_datetime_transaction_type_limit( + sql_session: Session, +) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + second_transaction = transactions[1] + sixth_transaction = transactions[5] + + result = transaction_log( + sql_session, + user=user, + after_time=second_transaction.time, + before_time=sixth_transaction.time, + transaction_type=[TransactionType.ADJUST_BALANCE, TransactionType.ADD_PRODUCT], + limit=2, + ) + + assert len(result) == 2 + + +def test_transaction_log_combined_filter_user_transaction_id_transaction_type_limit( + sql_session: Session, +) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + second_transaction = transactions[1] + sixth_transaction = transactions[5] + + result = transaction_log( + sql_session, + user=user, + after_transaction_id=second_transaction.id, + before_transaction_id=sixth_transaction.id, + transaction_type=[TransactionType.ADJUST_BALANCE, TransactionType.ADD_PRODUCT], + limit=2, + ) + + assert len(result) == 2 + + +def test_transaction_log_combined_filter_product_datetime_transaction_type_limit( + sql_session: Session, +) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + second_transaction = transactions[1] + sixth_transaction = transactions[5] + + result = transaction_log( + sql_session, + product=product2, + after_time=second_transaction.time, + before_time=sixth_transaction.time, + transaction_type=[TransactionType.BUY_PRODUCT, TransactionType.ADD_PRODUCT], + limit=2, + ) + + assert len(result) == 2 + + +def test_transaction_log_combined_filter_product_transaction_id_transaction_type_limit( + sql_session: Session, +) -> None: + user, user2, product, product2 = insert_test_data(sql_session) + transactions = insert_default_test_transactions(sql_session, user, user2, product, product2) + + second_transaction = transactions[1] + sixth_transaction = transactions[5] + + result = transaction_log( + sql_session, + product=product2, + after_transaction_id=second_transaction.id, + before_transaction_id=sixth_transaction.id, + transaction_type=[TransactionType.BUY_PRODUCT, TransactionType.ADD_PRODUCT], + limit=2, + ) + + assert len(result) == 2 + + +# NOTE: see the corresponding TODO's above the function definition + + +@pytest.mark.skip(reason="Not yet implemented") +def test_transaction_log_filtered_by_user_joint_transactions(sql_session: Session) -> None: ... + + +@pytest.mark.skip(reason="Not yet implemented") +def test_transaction_log_filtered_by_user_throw_away_transactions(sql_session: Session) -> None: ... diff --git a/tests/queries/test_user_balance.py b/tests/queries/test_user_balance.py new file mode 100644 index 0000000..ee7489e --- /dev/null +++ b/tests/queries/test_user_balance.py @@ -0,0 +1,709 @@ +import math +from datetime import datetime +from pprint import pprint + +import pytest +from sqlalchemy.orm import Session + +from dibbler.models import Product, Transaction, User +from dibbler.models.Transaction import DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE +from dibbler.queries import joint_buy_product, user_balance, user_balance_log +from dibbler.queries.user_balance import _joint_transaction_query, _non_joint_transaction_query + +# TODO: see if we can use pytest_runtest_makereport to print the "user_balance_log"s +# only on failures instead of inlining it in every test function + + +def insert_test_data(sql_session: Session) -> tuple[User, User, User, Product]: + user = User("Test User") + user2 = User("Test User 2") + user3 = User("Test User 3") + product = Product("1234567890123", "Test Product") + + sql_session.add_all([user, user2, user3, product]) + sql_session.commit() + + return user, user2, user3, product + + +def _product_cost( + per_product: int, + product_count: int, + interest_rate_percent: int, + apply_penalty: bool, + penalty_multiplier_percent: int, +) -> int: + base_cost = per_product * product_count + cost_with_interest = math.ceil(base_cost * (interest_rate_percent / 100)) + if apply_penalty: + total_cost = math.ceil(cost_with_interest * (penalty_multiplier_percent / 100)) + else: + total_cost = cost_with_interest + return total_cost + + +def test_non_joint_transaction_query(sql_session) -> None: + user1, user2, user3, product = insert_test_data(sql_session) + + transactions = [ + Transaction.adjust_balance( + time=datetime(2023, 10, 1, 10, 0, 0), + user_id=user1.id, + amount=100, + ), + Transaction.adjust_balance( + time=datetime(2023, 10, 1, 10, 0, 0), + user_id=user2.id, + amount=50, + ), + Transaction.add_product( + time=datetime(2023, 10, 1, 10, 0, 1), + user_id=user2.id, + amount=70, + product_id=product.id, + product_count=3, + per_product=30, + ), + Transaction.transfer( + time=datetime(2023, 10, 1, 10, 0, 2), + user_id=user1.id, + transfer_user_id=user2.id, + amount=50, + ), + Transaction.transfer( + time=datetime(2023, 10, 1, 10, 0, 3), + user_id=user2.id, + transfer_user_id=user3.id, + amount=30, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + t = transactions + + result = [ + row[0] + for row in sql_session.execute( + _non_joint_transaction_query( + user_id=user1.id, + use_cache=False, + ), + ).all() + ] + assert result == [t[0].id, t[3].id] + + result = [ + row[0] + for row in sql_session.execute( + _non_joint_transaction_query( + user_id=user2.id, + use_cache=False, + ), + ).all() + ] + assert result == [t[1].id, t[2].id, t[3].id, t[4].id] + + result = [ + row[0] + for row in sql_session.execute( + _non_joint_transaction_query( + user_id=user3.id, + use_cache=False, + ), + ).all() + ] + assert result == [t[4].id] + + +def test_joint_transaction_query(sql_session: Session) -> None: + user1, user2, user3, product = insert_test_data(sql_session) + + j1 = joint_buy_product( + sql_session, + product=product, + product_count=3, + instigator=user1, + users=[user1, user2], + ) + + j2 = joint_buy_product( + sql_session, + product=product, + product_count=2, + instigator=user1, + users=[user1, user1, user2], + ) + + j3 = joint_buy_product( + sql_session, + product=product, + product_count=2, + instigator=user1, + users=[user1, user3, user3], + ) + + j4 = joint_buy_product( + sql_session, + product=product, + product_count=2, + instigator=user2, + users=[user2, user3, user3], + ) + + result = list( + sql_session.execute( + _joint_transaction_query( + user_id=user1.id, + use_cache=False, + ), + ).all(), + ) + assert list(result) == [ + (j1[0].id, 1, 2), + (j2[0].id, 2, 3), + (j3[0].id, 1, 3), + ] + + result = list( + sql_session.execute( + _joint_transaction_query( + user_id=user2.id, + use_cache=False, + ), + ).all(), + ) + assert list(result) == [ + (j1[0].id, 1, 2), + (j2[0].id, 1, 3), + (j4[0].id, 1, 3), + ] + + result = list( + sql_session.execute( + _joint_transaction_query( + user_id=user3.id, + use_cache=False, + ), + ).all(), + ) + assert list(result) == [ + (j3[0].id, 2, 3), + (j4[0].id, 2, 3), + ] + + +def test_user_balance_no_transactions(sql_session: Session) -> None: + user, *_ = insert_test_data(sql_session) + + pprint(user_balance_log(sql_session, user)) + + balance = user_balance(sql_session, user) + + assert balance == 0 + + +def test_user_balance_basic_history(sql_session: Session) -> None: + user, _, _, product = insert_test_data(sql_session) + + transactions = [ + Transaction.adjust_balance( + time=datetime(2023, 10, 1, 10, 0, 0), + user_id=user.id, + amount=100, + ), + Transaction.add_product( + time=datetime(2023, 10, 1, 10, 0, 1), + user_id=user.id, + product_id=product.id, + amount=27, + per_product=27, + product_count=1, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + pprint(user_balance_log(sql_session, user)) + + balance = user_balance(sql_session, user) + + assert balance == 100 + 27 + + +def test_user_balance_with_transfers(sql_session: Session) -> None: + user1, user2, _, product = insert_test_data(sql_session) + + transactions = [ + Transaction.adjust_balance( + time=datetime(2023, 10, 1, 10, 0, 0), + user_id=user1.id, + amount=100, + ), + Transaction.transfer( + time=datetime(2023, 10, 1, 10, 0, 1), + user_id=user1.id, + transfer_user_id=user2.id, + amount=50, + ), + Transaction.transfer( + time=datetime(2023, 10, 1, 10, 0, 2), + user_id=user2.id, + transfer_user_id=user1.id, + amount=30, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + pprint(user_balance_log(sql_session, user1)) + + user1_balance = user_balance(sql_session, user1) + assert user1_balance == 100 - 50 + 30 + + pprint(user_balance_log(sql_session, user2)) + + user2_balance = user_balance(sql_session, user2) + assert user2_balance == 50 - 30 + + +@pytest.mark.skip(reason="Not yet implemented") +def test_user_balance_complex_history(sql_session: Session) -> None: ... + + +def test_user_balance_penalty(sql_session: Session) -> None: + user, _, _, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 10, 0, 0), + user_id=user.id, + product_id=product.id, + amount=27, + per_product=27, + product_count=1, + ), + Transaction.adjust_balance( + time=datetime(2023, 10, 1, 11, 0, 0), + user_id=user.id, + amount=-200, + ), + # Penalized, pays 2x the price (default penalty) + Transaction.buy_product( + time=datetime(2023, 10, 1, 12, 0, 0), + user_id=user.id, + product_id=product.id, + product_count=1, + ), + ] + sql_session.add_all(transactions) + sql_session.commit() + + pprint(user_balance_log(sql_session, user)) + + assert user_balance(sql_session, user) == 27 - 200 - ( + 27 * (DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE // 100) + ) + + +def test_user_balance_changing_penalty(sql_session: Session) -> None: + user, _, _, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 10, 0, 0), + user_id=user.id, + product_id=product.id, + amount=27, + per_product=27, + product_count=1, + ), + Transaction.adjust_balance( + time=datetime(2023, 10, 1, 11, 0, 0), + user_id=user.id, + amount=-200, + ), + # Penalized, pays 2x the price (default penalty) + Transaction.buy_product( + time=datetime(2023, 10, 1, 12, 0, 0), + user_id=user.id, + product_id=product.id, + product_count=1, + ), + Transaction.adjust_penalty( + time=datetime(2023, 10, 1, 13, 0, 0), + user_id=user.id, + penalty_multiplier_percent=300, + penalty_threshold=-100, + ), + # Penalized, pays 3x the price + Transaction.buy_product( + time=datetime(2023, 10, 1, 14, 0, 0), + user_id=user.id, + product_id=product.id, + product_count=1, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + pprint(user_balance_log(sql_session, user)) + + assert user_balance(sql_session, user) == 27 - 200 - ( + 27 * (DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE // 100) + ) - (27 * 3) + + +def test_user_balance_interest(sql_session: Session) -> None: + user, _, _, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 10, 0, 0), + user_id=user.id, + product_id=product.id, + amount=27, + per_product=27, + product_count=1, + ), + Transaction.adjust_interest( + time=datetime(2023, 10, 1, 11, 0, 0), + user_id=user.id, + interest_rate_percent=110, + ), + Transaction.buy_product( + time=datetime(2023, 10, 1, 12, 0, 0), + user_id=user.id, + product_id=product.id, + product_count=1, + ), + ] + sql_session.add_all(transactions) + sql_session.commit() + + pprint(user_balance_log(sql_session, user)) + + assert user_balance(sql_session, user) == 27 - math.ceil(27 * 1.1) + + +def test_user_balance_changing_interest(sql_session: Session) -> None: + user, _, _, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 10, 0, 0), + user_id=user.id, + product_id=product.id, + amount=27 * 3, + per_product=27, + product_count=3, + ), + Transaction.adjust_interest( + time=datetime(2023, 10, 1, 11, 0, 0), + user_id=user.id, + interest_rate_percent=110, + ), + # Pays 1.1x the price + Transaction.buy_product( + time=datetime(2023, 10, 1, 12, 0, 0), + user_id=user.id, + product_id=product.id, + product_count=1, + ), + Transaction.adjust_interest( + time=datetime(2023, 10, 1, 13, 0, 0), + user_id=user.id, + interest_rate_percent=120, + ), + # Pays 1.2x the price + Transaction.buy_product( + time=datetime(2023, 10, 1, 14, 0, 0), + user_id=user.id, + product_id=product.id, + product_count=1, + ), + ] + sql_session.add_all(transactions) + sql_session.commit() + + pprint(user_balance_log(sql_session, user)) + + assert user_balance(sql_session, user) == 27 * 3 - math.ceil(27 * 1.1) - math.ceil(27 * 1.2) + + +def test_user_balance_penalty_interest_combined(sql_session: Session) -> None: + user, _, _, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 10, 0, 0), + user_id=user.id, + product_id=product.id, + amount=27, + per_product=27, + product_count=1, + ), + Transaction.adjust_interest( + time=datetime(2023, 10, 1, 11, 0, 0), + user_id=user.id, + interest_rate_percent=110, + ), + Transaction.adjust_balance( + time=datetime(2023, 10, 1, 12, 0, 0), + user_id=user.id, + amount=-200, + ), + # Penalized, pays 2x the price (default penalty) + # Pays 1.1x the price + Transaction.buy_product( + time=datetime(2023, 10, 1, 13, 0, 0), + user_id=user.id, + product_id=product.id, + product_count=1, + ), + ] + sql_session.add_all(transactions) + sql_session.commit() + + pprint(user_balance_log(sql_session, user)) + + assert user_balance(sql_session, user) == ( + 27 - 200 - math.ceil(27 * (DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE // 100) * 1.1) + ) + + +def test_user_balance_joint_transaction_single_user(sql_session: Session) -> None: + user, _, _, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 10, 0, 0), + user_id=user.id, + product_id=product.id, + amount=27 * 3, + per_product=27, + product_count=3, + ), + ] + sql_session.add_all(transactions) + sql_session.commit() + + joint_buy_product( + sql_session, + instigator=user, + users=[user], + product=product, + product_count=2, + ) + + pprint(user_balance_log(sql_session, user)) + + assert user_balance(sql_session, user) == (27 * 3) - (27 * 2) + + +def test_user_balance_joint_transactions_multiple_users(sql_session: Session) -> None: + user, user2, user3, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 10, 0, 0), + user_id=user.id, + product_id=product.id, + amount=27 * 3, + per_product=27, + product_count=3, + ), + ] + sql_session.add_all(transactions) + sql_session.commit() + + joint_buy_product( + sql_session, + instigator=user, + users=[user, user2, user3], + product=product, + product_count=2, + ) + + pprint(user_balance_log(sql_session, user)) + + assert user_balance(sql_session, user) == (27 * 3) - math.ceil((27 * 2) / 3) + + +def test_user_balance_joint_transactions_multiple_times_self(sql_session: Session) -> None: + user, user2, _, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 10, 0, 0), + user_id=user.id, + product_id=product.id, + amount=27 * 3, + per_product=27, + product_count=3, + ), + ] + sql_session.add_all(transactions) + sql_session.commit() + + joint_buy_product( + sql_session, + instigator=user, + users=[user, user, user2], + product=product, + product_count=2, + ) + + pprint(user_balance_log(sql_session, user)) + + assert user_balance(sql_session, user) == (27 * 3) - math.ceil((27 * 2) * (2 / 3)) + + +def test_user_balance_joint_transactions_interest(sql_session: Session) -> None: + user, user2, _, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 10, 0, 0), + user_id=user.id, + product_id=product.id, + amount=27 * 3, + per_product=27, + product_count=3, + ), + Transaction.adjust_interest( + time=datetime(2023, 10, 1, 11, 0, 0), + user_id=user.id, + interest_rate_percent=110, + ), + ] + sql_session.add_all(transactions) + sql_session.commit() + + joint_buy_product( + sql_session, + instigator=user, + users=[user, user2], + product=product, + product_count=2, + ) + + pprint(user_balance_log(sql_session, user)) + + assert user_balance(sql_session, user) == (27 * 3) - math.ceil(math.ceil(27 * 2 / 2) * 1.1) + + +def test_user_balance_joint_transactions_changing_interest(sql_session: Session) -> None: + user, user2, _, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 10, 0, 0), + user_id=user.id, + product_id=product.id, + amount=27 * 4, + per_product=27, + product_count=4, + ), + # Pays 1.1x the price + Transaction.adjust_interest( + time=datetime(2023, 10, 1, 11, 0, 0), + user_id=user.id, + interest_rate_percent=110, + ), + ] + sql_session.add_all(transactions) + sql_session.commit() + + joint_buy_product( + sql_session, + instigator=user, + users=[user, user2], + product=product, + product_count=2, + ) + + transactions = [ + # Pays 1.2x the price + Transaction.adjust_interest( + time=datetime(2023, 10, 1, 12, 0, 0), + user_id=user.id, + interest_rate_percent=120, + ) + ] + sql_session.add_all(transactions) + sql_session.commit() + + joint_buy_product( + sql_session, + instigator=user, + users=[user, user2], + product=product, + product_count=1, + ) + + pprint(user_balance_log(sql_session, user)) + + assert user_balance(sql_session, user) == ( + (27 * 4) - math.ceil(math.ceil(27 * 2 / 2) * 1.1) - math.ceil(math.ceil(27 * 1 / 2) * 1.2) + ) + + +def test_user_balance_joint_transactions_penalty(sql_session: Session) -> None: + user, user2, _, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 10, 0, 0), + user_id=user.id, + product_id=product.id, + amount=27 * 3, + per_product=27, + product_count=3, + ), + Transaction.adjust_balance( + time=datetime(2023, 10, 1, 11, 0, 0), + user_id=user.id, + amount=-200, + ), + ] + sql_session.add_all(transactions) + sql_session.commit() + + joint_buy_product( + sql_session, + instigator=user, + users=[user, user2], + product=product, + product_count=2, + ) + + pprint(user_balance_log(sql_session, user)) + + assert user_balance(sql_session, user) == ( + (27 * 3) + - 200 + - math.ceil(math.ceil(27 * 2 / 2) * (DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE // 100)) + ) + + +@pytest.mark.skip(reason="Not yet implemented") +def test_user_balance_joint_transactions_changing_penalty(sql_session: Session) -> None: ... + + +@pytest.mark.skip(reason="Not yet implemented") +def test_user_balance_joint_transactions_penalty_interest_combined( + sql_session: Session, +) -> None: ... + + +@pytest.mark.skip(reason="Not yet implemented") +def test_user_balance_until(sql_session: Session) -> None: ... + + +@pytest.mark.skip(reason="Not yet implemented") +def test_user_balance_throw_away_products(sql_session: Session) -> None: ...