diff --git a/dibbler/db.py b/dibbler/db.py deleted file mode 100644 index 7f52b43..0000000 --- a/dibbler/db.py +++ /dev/null @@ -1,5 +0,0 @@ -from sqlalchemy.engine.base import Engine -from sqlalchemy.orm import Session - -engine: Engine = None -session: Session = None diff --git a/dibbler/lib/helpers.py b/dibbler/lib/helpers.py index 30926a3..ed04de4 100644 --- a/dibbler/lib/helpers.py +++ b/dibbler/lib/helpers.py @@ -4,21 +4,22 @@ import os import signal from sqlalchemy import or_, and_ +from sqlalchemy.orm import Session from ..models import User, Product -def search_user(string, session, ignorethisflag=None): +def search_user(string, sql_session: Session, ignorethisflag=None): string = string.lower() exact_match = ( - session.query(User) + sql_session.query(User) .filter(or_(User.name == string, User.card == string, User.rfid == string)) .first() ) if exact_match: return exact_match user_list = ( - session.query(User) + sql_session.query(User) .filter( or_( User.name.ilike(f"%{string}%"), @@ -31,16 +32,16 @@ def search_user(string, session, ignorethisflag=None): return user_list -def search_product(string, session, find_hidden_products=True): +def search_product(string, sql_session: Session, find_hidden_products=True): if find_hidden_products: exact_match = ( - session.query(Product) + sql_session.query(Product) .filter(or_(Product.bar_code == string, Product.name == string)) .first() ) else: exact_match = ( - session.query(Product) + sql_session.query(Product) .filter( or_( Product.bar_code == string, @@ -53,7 +54,7 @@ def search_product(string, session, find_hidden_products=True): return exact_match if find_hidden_products: product_list = ( - session.query(Product) + sql_session.query(Product) .filter( or_( Product.bar_code.ilike(f"%{string}%"), @@ -64,7 +65,7 @@ def search_product(string, session, find_hidden_products=True): ) else: product_list = ( - session.query(Product) + sql_ession.query(Product) .filter( or_( Product.bar_code.ilike(f"%{string}%"), diff --git a/dibbler/lib/statistikkHelpers.py b/dibbler/lib/statistikkHelpers.py index f17da25..cffaf52 100644 --- a/dibbler/lib/statistikkHelpers.py +++ b/dibbler/lib/statistikkHelpers.py @@ -4,17 +4,17 @@ import datetime from collections import defaultdict +from sqlalchemy.orm import Session + from .helpers import * from ..models import Transaction -from ..db import session as create_session -def getUser(): +def getUser(sql_session: Session): while 1: string = input("user? ") - session = create_session() - user = search_user(string, session) - session.close() + user = search_user(string, sql_session) + sql_session.close() if not isinstance(user, list): return user.name i = 0 @@ -37,12 +37,10 @@ def getUser(): return user[n].name -def getProduct(): +def getProduct(sql_session: Session): while 1: string = input("product? ") - session = create_session() - product = search_product(string, session) - session.close() + product = search_product(string, sql_session) if not isinstance(product, list): return product.name i = 0 @@ -238,12 +236,11 @@ def addLineToDatabase(database, inputLine): return database -def buildDatabaseFromDb(inputType, inputProduct, inputUser): +def buildDatabaseFromDb(inputType, inputProduct, inputUser, sql_session: Session): sdate = input("enter start date (yyyy-mm-dd)? ") edate = input("enter end date (yyyy-mm-dd)? ") print("building database...") - session = create_session() - transaction_list = session.query(Transaction).all() + transaction_list = sql_session.query(Transaction).all() inputLine = InputLine(inputUser, inputProduct, inputType) startDate = getDateDb(transaction_list[0].time, sdate) endDate = getDateDb(transaction_list[-1].time, edate) @@ -275,7 +272,7 @@ def buildDatabaseFromDb(inputType, inputProduct, inputUser): print("saving as default.dibblerlog...", end=" ") f = open("default.dibblerlog", "w") line_format = "%s|%s|%s|%s|%s|%s\n" - transaction_list = session.query(Transaction).all() + transaction_list = sql_session.query(Transaction).all() for transaction in transaction_list: if transaction.purchase: products = "ยค".join([ent.product.name for ent in transaction.purchase.entries]) @@ -290,7 +287,6 @@ def buildDatabaseFromDb(inputType, inputProduct, inputUser): transaction.description, ) f.write(line.encode("utf8")) - session.close() f.close # bygg database.pengebeholdning if (inputType == 3) or (inputType == 4): @@ -466,7 +462,7 @@ def printGlobal(database, dateLine, n): ) -def alt4menuTextOnly(database, dateLine): +def alt4menuTextOnly(database, dateLine, sql_session: Session): n = 10 while 1: print( @@ -477,12 +473,12 @@ def alt4menuTextOnly(database, dateLine): break elif inp == "1": try: - printUser(database, dateLine, getUser(), n) + printUser(database, dateLine, getUser(sql_session), n) except: print("\n\nSomething is not right, (last date prior to first date?)") elif inp == "2": try: - printProduct(database, dateLine, getProduct(), n) + printProduct(database, dateLine, getProduct(sql_session), n) except: print("\n\nSomething is not right, (last date prior to first date?)") elif inp == "3": @@ -494,15 +490,15 @@ def alt4menuTextOnly(database, dateLine): n = int(input("set number to show ")) -def statisticsTextOnly(): +def statisticsTextOnly(sql_session: Session): inputType = 4 product = "" user = "" print("\n0: from file, 1: from database, q:quit") inp = input("") if inp == "1": - database, dateLine = buildDatabaseFromDb(inputType, product, user) + database, dateLine = buildDatabaseFromDb(inputType, product, user, sql_session) elif inp == "0" or inp == "": database, dateLine = buildDatabaseFromFile("default.dibblerlog", inputType, product, user) if not inp == "q": - alt4menuTextOnly(database, dateLine) + alt4menuTextOnly(database, dateLine, sql_session) diff --git a/dibbler/main.py b/dibbler/main.py index de63975..8c56cb2 100644 --- a/dibbler/main.py +++ b/dibbler/main.py @@ -2,10 +2,9 @@ import argparse from pathlib import Path from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session from dibbler.conf import load_config, config_db_string -from dibbler.db import engine, session parser = argparse.ArgumentParser() @@ -34,29 +33,28 @@ def main(): load_config(args.config) - global engine, session engine = create_engine(config_db_string()) - session = sessionmaker(bind=engine) + sql_session = Session(engine) if args.subcommand == "loop": import dibbler.subcommands.loop as loop - loop.main() + loop.main(sql_session) elif args.subcommand == "create-db": import dibbler.subcommands.makedb as makedb - makedb.main() + makedb.main(engine) elif args.subcommand == "slabbedasker": import dibbler.subcommands.slabbedasker as slabbedasker - slabbedasker.main() + slabbedasker.main(sql_session) elif args.subcommand == "seed-data": import dibbler.subcommands.seed_test_data as seed_test_data - seed_test_data.main() + seed_test_data.main(sql_session) if __name__ == "__main__": diff --git a/dibbler/menus/addstock.py b/dibbler/menus/addstock.py index acd7034..a3130df 100644 --- a/dibbler/menus/addstock.py +++ b/dibbler/menus/addstock.py @@ -1,6 +1,7 @@ from math import ceil import sqlalchemy +from sqlalchemy.orm import Session from dibbler.models import ( Product, @@ -13,8 +14,8 @@ from .helpermenus import Menu class AddStockMenu(Menu): - def __init__(self): - Menu.__init__(self, "Add stock and adjust credit", uses_db=True) + def __init__(self, sql_session: Session): + Menu.__init__(self, "Add stock and adjust credit", sql_session=sql_session, uses_db=True) self.help_text = """ Enter what you have bought for PVVVV here, along with your user name and how much money you're due in credits for the purchase when prompted.\n""" @@ -151,10 +152,10 @@ much money you're due in credits for the purchase when prompted.\n""" PurchaseEntry(purchase, product, -self.products[product][0]) purchase.perform_soft_purchase(-self.price, round_up=False) - self.session.add(purchase) + self.sql_session.add(purchase) try: - self.session.commit() + self.sql_session.commit() print("Success! Transaction performed:") # self.print_info() for user in self.users: diff --git a/dibbler/menus/buymenu.py b/dibbler/menus/buymenu.py index 66fa340..7d09be1 100644 --- a/dibbler/menus/buymenu.py +++ b/dibbler/menus/buymenu.py @@ -1,4 +1,5 @@ import sqlalchemy +from sqlalchemy.orm import Session from dibbler.conf import config from dibbler.models import ( @@ -13,10 +14,8 @@ from .helpermenus import Menu class BuyMenu(Menu): - def __init__(self, session=None): - Menu.__init__(self, "Buy", uses_db=True) - if session: - self.session = session + def __init__(self, sql_session: Session): + Menu.__init__(self, "Buy", sql_session=sql_session, uses_db=True) self.superfast_mode = False self.help_text = """ Each purchase may contain one or more products and one or more buyers. @@ -167,9 +166,9 @@ When finished, write an empty line to confirm the purchase.\n""" break self.purchase.perform_purchase() - self.session.add(self.purchase) + self.sql_session.add(self.purchase) try: - self.session.commit() + self.sql_session.commit() except sqlalchemy.exc.SQLAlchemyError as e: print(f"Could not store purchase: {e}") else: diff --git a/dibbler/menus/editing.py b/dibbler/menus/editing.py index 1d8d930..bcb1991 100644 --- a/dibbler/menus/editing.py +++ b/dibbler/menus/editing.py @@ -1,5 +1,7 @@ import sqlalchemy +from sqlalchemy.orm import Session + from dibbler.models import User, Product from .helpermenus import Menu, Selector @@ -14,8 +16,8 @@ __all__ = [ class AddUserMenu(Menu): - def __init__(self): - Menu.__init__(self, "Add user", uses_db=True) + def __init__(self, sql_session: Session): + Menu.__init__(self, "Add user", sql_session=sql_session, uses_db=True) def _execute(self): self.print_header() @@ -28,9 +30,9 @@ class AddUserMenu(Menu): cardnum = cardnum.lower() rfid = self.input_str("RFID (optional)", regex=User.rfid_re, length_range=(0, 10)) user = User(username, cardnum, rfid) - self.session.add(user) + self.sql_session.add(user) try: - self.session.commit() + self.sql_session.commit() print(f"User {username} stored") except sqlalchemy.exc.IntegrityError as e: print(f"Could not store user {username}: {e}") @@ -38,8 +40,8 @@ class AddUserMenu(Menu): class EditUserMenu(Menu): - def __init__(self): - Menu.__init__(self, "Edit user", uses_db=True) + def __init__(self, sql_session: Session): + Menu.__init__(self, "Edit user", sql_session=sql_session, uses_db=True) self.help_text = """ The only editable part of a user is its card number and rfid. @@ -69,7 +71,7 @@ user, then rfid (write an empty line to remove the card number or rfid). empty_string_is_none=True, ) try: - self.session.commit() + self.sql_session.commit() print(f"User {user.name} stored") except sqlalchemy.exc.SQLAlchemyError as e: print(f"Could not store user {user.name}: {e}") @@ -77,8 +79,8 @@ user, then rfid (write an empty line to remove the card number or rfid). class AddProductMenu(Menu): - def __init__(self): - Menu.__init__(self, "Add product", uses_db=True) + def __init__(self, sql_session: Session): + Menu.__init__(self, "Add product", sql_session=sql_session, uses_db=True) def _execute(self): self.print_header() @@ -86,9 +88,9 @@ class AddProductMenu(Menu): name = self.input_str("Name", regex=Product.name_re, length_range=(1, Product.name_length)) price = self.input_int("Price", 1, 100000) product = Product(bar_code, name, price) - self.session.add(product) + self.sql_session.add(product) try: - self.session.commit() + self.sql_session.commit() print(f"Product {name} stored") except sqlalchemy.exc.SQLAlchemyError as e: print(f"Could not store product {name}: {e}") @@ -96,8 +98,8 @@ class AddProductMenu(Menu): class EditProductMenu(Menu): - def __init__(self): - Menu.__init__(self, "Edit product", uses_db=True) + def __init__(self, sql_session: Session): + Menu.__init__(self, "Edit product", sql_session=sql_session, uses_db=True) def _execute(self): self.print_header() @@ -135,7 +137,7 @@ class EditProductMenu(Menu): product.hidden = self.confirm(f"Hidden(currently {product.hidden})", default=False) elif what == "store": try: - self.session.commit() + self.sql_session.commit() print(f"Product {product.name} stored") except sqlalchemy.exc.SQLAlchemyError as e: print(f"Could not store product {product.name}: {e}") @@ -149,8 +151,8 @@ class EditProductMenu(Menu): class AdjustStockMenu(Menu): - def __init__(self): - Menu.__init__(self, "Adjust stock", uses_db=True) + def __init__(self, sql_session: Session): + Menu.__init__(self, "Adjust stock", sql_session=sql_session, uses_db=True) def _execute(self): self.print_header() @@ -168,7 +170,7 @@ class AdjustStockMenu(Menu): product.stock += add_stock try: - self.session.commit() + self.sql_session.commit() print("Stock is now stored") self.pause() except sqlalchemy.exc.SQLAlchemyError as e: @@ -179,13 +181,13 @@ class AdjustStockMenu(Menu): class CleanupStockMenu(Menu): - def __init__(self): - Menu.__init__(self, "Stock Cleanup", uses_db=True) + def __init__(self, sql_session: Session): + Menu.__init__(self, "Stock Cleanup", sql_session=sql_session, uses_db=True) def _execute(self): self.print_header() - products = self.session.query(Product).filter(Product.stock != 0).all() + products = self.sql_session.query(Product).filter(Product.stock != 0).all() print("Every product in stock will be printed.") print("Entering no value will keep current stock or set it to 0 if it is negative.") @@ -199,12 +201,12 @@ class CleanupStockMenu(Menu): for product in products: oldstock = product.stock product.stock = self.input_int(product.name, 0, 10000, default=max(0, oldstock)) - self.session.add(product) + self.sql_session.add(product) if oldstock != product.stock: changed_products.append((product, oldstock)) try: - self.session.commit() + self.sql_session.commit() print("New stocks are now stored.") self.pause() except sqlalchemy.exc.SQLAlchemyError as e: diff --git a/dibbler/menus/helpermenus.py b/dibbler/menus/helpermenus.py index 2270d8c..ff9a0b7 100644 --- a/dibbler/menus/helpermenus.py +++ b/dibbler/menus/helpermenus.py @@ -5,7 +5,8 @@ import re import sys from select import select -from dibbler.db import session as create_session +from sqlalchemy.orm import Session + from dibbler.models import User from dibbler.lib.helpers import ( search_user, @@ -37,6 +38,7 @@ class Menu(object): exit_disallowed_msg=None, help_text=None, uses_db=False, + sql_session: Session | None=None, ): self.name = name self.items = items if items is not None else [] @@ -49,7 +51,9 @@ class Menu(object): self.help_text = help_text self.context = None self.uses_db = uses_db - self.session = None + self.sql_session: Session | None = sql_session + + assert not (self.uses_db and self.sql_session is None) def exit_menu(self): if self.exit_disallowed_msg is not None: @@ -338,7 +342,7 @@ class Menu(object): results = {} result_values = {} for thing in permitted_things: - results[thing] = search_fun[thing](search_str, self.session, find_hidden_products) + results[thing] = search_fun[thing](search_str, self.sql_session, find_hidden_products) result_values[thing] = self.search_result_value(results[thing]) selected_thing = argmax(result_values) if not results[selected_thing]: @@ -373,7 +377,7 @@ class Menu(object): print(f'"{string}" looks like a username, but no such user exists.') if self.confirm(f"Create user {string}?"): user = User(string, None) - self.session.add(user) + self.sql_session.add(user) return user return None if type_guess == "card": @@ -392,7 +396,7 @@ class Menu(object): (1, 10), ) user = User(username, string) - self.session.add(user) + self.sql_session.add(user) return user if selection == "set": user = self.input_user("User to set card number for") @@ -406,7 +410,7 @@ class Menu(object): return None def search_ui(self, search_fun, search_str, thing): - result = search_fun(search_str, self.session) + result = search_fun(search_str, self.sql_session) return self.search_ui2(search_str, result, thing) def search_ui2(self, search_str, result, thing): @@ -484,16 +488,13 @@ class Menu(object): def execute(self, **kwargs): self.set_context(None) try: - if self.uses_db and not self.session: - self.session = create_session() return self._execute(**kwargs) except ExitMenu: self.at_exit() return None finally: - if self.session is not None: - self.session.close() - self.session = None + if self.sql_session is not None: + self.sql_session = None def _execute(self, **kwargs): while True: diff --git a/dibbler/menus/mainmenu.py b/dibbler/menus/mainmenu.py index da1c4a6..a10d97e 100644 --- a/dibbler/menus/mainmenu.py +++ b/dibbler/menus/mainmenu.py @@ -3,7 +3,6 @@ import os import random import sys -from dibbler.db import session as create_session from .buymenu import BuyMenu from .faq import FAQMenu @@ -28,7 +27,7 @@ class MainMenu(Menu): else: num = 1 item_name = in_str - buy_menu = BuyMenu(create_session()) + buy_menu = BuyMenu(self.sql_session) thing = buy_menu.search_for_thing(item_name, find_hidden_products=False) if thing: buy_menu.execute(initial_contents=[(thing, num)]) diff --git a/dibbler/menus/miscmenus.py b/dibbler/menus/miscmenus.py index 218842b..975b3fc 100644 --- a/dibbler/menus/miscmenus.py +++ b/dibbler/menus/miscmenus.py @@ -1,4 +1,5 @@ import sqlalchemy +from sqlalchemy.orm import Session from dibbler.conf import config from dibbler.models import Transaction, Product, User @@ -8,8 +9,8 @@ from .helpermenus import Menu, Selector class TransferMenu(Menu): - def __init__(self): - Menu.__init__(self, "Transfer credit between users", uses_db=True) + def __init__(self, sql_session: Session): + Menu.__init__(self, "Transfer credit between users", sql_session=sql_session, uses_db=True) def _execute(self): self.print_header() @@ -26,10 +27,10 @@ class TransferMenu(Menu): t2 = Transaction(user2, -amount, f'transfer from {user1.name} "{comment}"') t1.perform_transaction() t2.perform_transaction() - self.session.add(t1) - self.session.add(t2) + self.sql_session.add(t1) + self.sql_session.add(t2) try: - self.session.commit() + self.sql_session.commit() print(f"Transferred {amount:d} kr from {user1} to {user2}") print(f"User {user1}'s credit is now {user1.credit:d} kr") print(f"User {user2}'s credit is now {user2.credit:d} kr") @@ -40,8 +41,8 @@ class TransferMenu(Menu): class ShowUserMenu(Menu): - def __init__(self): - Menu.__init__(self, "Show user", uses_db=True) + def __init__(self, sql_session: Session): + Menu.__init__(self, "Show user", sql_session=sql_session, uses_db=True) def _execute(self): self.print_header() @@ -123,13 +124,13 @@ class ShowUserMenu(Menu): class UserListMenu(Menu): - def __init__(self): - Menu.__init__(self, "User list", uses_db=True) + def __init__(self, sql_session: Session): + Menu.__init__(self, "User list", sql_session=sql_session, uses_db=True) def _execute(self): self.print_header() - user_list = self.session.query(User).all() - total_credit = self.session.query(sqlalchemy.func.sum(User.credit)).first()[0] + user_list = self.sql_session.query(User).all() + total_credit = self.sql_session.query(sqlalchemy.func.sum(User.credit)).first()[0] line_format = "%-12s | %6s\n" hline = "---------------------\n" @@ -144,8 +145,8 @@ class UserListMenu(Menu): class AdjustCreditMenu(Menu): - def __init__(self): - Menu.__init__(self, "Adjust credit", uses_db=True) + def __init__(self, sql_session: Session): + Menu.__init__(self, "Adjust credit", sql_session=sql_session, uses_db=True) def _execute(self): self.print_header() @@ -164,9 +165,9 @@ class AdjustCreditMenu(Menu): description = "manually adjusted credit" transaction = Transaction(user, -amount, description) transaction.perform_transaction() - self.session.add(transaction) + self.sql_session.add(transaction) try: - self.session.commit() + self.sql_session.commit() print(f"User {user.name}'s credit is now {user.credit:d} kr") except sqlalchemy.exc.SQLAlchemyError as e: print(f"Could not store transaction: {e}") @@ -174,14 +175,14 @@ class AdjustCreditMenu(Menu): class ProductListMenu(Menu): - def __init__(self): - Menu.__init__(self, "Product list", uses_db=True) + def __init__(self, sql_session: Session): + Menu.__init__(self, "Product list", sql_session=sql_session, uses_db=True) def _execute(self): self.print_header() text = "" product_list = ( - self.session.query(Product) + self.sql_session.query(Product) .filter(Product.hidden.is_(False)) .order_by(Product.stock.desc()) ) @@ -204,8 +205,8 @@ class ProductListMenu(Menu): class ProductSearchMenu(Menu): - def __init__(self): - Menu.__init__(self, "Product search", uses_db=True) + def __init__(self, sql_session: Session): + Menu.__init__(self, "Product search", sql_session=sql_session, uses_db=True) def _execute(self): self.print_header() diff --git a/dibbler/menus/printermenu.py b/dibbler/menus/printermenu.py index fe521f6..afe0c04 100644 --- a/dibbler/menus/printermenu.py +++ b/dibbler/menus/printermenu.py @@ -1,5 +1,7 @@ import re +from sqlalchemy.orm import Session + from dibbler.conf import config from dibbler.models import Product, User from dibbler.lib.printer_helpers import print_bar_code, print_name_label @@ -8,8 +10,8 @@ from .helpermenus import Menu class PrintLabelMenu(Menu): - def __init__(self): - Menu.__init__(self, "Print a label", uses_db=True) + def __init__(self, sql_session: Session): + Menu.__init__(self, "Print a label", sql_session=sql_session, uses_db=True) self.help_text = """ Prints out a product bar code on the printer diff --git a/dibbler/menus/stats.py b/dibbler/menus/stats.py index 4100ada..ee0ed0f 100644 --- a/dibbler/menus/stats.py +++ b/dibbler/menus/stats.py @@ -1,4 +1,5 @@ from sqlalchemy import desc, func +from sqlalchemy.orm import Session from dibbler.lib.helpers import less from dibbler.models import PurchaseEntry, Product, User @@ -15,14 +16,14 @@ __all__ = [ class ProductPopularityMenu(Menu): - def __init__(self): - Menu.__init__(self, "Products by popularity", uses_db=True) + def __init__(self, sql_session: Session): + Menu.__init__(self, "Products by popularity", sql_session=sql_session, uses_db=True) def _execute(self): self.print_header() text = "" sub = ( - self.session.query( + self.sql_session.query( PurchaseEntry.product_id, func.sum(PurchaseEntry.amount).label("purchase_count"), ) @@ -31,7 +32,7 @@ class ProductPopularityMenu(Menu): .subquery() ) product_list = ( - self.session.query(Product, sub.c.purchase_count) + self.sql_session.query(Product, sub.c.purchase_count) .outerjoin((sub, Product.product_id == sub.c.product_id)) .order_by(desc(sub.c.purchase_count)) .filter(sub.c.purchase_count is not None) @@ -48,14 +49,14 @@ class ProductPopularityMenu(Menu): class ProductRevenueMenu(Menu): - def __init__(self): - Menu.__init__(self, "Products by revenue", uses_db=True) + def __init__(self, sql_session: Session): + Menu.__init__(self, "Products by revenue", sql_session=sql_session, uses_db=True) def _execute(self): self.print_header() text = "" sub = ( - self.session.query( + self.sql_session.query( PurchaseEntry.product_id, func.sum(PurchaseEntry.amount).label("purchase_count"), ) @@ -64,7 +65,7 @@ class ProductRevenueMenu(Menu): .subquery() ) product_list = ( - self.session.query(Product, sub.c.purchase_count) + self.sql_session.query(Product, sub.c.purchase_count) .outerjoin((sub, Product.product_id == sub.c.product_id)) .order_by(desc(sub.c.purchase_count * Product.price)) .filter(sub.c.purchase_count is not None) @@ -86,22 +87,22 @@ class ProductRevenueMenu(Menu): class BalanceMenu(Menu): - def __init__(self): - Menu.__init__(self, "Total balance of PVVVV", uses_db=True) + def __init__(self, sql_session: Session): + Menu.__init__(self, "Total balance of PVVVV", sql_session=sql_session, uses_db=True) def _execute(self): self.print_header() text = "" total_value = 0 - product_list = self.session.query(Product).filter(Product.stock > 0).all() + product_list = self.sql_session.query(Product).filter(Product.stock > 0).all() for p in product_list: total_value += p.stock * p.price total_positive_credit = ( - self.session.query(func.sum(User.credit)).filter(User.credit > 0).first()[0] + self.sql_session.query(func.sum(User.credit)).filter(User.credit > 0).first()[0] ) total_negative_credit = ( - self.session.query(func.sum(User.credit)).filter(User.credit < 0).first()[0] + self.sql_session.query(func.sum(User.credit)).filter(User.credit < 0).first()[0] ) total_credit = total_positive_credit + total_negative_credit @@ -119,8 +120,8 @@ class BalanceMenu(Menu): class LoggedStatisticsMenu(Menu): - def __init__(self): - Menu.__init__(self, "Statistics from log", uses_db=True) + def __init__(self, sql_session: Session): + Menu.__init__(self, "Statistics from log", sql_session=sql_session, uses_db=True) def _execute(self): - statisticsTextOnly() + statisticsTextOnly(self.sql_session) diff --git a/dibbler/subcommands/loop.py b/dibbler/subcommands/loop.py index 2faff90..61fc81a 100755 --- a/dibbler/subcommands/loop.py +++ b/dibbler/subcommands/loop.py @@ -5,6 +5,8 @@ import random import sys import traceback +from sqlalchemy.orm import Session + from ..conf import config from ..lib.helpers import * from ..menus import * @@ -12,7 +14,7 @@ from ..menus import * random.seed() -def main(): +def main(sql_session: Session): if not config["general"]["stop_allowed"]: signal.signal(signal.SIGQUIT, signal.SIG_IGN) @@ -22,36 +24,36 @@ def main(): main = MainMenu( "Dibbler main menu", items=[ - BuyMenu(), - ProductListMenu(), - ShowUserMenu(), - UserListMenu(), - AdjustCreditMenu(), - TransferMenu(), - AddStockMenu(), + BuyMenu(sql_session), + ProductListMenu(sql_session), + ShowUserMenu(sql_session), + UserListMenu(sql_session), + AdjustCreditMenu(sql_session), + TransferMenu(sql_session), + AddStockMenu(sql_session), Menu( "Add/edit", items=[ - AddUserMenu(), - EditUserMenu(), - AddProductMenu(), - EditProductMenu(), - AdjustStockMenu(), - CleanupStockMenu(), + AddUserMenu(sql_session), + EditUserMenu(sql_session), + AddProductMenu(sql_session), + EditProductMenu(sql_session), + AdjustStockMenu(sql_session), + CleanupStockMenu(sql_session), ], ), - ProductSearchMenu(), + ProductSearchMenu(sql_session), Menu( "Statistics", items=[ - ProductPopularityMenu(), - ProductRevenueMenu(), - BalanceMenu(), - LoggedStatisticsMenu(), + ProductPopularityMenu(sql_session), + ProductRevenueMenu(sql_session), + BalanceMenu(sql_session), + LoggedStatisticsMenu(sql_session), ], ), FAQMenu(), - PrintLabelMenu(), + PrintLabelMenu(sql_session), ], exit_msg="happy happy joy joy", exit_confirm_msg="Really quit Dibbler?", @@ -73,7 +75,3 @@ def main(): else: break print("Restarting main menu.") - - -if __name__ == "__main__": - main() diff --git a/dibbler/subcommands/makedb.py b/dibbler/subcommands/makedb.py index 74a6826..3280516 100644 --- a/dibbler/subcommands/makedb.py +++ b/dibbler/subcommands/makedb.py @@ -1,11 +1,9 @@ #!/usr/bin/python + +from sqlalchemy.engine import Engine + from dibbler.models import Base -from dibbler.db import engine -def main(): +def main(engine: Engine): Base.metadata.create_all(engine) - - -if __name__ == "__main__": - main() diff --git a/dibbler/subcommands/seed_test_data.py b/dibbler/subcommands/seed_test_data.py index 9ffd02a..9225d26 100644 --- a/dibbler/subcommands/seed_test_data.py +++ b/dibbler/subcommands/seed_test_data.py @@ -1,24 +1,23 @@ import json -from dibbler.db import session as create_session from pathlib import Path -from dibbler.models.Product import Product +from sqlalchemy.orm import Session +from dibbler.models.Product import Product from dibbler.models.User import User JSON_FILE = Path(__file__).parent.parent.parent / "mock_data.json" -def clear_db(session): - session.query(Product).delete() - session.query(User).delete() - session.commit() +def clear_db(sql_session: Session): + sql_session.query(Product).delete() + sql_session.query(User).delete() + sql_session.commit() -def main(): - session = create_session() - clear_db(session) +def main(sql_session: Session): + clear_db(sql_session) product_items = [] user_items = [] @@ -43,6 +42,6 @@ def main(): ) user_items.append(user_item) - session.add_all(product_items) - session.add_all(user_items) - session.commit() + sql_session.add_all(product_items) + sql_session.add_all(user_items) + sql_session.commit() diff --git a/dibbler/subcommands/slabbedasker.py b/dibbler/subcommands/slabbedasker.py index c66df30..f668070 100644 --- a/dibbler/subcommands/slabbedasker.py +++ b/dibbler/subcommands/slabbedasker.py @@ -1,18 +1,13 @@ #!/usr/bin/python -from dibbler.db import session as create_session +from sqlalchemy.orm import Session + from dibbler.models import User -def main(): - # Start an SQL session - session = create_session() +def main(sql_session: Session): # Let's find all users with a negative credit - slabbedasker = session.query(User).filter(User.credit < 0).all() + slabbedasker = sql_session.query(User).filter(User.credit < 0).all() for slubbert in slabbedasker: print(f"{slubbert.name}, {slubbert.credit}") - - -if __name__ == "__main__": - main()