repl: better typing
This commit is contained in:
@@ -1,49 +1,63 @@
|
||||
from cmd import Cmd
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable
|
||||
from typing import Any, Callable, TypeVar, Generic, Self
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
def prompt_yes_no(question: str, default: bool | None = None) -> bool:
|
||||
prompt = {
|
||||
None: '[y/n]',
|
||||
True: '[Y/n]',
|
||||
False: '[y/N]',
|
||||
None: "[y/n]",
|
||||
True: "[Y/n]",
|
||||
False: "[y/N]",
|
||||
}[default]
|
||||
|
||||
while not any([
|
||||
(answer := input(f'{question} {prompt} ').lower()) in ('y','n'),
|
||||
(default != None and answer.strip() == '')
|
||||
]):
|
||||
while not any(
|
||||
[
|
||||
(answer := input(f"{question} {prompt} ").lower()) in ("y", "n"),
|
||||
(default is not None and answer.strip() == ""),
|
||||
]
|
||||
):
|
||||
pass
|
||||
|
||||
return {
|
||||
'y': True,
|
||||
'n': False,
|
||||
'': default,
|
||||
"y": True,
|
||||
"n": False,
|
||||
"": default,
|
||||
}[answer]
|
||||
|
||||
|
||||
def format_date(date: datetime):
|
||||
def format_date(date: datetime) -> str:
|
||||
return date.strftime("%a %b %d, %Y")
|
||||
|
||||
|
||||
class InteractiveItemSelector(Cmd):
|
||||
InteractiveItemSelectorType = TypeVar("InteractiveItemSelectorType")
|
||||
|
||||
|
||||
class InteractiveItemSelector(Generic[InteractiveItemSelectorType], Cmd):
|
||||
sql_session: Session
|
||||
execute_selection: Callable[[Session, type, str], list[Any]]
|
||||
complete_selection: Callable[[Session, type, str], list[str]]
|
||||
default_item: InteractiveItemSelectorType | None
|
||||
result: InteractiveItemSelectorType | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cls: type,
|
||||
sql_session: Session,
|
||||
execute_selection: Callable[[Session, type, str], list[Any]] = lambda session, cls, arg: session.scalars(
|
||||
select(cls)
|
||||
.where(cls.name == arg),
|
||||
execute_selection: Callable[
|
||||
[Session, type, str], list[InteractiveItemSelectorType]
|
||||
] = lambda session, cls, arg: session.scalars(
|
||||
select(cls).where(cls.name == arg),
|
||||
).all(),
|
||||
complete_selection: Callable[[Session, type, str], list[str]] = lambda session, cls, text: session.scalars(
|
||||
select(cls.name)
|
||||
.where(cls.name.istartswith(text)),
|
||||
complete_selection: Callable[[Session, type, str], list[str]] = lambda session,
|
||||
cls,
|
||||
text: session.scalars(
|
||||
select(cls.name).where(cls.name.istartswith(text)),
|
||||
).all(),
|
||||
default: Any | None = None,
|
||||
):
|
||||
default: InteractiveItemSelectorType | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
This is a utility class for prompting the user to select an
|
||||
item from the database. The default functions assumes that
|
||||
@@ -61,10 +75,9 @@ class InteractiveItemSelector(Cmd):
|
||||
self.result = None
|
||||
|
||||
if default is not None:
|
||||
self.prompt = f'Select {cls.__name__} [{default.name}]> '
|
||||
self.prompt = f"Select {cls.__name__} [{default.name}]> "
|
||||
else:
|
||||
self.prompt = f'Select {cls.__name__}> '
|
||||
|
||||
self.prompt = f"Select {cls.__name__}> "
|
||||
|
||||
def emptyline(self) -> bool:
|
||||
if self.default_item is not None:
|
||||
@@ -72,15 +85,14 @@ class InteractiveItemSelector(Cmd):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def default(self, arg: str):
|
||||
def default(self, arg: str) -> bool | None:
|
||||
try:
|
||||
result = self.execute_selection(self.sql_session, self.cls, arg)
|
||||
result = self.execute_selection(self.sql_session, self.cls, arg)
|
||||
except Exception as e:
|
||||
print(f'Error executing selection: {e}')
|
||||
print(f"Error executing selection: {e}")
|
||||
|
||||
if len(result) != 1:
|
||||
print(f'No such {self.cls.__name__} found: {arg}')
|
||||
print(f"No such {self.cls.__name__} found: {arg}")
|
||||
return
|
||||
|
||||
self.result = result[0]
|
||||
@@ -132,72 +144,77 @@ class NumberedCmd(Cmd):
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
prompt_header: str | None = None
|
||||
funcs: dict[int, dict[str, str | Callable[[Any, str], bool | None]]]
|
||||
funcs: dict[
|
||||
int,
|
||||
dict[
|
||||
str,
|
||||
str | Callable[[Self, str], bool | None],
|
||||
],
|
||||
]
|
||||
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
def _generate_usage_list(self) -> str:
|
||||
result = ''
|
||||
result = ""
|
||||
for i, func in self.funcs.items():
|
||||
if i == 0:
|
||||
i = '*'
|
||||
result += f'{i}) {func["doc"]}\n'
|
||||
i = "*"
|
||||
result += f"{i}) {func['doc']}\n"
|
||||
return result
|
||||
|
||||
|
||||
def _default(self, arg: str):
|
||||
def _default(self, arg: str) -> bool | None:
|
||||
try:
|
||||
i = int(arg)
|
||||
self.funcs[i]
|
||||
except (ValueError, KeyError):
|
||||
return
|
||||
|
||||
return self.funcs[i]['f'](self, arg)
|
||||
return self.funcs[i]["f"](self, arg)
|
||||
|
||||
|
||||
def default(self, arg: str):
|
||||
def default(self, arg: str) -> bool | None:
|
||||
return self._default(arg)
|
||||
|
||||
|
||||
def _postcmd(self, stop: bool, _: str) -> bool:
|
||||
if not stop:
|
||||
print()
|
||||
print('-----------------')
|
||||
print()
|
||||
print()
|
||||
print("-----------------")
|
||||
print()
|
||||
return stop
|
||||
|
||||
|
||||
def postcmd(self, stop: bool, line: str) -> bool:
|
||||
return self._postcmd(stop, line)
|
||||
|
||||
|
||||
@property
|
||||
def prompt(self):
|
||||
result = ''
|
||||
def prompt(self) -> str:
|
||||
result = ""
|
||||
|
||||
if self.prompt_header != None:
|
||||
result += self.prompt_header + '\n'
|
||||
if self.prompt_header is not None:
|
||||
result += self.prompt_header + "\n"
|
||||
|
||||
result += self._generate_usage_list()
|
||||
|
||||
if self.lastcmd == '':
|
||||
result += f'> '
|
||||
if self.lastcmd == "":
|
||||
result += "> "
|
||||
else:
|
||||
result += f'[{self.lastcmd}]> '
|
||||
result += f"[{self.lastcmd}]> "
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class NumberedItemSelector(NumberedCmd):
|
||||
NumberedItemSelectorType = TypeVar("NumberedItemSelectorType")
|
||||
|
||||
|
||||
class NumberedItemSelector(Generic[NumberedItemSelectorType], NumberedCmd):
|
||||
items: list[NumberedItemSelectorType]
|
||||
stringify: Callable[[NumberedItemSelectorType], str]
|
||||
result: NumberedItemSelectorType | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
items: list[Any],
|
||||
stringify: Callable[[Any], str] = lambda x: str(x),
|
||||
items: list[NumberedItemSelectorType],
|
||||
stringify: Callable[[NumberedItemSelectorType], str] = lambda x: str(x),
|
||||
):
|
||||
super().__init__()
|
||||
self.items = items
|
||||
@@ -205,13 +222,12 @@ class NumberedItemSelector(NumberedCmd):
|
||||
self.result = None
|
||||
self.funcs = {
|
||||
i: {
|
||||
'f': self._select_item,
|
||||
'doc': self.stringify(item),
|
||||
"f": self._select_item,
|
||||
"doc": self.stringify(item),
|
||||
}
|
||||
for i, item in enumerate(items, start=1)
|
||||
}
|
||||
|
||||
|
||||
def _select_item(self, *a):
|
||||
self.result = self.items[int(self.lastcmd)-1]
|
||||
def _select_item(self, *a) -> bool:
|
||||
self.result = self.items[int(self.lastcmd) - 1]
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user