From c563654aceda7dee50148f10c3ed5829d5cfa199 Mon Sep 17 00:00:00 2001 From: RunasSudo Date: Sun, 27 Jun 2021 17:44:30 +1000 Subject: [PATCH] Implement constraint matrix logic and test case --- Cargo.lock | 38 +++++ Cargo.toml | 1 + build_docs.sh | 11 ++ src/constraints.rs | 346 +++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 2 + src/main.rs | 9 +- src/stv/meek.rs | 2 +- src/stv/wasm.rs | 2 +- tests/aec.rs | 2 +- tests/meek.rs | 4 +- 10 files changed, 408 insertions(+), 9 deletions(-) create mode 100755 build_docs.sh create mode 100644 src/constraints.rs diff --git a/Cargo.lock b/Cargo.lock index 92f5c6e..6da0914 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -365,6 +365,15 @@ dependencies = [ "cfg-if 1.0.0", ] +[[package]] +name = "matrixmultiply" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a8a15b776d9dfaecd44b03c5828c2199cddff5247215858aac14624f8d6b741" +dependencies = [ + "rawpointer", +] + [[package]] name = "memchr" version = "2.4.0" @@ -381,6 +390,19 @@ dependencies = [ "autocfg", ] +[[package]] +name = "ndarray" +version = "0.15.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08e854964160a323e65baa19a0b1a027f76d590faba01f05c0cbc3187221a8c9" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "rawpointer", +] + [[package]] name = "normalize-line-endings" version = "0.3.0" @@ -398,6 +420,15 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-complex" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26873667bbbb7c5182d4a37c1add32cdf09f841af72da53318fdb81543c15085" +dependencies = [ + "num-traits", +] + [[package]] name = "num-integer" version = "0.1.44" @@ -449,6 +480,7 @@ dependencies = [ "ibig", "itertools", "js-sys", + "ndarray", "num-bigint", "num-rational", "num-traits", @@ -564,6 +596,12 @@ version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34cf66eb183df1c5876e2dcf6b13d57340741e8dc255b48e40a26de954d06ae7" +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "regex" version = "1.5.4" diff --git a/Cargo.toml b/Cargo.toml index a4b8051..6a4c09d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ derive_more = "0.99.14" git-version = "0.3.4" ibig = "0.3.2" itertools = "0.10.1" +ndarray = "0.15.3" predicates = "1.0.8" num-traits = "0.2" sha2 = "0.9.5" diff --git a/build_docs.sh b/build_docs.sh new file mode 100755 index 0000000..fe88ab9 --- /dev/null +++ b/build_docs.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +# Gen docs +cargo doc --no-deps --document-private-items + +# Remove wasm-bindgen items from HTML +find target/doc/opentally -type f -print0 | xargs -0 perl -0777 -i -pe 's#. + */ + +use ndarray::{Array, Dimension, IxDyn}; + +use std::fmt; +use std::ops; + +#[derive(Debug)] +enum ConstraintError { + NoConformantResult, +} + +#[derive(Clone)] +struct ConstraintMatrixCell { + elected: usize, + min: usize, + max: usize, + cands: usize, +} + +struct ConstraintMatrix(Array); + +impl ConstraintMatrix { + pub fn new(constraints: &mut [usize]) -> Self { + // Add 1 to dimensions for totals cells + for c in constraints.iter_mut() { + *c += 1; + } + + return Self(Array::from_elem( + IxDyn(constraints), + ConstraintMatrixCell { + elected: 0, + min: 0, + max: 0, + cands: 0, + } + )); + } + + pub fn init(&mut self) { + let indices: Vec = ndarray::indices(self.0.shape()).into_iter().collect(); + + // Compute candidate totals + self.recount_cands(); + + // Initialise max for grand total cell + let idx = IxDyn(&vec![0; self.0.ndim()][..]); + self.0[&idx].max = self.0[&idx].cands; + + // Initialise max for inner cells (>=2 zeroes) + for idx in indices.iter() { + if (0..idx.ndim()).fold(0, |acc, d| if idx[d] != 0 { acc + 1 } else { acc }) < 2 { + continue; + } + self.0[idx].max = self.0[idx].cands; + } + + // NB: Bounds on min, max, etc. will be further refined in initial step() calls + } + + pub fn recount_cands(&mut self) { + let shape = Vec::from(self.0.shape()); + let indices: Vec = ndarray::indices(self.0.shape()).into_iter().collect(); + + // Compute cands/elected totals + for nzeroes in 1..self.0.ndim()+1 { + for idx in indices.iter() { + // First compute totals cells with 1 zero, then 2 zeroes, ... then grand total (all zeroes) + if (0..idx.ndim()).fold(0, |acc, d| if idx[d] == 0 { acc + 1 } else { acc }) != nzeroes { + continue; + } + + self.0[idx].cands = 0; + self.0[idx].elected = 0; + + // The axis along which to sum - if multiple, just pick the first, as these should agree + let zero_axis = (0..idx.ndim()).filter(|d| idx[*d] == 0).next().unwrap(); + + // Traverse along the axis and sum the candidates + let mut idx2 = idx.clone(); + for coord in 1..shape[zero_axis] { + idx2[zero_axis] = coord; + self.0[idx].cands += self.0[&idx2].cands; + self.0[idx].elected += self.0[&idx2].elected; + } + } + } + } + + pub fn step(&mut self) -> Result { + let shape = Vec::from(self.0.shape()); + let indices: Vec = ndarray::indices(self.0.shape()).into_iter().collect(); + + for idx in indices.iter() { + let cell = &mut self.0[idx]; + + // Rule 1: Ensure elected < min < max < cands + if cell.min < cell.elected { + cell.min = cell.elected; + return Ok(false); + } + if cell.max > cell.cands { + cell.max = cell.cands; + return Ok(false); + } + if cell.min > cell.max { + return Err(ConstraintError::NoConformantResult); + } + + let nzeroes = (0..idx.ndim()).fold(0, |acc, d| if idx[d] == 0 { acc + 1 } else { acc }); + + // Rule 2/3: Ensure min/max is possible in inner cells + if nzeroes == 0 { + for axis in 0..self.0.ndim() { + let mut idx2 = idx.clone(); + + // What is the min/max number of candidates that can be elected from other cells in this axis? + let (other_max, other_min) = (1..shape[axis]).fold((0, 0), |(acc_max, acc_min), coord| { + if coord == idx[axis] { + return (acc_max, acc_min); + } + idx2[axis] = coord; + return (acc_max + self.0[&idx2].max, acc_min + self.0[&idx2].min); + }); + + // What is the min/max number of candidates that can be elected along this axis? + idx2[axis] = 0; + let axis_max = self.0[&idx2].max; + let axis_min = self.0[&idx2].min; + + // How many must be elected from this cell? + let this_max = (axis_max as i32) - (other_min as i32); + let this_min = (axis_min as i32) - (other_max as i32); + + if this_max < (self.0[idx].max as i32) { + self.0[idx].max = this_max as usize; + return Ok(false); + } + if this_min > (self.0[idx].min as i32) { + self.0[idx].min = this_min as usize; + return Ok(false); + } + } + } + + // Rule 4/5: Ensure min/max is possible in totals cells + if nzeroes > 0 { + for axis in 0..self.0.ndim() { + if idx[axis] != 0 { + continue; + } + + // What is the total min/max along this axis? + let mut idx2 = idx.clone(); + let (axis_max, axis_min) = (1..shape[axis]).fold((0, 0), |(acc_max, acc_min), coord| { + idx2[axis] = coord; + return (acc_max + self.0[&idx2].max, acc_min + self.0[&idx2].min); + }); + + if axis_max < self.0[idx].max { + self.0[idx].max = axis_max; + return Ok(false); + } + if axis_min > self.0[idx].min { + self.0[idx].min = axis_min; + return Ok(false); + } + } + } + } + + return Ok(true); + } +} + +impl fmt::Display for ConstraintMatrix { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + let shape = self.0.shape(); + + let mut result = String::new(); + + // TODO: ≠2 dimensions + for y in 0..shape[1] { + result.push_str("+"); + for _ in 0..shape[0] { + result.push_str(if y == 1 { "=============+" } else { "-------------+" }); + } + result.push_str("\n"); + + result.push_str("|"); + for x in 0..shape[0] { + result.push_str(&format!(" Elected: {:2}", self[&[x, y]].elected)); + result.push_str(if x == 0 { " ‖" } else { " |" }); + } + result.push_str("\n"); + + result.push_str("|"); + for x in 0..shape[0] { + result.push_str(&format!(" Min: {:2}", self[&[x, y]].min)); + result.push_str(if x == 0 { " ‖" } else { " |" }); + } + result.push_str("\n"); + + result.push_str("|"); + for x in 0..shape[0] { + result.push_str(&format!(" Max: {:2}", self[&[x, y]].max)); + result.push_str(if x == 0 { " ‖" } else { " |" }); + } + result.push_str("\n"); + + result.push_str("|"); + for x in 0..shape[0] { + result.push_str(&format!(" Cands: {:2}", self[&[x, y]].cands)); + result.push_str(if x == 0 { " ‖" } else { " |" }); + } + result.push_str("\n"); + } + + result.push_str("+"); + for _ in 0..shape[0] { + result.push_str("-------------+"); + } + result.push_str("\n"); + + return f.write_str(&result); + } +} + +impl ops::Index<&[usize]> for ConstraintMatrix { + type Output = ConstraintMatrixCell; + fn index(&self, index: &[usize]) -> &Self::Output { &self.0[index] } +} +impl ops::IndexMut<&[usize]> for ConstraintMatrix { + fn index_mut(&mut self, index: &[usize]) -> &mut Self::Output { &mut self.0[index] } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn assert_cell(cm: &ConstraintMatrix, idx: &[usize], elected: usize, min: usize, max: usize, cands: usize) { + assert_eq!(cm[idx].elected, elected, "Failed to validate elected at {:?}", idx); + assert_eq!(cm[idx].min, min, "Failed to validate min at {:?}", idx); + assert_eq!(cm[idx].max, max, "Failed to validate min at {:?}", idx); + assert_eq!(cm[idx].cands, cands, "Failed to validate cands at {:?}", idx); + } + + #[test] + fn cm_otten() { + let mut cm = ConstraintMatrix::new(&mut [3, 2]); + + // Totals + let c = &mut cm[&[0, 1]]; c.min = 7; c.max = 7; + let c = &mut cm[&[0, 2]]; c.min = 7; c.max = 7; + let c = &mut cm[&[1, 0]]; c.min = 7; c.max = 7; + let c = &mut cm[&[2, 0]]; c.min = 6; c.max = 6; + let c = &mut cm[&[3, 0]]; c.min = 1; c.max = 1; + + // Candidates + let c = &mut cm[&[1, 1]]; c.cands = 4; + let c = &mut cm[&[2, 1]]; c.cands = 11; + let c = &mut cm[&[3, 1]]; c.cands = 2; + let c = &mut cm[&[1, 2]]; c.cands = 7; + let c = &mut cm[&[2, 2]]; c.cands = 3; + let c = &mut cm[&[3, 2]]; c.cands = 1; + + // Init + cm.init(); + while !cm.step().expect("No conformant result") {} + println!("{}", cm); + + assert_cell(&cm, &[1, 1], 0, 0, 4, 4); + assert_cell(&cm, &[2, 1], 0, 3, 6, 11); + assert_cell(&cm, &[3, 1], 0, 0, 1, 2); + assert_cell(&cm, &[0, 1], 0, 7, 7, 17); + assert_cell(&cm, &[1, 2], 0, 3, 7, 7); + assert_cell(&cm, &[2, 2], 0, 0, 3, 3); + assert_cell(&cm, &[3, 2], 0, 0, 1, 1); + assert_cell(&cm, &[0, 2], 0, 7, 7, 11); + assert_cell(&cm, &[1, 0], 0, 7, 7, 11); + assert_cell(&cm, &[2, 0], 0, 6, 6, 14); + assert_cell(&cm, &[3, 0], 0, 1, 1, 3); + assert_cell(&cm, &[0, 0], 0, 14, 14, 28); + + // Election of Welsh man + cm[&[3, 1]].elected += 1; + cm.recount_cands(); + while !cm.step().expect("No conformant result") {} + println!("{}", cm); + + assert_cell(&cm, &[1, 1], 0, 0, 3, 4); + assert_cell(&cm, &[2, 1], 0, 3, 6, 11); + assert_cell(&cm, &[3, 1], 1, 1, 1, 2); + assert_cell(&cm, &[0, 1], 1, 7, 7, 17); // Error in Otten paper + assert_cell(&cm, &[1, 2], 0, 4, 7, 7); + assert_cell(&cm, &[2, 2], 0, 0, 3, 3); + assert_cell(&cm, &[3, 2], 0, 0, 0, 1); + assert_cell(&cm, &[0, 2], 0, 7, 7, 11); + assert_cell(&cm, &[1, 0], 0, 7, 7, 11); + assert_cell(&cm, &[2, 0], 0, 6, 6, 14); + assert_cell(&cm, &[3, 0], 1, 1, 1, 3); + assert_cell(&cm, &[0, 0], 1, 14, 14, 28); + + // Remaining Welsh man, Welsh woman doomed + cm[&[3, 1]].cands -= 1; + cm[&[3, 2]].cands -= 1; + + // Election of 2 English men, 2 English women + // Exclusion of 1 Scottish woman + cm[&[1, 1]].elected += 2; + cm[&[1, 2]].elected += 2; + cm[&[2, 2]].cands -= 1; + cm.recount_cands(); + while !cm.step().expect("No conformant result") {} + println!("{}", cm); + + assert_cell(&cm, &[1, 1], 2, 2, 2, 4); + assert_cell(&cm, &[2, 1], 0, 4, 4, 11); + assert_cell(&cm, &[3, 1], 1, 1, 1, 1); + assert_cell(&cm, &[0, 1], 3, 7, 7, 16); // Error in Otten paper + assert_cell(&cm, &[1, 2], 2, 5, 5, 7); + assert_cell(&cm, &[2, 2], 0, 2, 2, 2); + assert_cell(&cm, &[3, 2], 0, 0, 0, 0); + assert_cell(&cm, &[0, 2], 2, 7, 7, 9); + assert_cell(&cm, &[1, 0], 4, 7, 7, 11); + assert_cell(&cm, &[2, 0], 0, 6, 6, 13); + assert_cell(&cm, &[3, 0], 1, 1, 1, 1); + assert_cell(&cm, &[0, 0], 5, 14, 14, 25); + } +} diff --git a/src/lib.rs b/src/lib.rs index 02888ea..fe69e89 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,6 +19,8 @@ //! Open source counting software for various preferential voting election systems +/// Data types and logic for constraints on elections +pub mod constraints; /// Data types for representing abstract elections pub mod election; /// Smart logging framework diff --git a/src/main.rs b/src/main.rs index 0fcdc08..260815d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -184,18 +184,19 @@ fn main() { // Create and count election according to --numbers if cmd_opts.numbers == "rational" { let election: Election = Election::from_blt(lines.map(|r| r.expect("IO Error").to_string()).into_iter()); - count_election(election, cmd_opts); + // Must specify :: here and in a few other places because ndarray causes E0275 otherwise + count_election::(election, cmd_opts); } else if cmd_opts.numbers == "float64" { let election: Election = Election::from_blt(lines.map(|r| r.expect("IO Error").to_string()).into_iter()); - count_election(election, cmd_opts); + count_election::(election, cmd_opts); } else if cmd_opts.numbers == "fixed" { Fixed::set_dps(cmd_opts.decimals); let election: Election = Election::from_blt(lines.map(|r| r.expect("IO Error").to_string()).into_iter()); - count_election(election, cmd_opts); + count_election::(election, cmd_opts); } else if cmd_opts.numbers == "gfixed" { GuardedFixed::set_dps(cmd_opts.decimals); let election: Election = Election::from_blt(lines.map(|r| r.expect("IO Error").to_string()).into_iter()); - count_election(election, cmd_opts); + count_election::(election, cmd_opts); } } diff --git a/src/stv/meek.rs b/src/stv/meek.rs index 03e532e..2277056 100644 --- a/src/stv/meek.rs +++ b/src/stv/meek.rs @@ -210,7 +210,7 @@ fn recompute_keep_values<'s, N: Number>(state: &mut CountState<'s, N>, opts: &ST } } -/// Determine if the specified surpluses should be distributed, according to [STVOptions::meek_quota_tolerance] +/// Determine if the specified surpluses should be distributed, according to [STVOptions::meek_surplus_tolerance] fn should_distribute_surpluses(state: &CountState, has_surplus: &Vec<&Candidate>, opts: &STVOptions) -> bool where for<'r> &'r N: ops::Sub<&'r N, Output=N>, diff --git a/src/stv/wasm.rs b/src/stv/wasm.rs index 1e96fdb..8f1e4aa 100644 --- a/src/stv/wasm.rs +++ b/src/stv/wasm.rs @@ -77,7 +77,7 @@ macro_rules! impl_type { #[wasm_bindgen] #[allow(non_snake_case)] pub fn [](state: &mut [], opts: &STVOptions) -> Result { - match stv::count_one_stage(&mut state.0, &opts.0) { + match stv::count_one_stage::<[<$type>]>(&mut state.0, &opts.0) { Ok(v) => Ok(v), Err(stv::STVError::RequireInput) => Err("RequireInput".into()), Err(stv::STVError::UnresolvedTie) => Err("UnresolvedTie".into()), diff --git a/tests/aec.rs b/tests/aec.rs index 3add705..3b74cd9 100644 --- a/tests/aec.rs +++ b/tests/aec.rs @@ -77,5 +77,5 @@ fn aec_tas19_rational() { meek_immediate_elect: false, pp_decimals: 2, }; - utils::validate_election(stages, records, election, stv_opts, None, &["exhausted", "lbf"]); + utils::validate_election::(stages, records, election, stv_opts, None, &["exhausted", "lbf"]); } diff --git a/tests/meek.rs b/tests/meek.rs index 275a759..7115b97 100644 --- a/tests/meek.rs +++ b/tests/meek.rs @@ -93,7 +93,7 @@ fn meek06_ers97_fixed12() { // Count to completion stv::count_init(&mut state, &stv_opts); - while !stv::count_one_stage(&mut state, &stv_opts).unwrap() {} + while !stv::count_one_stage::(&mut state, &stv_opts).unwrap() {} // Check states and keep values for (candidate, count_card) in state.candidates.iter() { @@ -167,7 +167,7 @@ fn meeknz_ers97_fixed12() { // Count to completion stv::count_init(&mut state, &stv_opts); - while !stv::count_one_stage(&mut state, &stv_opts).unwrap() {} + while !stv::count_one_stage::(&mut state, &stv_opts).unwrap() {} // Check states and keep values for (candidate, count_card) in state.candidates.iter() {