From 745db277ec6b0bf75e6b7de51d15fadfe9d1b48f Mon Sep 17 00:00:00 2001 From: h7x4 Date: Tue, 10 Jun 2025 20:59:38 +0200 Subject: [PATCH] WIP --- dibbler/models/Base.py | 10 +- dibbler/models/InterestRate.py | 10 + dibbler/models/Product.py | 132 +++++- dibbler/models/ProductPriceCache.py | 11 + dibbler/models/Purchase.py | 70 --- dibbler/models/PurchaseEntry.py | 37 -- dibbler/models/Transaction.py | 591 +++++++++++++++++++++++- dibbler/models/TransactionType.py | 13 + dibbler/models/User.py | 76 ++- dibbler/models/UserBalanceCache.py | 11 + dibbler/models/UserProducts.py | 31 -- dibbler/models/__init__.py | 7 +- dibbler/subcommands/repopulate_cache.py | 0 dibbler/subcommands/seed_test_data.py | 129 ++++-- example-config.ini | 2 +- tests/__init__.py | 0 tests/conftest.py | 27 ++ tests/test_product.py | 216 +++++++++ tests/test_transaction.py | 97 ++++ tests/test_user.py | 108 +++++ 20 files changed, 1324 insertions(+), 254 deletions(-) create mode 100644 dibbler/models/InterestRate.py create mode 100644 dibbler/models/ProductPriceCache.py delete mode 100644 dibbler/models/Purchase.py delete mode 100644 dibbler/models/PurchaseEntry.py create mode 100644 dibbler/models/TransactionType.py create mode 100644 dibbler/models/UserBalanceCache.py delete mode 100644 dibbler/models/UserProducts.py create mode 100644 dibbler/subcommands/repopulate_cache.py create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/test_product.py create mode 100644 tests/test_transaction.py create mode 100644 tests/test_user.py diff --git a/dibbler/models/Base.py b/dibbler/models/Base.py index f0764fe..f48fb10 100644 --- a/dibbler/models/Base.py +++ b/dibbler/models/Base.py @@ -10,12 +10,16 @@ from sqlalchemy.orm.collections import ( ) +def _pascal_case_to_snake_case(name: str) -> str: + return "".join(["_" + i.lower() if i.isupper() else i for i in name]).lstrip("_") + + class Base(DeclarativeBase): metadata = MetaData( naming_convention={ - "ix": "ix_%(column_0_label)s", + "ix": "ix_%(table_name)s_%(column_0_label)s", "uq": "uq_%(table_name)s_%(column_0_name)s", - "ck": "ck_%(table_name)s_`%(constraint_name)s`", + "ck": "ck_%(table_name)s_%(constraint_name)s", "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", "pk": "pk_%(table_name)s", } @@ -23,7 +27,7 @@ class Base(DeclarativeBase): @declared_attr.directive def __tablename__(cls) -> str: - return cls.__name__ + return _pascal_case_to_snake_case(cls.__name__) def __repr__(self) -> str: columns = ", ".join( diff --git a/dibbler/models/InterestRate.py b/dibbler/models/InterestRate.py new file mode 100644 index 0000000..c9ffccb --- /dev/null +++ b/dibbler/models/InterestRate.py @@ -0,0 +1,10 @@ +from datetime import datetime + +from sqlalchemy import Integer, DateTime +from sqlalchemy.orm import Mapped, mapped_column + +from dibbler.models import Base + +class InterestRate(Base): + timestamp: Mapped[datetime] = mapped_column(DateTime) + percentage: Mapped[int] = mapped_column(Integer) diff --git a/dibbler/models/Product.py b/dibbler/models/Product.py index 48e2f26..b5c22e9 100644 --- a/dibbler/models/Product.py +++ b/dibbler/models/Product.py @@ -1,47 +1,129 @@ from __future__ import annotations -from typing import TYPE_CHECKING + +from typing import Self from sqlalchemy import ( Boolean, Integer, String, + case, + func, + select, ) from sqlalchemy.orm import ( Mapped, + Session, mapped_column, - relationship, ) -from .Base import Base +import dibbler.models.User as user -if TYPE_CHECKING: - from .PurchaseEntry import PurchaseEntry - from .UserProducts import UserProducts +from .Base import Base +from .Transaction import Transaction +from .TransactionType import TransactionType + +# if TYPE_CHECKING: +# from .PurchaseEntry import PurchaseEntry +# from .UserProducts import UserProducts class Product(Base): - __tablename__ = "products" + id: Mapped[int] = mapped_column(Integer, primary_key=True) - product_id: Mapped[int] = mapped_column(Integer, primary_key=True) - bar_code: Mapped[str] = mapped_column(String(13)) + bar_code: Mapped[str] = mapped_column(String(13), unique=True) name: Mapped[str] = mapped_column(String(45)) - price: Mapped[int] = mapped_column(Integer) - stock: Mapped[int] = mapped_column(Integer) - hidden: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + # price: Mapped[int] = mapped_column(Integer) + # stock: Mapped[int] = mapped_column(Integer) + hidden: Mapped[bool] = mapped_column(Boolean, default=False) - purchases: Mapped[set[PurchaseEntry]] = relationship(back_populates="product") - users: Mapped[set[UserProducts]] = relationship(back_populates="product") - - bar_code_re = r"[0-9]+" - name_re = r".+" - name_length = 45 - - def __init__(self, bar_code, name, price, stock=0, hidden=False): - self.name = name + def __init__( + self: Self, + bar_code: str, + name: str, + hidden: bool = False, + ) -> None: self.bar_code = bar_code - self.price = price - self.stock = stock + self.name = name self.hidden = hidden - def __str__(self): - return self.name + # - count (virtual) + def stock(self: Self, sql_session: Session) -> int: + """ + Returns the number of products in stock. + """ + + result = sql_session.scalars( + select( + func.sum( + case( + ( + Transaction.type_ == TransactionType.ADD_PRODUCT, + Transaction.product_count, + ), + ( + Transaction.type_ == TransactionType.BUY_PRODUCT, + -Transaction.product_count, + ), + ( + Transaction.type_ == TransactionType.ADJUST_STOCK, + Transaction.product_count, + ), + else_=0, + ) + ) + ).where( + Transaction.type_.in_( + [ + TransactionType.BUY_PRODUCT, + TransactionType.ADD_PRODUCT, + TransactionType.ADJUST_STOCK, + ] + ), + Transaction.product_id == self.id, + ) + ).one_or_none() + + return result or 0 + + def remaining_with_exact_price(self: Self, sql_session: Session) -> list[int]: + """ + Retrieves the remaining products with their exact price as they were bought. + """ + + stock = self.stock(sql_session) + + # TODO: only retrieve as many transactions as exists in the stock + last_added = sql_session.scalars( + select( + func.row_number(), + Transaction.time, + Transaction.per_product, + Transaction.product_count, + ) + .where( + Transaction.type_ == TransactionType.ADD_PRODUCT, + Transaction.product_id == self.id, + ) + .order_by(Transaction.time.desc()) + ).all() + + # result = [] + # while stock > 0 and last_added: + + ... + + def price(self: Self, sql_session: Session) -> int: + """ + Returns the price of the product. + + Average price over the last bought products. + """ + + return Transaction.product_price(sql_session=sql_session, product=self) + + def owned_by_user(self: Self, sql_session: Session) -> dict[user.User, int]: + """ + Returns an overview of how many of the remaining products are owned by which user. + """ + + ... diff --git a/dibbler/models/ProductPriceCache.py b/dibbler/models/ProductPriceCache.py new file mode 100644 index 0000000..d468f87 --- /dev/null +++ b/dibbler/models/ProductPriceCache.py @@ -0,0 +1,11 @@ +from datetime import datetime + +from sqlalchemy import Integer, DateTime +from sqlalchemy.orm import Mapped, mapped_column + +from dibbler.models import Base + +class ProductPriceCache(Base): + product_id: Mapped[int] = mapped_column(Integer, primary_key=True) + timestamp: Mapped[datetime] = mapped_column(DateTime) + price: Mapped[int] = mapped_column(Integer) diff --git a/dibbler/models/Purchase.py b/dibbler/models/Purchase.py deleted file mode 100644 index b725f96..0000000 --- a/dibbler/models/Purchase.py +++ /dev/null @@ -1,70 +0,0 @@ -from __future__ import annotations -from typing import TYPE_CHECKING - -from datetime import datetime -import math - -from sqlalchemy import ( - DateTime, - Integer, -) -from sqlalchemy.orm import ( - Mapped, - mapped_column, - relationship, -) - -from .Base import Base -from .Transaction import Transaction - -if TYPE_CHECKING: - from .PurchaseEntry import PurchaseEntry - - -class Purchase(Base): - __tablename__ = "purchases" - - id: Mapped[int] = mapped_column(Integer, primary_key=True) - time: Mapped[datetime] = mapped_column(DateTime) - price: Mapped[int] = mapped_column(Integer) - - transactions: Mapped[set[Transaction]] = relationship( - back_populates="purchase", order_by="Transaction.user_name" - ) - entries: Mapped[set[PurchaseEntry]] = relationship(back_populates="purchase") - - def __init__(self): - pass - - def is_complete(self): - return len(self.transactions) > 0 and len(self.entries) > 0 - - def price_per_transaction(self, round_up=True): - if round_up: - return int(math.ceil(float(self.price) / len(self.transactions))) - else: - return int(math.floor(float(self.price) / len(self.transactions))) - - def set_price(self, round_up=True): - self.price = 0 - for entry in self.entries: - self.price += entry.amount * entry.product.price - if len(self.transactions) > 0: - for t in self.transactions: - t.amount = self.price_per_transaction(round_up=round_up) - - def perform_purchase(self, ignore_penalty=False, round_up=True): - self.time = datetime.datetime.now() - self.set_price(round_up=round_up) - for t in self.transactions: - t.perform_transaction(ignore_penalty=ignore_penalty) - for entry in self.entries: - entry.product.stock -= entry.amount - - def perform_soft_purchase(self, price, round_up=True): - self.time = datetime.datetime.now() - self.price = price - for t in self.transactions: - t.amount = self.price_per_transaction(round_up=round_up) - for t in self.transactions: - t.perform_transaction() diff --git a/dibbler/models/PurchaseEntry.py b/dibbler/models/PurchaseEntry.py deleted file mode 100644 index 8484b32..0000000 --- a/dibbler/models/PurchaseEntry.py +++ /dev/null @@ -1,37 +0,0 @@ -from __future__ import annotations -from typing import TYPE_CHECKING - -from sqlalchemy import ( - Integer, - ForeignKey, -) -from sqlalchemy.orm import ( - Mapped, - mapped_column, - relationship, -) - -from .Base import Base - -if TYPE_CHECKING: - from .Product import Product - from .Purchase import Purchase - - -class PurchaseEntry(Base): - __tablename__ = "purchase_entries" - - id: Mapped[int] = mapped_column(Integer, primary_key=True) - amount: Mapped[int] = mapped_column(Integer) - - product_id: Mapped[int] = mapped_column(ForeignKey("products.product_id")) - purchase_id: Mapped[int] = mapped_column(ForeignKey("purchases.id")) - - product: Mapped[Product] = relationship(lazy="joined") - purchase: Mapped[Purchase] = relationship(lazy="joined") - - def __init__(self, purchase, product, amount): - self.product = product - self.product_bar_code = product.bar_code - self.purchase = purchase - self.amount = amount diff --git a/dibbler/models/Transaction.py b/dibbler/models/Transaction.py index df1155c..b07f79d 100644 --- a/dibbler/models/Transaction.py +++ b/dibbler/models/Transaction.py @@ -1,52 +1,601 @@ from __future__ import annotations -from typing import TYPE_CHECKING from datetime import datetime +from typing import TYPE_CHECKING, Self from sqlalchemy import ( + CheckConstraint, DateTime, ForeignKey, Integer, - String, + Text, + asc, + case, + cast, + func, + literal, + select, +) +from sqlalchemy import ( + Enum as SQLEnum, ) from sqlalchemy.orm import ( Mapped, + Session, mapped_column, relationship, ) +from sqlalchemy.sql.schema import Index from .Base import Base +from .TransactionType import TransactionType if TYPE_CHECKING: + from .Product import Product from .User import User - from .Purchase import Purchase + +# TODO: allow for joint transactions? +# dibbler allows joint transactions (e.g. buying more than one product at once, several people buying the same product, etc.) +# instead of having the software split the transactions up, making them hard to reconnect, +# maybe we should add some sort of joint transaction id field to allow multiple transactions to be grouped together? + +_DYNAMIC_FIELDS: set[str] = { + "per_product", + "user_id", + "transfer_user_id", + "product_id", + "product_count", +} + +_EXPECTED_FIELDS: dict[TransactionType, set[str]] = { + TransactionType.ADJUST_BALANCE: {"user_id"}, + TransactionType.ADJUST_STOCK: {"user_id", "product_id", "product_count"}, + TransactionType.TRANSFER: {"user_id", "transfer_user_id"}, + TransactionType.ADD_PRODUCT: {"user_id", "product_id", "per_product", "product_count"}, + TransactionType.BUY_PRODUCT: {"user_id", "product_id", "product_count"}, +} + + +def _transaction_type_field_constraints( + transaction_type: TransactionType, + expected_fields: set[str], +) -> CheckConstraint: + unexpected_fields = _DYNAMIC_FIELDS - expected_fields + + expected_constraints = ["{} IS NOT NULL".format(field) for field in expected_fields] + unexpected_constraints = ["{} IS NULL".format(field) for field in unexpected_fields] + + constraints = expected_constraints + unexpected_constraints + + # TODO: use sqlalchemy's `and_` and `or_` to build the constraints + return CheckConstraint( + f"type <> '{transaction_type}' OR ({' AND '.join(constraints)})", + name=f"trx_type_{transaction_type.value}_expected_fields", + ) class Transaction(Base): - __tablename__ = "transactions" + __table_args__ = ( + *[ + _transaction_type_field_constraints(transaction_type, expected_fields) + for transaction_type, expected_fields in _EXPECTED_FIELDS.items() + ], + # Speed up product count calculation + Index("product_user_time", "product_id", "user_id", "time"), + # Speed up product owner calculation + Index("user_product_time", "user_id", "product_id", "time"), + # Speed up user transaction list / credit calculation + Index("user_time", "user_id", "time"), + ) id: Mapped[int] = mapped_column(Integer, primary_key=True) + time: Mapped[datetime] = mapped_column(DateTime, unique=True) + message: Mapped[str | None] = mapped_column(Text, nullable=True) - time: Mapped[datetime] = mapped_column(DateTime) + # The type of transaction + type_: Mapped[TransactionType] = mapped_column(SQLEnum(TransactionType), name="type") + + # TODO: this should be inferred + # If buying products, is the user penalized for having too low credit? + # penalty: Mapped[Boolean] = mapped_column(Boolean, default=False) + + # The amount of money being added or subtracted from the user's credit + # This amount means different things depending on the transaction type: + # - ADJUST_BALANCE: The amount of credit to add or subtract from the user's balance + # - ADJUST_STOCK: The amount of money which disappeared with this stock adjustment + # (i.e. current price * product_count) + # - TRANSFER: The amount of credit to transfer to another user + # - ADD_PRODUCT: The real amount spent on the products + # (i.e. not per_product * product_count, which should be rounded up) + # - BUY_PRODUCT: The amount of credit spent on the product amount: Mapped[int] = mapped_column(Integer) - penalty: Mapped[int] = mapped_column(Integer) - description: Mapped[str | None] = mapped_column(String(50)) - user_name: Mapped[str] = mapped_column(ForeignKey("users.name")) - purchase_id: Mapped[int | None] = mapped_column(ForeignKey("purchases.id")) + # If adding products, how much is each product worth + per_product: Mapped[int | None] = mapped_column(Integer) - user: Mapped[User] = relationship(lazy="joined") - purchase: Mapped[Purchase] = relationship(lazy="joined") + # The user who performs the transaction + user_id: Mapped[int | None] = mapped_column(ForeignKey("user.id")) + user: Mapped[User | None] = relationship( + lazy="joined", + foreign_keys=[user_id], + ) - def __init__(self, user, amount=0, description=None, purchase=None, penalty=1): - self.user = user + # Receiving user when moving credit from one user to another + transfer_user_id: Mapped[int | None] = mapped_column(ForeignKey("user.id")) + transfer_user: Mapped[User | None] = relationship( + lazy="joined", + foreign_keys=[transfer_user_id], + ) + + # The product that is either being added or bought + product_id: Mapped[int | None] = mapped_column(ForeignKey("product.id")) + product: Mapped[Product | None] = relationship(lazy="joined") + + # The amount of products being added or bought + product_count: Mapped[int | None] = mapped_column(Integer) + + def __init__( + self: Self, + type_: TransactionType, + user_id: int, + amount: int, + time: datetime | None = None, + message: str | None = None, + product_id: int | None = None, + transfer_user_id: int | None = None, + per_product: int | None = None, + product_count: int | None = None, + # penalty: bool = False + ) -> None: + if time is None: + time = datetime.now() + + self.time = time + self.message = message + self.type_ = type_ self.amount = amount - self.description = description - self.purchase = purchase - self.penalty = penalty + self.user_id = user_id + self.product_id = product_id + self.transfer_user_id = transfer_user_id + self.per_product = per_product + self.product_count = product_count + # self.penalty = penalty - def perform_transaction(self, ignore_penalty=False): - self.time = datetime.datetime.now() - if not ignore_penalty: - self.amount *= self.penalty - self.user.credit -= self.amount + self._validate_by_transaction_type() + + def _validate_by_transaction_type(self: Self) -> None: + """ + Validates the transaction based on its type. + Raises ValueError if the transaction is invalid. + """ + # TODO: do we allow free products? + if self.amount == 0: + raise ValueError("Amount must not be zero.") + + for field in _EXPECTED_FIELDS[self.type_]: + if getattr(self, field) is None: + raise ValueError(f"{field} must not be None for {self.type_.value} transactions.") + + for field in _DYNAMIC_FIELDS - _EXPECTED_FIELDS[self.type_]: + if getattr(self, field) is not None: + raise ValueError(f"{field} must be None for {self.type_.value} transactions.") + + if self.per_product is not None and self.per_product <= 0: + raise ValueError("per_product must be greater than zero.") + + if ( + self.per_product is not None + and self.product_count is not None + and self.amount > self.per_product * self.product_count + ): + raise ValueError( + "The real amount of the transaction must be less than the total value of the products." + ) + + ################### + # FACTORY METHODS # + ################### + + @classmethod + def adjust_balance( + cls: type[Self], + amount: int, + user_id: int, + time: datetime | None = None, + message: str | None = None, + ) -> Transaction: + """ + Creates an ADJUST transaction. + """ + return cls( + time=time, + type_=TransactionType.ADJUST_BALANCE, + amount=amount, + user_id=user_id, + message=message, + ) + + @classmethod + def adjust_stock( + cls: type[Self], + amount: int, + user_id: int, + product_id: int, + product_count: int, + time: datetime | None = None, + message: str | None = None, + ) -> Transaction: + """ + Creates an ADJUST_STOCK transaction. + """ + return cls( + time=time, + type_=TransactionType.ADJUST_STOCK, + amount=amount, + user_id=user_id, + product_id=product_id, + product_count=product_count, + message=message, + ) + + @classmethod + def adjust_stock_auto_amount( + cls: type[Self], + sql_session: Session, + user_id: int, + product_id: int, + product_count: int, + time: datetime | None = None, + message: str | None = None, + ) -> Transaction: + """ + Creates an ADJUST_STOCK transaction with the amount automatically calculated based on the product's current price. + """ + from .Product import Product + + product = sql_session.scalar(select(Product).where(Product.id == product_id)) + if product is None: + raise ValueError(f"Product with id {product_id} does not exist.") + + price = product.price(sql_session) + + return cls( + time=time, + type_=TransactionType.ADJUST_STOCK, + amount=price * product_count, + user_id=user_id, + product_id=product_id, + product_count=product_count, + message=message, + ) + + @classmethod + def transfer( + cls: type[Self], + amount: int, + user_id: int, + transfer_user_id: int, + time: datetime | None = None, + message: str | None = None, + ) -> Transaction: + """ + Creates a TRANSFER transaction. + """ + return cls( + time=time, + type_=TransactionType.TRANSFER, + amount=amount, + user_id=user_id, + transfer_user_id=transfer_user_id, + message=message, + ) + + @classmethod + def add_product( + cls: type[Self], + amount: int, + user_id: int, + product_id: int, + per_product: int, + product_count: int, + time: datetime | None = None, + message: str | None = None, + ) -> Transaction: + """ + Creates an ADD_PRODUCT transaction. + """ + return cls( + time=time, + type_=TransactionType.ADD_PRODUCT, + amount=amount, + user_id=user_id, + product_id=product_id, + per_product=per_product, + product_count=product_count, + message=message, + ) + + @classmethod + def buy_product( + cls: type[Self], + amount: int, + user_id: int, + product_id: int, + product_count: int, + time: datetime | None = None, + message: str | None = None, + ) -> Transaction: + """ + Creates a BUY_PRODUCT transaction. + """ + return cls( + time=time, + type_=TransactionType.BUY_PRODUCT, + amount=amount, + user_id=user_id, + product_id=product_id, + product_count=product_count, + message=message, + ) + + @classmethod + def buy_product_auto_amount( + cls: type[Self], + sql_session: Session, + user_id: int, + product_id: int, + product_count: int, + time: datetime | None = None, + message: str | None = None, + ) -> Transaction: + """ + Creates a BUY_PRODUCT transaction with the amount automatically calculated based on the product's current price. + """ + from .Product import Product + + product = sql_session.scalar(select(Product).where(Product.id == product_id)) + if product is None: + raise ValueError(f"Product with id {product_id} does not exist.") + + price = product.price(sql_session) + + return cls( + time=time, + type_=TransactionType.BUY_PRODUCT, + amount=price * product_count, + user_id=user_id, + product_id=product_id, + product_count=product_count, + message=message, + ) + + ############################ + # USER BALANCE CALCULATION # + ############################ + + @staticmethod + def _user_balance_query( + user: User, + # until: datetime | None = None, + ): + """ + The inner query for calculating the user's balance. + This is used both directly via user_balance() and in Transaction CHECK constraints. + """ + + balance_adjustments = ( + select(func.coalesce(func.sum(Transaction.amount).label("balance_adjustments"), 0)) + .where( + Transaction.user_id == user.id, + Transaction.type_ == TransactionType.ADJUST_BALANCE, + ) + .scalar_subquery() + ) + + transfers_to_other_users = ( + select(func.coalesce(func.sum(Transaction.amount).label("transfers_to_other_users"), 0)) + .where( + Transaction.user_id == user.id, + Transaction.type_ == TransactionType.TRANSFER, + ) + .scalar_subquery() + ) + + transfers_to_self = ( + select(func.coalesce(func.sum(Transaction.amount).label("transfers_to_self"), 0)) + .where( + Transaction.transfer_user_id == user.id, + Transaction.type_ == TransactionType.TRANSFER, + ) + .scalar_subquery() + ) + + add_products = ( + select(func.coalesce(func.sum(Transaction.amount).label("add_products"), 0)) + .where( + Transaction.user_id == user.id, + Transaction.type_ == TransactionType.ADD_PRODUCT, + ) + .scalar_subquery() + ) + + buy_products = ( + select(func.coalesce(func.sum(Transaction.amount).label("buy_products"), 0)) + .where( + Transaction.user_id == user.id, + Transaction.type_ == TransactionType.BUY_PRODUCT, + ) + .scalar_subquery() + ) + + query = select( + # TODO: clearly define and fix the sign of the amount + ( + 0 + + balance_adjustments + - transfers_to_other_users + + transfers_to_self + + add_products + - buy_products + ).label("credit") + ) + + return query + + @staticmethod + def user_balance( + sql_session: Session, + user: User, + # Optional: calculate the balance until a certain transaction. + # until: Transaction | None = None, + ) -> int: + """ + Calculates the balance of a user. + """ + + query = Transaction._user_balance_query(user) # , until=until) + + result = sql_session.scalar(query) + + if result is None: + # If there are no transactions for this user, the query should return 0, not None. + raise RuntimeError( + f"Something went wrong while calculating the balance for user {user.name} (ID: {user.id})." + ) + + return result + + ############################# + # PRODUCT PRICE CALCULATION # + ############################# + + @staticmethod + def _product_price_query( + product: Product, + # until: datetime | None = None, + ): + """ + The inner query for calculating the product price. + + This is used both directly via product_price() and in Transaction CHECK constraints. + """ + initial_element = select( + literal(0).label("i"), + literal(0).label("time"), + literal(0).label("price"), + literal(0).label("product_count"), + ) + + recursive_cte = initial_element.cte(name="rec_cte", recursive=True) + + # Subset of transactions that we'll want to iterate over. + trx_subset = ( + select( + func.row_number().over(order_by=asc(Transaction.time)).label("i"), + Transaction.time, + Transaction.type_, + Transaction.product_count, + Transaction.per_product, + ) + .where( + Transaction.type_.in_( + [ + TransactionType.BUY_PRODUCT, + TransactionType.ADD_PRODUCT, + TransactionType.ADJUST_STOCK, + ] + ), + Transaction.product_id == product.id, + # TODO: + # If we have a transaction to limit the price calculation to, use it. + # If not, use all transactions for the product. + # (Transaction.time <= until.time) if until else True, + ) + .order_by(Transaction.time.asc()) + .alias("trx_subset") + ) + + recursive_elements = ( + select( + trx_subset.c.i, + trx_subset.c.time, + case( + # Someone buys the product -> price remains the same. + (trx_subset.c.type_ == TransactionType.BUY_PRODUCT, recursive_cte.c.price), + # Someone adds the product -> price is recalculated based on + # product count, previous price, and new price. + ( + trx_subset.c.type_ == TransactionType.ADD_PRODUCT, + cast( + func.ceil( + (trx_subset.c.per_product * trx_subset.c.product_count) + / ( + # The running product count can be negative if the accounting is bad. + # This ensures that we never end up with negative prices or zero divisions + # and other disastrous phenomena. + func.min(recursive_cte.c.product_count, 0) + + trx_subset.c.product_count + ) + ), + Integer, + ), + ), + # Someone adjusts the stock -> price remains the same. + (trx_subset.c.type_ == TransactionType.ADJUST_STOCK, recursive_cte.c.price), + # Should never happen + else_=recursive_cte.c.price, + ).label("price"), + case( + # Someone buys the product -> product count is reduced. + ( + trx_subset.c.type_ == TransactionType.BUY_PRODUCT, + recursive_cte.c.product_count - trx_subset.c.product_count, + ), + # Someone adds the product -> product count is increased. + ( + trx_subset.c.type_ == TransactionType.ADD_PRODUCT, + recursive_cte.c.product_count + trx_subset.c.product_count, + ), + # Someone adjusts the stock -> product count is adjusted. + ( + trx_subset.c.type_ == TransactionType.ADJUST_STOCK, + recursive_cte.c.product_count + trx_subset.c.product_count, + ), + # Should never happen + else_=recursive_cte.c.product_count, + ).label("product_count"), + ) + .select_from(trx_subset) + .where(trx_subset.c.i == recursive_cte.c.i + 1) + ) + + return recursive_cte.union_all(recursive_elements) + + @staticmethod + def product_price( + sql_session: Session, + product: Product, + # Optional: calculate the price until a certain transaction. + # until: Transaction | None = None, + ) -> int: + """ + Calculates the price of a product. + """ + + recursive_cte = Transaction._product_price_query(product) # , until=until) + + # TODO: optionally verify subresults: + # - product_count should never be negative (but this happens sometimes, so just a warning) + # - price should never be negative + + result = sql_session.scalar( + select(recursive_cte.c.price).order_by(recursive_cte.c.i.desc()).limit(1) + ) + + if result is None: + # If there are no transactions for this product, the query should return 0, not None. + raise RuntimeError( + f"Something went wrong while calculating the price for product {product.name} (ID: {product.id})." + ) + + return result diff --git a/dibbler/models/TransactionType.py b/dibbler/models/TransactionType.py new file mode 100644 index 0000000..a0af64a --- /dev/null +++ b/dibbler/models/TransactionType.py @@ -0,0 +1,13 @@ +from enum import Enum + + +class TransactionType(Enum): + """ + Enum for transaction types. + """ + + ADJUST_BALANCE = "adjust_balance" + ADJUST_STOCK = "adjust_stock" + TRANSFER = "transfer" + ADD_PRODUCT = "add_product" + BUY_PRODUCT = "buy_product" diff --git a/dibbler/models/User.py b/dibbler/models/User.py index d93e7fb..950dca1 100644 --- a/dibbler/models/User.py +++ b/dibbler/models/User.py @@ -1,49 +1,75 @@ from __future__ import annotations -from typing import TYPE_CHECKING + +from typing import Self from sqlalchemy import ( Integer, String, + select, ) from sqlalchemy.orm import ( Mapped, + Session, mapped_column, - relationship, ) -from .Base import Base +import dibbler.models.Product as product -if TYPE_CHECKING: - from .UserProducts import UserProducts - from .Transaction import Transaction +from .Base import Base +from .Transaction import Transaction class User(Base): - __tablename__ = "users" - name: Mapped[str] = mapped_column(String(10), primary_key=True) - credit: Mapped[str] = mapped_column(Integer) + id: Mapped[int] = mapped_column(Integer, primary_key=True) + + name: Mapped[str] = mapped_column(String(20), unique=True) card: Mapped[str | None] = mapped_column(String(20)) rfid: Mapped[str | None] = mapped_column(String(20)) - products: Mapped[set[UserProducts]] = relationship(back_populates="user") - transactions: Mapped[set[Transaction]] = relationship(back_populates="user") + # name_re = r"[a-z]+" + # card_re = r"(([Nn][Tt][Nn][Uu])?[0-9]+)?" + # rfid_re = r"[0-9a-fA-F]*" - name_re = r"[a-z]+" - card_re = r"(([Nn][Tt][Nn][Uu])?[0-9]+)?" - rfid_re = r"[0-9a-fA-F]*" - - def __init__(self, name, card, rfid=None, credit=0): + def __init__(self: Self, name: str, card: str | None = None, rfid: str | None = None) -> None: self.name = name - if card == "": - card = None self.card = card - if rfid == "": - rfid = None self.rfid = rfid - self.credit = credit - def __str__(self): - return self.name + # def __str__(self): + # return self.name - def is_anonymous(self): - return self.card == "11122233" + # def is_anonymous(self): + # return self.card == "11122233" + + # TODO: rename to 'balance' everywhere + def credit(self, sql_session: Session) -> int: + """ + Returns the current credit of the user. + """ + + result = Transaction.user_balance( + sql_session=sql_session, + user=self, + ) + + return result + + def products(self, sql_session: Session) -> list[tuple[product.Product, int]]: + """ + Returns the products that the user has put into the system (and has not been purchased yet) + """ + + ... + + def transactions(self, sql_session: Session) -> list[Transaction]: + """ + Returns the transactions of the user in chronological order. + """ + + return list( + sql_session.scalars( + select(Transaction) + .where(Transaction.user_id == self.id) + .order_by(Transaction.time.asc()) + ).all() + ) diff --git a/dibbler/models/UserBalanceCache.py b/dibbler/models/UserBalanceCache.py new file mode 100644 index 0000000..b42994e --- /dev/null +++ b/dibbler/models/UserBalanceCache.py @@ -0,0 +1,11 @@ +from datetime import datetime + +from sqlalchemy import Integer, DateTime +from sqlalchemy.orm import Mapped, mapped_column + +from dibbler.models import Base + +class UserBalanceCache(Base): + user_id: Mapped[int] = mapped_column(Integer, primary_key=True) + timestamp: Mapped[datetime] = mapped_column(DateTime) + balance: Mapped[int] = mapped_column(Integer) diff --git a/dibbler/models/UserProducts.py b/dibbler/models/UserProducts.py deleted file mode 100644 index 17a8f13..0000000 --- a/dibbler/models/UserProducts.py +++ /dev/null @@ -1,31 +0,0 @@ -from __future__ import annotations -from typing import TYPE_CHECKING - -from sqlalchemy import ( - Integer, - ForeignKey, -) -from sqlalchemy.orm import ( - Mapped, - mapped_column, - relationship, -) - -from .Base import Base - -if TYPE_CHECKING: - from .User import User - from .Product import Product - - -class UserProducts(Base): - __tablename__ = "user_products" - - user_name: Mapped[str] = mapped_column(ForeignKey("users.name"), primary_key=True) - product_id: Mapped[int] = mapped_column(ForeignKey("products.product_id"), primary_key=True) - - count: Mapped[int] = mapped_column(Integer) - sign: Mapped[int] = mapped_column(Integer) - - user: Mapped[User] = relationship() - product: Mapped[Product] = relationship() diff --git a/dibbler/models/__init__.py b/dibbler/models/__init__.py index 9cd0325..1c47c1c 100644 --- a/dibbler/models/__init__.py +++ b/dibbler/models/__init__.py @@ -1,17 +1,12 @@ __all__ = [ "Base", "Product", - "Purchase", - "PurchaseEntry", "Transaction", "User", - "UserProducts", ] from .Base import Base from .Product import Product -from .Purchase import Purchase -from .PurchaseEntry import PurchaseEntry from .Transaction import Transaction +from .TransactionType import TransactionType from .User import User -from .UserProducts import UserProducts diff --git a/dibbler/subcommands/repopulate_cache.py b/dibbler/subcommands/repopulate_cache.py new file mode 100644 index 0000000..e69de29 diff --git a/dibbler/subcommands/seed_test_data.py b/dibbler/subcommands/seed_test_data.py index 07454ea..69fab24 100644 --- a/dibbler/subcommands/seed_test_data.py +++ b/dibbler/subcommands/seed_test_data.py @@ -1,48 +1,107 @@ -import json -from dibbler.db import Session - +from datetime import datetime from pathlib import Path -from dibbler.models.Product import Product - -from dibbler.models.User import User +from dibbler.db import Session +from dibbler.models import Product, Transaction, TransactionType, User JSON_FILE = Path(__file__).parent.parent.parent / "mock_data.json" -def clear_db(session): - session.query(Product).delete() - session.query(User).delete() - session.commit() +# TODO: integrate this as a part of create-db, either asking interactively +# whether to seed test data, or by using command line arguments for +# automatating the answer. + +def clear_db(sql_session): + sql_session.query(Product).delete() + sql_session.query(User).delete() + sql_session.commit() def main(): - session = Session() - clear_db(session) - product_items = [] - user_items = [] + # TODO: There is some leftover json data in the mock_data.json file. + # It should be dealt with before merging this PR, either by removing + # it or using it here. + sql_session = Session() + clear_db(sql_session) - with open(JSON_FILE) as f: - json_obj = json.load(f) + # Add users + user1 = User("Test User 1") + user2 = User("Test User 2") - for product in json_obj["products"]: - product_item = Product( - bar_code=product["bar_code"], - name=product["name"], - price=product["price"], - stock=product["stock"], - ) - product_items.append(product_item) + sql_session.add(user1) + sql_session.add(user2) + sql_session.commit() - for user in json_obj["users"]: - user_item = User( - name=user["name"], - card=user["card"], - rfid=user["rfid"], - credit=user["credit"], - ) - user_items.append(user_item) + # Add products + product1 = Product("1234567890123", "Test Product 1") + product2 = Product("9876543210987", "Test Product 2") + sql_session.add(product1) + sql_session.add(product2) + sql_session.commit() - session.add_all(product_items) - session.add_all(user_items) - session.commit() + # Add transactions + transactions = [ + Transaction( + time=datetime(2023, 10, 1, 10, 0, 0), + type_=TransactionType.ADJUST_BALANCE, + amount=100, + user_id=user1.id, + ), + Transaction( + time=datetime(2023, 10, 1, 10, 0, 1), + type_=TransactionType.ADJUST_BALANCE, + amount=50, + user_id=user2.id, + ), + Transaction( + time=datetime(2023, 10, 1, 10, 0, 2), + type_=TransactionType.ADJUST_BALANCE, + amount=-50, + user_id=user1.id, + ), + Transaction( + time=datetime(2023, 10, 1, 12, 0, 0), + type_=TransactionType.ADD_PRODUCT, + amount=27 * 2, + per_product=27, + product_count=2, + user_id=user1.id, + product_id=product1.id, + ), + Transaction( + time=datetime(2023, 10, 1, 12, 0, 1), + type_=TransactionType.BUY_PRODUCT, + amount=27, + product_count=1, + user_id=user2.id, + product_id=product1.id, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + # Note: These constructors depend on the content of the previous transactions, + # so they cannot be part of the initial transaction list. + + transaction = Transaction.adjust_stock_auto_amount( + sql_session=sql_session, + time=datetime(2023, 10, 1, 12, 0, 2), + product_count=3, + user_id=user1.id, + product_id=product1.id, + ) + + sql_session.add(transaction) + sql_session.commit() + + transaction = Transaction.adjust_stock_auto_amount( + sql_session=sql_session, + time=datetime(2023, 10, 1, 12, 0, 3), + product_count=-2, + user_id=user1.id, + product_id=product1.id, + ) + + sql_session.add(transaction) + sql_session.commit() diff --git a/example-config.ini b/example-config.ini index 7abacb0..324b7fd 100644 --- a/example-config.ini +++ b/example-config.ini @@ -6,7 +6,7 @@ input_encoding = 'utf8' [database] # url = "postgresql://robertem@127.0.0.1/pvvvv" -url = "sqlite:///test.db" +url = sqlite:///test.db [limits] low_credit_warning_limit = -100 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..f94d4d7 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,27 @@ +import pytest + +from sqlalchemy import create_engine +from sqlalchemy.orm import Session + +from dibbler.models import Base + +def pytest_addoption(parser): + parser.addoption( + "--echo", + action="store_true", + help="Enable SQLAlchemy echo mode for debugging", + ) + +@pytest.fixture(scope="function") +def sql_session(request): + """Create a new SQLAlchemy session for testing.""" + + echo = request.config.getoption("--echo") + + engine = create_engine( + "sqlite:///:memory:", + echo=echo, + ) + Base.metadata.create_all(engine) + with Session(engine) as sql_session: + yield sql_session diff --git a/tests/test_product.py b/tests/test_product.py new file mode 100644 index 0000000..f424b95 --- /dev/null +++ b/tests/test_product.py @@ -0,0 +1,216 @@ +from datetime import datetime + +import pytest +from sqlalchemy import select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session + +from dibbler.models import Product, Transaction, TransactionType, User + + +def insert_test_data(sql_session: Session) -> None: + # Add users + user1 = User("Test User 1") + user2 = User("Test User 2") + + sql_session.add_all([user1, user2]) + sql_session.commit() + + # Add products + product1 = Product("1234567890123", "Test Product 1") + product2 = Product("9876543210987", "Test Product 2") + product3 = Product("1111111111111", "Test Product 3") + sql_session.add_all([product1, product2, product3]) + sql_session.commit() + + # Add transactions + transactions = [ + Transaction( + time=datetime(2023, 10, 1, 10, 0, 0), + type_=TransactionType.ADJUST_BALANCE, + amount=100, + user_id=user1.id, + ), + Transaction( + time=datetime(2023, 10, 1, 10, 0, 1), + type_=TransactionType.ADJUST_BALANCE, + amount=50, + user_id=user2.id, + ), + Transaction( + time=datetime(2023, 10, 1, 10, 0, 2), + type_=TransactionType.ADJUST_BALANCE, + amount=-50, + user_id=user1.id, + ), + Transaction( + time=datetime(2023, 10, 1, 12, 0, 0), + type_=TransactionType.ADD_PRODUCT, + amount=27 * 2, + per_product=27, + product_count=2, + user_id=user1.id, + product_id=product1.id, + ), + Transaction( + time=datetime(2023, 10, 1, 12, 0, 1), + type_=TransactionType.BUY_PRODUCT, + amount=27, + product_count=1, + user_id=user2.id, + product_id=product1.id, + ), + Transaction( + time=datetime(2023, 10, 1, 12, 0, 2), + type_=TransactionType.ADD_PRODUCT, + amount=50, + per_product=50, + product_count=1, + user_id=user1.id, + product_id=product3.id, + ), + Transaction( + time=datetime(2023, 10, 1, 12, 0, 3), + type_=TransactionType.BUY_PRODUCT, + amount=50, + product_count=1, + user_id=user1.id, + product_id=product3.id, + ), + Transaction( + time=datetime(2023, 10, 1, 12, 0, 4), + type_=TransactionType.ADJUST_BALANCE, + amount=1000, + user_id=user1.id, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + # Note: These constructors depend on the content of the previous transactions, + # so they cannot be part of the initial transaction list. + + transaction = Transaction.adjust_stock_auto_amount( + sql_session=sql_session, + time=datetime(2023, 10, 1, 13, 0, 0), + product_count=3, + user_id=user1.id, + product_id=product1.id, + ) + + sql_session.add(transaction) + sql_session.commit() + + transaction = Transaction.adjust_stock_auto_amount( + sql_session=sql_session, + time=datetime(2023, 10, 1, 13, 0, 1), + product_count=-2, + user_id=user1.id, + product_id=product1.id, + ) + + sql_session.add(transaction) + sql_session.commit() + + +def test_no_duplicate_products(sql_session: Session): + insert_test_data(sql_session) + + product1 = Product("1234567890123", "Test Product 1") + sql_session.add(product1) + + with pytest.raises(IntegrityError): + sql_session.commit() + + +def test_product_stock(sql_session: Session): + insert_test_data(sql_session) + + product1 = sql_session.scalars(select(Product).where(Product.name == "Test Product 1")).one() + product2 = sql_session.scalars(select(Product).where(Product.name == "Test Product 2")).one() + + assert product1.stock(sql_session) == 2 - 1 + 3 - 2 + assert product2.stock(sql_session) == 0 + +def test_product_price(sql_session: Session): + insert_test_data(sql_session) + + product1 = sql_session.scalars(select(Product).where(Product.name == "Test Product 1")).one() + assert product1.price(sql_session) == 27 + + +def test_product_no_transactions_price(sql_session: Session): + insert_test_data(sql_session) + + product2 = sql_session.scalars(select(Product).where(Product.name == "Test Product 2")).one() + assert product2.price(sql_session) == 0 + + +def test_product_sold_out_price(sql_session: Session): + insert_test_data(sql_session) + + product3 = sql_session.scalars(select(Product).where(Product.name == "Test Product 3")).one() + assert product3.price(sql_session) == 50 + +def test_allowed_to_buy_more_than_stock(sql_session: Session): + insert_test_data(sql_session) + + product1 = sql_session.scalars(select(Product).where(Product.name == "Test Product 1")).one() + user1 = sql_session.scalars(select(User).where(User.name == "Test User 1")).one() + + transaction = Transaction.buy_product( + time=datetime(2023, 10, 1, 12, 0, 6), + amount = 27 * 5, + product_count=10, + user_id=user1.id, + product_id=product1.id, + ) + + sql_session.add(transaction) + sql_session.commit() + + product1_stock = product1.stock(sql_session) + assert product1_stock < 0 # Should be negative, as we bought more than available stock + + product1_price = product1.price(sql_session) + assert product1_price == 27 # Price should remain the same, as it is based on previous transactions + + transaction = Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 8), + amount=22, + per_product=22, + product_count=1, + user_id=user1.id, + product_id=product1.id, + ) + + sql_session.add(transaction) + sql_session.commit() + + product1_price = product1.price(sql_session) + assert product1_price == 22 # Price should now be updated to the new price of the added product + + +def test_not_allowed_to_buy_with_incorrect_amount(sql_session: Session): + insert_test_data(sql_session) + + product1 = sql_session.scalars(select(Product).where(Product.name == "Test Product 1")).one() + user1 = sql_session.scalars(select(User).where(User.name == "Test User 1")).one() + + product1_price = product1.price(sql_session) + + with pytest.raises(IntegrityError): + transaction = Transaction.buy_product( + time=datetime(2023, 10, 1, 12, 0, 7), + amount= product1_price * 4 + 1, # Incorrect amount + product_count=4, + user_id=user1.id, + product_id=product1.id, + ) + sql_session.add(transaction) + sql_session.commit() + + +def test_not_allowed_to_buy_with_too_little_balance(sql_session: Session): + ... diff --git a/tests/test_transaction.py b/tests/test_transaction.py new file mode 100644 index 0000000..df10e5f --- /dev/null +++ b/tests/test_transaction.py @@ -0,0 +1,97 @@ +from datetime import datetime + +import pytest + +from sqlalchemy import select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session + +from dibbler.models import Product, Transaction, TransactionType, User + + +def insert_test_data(sql_session: Session) -> None: + # Add users + user1 = User("Test User 1") + user2 = User("Test User 2") + + sql_session.add(user1) + sql_session.add(user2) + sql_session.commit() + + # Add products + product1 = Product("1234567890123", "Test Product 1") + product2 = Product("9876543210987", "Test Product 2") + sql_session.add(product1) + sql_session.add(product2) + sql_session.commit() + + # Add transactions + transactions = [ + Transaction( + time=datetime(2023, 10, 1, 10, 0, 0), + type_=TransactionType.ADJUST_BALANCE, + amount=100, + user_id=user1.id, + ), + Transaction( + time=datetime(2023, 10, 1, 10, 0, 1), + type_=TransactionType.ADJUST_BALANCE, + amount=50, + user_id=user2.id, + ), + Transaction( + time=datetime(2023, 10, 1, 10, 0, 2), + type_=TransactionType.ADJUST_BALANCE, + amount=-50, + user_id=user1.id, + ), + Transaction( + time=datetime(2023, 10, 1, 12, 0, 0), + type_=TransactionType.ADD_PRODUCT, + amount=27 * 2, + per_product=27, + product_count=2, + user_id=user1.id, + product_id=product1.id, + ), + Transaction( + time=datetime(2023, 10, 1, 12, 0, 1), + type_=TransactionType.BUY_PRODUCT, + amount=27, + product_count=1, + user_id=user2.id, + product_id=product1.id, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + +def test_no_duplicate_timestamps(sql_session: Session): + """ + Ensure that no two transactions have the same timestamp. + """ + # Insert test data + insert_test_data(sql_session) + + user1 = sql_session.scalar( + select(User).where(User.name == "Test User 1") + ) + + assert user1 is not None, "Test User 1 should exist" + + transaction_to_duplicate = sql_session.scalar( + select(Transaction).limit(1) + ) + + assert transaction_to_duplicate is not None, "There should be at least one transaction" + + duplicate_timestamp_transaction = Transaction.adjust_balance( + time=transaction_to_duplicate.time, # Use the same timestamp as an existing transaction + amount=50, + user_id=user1.id, + ) + + with pytest.raises(IntegrityError): + sql_session.add(duplicate_timestamp_transaction) + sql_session.commit() diff --git a/tests/test_user.py b/tests/test_user.py new file mode 100644 index 0000000..9d9da67 --- /dev/null +++ b/tests/test_user.py @@ -0,0 +1,108 @@ +from datetime import datetime + +import pytest +from sqlalchemy import select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session + +from dibbler.models import Product, Transaction, TransactionType, User + + +def insert_test_data(sql_session: Session) -> None: + # Add users + user1 = User("Test User 1") + user2 = User("Test User 2") + + sql_session.add(user1) + sql_session.add(user2) + sql_session.commit() + + # Add products + product1 = Product("1234567890123", "Test Product 1") + product2 = Product("9876543210987", "Test Product 2") + sql_session.add(product1) + sql_session.add(product2) + sql_session.commit() + + # Add transactions + transactions = [ + Transaction( + time=datetime(2023, 10, 1, 10, 0, 0), + type_=TransactionType.ADJUST_BALANCE, + amount=100, + user_id=user1.id, + ), + Transaction( + time=datetime(2023, 10, 1, 10, 0, 1), + type_=TransactionType.ADJUST_BALANCE, + amount=50, + user_id=user2.id, + ), + Transaction( + time=datetime(2023, 10, 1, 10, 0, 2), + type_=TransactionType.ADJUST_BALANCE, + amount=-50, + user_id=user1.id, + ), + Transaction( + time=datetime(2023, 10, 1, 12, 0, 0), + type_=TransactionType.ADD_PRODUCT, + amount=27 * 2, + per_product=27, + product_count=2, + user_id=user1.id, + product_id=product1.id, + ), + Transaction( + time=datetime(2023, 10, 1, 12, 0, 1), + type_=TransactionType.BUY_PRODUCT, + amount=27, + product_count=1, + user_id=user2.id, + product_id=product1.id, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + +def test_ensure_no_duplicate_users(sql_session: Session): + insert_test_data(sql_session) + + user1 = User("Test User 1") + sql_session.add(user1) + + with pytest.raises(IntegrityError): + sql_session.commit() + + +def test_user_credit(sql_session: Session): + insert_test_data(sql_session) + + user1 = sql_session.scalars(select(User).where(User.name == "Test User 1")).one() + user2 = sql_session.scalars(select(User).where(User.name == "Test User 2")).one() + + assert user1.credit(sql_session) == 100 - 50 + 27 * 2 + assert user2.credit(sql_session) == 50 - 27 + +def test_user_transactions(sql_session: Session): + insert_test_data(sql_session) + + user1 = sql_session.scalars(select(User).where(User.name == "Test User 1")).one() + user2 = sql_session.scalars(select(User).where(User.name == "Test User 2")).one() + + user1_transactions = user1.transactions(sql_session) + user2_transactions = user2.transactions(sql_session) + + assert len(user1_transactions) == 3 + assert len(user2_transactions) == 2 + +def test_user_not_allowed_to_transfer_to_self(sql_session: Session): + insert_test_data(sql_session) + ... + + # user1 = sql_session.scalars(select(User).where(User.name == "Test User 1")).one() + + # with pytest.raises(ValueError, match="Cannot transfer to self"): + # user1.transfer(sql_session, user1, 10) # Attempting to transfer to self