diff --git a/dibbler/models/Base.py b/dibbler/models/Base.py index f0764fe..e0bcc3d 100644 --- a/dibbler/models/Base.py +++ b/dibbler/models/Base.py @@ -10,12 +10,16 @@ from sqlalchemy.orm.collections import ( ) +def _pascal_case_to_snake_case(name: str) -> str: + return "".join(["_" + i.lower() if i.isupper() else i for i in name]).lstrip("_") + + class Base(DeclarativeBase): metadata = MetaData( naming_convention={ - "ix": "ix_%(column_0_label)s", + "ix": "ix_%(table_name)s_%(column_0_label)s", "uq": "uq_%(table_name)s_%(column_0_name)s", - "ck": "ck_%(table_name)s_`%(constraint_name)s`", + "ck": "ck_%(table_name)s_%(constraint_name)s", "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", "pk": "pk_%(table_name)s", } @@ -23,8 +27,12 @@ class Base(DeclarativeBase): @declared_attr.directive def __tablename__(cls) -> str: - return cls.__name__ + return _pascal_case_to_snake_case(cls.__name__) + # NOTE: This is the default implementation of __repr__ for all tables, + # but it is preferable to override it for each table to get a nicer + # looking representation. This trades a bit of messiness for a complete + # output of all relevant fields. def __repr__(self) -> str: columns = ", ".join( f"{k}={repr(v)}" diff --git a/dibbler/models/Product.py b/dibbler/models/Product.py index 48e2f26..c69171f 100644 --- a/dibbler/models/Product.py +++ b/dibbler/models/Product.py @@ -1,5 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING + +from typing import Self from sqlalchemy import ( Boolean, @@ -9,39 +10,44 @@ from sqlalchemy import ( from sqlalchemy.orm import ( Mapped, mapped_column, - relationship, ) from .Base import Base -if TYPE_CHECKING: - from .PurchaseEntry import PurchaseEntry - from .UserProducts import UserProducts - class Product(Base): - __tablename__ = "products" + id: Mapped[int] = mapped_column(Integer, primary_key=True) + """Internal database ID""" - product_id: Mapped[int] = mapped_column(Integer, primary_key=True) - bar_code: Mapped[str] = mapped_column(String(13)) - name: Mapped[str] = mapped_column(String(45)) - price: Mapped[int] = mapped_column(Integer) - stock: Mapped[int] = mapped_column(Integer) - hidden: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + bar_code: Mapped[str] = mapped_column(String(13), unique=True) + """ + The bar code of the product. - purchases: Mapped[set[PurchaseEntry]] = relationship(back_populates="product") - users: Mapped[set[UserProducts]] = relationship(back_populates="product") + This is a unique identifier for the product, typically a 13-digit + EAN-13 code. + """ - bar_code_re = r"[0-9]+" - name_re = r".+" - name_length = 45 + name: Mapped[str] = mapped_column(String(45), unique=True) + """ + The name of the product. - def __init__(self, bar_code, name, price, stock=0, hidden=False): - self.name = name + Please don't write fanfics here, this is not a place for that. + """ + + hidden: Mapped[bool] = mapped_column(Boolean, default=False) + """ + Whether the product is hidden from the user interface. + + Hidden products are not shown in the product list, but can still be + used in transactions. + """ + + def __init__( + self: Self, + bar_code: str, + name: str, + hidden: bool = False, + ) -> None: self.bar_code = bar_code - self.price = price - self.stock = stock + self.name = name self.hidden = hidden - - def __str__(self): - return self.name diff --git a/dibbler/models/ProductCache.py b/dibbler/models/ProductCache.py new file mode 100644 index 0000000..f6bcb66 --- /dev/null +++ b/dibbler/models/ProductCache.py @@ -0,0 +1,16 @@ +from datetime import datetime + +from sqlalchemy import Integer, DateTime +from sqlalchemy.orm import Mapped, mapped_column + +from dibbler.models import Base + + +class ProductCache(Base): + product_id: Mapped[int] = mapped_column(Integer, primary_key=True) + + price: Mapped[int] = mapped_column(Integer) + price_timestamp: Mapped[datetime] = mapped_column(DateTime) + + stock: Mapped[int] = mapped_column(Integer) + stock_timestamp: Mapped[datetime] = mapped_column(DateTime) diff --git a/dibbler/models/Purchase.py b/dibbler/models/Purchase.py deleted file mode 100644 index b725f96..0000000 --- a/dibbler/models/Purchase.py +++ /dev/null @@ -1,70 +0,0 @@ -from __future__ import annotations -from typing import TYPE_CHECKING - -from datetime import datetime -import math - -from sqlalchemy import ( - DateTime, - Integer, -) -from sqlalchemy.orm import ( - Mapped, - mapped_column, - relationship, -) - -from .Base import Base -from .Transaction import Transaction - -if TYPE_CHECKING: - from .PurchaseEntry import PurchaseEntry - - -class Purchase(Base): - __tablename__ = "purchases" - - id: Mapped[int] = mapped_column(Integer, primary_key=True) - time: Mapped[datetime] = mapped_column(DateTime) - price: Mapped[int] = mapped_column(Integer) - - transactions: Mapped[set[Transaction]] = relationship( - back_populates="purchase", order_by="Transaction.user_name" - ) - entries: Mapped[set[PurchaseEntry]] = relationship(back_populates="purchase") - - def __init__(self): - pass - - def is_complete(self): - return len(self.transactions) > 0 and len(self.entries) > 0 - - def price_per_transaction(self, round_up=True): - if round_up: - return int(math.ceil(float(self.price) / len(self.transactions))) - else: - return int(math.floor(float(self.price) / len(self.transactions))) - - def set_price(self, round_up=True): - self.price = 0 - for entry in self.entries: - self.price += entry.amount * entry.product.price - if len(self.transactions) > 0: - for t in self.transactions: - t.amount = self.price_per_transaction(round_up=round_up) - - def perform_purchase(self, ignore_penalty=False, round_up=True): - self.time = datetime.datetime.now() - self.set_price(round_up=round_up) - for t in self.transactions: - t.perform_transaction(ignore_penalty=ignore_penalty) - for entry in self.entries: - entry.product.stock -= entry.amount - - def perform_soft_purchase(self, price, round_up=True): - self.time = datetime.datetime.now() - self.price = price - for t in self.transactions: - t.amount = self.price_per_transaction(round_up=round_up) - for t in self.transactions: - t.perform_transaction() diff --git a/dibbler/models/PurchaseEntry.py b/dibbler/models/PurchaseEntry.py deleted file mode 100644 index 8484b32..0000000 --- a/dibbler/models/PurchaseEntry.py +++ /dev/null @@ -1,37 +0,0 @@ -from __future__ import annotations -from typing import TYPE_CHECKING - -from sqlalchemy import ( - Integer, - ForeignKey, -) -from sqlalchemy.orm import ( - Mapped, - mapped_column, - relationship, -) - -from .Base import Base - -if TYPE_CHECKING: - from .Product import Product - from .Purchase import Purchase - - -class PurchaseEntry(Base): - __tablename__ = "purchase_entries" - - id: Mapped[int] = mapped_column(Integer, primary_key=True) - amount: Mapped[int] = mapped_column(Integer) - - product_id: Mapped[int] = mapped_column(ForeignKey("products.product_id")) - purchase_id: Mapped[int] = mapped_column(ForeignKey("purchases.id")) - - product: Mapped[Product] = relationship(lazy="joined") - purchase: Mapped[Purchase] = relationship(lazy="joined") - - def __init__(self, purchase, product, amount): - self.product = product - self.product_bar_code = product.bar_code - self.purchase = purchase - self.amount = amount diff --git a/dibbler/models/Transaction.py b/dibbler/models/Transaction.py index df1155c..853d210 100644 --- a/dibbler/models/Transaction.py +++ b/dibbler/models/Transaction.py @@ -1,52 +1,571 @@ from __future__ import annotations -from typing import TYPE_CHECKING from datetime import datetime +from typing import TYPE_CHECKING, Self from sqlalchemy import ( + CheckConstraint, DateTime, ForeignKey, Integer, - String, + Text, + and_, + column, + func, + or_, ) from sqlalchemy.orm import ( Mapped, mapped_column, relationship, ) +from sqlalchemy.orm.collections import ( + InstrumentedDict, + InstrumentedList, + InstrumentedSet, +) +from sqlalchemy.sql.schema import Index from .Base import Base +from .TransactionType import TransactionType, TransactionTypeSQL if TYPE_CHECKING: + from .Product import Product from .User import User - from .Purchase import Purchase + +# TODO: rename to *_PERCENT +# NOTE: these only matter when there are no adjustments made in the database. +DEFAULT_INTEREST_RATE_PERCENTAGE = 100 +DEFAULT_PENALTY_THRESHOLD = -100 +DEFAULT_PENALTY_MULTIPLIER_PERCENTAGE = 200 + +_DYNAMIC_FIELDS: set[str] = { + "amount", + "interest_rate_percent", + "joint_transaction_id", + "penalty_multiplier_percent", + "penalty_threshold", + "per_product", + "product_count", + "product_id", + "transfer_user_id", +} + +EXPECTED_FIELDS: dict[TransactionType, set[str]] = { + TransactionType.ADD_PRODUCT: {"amount", "per_product", "product_count", "product_id"}, + TransactionType.ADJUST_BALANCE: {"amount"}, + TransactionType.ADJUST_INTEREST: {"interest_rate_percent"}, + TransactionType.ADJUST_PENALTY: {"penalty_multiplier_percent", "penalty_threshold"}, + TransactionType.ADJUST_STOCK: {"product_count", "product_id"}, + TransactionType.BUY_PRODUCT: {"product_count", "product_id"}, + TransactionType.JOINT: {"product_count", "product_id"}, + TransactionType.JOINT_BUY_PRODUCT: {"joint_transaction_id"}, + TransactionType.THROW_PRODUCT: {"product_count", "product_id"}, + TransactionType.TRANSFER: {"amount", "transfer_user_id"}, +} + +assert all(x <= _DYNAMIC_FIELDS for x in EXPECTED_FIELDS.values()), ( + "All expected fields must be part of _DYNAMIC_FIELDS." +) + + +def _transaction_type_field_constraints( + transaction_type: TransactionType, + expected_fields: set[str], +) -> CheckConstraint: + unexpected_fields = _DYNAMIC_FIELDS - expected_fields + + return CheckConstraint( + or_( + column("type") != transaction_type.value, + and_( + *[column(field).is_not(None) for field in expected_fields], + *[column(field).is_(None) for field in unexpected_fields], + ), + ), + name=f"trx_type_{transaction_type.value}_expected_fields", + ) class Transaction(Base): - __tablename__ = "transactions" + __table_args__ = ( + *[ + _transaction_type_field_constraints(transaction_type, expected_fields) + for transaction_type, expected_fields in EXPECTED_FIELDS.items() + ], + CheckConstraint( + or_( + column("type") != TransactionType.TRANSFER.value, + column("user_id") != column("transfer_user_id"), + ), + name="trx_type_transfer_no_self_transfers", + ), + CheckConstraint( + func.coalesce(column("product_count"), 1) != 0, + name="trx_product_count_non_zero", + ), + CheckConstraint( + func.coalesce(column("penalty_multiplier_percent"), 100) >= 100, + name="trx_penalty_multiplier_percent_min_100", + ), + CheckConstraint( + func.coalesce(column("interest_rate_percent"), 0) >= 0, + name="trx_interest_rate_percent_non_negative", + ), + CheckConstraint( + func.coalesce(column("amount"), 1) != 0, + name="trx_amount_non_zero", + ), + CheckConstraint( + func.coalesce(column("per_product"), 1) > 0, + name="trx_per_product_positive", + ), + CheckConstraint( + func.coalesce(column("penalty_threshold"), 0) <= 0, + name="trx_penalty_threshold_max_0", + ), + CheckConstraint( + or_( + column("joint_transaction_id").is_(None), + column("joint_transaction_id") != column("id"), + ), + name="trx_joint_transaction_id_not_self", + ), + # Speed up product count calculation + Index("product_user_time", "product_id", "user_id", "time"), + # Speed up product owner calculation + Index("user_product_time", "user_id", "product_id", "time"), + # Speed up user transaction list / credit calculation + Index("user_time", "user_id", "time"), + ) id: Mapped[int] = mapped_column(Integer, primary_key=True) + """ + A unique identifier for the transaction. + + Not used for anything else than identifying the transaction in the database. + """ time: Mapped[datetime] = mapped_column(DateTime) - amount: Mapped[int] = mapped_column(Integer) - penalty: Mapped[int] = mapped_column(Integer) - description: Mapped[str | None] = mapped_column(String(50)) + """ + The time when the transaction took place. - user_name: Mapped[str] = mapped_column(ForeignKey("users.name")) - purchase_id: Mapped[int | None] = mapped_column(ForeignKey("purchases.id")) + This is used to order transactions chronologically, and to calculate + all kinds of state. + """ - user: Mapped[User] = relationship(lazy="joined") - purchase: Mapped[Purchase] = relationship(lazy="joined") + message: Mapped[str | None] = mapped_column(Text, nullable=True) + """ + A message that can be set by the user to describe the reason + behind the transaction (or potentially a place to write som fan fiction). + + This is not used for any calculations, but can be useful for debugging. + """ + + type_: Mapped[TransactionType] = mapped_column(TransactionTypeSQL, name="type") + """ + Which type of transaction this is. + + The type determines which fields are expected to be set. + """ + + amount: Mapped[int | None] = mapped_column(Integer) + """ + This field means different things depending on the transaction type: + + - `ADD_PRODUCT`: The real amount spent on the products. + + - `ADJUST_BALANCE`: The amount of credit to add or subtract from the user's balance. + + - `TRANSFER`: The amount of balance to transfer to another user. + """ + + per_product: Mapped[int | None] = mapped_column(Integer) + """ + If adding products, how much is each product worth + + Note that this is distinct from the total amount of the transaction, + because this gets rounded up to the nearest integer, while the total amount + that the user paid in the store would be stored in the `amount` field. + """ + + user_id: Mapped[int] = mapped_column(ForeignKey("user.id")) + """The user who performs the transaction. See `user` for more details.""" + user: Mapped[User] = relationship( + lazy="joined", + foreign_keys=[user_id], + ) + """ + The user who performs the transaction. + + For some transaction types, like `TRANSFER` and `ADD_PRODUCT`, this is a + functional field with "real world consequences" for price calculations. + + For others, like `ADJUST_PENALTY` and `ADJUST_STOCK`, this is just a record of who + performed the transaction, and does not affect any state calculations. + + In the case of `JOINT` transactions, this is the user who initiated the joint transaction. + """ + + joint_transaction_id: Mapped[int | None] = mapped_column(ForeignKey("transaction.id")) + """ + An optional ID to group multiple transactions together as part of a joint transaction. + + This is used for `JOINT` and `JOINT_BUY_PRODUCT` transactions, where multiple users + are involved in a single transaction. + """ + joint_transaction: Mapped[Transaction | None] = relationship( + lazy="joined", + foreign_keys=[joint_transaction_id], + ) + """ + The joint transaction that this transaction is part of, if any. + """ + + # Receiving user when moving credit from one user to another + transfer_user_id: Mapped[int | None] = mapped_column(ForeignKey("user.id")) + """The user who receives money in a `TRANSFER` transaction.""" + transfer_user: Mapped[User | None] = relationship( + lazy="joined", + foreign_keys=[transfer_user_id], + ) + """The user who receives money in a `TRANSFER` transaction.""" + + # The product that is either being added or bought + product_id: Mapped[int | None] = mapped_column(ForeignKey("product.id")) + """The product being added or bought.""" + product: Mapped[Product | None] = relationship(lazy="joined") + """The product being added or bought.""" + + # The amount of products being added or bought + product_count: Mapped[int | None] = mapped_column(Integer) + """ + The amount of products being added or bought. + + This is always relative to the existing stock. + + - `ADD_PRODUCT` increases the stock by this amount. + + - `BUY_PRODUCT` decreases the stock by this amount. + + - `ADJUST_STOCK` increases or decreases the stock by this amount, + depending on whether the amount is positive or negative. + """ + + penalty_threshold: Mapped[int | None] = mapped_column(Integer, nullable=True) + """ + On `ADJUST_PENALTY` transactions, this is the threshold in krs for when the user + should start getting penalized for low credit. + + See also `penalty_multiplier`. + """ + + penalty_multiplier_percent: Mapped[int | None] = mapped_column(Integer, nullable=True) + """ + On `ADJUST_PENALTY` transactions, this is the multiplier for the amount of + money the user has to pay when they have too low credit. + + The multiplier is a percentage, so `100` means the user has to pay the full + price of the product, `200` means they have to pay double, etc. + + See also `penalty_threshold`. + """ + + interest_rate_percent: Mapped[int | None] = mapped_column(Integer, nullable=True) + """ + On `ADJUST_INTEREST` transactions, this is the interest rate in percent + that the user has to pay on their balance. + + The interest rate is a percentage, so `100` means the user has to pay the full + price of the product, `200` means they have to pay double, etc. + """ + + economy_spec_version: Mapped[int] = mapped_column(Integer, default=1) + """ + The version of the economy specification that this transaction adheres to. + + This is used to handle changes in the economy rules over time. + """ + + def __init__( + self: Self, + type_: TransactionType, + user_id: int, + amount: int | None = None, + interest_rate_percent: int | None = None, + joint_transaction_id: int | None = None, + message: str | None = None, + penalty_multiplier_percent: int | None = None, + penalty_threshold: int | None = None, + per_product: int | None = None, + product_count: int | None = None, + product_id: int | None = None, + time: datetime | None = None, + transfer_user_id: int | None = None, + ) -> None: + """ + Please do not call this constructor directly, use the factory methods instead. + """ + if time is None: + time = datetime.now() - def __init__(self, user, amount=0, description=None, purchase=None, penalty=1): - self.user = user self.amount = amount - self.description = description - self.purchase = purchase - self.penalty = penalty + self.interest_rate_percent = interest_rate_percent + self.joint_transaction_id = joint_transaction_id + self.message = message + self.penalty_multiplier_percent = penalty_multiplier_percent + self.penalty_threshold = penalty_threshold + self.per_product = per_product + self.product_count = product_count + self.product_id = product_id + self.time = time + self.transfer_user_id = transfer_user_id + self.type_ = type_ + self.user_id = user_id - def perform_transaction(self, ignore_penalty=False): - self.time = datetime.datetime.now() - if not ignore_penalty: - self.amount *= self.penalty - self.user.credit -= self.amount + self._validate_by_transaction_type() + + def _validate_by_transaction_type(self: Self) -> None: + """ + Validates the transaction's fields based on its type. + Raises `ValueError` if the transaction is invalid. + """ + # TODO: do we allow free products? + if self.amount == 0: + raise ValueError("Amount must not be zero.") + + for field in EXPECTED_FIELDS[self.type_]: + if getattr(self, field) is None: + raise ValueError(f"{field} must not be None for {self.type_.value} transactions.") + + for field in _DYNAMIC_FIELDS - EXPECTED_FIELDS[self.type_]: + if getattr(self, field) is not None: + raise ValueError(f"{field} must be None for {self.type_.value} transactions.") + + if self.per_product is not None and self.per_product <= 0: + raise ValueError("per_product must be greater than zero.") + + if ( + self.per_product is not None + and self.product_count is not None + and self.amount is not None + and self.amount > self.per_product * self.product_count + ): + raise ValueError( + "The real amount of the transaction must be less than the total value of the products." + ) + + # TODO: improve printing further + + def __repr__(self) -> str: + sort_order = [ + "id", + "time", + ] + + columns = ", ".join( + f"{k}={repr(v)}" + for k, v in sorted( + self.__dict__.items(), + key=lambda item: chr(sort_order.index(item[0])) + if item[0] in sort_order + else item[0], + ) + if not any( + [ + k == "type_", + (k == "message" and v is None), + k.startswith("_"), + # Ensure that we don't try to print out the entire list of + # relationships, which could create an infinite loop + isinstance(v, Base), + isinstance(v, InstrumentedList), + isinstance(v, InstrumentedSet), + isinstance(v, InstrumentedDict), + *[k in (_DYNAMIC_FIELDS - EXPECTED_FIELDS[self.type_])], + ] + ) + ) + return f"{self.type_.upper()}({columns})" + + ################### + # FACTORY METHODS # + ################### + + @classmethod + def adjust_balance( + cls: type[Self], + amount: int, + user_id: int, + time: datetime | None = None, + message: str | None = None, + ) -> Self: + return cls( + time=time, + type_=TransactionType.ADJUST_BALANCE, + amount=amount, + user_id=user_id, + message=message, + ) + + @classmethod + def adjust_interest( + cls: type[Self], + interest_rate_percent: int, + user_id: int, + time: datetime | None = None, + message: str | None = None, + ) -> Self: + return cls( + time=time, + type_=TransactionType.ADJUST_INTEREST, + interest_rate_percent=interest_rate_percent, + user_id=user_id, + message=message, + ) + + @classmethod + def adjust_penalty( + cls: type[Self], + penalty_multiplier_percent: int, + penalty_threshold: int, + user_id: int, + time: datetime | None = None, + message: str | None = None, + ) -> Self: + return cls( + time=time, + type_=TransactionType.ADJUST_PENALTY, + penalty_multiplier_percent=penalty_multiplier_percent, + penalty_threshold=penalty_threshold, + user_id=user_id, + message=message, + ) + + @classmethod + def adjust_stock( + cls: type[Self], + user_id: int, + product_id: int, + product_count: int, + time: datetime | None = None, + message: str | None = None, + ) -> Self: + return cls( + time=time, + type_=TransactionType.ADJUST_STOCK, + user_id=user_id, + product_id=product_id, + product_count=product_count, + message=message, + ) + + @classmethod + def add_product( + cls: type[Self], + amount: int, + user_id: int, + product_id: int, + per_product: int, + product_count: int, + time: datetime | None = None, + message: str | None = None, + ) -> Self: + return cls( + time=time, + type_=TransactionType.ADD_PRODUCT, + amount=amount, + user_id=user_id, + product_id=product_id, + per_product=per_product, + product_count=product_count, + message=message, + ) + + @classmethod + def buy_product( + cls: type[Self], + user_id: int, + product_id: int, + product_count: int, + time: datetime | None = None, + message: str | None = None, + ) -> Self: + return cls( + time=time, + type_=TransactionType.BUY_PRODUCT, + user_id=user_id, + product_id=product_id, + product_count=product_count, + message=message, + ) + + @classmethod + def joint( + cls: type[Self], + user_id: int, + product_id: int, + product_count: int, + time: datetime | None = None, + message: str | None = None, + ) -> Self: + return cls( + time=time, + type_=TransactionType.JOINT, + user_id=user_id, + product_id=product_id, + product_count=product_count, + message=message, + ) + + @classmethod + def joint_buy_product( + cls: type[Self], + joint_transaction_id: int, + user_id: int, + time: datetime | None = None, + message: str | None = None, + ) -> Self: + return cls( + time=time, + type_=TransactionType.JOINT_BUY_PRODUCT, + joint_transaction_id=joint_transaction_id, + user_id=user_id, + message=message, + ) + + @classmethod + def transfer( + cls: type[Self], + amount: int, + user_id: int, + transfer_user_id: int, + time: datetime | None = None, + message: str | None = None, + ) -> Self: + return cls( + time=time, + type_=TransactionType.TRANSFER, + amount=amount, + user_id=user_id, + transfer_user_id=transfer_user_id, + message=message, + ) + + @classmethod + def throw_product( + cls: type[Self], + user_id: int, + product_id: int, + product_count: int, + time: datetime | None = None, + message: str | None = None, + ) -> Self: + return cls( + time=time, + type_=TransactionType.THROW_PRODUCT, + user_id=user_id, + product_id=product_id, + product_count=product_count, + message=message, + ) diff --git a/dibbler/models/TransactionType.py b/dibbler/models/TransactionType.py new file mode 100644 index 0000000..3b8f0f1 --- /dev/null +++ b/dibbler/models/TransactionType.py @@ -0,0 +1,40 @@ +from enum import StrEnum, auto + +from sqlalchemy import Enum as SQLEnum + + +class TransactionType(StrEnum): + """ + Enum for transaction types. + """ + + ADD_PRODUCT = auto() + ADJUST_BALANCE = auto() + ADJUST_INTEREST = auto() + ADJUST_PENALTY = auto() + ADJUST_STOCK = auto() + BUY_PRODUCT = auto() + JOINT = auto() + JOINT_BUY_PRODUCT = auto() + THROW_PRODUCT = auto() + TRANSFER = auto() + + def as_literal_column(self): + """ + Return the transaction type as a SQL literal column. + + This is useful to avoid too many `?` bind parameters in SQL queries, + when the input value is known to be safe. + """ + from sqlalchemy import literal_column + + return literal_column(f"'{self.value}'") + + +TransactionTypeSQL = SQLEnum( + TransactionType, + native_enum=True, + create_constraint=True, + validate_strings=True, + values_callable=lambda x: [i.value for i in x], +) diff --git a/dibbler/models/User.py b/dibbler/models/User.py index d93e7fb..0f86bc7 100644 --- a/dibbler/models/User.py +++ b/dibbler/models/User.py @@ -1,5 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING + +from typing import Self from sqlalchemy import ( Integer, @@ -8,42 +9,35 @@ from sqlalchemy import ( from sqlalchemy.orm import ( Mapped, mapped_column, - relationship, ) from .Base import Base -if TYPE_CHECKING: - from .UserProducts import UserProducts - from .Transaction import Transaction - class User(Base): - __tablename__ = "users" - name: Mapped[str] = mapped_column(String(10), primary_key=True) - credit: Mapped[str] = mapped_column(Integer) + id: Mapped[int] = mapped_column(Integer, primary_key=True) + """Internal database ID""" + + name: Mapped[str] = mapped_column(String(20), unique=True) + """The PVV username of the user.""" + card: Mapped[str | None] = mapped_column(String(20)) + """The NTNU card number of the user.""" + rfid: Mapped[str | None] = mapped_column(String(20)) + """The RFID tag of the user (if they have any, rare these days).""" - products: Mapped[set[UserProducts]] = relationship(back_populates="user") - transactions: Mapped[set[Transaction]] = relationship(back_populates="user") + # name_re = r"[a-z]+" + # card_re = r"(([Nn][Tt][Nn][Uu])?[0-9]+)?" + # rfid_re = r"[0-9a-fA-F]*" - name_re = r"[a-z]+" - card_re = r"(([Nn][Tt][Nn][Uu])?[0-9]+)?" - rfid_re = r"[0-9a-fA-F]*" - - def __init__(self, name, card, rfid=None, credit=0): + def __init__(self: Self, name: str, card: str | None = None, rfid: str | None = None) -> None: self.name = name - if card == "": - card = None self.card = card - if rfid == "": - rfid = None self.rfid = rfid - self.credit = credit - def __str__(self): - return self.name + # def __str__(self): + # return self.name - def is_anonymous(self): - return self.card == "11122233" + # def is_anonymous(self): + # return self.card == "11122233" diff --git a/dibbler/models/UserCache.py b/dibbler/models/UserCache.py new file mode 100644 index 0000000..7b35e52 --- /dev/null +++ b/dibbler/models/UserCache.py @@ -0,0 +1,14 @@ +from datetime import datetime + +from sqlalchemy import Integer, DateTime +from sqlalchemy.orm import Mapped, mapped_column + +from dibbler.models import Base + + +# More like user balance cash money flow, amirite? +class UserBalanceCache(Base): + user_id: Mapped[int] = mapped_column(Integer, primary_key=True) + + balance: Mapped[int] = mapped_column(Integer) + timestamp: Mapped[datetime] = mapped_column(DateTime) diff --git a/dibbler/models/UserProducts.py b/dibbler/models/UserProducts.py deleted file mode 100644 index 17a8f13..0000000 --- a/dibbler/models/UserProducts.py +++ /dev/null @@ -1,31 +0,0 @@ -from __future__ import annotations -from typing import TYPE_CHECKING - -from sqlalchemy import ( - Integer, - ForeignKey, -) -from sqlalchemy.orm import ( - Mapped, - mapped_column, - relationship, -) - -from .Base import Base - -if TYPE_CHECKING: - from .User import User - from .Product import Product - - -class UserProducts(Base): - __tablename__ = "user_products" - - user_name: Mapped[str] = mapped_column(ForeignKey("users.name"), primary_key=True) - product_id: Mapped[int] = mapped_column(ForeignKey("products.product_id"), primary_key=True) - - count: Mapped[int] = mapped_column(Integer) - sign: Mapped[int] = mapped_column(Integer) - - user: Mapped[User] = relationship() - product: Mapped[Product] = relationship() diff --git a/dibbler/models/__init__.py b/dibbler/models/__init__.py index 9cd0325..083bd50 100644 --- a/dibbler/models/__init__.py +++ b/dibbler/models/__init__.py @@ -1,17 +1,13 @@ __all__ = [ "Base", "Product", - "Purchase", - "PurchaseEntry", "Transaction", + "TransactionType", "User", - "UserProducts", ] from .Base import Base from .Product import Product -from .Purchase import Purchase -from .PurchaseEntry import PurchaseEntry from .Transaction import Transaction +from .TransactionType import TransactionType from .User import User -from .UserProducts import UserProducts diff --git a/tests/models/__init__.py b/tests/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/test_product.py b/tests/models/test_product.py new file mode 100644 index 0000000..5e40f00 --- /dev/null +++ b/tests/models/test_product.py @@ -0,0 +1,32 @@ +import pytest +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session + +from dibbler.models import Product + + +def insert_test_data(sql_session: Session) -> Product: + product = Product("1234567890123", "Test Product") + sql_session.add(product) + sql_session.commit() + return product + + +def test_product_no_duplicate_barcodes(sql_session: Session): + product = insert_test_data(sql_session) + + duplicate_product = Product(product.bar_code, "Hehe >:)") + sql_session.add(duplicate_product) + + with pytest.raises(IntegrityError): + sql_session.commit() + + +def test_product_no_duplicate_names(sql_session: Session): + product = insert_test_data(sql_session) + + duplicate_product = Product("1918238911928", product.name) + sql_session.add(duplicate_product) + + with pytest.raises(IntegrityError): + sql_session.commit() diff --git a/tests/models/test_transaction.py b/tests/models/test_transaction.py new file mode 100644 index 0000000..28a0958 --- /dev/null +++ b/tests/models/test_transaction.py @@ -0,0 +1,157 @@ +from datetime import datetime + +import pytest +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session + +from dibbler.models import Product, Transaction, User +from dibbler.queries import product_stock + + +def insert_test_data(sql_session: Session) -> tuple[User, Product]: + user = User("Test User") + product = Product("1234567890123", "Test Product") + + sql_session.add(user) + sql_session.add(product) + sql_session.commit() + + return user, product + + +def test_user_not_allowed_to_transfer_to_self(sql_session: Session) -> None: + user, _ = insert_test_data(sql_session) + + transaction = Transaction.transfer( + time=datetime(2023, 10, 1, 12, 0, 0), + user_id=user.id, + transfer_user_id=user.id, + amount=50, + ) + + sql_session.add(transaction) + + with pytest.raises(IntegrityError): + sql_session.commit() + + +def test_product_foreign_key_constraint(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transaction = Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 0), + user_id=user.id, + product_id=product.id, + amount=27, + per_product=27, + product_count=1, + ) + + sql_session.add(transaction) + sql_session.commit() + + # Attempt to add a transaction with a non-existent product + invalid_transaction = Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 1), + user_id=user.id, + product_id=9999, # Non-existent product ID + amount=27, + per_product=27, + product_count=1, + ) + + sql_session.add(invalid_transaction) + + with pytest.raises(IntegrityError): + sql_session.commit() + + +def test_user_foreign_key_constraint(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transaction = Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 0), + user_id=user.id, + product_id=product.id, + amount=27, + per_product=27, + product_count=1, + ) + + sql_session.add(transaction) + sql_session.commit() + + # Attempt to add a transaction with a non-existent user + invalid_transaction = Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 1), + user_id=9999, # Non-existent user ID + product_id=product.id, + amount=27, + per_product=27, + product_count=1, + ) + + sql_session.add(invalid_transaction) + + with pytest.raises(IntegrityError): + sql_session.commit() + + +def test_transaction_buy_product_more_than_stock(sql_session: Session) -> None: + user, product = insert_test_data(sql_session) + + transactions = [ + Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 0), + user_id=user.id, + product_id=product.id, + amount=27, + per_product=27, + product_count=1, + ), + Transaction.buy_product( + time=datetime(2023, 10, 1, 13, 0, 0), + product_count=10, + user_id=user.id, + product_id=product.id, + ), + ] + + sql_session.add_all(transactions) + sql_session.commit() + + assert product_stock(sql_session, product) == 1 - 10 + + +def test_transaction_add_product_deny_amount_over_per_product_times_product_count( + sql_session: Session, +) -> None: + user, product = insert_test_data(sql_session) + + with pytest.raises(ValueError): + _transaction = Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 0), + user_id=user.id, + product_id=product.id, + amount=27 * 2 + 1, # Invalid amount + per_product=27, + product_count=2, + ) + + +def test_transaction_add_product_allow_amount_under_per_product_times_product_count( + sql_session: Session, +) -> None: + user, product = insert_test_data(sql_session) + + transaction = Transaction.add_product( + time=datetime(2023, 10, 1, 12, 0, 0), + user_id=user.id, + product_id=product.id, + amount=27 * 2 - 1, # Valid amount + per_product=27, + product_count=2, + ) + + sql_session.add(transaction) + sql_session.commit() diff --git a/tests/models/test_user.py b/tests/models/test_user.py new file mode 100644 index 0000000..7e8e852 --- /dev/null +++ b/tests/models/test_user.py @@ -0,0 +1,25 @@ +from datetime import datetime + +import pytest +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session + +from dibbler.models import Product, Transaction, User + + +def insert_test_data(sql_session: Session) -> User: + user = User("Test User") + sql_session.add(user) + sql_session.commit() + + return user + + +def test_ensure_no_duplicate_user_names(sql_session: Session): + user = insert_test_data(sql_session) + + user2 = User(user.name) + sql_session.add(user2) + + with pytest.raises(IntegrityError): + sql_session.commit()