WIP: caching
This commit is contained in:
@@ -27,6 +27,9 @@ class Base(DeclarativeBase):
|
||||
|
||||
@declared_attr.directive
|
||||
def __tablename__(cls) -> str:
|
||||
if hasattr(cls, "__table_name__"):
|
||||
assert isinstance(cls.__table_name__, str)
|
||||
return cls.__table_name__
|
||||
return _pascal_case_to_snake_case(cls.__name__)
|
||||
|
||||
# NOTE: This is the default implementation of __repr__ for all tables,
|
||||
|
||||
26
dibbler/models/LastCacheTransaction.py
Normal file
26
dibbler/models/LastCacheTransaction.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import Integer, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from dibbler.models import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dibbler.models import Transaction
|
||||
|
||||
|
||||
class LastCacheTransaction(Base):
|
||||
"""Tracks the last transaction that affected various caches."""
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
"""Internal database ID"""
|
||||
|
||||
transaction_id: Mapped[int | None] = mapped_column(ForeignKey("trx.id"), index=True)
|
||||
"""The ID of the last transaction that affected the cache(s)."""
|
||||
|
||||
transaction: Mapped[Transaction | None] = relationship(
|
||||
lazy="joined",
|
||||
foreign_keys=[transaction_id],
|
||||
)
|
||||
"""The last transaction that affected the cache(s)."""
|
||||
@@ -1,16 +1,30 @@
|
||||
from datetime import datetime
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import Integer, DateTime
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy import Integer, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from dibbler.models import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dibbler.models import LastCacheTransaction, Product
|
||||
|
||||
|
||||
class ProductCache(Base):
|
||||
product_id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
"""Internal database ID"""
|
||||
|
||||
product_id: Mapped[int] = mapped_column(ForeignKey('product.id'))
|
||||
product: Mapped[Product] = relationship(
|
||||
lazy="joined",
|
||||
foreign_keys=[product_id],
|
||||
)
|
||||
|
||||
price: Mapped[int] = mapped_column(Integer)
|
||||
price_timestamp: Mapped[datetime] = mapped_column(DateTime)
|
||||
|
||||
stock: Mapped[int] = mapped_column(Integer)
|
||||
stock_timestamp: Mapped[datetime] = mapped_column(DateTime)
|
||||
|
||||
last_cache_transaction_id: Mapped[int | None] = mapped_column(ForeignKey("last_cache_transaction.id"), nullable=True)
|
||||
last_cache_transaction: Mapped[LastCacheTransaction | None] = relationship(
|
||||
lazy="joined",
|
||||
foreign_keys=[last_cache_transaction_id],
|
||||
)
|
||||
|
||||
@@ -87,6 +87,7 @@ def _transaction_type_field_constraints(
|
||||
|
||||
|
||||
class Transaction(Base):
|
||||
__tablename__ = "trx"
|
||||
__table_args__ = (
|
||||
*[
|
||||
_transaction_type_field_constraints(transaction_type, expected_fields)
|
||||
@@ -210,7 +211,7 @@ class Transaction(Base):
|
||||
"""
|
||||
|
||||
joint_transaction_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("transaction.id"),
|
||||
ForeignKey("trx.id"),
|
||||
index=True,
|
||||
)
|
||||
"""
|
||||
|
||||
@@ -1,14 +1,30 @@
|
||||
from datetime import datetime
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import Integer, DateTime
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy import Integer, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from dibbler.models import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dibbler.models import LastCacheTransaction, User
|
||||
|
||||
|
||||
# More like user balance cash money flow, amirite?
|
||||
class UserBalanceCache(Base):
|
||||
user_id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
class UserCache(Base):
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
"""internal database id"""
|
||||
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey('user.id'))
|
||||
user: Mapped[User] = relationship(
|
||||
lazy="joined",
|
||||
foreign_keys=[user_id],
|
||||
)
|
||||
|
||||
balance: Mapped[int] = mapped_column(Integer)
|
||||
timestamp: Mapped[datetime] = mapped_column(DateTime)
|
||||
|
||||
last_cache_transaction_id: Mapped[int | None] = mapped_column(ForeignKey("last_cache_transaction.id"), nullable=True)
|
||||
last_cache_transaction: Mapped[LastCacheTransaction | None] = relationship(
|
||||
lazy="joined",
|
||||
foreign_keys=[last_cache_transaction_id],
|
||||
)
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
__all__ = [
|
||||
"Base",
|
||||
"LastCacheTransaction",
|
||||
"Product",
|
||||
"ProductCache",
|
||||
"Transaction",
|
||||
"TransactionType",
|
||||
"User",
|
||||
"UserCache",
|
||||
]
|
||||
|
||||
from .Base import Base
|
||||
from .LastCacheTransaction import LastCacheTransaction
|
||||
from .Product import Product
|
||||
from .ProductCache import ProductCache
|
||||
from .Transaction import Transaction
|
||||
from .TransactionType import TransactionType
|
||||
from .User import User
|
||||
from .UserCache import UserCache
|
||||
|
||||
@@ -4,6 +4,8 @@ __all__ = [
|
||||
"adjust_interest",
|
||||
"adjust_penalty",
|
||||
"adjust_stock",
|
||||
"affected_products",
|
||||
"affected_users",
|
||||
"create_product",
|
||||
"create_user",
|
||||
"current_interest",
|
||||
@@ -19,6 +21,7 @@ __all__ = [
|
||||
"throw_product",
|
||||
"transaction_log",
|
||||
"transfer",
|
||||
"update_cache",
|
||||
"user_balance",
|
||||
"user_balance_log",
|
||||
"user_products",
|
||||
@@ -29,6 +32,8 @@ from .adjust_balance import adjust_balance
|
||||
from .adjust_interest import adjust_interest
|
||||
from .adjust_penalty import adjust_penalty
|
||||
from .adjust_stock import adjust_stock
|
||||
from .affected_products import affected_products
|
||||
from .affected_users import affected_users
|
||||
from .create_product import create_product
|
||||
from .create_user import create_user
|
||||
from .current_interest import current_interest
|
||||
@@ -42,5 +47,6 @@ from .search_user import search_user
|
||||
from .throw_product import throw_product
|
||||
from .transaction_log import transaction_log
|
||||
from .transfer import transfer
|
||||
from .update_cache import update_cache
|
||||
from .user_balance import user_balance, user_balance_log
|
||||
from .user_products import user_products
|
||||
|
||||
88
dibbler/queries/affected_products.py
Normal file
88
dibbler/queries/affected_products.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import BindParameter, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.models import Product, Transaction, TransactionType
|
||||
from dibbler.queries.query_helpers import until_filter, after_filter
|
||||
|
||||
|
||||
def affected_products(
|
||||
sql_session: Session,
|
||||
until_time: BindParameter[datetime] | datetime | None = None,
|
||||
until_transaction: BindParameter[Transaction] | Transaction | None = None,
|
||||
until_inclusive: bool = True,
|
||||
after_time: BindParameter[datetime] | datetime | None = None,
|
||||
after_transaction: Transaction | None = None,
|
||||
after_inclusive: bool = True,
|
||||
) -> set[Product]:
|
||||
"""
|
||||
Get all products where attributes were affected over a given interval.
|
||||
"""
|
||||
|
||||
if isinstance(until_time, datetime):
|
||||
until_time = BindParameter("until_time", value=until_time)
|
||||
|
||||
if isinstance(until_transaction, Transaction):
|
||||
if until_transaction.id is None:
|
||||
raise ValueError("until_transaction must be persisted in the database.")
|
||||
until_transaction_id = BindParameter("until_transaction_id", value=until_transaction.id)
|
||||
else:
|
||||
until_transaction_id = None
|
||||
|
||||
if not (after_time is None or after_transaction is None):
|
||||
raise ValueError("Cannot filter by both after_time and after_transaction_id.")
|
||||
|
||||
if isinstance(after_time, datetime):
|
||||
after_time = BindParameter("after_time", value=after_time)
|
||||
|
||||
if isinstance(after_transaction, Transaction):
|
||||
if after_transaction.id is None:
|
||||
raise ValueError("after_transaction must be persisted in the database.")
|
||||
after_transaction_id = BindParameter("after_transaction_id", value=after_transaction.id)
|
||||
else:
|
||||
after_transaction_id = None
|
||||
|
||||
if after_time is not None and until_time is not None:
|
||||
assert isinstance(after_time.value, datetime)
|
||||
assert isinstance(until_time.value, datetime)
|
||||
|
||||
if after_time.value > until_time.value:
|
||||
raise ValueError("after_time cannot be after until_time.")
|
||||
|
||||
if after_transaction is not None and until_transaction is not None:
|
||||
assert after_transaction.time is not None
|
||||
assert until_transaction.time is not None
|
||||
|
||||
if after_transaction.time > until_transaction.time:
|
||||
raise ValueError("after_transaction cannot be after until_transaction.")
|
||||
|
||||
result = sql_session.scalars(
|
||||
select(Product)
|
||||
.distinct()
|
||||
.join(Transaction, Product.id == Transaction.product_id)
|
||||
.where(
|
||||
Transaction.type_.in_(
|
||||
[
|
||||
TransactionType.ADD_PRODUCT.as_literal_column(),
|
||||
TransactionType.ADJUST_STOCK.as_literal_column(),
|
||||
TransactionType.BUY_PRODUCT.as_literal_column(),
|
||||
TransactionType.JOINT.as_literal_column(),
|
||||
TransactionType.THROW_PRODUCT.as_literal_column(),
|
||||
]
|
||||
),
|
||||
until_filter(
|
||||
until_time=until_time,
|
||||
until_transaction_id=until_transaction_id,
|
||||
until_inclusive=until_inclusive,
|
||||
),
|
||||
after_filter(
|
||||
after_time=after_time,
|
||||
after_transaction_id=after_transaction_id,
|
||||
after_inclusive=after_inclusive,
|
||||
),
|
||||
)
|
||||
.order_by(Transaction.time.desc())
|
||||
).all()
|
||||
|
||||
return set(result)
|
||||
86
dibbler/queries/affected_users.py
Normal file
86
dibbler/queries/affected_users.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import BindParameter, select, or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.models import Transaction, TransactionType, User
|
||||
from dibbler.queries.query_helpers import until_filter, after_filter
|
||||
|
||||
def affected_users(
|
||||
sql_session: Session,
|
||||
until_time: BindParameter[datetime] | datetime | None = None,
|
||||
until_transaction: BindParameter[Transaction] | Transaction | None = None,
|
||||
until_inclusive: bool = True,
|
||||
after_time: BindParameter[datetime] | datetime | None = None,
|
||||
after_transaction: Transaction | None = None,
|
||||
after_inclusive: bool = True,
|
||||
) -> set[User]:
|
||||
"""
|
||||
Get all users where attributes were affected over a given interval.
|
||||
"""
|
||||
|
||||
if isinstance(until_time, datetime):
|
||||
until_time = BindParameter("until_time", value=until_time)
|
||||
|
||||
if isinstance(until_transaction, Transaction):
|
||||
if until_transaction.id is None:
|
||||
raise ValueError("until_transaction must be persisted in the database.")
|
||||
until_transaction_id = BindParameter("until_transaction_id", value=until_transaction.id)
|
||||
else:
|
||||
until_transaction_id = None
|
||||
|
||||
if not (after_time is None or after_transaction is None):
|
||||
raise ValueError("Cannot filter by both after_time and after_transaction_id.")
|
||||
|
||||
if isinstance(after_time, datetime):
|
||||
after_time = BindParameter("after_time", value=after_time)
|
||||
|
||||
if isinstance(after_transaction, Transaction):
|
||||
if after_transaction.id is None:
|
||||
raise ValueError("after_transaction must be persisted in the database.")
|
||||
after_transaction_id = BindParameter("after_transaction_id", value=after_transaction.id)
|
||||
else:
|
||||
after_transaction_id = None
|
||||
|
||||
if after_time is not None and until_time is not None:
|
||||
assert isinstance(after_time.value, datetime)
|
||||
assert isinstance(until_time.value, datetime)
|
||||
|
||||
if after_time.value > until_time.value:
|
||||
raise ValueError("after_time cannot be after until_time.")
|
||||
|
||||
if after_transaction is not None and until_transaction is not None:
|
||||
assert after_transaction.time is not None
|
||||
assert until_transaction.time is not None
|
||||
|
||||
if after_transaction.time > until_transaction.time:
|
||||
raise ValueError("after_transaction cannot be after until_transaction.")
|
||||
|
||||
result = sql_session.scalars(
|
||||
select(User)
|
||||
.distinct()
|
||||
.join(Transaction, or_(User.id == Transaction.user_id, User.id == Transaction.transfer_user_id))
|
||||
.where(
|
||||
Transaction.type_.in_(
|
||||
[
|
||||
TransactionType.ADD_PRODUCT.as_literal_column(),
|
||||
TransactionType.ADJUST_BALANCE.as_literal_column(),
|
||||
TransactionType.BUY_PRODUCT.as_literal_column(),
|
||||
TransactionType.TRANSFER.as_literal_column(),
|
||||
]
|
||||
),
|
||||
until_filter(
|
||||
until_time=until_time,
|
||||
until_transaction_id=until_transaction_id,
|
||||
until_inclusive=until_inclusive,
|
||||
),
|
||||
after_filter(
|
||||
after_time=after_time,
|
||||
after_transaction_id=after_transaction_id,
|
||||
after_inclusive=after_inclusive,
|
||||
),
|
||||
)
|
||||
.order_by(Transaction.time.desc())
|
||||
).all()
|
||||
|
||||
return set(result)
|
||||
@@ -78,3 +78,42 @@ def until_filter(
|
||||
)
|
||||
|
||||
return CONST_TRUE
|
||||
|
||||
|
||||
def after_filter(
|
||||
after_time: BindParameter[datetime] | None = None,
|
||||
after_transaction_id: BindParameter[int] | None = None,
|
||||
after_inclusive: bool = True,
|
||||
transaction_time: QueryableAttribute = Transaction.time,
|
||||
) -> ColumnExpressionArgument[bool]:
|
||||
"""
|
||||
Create a filter condition for transactions after a given time or transaction.
|
||||
|
||||
Only one of `after_time` or `after_transaction_id` may be specified.
|
||||
"""
|
||||
|
||||
assert not (after_time is not None and after_transaction_id is not None), (
|
||||
"Cannot filter by both after_time and after_transaction_id."
|
||||
)
|
||||
|
||||
match (after_time, after_transaction_id, after_inclusive):
|
||||
case (BindParameter(), None, True):
|
||||
return transaction_time >= after_time
|
||||
case (BindParameter(), None, False):
|
||||
return transaction_time > after_time
|
||||
case (None, BindParameter(), True):
|
||||
return (
|
||||
transaction_time
|
||||
>= select(Transaction.time)
|
||||
.where(Transaction.id == after_transaction_id)
|
||||
.scalar_subquery()
|
||||
)
|
||||
case (None, BindParameter(), False):
|
||||
return (
|
||||
transaction_time
|
||||
> select(Transaction.time)
|
||||
.where(Transaction.id == after_transaction_id)
|
||||
.scalar_subquery()
|
||||
)
|
||||
|
||||
return CONST_TRUE
|
||||
|
||||
118
dibbler/queries/update_cache.py
Normal file
118
dibbler/queries/update_cache.py
Normal file
@@ -0,0 +1,118 @@
|
||||
from sqlalchemy import insert, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.models import LastCacheTransaction, ProductCache, Transaction, UserCache
|
||||
from dibbler.queries.affected_products import affected_products
|
||||
from dibbler.queries.affected_users import affected_users
|
||||
from dibbler.queries.product_price import product_price
|
||||
from dibbler.queries.product_stock import product_stock
|
||||
from dibbler.queries.user_balance import user_balance
|
||||
|
||||
|
||||
def update_cache(
|
||||
sql_session: Session,
|
||||
use_cache: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Update the cache used for searching products.
|
||||
|
||||
If `use_cache` is False, the cache will be rebuilt from scratch.
|
||||
"""
|
||||
|
||||
last_transaction = sql_session.scalars(
|
||||
select(Transaction).order_by(Transaction.time.desc()).limit(1)
|
||||
).one_or_none()
|
||||
|
||||
print(last_transaction)
|
||||
|
||||
if last_transaction is None:
|
||||
# No transactions exist, nothing to update
|
||||
return
|
||||
|
||||
if use_cache:
|
||||
last_cache_transaction = sql_session.scalars(
|
||||
select(LastCacheTransaction)
|
||||
.join(Transaction, LastCacheTransaction.transaction_id == Transaction.id)
|
||||
.order_by(Transaction.time.desc())
|
||||
.limit(1)
|
||||
).one_or_none()
|
||||
if last_cache_transaction is not None:
|
||||
last_cache_transaction = last_cache_transaction.transaction
|
||||
else:
|
||||
last_cache_transaction = None
|
||||
|
||||
if last_cache_transaction is not None and last_cache_transaction.id == last_transaction.id:
|
||||
# Cache is already up to date
|
||||
return
|
||||
|
||||
users = affected_users(
|
||||
sql_session,
|
||||
after_transaction=last_cache_transaction,
|
||||
after_inclusive=False,
|
||||
until_transaction=last_transaction,
|
||||
)
|
||||
products = affected_products(
|
||||
sql_session,
|
||||
after_transaction=last_cache_transaction,
|
||||
after_inclusive=False,
|
||||
until_transaction=last_transaction,
|
||||
)
|
||||
|
||||
user_balances = {}
|
||||
for user in users:
|
||||
x = user_balance(
|
||||
sql_session,
|
||||
user,
|
||||
use_cache=use_cache,
|
||||
until_transaction=last_transaction,
|
||||
)
|
||||
user_balances[user.id] = x
|
||||
|
||||
product_stocks = {}
|
||||
product_prices = {}
|
||||
for product in products:
|
||||
product_stocks[product.id] = product_stock(
|
||||
sql_session,
|
||||
product,
|
||||
use_cache=use_cache,
|
||||
until_transaction=last_transaction,
|
||||
)
|
||||
product_prices[product.id] = product_price(
|
||||
sql_session,
|
||||
product,
|
||||
use_cache=use_cache,
|
||||
until_transaction=last_transaction,
|
||||
)
|
||||
|
||||
next_cache_transaction = LastCacheTransaction(transaction_id=last_transaction.id)
|
||||
sql_session.add(next_cache_transaction)
|
||||
sql_session.flush()
|
||||
|
||||
if not len(users) == 0:
|
||||
sql_session.execute(
|
||||
insert(UserCache),
|
||||
[
|
||||
{
|
||||
"user_id": user.id,
|
||||
"balance": user_balances[user.id],
|
||||
"last_cache_transaction_id": next_cache_transaction.id,
|
||||
}
|
||||
for user in users
|
||||
],
|
||||
)
|
||||
|
||||
if not len(products) == 0:
|
||||
sql_session.execute(
|
||||
insert(ProductCache),
|
||||
[
|
||||
{
|
||||
"product_id": product.id,
|
||||
"stock": product_stocks[product.id],
|
||||
"price": product_prices[product.id],
|
||||
"last_cache_transaction_id": next_cache_transaction.id,
|
||||
}
|
||||
for product in products
|
||||
],
|
||||
)
|
||||
|
||||
sql_session.commit()
|
||||
74
tests/queries/test_affected_products.py
Normal file
74
tests/queries/test_affected_products.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.models import Product, Transaction, User
|
||||
from dibbler.queries import affected_products
|
||||
from tests.helpers import assert_id_order_similar_to_time_order, assign_times
|
||||
|
||||
|
||||
def insert_test_data(sql_session: Session) -> tuple[User, list[Product]]:
|
||||
user = User("Test User")
|
||||
|
||||
products = []
|
||||
for i in range(10):
|
||||
product = Product(f"12345678901{i:02d}", f"Test Product {i}")
|
||||
products.append(product)
|
||||
|
||||
sql_session.add(user)
|
||||
sql_session.add_all(products)
|
||||
sql_session.commit()
|
||||
|
||||
return user, products
|
||||
|
||||
|
||||
def test_affected_products_no_history(sql_session: Session) -> None:
|
||||
insert_test_data(sql_session)
|
||||
|
||||
result = affected_products(sql_session)
|
||||
|
||||
assert result == set()
|
||||
|
||||
|
||||
def test_affected_products_basic_history(sql_session: Session) -> None:
|
||||
user, products = insert_test_data(sql_session)
|
||||
|
||||
transactions = [
|
||||
Transaction.add_product(
|
||||
amount=10,
|
||||
per_product=10,
|
||||
user_id=user.id,
|
||||
product_id=products[i].id,
|
||||
product_count=1,
|
||||
)
|
||||
for i in range(5)
|
||||
] + [
|
||||
Transaction.buy_product(
|
||||
user_id=user.id,
|
||||
product_id=products[i].id,
|
||||
product_count=1,
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
assign_times(transactions)
|
||||
|
||||
sql_session.add_all(transactions)
|
||||
sql_session.commit()
|
||||
|
||||
assert_id_order_similar_to_time_order(transactions)
|
||||
|
||||
result = affected_products(sql_session)
|
||||
|
||||
expected_products = {products[i] for i in range(5)}
|
||||
|
||||
assert result == expected_products
|
||||
|
||||
|
||||
# def test_affected_products_after(sql_session: Session) -> None:
|
||||
# def test_affected_products_until(sql_session: Session) -> None:
|
||||
# def test_affected_products_after_until(sql_session: Session) -> None:
|
||||
# def test_affected_products_after_inclusive(sql_session: Session) -> None:
|
||||
# def test_affected_products_until_inclusive(sql_session: Session) -> None:
|
||||
# def test_affected_products_after_until_inclusive(sql_session: Session) -> None:
|
||||
74
tests/queries/test_affected_users.py
Normal file
74
tests/queries/test_affected_users.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.models import Product, Transaction, User
|
||||
from dibbler.queries import affected_users
|
||||
from tests.helpers import assert_id_order_similar_to_time_order, assign_times
|
||||
|
||||
|
||||
def insert_test_data(sql_session: Session) -> tuple[list[User], Product]:
|
||||
users = []
|
||||
for i in range(10):
|
||||
user = User(f"Test User {i + 1}")
|
||||
users.append(user)
|
||||
|
||||
product = Product("1234567890123", "Test Product")
|
||||
|
||||
sql_session.add_all(users)
|
||||
sql_session.add(product)
|
||||
sql_session.commit()
|
||||
|
||||
return users, product
|
||||
|
||||
|
||||
def test_affected_users_no_history(sql_session: Session) -> None:
|
||||
insert_test_data(sql_session)
|
||||
|
||||
result = affected_users(sql_session)
|
||||
|
||||
assert result == set()
|
||||
|
||||
|
||||
def test_affected_users_basic_history(sql_session: Session) -> None:
|
||||
users, product = insert_test_data(sql_session)
|
||||
|
||||
transactions = [
|
||||
Transaction.add_product(
|
||||
amount=10,
|
||||
per_product=10,
|
||||
user_id=users[i].id,
|
||||
product_id=product.id,
|
||||
product_count=1,
|
||||
)
|
||||
for i in range(5)
|
||||
] + [
|
||||
Transaction.buy_product(
|
||||
user_id=users[i].id,
|
||||
product_id=product.id,
|
||||
product_count=1,
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
assign_times(transactions)
|
||||
|
||||
sql_session.add_all(transactions)
|
||||
sql_session.commit()
|
||||
|
||||
assert_id_order_similar_to_time_order(transactions)
|
||||
|
||||
result = affected_users(sql_session)
|
||||
|
||||
expected_users = {users[i] for i in range(5)}
|
||||
|
||||
assert result == expected_users
|
||||
|
||||
|
||||
# def test_affected_users_after(sql_session: Session) -> None:
|
||||
# def test_affected_users_until(sql_session: Session) -> None:
|
||||
# def test_affected_users_after_until(sql_session: Session) -> None:
|
||||
# def test_affected_users_after_inclusive(sql_session: Session) -> None:
|
||||
# def test_affected_users_until_inclusive(sql_session: Session) -> None:
|
||||
# def test_affected_users_after_until_inclusive(sql_session: Session) -> None:
|
||||
204
tests/queries/test_update_cache.py
Normal file
204
tests/queries/test_update_cache.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.models import Product, ProductCache, Transaction, User, UserCache
|
||||
from dibbler.models.LastCacheTransaction import LastCacheTransaction
|
||||
from dibbler.queries import update_cache
|
||||
from tests.helpers import assert_id_order_similar_to_time_order, assign_times
|
||||
|
||||
|
||||
def insert_test_data(sql_session: Session) -> tuple[User, User, Product, Product]:
|
||||
user1 = User("Test User")
|
||||
user2 = User("Another User")
|
||||
product1 = Product("1234567890123", "Test Product 1")
|
||||
product2 = Product("9876543210987", "Test Product 2")
|
||||
|
||||
sql_session.add_all([user1, user2, product1, product2])
|
||||
sql_session.commit()
|
||||
|
||||
return user1, user2, product1, product2
|
||||
|
||||
|
||||
def get_cache_entries(sql_session: Session) -> tuple[list[UserCache], list[ProductCache]]:
|
||||
user_cache = sql_session.scalars(
|
||||
select(UserCache)
|
||||
.join(LastCacheTransaction, UserCache.last_cache_transaction_id == LastCacheTransaction.id)
|
||||
.join(Transaction, LastCacheTransaction.transaction_id == Transaction.id)
|
||||
.order_by(Transaction.time.asc(), UserCache.user_id)
|
||||
).all()
|
||||
|
||||
product_cache = sql_session.scalars(
|
||||
select(ProductCache)
|
||||
.join(LastCacheTransaction, ProductCache.last_cache_transaction_id == LastCacheTransaction.id)
|
||||
.join(Transaction, LastCacheTransaction.transaction_id == Transaction.id)
|
||||
.order_by(Transaction.time.asc(), ProductCache.product_id)
|
||||
).all()
|
||||
|
||||
return list(user_cache), list(product_cache)
|
||||
|
||||
|
||||
def test_affected_update_cache_no_history(sql_session: Session) -> None:
|
||||
insert_test_data(sql_session)
|
||||
|
||||
update_cache(sql_session)
|
||||
|
||||
|
||||
def test_affected_update_cache_basic_history(sql_session: Session) -> None:
|
||||
user1, user2, product1, product2 = insert_test_data(sql_session)
|
||||
|
||||
transactions = [
|
||||
Transaction.add_product(
|
||||
amount=10,
|
||||
per_product=10,
|
||||
user_id=user1.id,
|
||||
product_id=product1.id,
|
||||
product_count=1,
|
||||
),
|
||||
Transaction.add_product(
|
||||
amount=20,
|
||||
per_product=10,
|
||||
user_id=user2.id,
|
||||
product_id=product2.id,
|
||||
product_count=2,
|
||||
),
|
||||
]
|
||||
|
||||
assign_times(transactions)
|
||||
|
||||
sql_session.add_all(transactions)
|
||||
sql_session.commit()
|
||||
|
||||
assert_id_order_similar_to_time_order(transactions)
|
||||
|
||||
update_cache(sql_session)
|
||||
|
||||
user_cache = sql_session.scalars(select(UserCache).order_by(UserCache.user_id)).all()
|
||||
product_cache = sql_session.scalars(
|
||||
select(ProductCache).order_by(ProductCache.product_id)
|
||||
).all()
|
||||
|
||||
assert user_cache[0].user_id == user1.id
|
||||
assert user_cache[0].balance == 10
|
||||
assert user_cache[1].user_id == user2.id
|
||||
assert user_cache[1].balance == 20
|
||||
|
||||
assert product_cache[0].product_id == product1.id
|
||||
assert product_cache[0].stock == 1
|
||||
assert product_cache[0].price == 10
|
||||
|
||||
assert product_cache[1].product_id == product2.id
|
||||
assert product_cache[1].stock == 2
|
||||
assert product_cache[1].price == 10
|
||||
|
||||
|
||||
def test_update_cache_multiple_times_no_changes(sql_session: Session) -> None:
|
||||
user1, user2, product1, product2 = insert_test_data(sql_session)
|
||||
|
||||
transactions = [
|
||||
Transaction.add_product(
|
||||
amount=10,
|
||||
per_product=10,
|
||||
user_id=user1.id,
|
||||
product_id=product1.id,
|
||||
product_count=1,
|
||||
),
|
||||
Transaction.add_product(
|
||||
amount=20,
|
||||
per_product=10,
|
||||
user_id=user2.id,
|
||||
product_id=product2.id,
|
||||
product_count=2,
|
||||
),
|
||||
]
|
||||
|
||||
assign_times(transactions)
|
||||
|
||||
sql_session.add_all(transactions)
|
||||
sql_session.commit()
|
||||
|
||||
assert_id_order_similar_to_time_order(transactions)
|
||||
|
||||
update_cache(sql_session)
|
||||
|
||||
update_cache(sql_session)
|
||||
|
||||
user_cache, product_cache = get_cache_entries(sql_session)
|
||||
|
||||
assert user_cache[0].user_id == user1.id
|
||||
assert user_cache[0].balance == 10
|
||||
assert user_cache[1].user_id == user2.id
|
||||
assert user_cache[1].balance == 20
|
||||
|
||||
|
||||
def test_update_cache_multiple_times(sql_session: Session) -> None:
|
||||
user1, user2, product1, product2 = insert_test_data(sql_session)
|
||||
|
||||
transactions = [
|
||||
Transaction.add_product(
|
||||
amount=10,
|
||||
per_product=10,
|
||||
user_id=user1.id,
|
||||
product_id=product1.id,
|
||||
product_count=1,
|
||||
),
|
||||
Transaction.add_product(
|
||||
amount=20,
|
||||
per_product=10,
|
||||
user_id=user2.id,
|
||||
product_id=product2.id,
|
||||
product_count=2,
|
||||
),
|
||||
]
|
||||
|
||||
assign_times(transactions)
|
||||
|
||||
sql_session.add_all(transactions)
|
||||
sql_session.commit()
|
||||
|
||||
assert_id_order_similar_to_time_order(transactions)
|
||||
|
||||
update_cache(sql_session)
|
||||
|
||||
transactions_more = [
|
||||
Transaction.add_product(
|
||||
amount=30,
|
||||
per_product=10,
|
||||
user_id=user1.id,
|
||||
product_id=product1.id,
|
||||
product_count=3,
|
||||
),
|
||||
Transaction.buy_product(
|
||||
user_id=user1.id,
|
||||
product_id=product1.id,
|
||||
product_count=1,
|
||||
),
|
||||
]
|
||||
|
||||
assign_times(transactions_more, start_time=transactions[-1].time)
|
||||
|
||||
sql_session.add_all(transactions_more)
|
||||
sql_session.commit()
|
||||
|
||||
assert_id_order_similar_to_time_order(transactions_more)
|
||||
|
||||
update_cache(sql_session)
|
||||
|
||||
user_cache, product_cache = get_cache_entries(sql_session)
|
||||
|
||||
assert user_cache[0].user_id == user1.id
|
||||
assert user_cache[0].balance == 10
|
||||
assert user_cache[1].user_id == user2.id
|
||||
assert user_cache[1].balance == 20
|
||||
assert product_cache[0].product_id == product1.id
|
||||
assert product_cache[0].stock == 1
|
||||
assert product_cache[0].price == 10
|
||||
assert product_cache[1].product_id == product2.id
|
||||
assert product_cache[1].stock == 2
|
||||
assert product_cache[1].price == 10
|
||||
|
||||
assert user_cache[2].user_id == user1.id
|
||||
assert user_cache[2].balance == 10 + 30 - 10
|
||||
assert product_cache[2].product_id == product1.id
|
||||
assert product_cache[2].stock == 1 + 3 - 1
|
||||
assert product_cache[2].price == 10
|
||||
Reference in New Issue
Block a user