From 0433d79f0c1a2f05b2f3aa4ba3799c9e4e1e7c5e Mon Sep 17 00:00:00 2001 From: h7x4 Date: Tue, 6 May 2025 17:09:30 +0200 Subject: [PATCH] WIP --- dibbler/models/Base.py | 12 +- dibbler/models/Product.py | 140 +++++++++-- dibbler/models/Purchase.py | 70 ------ dibbler/models/PurchaseEntry.py | 37 --- dibbler/models/Transaction.py | 336 ++++++++++++++++++++++++-- dibbler/models/TransactionType.py | 12 + dibbler/models/User.py | 135 +++++++++-- dibbler/models/UserProducts.py | 31 --- dibbler/models/__init__.py | 7 +- dibbler/subcommands/seed_test_data.py | 47 ++-- example-config.ini | 2 +- nix/dibbler.nix | 8 + nix/shell.nix | 1 + pyproject.toml | 5 + tests/__init__.py | 0 tests/conftest.py | 27 +++ tests/test_product.py | 96 ++++++++ tests/test_transaction.py | 64 +++++ tests/test_user.py | 87 +++++++ 19 files changed, 875 insertions(+), 242 deletions(-) delete mode 100644 dibbler/models/Purchase.py delete mode 100644 dibbler/models/PurchaseEntry.py create mode 100644 dibbler/models/TransactionType.py delete mode 100644 dibbler/models/UserProducts.py create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/test_product.py create mode 100644 tests/test_transaction.py create mode 100644 tests/test_user.py diff --git a/dibbler/models/Base.py b/dibbler/models/Base.py index f0764fe..e472426 100644 --- a/dibbler/models/Base.py +++ b/dibbler/models/Base.py @@ -10,12 +10,18 @@ from sqlalchemy.orm.collections import ( ) +def _pascal_case_to_snake_case(name: str) -> str: + return "".join( + ["_" + i.lower() if i.isupper() else i for i in name] + ).lstrip("_") + + class Base(DeclarativeBase): metadata = MetaData( naming_convention={ - "ix": "ix_%(column_0_label)s", + "ix": "ix_%(table_name)s_%(column_0_label)s", "uq": "uq_%(table_name)s_%(column_0_name)s", - "ck": "ck_%(table_name)s_`%(constraint_name)s`", + "ck": "ck_%(table_name)s_%(constraint_name)s", "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", "pk": "pk_%(table_name)s", } @@ -23,7 +29,7 @@ class Base(DeclarativeBase): @declared_attr.directive def __tablename__(cls) -> str: - return cls.__name__ + return _pascal_case_to_snake_case(cls.__name__) def __repr__(self) -> str: columns = ", ".join( diff --git a/dibbler/models/Product.py b/dibbler/models/Product.py index 48e2f26..5f313fc 100644 --- a/dibbler/models/Product.py +++ b/dibbler/models/Product.py @@ -1,47 +1,137 @@ from __future__ import annotations -from typing import TYPE_CHECKING + +from typing import Self + +import math from sqlalchemy import ( Boolean, Integer, String, + func, + select, ) from sqlalchemy.orm import ( Mapped, + Session, mapped_column, - relationship, ) -from .Base import Base +import dibbler.models.User as user -if TYPE_CHECKING: - from .PurchaseEntry import PurchaseEntry - from .UserProducts import UserProducts +from .Base import Base +from .Transaction import Transaction +from .TransactionType import TransactionType + +# if TYPE_CHECKING: +# from .PurchaseEntry import PurchaseEntry +# from .UserProducts import UserProducts class Product(Base): - __tablename__ = "products" + id: Mapped[int] = mapped_column(Integer, primary_key=True) - product_id: Mapped[int] = mapped_column(Integer, primary_key=True) - bar_code: Mapped[str] = mapped_column(String(13)) + bar_code: Mapped[str] = mapped_column(String(13), unique=True) name: Mapped[str] = mapped_column(String(45)) - price: Mapped[int] = mapped_column(Integer) - stock: Mapped[int] = mapped_column(Integer) - hidden: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + # price: Mapped[int] = mapped_column(Integer) + # stock: Mapped[int] = mapped_column(Integer) + hidden: Mapped[bool] = mapped_column(Boolean, default=False) - purchases: Mapped[set[PurchaseEntry]] = relationship(back_populates="product") - users: Mapped[set[UserProducts]] = relationship(back_populates="product") - - bar_code_re = r"[0-9]+" - name_re = r".+" - name_length = 45 - - def __init__(self, bar_code, name, price, stock=0, hidden=False): - self.name = name + def __init__( + self: Self, + bar_code: str, + name: str, + hidden: bool = False, + ) -> None: self.bar_code = bar_code - self.price = price - self.stock = stock + self.name = name self.hidden = hidden - def __str__(self): - return self.name + # - count (virtual) + def stock(self, sql_session: Session) -> int: + """ + Returns the number of products in stock. + """ + + added_products = sql_session.scalars( + select(func.sum(Transaction.product_count)).where( + Transaction.type == TransactionType.ADD_PRODUCT, + Transaction.product_id == self.id, + ) + ).one_or_none() + + bought_products = sql_session.scalars( + select(func.sum(Transaction.product_count)).where( + Transaction.type == TransactionType.BUY_PRODUCT, + Transaction.product_id == self.id, + ) + ).one_or_none() + + return (added_products or 0) - (bought_products or 0) + + def remaining_with_exact_price(self, sql_session: Session) -> list[int]: + """ + Retrieves the remaining products with their exact price as they were bought. + """ + + stock = self.stock(sql_session) + + # TODO: only retrieve as many transactions as exists in the stock + last_added = sql_session.scalars( + select(Transaction) + .where( + Transaction.type == TransactionType.ADD_PRODUCT, + Transaction.product_id == self.id, + ) + .order_by(Transaction.time.desc()) + ).all() + + # result = [] + # while stock > 0 and last_added: + + ... + + def price(self, sql_session: Session) -> int: + """ + Returns the price of the product. + + Average price over the last bought products. + """ + + buy_add_transactions = sql_session.scalars( + select(Transaction) + .where( + Transaction.type.in_([TransactionType.BUY_PRODUCT, TransactionType.ADD_PRODUCT]), + Transaction.product_id == self.id, + ) + .order_by(Transaction.time.asc()) + ).all() + + price = 0 + product_count = 0 + for trx in buy_add_transactions: + if trx.type == TransactionType.BUY_PRODUCT: + product_count -= trx.product_count + + if product_count < 0: + raise ValueError( + f"Product {self.name} (ID: {self.id}) has negative stock {product_count} after buying {trx.product_count} products." + ) + + elif trx.type == TransactionType.ADD_PRODUCT: + product_count += trx.product_count + price = math.ceil((trx.per_product * trx.product_count) / product_count) + + if price < 0: + raise ValueError( + f"Product {self.name} (ID: {self.id}) has a negative price of {price}." + ) + + return price + + def owned_by_user(self, sql_session: Session) -> dict[user.User, int]: + """ + Returns an overview of how many of the remaining products are owned by which user. + """ + + ... diff --git a/dibbler/models/Purchase.py b/dibbler/models/Purchase.py deleted file mode 100644 index b725f96..0000000 --- a/dibbler/models/Purchase.py +++ /dev/null @@ -1,70 +0,0 @@ -from __future__ import annotations -from typing import TYPE_CHECKING - -from datetime import datetime -import math - -from sqlalchemy import ( - DateTime, - Integer, -) -from sqlalchemy.orm import ( - Mapped, - mapped_column, - relationship, -) - -from .Base import Base -from .Transaction import Transaction - -if TYPE_CHECKING: - from .PurchaseEntry import PurchaseEntry - - -class Purchase(Base): - __tablename__ = "purchases" - - id: Mapped[int] = mapped_column(Integer, primary_key=True) - time: Mapped[datetime] = mapped_column(DateTime) - price: Mapped[int] = mapped_column(Integer) - - transactions: Mapped[set[Transaction]] = relationship( - back_populates="purchase", order_by="Transaction.user_name" - ) - entries: Mapped[set[PurchaseEntry]] = relationship(back_populates="purchase") - - def __init__(self): - pass - - def is_complete(self): - return len(self.transactions) > 0 and len(self.entries) > 0 - - def price_per_transaction(self, round_up=True): - if round_up: - return int(math.ceil(float(self.price) / len(self.transactions))) - else: - return int(math.floor(float(self.price) / len(self.transactions))) - - def set_price(self, round_up=True): - self.price = 0 - for entry in self.entries: - self.price += entry.amount * entry.product.price - if len(self.transactions) > 0: - for t in self.transactions: - t.amount = self.price_per_transaction(round_up=round_up) - - def perform_purchase(self, ignore_penalty=False, round_up=True): - self.time = datetime.datetime.now() - self.set_price(round_up=round_up) - for t in self.transactions: - t.perform_transaction(ignore_penalty=ignore_penalty) - for entry in self.entries: - entry.product.stock -= entry.amount - - def perform_soft_purchase(self, price, round_up=True): - self.time = datetime.datetime.now() - self.price = price - for t in self.transactions: - t.amount = self.price_per_transaction(round_up=round_up) - for t in self.transactions: - t.perform_transaction() diff --git a/dibbler/models/PurchaseEntry.py b/dibbler/models/PurchaseEntry.py deleted file mode 100644 index 8484b32..0000000 --- a/dibbler/models/PurchaseEntry.py +++ /dev/null @@ -1,37 +0,0 @@ -from __future__ import annotations -from typing import TYPE_CHECKING - -from sqlalchemy import ( - Integer, - ForeignKey, -) -from sqlalchemy.orm import ( - Mapped, - mapped_column, - relationship, -) - -from .Base import Base - -if TYPE_CHECKING: - from .Product import Product - from .Purchase import Purchase - - -class PurchaseEntry(Base): - __tablename__ = "purchase_entries" - - id: Mapped[int] = mapped_column(Integer, primary_key=True) - amount: Mapped[int] = mapped_column(Integer) - - product_id: Mapped[int] = mapped_column(ForeignKey("products.product_id")) - purchase_id: Mapped[int] = mapped_column(ForeignKey("purchases.id")) - - product: Mapped[Product] = relationship(lazy="joined") - purchase: Mapped[Purchase] = relationship(lazy="joined") - - def __init__(self, purchase, product, amount): - self.product = product - self.product_bar_code = product.bar_code - self.purchase = purchase - self.amount = amount diff --git a/dibbler/models/Transaction.py b/dibbler/models/Transaction.py index df1155c..19fa1f3 100644 --- a/dibbler/models/Transaction.py +++ b/dibbler/models/Transaction.py @@ -1,52 +1,346 @@ from __future__ import annotations -from typing import TYPE_CHECKING from datetime import datetime +from typing import TYPE_CHECKING, Self from sqlalchemy import ( + Boolean, + CheckConstraint, DateTime, ForeignKey, Integer, - String, + Text, +) +from sqlalchemy import ( + Enum as SQLEnum, ) from sqlalchemy.orm import ( Mapped, mapped_column, relationship, ) +from sqlalchemy.sql.schema import Index from .Base import Base +from .TransactionType import TransactionType if TYPE_CHECKING: + from .Product import Product from .User import User - from .Purchase import Purchase + +# TODO: allow for joint transactions? +# dibbler allows joint transactions (e.g. buying more than one product at once, several people buying the same product, etc.) +# instead of having the software split the transactions up, making them hard to reconnect, +# maybe we should add some sort of joint transaction id field to allow multiple transactions to be grouped together? class Transaction(Base): - __tablename__ = "transactions" + __table_args__ = ( + # TODO: embed everything from _validate_by_transaction_type into the constraints + CheckConstraint( + f"type != '{TransactionType.TRANSFER}' OR transfer_user_id IS NOT NULL", + name="trx_type_transfer_required_fields", + ), + CheckConstraint( + f"type != '{TransactionType.ADD_PRODUCT}' OR (product_id IS NOT NULL AND per_product IS NOT NULL AND product_count IS NOT NULL)", + name="trx_type_add_product_required_fields", + ), + CheckConstraint( + f"type != '{TransactionType.BUY_PRODUCT}' OR (product_id IS NOT NULL AND product_count IS NOT NULL)", + name="trx_type_buy_product_required_fields", + ), + # Speed up product count calculation + Index("product_user_time", "product_id", "user_id", "time"), + # Speed up product owner calculation + Index("user_product_time", "user_id", "product_id", "time"), + # Speed up user transaction list / credit calculation + Index("user_time", "user_id", "time"), + ) id: Mapped[int] = mapped_column(Integer, primary_key=True) - time: Mapped[datetime] = mapped_column(DateTime) + message: Mapped[str | None] = mapped_column(Text, nullable=True) + + # The type of transaction + type: Mapped[TransactionType] = mapped_column(SQLEnum(TransactionType)) + + # The amount of money being added or subtracted from the user's credit amount: Mapped[int] = mapped_column(Integer) - penalty: Mapped[int] = mapped_column(Integer) - description: Mapped[str | None] = mapped_column(String(50)) - user_name: Mapped[str] = mapped_column(ForeignKey("users.name")) - purchase_id: Mapped[int | None] = mapped_column(ForeignKey("purchases.id")) + # If buying products, is the user penalized for having too low credit? + penalty: Mapped[Boolean] = mapped_column(Boolean, default=False) - user: Mapped[User] = relationship(lazy="joined") - purchase: Mapped[Purchase] = relationship(lazy="joined") + # If adding products, how much is each product worth + per_product: Mapped[int | None] = mapped_column(Integer) - def __init__(self, user, amount=0, description=None, purchase=None, penalty=1): - self.user = user + # The user who performs the transaction + user_id: Mapped[int | None] = mapped_column(ForeignKey("user.id")) + user: Mapped[User] = relationship(lazy="joined", foreign_keys=[user_id]) + + # Receiving user when moving credit from one user to another + transfer_user_id: Mapped[int | None] = mapped_column(ForeignKey("user.id")) + transfer_user: Mapped[User | None] = relationship( + lazy="joined", foreign_keys=[transfer_user_id] + ) + + # The product that is either being added or bought + product_id: Mapped[int | None] = mapped_column(ForeignKey("product.id")) + product: Mapped[Product | None] = relationship(lazy="joined") + + # The amount of products being added or bought + product_count: Mapped[int | None] = mapped_column(Integer) + + def __init__( + self: Self, + time: datetime, + type: TransactionType, + amount: int, + user_id: int, + message: str | None = None, + product_id: int | None = None, + transfer_user_id: int | None = None, + per_product: int | None = None, + product_count: int | None = None, + # penalty: bool = False + ) -> None: + self.time = time + self.message = message + self.type = type self.amount = amount - self.description = description - self.purchase = purchase - self.penalty = penalty + self.user_id = user_id + self.product_id = product_id + self.transfer_user_id = transfer_user_id + self.per_product = per_product + self.product_count = product_count + # self.penalty = penalty - def perform_transaction(self, ignore_penalty=False): - self.time = datetime.datetime.now() - if not ignore_penalty: - self.amount *= self.penalty - self.user.credit -= self.amount + self._validate_by_transaction_type() + + def _validate_by_transaction_type(self: Self) -> None: + """ + Validates the transaction based on its type. + Raises ValueError if the transaction is invalid. + """ + match self.type: + case TransactionType.ADJUST_BALANCE: + if self.amount == 0: + raise ValueError("Amount must not be zero for ADJUST_BALANCE transactions.") + + if self.user_id is None: + raise ValueError("ADJUST_BALANCE transactions must have a user.") + + if self.product_id is not None: + raise ValueError("ADJUST_BALANCE transactions must not have a product.") + + if self.product_count is not None: + raise ValueError("ADJUST_BALANCE transactions must not have a product count.") + + if self.transfer_user_id is not None: + raise ValueError("ADJUST_BALANCE transactions must not have a transfer user.") + + if self.per_product is not None: + raise ValueError( + "ADJUST_BALANCE transactions must not have a per_product value." + ) + + case TransactionType.ADJUST_STOCK: + if self.amount == 0: + raise ValueError("Amount must not be zero for ADJUST_STOCK transactions.") + + if self.product_id is None: + raise ValueError("ADJUST_STOCK transactions must have a product.") + + if self.product_count is None: + raise ValueError("ADJUST_STOCK transactions must have a product count.") + + if self.transfer_user_id is not None: + raise ValueError("ADJUST_STOCK transactions must not have a transfer user.") + + if self.per_product is not None: + raise ValueError("ADJUST_STOCK transactions must not have a per_product value.") + + case TransactionType.TRANSFER: + if self.amount == 0: + raise ValueError("Amount must not be zero for TRANSFER transactions.") + + if self.user_id is None: + raise ValueError("TRANSFER transactions must have a user.") + + if self.product_id is not None: + raise ValueError("TRANSFER transactions must not have a product.") + + if self.product_count is not None: + raise ValueError("TRANSFER transactions must not have a product count.") + + if self.transfer_user_id is None: + raise ValueError("TRANSFER transactions must have a transfer user.") + + if self.per_product is not None: + raise ValueError("TRANSFER transactions must not have a per_product value.") + + case TransactionType.ADD_PRODUCT: + # TODO: do we allow free products? + if self.amount == 0: + raise ValueError("Amount must not be zero for ADD_PRODUCT transactions.") + + if self.user_id is None: + raise ValueError("ADD_PRODUCT transactions must have a user.") + + if self.product_id is None: + raise ValueError("ADD_PRODUCT transactions must have a product.") + + if self.product_count is None: + raise ValueError("ADD_PRODUCT transactions must have a product count.") + + if self.transfer_user_id is not None: + raise ValueError("ADD_PRODUCT transactions must not have a transfer user.") + + if self.per_product is None: + raise ValueError("ADD_PRODUCT transactions must have a per_product value.") + + if self.per_product <= 0: + raise ValueError("per_product must be greater than zero.") + + if self.product_count <= 0: + raise ValueError("product_count must be greater than zero.") + + if self.amount > self.per_product * self.product_count: + raise ValueError( + "The real amount of the transaction must be less than the total value of the products." + ) + + case TransactionType.BUY_PRODUCT: + if self.amount == 0: + raise ValueError("Amount must not be zero for BUY_PRODUCT transactions.") + + if self.user_id is None: + raise ValueError("BUY_PRODUCT transactions must have a user.") + + if self.product_id is None: + raise ValueError("BUY_PRODUCT transactions must have a product.") + + if self.product_count is None: + raise ValueError("BUY_PRODUCT transactions must have a product count.") + + if self.transfer_user_id is not None: + raise ValueError("BUY_PRODUCT transactions must not have a transfer user.") + + if self.per_product is not None: + raise ValueError("BUY_PRODUCT transactions must not have a per_product value.") + + case _: + raise ValueError(f"Unknown transaction type: {self.type}") + + def economy_difference(self: Self) -> int: + """ + Returns the difference in economy caused by this transaction. + """ + if self.type == TransactionType.ADJUST_BALANCE: + return self.amount + elif self.type == TransactionType.ADJUST_STOCK: + return -self.amount + elif self.type == TransactionType.TRANSFER: + return 0 + elif self.type == TransactionType.ADD_PRODUCT: + product_value = self.per_product * self.product_count + return product_value - self.amount + elif self.type == TransactionType.BUY_PRODUCT: + return 0 + else: + raise ValueError(f"Unknown transaction type: {self.type}") + + def adjust_balance( + self: Self, + amount: int, + user_id: int, + time: datetime | None = None, + message: str | None = None, + ) -> Transaction: + """ + Creates an ADJUST transaction. + """ + if time is None: + time = datetime.now() + + return Transaction( + time=time, + type=TransactionType.ADJUST_BALANCE, + amount=amount, + user_id=user_id, + message=message, + ) + + def transfer( + self: Self, + amount: int, + user_id: int, + transfer_user_id: int, + time: datetime | None = None, + message: str | None = None, + ) -> Transaction: + """ + Creates a TRANSFER transaction. + """ + if time is None: + time = datetime.now() + + return Transaction( + time=time, + type=TransactionType.TRANSFER, + amount=amount, + user_id=user_id, + transfer_user_id=transfer_user_id, + message=message, + ) + + def add_product( + self: Self, + amount: int, + user_id: int, + product_id: int, + per_product: int, + product_count: int, + time: datetime | None = None, + message: str | None = None, + ) -> Transaction: + """ + Creates an ADD_PRODUCT transaction. + """ + if time is None: + time = datetime.now() + + return Transaction( + time=time, + type=TransactionType.ADD_PRODUCT, + amount=amount, + user_id=user_id, + product_id=product_id, + per_product=per_product, + product_count=product_count, + message=message, + ) + + def buy_product( + self: Self, + amount: int, + user_id: int, + product_id: int, + product_count: int, + time: datetime | None = None, + message: str | None = None, + ) -> Transaction: + """ + Creates a BUY_PRODUCT transaction. + """ + if time is None: + time = datetime.now() + + return Transaction( + time=time, + type=TransactionType.BUY_PRODUCT, + amount=amount, + user_id=user_id, + product_id=product_id, + product_count=product_count, + message=message, + ) diff --git a/dibbler/models/TransactionType.py b/dibbler/models/TransactionType.py new file mode 100644 index 0000000..3ec5af9 --- /dev/null +++ b/dibbler/models/TransactionType.py @@ -0,0 +1,12 @@ + +from enum import Enum + +class TransactionType(Enum): + """ + Enum for transaction types. + """ + ADJUST_BALANCE = "adjust_balance" + ADJUST_STOCK = "adjust_stock" + TRANSFER = "transfer" + ADD_PRODUCT = "add_product" + BUY_PRODUCT = "buy_product" diff --git a/dibbler/models/User.py b/dibbler/models/User.py index d93e7fb..2182dc5 100644 --- a/dibbler/models/User.py +++ b/dibbler/models/User.py @@ -1,49 +1,134 @@ from __future__ import annotations -from typing import TYPE_CHECKING + +from typing import Self from sqlalchemy import ( Integer, String, + func, + select, ) from sqlalchemy.orm import ( Mapped, + Session, mapped_column, - relationship, ) -from .Base import Base +import dibbler.models.Product as product -if TYPE_CHECKING: - from .UserProducts import UserProducts - from .Transaction import Transaction +from .Base import Base +from .Transaction import Transaction +from .TransactionType import TransactionType class User(Base): - __tablename__ = "users" - name: Mapped[str] = mapped_column(String(10), primary_key=True) - credit: Mapped[str] = mapped_column(Integer) + id: Mapped[int] = mapped_column(Integer, primary_key=True) + + name: Mapped[str] = mapped_column(String(20), unique=True) card: Mapped[str | None] = mapped_column(String(20)) rfid: Mapped[str | None] = mapped_column(String(20)) - products: Mapped[set[UserProducts]] = relationship(back_populates="user") - transactions: Mapped[set[Transaction]] = relationship(back_populates="user") + # name_re = r"[a-z]+" + # card_re = r"(([Nn][Tt][Nn][Uu])?[0-9]+)?" + # rfid_re = r"[0-9a-fA-F]*" - name_re = r"[a-z]+" - card_re = r"(([Nn][Tt][Nn][Uu])?[0-9]+)?" - rfid_re = r"[0-9a-fA-F]*" - - def __init__(self, name, card, rfid=None, credit=0): + def __init__(self: Self, name: str, card: str | None = None, rfid: str | None = None) -> None: self.name = name - if card == "": - card = None self.card = card - if rfid == "": - rfid = None self.rfid = rfid - self.credit = credit - def __str__(self): - return self.name + # def __str__(self): + # return self.name - def is_anonymous(self): - return self.card == "11122233" + # def is_anonymous(self): + # return self.card == "11122233" + + def credit(self, sql_session: Session) -> int: + """ + Returns the current credit of the user. + """ + + balance_adjustments = ( + select(func.coalesce(func.sum(Transaction.amount).label("balance_adjustments"), 0)) + .where( + Transaction.user_id == self.id, + Transaction.type == TransactionType.ADJUST_BALANCE, + ) + .scalar_subquery() + ) + + transfers_to_other_users = ( + select(func.coalesce(func.sum(Transaction.amount).label("transfers_to_other_users"), 0)) + .where( + Transaction.user_id == self.id, + Transaction.type == TransactionType.TRANSFER, + Transaction.transfer_user_id != self.id, + ) + .scalar_subquery() + ) + + transfers_to_self = ( + select(func.coalesce(func.sum(Transaction.amount).label("transfers_to_self"), 0)) + .where( + Transaction.transfer_user_id == self.id, + Transaction.type == TransactionType.TRANSFER, + Transaction.user_id != self.id, + ) + .scalar_subquery() + ) + + add_products = ( + select(func.coalesce(func.sum(Transaction.amount).label("add_products"), 0)) + .where( + Transaction.user_id == self.id, + Transaction.type == TransactionType.ADD_PRODUCT, + ) + .scalar_subquery() + ) + + buy_products = ( + select(func.coalesce(func.sum(Transaction.amount).label("buy_products"), 0)) + .where( + Transaction.user_id == self.id, + Transaction.type == TransactionType.BUY_PRODUCT, + ) + .scalar_subquery() + ) + + result = sql_session.scalar( + select( + # TODO: clearly define and fix the sign of the amount + ( + 0 + + balance_adjustments + - transfers_to_other_users + + transfers_to_self + + add_products + - buy_products + ).label("credit") + ) + ) + + assert result is not None, "Credit calculation returned None, please file a bug report." + + return result + + def products(self, sql_session: Session) -> list[tuple[product.Product, int]]: + """ + Returns the products that the user has put into the system (and has not been purchased yet) + """ + + ... + + def transactions(self, sql_session: Session) -> list[Transaction]: + """ + Returns the transactions of the user. + """ + + return list( + sql_session.scalars( + select(Transaction) + .where(Transaction.user_id == self.id) + .order_by(Transaction.time.desc()) + ).all() + ) diff --git a/dibbler/models/UserProducts.py b/dibbler/models/UserProducts.py deleted file mode 100644 index 17a8f13..0000000 --- a/dibbler/models/UserProducts.py +++ /dev/null @@ -1,31 +0,0 @@ -from __future__ import annotations -from typing import TYPE_CHECKING - -from sqlalchemy import ( - Integer, - ForeignKey, -) -from sqlalchemy.orm import ( - Mapped, - mapped_column, - relationship, -) - -from .Base import Base - -if TYPE_CHECKING: - from .User import User - from .Product import Product - - -class UserProducts(Base): - __tablename__ = "user_products" - - user_name: Mapped[str] = mapped_column(ForeignKey("users.name"), primary_key=True) - product_id: Mapped[int] = mapped_column(ForeignKey("products.product_id"), primary_key=True) - - count: Mapped[int] = mapped_column(Integer) - sign: Mapped[int] = mapped_column(Integer) - - user: Mapped[User] = relationship() - product: Mapped[Product] = relationship() diff --git a/dibbler/models/__init__.py b/dibbler/models/__init__.py index 9cd0325..1c47c1c 100644 --- a/dibbler/models/__init__.py +++ b/dibbler/models/__init__.py @@ -1,17 +1,12 @@ __all__ = [ "Base", "Product", - "Purchase", - "PurchaseEntry", "Transaction", "User", - "UserProducts", ] from .Base import Base from .Product import Product -from .Purchase import Purchase -from .PurchaseEntry import PurchaseEntry from .Transaction import Transaction +from .TransactionType import TransactionType from .User import User -from .UserProducts import UserProducts diff --git a/dibbler/subcommands/seed_test_data.py b/dibbler/subcommands/seed_test_data.py index 07454ea..1d8aa1d 100644 --- a/dibbler/subcommands/seed_test_data.py +++ b/dibbler/subcommands/seed_test_data.py @@ -19,30 +19,31 @@ def clear_db(session): def main(): session = Session() clear_db(session) - product_items = [] - user_items = [] - with open(JSON_FILE) as f: - json_obj = json.load(f) + # product_items = [] + # user_items = [] - for product in json_obj["products"]: - product_item = Product( - bar_code=product["bar_code"], - name=product["name"], - price=product["price"], - stock=product["stock"], - ) - product_items.append(product_item) + # with open(JSON_FILE) as f: + # json_obj = json.load(f) - for user in json_obj["users"]: - user_item = User( - name=user["name"], - card=user["card"], - rfid=user["rfid"], - credit=user["credit"], - ) - user_items.append(user_item) + # for product in json_obj["products"]: + # product_item = Product( + # bar_code=product["bar_code"], + # name=product["name"], + # price=product["price"], + # stock=product["stock"], + # ) + # product_items.append(product_item) - session.add_all(product_items) - session.add_all(user_items) - session.commit() + # for user in json_obj["users"]: + # user_item = User( + # name=user["name"], + # card=user["card"], + # rfid=user["rfid"], + # credit=user["credit"], + # ) + # user_items.append(user_item) + + # session.add_all(product_items) + # session.add_all(user_items) + # session.commit() diff --git a/example-config.ini b/example-config.ini index 7abacb0..324b7fd 100644 --- a/example-config.ini +++ b/example-config.ini @@ -6,7 +6,7 @@ input_encoding = 'utf8' [database] # url = "postgresql://robertem@127.0.0.1/pvvvv" -url = "sqlite:///test.db" +url = sqlite:///test.db [limits] low_credit_warning_limit = -100 diff --git a/nix/dibbler.nix b/nix/dibbler.nix index 24dc21d..610b2dd 100644 --- a/nix/dibbler.nix +++ b/nix/dibbler.nix @@ -13,6 +13,14 @@ python3Packages.buildPythonApplication { # https://github.com/NixOS/nixpkgs/issues/285234 dontCheckRuntimeDeps = true; + pythonImportsCheck = []; + + doCheck = true; + nativeCheckInputs = with python3Packages; [ + pytest + pytestCheckHook + ]; + nativeBuildInputs = with python3Packages; [ setuptools ]; propagatedBuildInputs = with python3Packages; [ brother-ql diff --git a/nix/shell.nix b/nix/shell.nix index 7c93f0f..0c5fcae 100644 --- a/nix/shell.nix +++ b/nix/shell.nix @@ -15,6 +15,7 @@ mkShell { psycopg2 python-barcode sqlalchemy + pytest ])) ]; } diff --git a/pyproject.toml b/pyproject.toml index 3179a6f..a0b531d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,11 @@ dependencies = [ ] dynamic = ["version"] +[project.optional-dependencies] +dev = [ + "pytest" +] + [tool.setuptools.packages.find] include = ["dibbler*"] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..cc952e3 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,27 @@ +import pytest + +from sqlalchemy import create_engine +from sqlalchemy.orm import Session + +from dibbler.models import Base + +def pytest_addoption(parser): + parser.addoption( + "--echo", + action="store_true", + help="Enable SQLAlchemy echo mode for debugging", + ) + +@pytest.fixture(scope="function") +def session(request): + """Create a new SQLAlchemy session for testing.""" + + echo = request.config.getoption("--echo") + + engine = create_engine( + "sqlite:///:memory:", + echo=echo, + ) + Base.metadata.create_all(engine) + with Session(engine) as session: + yield session diff --git a/tests/test_product.py b/tests/test_product.py new file mode 100644 index 0000000..0867f22 --- /dev/null +++ b/tests/test_product.py @@ -0,0 +1,96 @@ +from datetime import datetime + +import pytest +from sqlalchemy import select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session + +from dibbler.models import Product, Transaction, TransactionType, User + + +def insert_test_data(session: Session) -> None: + # Add users + user1 = User("Test User 1") + user2 = User("Test User 2") + + session.add(user1) + session.add(user2) + session.commit() + + # Add products + product1 = Product("1234567890123", "Test Product 1") + product2 = Product("9876543210987", "Test Product 2") + session.add(product1) + session.add(product2) + session.commit() + + # Add transactions + transactions = [ + Transaction( + time=datetime(2023, 10, 1, 10, 0, 0), + type=TransactionType.ADJUST_BALANCE, + amount=100, + user_id=user1.id, + ), + Transaction( + time=datetime(2023, 10, 1, 10, 0, 0), + type=TransactionType.ADJUST_BALANCE, + amount=50, + user_id=user2.id, + ), + Transaction( + time=datetime(2023, 10, 1, 10, 0, 1), + type=TransactionType.ADJUST_BALANCE, + amount=-50, + user_id=user1.id, + ), + Transaction( + time=datetime(2023, 10, 1, 12, 0, 0), + type=TransactionType.ADD_PRODUCT, + amount=27 * 2, + per_product=27, + product_count=2, + user_id=user1.id, + product_id=product1.id, + ), + Transaction( + time=datetime(2023, 10, 1, 12, 0, 1), + type=TransactionType.BUY_PRODUCT, + amount=27, + product_count=1, + user_id=user2.id, + product_id=product1.id, + ), + ] + + session.add_all(transactions) + session.commit() + + +def test_no_duplicate_products(session: Session): + insert_test_data(session) + + product1 = Product("1234567890123", "Test Product 1") + session.add(product1) + + with pytest.raises(IntegrityError): + session.commit() + + +def test_product_stock(session: Session): + insert_test_data(session) + + product1 = session.scalars(select(Product).where(Product.name == "Test Product 1")).one() + product2 = session.scalars(select(Product).where(Product.name == "Test Product 2")).one() + + assert product1.stock(session) == 1 + assert product2.stock(session) == 0 + +def test_product_price(session: Session): + insert_test_data(session) + + product1 = session.scalars(select(Product).where(Product.name == "Test Product 1")).one() + product2 = session.scalars(select(Product).where(Product.name == "Test Product 2")).one() + + assert product1.price(session) == 27 + assert product2.price(session) == 0 diff --git a/tests/test_transaction.py b/tests/test_transaction.py new file mode 100644 index 0000000..92d0232 --- /dev/null +++ b/tests/test_transaction.py @@ -0,0 +1,64 @@ +from datetime import datetime + +from sqlalchemy.orm import Session + +from dibbler.models import Product, Transaction, TransactionType, User + + +def insert_test_data(session: Session) -> None: + # Add users + user1 = User("Test User 1") + user2 = User("Test User 2") + + session.add(user1) + session.add(user2) + session.commit() + + # Add products + product1 = Product("1234567890123", "Test Product 1") + product2 = Product("9876543210987", "Test Product 2") + session.add(product1) + session.add(product2) + session.commit() + + # Add transactions + transactions = [ + Transaction( + time=datetime(2023, 10, 1, 10, 0, 0), + type=TransactionType.ADJUST_BALANCE, + amount=100, + user_id=user1.id, + ), + Transaction( + time=datetime(2023, 10, 1, 10, 0, 0), + type=TransactionType.ADJUST_BALANCE, + amount=50, + user_id=user2.id, + ), + Transaction( + time=datetime(2023, 10, 1, 10, 0, 1), + type=TransactionType.ADJUST_BALANCE, + amount=-50, + user_id=user1.id, + ), + Transaction( + time=datetime(2023, 10, 1, 12, 0, 0), + type=TransactionType.ADD_PRODUCT, + amount=27 * 2, + per_product=27, + product_count=2, + user_id=user1.id, + product_id=product1.id, + ), + Transaction( + time=datetime(2023, 10, 1, 12, 0, 1), + type=TransactionType.BUY_PRODUCT, + amount=27, + product_count=1, + user_id=user2.id, + product_id=product1.id, + ), + ] + + session.add_all(transactions) + session.commit() diff --git a/tests/test_user.py b/tests/test_user.py new file mode 100644 index 0000000..0566b9f --- /dev/null +++ b/tests/test_user.py @@ -0,0 +1,87 @@ +from datetime import datetime + +import pytest +from sqlalchemy import select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session + +from dibbler.models import Product, Transaction, TransactionType, User + + +def insert_test_data(session: Session) -> None: + # Add users + user1 = User("Test User 1") + user2 = User("Test User 2") + + session.add(user1) + session.add(user2) + session.commit() + + # Add products + product1 = Product("1234567890123", "Test Product 1") + product2 = Product("9876543210987", "Test Product 2") + session.add(product1) + session.add(product2) + session.commit() + + # Add transactions + transactions = [ + Transaction( + time=datetime(2023, 10, 1, 10, 0, 0), + type=TransactionType.ADJUST_BALANCE, + amount=100, + user_id=user1.id, + ), + Transaction( + time=datetime(2023, 10, 1, 10, 0, 0), + type=TransactionType.ADJUST_BALANCE, + amount=50, + user_id=user2.id, + ), + Transaction( + time=datetime(2023, 10, 1, 10, 0, 1), + type=TransactionType.ADJUST_BALANCE, + amount=-50, + user_id=user1.id, + ), + Transaction( + time=datetime(2023, 10, 1, 12, 0, 0), + type=TransactionType.ADD_PRODUCT, + amount=27 * 2, + per_product=27, + product_count=2, + user_id=user1.id, + product_id=product1.id, + ), + Transaction( + time=datetime(2023, 10, 1, 12, 0, 1), + type=TransactionType.BUY_PRODUCT, + amount=27, + product_count=1, + user_id=user2.id, + product_id=product1.id, + ), + ] + + session.add_all(transactions) + session.commit() + + +def test_ensure_no_duplicate_users(session: Session): + insert_test_data(session) + + user1 = User("Test User 1") + session.add(user1) + + with pytest.raises(IntegrityError): + session.commit() + + +def test_user_credit(session: Session): + insert_test_data(session) + + user1 = session.scalars(select(User).where(User.name == "Test User 1")).one() + user2 = session.scalars(select(User).where(User.name == "Test User 2")).one() + + assert user1.credit(session) == 100 - 50 + 27 * 2 + assert user2.credit(session) == 50 - 27