Split `db.py` and make declarative models

This commit is contained in:
Oystein Kristoffer Tveit 2023-08-29 22:50:53 +02:00
parent c25e5cec27
commit cde79ccb34
Signed by: oysteikt
GPG Key ID: 9F2F7D8250F35146
30 changed files with 415 additions and 224 deletions

2
.gitignore vendored
View File

@ -2,3 +2,5 @@ result
result-* result-*
dist dist
test.db

View File

@ -1,24 +0,0 @@
import argparse
from dibbler.conf import config
parser = argparse.ArgumentParser()
parser.add_argument(
"-c",
"--config",
help="Path to the config file",
type=str,
required=False,
)
def main():
args = parser.parse_args()
config.read(args.config)
import dibbler.text_based as text_based
text_based.main()
if __name__ == "__main__":
main()

7
dibbler/db.py Normal file
View File

@ -0,0 +1,7 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from dibbler.conf import config
engine = create_engine(config.get('database', 'url'))
Session = sessionmaker(bind=engine)

View File

@ -5,7 +5,7 @@ import signal
from sqlalchemy import or_, and_ from sqlalchemy import or_, and_
from .models.db import * from .models import User, Product
def search_user(string, session, ignorethisflag=None): def search_user(string, session, ignorethisflag=None):
string = string.lower() string = string.lower()

51
dibbler/main.py Normal file
View File

@ -0,0 +1,51 @@
import argparse
from dibbler.conf import config
parser = argparse.ArgumentParser()
parser.add_argument(
"-c",
"--config",
help="Path to the config file",
type=str,
required=False,
)
subparsers = parser.add_subparsers(
title='subcommands',
dest='subcommand',
required=True,
)
subparsers.add_parser(
'loop',
help='Run the dibbler loop'
)
subparsers.add_parser(
'create-db',
help='Create the database'
)
subparsers.add_parser(
'slabbedasker',
help='Find out who is slabbedasker'
)
def main():
args = parser.parse_args()
config.read(args.config)
if args.subcommand == 'loop':
import dibbler.text_based as text_based
text_based.main()
elif args.subcommand == 'create-db':
import dibbler.scripts.makedb as makedb
makedb.main()
elif args.subcommand == 'slabbedasker':
import dibbler.scripts.slabbedasker as slabbedasker
slabbedasker.main()
if __name__ == "__main__":
main()

View File

@ -1,4 +0,0 @@
#!/usr/bin/python
from .models.db import db
db.Base.metadata.create_all(db.engine)

View File

@ -2,7 +2,7 @@ from math import ceil
import sqlalchemy import sqlalchemy
from dibbler.models.db import ( from dibbler.models import (
Product, Product,
Purchase, Purchase,
PurchaseEntry, PurchaseEntry,

View File

@ -1,7 +1,7 @@
import sqlalchemy import sqlalchemy
from dibbler.conf import config from dibbler.conf import config
from dibbler.models.db import ( from dibbler.models import (
Product, Product,
Purchase, Purchase,
PurchaseEntry, PurchaseEntry,

View File

@ -1,6 +1,6 @@
import sqlalchemy import sqlalchemy
from dibbler.models.db import User, Product from dibbler.models import User, Product
from .helpermenus import Menu, Selector from .helpermenus import Menu, Selector
__all__ = ["AddUserMenu", "AddProductMenu", "EditProductMenu", "AdjustStockMenu", "CleanupStockMenu", "EditUserMenu"] __all__ = ["AddUserMenu", "AddProductMenu", "EditProductMenu", "AdjustStockMenu", "CleanupStockMenu", "EditUserMenu"]

View File

@ -5,7 +5,8 @@ import re
import sys import sys
from select import select from select import select
from dibbler.models.db import User, Session from dibbler.db import Session
from dibbler.models import User
from dibbler.helpers import ( from dibbler.helpers import (
search_user, search_user,
search_product, search_product,

View File

@ -4,7 +4,7 @@ import os
import random import random
import sys import sys
from dibbler.models.db import Session from dibbler.db import Session
from . import faq_commands, restart_commands from . import faq_commands, restart_commands
from .buymenu import BuyMenu from .buymenu import BuyMenu

View File

@ -1,7 +1,7 @@
import sqlalchemy import sqlalchemy
from dibbler.conf import config from dibbler.conf import config
from dibbler.models.db import Transaction, Product, User from dibbler.models import Transaction, Product, User
from dibbler.helpers import less from dibbler.helpers import less
from .helpermenus import Menu, Selector from .helpermenus import Menu, Selector

View File

@ -1,7 +1,7 @@
import re import re
from dibbler.conf import config from dibbler.conf import config
from dibbler.models.db import Product, User from dibbler.models import Product, User
from dibbler.printer_helpers import print_bar_code, print_name_label from dibbler.printer_helpers import print_bar_code, print_name_label
from .helpermenus import Menu from .helpermenus import Menu

View File

@ -1,7 +1,7 @@
from sqlalchemy import desc, func from sqlalchemy import desc, func
from dibbler.helpers import less from dibbler.helpers import less
from dibbler.models.db import PurchaseEntry, Product, User from dibbler.models import PurchaseEntry, Product, User
from dibbler.statistikkHelpers import statisticsTextOnly from dibbler.statistikkHelpers import statisticsTextOnly
from .helpermenus import Menu from .helpermenus import Menu

40
dibbler/models/Base.py Normal file
View File

@ -0,0 +1,40 @@
from sqlalchemy import MetaData
from sqlalchemy.orm import (
DeclarativeBase,
declared_attr,
)
from sqlalchemy.orm.collections import (
InstrumentedDict,
InstrumentedList,
InstrumentedSet,
)
class Base(DeclarativeBase):
metadata = MetaData(
naming_convention={
"ix": "ix_%(column_0_label)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_`%(constraint_name)s`",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s"
}
)
@declared_attr.directive
def __tablename__(cls) -> str:
return cls.__name__
def __repr__(self) -> str:
columns = ", ".join(
f"{k}={repr(v)}" for k, v in self.__dict__.items() if not any([
k.startswith("_"),
# Ensure that we don't try to print out the entire list of
# relationships, which could create an infinite loop
isinstance(v, Base),
isinstance(v, InstrumentedList),
isinstance(v, InstrumentedSet),
isinstance(v, InstrumentedDict),
])
)
return f"<{self.__class__.__name__}({columns})>"

45
dibbler/models/Product.py Normal file
View File

@ -0,0 +1,45 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from sqlalchemy import (
Boolean,
Integer,
String,
)
from sqlalchemy.orm import (
Mapped,
mapped_column,
relationship,
)
from .Base import Base
if TYPE_CHECKING:
from .PurchaseEntry import PurchaseEntry
from .UserProducts import UserProducts
class Product(Base):
__tablename__ = 'products'
product_id: Mapped[int] = mapped_column(Integer, primary_key=True)
bar_code: Mapped[str] = mapped_column(String(13))
name: Mapped[str] = mapped_column(String(45))
price: Mapped[int] = mapped_column(Integer)
stock: Mapped[int] = mapped_column(Integer)
hidden: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
purchases: Mapped[set[PurchaseEntry]] = relationship(back_populates="product")
users: Mapped[set[UserProducts]] = relationship(back_populates="product")
bar_code_re = r"[0-9]+"
name_re = r".+"
name_length = 45
def __init__(self, bar_code, name, price, stock=0, hidden = False):
self.name = name
self.bar_code = bar_code
self.price = price
self.stock = stock
self.hidden = hidden
def __str__(self):
return self.name

View File

@ -0,0 +1,69 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from datetime import datetime
import math
from sqlalchemy import (
Boolean,
DateTime,
Integer,
String,
)
from sqlalchemy.orm import (
Mapped,
mapped_column,
relationship,
)
from .Base import Base
from .Transaction import Transaction
if TYPE_CHECKING:
from .PurchaseEntry import PurchaseEntry
class Purchase(Base):
__tablename__ = 'purchases'
id: Mapped[int] = mapped_column(Integer, primary_key=True)
time: Mapped[datetime] = mapped_column(DateTime)
price: Mapped[int] = mapped_column(Integer)
transactions: Mapped[set[Transaction]] = relationship(back_populates="purchase", order_by='Transaction.user_name')
entries: Mapped[set[PurchaseEntry]] = relationship(back_populates="purchase")
def __init__(self):
pass
def is_complete(self):
return len(self.transactions) > 0 and len(self.entries) > 0
def price_per_transaction(self, round_up=True):
if round_up:
return int(math.ceil(float(self.price)/len(self.transactions)))
else:
return int(math.floor(float(self.price)/len(self.transactions)))
def set_price(self, round_up=True):
self.price = 0
for entry in self.entries:
self.price += entry.amount*entry.product.price
if len(self.transactions) > 0:
for t in self.transactions:
t.amount = self.price_per_transaction(round_up=round_up)
def perform_purchase(self, ignore_penalty=False, round_up=True):
self.time = datetime.datetime.now()
self.set_price(round_up=round_up)
for t in self.transactions:
t.perform_transaction(ignore_penalty=ignore_penalty)
for entry in self.entries:
entry.product.stock -= entry.amount
def perform_soft_purchase(self, price, round_up=True):
self.time = datetime.datetime.now()
self.price = price
for t in self.transactions:
t.amount = self.price_per_transaction(round_up=round_up)
for t in self.transactions:
t.perform_transaction()

View File

@ -0,0 +1,36 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from sqlalchemy import (
Integer,
ForeignKey,
)
from sqlalchemy.orm import (
Mapped,
mapped_column,
relationship,
)
from .Base import Base
if TYPE_CHECKING:
from .Product import Product
from .Purchase import Purchase
class PurchaseEntry(Base):
__tablename__ = 'purchase_entries'
id: Mapped[int] = mapped_column(Integer, primary_key=True)
amount: Mapped[int] = mapped_column(Integer)
product_id: Mapped[int] = mapped_column(ForeignKey("products.product_id"))
purchase_id: Mapped[int] = mapped_column(ForeignKey("purchases.id"))
product: Mapped[Product] = relationship(lazy='joined')
purchase: Mapped[Purchase] = relationship(lazy='joined')
def __init__(self, purchase, product, amount):
self.product = product
self.product_bar_code = product.bar_code
self.purchase = purchase
self.amount = amount

View File

@ -0,0 +1,51 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from datetime import datetime
from sqlalchemy import (
DateTime,
ForeignKey,
Integer,
String,
)
from sqlalchemy.orm import (
Mapped,
mapped_column,
relationship,
)
from .Base import Base
if TYPE_CHECKING:
from .User import User
from .Purchase import Purchase
class Transaction(Base):
__tablename__ = 'transactions'
id: Mapped[int] = mapped_column(Integer, primary_key=True)
time: Mapped[datetime] = mapped_column(DateTime)
amount: Mapped[int] = mapped_column(Integer)
penalty: Mapped[int] = mapped_column(Integer)
description: Mapped[str | None] = mapped_column(String(50))
user_name: Mapped[str] = mapped_column(ForeignKey('users.name'))
purchase_id: Mapped[int | None] = mapped_column(ForeignKey('purchases.id'))
user: Mapped[User] = relationship(lazy='joined')
purchase: Mapped[Purchase] = relationship(lazy='joined')
def __init__(self, user, amount=0, description=None, purchase=None, penalty=1):
self.user = user
self.amount = amount
self.description = description
self.purchase = purchase
self.penalty = penalty
def perform_transaction(self, ignore_penalty=False):
self.time = datetime.datetime.now()
if not ignore_penalty:
self.amount *= self.penalty
self.user.credit -= self.amount

47
dibbler/models/User.py Normal file
View File

@ -0,0 +1,47 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from sqlalchemy import (
Integer,
String,
)
from sqlalchemy.orm import (
Mapped,
mapped_column,
relationship,
)
from .Base import Base
if TYPE_CHECKING:
from .UserProducts import UserProducts
from .Transaction import Transaction
class User(Base):
__tablename__ = 'users'
name: Mapped[str] = mapped_column(String(10), primary_key=True)
credit: Mapped[str] = mapped_column(Integer)
card: Mapped[str | None] = mapped_column(String(20))
rfid: Mapped[str | None] = mapped_column(String(20))
products: Mapped[set[UserProducts]] = relationship(back_populates="user")
transactions: Mapped[set[Transaction]] = relationship(back_populates="user")
name_re = r"[a-z]+"
card_re = r"(([Nn][Tt][Nn][Uu])?[0-9]+)?"
rfid_re = r"[0-9a-fA-F]*"
def __init__(self, name, card, rfid=None, credit=0):
self.name = name
if card == '':
card = None
self.card = card
if rfid == '':
rfid = None
self.rfid = rfid
self.credit = credit
def __str__(self):
return self.name
def is_anonymous(self):
return self.card == '11122233'

View File

@ -0,0 +1,31 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from sqlalchemy import (
Integer,
ForeignKey,
)
from sqlalchemy.orm import (
Mapped,
mapped_column,
relationship,
)
from .Base import Base
if TYPE_CHECKING:
from .User import User
from .Product import Product
class UserProducts(Base):
__tablename__ = 'user_products'
user_name: Mapped[str] = mapped_column(ForeignKey('users.name'), primary_key=True)
product_id: Mapped[int] = mapped_column(ForeignKey("products.product_id"), primary_key=True)
count: Mapped[int] = mapped_column(Integer)
sign: Mapped[int] = mapped_column(Integer)
user: Mapped[User] = relationship()
product: Mapped[Product] = relationship()

View File

@ -0,0 +1,7 @@
from .Base import Base
from .Product import Product
from .Purchase import Purchase
from .PurchaseEntry import PurchaseEntry
from .Transaction import Transaction
from .User import User
from .UserProducts import UserProducts

View File

@ -1,177 +0,0 @@
from math import ceil, floor
import datetime
from sqlalchemy import Column, Integer, String, ForeignKey, create_engine, DateTime, Boolean
from sqlalchemy.orm import sessionmaker, relationship, backref
from sqlalchemy.ext.declarative import declarative_base
from dibbler.conf import config
engine = create_engine(config.get('database', 'url'))
Base = declarative_base()
Session = sessionmaker(bind=engine)
class User(Base):
__tablename__ = 'users'
name = Column(String(10), primary_key=True)
card = Column(String(20))
rfid = Column(String(20))
credit = Column(Integer)
name_re = r"[a-z]+"
card_re = r"(([Nn][Tt][Nn][Uu])?[0-9]+)?"
rfid_re = r"[0-9a-fA-F]*"
def __init__(self, name, card, rfid=None, credit=0):
self.name = name
if card == '':
card = None
self.card = card
if rfid == '':
rfid = None
self.rfid = rfid
self.credit = credit
def __repr__(self):
return f"<User('{self.name}')>"
def __str__(self):
return self.name
def is_anonymous(self):
return self.card == '11122233'
class Product(Base):
__tablename__ = 'products'
product_id = Column(Integer, primary_key=True)
bar_code = Column(String(13))
name = Column(String(45))
price = Column(Integer)
stock = Column(Integer)
hidden = Column(Boolean, nullable=False, default=False)
bar_code_re = r"[0-9]+"
name_re = r".+"
name_length = 45
def __init__(self, bar_code, name, price, stock=0, hidden = False):
self.name = name
self.bar_code = bar_code
self.price = price
self.stock = stock
self.hidden = hidden
def __repr__(self):
return f"<Product('{self.name}', '{self.bar_code}', '{self.price}', '{self.stock}', '{self.hidden}')>"
def __str__(self):
return self.name
class UserProducts(Base):
__tablename__ = 'user_products'
user_name = Column(String(10), ForeignKey('users.name'), primary_key=True)
product_id = Column(Integer, ForeignKey("products.product_id"), primary_key=True)
count = Column(Integer)
sign = Column(Integer)
user = relationship(User, backref=backref('products', order_by=count.desc()), lazy='joined')
product = relationship(Product, backref="users", lazy='joined')
class PurchaseEntry(Base):
__tablename__ = 'purchase_entries'
id = Column(Integer, primary_key=True)
purchase_id = Column(Integer, ForeignKey("purchases.id"))
product_id = Column(Integer, ForeignKey("products.product_id"))
amount = Column(Integer)
product = relationship(Product, backref="purchases")
def __init__(self, purchase, product, amount):
self.product = product
self.product_bar_code = product.bar_code
self.purchase = purchase
self.amount = amount
def __repr__(self):
return f"<PurchaseEntry('{self.product.name}', '{self.amount}')>"
class Transaction(Base):
__tablename__ = 'transactions'
id = Column(Integer, primary_key=True)
time = Column(DateTime)
user_name = Column(String(10), ForeignKey('users.name'))
amount = Column(Integer)
description = Column(String(50))
purchase_id = Column(Integer, ForeignKey('purchases.id'))
penalty = Column(Integer)
user = relationship(User, backref=backref('transactions', order_by=time))
def __init__(self, user, amount=0, description=None, purchase=None, penalty=1):
self.user = user
self.amount = amount
self.description = description
self.purchase = purchase
self.penalty = penalty
def perform_transaction(self, ignore_penalty=False):
self.time = datetime.datetime.now()
if not ignore_penalty:
self.amount *= self.penalty
self.user.credit -= self.amount
class Purchase(Base):
__tablename__ = 'purchases'
id = Column(Integer, primary_key=True)
time = Column(DateTime)
price = Column(Integer)
transactions = relationship(Transaction, order_by=Transaction.user_name, backref='purchase')
entries = relationship(PurchaseEntry, backref=backref("purchase"))
def __init__(self):
pass
def __repr__(self):
return f"<Purchase({int(self.id):d}, {self.price:d}, '{self.time.strftime('%c')}')>"
def is_complete(self):
return len(self.transactions) > 0 and len(self.entries) > 0
def price_per_transaction(self, round_up=True):
if round_up:
return int(ceil(float(self.price)/len(self.transactions)))
else:
return int(floor(float(self.price)/len(self.transactions)))
def set_price(self, round_up=True):
self.price = 0
for entry in self.entries:
self.price += entry.amount*entry.product.price
if len(self.transactions) > 0:
for t in self.transactions:
t.amount = self.price_per_transaction(round_up=round_up)
def perform_purchase(self, ignore_penalty=False, round_up=True):
self.time = datetime.datetime.now()
self.set_price(round_up=round_up)
for t in self.transactions:
t.perform_transaction(ignore_penalty=ignore_penalty)
for entry in self.entries:
entry.product.stock -= entry.amount
def perform_soft_purchase(self, price, round_up=True):
self.time = datetime.datetime.now()
self.price = price
for t in self.transactions:
t.amount = self.price_per_transaction(round_up=round_up)
for t in self.transactions:
t.perform_transaction()

View File

@ -0,0 +1,9 @@
#!/usr/bin/python
from dibbler.models import Base
from dibbler.db import engine
def main():
Base.metadata.create_all(engine)
if __name__ == "__main__":
main()

View File

@ -1,6 +1,7 @@
#!/usr/bin/python #!/usr/bin/python
from dibbler.models.db import * from dibbler.db import Session
from dibbler.models import User
def main(): def main():
# Start an SQL session # Start an SQL session

View File

@ -5,7 +5,8 @@ import datetime
from collections import defaultdict from collections import defaultdict
from .helpers import * from .helpers import *
from .models.db import *; from .models import Transaction
from .db import Session
def getUser(): def getUser():
while 1: while 1:

View File

@ -29,7 +29,6 @@ from .conf import config
random.seed() random.seed()
def main(): def main():
if not config.getboolean('general', 'stop_allowed'): if not config.getboolean('general', 'stop_allowed'):
signal.signal(signal.SIGQUIT, signal.SIG_IGN) signal.signal(signal.SIGQUIT, signal.SIG_IGN)

View File

@ -5,7 +5,8 @@ show_tracebacks = true
input_encoding = 'utf8' input_encoding = 'utf8'
[database] [database]
url = postgresql://robertem@127.0.0.1/pvvvv ; url = postgresql://robertem@127.0.0.1/pvvvv
url = sqlite:///test.db
[limits] [limits]
low_credit_warning_limit = -100 low_credit_warning_limit = -100

View File

@ -32,7 +32,7 @@ in {
wantedBy = [ "default.target" ]; wantedBy = [ "default.target" ];
serviceConfig = { serviceConfig = {
ExecStartPre = "-${screen} -X -S dibbler kill"; ExecStartPre = "-${screen} -X -S dibbler kill";
ExecStart = "${screen} -dmS dibbler -O -l ${cfg.package}/bin/dibbler --config ${cfg.config}"; ExecStart = "${screen} -dmS dibbler -O -l ${cfg.package}/bin/dibbler --config ${cfg.config} loop";
ExecStartPost = "${screen} -X -S dibbler width 42 80"; ExecStartPost = "${screen} -X -S dibbler width 42 80";
User = "dibbler"; User = "dibbler";
Group = "dibbler"; Group = "dibbler";

View File

@ -25,6 +25,4 @@ dynamic = ["version"]
include = ["dibbler*"] include = ["dibbler*"]
[project.scripts] [project.scripts]
dibbler = "dibbler.cli:main" dibbler = "dibbler.main:main"
slabbedasker = "dibbler.scripts.slabbedasker:main"
statistikk = "dibbler.scripts.statistikk:main"