WIP: caching
All checks were successful
Run tests / run-tests (push) Successful in 3m6s
Run benchmarks / run-tests (push) Successful in 35m23s

This commit is contained in:
2026-01-08 13:52:52 +09:00
parent 159d389268
commit 5d6d47c664
14 changed files with 769 additions and 14 deletions

View File

@@ -27,6 +27,9 @@ class Base(DeclarativeBase):
@declared_attr.directive
def __tablename__(cls) -> str:
if hasattr(cls, "__table_name__"):
assert isinstance(cls.__table_name__, str)
return cls.__table_name__
return _pascal_case_to_snake_case(cls.__name__)
# NOTE: This is the default implementation of __repr__ for all tables,

View File

@@ -0,0 +1,26 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from sqlalchemy import Integer, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
from dibbler.models import Base
if TYPE_CHECKING:
from dibbler.models import Transaction
class LastCacheTransaction(Base):
"""Tracks the last transaction that affected various caches."""
id: Mapped[int] = mapped_column(Integer, primary_key=True)
"""Internal database ID"""
transaction_id: Mapped[int | None] = mapped_column(ForeignKey("trx.id"), index=True)
"""The ID of the last transaction that affected the cache(s)."""
transaction: Mapped[Transaction | None] = relationship(
lazy="joined",
foreign_keys=[transaction_id],
)
"""The last transaction that affected the cache(s)."""

View File

@@ -1,16 +1,30 @@
from datetime import datetime
from __future__ import annotations
from typing import TYPE_CHECKING
from sqlalchemy import Integer, DateTime
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy import Integer, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
from dibbler.models import Base
if TYPE_CHECKING:
from dibbler.models import LastCacheTransaction, Product
class ProductCache(Base):
product_id: Mapped[int] = mapped_column(Integer, primary_key=True)
id: Mapped[int] = mapped_column(Integer, primary_key=True)
"""Internal database ID"""
product_id: Mapped[int] = mapped_column(ForeignKey('product.id'))
product: Mapped[Product] = relationship(
lazy="joined",
foreign_keys=[product_id],
)
price: Mapped[int] = mapped_column(Integer)
price_timestamp: Mapped[datetime] = mapped_column(DateTime)
stock: Mapped[int] = mapped_column(Integer)
stock_timestamp: Mapped[datetime] = mapped_column(DateTime)
last_cache_transaction_id: Mapped[int | None] = mapped_column(ForeignKey("last_cache_transaction.id"), nullable=True)
last_cache_transaction: Mapped[LastCacheTransaction | None] = relationship(
lazy="joined",
foreign_keys=[last_cache_transaction_id],
)

View File

@@ -87,6 +87,7 @@ def _transaction_type_field_constraints(
class Transaction(Base):
__tablename__ = "trx"
__table_args__ = (
*[
_transaction_type_field_constraints(transaction_type, expected_fields)
@@ -210,7 +211,7 @@ class Transaction(Base):
"""
joint_transaction_id: Mapped[int | None] = mapped_column(
ForeignKey("transaction.id"),
ForeignKey("trx.id"),
index=True,
)
"""

View File

@@ -1,14 +1,30 @@
from datetime import datetime
from __future__ import annotations
from typing import TYPE_CHECKING
from sqlalchemy import Integer, DateTime
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy import Integer, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
from dibbler.models import Base
if TYPE_CHECKING:
from dibbler.models import LastCacheTransaction, User
# More like user balance cash money flow, amirite?
class UserBalanceCache(Base):
user_id: Mapped[int] = mapped_column(Integer, primary_key=True)
class UserCache(Base):
id: Mapped[int] = mapped_column(Integer, primary_key=True)
"""internal database id"""
user_id: Mapped[int] = mapped_column(ForeignKey('user.id'))
user: Mapped[User] = relationship(
lazy="joined",
foreign_keys=[user_id],
)
balance: Mapped[int] = mapped_column(Integer)
timestamp: Mapped[datetime] = mapped_column(DateTime)
last_cache_transaction_id: Mapped[int | None] = mapped_column(ForeignKey("last_cache_transaction.id"), nullable=True)
last_cache_transaction: Mapped[LastCacheTransaction | None] = relationship(
lazy="joined",
foreign_keys=[last_cache_transaction_id],
)

View File

@@ -1,13 +1,19 @@
__all__ = [
"Base",
"LastCacheTransaction",
"Product",
"ProductCache",
"Transaction",
"TransactionType",
"User",
"UserCache",
]
from .Base import Base
from .LastCacheTransaction import LastCacheTransaction
from .Product import Product
from .ProductCache import ProductCache
from .Transaction import Transaction
from .TransactionType import TransactionType
from .User import User
from .UserCache import UserCache

View File

@@ -4,6 +4,8 @@ __all__ = [
"adjust_interest",
"adjust_penalty",
"adjust_stock",
"affected_products",
"affected_users",
"create_product",
"create_user",
"current_interest",
@@ -19,6 +21,7 @@ __all__ = [
"throw_product",
"transaction_log",
"transfer",
"update_cache",
"user_balance",
"user_balance_log",
"user_products",
@@ -29,6 +32,8 @@ from .adjust_balance import adjust_balance
from .adjust_interest import adjust_interest
from .adjust_penalty import adjust_penalty
from .adjust_stock import adjust_stock
from .affected_products import affected_products
from .affected_users import affected_users
from .create_product import create_product
from .create_user import create_user
from .current_interest import current_interest
@@ -42,5 +47,6 @@ from .search_user import search_user
from .throw_product import throw_product
from .transaction_log import transaction_log
from .transfer import transfer
from .update_cache import update_cache
from .user_balance import user_balance, user_balance_log
from .user_products import user_products

View File

@@ -0,0 +1,88 @@
from datetime import datetime
from sqlalchemy import BindParameter, select
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, TransactionType
from dibbler.queries.query_helpers import until_filter, after_filter
def affected_products(
sql_session: Session,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: BindParameter[Transaction] | Transaction | None = None,
until_inclusive: bool = True,
after_time: BindParameter[datetime] | datetime | None = None,
after_transaction: Transaction | None = None,
after_inclusive: bool = True,
) -> set[Product]:
"""
Get all products where attributes were affected over a given interval.
"""
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = BindParameter("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
if not (after_time is None or after_transaction is None):
raise ValueError("Cannot filter by both after_time and after_transaction_id.")
if isinstance(after_time, datetime):
after_time = BindParameter("after_time", value=after_time)
if isinstance(after_transaction, Transaction):
if after_transaction.id is None:
raise ValueError("after_transaction must be persisted in the database.")
after_transaction_id = BindParameter("after_transaction_id", value=after_transaction.id)
else:
after_transaction_id = None
if after_time is not None and until_time is not None:
assert isinstance(after_time.value, datetime)
assert isinstance(until_time.value, datetime)
if after_time.value > until_time.value:
raise ValueError("after_time cannot be after until_time.")
if after_transaction is not None and until_transaction is not None:
assert after_transaction.time is not None
assert until_transaction.time is not None
if after_transaction.time > until_transaction.time:
raise ValueError("after_transaction cannot be after until_transaction.")
result = sql_session.scalars(
select(Product)
.distinct()
.join(Transaction, Product.id == Transaction.product_id)
.where(
Transaction.type_.in_(
[
TransactionType.ADD_PRODUCT.as_literal_column(),
TransactionType.ADJUST_STOCK.as_literal_column(),
TransactionType.BUY_PRODUCT.as_literal_column(),
TransactionType.JOINT.as_literal_column(),
TransactionType.THROW_PRODUCT.as_literal_column(),
]
),
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
after_filter(
after_time=after_time,
after_transaction_id=after_transaction_id,
after_inclusive=after_inclusive,
),
)
.order_by(Transaction.time.desc())
).all()
return set(result)

View File

@@ -0,0 +1,86 @@
from datetime import datetime
from sqlalchemy import BindParameter, select, or_
from sqlalchemy.orm import Session
from dibbler.models import Transaction, TransactionType, User
from dibbler.queries.query_helpers import until_filter, after_filter
def affected_users(
sql_session: Session,
until_time: BindParameter[datetime] | datetime | None = None,
until_transaction: BindParameter[Transaction] | Transaction | None = None,
until_inclusive: bool = True,
after_time: BindParameter[datetime] | datetime | None = None,
after_transaction: Transaction | None = None,
after_inclusive: bool = True,
) -> set[User]:
"""
Get all users where attributes were affected over a given interval.
"""
if isinstance(until_time, datetime):
until_time = BindParameter("until_time", value=until_time)
if isinstance(until_transaction, Transaction):
if until_transaction.id is None:
raise ValueError("until_transaction must be persisted in the database.")
until_transaction_id = BindParameter("until_transaction_id", value=until_transaction.id)
else:
until_transaction_id = None
if not (after_time is None or after_transaction is None):
raise ValueError("Cannot filter by both after_time and after_transaction_id.")
if isinstance(after_time, datetime):
after_time = BindParameter("after_time", value=after_time)
if isinstance(after_transaction, Transaction):
if after_transaction.id is None:
raise ValueError("after_transaction must be persisted in the database.")
after_transaction_id = BindParameter("after_transaction_id", value=after_transaction.id)
else:
after_transaction_id = None
if after_time is not None and until_time is not None:
assert isinstance(after_time.value, datetime)
assert isinstance(until_time.value, datetime)
if after_time.value > until_time.value:
raise ValueError("after_time cannot be after until_time.")
if after_transaction is not None and until_transaction is not None:
assert after_transaction.time is not None
assert until_transaction.time is not None
if after_transaction.time > until_transaction.time:
raise ValueError("after_transaction cannot be after until_transaction.")
result = sql_session.scalars(
select(User)
.distinct()
.join(Transaction, or_(User.id == Transaction.user_id, User.id == Transaction.transfer_user_id))
.where(
Transaction.type_.in_(
[
TransactionType.ADD_PRODUCT.as_literal_column(),
TransactionType.ADJUST_BALANCE.as_literal_column(),
TransactionType.BUY_PRODUCT.as_literal_column(),
TransactionType.TRANSFER.as_literal_column(),
]
),
until_filter(
until_time=until_time,
until_transaction_id=until_transaction_id,
until_inclusive=until_inclusive,
),
after_filter(
after_time=after_time,
after_transaction_id=after_transaction_id,
after_inclusive=after_inclusive,
),
)
.order_by(Transaction.time.desc())
).all()
return set(result)

View File

@@ -78,3 +78,42 @@ def until_filter(
)
return CONST_TRUE
def after_filter(
after_time: BindParameter[datetime] | None = None,
after_transaction_id: BindParameter[int] | None = None,
after_inclusive: bool = True,
transaction_time: QueryableAttribute = Transaction.time,
) -> ColumnExpressionArgument[bool]:
"""
Create a filter condition for transactions after a given time or transaction.
Only one of `after_time` or `after_transaction_id` may be specified.
"""
assert not (after_time is not None and after_transaction_id is not None), (
"Cannot filter by both after_time and after_transaction_id."
)
match (after_time, after_transaction_id, after_inclusive):
case (BindParameter(), None, True):
return transaction_time >= after_time
case (BindParameter(), None, False):
return transaction_time > after_time
case (None, BindParameter(), True):
return (
transaction_time
>= select(Transaction.time)
.where(Transaction.id == after_transaction_id)
.scalar_subquery()
)
case (None, BindParameter(), False):
return (
transaction_time
> select(Transaction.time)
.where(Transaction.id == after_transaction_id)
.scalar_subquery()
)
return CONST_TRUE

View File

@@ -0,0 +1,118 @@
from sqlalchemy import insert, select
from sqlalchemy.orm import Session
from dibbler.models import LastCacheTransaction, ProductCache, Transaction, UserCache
from dibbler.queries.affected_products import affected_products
from dibbler.queries.affected_users import affected_users
from dibbler.queries.product_price import product_price
from dibbler.queries.product_stock import product_stock
from dibbler.queries.user_balance import user_balance
def update_cache(
sql_session: Session,
use_cache: bool = True,
) -> None:
"""
Update the cache used for searching products.
If `use_cache` is False, the cache will be rebuilt from scratch.
"""
last_transaction = sql_session.scalars(
select(Transaction).order_by(Transaction.time.desc()).limit(1)
).one_or_none()
print(last_transaction)
if last_transaction is None:
# No transactions exist, nothing to update
return
if use_cache:
last_cache_transaction = sql_session.scalars(
select(LastCacheTransaction)
.join(Transaction, LastCacheTransaction.transaction_id == Transaction.id)
.order_by(Transaction.time.desc())
.limit(1)
).one_or_none()
if last_cache_transaction is not None:
last_cache_transaction = last_cache_transaction.transaction
else:
last_cache_transaction = None
if last_cache_transaction is not None and last_cache_transaction.id == last_transaction.id:
# Cache is already up to date
return
users = affected_users(
sql_session,
after_transaction=last_cache_transaction,
after_inclusive=False,
until_transaction=last_transaction,
)
products = affected_products(
sql_session,
after_transaction=last_cache_transaction,
after_inclusive=False,
until_transaction=last_transaction,
)
user_balances = {}
for user in users:
x = user_balance(
sql_session,
user,
use_cache=use_cache,
until_transaction=last_transaction,
)
user_balances[user.id] = x
product_stocks = {}
product_prices = {}
for product in products:
product_stocks[product.id] = product_stock(
sql_session,
product,
use_cache=use_cache,
until_transaction=last_transaction,
)
product_prices[product.id] = product_price(
sql_session,
product,
use_cache=use_cache,
until_transaction=last_transaction,
)
next_cache_transaction = LastCacheTransaction(transaction_id=last_transaction.id)
sql_session.add(next_cache_transaction)
sql_session.flush()
if not len(users) == 0:
sql_session.execute(
insert(UserCache),
[
{
"user_id": user.id,
"balance": user_balances[user.id],
"last_cache_transaction_id": next_cache_transaction.id,
}
for user in users
],
)
if not len(products) == 0:
sql_session.execute(
insert(ProductCache),
[
{
"product_id": product.id,
"stock": product_stocks[product.id],
"price": product_prices[product.id],
"last_cache_transaction_id": next_cache_transaction.id,
}
for product in products
],
)
sql_session.commit()

View File

@@ -0,0 +1,74 @@
from datetime import datetime, timedelta
import pytest
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
from dibbler.queries import affected_products
from tests.helpers import assert_id_order_similar_to_time_order, assign_times
def insert_test_data(sql_session: Session) -> tuple[User, list[Product]]:
user = User("Test User")
products = []
for i in range(10):
product = Product(f"12345678901{i:02d}", f"Test Product {i}")
products.append(product)
sql_session.add(user)
sql_session.add_all(products)
sql_session.commit()
return user, products
def test_affected_products_no_history(sql_session: Session) -> None:
insert_test_data(sql_session)
result = affected_products(sql_session)
assert result == set()
def test_affected_products_basic_history(sql_session: Session) -> None:
user, products = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
amount=10,
per_product=10,
user_id=user.id,
product_id=products[i].id,
product_count=1,
)
for i in range(5)
] + [
Transaction.buy_product(
user_id=user.id,
product_id=products[i].id,
product_count=1,
)
for i in range(3)
]
assign_times(transactions)
sql_session.add_all(transactions)
sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
result = affected_products(sql_session)
expected_products = {products[i] for i in range(5)}
assert result == expected_products
# def test_affected_products_after(sql_session: Session) -> None:
# def test_affected_products_until(sql_session: Session) -> None:
# def test_affected_products_after_until(sql_session: Session) -> None:
# def test_affected_products_after_inclusive(sql_session: Session) -> None:
# def test_affected_products_until_inclusive(sql_session: Session) -> None:
# def test_affected_products_after_until_inclusive(sql_session: Session) -> None:

View File

@@ -0,0 +1,74 @@
from datetime import datetime, timedelta
import pytest
from sqlalchemy.orm import Session
from dibbler.models import Product, Transaction, User
from dibbler.queries import affected_users
from tests.helpers import assert_id_order_similar_to_time_order, assign_times
def insert_test_data(sql_session: Session) -> tuple[list[User], Product]:
users = []
for i in range(10):
user = User(f"Test User {i + 1}")
users.append(user)
product = Product("1234567890123", "Test Product")
sql_session.add_all(users)
sql_session.add(product)
sql_session.commit()
return users, product
def test_affected_users_no_history(sql_session: Session) -> None:
insert_test_data(sql_session)
result = affected_users(sql_session)
assert result == set()
def test_affected_users_basic_history(sql_session: Session) -> None:
users, product = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
amount=10,
per_product=10,
user_id=users[i].id,
product_id=product.id,
product_count=1,
)
for i in range(5)
] + [
Transaction.buy_product(
user_id=users[i].id,
product_id=product.id,
product_count=1,
)
for i in range(3)
]
assign_times(transactions)
sql_session.add_all(transactions)
sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
result = affected_users(sql_session)
expected_users = {users[i] for i in range(5)}
assert result == expected_users
# def test_affected_users_after(sql_session: Session) -> None:
# def test_affected_users_until(sql_session: Session) -> None:
# def test_affected_users_after_until(sql_session: Session) -> None:
# def test_affected_users_after_inclusive(sql_session: Session) -> None:
# def test_affected_users_until_inclusive(sql_session: Session) -> None:
# def test_affected_users_after_until_inclusive(sql_session: Session) -> None:

View File

@@ -0,0 +1,204 @@
import pytest
from sqlalchemy import select
from sqlalchemy.orm import Session
from dibbler.models import Product, ProductCache, Transaction, User, UserCache
from dibbler.models.LastCacheTransaction import LastCacheTransaction
from dibbler.queries import update_cache
from tests.helpers import assert_id_order_similar_to_time_order, assign_times
def insert_test_data(sql_session: Session) -> tuple[User, User, Product, Product]:
user1 = User("Test User")
user2 = User("Another User")
product1 = Product("1234567890123", "Test Product 1")
product2 = Product("9876543210987", "Test Product 2")
sql_session.add_all([user1, user2, product1, product2])
sql_session.commit()
return user1, user2, product1, product2
def get_cache_entries(sql_session: Session) -> tuple[list[UserCache], list[ProductCache]]:
user_cache = sql_session.scalars(
select(UserCache)
.join(LastCacheTransaction, UserCache.last_cache_transaction_id == LastCacheTransaction.id)
.join(Transaction, LastCacheTransaction.transaction_id == Transaction.id)
.order_by(Transaction.time.asc(), UserCache.user_id)
).all()
product_cache = sql_session.scalars(
select(ProductCache)
.join(LastCacheTransaction, ProductCache.last_cache_transaction_id == LastCacheTransaction.id)
.join(Transaction, LastCacheTransaction.transaction_id == Transaction.id)
.order_by(Transaction.time.asc(), ProductCache.product_id)
).all()
return list(user_cache), list(product_cache)
def test_affected_update_cache_no_history(sql_session: Session) -> None:
insert_test_data(sql_session)
update_cache(sql_session)
def test_affected_update_cache_basic_history(sql_session: Session) -> None:
user1, user2, product1, product2 = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
amount=10,
per_product=10,
user_id=user1.id,
product_id=product1.id,
product_count=1,
),
Transaction.add_product(
amount=20,
per_product=10,
user_id=user2.id,
product_id=product2.id,
product_count=2,
),
]
assign_times(transactions)
sql_session.add_all(transactions)
sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
update_cache(sql_session)
user_cache = sql_session.scalars(select(UserCache).order_by(UserCache.user_id)).all()
product_cache = sql_session.scalars(
select(ProductCache).order_by(ProductCache.product_id)
).all()
assert user_cache[0].user_id == user1.id
assert user_cache[0].balance == 10
assert user_cache[1].user_id == user2.id
assert user_cache[1].balance == 20
assert product_cache[0].product_id == product1.id
assert product_cache[0].stock == 1
assert product_cache[0].price == 10
assert product_cache[1].product_id == product2.id
assert product_cache[1].stock == 2
assert product_cache[1].price == 10
def test_update_cache_multiple_times_no_changes(sql_session: Session) -> None:
user1, user2, product1, product2 = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
amount=10,
per_product=10,
user_id=user1.id,
product_id=product1.id,
product_count=1,
),
Transaction.add_product(
amount=20,
per_product=10,
user_id=user2.id,
product_id=product2.id,
product_count=2,
),
]
assign_times(transactions)
sql_session.add_all(transactions)
sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
update_cache(sql_session)
update_cache(sql_session)
user_cache, product_cache = get_cache_entries(sql_session)
assert user_cache[0].user_id == user1.id
assert user_cache[0].balance == 10
assert user_cache[1].user_id == user2.id
assert user_cache[1].balance == 20
def test_update_cache_multiple_times(sql_session: Session) -> None:
user1, user2, product1, product2 = insert_test_data(sql_session)
transactions = [
Transaction.add_product(
amount=10,
per_product=10,
user_id=user1.id,
product_id=product1.id,
product_count=1,
),
Transaction.add_product(
amount=20,
per_product=10,
user_id=user2.id,
product_id=product2.id,
product_count=2,
),
]
assign_times(transactions)
sql_session.add_all(transactions)
sql_session.commit()
assert_id_order_similar_to_time_order(transactions)
update_cache(sql_session)
transactions_more = [
Transaction.add_product(
amount=30,
per_product=10,
user_id=user1.id,
product_id=product1.id,
product_count=3,
),
Transaction.buy_product(
user_id=user1.id,
product_id=product1.id,
product_count=1,
),
]
assign_times(transactions_more, start_time=transactions[-1].time)
sql_session.add_all(transactions_more)
sql_session.commit()
assert_id_order_similar_to_time_order(transactions_more)
update_cache(sql_session)
user_cache, product_cache = get_cache_entries(sql_session)
assert user_cache[0].user_id == user1.id
assert user_cache[0].balance == 10
assert user_cache[1].user_id == user2.id
assert user_cache[1].balance == 20
assert product_cache[0].product_id == product1.id
assert product_cache[0].stock == 1
assert product_cache[0].price == 10
assert product_cache[1].product_id == product2.id
assert product_cache[1].stock == 2
assert product_cache[1].price == 10
assert user_cache[2].user_id == user1.id
assert user_cache[2].balance == 10 + 30 - 10
assert product_cache[2].product_id == product1.id
assert product_cache[2].stock == 1 + 3 - 1
assert product_cache[2].price == 10