refactor
This commit is contained in:
@@ -56,6 +56,20 @@ class CSP:
|
||||
bool
|
||||
False if a domain becomes empty, otherwise True
|
||||
"""
|
||||
|
||||
def revise(csp, xi, xj) -> bool:
|
||||
revised = False
|
||||
for x in set(csp.domains[xi]):
|
||||
if not any(
|
||||
[
|
||||
(x, y) in csp.binary_constraints[(xi, xj)]
|
||||
for y in csp.domains[xj]
|
||||
]
|
||||
):
|
||||
csp.domains[xi].remove(x)
|
||||
revised = True
|
||||
return revised
|
||||
|
||||
queue = Queue()
|
||||
for edge in self.binary_constraints.keys():
|
||||
queue.put(edge)
|
||||
@@ -63,7 +77,7 @@ class CSP:
|
||||
while not queue.empty():
|
||||
(xi, xj) = queue.get()
|
||||
|
||||
if self._revise(xi, xj):
|
||||
if revise(self, xi, xj):
|
||||
if len(self.domains[xi]) == 0:
|
||||
return False
|
||||
for neighboring_edge in [
|
||||
@@ -74,53 +88,6 @@ class CSP:
|
||||
queue.put(neighboring_edge)
|
||||
return True
|
||||
|
||||
def _revise(self, xi, xj) -> bool:
|
||||
"""Internal of ac_3
|
||||
Makes an arc consistent by shrinking xi's domain.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
xi, xj
|
||||
Nodes of the arc
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if a change was made to the domain of xi, False otherwise
|
||||
|
||||
"""
|
||||
revised = False
|
||||
for x in set(self.domains[xi]):
|
||||
if not any(
|
||||
[(x, y) in self.binary_constraints[(xi, xj)] for y in self.domains[xj]]
|
||||
):
|
||||
self.domains[xi].remove(x)
|
||||
revised = True
|
||||
return revised
|
||||
|
||||
def _consistent(self, var, value, assignment) -> bool:
|
||||
for v in assignment:
|
||||
if (var, v) in self.binary_constraints.keys():
|
||||
if (value, assignment[v]) not in self.binary_constraints[(var, v)]:
|
||||
return False
|
||||
if (v, var) in self.binary_constraints.keys():
|
||||
if (value, assignment[v]) not in self.binary_constraints[(v, var)]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _backtrack(self, assignment: dict[str, Any]) -> dict | None:
|
||||
if len(assignment) == len(self.variables):
|
||||
return assignment # base-case
|
||||
var = self._select_unassigned_variable(assignment)
|
||||
for value in self._order_domain_values(var, assignment):
|
||||
if not self._consistent(var, value, assignment):
|
||||
continue
|
||||
assignment[var] = value
|
||||
if result := self._backtrack(assignment):
|
||||
return result
|
||||
assignment.pop(var)
|
||||
return None # failure
|
||||
|
||||
def backtracking_search(self) -> None | dict[str, Any]:
|
||||
"""Performs backtracking search on the CSP.
|
||||
|
||||
@@ -130,32 +97,53 @@ class CSP:
|
||||
A solution if any exists, otherwise None
|
||||
"""
|
||||
|
||||
return self._backtrack({})
|
||||
|
||||
def _select_unassigned_variable(self, assignment) -> str:
|
||||
for v in self.variables:
|
||||
if v not in assignment.keys():
|
||||
return v
|
||||
return "this shouldn't happen"
|
||||
|
||||
# TODO: may be wrong
|
||||
def _order_domain_values(self, var, assignment) -> list[Any]:
|
||||
"""least constrained value"""
|
||||
|
||||
def compare(value) -> int:
|
||||
s = 0
|
||||
for neighbor in [
|
||||
b for (a, b) in self.binary_constraints.keys() if a == var
|
||||
]:
|
||||
if neighbor in assignment.keys():
|
||||
def backtrack(csp, assignment: dict[str, Any]) -> dict | None:
|
||||
if len(assignment) == len(csp.variables):
|
||||
return assignment # base-case
|
||||
var = select_unassigned_variable(csp, assignment)
|
||||
for value in order_domain_values(csp, var, assignment):
|
||||
if not consistent(csp, var, value, assignment):
|
||||
continue
|
||||
for neighbor_value in self.domains[neighbor]:
|
||||
s += (value, neighbor_value) not in self.binary_constraints[
|
||||
(var, neighbor)
|
||||
]
|
||||
return s
|
||||
assignment[var] = value
|
||||
if result := backtrack(csp, assignment):
|
||||
return result
|
||||
assignment.pop(var)
|
||||
return None # failure
|
||||
|
||||
return sorted(list(self.domains[var]), key=compare)
|
||||
def consistent(csp, var, value, assignment) -> bool:
|
||||
for v in assignment:
|
||||
if (var, v) in csp.binary_constraints.keys():
|
||||
if (value, assignment[v]) not in csp.binary_constraints[(var, v)]:
|
||||
return False
|
||||
if (v, var) in csp.binary_constraints.keys():
|
||||
if (value, assignment[v]) not in csp.binary_constraints[(v, var)]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def select_unassigned_variable(csp, assignment) -> str:
|
||||
for v in csp.variables:
|
||||
if v not in assignment.keys():
|
||||
return v
|
||||
return "this shouldn't happen"
|
||||
|
||||
# choose least constrained value
|
||||
def order_domain_values(csp, var, assignment) -> list[Any]:
|
||||
def compare(value) -> int:
|
||||
s = 0
|
||||
for neighbor in [
|
||||
b for (a, b) in csp.binary_constraints.keys() if a == var
|
||||
]:
|
||||
if neighbor in assignment.keys():
|
||||
continue
|
||||
for neighbor_value in csp.domains[neighbor]:
|
||||
s += (value, neighbor_value) not in csp.binary_constraints[
|
||||
(var, neighbor)
|
||||
]
|
||||
return s
|
||||
|
||||
return sorted(list(csp.domains[var]), key=compare)
|
||||
|
||||
return backtrack(self, {})
|
||||
|
||||
|
||||
def alldiff(variables: list[str]) -> list[tuple[str, str]]:
|
||||
|
||||
Reference in New Issue
Block a user