From 3c9bca8b55b3ce7ed3579a44b158ab6d8876de08 Mon Sep 17 00:00:00 2001 From: h7x4 Date: Wed, 11 Jun 2025 20:01:11 +0200 Subject: [PATCH] fixup! WIP --- dibbler/lib/helpers.py | 76 +-- dibbler/models/InterestRate.py | 10 - dibbler/models/Product.py | 116 +--- .../{ProductPriceCache.py => ProductCache.py} | 8 +- dibbler/models/Transaction.py | 522 ++++++------------ dibbler/models/TransactionType.py | 8 +- dibbler/models/User.py | 36 +- .../{UserBalanceCache.py => UserCache.py} | 4 +- dibbler/queries/__init__.py | 0 dibbler/queries/add_product.py | 0 dibbler/queries/adjust_interest.py | 2 + dibbler/queries/adjust_penalty.py | 2 + dibbler/queries/buy_product.py | 37 ++ dibbler/queries/product_price.py | 181 ++++++ dibbler/queries/product_stock.py | 52 ++ dibbler/queries/products_owned_by_user.py | 0 dibbler/queries/search_product.py | 53 ++ dibbler/queries/search_user.py | 28 + dibbler/queries/user_balance.py | 102 ++++ dibbler/queries/users_owning_product.py | 0 dibbler/subcommands/seed_test_data.py | 25 - tests/conftest.py | 2 + tests/models/__init__.py | 0 tests/models/test_product.py | 32 ++ tests/models/test_transaction.py | 73 +++ tests/models/test_user.py | 72 +++ tests/queries/__init__.py | 0 tests/queries/test_buy_product.py | 184 ++++++ tests/queries/test_product_price.py | 173 ++++++ tests/queries/test_product_stock.py | 143 +++++ tests/queries/test_transfer_balance.py | 12 + tests/queries/test_user_balance.py | 103 ++++ tests/test_product.py | 216 -------- tests/test_transaction.py | 97 ---- tests/test_user.py | 108 ---- 35 files changed, 1467 insertions(+), 1010 deletions(-) delete mode 100644 dibbler/models/InterestRate.py rename dibbler/models/{ProductPriceCache.py => ProductCache.py} (57%) rename dibbler/models/{UserBalanceCache.py => UserCache.py} (87%) create mode 100644 dibbler/queries/__init__.py create mode 100644 dibbler/queries/add_product.py create mode 100644 dibbler/queries/adjust_interest.py create mode 100644 dibbler/queries/adjust_penalty.py create mode 100644 dibbler/queries/buy_product.py create mode 100644 dibbler/queries/product_price.py create mode 100644 dibbler/queries/product_stock.py create mode 100644 dibbler/queries/products_owned_by_user.py create mode 100644 dibbler/queries/search_product.py create mode 100644 dibbler/queries/search_user.py create mode 100644 dibbler/queries/user_balance.py create mode 100644 dibbler/queries/users_owning_product.py create mode 100644 tests/models/__init__.py create mode 100644 tests/models/test_product.py create mode 100644 tests/models/test_transaction.py create mode 100644 tests/models/test_user.py create mode 100644 tests/queries/__init__.py create mode 100644 tests/queries/test_buy_product.py create mode 100644 tests/queries/test_product_price.py create mode 100644 tests/queries/test_product_stock.py create mode 100644 tests/queries/test_transfer_balance.py create mode 100644 tests/queries/test_user_balance.py delete mode 100644 tests/test_product.py delete mode 100644 tests/test_transaction.py delete mode 100644 tests/test_user.py diff --git a/dibbler/lib/helpers.py b/dibbler/lib/helpers.py index 30926a3..0aab88e 100644 --- a/dibbler/lib/helpers.py +++ b/dibbler/lib/helpers.py @@ -1,79 +1,7 @@ -import pwd -import subprocess import os +import pwd import signal - -from sqlalchemy import or_, and_ - -from ..models import User, Product - - -def search_user(string, session, ignorethisflag=None): - string = string.lower() - exact_match = ( - session.query(User) - .filter(or_(User.name == string, User.card == string, User.rfid == string)) - .first() - ) - if exact_match: - return exact_match - user_list = ( - session.query(User) - .filter( - or_( - User.name.ilike(f"%{string}%"), - User.card.ilike(f"%{string}%"), - User.rfid.ilike(f"%{string}%"), - ) - ) - .all() - ) - return user_list - - -def search_product(string, session, find_hidden_products=True): - if find_hidden_products: - exact_match = ( - session.query(Product) - .filter(or_(Product.bar_code == string, Product.name == string)) - .first() - ) - else: - exact_match = ( - session.query(Product) - .filter( - or_( - Product.bar_code == string, - and_(Product.name == string, Product.hidden is False), - ) - ) - .first() - ) - if exact_match: - return exact_match - if find_hidden_products: - product_list = ( - session.query(Product) - .filter( - or_( - Product.bar_code.ilike(f"%{string}%"), - Product.name.ilike(f"%{string}%"), - ) - ) - .all() - ) - else: - product_list = ( - session.query(Product) - .filter( - or_( - Product.bar_code.ilike(f"%{string}%"), - and_(Product.name.ilike(f"%{string}%"), Product.hidden is False), - ) - ) - .all() - ) - return product_list +import subprocess def system_user_exists(username): diff --git a/dibbler/models/InterestRate.py b/dibbler/models/InterestRate.py deleted file mode 100644 index c9ffccb..0000000 --- a/dibbler/models/InterestRate.py +++ /dev/null @@ -1,10 +0,0 @@ -from datetime import datetime - -from sqlalchemy import Integer, DateTime -from sqlalchemy.orm import Mapped, mapped_column - -from dibbler.models import Base - -class InterestRate(Base): - timestamp: Mapped[datetime] = mapped_column(DateTime) - percentage: Mapped[int] = mapped_column(Integer) diff --git a/dibbler/models/Product.py b/dibbler/models/Product.py index b5c22e9..dd10fd9 100644 --- a/dibbler/models/Product.py +++ b/dibbler/models/Product.py @@ -6,35 +6,41 @@ from sqlalchemy import ( Boolean, Integer, String, - case, - func, - select, ) from sqlalchemy.orm import ( Mapped, - Session, mapped_column, ) -import dibbler.models.User as user - from .Base import Base -from .Transaction import Transaction -from .TransactionType import TransactionType - -# if TYPE_CHECKING: -# from .PurchaseEntry import PurchaseEntry -# from .UserProducts import UserProducts class Product(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True) + """Internal database ID""" bar_code: Mapped[str] = mapped_column(String(13), unique=True) + """ + The bar code of the product. + + This is a unique identifier for the product, typically a 13-digit + EAN-13 code. + """ + name: Mapped[str] = mapped_column(String(45)) - # price: Mapped[int] = mapped_column(Integer) - # stock: Mapped[int] = mapped_column(Integer) + """ + The name of the product. + + Please don't write fanfics here, this is not a place for that. + """ + hidden: Mapped[bool] = mapped_column(Boolean, default=False) + """ + Whether the product is hidden from the user interface. + + Hidden products are not shown in the product list, but can still be + used in transactions. + """ def __init__( self: Self, @@ -45,85 +51,3 @@ class Product(Base): self.bar_code = bar_code self.name = name self.hidden = hidden - - # - count (virtual) - def stock(self: Self, sql_session: Session) -> int: - """ - Returns the number of products in stock. - """ - - result = sql_session.scalars( - select( - func.sum( - case( - ( - Transaction.type_ == TransactionType.ADD_PRODUCT, - Transaction.product_count, - ), - ( - Transaction.type_ == TransactionType.BUY_PRODUCT, - -Transaction.product_count, - ), - ( - Transaction.type_ == TransactionType.ADJUST_STOCK, - Transaction.product_count, - ), - else_=0, - ) - ) - ).where( - Transaction.type_.in_( - [ - TransactionType.BUY_PRODUCT, - TransactionType.ADD_PRODUCT, - TransactionType.ADJUST_STOCK, - ] - ), - Transaction.product_id == self.id, - ) - ).one_or_none() - - return result or 0 - - def remaining_with_exact_price(self: Self, sql_session: Session) -> list[int]: - """ - Retrieves the remaining products with their exact price as they were bought. - """ - - stock = self.stock(sql_session) - - # TODO: only retrieve as many transactions as exists in the stock - last_added = sql_session.scalars( - select( - func.row_number(), - Transaction.time, - Transaction.per_product, - Transaction.product_count, - ) - .where( - Transaction.type_ == TransactionType.ADD_PRODUCT, - Transaction.product_id == self.id, - ) - .order_by(Transaction.time.desc()) - ).all() - - # result = [] - # while stock > 0 and last_added: - - ... - - def price(self: Self, sql_session: Session) -> int: - """ - Returns the price of the product. - - Average price over the last bought products. - """ - - return Transaction.product_price(sql_session=sql_session, product=self) - - def owned_by_user(self: Self, sql_session: Session) -> dict[user.User, int]: - """ - Returns an overview of how many of the remaining products are owned by which user. - """ - - ... diff --git a/dibbler/models/ProductPriceCache.py b/dibbler/models/ProductCache.py similarity index 57% rename from dibbler/models/ProductPriceCache.py rename to dibbler/models/ProductCache.py index d468f87..b2ac8a4 100644 --- a/dibbler/models/ProductPriceCache.py +++ b/dibbler/models/ProductCache.py @@ -5,7 +5,11 @@ from sqlalchemy.orm import Mapped, mapped_column from dibbler.models import Base -class ProductPriceCache(Base): +class ProductCache(Base): product_id: Mapped[int] = mapped_column(Integer, primary_key=True) - timestamp: Mapped[datetime] = mapped_column(DateTime) + price: Mapped[int] = mapped_column(Integer) + price_timestamp: Mapped[datetime] = mapped_column(DateTime) + + stock: Mapped[int] = mapped_column(Integer) + stock_timestamp: Mapped[datetime] = mapped_column(DateTime) diff --git a/dibbler/models/Transaction.py b/dibbler/models/Transaction.py index b07f79d..50d2d5b 100644 --- a/dibbler/models/Transaction.py +++ b/dibbler/models/Transaction.py @@ -9,19 +9,12 @@ from sqlalchemy import ( ForeignKey, Integer, Text, - asc, - case, - cast, - func, - literal, - select, ) from sqlalchemy import ( Enum as SQLEnum, ) from sqlalchemy.orm import ( Mapped, - Session, mapped_column, relationship, ) @@ -40,21 +33,32 @@ if TYPE_CHECKING: # maybe we should add some sort of joint transaction id field to allow multiple transactions to be grouped together? _DYNAMIC_FIELDS: set[str] = { + "amount", + "interest_rate_percent", + "penalty_multiplier_percent", + "penalty_threshold", "per_product", - "user_id", - "transfer_user_id", - "product_id", "product_count", + "product_id", + "transfer_user_id", } _EXPECTED_FIELDS: dict[TransactionType, set[str]] = { - TransactionType.ADJUST_BALANCE: {"user_id"}, - TransactionType.ADJUST_STOCK: {"user_id", "product_id", "product_count"}, - TransactionType.TRANSFER: {"user_id", "transfer_user_id"}, - TransactionType.ADD_PRODUCT: {"user_id", "product_id", "per_product", "product_count"}, - TransactionType.BUY_PRODUCT: {"user_id", "product_id", "product_count"}, + TransactionType.ADD_PRODUCT: {"amount", "per_product", "product_count", "product_id"}, + TransactionType.ADJUST_BALANCE: {"amount"}, + TransactionType.ADJUST_INTEREST: {"interest_rate_percent"}, + TransactionType.ADJUST_PENALTY: {"penalty_multiplier_percent", "penalty_threshold"}, + TransactionType.ADJUST_STOCK: {"product_count", "product_id"}, + # TODO: remove amount from BUY_PRODUCT + # this requires modifications to user credit calculations + TransactionType.BUY_PRODUCT: {"amount", "product_count", "product_id"}, + TransactionType.TRANSFER: {"amount", "transfer_user_id"}, } +assert all(x <= _DYNAMIC_FIELDS for x in _EXPECTED_FIELDS.values()), ( + "All expected fields must be part of _DYNAMIC_FIELDS." +) + def _transaction_type_field_constraints( transaction_type: TransactionType, @@ -89,64 +93,147 @@ class Transaction(Base): ) id: Mapped[int] = mapped_column(Integer, primary_key=True) + """ + A unique identifier for the transaction. + + Not used for anything else than identifying the transaction in the database. + """ + time: Mapped[datetime] = mapped_column(DateTime, unique=True) + """ + The time when the transaction took place. + + This is used to order transactions chronologically, and to calculate + all kinds of state. + """ + message: Mapped[str | None] = mapped_column(Text, nullable=True) + """ + A message that can be set by the user to describe the reason + behind the transaction (or potentially a place to write som fan fiction). + + This is not used for any calculations, but can be useful for debugging. + """ - # The type of transaction type_: Mapped[TransactionType] = mapped_column(SQLEnum(TransactionType), name="type") + """ + Which type of transaction this is. - # TODO: this should be inferred - # If buying products, is the user penalized for having too low credit? - # penalty: Mapped[Boolean] = mapped_column(Boolean, default=False) + The type determines which fields are expected to be set. + """ - # The amount of money being added or subtracted from the user's credit - # This amount means different things depending on the transaction type: - # - ADJUST_BALANCE: The amount of credit to add or subtract from the user's balance - # - ADJUST_STOCK: The amount of money which disappeared with this stock adjustment - # (i.e. current price * product_count) - # - TRANSFER: The amount of credit to transfer to another user - # - ADD_PRODUCT: The real amount spent on the products - # (i.e. not per_product * product_count, which should be rounded up) - # - BUY_PRODUCT: The amount of credit spent on the product - amount: Mapped[int] = mapped_column(Integer) + amount: Mapped[int | None] = mapped_column(Integer) + """ + This field means different things depending on the transaction type: + + - `ADD_PRODUCT`: The real amount spent on the products. + + - `ADJUST_BALANCE`: The amount of credit to add or subtract from the user's balance. + + - `BUY_PRODUCT`: The amount of credit spent on the product. + Note that this includes any penalties and interest that the user + had to pay as well. + + - `TRANSFER`: The amount of balance to transfer to another user. + """ - # If adding products, how much is each product worth per_product: Mapped[int | None] = mapped_column(Integer) + """ + If adding products, how much is each product worth - # The user who performs the transaction - user_id: Mapped[int | None] = mapped_column(ForeignKey("user.id")) - user: Mapped[User | None] = relationship( + Note that this is distinct from the total amount of the transaction, + because this gets rounded up to the nearest integer, while the total amount + that the user paid in the store would be stored in the `amount` field. + """ + + user_id: Mapped[int] = mapped_column(ForeignKey("user.id")) + """The user who performs the transaction. See `user` for more details.""" + user: Mapped[User] = relationship( lazy="joined", foreign_keys=[user_id], ) + """ + The user who performs the transaction. + + For some transaction types, like `TRANSFER` and `ADD_PRODUCT`, this is a + functional field with "real world consequences" for price calculations. + + For others, like `ADJUST_PENALTY` and `ADJUST_STOCK`, this is just a record of who + performed the transaction, and does not affect any state calculations. + """ # Receiving user when moving credit from one user to another transfer_user_id: Mapped[int | None] = mapped_column(ForeignKey("user.id")) + """The user who receives money in a `TRANSFER` transaction.""" transfer_user: Mapped[User | None] = relationship( lazy="joined", foreign_keys=[transfer_user_id], ) + """The user who receives money in a `TRANSFER` transaction.""" # The product that is either being added or bought product_id: Mapped[int | None] = mapped_column(ForeignKey("product.id")) + """The product being added or bought.""" product: Mapped[Product | None] = relationship(lazy="joined") + """The product being added or bought.""" # The amount of products being added or bought product_count: Mapped[int | None] = mapped_column(Integer) + """ + The amount of products being added or bought. + """ + + penalty_threshold: Mapped[int | None] = mapped_column(Integer, nullable=True) + """ + On `ADJUST_PENALTY` transactions, this is the threshold in krs for when the user + should start getting penalized for low credit. + + See also `penalty_multiplier`. + """ + + penalty_multiplier_percent: Mapped[int | None] = mapped_column(Integer, nullable=True) + """ + On `ADJUST_PENALTY` transactions, this is the multiplier for the amount of + money the user has to pay when they have too low credit. + + The multiplier is a percentage, so `100` means the user has to pay the full + price of the product, `200` means they have to pay double, etc. + + See also `penalty_threshold`. + """ + + # TODO: this should be inferred + # Assuming this is a BUY_PRODUCT transaction, was the user penalized for having + # too low credit in this transaction? + # is_penalized: Mapped[Boolean] = mapped_column(Boolean, default=False) + + interest_rate_percent: Mapped[int | None] = mapped_column(Integer, nullable=True) + """ + On `ADJUST_INTEREST` transactions, this is the interest rate in percent + that the user has to pay on their balance. + + The interest rate is a percentage, so `100` means the user has to pay the full + price of the product, `200` means they have to pay double, etc. + """ def __init__( self: Self, type_: TransactionType, user_id: int, - amount: int, + amount: int | None = None, time: datetime | None = None, message: str | None = None, product_id: int | None = None, transfer_user_id: int | None = None, per_product: int | None = None, product_count: int | None = None, - # penalty: bool = False + penalty_threshold: int | None = None, + penalty_multiplier_percent: int | None = None, + interest_rate_percent: int | None = None, ) -> None: + """ + Please do not call this constructor directly, use the factory methods instead. + """ if time is None: time = datetime.now() @@ -159,14 +246,16 @@ class Transaction(Base): self.transfer_user_id = transfer_user_id self.per_product = per_product self.product_count = product_count - # self.penalty = penalty + self.penalty_threshold = penalty_threshold + self.penalty_multiplier_percent = penalty_multiplier_percent + self.interest_rate_percent = interest_rate_percent self._validate_by_transaction_type() def _validate_by_transaction_type(self: Self) -> None: """ - Validates the transaction based on its type. - Raises ValueError if the transaction is invalid. + Validates the transaction's fields based on its type. + Raises `ValueError` if the transaction is invalid. """ # TODO: do we allow free products? if self.amount == 0: @@ -186,6 +275,7 @@ class Transaction(Base): if ( self.per_product is not None and self.product_count is not None + and self.amount is not None and self.amount > self.per_product * self.product_count ): raise ValueError( @@ -204,9 +294,6 @@ class Transaction(Base): time: datetime | None = None, message: str | None = None, ) -> Transaction: - """ - Creates an ADJUST transaction. - """ return cls( time=time, type_=TransactionType.ADJUST_BALANCE, @@ -215,81 +302,58 @@ class Transaction(Base): message=message, ) + @classmethod + def adjust_interest( + cls: type[Self], + interest_rate_percent: int, + user_id: int, + time: datetime | None = None, + message: str | None = None, + ) -> Transaction: + return cls( + time=time, + type_=TransactionType.ADJUST_INTEREST, + interest_rate_percent=interest_rate_percent, + user_id=user_id, + message=message, + ) + + @classmethod + def adjust_penalty( + cls: type[Self], + penalty_multiplier_percent: int, + penalty_threshold: int, + user_id: int, + time: datetime | None = None, + message: str | None = None, + ) -> Transaction: + return cls( + time=time, + type_=TransactionType.ADJUST_PENALTY, + penalty_multiplier_percent=penalty_multiplier_percent, + penalty_threshold=penalty_threshold, + user_id=user_id, + message=message, + ) + @classmethod def adjust_stock( cls: type[Self], - amount: int, user_id: int, product_id: int, product_count: int, time: datetime | None = None, message: str | None = None, ) -> Transaction: - """ - Creates an ADJUST_STOCK transaction. - """ return cls( time=time, type_=TransactionType.ADJUST_STOCK, - amount=amount, user_id=user_id, product_id=product_id, product_count=product_count, message=message, ) - @classmethod - def adjust_stock_auto_amount( - cls: type[Self], - sql_session: Session, - user_id: int, - product_id: int, - product_count: int, - time: datetime | None = None, - message: str | None = None, - ) -> Transaction: - """ - Creates an ADJUST_STOCK transaction with the amount automatically calculated based on the product's current price. - """ - from .Product import Product - - product = sql_session.scalar(select(Product).where(Product.id == product_id)) - if product is None: - raise ValueError(f"Product with id {product_id} does not exist.") - - price = product.price(sql_session) - - return cls( - time=time, - type_=TransactionType.ADJUST_STOCK, - amount=price * product_count, - user_id=user_id, - product_id=product_id, - product_count=product_count, - message=message, - ) - - @classmethod - def transfer( - cls: type[Self], - amount: int, - user_id: int, - transfer_user_id: int, - time: datetime | None = None, - message: str | None = None, - ) -> Transaction: - """ - Creates a TRANSFER transaction. - """ - return cls( - time=time, - type_=TransactionType.TRANSFER, - amount=amount, - user_id=user_id, - transfer_user_id=transfer_user_id, - message=message, - ) - @classmethod def add_product( cls: type[Self], @@ -301,9 +365,6 @@ class Transaction(Base): time: datetime | None = None, message: str | None = None, ) -> Transaction: - """ - Creates an ADD_PRODUCT transaction. - """ return cls( time=time, type_=TransactionType.ADD_PRODUCT, @@ -325,9 +386,6 @@ class Transaction(Base): time: datetime | None = None, message: str | None = None, ) -> Transaction: - """ - Creates a BUY_PRODUCT transaction. - """ return cls( time=time, type_=TransactionType.BUY_PRODUCT, @@ -339,263 +397,19 @@ class Transaction(Base): ) @classmethod - def buy_product_auto_amount( + def transfer( cls: type[Self], - sql_session: Session, + amount: int, user_id: int, - product_id: int, - product_count: int, + transfer_user_id: int, time: datetime | None = None, message: str | None = None, ) -> Transaction: - """ - Creates a BUY_PRODUCT transaction with the amount automatically calculated based on the product's current price. - """ - from .Product import Product - - product = sql_session.scalar(select(Product).where(Product.id == product_id)) - if product is None: - raise ValueError(f"Product with id {product_id} does not exist.") - - price = product.price(sql_session) - return cls( time=time, - type_=TransactionType.BUY_PRODUCT, - amount=price * product_count, + type_=TransactionType.TRANSFER, + amount=amount, user_id=user_id, - product_id=product_id, - product_count=product_count, + transfer_user_id=transfer_user_id, message=message, ) - - ############################ - # USER BALANCE CALCULATION # - ############################ - - @staticmethod - def _user_balance_query( - user: User, - # until: datetime | None = None, - ): - """ - The inner query for calculating the user's balance. - This is used both directly via user_balance() and in Transaction CHECK constraints. - """ - - balance_adjustments = ( - select(func.coalesce(func.sum(Transaction.amount).label("balance_adjustments"), 0)) - .where( - Transaction.user_id == user.id, - Transaction.type_ == TransactionType.ADJUST_BALANCE, - ) - .scalar_subquery() - ) - - transfers_to_other_users = ( - select(func.coalesce(func.sum(Transaction.amount).label("transfers_to_other_users"), 0)) - .where( - Transaction.user_id == user.id, - Transaction.type_ == TransactionType.TRANSFER, - ) - .scalar_subquery() - ) - - transfers_to_self = ( - select(func.coalesce(func.sum(Transaction.amount).label("transfers_to_self"), 0)) - .where( - Transaction.transfer_user_id == user.id, - Transaction.type_ == TransactionType.TRANSFER, - ) - .scalar_subquery() - ) - - add_products = ( - select(func.coalesce(func.sum(Transaction.amount).label("add_products"), 0)) - .where( - Transaction.user_id == user.id, - Transaction.type_ == TransactionType.ADD_PRODUCT, - ) - .scalar_subquery() - ) - - buy_products = ( - select(func.coalesce(func.sum(Transaction.amount).label("buy_products"), 0)) - .where( - Transaction.user_id == user.id, - Transaction.type_ == TransactionType.BUY_PRODUCT, - ) - .scalar_subquery() - ) - - query = select( - # TODO: clearly define and fix the sign of the amount - ( - 0 - + balance_adjustments - - transfers_to_other_users - + transfers_to_self - + add_products - - buy_products - ).label("credit") - ) - - return query - - @staticmethod - def user_balance( - sql_session: Session, - user: User, - # Optional: calculate the balance until a certain transaction. - # until: Transaction | None = None, - ) -> int: - """ - Calculates the balance of a user. - """ - - query = Transaction._user_balance_query(user) # , until=until) - - result = sql_session.scalar(query) - - if result is None: - # If there are no transactions for this user, the query should return 0, not None. - raise RuntimeError( - f"Something went wrong while calculating the balance for user {user.name} (ID: {user.id})." - ) - - return result - - ############################# - # PRODUCT PRICE CALCULATION # - ############################# - - @staticmethod - def _product_price_query( - product: Product, - # until: datetime | None = None, - ): - """ - The inner query for calculating the product price. - - This is used both directly via product_price() and in Transaction CHECK constraints. - """ - initial_element = select( - literal(0).label("i"), - literal(0).label("time"), - literal(0).label("price"), - literal(0).label("product_count"), - ) - - recursive_cte = initial_element.cte(name="rec_cte", recursive=True) - - # Subset of transactions that we'll want to iterate over. - trx_subset = ( - select( - func.row_number().over(order_by=asc(Transaction.time)).label("i"), - Transaction.time, - Transaction.type_, - Transaction.product_count, - Transaction.per_product, - ) - .where( - Transaction.type_.in_( - [ - TransactionType.BUY_PRODUCT, - TransactionType.ADD_PRODUCT, - TransactionType.ADJUST_STOCK, - ] - ), - Transaction.product_id == product.id, - # TODO: - # If we have a transaction to limit the price calculation to, use it. - # If not, use all transactions for the product. - # (Transaction.time <= until.time) if until else True, - ) - .order_by(Transaction.time.asc()) - .alias("trx_subset") - ) - - recursive_elements = ( - select( - trx_subset.c.i, - trx_subset.c.time, - case( - # Someone buys the product -> price remains the same. - (trx_subset.c.type_ == TransactionType.BUY_PRODUCT, recursive_cte.c.price), - # Someone adds the product -> price is recalculated based on - # product count, previous price, and new price. - ( - trx_subset.c.type_ == TransactionType.ADD_PRODUCT, - cast( - func.ceil( - (trx_subset.c.per_product * trx_subset.c.product_count) - / ( - # The running product count can be negative if the accounting is bad. - # This ensures that we never end up with negative prices or zero divisions - # and other disastrous phenomena. - func.min(recursive_cte.c.product_count, 0) - + trx_subset.c.product_count - ) - ), - Integer, - ), - ), - # Someone adjusts the stock -> price remains the same. - (trx_subset.c.type_ == TransactionType.ADJUST_STOCK, recursive_cte.c.price), - # Should never happen - else_=recursive_cte.c.price, - ).label("price"), - case( - # Someone buys the product -> product count is reduced. - ( - trx_subset.c.type_ == TransactionType.BUY_PRODUCT, - recursive_cte.c.product_count - trx_subset.c.product_count, - ), - # Someone adds the product -> product count is increased. - ( - trx_subset.c.type_ == TransactionType.ADD_PRODUCT, - recursive_cte.c.product_count + trx_subset.c.product_count, - ), - # Someone adjusts the stock -> product count is adjusted. - ( - trx_subset.c.type_ == TransactionType.ADJUST_STOCK, - recursive_cte.c.product_count + trx_subset.c.product_count, - ), - # Should never happen - else_=recursive_cte.c.product_count, - ).label("product_count"), - ) - .select_from(trx_subset) - .where(trx_subset.c.i == recursive_cte.c.i + 1) - ) - - return recursive_cte.union_all(recursive_elements) - - @staticmethod - def product_price( - sql_session: Session, - product: Product, - # Optional: calculate the price until a certain transaction. - # until: Transaction | None = None, - ) -> int: - """ - Calculates the price of a product. - """ - - recursive_cte = Transaction._product_price_query(product) # , until=until) - - # TODO: optionally verify subresults: - # - product_count should never be negative (but this happens sometimes, so just a warning) - # - price should never be negative - - result = sql_session.scalar( - select(recursive_cte.c.price).order_by(recursive_cte.c.i.desc()).limit(1) - ) - - if result is None: - # If there are no transactions for this product, the query should return 0, not None. - raise RuntimeError( - f"Something went wrong while calculating the price for product {product.name} (ID: {product.id})." - ) - - return result diff --git a/dibbler/models/TransactionType.py b/dibbler/models/TransactionType.py index a0af64a..4f4fa29 100644 --- a/dibbler/models/TransactionType.py +++ b/dibbler/models/TransactionType.py @@ -6,8 +6,10 @@ class TransactionType(Enum): Enum for transaction types. """ - ADJUST_BALANCE = "adjust_balance" - ADJUST_STOCK = "adjust_stock" - TRANSFER = "transfer" ADD_PRODUCT = "add_product" + ADJUST_BALANCE = "adjust_balance" + ADJUST_INTEREST = "adjust_interest" + ADJUST_PENALTY = "adjust_penalty" + ADJUST_STOCK = "adjust_stock" BUY_PRODUCT = "buy_product" + TRANSFER = "transfer" diff --git a/dibbler/models/User.py b/dibbler/models/User.py index 950dca1..d359d13 100644 --- a/dibbler/models/User.py +++ b/dibbler/models/User.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Self +from typing import TYPE_CHECKING, Self from sqlalchemy import ( Integer, @@ -13,16 +13,21 @@ from sqlalchemy.orm import ( mapped_column, ) -import dibbler.models.Product as product - from .Base import Base -from .Transaction import Transaction + +if TYPE_CHECKING: + from .Transaction import Transaction class User(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True) + """Internal database ID""" name: Mapped[str] = mapped_column(String(20), unique=True) + """ + The PVV username of the user. + """ + card: Mapped[str | None] = mapped_column(String(20)) rfid: Mapped[str | None] = mapped_column(String(20)) @@ -41,31 +46,14 @@ class User(Base): # def is_anonymous(self): # return self.card == "11122233" - # TODO: rename to 'balance' everywhere - def credit(self, sql_session: Session) -> int: - """ - Returns the current credit of the user. - """ - - result = Transaction.user_balance( - sql_session=sql_session, - user=self, - ) - - return result - - def products(self, sql_session: Session) -> list[tuple[product.Product, int]]: - """ - Returns the products that the user has put into the system (and has not been purchased yet) - """ - - ... - + # TODO: move to 'queries' def transactions(self, sql_session: Session) -> list[Transaction]: """ Returns the transactions of the user in chronological order. """ + from .Transaction import Transaction # Import here to avoid circular import + return list( sql_session.scalars( select(Transaction) diff --git a/dibbler/models/UserBalanceCache.py b/dibbler/models/UserCache.py similarity index 87% rename from dibbler/models/UserBalanceCache.py rename to dibbler/models/UserCache.py index b42994e..31ea604 100644 --- a/dibbler/models/UserBalanceCache.py +++ b/dibbler/models/UserCache.py @@ -5,7 +5,9 @@ from sqlalchemy.orm import Mapped, mapped_column from dibbler.models import Base +# More like user balance cash money flow, amirite? class UserBalanceCache(Base): user_id: Mapped[int] = mapped_column(Integer, primary_key=True) - timestamp: Mapped[datetime] = mapped_column(DateTime) + balance: Mapped[int] = mapped_column(Integer) + timestamp: Mapped[datetime] = mapped_column(DateTime) diff --git a/dibbler/queries/__init__.py b/dibbler/queries/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dibbler/queries/add_product.py b/dibbler/queries/add_product.py new file mode 100644 index 0000000..e69de29 diff --git a/dibbler/queries/adjust_interest.py b/dibbler/queries/adjust_interest.py new file mode 100644 index 0000000..b81ae3e --- /dev/null +++ b/dibbler/queries/adjust_interest.py @@ -0,0 +1,2 @@ +# NOTE: this type of transaction should be password protected. +# the password can be set as a string literal in the config file. diff --git a/dibbler/queries/adjust_penalty.py b/dibbler/queries/adjust_penalty.py new file mode 100644 index 0000000..b81ae3e --- /dev/null +++ b/dibbler/queries/adjust_penalty.py @@ -0,0 +1,2 @@ +# NOTE: this type of transaction should be password protected. +# the password can be set as a string literal in the config file. diff --git a/dibbler/queries/buy_product.py b/dibbler/queries/buy_product.py new file mode 100644 index 0000000..99605ea --- /dev/null +++ b/dibbler/queries/buy_product.py @@ -0,0 +1,37 @@ +from datetime import datetime + +from sqlalchemy.orm import Session + +from dibbler.models import ( + Transaction, + TransactionType, + User, + Product, +) + +from .product_price import product_price + + +def buy_product( + sql_session: Session, + user: User, + product: Product, + product_count: int, + time: datetime | None = None, + message: str | None = None, +) -> Transaction: + """ + Creates a BUY_PRODUCT transaction with the amount automatically calculated based on the product's current price. + """ + + price = product_price(sql_session, product) + + return Transaction( + time=time, + type_=TransactionType.BUY_PRODUCT, + amount=price * product_count, + user_id=user.id, + product_id=product.id, + product_count=product_count, + message=message, + ) diff --git a/dibbler/queries/product_price.py b/dibbler/queries/product_price.py new file mode 100644 index 0000000..621efc4 --- /dev/null +++ b/dibbler/queries/product_price.py @@ -0,0 +1,181 @@ +from datetime import datetime + +from sqlalchemy import ( + Integer, + asc, + case, + cast, + func, + literal, + select, +) + +from sqlalchemy.orm import Session + +from dibbler.models import ( + Product, + Transaction, + TransactionType, +) + +def _product_price_query( + product: Product, + # use_cache: bool = True, + # until: datetime | None = None, +): + """ + The inner query for calculating the product price. + """ + initial_element = select( + literal(0).label("i"), + literal(0).label("time"), + literal(0).label("price"), + literal(0).label("product_count"), + ) + + recursive_cte = initial_element.cte(name="rec_cte", recursive=True) + + # Subset of transactions that we'll want to iterate over. + trx_subset = ( + select( + func.row_number().over(order_by=asc(Transaction.time)).label("i"), + Transaction.time, + Transaction.type_, + Transaction.product_count, + Transaction.per_product, + ) + .where( + Transaction.type_.in_( + [ + TransactionType.BUY_PRODUCT, + TransactionType.ADD_PRODUCT, + TransactionType.ADJUST_STOCK, + ] + ), + Transaction.product_id == product.id, + # TODO: + # If we have a transaction to limit the price calculation to, use it. + # If not, use all transactions for the product. + # (Transaction.time <= until.time) if until else True, + ) + .order_by(Transaction.time.asc()) + .alias("trx_subset") + ) + + recursive_elements = ( + select( + trx_subset.c.i, + trx_subset.c.time, + case( + # Someone buys the product -> price remains the same. + (trx_subset.c.type_ == TransactionType.BUY_PRODUCT, recursive_cte.c.price), + # Someone adds the product -> price is recalculated based on + # product count, previous price, and new price. + ( + trx_subset.c.type_ == TransactionType.ADD_PRODUCT, + cast( + func.ceil( + (trx_subset.c.per_product * trx_subset.c.product_count) + / ( + # The running product count can be negative if the accounting is bad. + # This ensures that we never end up with negative prices or zero divisions + # and other disastrous phenomena. + func.max(recursive_cte.c.product_count, 0) + + trx_subset.c.product_count + ) + ), + Integer, + ), + ), + # Someone adjusts the stock -> price remains the same. + (trx_subset.c.type_ == TransactionType.ADJUST_STOCK, recursive_cte.c.price), + # Should never happen + else_=recursive_cte.c.price, + ).label("price"), + case( + # Someone buys the product -> product count is reduced. + ( + trx_subset.c.type_ == TransactionType.BUY_PRODUCT, + recursive_cte.c.product_count - trx_subset.c.product_count, + ), + # Someone adds the product -> product count is increased. + ( + trx_subset.c.type_ == TransactionType.ADD_PRODUCT, + recursive_cte.c.product_count + trx_subset.c.product_count, + ), + # Someone adjusts the stock -> product count is adjusted. + ( + trx_subset.c.type_ == TransactionType.ADJUST_STOCK, + recursive_cte.c.product_count + trx_subset.c.product_count, + ), + # Should never happen + else_=recursive_cte.c.product_count, + ).label("product_count"), + ) + .select_from(trx_subset) + .where(trx_subset.c.i == recursive_cte.c.i + 1) + ) + + return recursive_cte.union_all(recursive_elements) + + +def product_price_log( + sql_session: Session, + product: Product, + # use_cache: bool = True, + # Optional: calculate the price until a certain transaction. + # until: Transaction | None = None, +) -> list[tuple[int, datetime, int, int]]: + """ + Calculates the price of a product and returns a log of the price changes. + """ + + recursive_cte = _product_price_query(product) + + result = sql_session.execute( + select( + recursive_cte.c.i, + recursive_cte.c.time, + recursive_cte.c.price, + recursive_cte.c.product_count, + ).order_by(recursive_cte.c.i.asc()) + ).all() + + if not result: + # If there are no transactions for this product, the query should return an empty list, not None. + raise RuntimeError( + f"Something went wrong while calculating the price log for product {product.name} (ID: {product.id})." + ) + + return [(row.i, row.time, row.price, row.product_count) for row in result] + + +@staticmethod +def product_price( + sql_session: Session, + product: Product, + # use_cache: bool = True, + # Optional: calculate the price until a certain transaction. + # until: Transaction | None = None, +) -> int: + """ + Calculates the price of a product. + """ + + recursive_cte = _product_price_query(product) # , until=until) + + # TODO: optionally verify subresults: + # - product_count should never be negative (but this happens sometimes, so just a warning) + # - price should never be negative + + result = sql_session.scalar( + select(recursive_cte.c.price).order_by(recursive_cte.c.i.desc()).limit(1) + ) + + if result is None: + # If there are no transactions for this product, the query should return 0, not None. + raise RuntimeError( + f"Something went wrong while calculating the price for product {product.name} (ID: {product.id})." + ) + + return result diff --git a/dibbler/queries/product_stock.py b/dibbler/queries/product_stock.py new file mode 100644 index 0000000..cae94ec --- /dev/null +++ b/dibbler/queries/product_stock.py @@ -0,0 +1,52 @@ +from sqlalchemy import case, func, select +from sqlalchemy.orm import Session + +from dibbler.models import ( + Product, + Transaction, + TransactionType, +) + + +def product_stock( + sql_session: Session, + product: Product, + # use_cache: bool = True, + # until: datetime | None = None, +) -> int: + """ + Returns the number of products in stock. + """ + + result = sql_session.scalars( + select( + func.sum( + case( + ( + Transaction.type_ == TransactionType.ADD_PRODUCT, + Transaction.product_count, + ), + ( + Transaction.type_ == TransactionType.BUY_PRODUCT, + -Transaction.product_count, + ), + ( + Transaction.type_ == TransactionType.ADJUST_STOCK, + Transaction.product_count, + ), + else_=0, + ) + ) + ).where( + Transaction.type_.in_( + [ + TransactionType.BUY_PRODUCT, + TransactionType.ADD_PRODUCT, + TransactionType.ADJUST_STOCK, + ] + ), + Transaction.product_id == product.id, + ) + ).one_or_none() + + return result or 0 diff --git a/dibbler/queries/products_owned_by_user.py b/dibbler/queries/products_owned_by_user.py new file mode 100644 index 0000000..e69de29 diff --git a/dibbler/queries/search_product.py b/dibbler/queries/search_product.py new file mode 100644 index 0000000..98161e3 --- /dev/null +++ b/dibbler/queries/search_product.py @@ -0,0 +1,53 @@ +from sqlalchemy import and_, or_ +from sqlalchemy.orm import Session + +from dibbler.models import Product + + +def search_product( + string: str, + session: Session, + find_hidden_products=True, +) -> Product | list[Product]: + if find_hidden_products: + exact_match = ( + session.query(Product) + .filter(or_(Product.bar_code == string, Product.name == string)) + .first() + ) + else: + exact_match = ( + session.query(Product) + .filter( + or_( + Product.bar_code == string, + and_(Product.name == string, not Product.hidden), + ) + ) + .first() + ) + if exact_match: + return exact_match + if find_hidden_products: + product_list = ( + session.query(Product) + .filter( + or_( + Product.bar_code.ilike(f"%{string}%"), + Product.name.ilike(f"%{string}%"), + ) + ) + .all() + ) + else: + product_list = ( + session.query(Product) + .filter( + or_( + Product.bar_code.ilike(f"%{string}%"), + and_(Product.name.ilike(f"%{string}%"), not Product.hidden), + ) + ) + .all() + ) + return product_list diff --git a/dibbler/queries/search_user.py b/dibbler/queries/search_user.py new file mode 100644 index 0000000..62d8c53 --- /dev/null +++ b/dibbler/queries/search_user.py @@ -0,0 +1,28 @@ +from sqlalchemy import or_ +from sqlalchemy.orm import Session + +from dibbler.models import User + + +# TODO: modernize queries to use SQLAlchemy 2.0 style +def search_user(string: str, session: Session, ignorethisflag=None) -> User | list[User]: + string = string.lower() + exact_match = ( + session.query(User) + .filter(or_(User.name == string, User.card == string, User.rfid == string)) + .first() + ) + if exact_match: + return exact_match + user_list = ( + session.query(User) + .filter( + or_( + User.name.ilike(f"%{string}%"), + User.card.ilike(f"%{string}%"), + User.rfid.ilike(f"%{string}%"), + ) + ) + .all() + ) + return user_list diff --git a/dibbler/queries/user_balance.py b/dibbler/queries/user_balance.py new file mode 100644 index 0000000..dcedd5e --- /dev/null +++ b/dibbler/queries/user_balance.py @@ -0,0 +1,102 @@ +from sqlalchemy import func, select +from sqlalchemy.orm import Session + +from dibbler.models import ( + Transaction, + TransactionType, + User, +) + +# TODO: rename to 'balance' everywhere + +def _user_balance_query( + user: User, + # use_cache: bool = True, + # until: datetime | None = None, +): + """ + The inner query for calculating the user's balance. + """ + + balance_adjustments = ( + select(func.coalesce(func.sum(Transaction.amount).label("balance_adjustments"), 0)) + .where( + Transaction.user_id == user.id, + Transaction.type_ == TransactionType.ADJUST_BALANCE, + ) + .scalar_subquery() + ) + + transfers_to_other_users = ( + select(func.coalesce(func.sum(Transaction.amount).label("transfers_to_other_users"), 0)) + .where( + Transaction.user_id == user.id, + Transaction.type_ == TransactionType.TRANSFER, + ) + .scalar_subquery() + ) + + transfers_to_self = ( + select(func.coalesce(func.sum(Transaction.amount).label("transfers_to_self"), 0)) + .where( + Transaction.transfer_user_id == user.id, + Transaction.type_ == TransactionType.TRANSFER, + ) + .scalar_subquery() + ) + + add_products = ( + select(func.coalesce(func.sum(Transaction.amount).label("add_products"), 0)) + .where( + Transaction.user_id == user.id, + Transaction.type_ == TransactionType.ADD_PRODUCT, + ) + .scalar_subquery() + ) + + buy_products = ( + select(func.coalesce(func.sum(Transaction.amount).label("buy_products"), 0)) + .where( + Transaction.user_id == user.id, + Transaction.type_ == TransactionType.BUY_PRODUCT, + ) + .scalar_subquery() + ) + + query = select( + # TODO: clearly define and fix the sign of the amount + ( + 0 + + balance_adjustments + - transfers_to_other_users + + transfers_to_self + + add_products + - buy_products + ).label("balance") + ) + + return query + + +def user_balance( + sql_session: Session, + user: User, + # use_cache: bool = True, + # Optional: calculate the balance until a certain transaction. + # until: Transaction | None = None, +) -> int: + """ + Calculates the balance of a user. + """ + + query = _user_balance_query(user) # , until=until) + + result = sql_session.scalar(query) + + if result is None: + # If there are no transactions for this user, the query should return 0, not None. + raise RuntimeError( + f"Something went wrong while calculating the balance for user {user.name} (ID: {user.id})." + ) + + return result diff --git a/dibbler/queries/users_owning_product.py b/dibbler/queries/users_owning_product.py new file mode 100644 index 0000000..e69de29 diff --git a/dibbler/subcommands/seed_test_data.py b/dibbler/subcommands/seed_test_data.py index 69fab24..1cb9720 100644 --- a/dibbler/subcommands/seed_test_data.py +++ b/dibbler/subcommands/seed_test_data.py @@ -80,28 +80,3 @@ def main(): sql_session.add_all(transactions) sql_session.commit() - - # Note: These constructors depend on the content of the previous transactions, - # so they cannot be part of the initial transaction list. - - transaction = Transaction.adjust_stock_auto_amount( - sql_session=sql_session, - time=datetime(2023, 10, 1, 12, 0, 2), - product_count=3, - user_id=user1.id, - product_id=product1.id, - ) - - sql_session.add(transaction) - sql_session.commit() - - transaction = Transaction.adjust_stock_auto_amount( - sql_session=sql_session, - time=datetime(2023, 10, 1, 12, 0, 3), - product_count=-2, - user_id=user1.id, - product_id=product1.id, - ) - - sql_session.add(transaction) - sql_session.commit() diff --git a/tests/conftest.py b/tests/conftest.py index f94d4d7..293439f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ from sqlalchemy.orm import Session from dibbler.models import Base + def pytest_addoption(parser): parser.addoption( "--echo", @@ -12,6 +13,7 @@ def pytest_addoption(parser): help="Enable SQLAlchemy echo mode for debugging", ) + @pytest.fixture(scope="function") def sql_session(request): """Create a new SQLAlchemy session for testing.""" diff --git a/tests/models/__init__.py b/tests/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/test_product.py b/tests/models/test_product.py new file mode 100644 index 0000000..5e40f00 --- /dev/null +++ b/tests/models/test_product.py @@ -0,0 +1,32 @@ +import pytest +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session + +from dibbler.models import Product + + +def insert_test_data(sql_session: Session) -> Product: + product = Product("1234567890123", "Test Product") + sql_session.add(product) + sql_session.commit() + return product + + +def test_product_no_duplicate_barcodes(sql_session: Session): + product = insert_test_data(sql_session) + + duplicate_product = Product(product.bar_code, "Hehe >:)") + sql_session.add(duplicate_product) + + with pytest.raises(IntegrityError): + sql_session.commit() + + +def test_product_no_duplicate_names(sql_session: Session): + product = insert_test_data(sql_session) + + duplicate_product = Product("1918238911928", product.name) + sql_session.add(duplicate_product) + + with pytest.raises(IntegrityError): + sql_session.commit() diff --git a/tests/models/test_transaction.py b/tests/models/test_transaction.py new file mode 100644 index 0000000..cbcc303 --- /dev/null +++ b/tests/models/test_transaction.py @@ -0,0 +1,73 @@ +from datetime import datetime + +import pytest +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session + +from dibbler.models import Product, Transaction, User + + +def insert_test_data(sql_session: Session) -> tuple[User, Product]: + user = User("Test User") + product = Product("1234567890123", "Test Product") + + sql_session.add(user) + sql_session.add(product) + sql_session.commit() + + return user, product + + +def test_transaction_no_duplicate_timestamps(sql_session: Session): + user, _ = insert_test_data(sql_session) + + transaction1 = Transaction.adjust_balance( + time=datetime(2023, 10, 1, 12, 0, 0), + user_id=user.id, + amount=100, + ) + + sql_session.add(transaction1) + sql_session.commit() + + transaction2 = Transaction.adjust_balance( + time=transaction1.time, + user_id=user.id, + amount=-50, + ) + + sql_session.add(transaction2) + + with pytest.raises(IntegrityError): + sql_session.commit() + + +def test_transaction_buy_product_wrong_amount(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + # Set price by adding a product + transaction = Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 0), + user_id=user.id, + product_id=product.id, + amount=27, + per_product=27, + product_count=1, + ) + + sql_session.add(transaction) + sql_session.commit() + + # Attempt to buy product with wrong amount + transaction2 = Transaction.buy_product( + time=datetime(2023, 10, 1, 12, 0, 1), + user_id=user.id, + product_id=product.id, + amount=(27 * 2) + 1, # Wrong amount + product_count=2, + ) + + sql_session.add(transaction2) + + with pytest.raises(ValueError): + sql_session.commit() diff --git a/tests/models/test_user.py b/tests/models/test_user.py new file mode 100644 index 0000000..dda5930 --- /dev/null +++ b/tests/models/test_user.py @@ -0,0 +1,72 @@ +from datetime import datetime + +import pytest +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session + +from dibbler.models import Product, Transaction, User + + +def insert_test_data(sql_session: Session) -> User: + user = User("Test User") + sql_session.add(user) + sql_session.commit() + + return user + + +def test_ensure_no_duplicate_user_names(sql_session: Session): + user = insert_test_data(sql_session) + + user2 = User(user.name) + sql_session.add(user2) + + with pytest.raises(IntegrityError): + sql_session.commit() + + +def test_user_transactions(sql_session: Session): + user = insert_test_data(sql_session) + + product = Product("1234567890123", "Test Product") + user2 = User("Test User 2") + sql_session.add_all([product, user2]) + sql_session.commit() + + transactions = [ + Transaction.adjust_balance( + time=datetime(2023, 10, 1, 10, 0, 0), + amount=100, + user_id=user.id, + ), + Transaction.adjust_balance( + time=datetime(2023, 10, 1, 10, 0, 1), + amount=50, + user_id=user2.id, + ), + Transaction.adjust_balance( + time=datetime(2023, 10, 1, 10, 0, 2), + amount=-50, + user_id=user.id, + ), + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 0), + amount=27 * 2, + per_product=27, + product_count=2, + user_id=user.id, + product_id=product.id, + ), + Transaction.buy_product( + time=datetime(2023, 10, 1, 12, 0, 1), + amount=27, + product_count=1, + user_id=user2.id, + product_id=product.id, + ), + ] + + sql_session.add_all(transactions) + + assert len(user.transactions(sql_session)) == 3 + assert len(user2.transactions(sql_session)) == 2 diff --git a/tests/queries/__init__.py b/tests/queries/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/queries/test_buy_product.py b/tests/queries/test_buy_product.py new file mode 100644 index 0000000..ff80ddd --- /dev/null +++ b/tests/queries/test_buy_product.py @@ -0,0 +1,184 @@ +import math +from datetime import datetime + +from sqlalchemy.orm import Session + +from dibbler.models import Product, Transaction, User +from dibbler.queries.buy_product import buy_product +from dibbler.queries.product_stock import product_stock +from dibbler.queries.user_balance import user_balance + + +def insert_test_data(sql_session: Session) -> tuple[User, Product]: + user = User("Test User") + product = Product("1234567890123", "Test Product") + + sql_session.add(user) + sql_session.add(product) + sql_session.commit() + + transactions = [ + Transaction.adjust_penalty( + time=datetime(2023, 10, 1, 10, 0, 0), + user_id=user.id, + penalty_multiplier_percent=200, + penalty_threshold=-100, + ), + Transaction.adjust_balance( + time=datetime(2023, 10, 1, 10, 0, 1), + user_id=user.id, + amount=100, + ), + Transaction.add_product( + time=datetime(2023, 10, 1, 10, 0, 2), + user_id=user.id, + product_id=product.id, + amount=27, + per_product=27, + product_count=1, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + return user, product + + +def test_buy_product_basic(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transaction = buy_product( + sql_session=sql_session, + time=datetime(2023, 10, 1, 12, 0, 0), + user=user, + product=product, + product_count=1, + ) + + sql_session.add(transaction) + sql_session.commit() + + +def test_buy_product_with_penalty(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.adjust_balance( + time=datetime(2023, 10, 1, 11, 0, 0), + user_id=user.id, + amount=-200, + ) + ] + sql_session.add_all(transactions) + sql_session.commit() + + transaction = buy_product( + sql_session=sql_session, + time=datetime(2023, 10, 1, 12, 0, 0), + user=user, + product=product, + product_count=1, + ) + sql_session.add(transaction) + sql_session.commit() + + assert user_balance(sql_session, user) == 100 + 27 - 200 - (27 * 2) + + +def test_buy_product_with_interest(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.adjust_interest( + time=datetime(2023, 10, 1, 11, 0, 0), + user_id=user.id, + interest_rate_percent=110, + ) + ] + sql_session.add_all(transactions) + sql_session.commit() + + transaction = buy_product( + sql_session=sql_session, + time=datetime(2023, 10, 1, 12, 0, 0), + user=user, + product=product, + product_count=1, + ) + sql_session.add(transaction) + sql_session.commit() + + assert user_balance(sql_session, user) == 100 + 27 - math.ceil(27 * 1.1) + + +def test_buy_product_with_changing_penalty(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.adjust_balance( + time=datetime(2023, 10, 1, 11, 0, 0), + user_id=user.id, + amount=-200, + ) + ] + sql_session.add_all(transactions) + sql_session.commit() + + transaction = buy_product( + sql_session=sql_session, + time=datetime(2023, 10, 1, 12, 0, 0), + user=user, + product=product, + product_count=1, + ) + sql_session.add(transaction) + sql_session.commit() + + assert user_balance(sql_session, user) == 100 + 27 - 200 - (27 * 2) + + adjust_penalty = Transaction.adjust_penalty( + time=datetime(2023, 10, 1, 13, 0, 0), + user_id=user.id, + penalty_multiplier_percent=300, + penalty_threshold=-100, + ) + sql_session.add(adjust_penalty) + sql_session.commit() + + transaction = buy_product( + sql_session=sql_session, + time=datetime(2023, 10, 1, 14, 0, 0), + user=user, + product=product, + product_count=1, + ) + sql_session.add(transaction) + sql_session.commit() + + assert user_balance(sql_session, user) == 100 + 27 - 200 - (27 * 2) - (27 * 3) + + +def test_buy_product_with_changing_interest(sql_session: Session) -> None: + raise NotImplementedError("This test is not implemented yet.") + + +def test_buy_product_with_penalty_interest_combined(sql_session: Session) -> None: + raise NotImplementedError("This test is not implemented yet.") + + +def test_buy_product_more_than_stock(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transaction = buy_product( + sql_session=sql_session, + time=datetime(2023, 10, 1, 13, 0, 0), + product_count=10, + user=user, + product=product, + ) + + sql_session.add(transaction) + sql_session.commit() + + assert product_stock(sql_session, product) == 1 - 10 diff --git a/tests/queries/test_product_price.py b/tests/queries/test_product_price.py new file mode 100644 index 0000000..c2a247c --- /dev/null +++ b/tests/queries/test_product_price.py @@ -0,0 +1,173 @@ +from datetime import datetime + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from dibbler.models import Product, Transaction, User +from dibbler.queries.product_price import product_price + + +def insert_test_data(sql_session: Session) -> None: + # Add users + user1 = User("Test User 1") + user2 = User("Test User 2") + + sql_session.add_all([user1, user2]) + sql_session.commit() + + # Add products + product1 = Product("1234567890123", "Test Product 1") + product2 = Product("9876543210987", "Test Product 2") + product3 = Product("1111111111111", "Test Product 3") + sql_session.add_all([product1, product2, product3]) + sql_session.commit() + + # Add transactions + transactions = [ + Transaction.adjust_balance( + time=datetime(2023, 10, 1, 10, 0, 0), + amount=100, + user_id=user1.id, + ), + Transaction.adjust_balance( + time=datetime(2023, 10, 1, 10, 0, 1), + amount=50, + user_id=user2.id, + ), + Transaction.adjust_balance( + time=datetime(2023, 10, 1, 10, 0, 2), + amount=-50, + user_id=user1.id, + ), + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 0), + amount=27 * 2, + per_product=27, + product_count=2, + user_id=user1.id, + product_id=product1.id, + ), + Transaction.buy_product( + time=datetime(2023, 10, 1, 12, 0, 1), + amount=27, + product_count=1, + user_id=user2.id, + product_id=product1.id, + ), + Transaction.adjust_stock( + time=datetime(2023, 10, 1, 12, 0, 2), + product_count=3, + user_id=user1.id, + product_id=product1.id, + ), + Transaction.adjust_stock( + time=datetime(2023, 10, 1, 12, 0, 3), + product_count=-2, + user_id=user1.id, + product_id=product1.id, + ), + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 4), + amount=50, + per_product=50, + product_count=1, + user_id=user1.id, + product_id=product3.id, + ), + Transaction.buy_product( + time=datetime(2023, 10, 1, 12, 0, 5), + amount=50, + product_count=1, + user_id=user1.id, + product_id=product3.id, + ), + Transaction.adjust_balance( + time=datetime(2023, 10, 1, 12, 0, 6), + amount=1000, + user_id=user1.id, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + +def test_product_price(sql_session: Session) -> None: + insert_test_data(sql_session) + + product1 = sql_session.scalars(select(Product).where(Product.name == "Test Product 1")).one() + assert product_price(sql_session, product1) == 27 + + +def test_product_price_no_transactions(sql_session: Session) -> None: + insert_test_data(sql_session) + + product2 = sql_session.scalars(select(Product).where(Product.name == "Test Product 2")).one() + assert product_price(sql_session, product2) == 0 + + +def test_product_price_sold_out(sql_session: Session) -> None: + insert_test_data(sql_session) + + product3 = sql_session.scalars(select(Product).where(Product.name == "Test Product 3")).one() + assert product_price(sql_session, product3) == 50 + + +def test_product_price_interest(sql_session: Session) -> None: + raise NotImplementedError("This test is not implemented yet.") + + +def test_product_price_changing_interest(sql_session: Session) -> None: + raise NotImplementedError("This test is not implemented yet.") + + +# Price goes up and gets rounded up to the next integer +def test_product_price_round_up_from_below(sql_session: Session) -> None: + raise NotImplementedError("This test is not implemented yet.") + + +# Price goes down and gets rounded up to the next integer +def test_product_price_round_up_from_above(sql_session: Session) -> None: + raise NotImplementedError("This test is not implemented yet.") + + +def test_product_price_with_negative_stock_single_addition(sql_session: Session) -> None: + insert_test_data(sql_session) + + product1 = sql_session.scalars(select(Product).where(Product.name == "Test Product 1")).one() + user1 = sql_session.scalars(select(User).where(User.name == "Test User 1")).one() + + transaction = Transaction.buy_product( + time=datetime(2023, 10, 1, 13, 0, 0), + amount=27 * 5, + product_count=10, + user_id=user1.id, + product_id=product1.id, + ) + + sql_session.add(transaction) + sql_session.commit() + + product1_price = product_price(sql_session, product1) + assert product1_price == 27 + + transaction = Transaction.add_product( + time=datetime(2023, 10, 1, 13, 0, 1), + amount=22, + per_product=22, + product_count=1, + user_id=user1.id, + product_id=product1.id, + ) + + sql_session.add(transaction) + sql_session.commit() + + # Stock went subzero, price should be the last added product price + product1_price = product_price(sql_session, product1) + assert product1_price == 22 + + +# TODO: what happens when stock is still negative and yet new products are added? +def test_product_price_with_negative_stock_multiple_additions(sql_session: Session) -> None: + raise NotImplementedError("This test is not implemented yet.") diff --git a/tests/queries/test_product_stock.py b/tests/queries/test_product_stock.py new file mode 100644 index 0000000..14f2d3f --- /dev/null +++ b/tests/queries/test_product_stock.py @@ -0,0 +1,143 @@ +from datetime import datetime + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from dibbler.models import Product, Transaction, User +from dibbler.queries.product_stock import product_stock + + +def insert_test_data(sql_session: Session) -> None: + user1 = User("Test User 1") + + sql_session.add(user1) + sql_session.commit() + + +def test_product_stock_basic_history(sql_session: Session) -> None: + insert_test_data(sql_session) + + user1 = sql_session.scalars(select(User).where(User.name == "Test User 1")).one() + + product = Product("1234567890123", "Test Product") + sql_session.add(product) + sql_session.commit() + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 0), + amount=10, + per_product=10, + user_id=user1.id, + product_id=product.id, + product_count=1, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + assert product_stock(sql_session, product) == 1 + + +def test_product_stock_complex_history(sql_session: Session) -> None: + insert_test_data(sql_session) + + user1 = sql_session.scalars(select(User).where(User.name == "Test User 1")).one() + + product = Product("1234567890123", "Test Product") + sql_session.add(product) + sql_session.commit() + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 13, 0, 0), + amount=27 * 2, + per_product=27, + user_id=user1.id, + product_id=product.id, + product_count=2, + ), + Transaction.buy_product( + time=datetime(2023, 10, 1, 13, 0, 1), + amount=27 * 3, + user_id=user1.id, + product_id=product.id, + product_count=3, + ), + Transaction.add_product( + time=datetime(2023, 10, 1, 13, 0, 2), + amount=50 * 4, + per_product=50, + user_id=user1.id, + product_id=product.id, + product_count=4, + ), + Transaction.adjust_stock( + time=datetime(2023, 10, 1, 15, 0, 0), + user_id=user1.id, + product_id=product.id, + product_count=3, + ), + Transaction.adjust_stock( + time=datetime(2023, 10, 1, 15, 0, 1), + user_id=user1.id, + product_id=product.id, + product_count=-2, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + assert product_stock(sql_session, product) == 2 - 3 + 4 + 3 - 2 + + +def test_product_stock_no_transactions(sql_session: Session) -> None: + insert_test_data(sql_session) + + product = Product("1234567890123", "Test Product") + sql_session.add(product) + sql_session.commit() + + assert product_stock(sql_session, product) == 0 + + +def test_negative_product_stock(sql_session: Session) -> None: + insert_test_data(sql_session) + + user1 = sql_session.scalars(select(User).where(User.name == "Test User 1")).one() + + product = Product("1234567890123", "Test Product") + sql_session.add(product) + sql_session.commit() + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 14, 0, 0), + amount=50, + per_product=50, + user_id=user1.id, + product_id=product.id, + product_count=1, + ), + Transaction.buy_product( + time=datetime(2023, 10, 1, 14, 0, 1), + amount=50, + user_id=user1.id, + product_id=product.id, + product_count=2, + ), + Transaction.adjust_stock( + time=datetime(2023, 10, 1, 16, 0, 0), + user_id=user1.id, + product_id=product.id, + product_count=-1, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + # The stock should be negative because we added and bought the product + assert product_stock(sql_session, product) == 1 - 2 - 1 diff --git a/tests/queries/test_transfer_balance.py b/tests/queries/test_transfer_balance.py new file mode 100644 index 0000000..389e973 --- /dev/null +++ b/tests/queries/test_transfer_balance.py @@ -0,0 +1,12 @@ +from sqlalchemy.orm import Session + + +def test_user_not_allowed_to_transfer_to_self(sql_session: Session) -> None: + raise NotImplementedError("This test is not implemented yet.") +# insert_test_data(sql_session) +# ... + +# user1 = sql_session.scalars(select(User).where(User.name == "Test User 1")).one() + +# with pytest.raises(ValueError, match="Cannot transfer to self"): +# user1.transfer(sql_session, user1, 10) # Attempting to transfer to self diff --git a/tests/queries/test_user_balance.py b/tests/queries/test_user_balance.py new file mode 100644 index 0000000..3d5b313 --- /dev/null +++ b/tests/queries/test_user_balance.py @@ -0,0 +1,103 @@ +from datetime import datetime + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from dibbler.models import Product, Transaction, User +from dibbler.queries.user_balance import user_balance + + +def insert_test_data(sql_session: Session) -> None: + # Add users + user1 = User("Test User 1") + user2 = User("Test User 2") + + sql_session.add(user1) + sql_session.add(user2) + sql_session.commit() + + # Add products + product1 = Product("1234567890123", "Test Product 1") + product2 = Product("9876543210987", "Test Product 2") + sql_session.add(product1) + sql_session.add(product2) + sql_session.commit() + + # Add transactions + transactions = [ + Transaction.adjust_balance( + time=datetime(2023, 10, 1, 10, 0, 0), + amount=100, + user_id=user1.id, + ), + Transaction.adjust_balance( + time=datetime(2023, 10, 1, 10, 0, 1), + amount=50, + user_id=user2.id, + ), + Transaction.adjust_balance( + time=datetime(2023, 10, 1, 10, 0, 2), + amount=-50, + user_id=user1.id, + ), + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 0), + amount=27 * 2, + per_product=27, + product_count=2, + user_id=user1.id, + product_id=product1.id, + ), + Transaction.buy_product( + time=datetime(2023, 10, 1, 12, 0, 1), + amount=27, + product_count=1, + user_id=user2.id, + product_id=product1.id, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + +def test_user_balance_basic_history(sql_session: Session) -> None: + insert_test_data(sql_session) + + user1 = sql_session.scalars(select(User).where(User.name == "Test User 1")).one() + user2 = sql_session.scalars(select(User).where(User.name == "Test User 2")).one() + + assert user_balance(sql_session, user1) == 100 - 50 + 27 * 2 + assert user_balance(sql_session, user2) == 50 - 27 + + +def test_user_balance_no_transactions(sql_session: Session) -> None: + raise NotImplementedError("This test is not implemented yet.") + + +def test_user_balance_complex_history(sql_session: Session) -> None: + raise NotImplementedError("This test is not implemented yet.") + + +def test_user_balance_with_tranfers(sql_session: Session) -> None: + raise NotImplementedError("This test is not implemented yet.") + + +def test_user_balance_penalty(sql_session: Session) -> None: + raise NotImplementedError("This test is not implemented yet.") + + +def test_user_balance_changing_penalty(sql_session: Session) -> None: + raise NotImplementedError("This test is not implemented yet.") + + +def test_user_balance_interest(sql_session: Session) -> None: + raise NotImplementedError("This test is not implemented yet.") + + +def test_user_balance_changing_interest(sql_session: Session) -> None: + raise NotImplementedError("This test is not implemented yet.") + + +def test_user_balance_penalty_interest_combined(sql_session: Session) -> None: + raise NotImplementedError("This test is not implemented yet.") diff --git a/tests/test_product.py b/tests/test_product.py deleted file mode 100644 index f424b95..0000000 --- a/tests/test_product.py +++ /dev/null @@ -1,216 +0,0 @@ -from datetime import datetime - -import pytest -from sqlalchemy import select -from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import Session - -from dibbler.models import Product, Transaction, TransactionType, User - - -def insert_test_data(sql_session: Session) -> None: - # Add users - user1 = User("Test User 1") - user2 = User("Test User 2") - - sql_session.add_all([user1, user2]) - sql_session.commit() - - # Add products - product1 = Product("1234567890123", "Test Product 1") - product2 = Product("9876543210987", "Test Product 2") - product3 = Product("1111111111111", "Test Product 3") - sql_session.add_all([product1, product2, product3]) - sql_session.commit() - - # Add transactions - transactions = [ - Transaction( - time=datetime(2023, 10, 1, 10, 0, 0), - type_=TransactionType.ADJUST_BALANCE, - amount=100, - user_id=user1.id, - ), - Transaction( - time=datetime(2023, 10, 1, 10, 0, 1), - type_=TransactionType.ADJUST_BALANCE, - amount=50, - user_id=user2.id, - ), - Transaction( - time=datetime(2023, 10, 1, 10, 0, 2), - type_=TransactionType.ADJUST_BALANCE, - amount=-50, - user_id=user1.id, - ), - Transaction( - time=datetime(2023, 10, 1, 12, 0, 0), - type_=TransactionType.ADD_PRODUCT, - amount=27 * 2, - per_product=27, - product_count=2, - user_id=user1.id, - product_id=product1.id, - ), - Transaction( - time=datetime(2023, 10, 1, 12, 0, 1), - type_=TransactionType.BUY_PRODUCT, - amount=27, - product_count=1, - user_id=user2.id, - product_id=product1.id, - ), - Transaction( - time=datetime(2023, 10, 1, 12, 0, 2), - type_=TransactionType.ADD_PRODUCT, - amount=50, - per_product=50, - product_count=1, - user_id=user1.id, - product_id=product3.id, - ), - Transaction( - time=datetime(2023, 10, 1, 12, 0, 3), - type_=TransactionType.BUY_PRODUCT, - amount=50, - product_count=1, - user_id=user1.id, - product_id=product3.id, - ), - Transaction( - time=datetime(2023, 10, 1, 12, 0, 4), - type_=TransactionType.ADJUST_BALANCE, - amount=1000, - user_id=user1.id, - ), - ] - - sql_session.add_all(transactions) - sql_session.commit() - - # Note: These constructors depend on the content of the previous transactions, - # so they cannot be part of the initial transaction list. - - transaction = Transaction.adjust_stock_auto_amount( - sql_session=sql_session, - time=datetime(2023, 10, 1, 13, 0, 0), - product_count=3, - user_id=user1.id, - product_id=product1.id, - ) - - sql_session.add(transaction) - sql_session.commit() - - transaction = Transaction.adjust_stock_auto_amount( - sql_session=sql_session, - time=datetime(2023, 10, 1, 13, 0, 1), - product_count=-2, - user_id=user1.id, - product_id=product1.id, - ) - - sql_session.add(transaction) - sql_session.commit() - - -def test_no_duplicate_products(sql_session: Session): - insert_test_data(sql_session) - - product1 = Product("1234567890123", "Test Product 1") - sql_session.add(product1) - - with pytest.raises(IntegrityError): - sql_session.commit() - - -def test_product_stock(sql_session: Session): - insert_test_data(sql_session) - - product1 = sql_session.scalars(select(Product).where(Product.name == "Test Product 1")).one() - product2 = sql_session.scalars(select(Product).where(Product.name == "Test Product 2")).one() - - assert product1.stock(sql_session) == 2 - 1 + 3 - 2 - assert product2.stock(sql_session) == 0 - -def test_product_price(sql_session: Session): - insert_test_data(sql_session) - - product1 = sql_session.scalars(select(Product).where(Product.name == "Test Product 1")).one() - assert product1.price(sql_session) == 27 - - -def test_product_no_transactions_price(sql_session: Session): - insert_test_data(sql_session) - - product2 = sql_session.scalars(select(Product).where(Product.name == "Test Product 2")).one() - assert product2.price(sql_session) == 0 - - -def test_product_sold_out_price(sql_session: Session): - insert_test_data(sql_session) - - product3 = sql_session.scalars(select(Product).where(Product.name == "Test Product 3")).one() - assert product3.price(sql_session) == 50 - -def test_allowed_to_buy_more_than_stock(sql_session: Session): - insert_test_data(sql_session) - - product1 = sql_session.scalars(select(Product).where(Product.name == "Test Product 1")).one() - user1 = sql_session.scalars(select(User).where(User.name == "Test User 1")).one() - - transaction = Transaction.buy_product( - time=datetime(2023, 10, 1, 12, 0, 6), - amount = 27 * 5, - product_count=10, - user_id=user1.id, - product_id=product1.id, - ) - - sql_session.add(transaction) - sql_session.commit() - - product1_stock = product1.stock(sql_session) - assert product1_stock < 0 # Should be negative, as we bought more than available stock - - product1_price = product1.price(sql_session) - assert product1_price == 27 # Price should remain the same, as it is based on previous transactions - - transaction = Transaction.add_product( - time=datetime(2023, 10, 1, 12, 0, 8), - amount=22, - per_product=22, - product_count=1, - user_id=user1.id, - product_id=product1.id, - ) - - sql_session.add(transaction) - sql_session.commit() - - product1_price = product1.price(sql_session) - assert product1_price == 22 # Price should now be updated to the new price of the added product - - -def test_not_allowed_to_buy_with_incorrect_amount(sql_session: Session): - insert_test_data(sql_session) - - product1 = sql_session.scalars(select(Product).where(Product.name == "Test Product 1")).one() - user1 = sql_session.scalars(select(User).where(User.name == "Test User 1")).one() - - product1_price = product1.price(sql_session) - - with pytest.raises(IntegrityError): - transaction = Transaction.buy_product( - time=datetime(2023, 10, 1, 12, 0, 7), - amount= product1_price * 4 + 1, # Incorrect amount - product_count=4, - user_id=user1.id, - product_id=product1.id, - ) - sql_session.add(transaction) - sql_session.commit() - - -def test_not_allowed_to_buy_with_too_little_balance(sql_session: Session): - ... diff --git a/tests/test_transaction.py b/tests/test_transaction.py deleted file mode 100644 index df10e5f..0000000 --- a/tests/test_transaction.py +++ /dev/null @@ -1,97 +0,0 @@ -from datetime import datetime - -import pytest - -from sqlalchemy import select -from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import Session - -from dibbler.models import Product, Transaction, TransactionType, User - - -def insert_test_data(sql_session: Session) -> None: - # Add users - user1 = User("Test User 1") - user2 = User("Test User 2") - - sql_session.add(user1) - sql_session.add(user2) - sql_session.commit() - - # Add products - product1 = Product("1234567890123", "Test Product 1") - product2 = Product("9876543210987", "Test Product 2") - sql_session.add(product1) - sql_session.add(product2) - sql_session.commit() - - # Add transactions - transactions = [ - Transaction( - time=datetime(2023, 10, 1, 10, 0, 0), - type_=TransactionType.ADJUST_BALANCE, - amount=100, - user_id=user1.id, - ), - Transaction( - time=datetime(2023, 10, 1, 10, 0, 1), - type_=TransactionType.ADJUST_BALANCE, - amount=50, - user_id=user2.id, - ), - Transaction( - time=datetime(2023, 10, 1, 10, 0, 2), - type_=TransactionType.ADJUST_BALANCE, - amount=-50, - user_id=user1.id, - ), - Transaction( - time=datetime(2023, 10, 1, 12, 0, 0), - type_=TransactionType.ADD_PRODUCT, - amount=27 * 2, - per_product=27, - product_count=2, - user_id=user1.id, - product_id=product1.id, - ), - Transaction( - time=datetime(2023, 10, 1, 12, 0, 1), - type_=TransactionType.BUY_PRODUCT, - amount=27, - product_count=1, - user_id=user2.id, - product_id=product1.id, - ), - ] - - sql_session.add_all(transactions) - sql_session.commit() - -def test_no_duplicate_timestamps(sql_session: Session): - """ - Ensure that no two transactions have the same timestamp. - """ - # Insert test data - insert_test_data(sql_session) - - user1 = sql_session.scalar( - select(User).where(User.name == "Test User 1") - ) - - assert user1 is not None, "Test User 1 should exist" - - transaction_to_duplicate = sql_session.scalar( - select(Transaction).limit(1) - ) - - assert transaction_to_duplicate is not None, "There should be at least one transaction" - - duplicate_timestamp_transaction = Transaction.adjust_balance( - time=transaction_to_duplicate.time, # Use the same timestamp as an existing transaction - amount=50, - user_id=user1.id, - ) - - with pytest.raises(IntegrityError): - sql_session.add(duplicate_timestamp_transaction) - sql_session.commit() diff --git a/tests/test_user.py b/tests/test_user.py deleted file mode 100644 index 9d9da67..0000000 --- a/tests/test_user.py +++ /dev/null @@ -1,108 +0,0 @@ -from datetime import datetime - -import pytest -from sqlalchemy import select -from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import Session - -from dibbler.models import Product, Transaction, TransactionType, User - - -def insert_test_data(sql_session: Session) -> None: - # Add users - user1 = User("Test User 1") - user2 = User("Test User 2") - - sql_session.add(user1) - sql_session.add(user2) - sql_session.commit() - - # Add products - product1 = Product("1234567890123", "Test Product 1") - product2 = Product("9876543210987", "Test Product 2") - sql_session.add(product1) - sql_session.add(product2) - sql_session.commit() - - # Add transactions - transactions = [ - Transaction( - time=datetime(2023, 10, 1, 10, 0, 0), - type_=TransactionType.ADJUST_BALANCE, - amount=100, - user_id=user1.id, - ), - Transaction( - time=datetime(2023, 10, 1, 10, 0, 1), - type_=TransactionType.ADJUST_BALANCE, - amount=50, - user_id=user2.id, - ), - Transaction( - time=datetime(2023, 10, 1, 10, 0, 2), - type_=TransactionType.ADJUST_BALANCE, - amount=-50, - user_id=user1.id, - ), - Transaction( - time=datetime(2023, 10, 1, 12, 0, 0), - type_=TransactionType.ADD_PRODUCT, - amount=27 * 2, - per_product=27, - product_count=2, - user_id=user1.id, - product_id=product1.id, - ), - Transaction( - time=datetime(2023, 10, 1, 12, 0, 1), - type_=TransactionType.BUY_PRODUCT, - amount=27, - product_count=1, - user_id=user2.id, - product_id=product1.id, - ), - ] - - sql_session.add_all(transactions) - sql_session.commit() - - -def test_ensure_no_duplicate_users(sql_session: Session): - insert_test_data(sql_session) - - user1 = User("Test User 1") - sql_session.add(user1) - - with pytest.raises(IntegrityError): - sql_session.commit() - - -def test_user_credit(sql_session: Session): - insert_test_data(sql_session) - - user1 = sql_session.scalars(select(User).where(User.name == "Test User 1")).one() - user2 = sql_session.scalars(select(User).where(User.name == "Test User 2")).one() - - assert user1.credit(sql_session) == 100 - 50 + 27 * 2 - assert user2.credit(sql_session) == 50 - 27 - -def test_user_transactions(sql_session: Session): - insert_test_data(sql_session) - - user1 = sql_session.scalars(select(User).where(User.name == "Test User 1")).one() - user2 = sql_session.scalars(select(User).where(User.name == "Test User 2")).one() - - user1_transactions = user1.transactions(sql_session) - user2_transactions = user2.transactions(sql_session) - - assert len(user1_transactions) == 3 - assert len(user2_transactions) == 2 - -def test_user_not_allowed_to_transfer_to_self(sql_session: Session): - insert_test_data(sql_session) - ... - - # user1 = sql_session.scalars(select(User).where(User.name == "Test User 1")).one() - - # with pytest.raises(ValueError, match="Cannot transfer to self"): - # user1.transfer(sql_session, user1, 10) # Attempting to transfer to self