diff --git a/assignment2/csp.py b/assignment2/csp.py index f01cc63..3ee6774 100644 --- a/assignment2/csp.py +++ b/assignment2/csp.py @@ -56,6 +56,20 @@ class CSP: bool False if a domain becomes empty, otherwise True """ + + def revise(csp, xi, xj) -> bool: + revised = False + for x in set(csp.domains[xi]): + if not any( + [ + (x, y) in csp.binary_constraints[(xi, xj)] + for y in csp.domains[xj] + ] + ): + csp.domains[xi].remove(x) + revised = True + return revised + queue = Queue() for edge in self.binary_constraints.keys(): queue.put(edge) @@ -63,7 +77,7 @@ class CSP: while not queue.empty(): (xi, xj) = queue.get() - if self._revise(xi, xj): + if revise(self, xi, xj): if len(self.domains[xi]) == 0: return False for neighboring_edge in [ @@ -74,53 +88,6 @@ class CSP: queue.put(neighboring_edge) return True - def _revise(self, xi, xj) -> bool: - """Internal of ac_3 - Makes an arc consistent by shrinking xi's domain. - - Parameters - ---------- - xi, xj - Nodes of the arc - - Returns - ------- - bool - True if a change was made to the domain of xi, False otherwise - - """ - revised = False - for x in set(self.domains[xi]): - if not any( - [(x, y) in self.binary_constraints[(xi, xj)] for y in self.domains[xj]] - ): - self.domains[xi].remove(x) - revised = True - return revised - - def _consistent(self, var, value, assignment) -> bool: - for v in assignment: - if (var, v) in self.binary_constraints.keys(): - if (value, assignment[v]) not in self.binary_constraints[(var, v)]: - return False - if (v, var) in self.binary_constraints.keys(): - if (value, assignment[v]) not in self.binary_constraints[(v, var)]: - return False - return True - - def _backtrack(self, assignment: dict[str, Any]) -> dict | None: - if len(assignment) == len(self.variables): - return assignment # base-case - var = self._select_unassigned_variable(assignment) - for value in self._order_domain_values(var, assignment): - if not self._consistent(var, value, assignment): - continue - assignment[var] = value - if result := self._backtrack(assignment): - return result - assignment.pop(var) - return None # failure - def backtracking_search(self) -> None | dict[str, Any]: """Performs backtracking search on the CSP. @@ -130,32 +97,53 @@ class CSP: A solution if any exists, otherwise None """ - return self._backtrack({}) - - def _select_unassigned_variable(self, assignment) -> str: - for v in self.variables: - if v not in assignment.keys(): - return v - return "this shouldn't happen" - - # TODO: may be wrong - def _order_domain_values(self, var, assignment) -> list[Any]: - """least constrained value""" - - def compare(value) -> int: - s = 0 - for neighbor in [ - b for (a, b) in self.binary_constraints.keys() if a == var - ]: - if neighbor in assignment.keys(): + def backtrack(csp, assignment: dict[str, Any]) -> dict | None: + if len(assignment) == len(csp.variables): + return assignment # base-case + var = select_unassigned_variable(csp, assignment) + for value in order_domain_values(csp, var, assignment): + if not consistent(csp, var, value, assignment): continue - for neighbor_value in self.domains[neighbor]: - s += (value, neighbor_value) not in self.binary_constraints[ - (var, neighbor) - ] - return s + assignment[var] = value + if result := backtrack(csp, assignment): + return result + assignment.pop(var) + return None # failure - return sorted(list(self.domains[var]), key=compare) + def consistent(csp, var, value, assignment) -> bool: + for v in assignment: + if (var, v) in csp.binary_constraints.keys(): + if (value, assignment[v]) not in csp.binary_constraints[(var, v)]: + return False + if (v, var) in csp.binary_constraints.keys(): + if (value, assignment[v]) not in csp.binary_constraints[(v, var)]: + return False + return True + + def select_unassigned_variable(csp, assignment) -> str: + for v in csp.variables: + if v not in assignment.keys(): + return v + return "this shouldn't happen" + + # choose least constrained value + def order_domain_values(csp, var, assignment) -> list[Any]: + def compare(value) -> int: + s = 0 + for neighbor in [ + b for (a, b) in csp.binary_constraints.keys() if a == var + ]: + if neighbor in assignment.keys(): + continue + for neighbor_value in csp.domains[neighbor]: + s += (value, neighbor_value) not in csp.binary_constraints[ + (var, neighbor) + ] + return s + + return sorted(list(csp.domains[var]), key=compare) + + return backtrack(self, {}) def alldiff(variables: list[str]) -> list[tuple[str, str]]: