diff --git a/src/main.rs b/src/main.rs index fdf5865..18f453e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,21 +6,39 @@ const ATOMIC_MASS_CONSTANT: f64 = 1.66053906660e-27f64; const WIEN_WAVELENGTH_DISPLACEMENT_LAW_CONSTANT: f64 = 2.897771955e-3f64; const SPEED_OF_LIGHT_IN_VACUUM: f64 = 299792458f64; +#[derive(Debug, PartialEq)] +#[allow(unused)] +enum Error { + Token(TokenizeError), + Conversion(InfixToPostfixError), + Calculate(CalculateError), + Implicit, +} + #[derive(Debug, PartialEq)] enum TokenizeError { NumberParseError(String), } +#[derive(Debug, PartialEq)] +enum PrecedenceError { + UnexpectedNumber, + UnexpectedSeparator, +} + #[derive(Debug, PartialEq)] enum InfixToPostfixError { ExpectedOperator, ExpectedLeftParenthesis, + Precedence(PrecedenceError), } #[derive(Debug, PartialEq)] enum CalculateError { PostfixExpectedNumbers, + TooFewArgs, EmptyStack, + UnexpectedToken(Token), } #[derive(PartialEq)] @@ -71,7 +89,7 @@ fn main() { continue; } - let result = compute(&input); + let result = handle_err(compute(&input)); print_result(result); } } else { @@ -82,11 +100,28 @@ fn main() { std::process::exit(1); } }; - let result = compute(&arg); + let result = handle_err(compute(&arg)); print_result(result); } } +fn handle_err(result: Result) -> f64 { + match result { + Ok(v) => v, + Err(e) => match e { + Error::Token(_) => print_err("failed to tokenize", e), + Error::Conversion(_) => print_err("failed to convert from infix to postfix", e), + Error::Calculate(_) => print_err("failed to calculate result", e), + Error::Implicit => print_err("expected at least two tokens", e), + }, + } +} + +fn print_err(msg: &str, error: Error) -> ! { + eprintln!("Error: {msg}: {:?}", error); + std::process::exit(1); +} + impl Token { fn is_operator(&self) -> bool { use Token::*; @@ -98,45 +133,18 @@ impl Token { } fn print_result(result: f64) { - if result > 10_000. || (result < 0.0001 && result > -0.0001) && result != 0. { + if result > 1E4 || (result < 1E-4 && result > -1E-4) && result != 0. { println!("{:E}", result); } else { println!("{result}"); } } -fn compute(input: &str) -> f64 { - let tokens = match tokenize(input) { - Ok(v) => v, - Err(e) => { - eprintln!("failed to parse tokens: {:?}", e); - std::process::exit(1); - } - }; - - let tokens = match implicit_operations(tokens) { - Some(v) => v, - None => { - eprintln!("expected at least two tokens"); - std::process::exit(1); - } - }; - - let tokens = match infix_to_postfix(tokens) { - Ok(v) => v, - Err(e) => { - eprintln!("failed to convert infix to postfix: {:?}", e); - std::process::exit(1); - } - }; - - match calculate(tokens) { - Ok(v) => v, - Err(e) => { - eprintln!("failed to calculate result: {:?}", e); - std::process::exit(1); - } - } +fn compute(input: &str) -> Result { + let tokens = tokenize(input).map_err(Error::Token)?; + let tokens = implicit_operations(tokens).ok_or(Error::Implicit)?; + let tokens = infix_to_postfix(tokens).map_err(Error::Conversion)?; + calculate(tokens).map_err(Error::Calculate) } impl std::fmt::Display for Token { @@ -279,7 +287,9 @@ fn implicit_operations(tokens: Vec) -> Option> { _ => (), }, Token::RightParenthesis => match token { - Token::LeftParenthesis | Token::Number(_) => new_tokens.push(Token::Multiply), + Token::Function(_) | Token::LeftParenthesis | Token::Number(_) => { + new_tokens.push(Token::Multiply) + } _ => (), }, _ => (), @@ -295,8 +305,8 @@ fn implicit_operations(tokens: Vec) -> Option> { if new_tokens[0] == Token::Subtract { if let Token::Number(n) = new_tokens[1] { - new_tokens.pop(); - new_tokens.pop(); + new_tokens.pop()?; + new_tokens.pop()?; new_tokens.push(Token::Number(-n)); } else if new_tokens[1] == Token::LeftParenthesis || matches!(new_tokens[1], Token::Function(_)) @@ -317,7 +327,7 @@ fn implicit_operations(tokens: Vec) -> Option> { _ => { if *b == Token::Subtract { if let Token::Number(n) = c { - if a.is_operator() { + if a.is_operator() || *a == Token::LeftParenthesis { new_tokens.pop(); new_tokens.push(Token::Number(-n)); } else { @@ -383,8 +393,8 @@ fn infix_to_postfix(tokens: Vec) -> Result, InfixToPostfixErro op => { while let Some(op2) = op_stack.last() { if op2 != &Token::LeftParenthesis - && (precedence(op2) > precedence(op) - || (precedence(op) == precedence(op2) + && (precedence(op2)? > precedence(op)? + || (precedence(op)? == precedence(op2)? && associativity(op) .ok_or(InfixToPostfixError::ExpectedOperator)? == Associativity::Left)) @@ -408,14 +418,18 @@ fn infix_to_postfix(tokens: Vec) -> Result, InfixToPostfixErro Ok(output) } -fn precedence(token: &Token) -> u8 { +fn precedence(token: &Token) -> Result { match token { - Token::Add | Token::Subtract => 0, - Token::Multiply | Token::Divide | Token::Modulus => 1, - Token::Power => 2, - Token::LeftParenthesis | Token::RightParenthesis | Token::Function(_) => 3, - Token::Number(_) => unreachable!(), - Token::Separator => unreachable!(), + Token::Add | Token::Subtract => Ok(0), + Token::Multiply | Token::Divide | Token::Modulus => Ok(1), + Token::Power => Ok(2), + Token::LeftParenthesis | Token::RightParenthesis | Token::Function(_) => Ok(3), + Token::Number(_) => Err(InfixToPostfixError::Precedence( + PrecedenceError::UnexpectedNumber, + )), + Token::Separator => Err(InfixToPostfixError::Precedence( + PrecedenceError::UnexpectedSeparator, + )), } } @@ -463,7 +477,7 @@ fn calculate(tokens: Vec) -> Result { let b = stack.pop().ok_or(CalculateError::PostfixExpectedNumbers)?; let (a, b) = match (a, b) { (Token::Number(a), Token::Number(b)) => (b, a), - _ => unreachable!(), + _ => return Err(CalculateError::TooFewArgs), }; let n = match token { @@ -475,7 +489,7 @@ fn calculate(tokens: Vec) -> Result { Token::Modulus => a % b, Token::Function(FunctionType::Max) => a.max(b), Token::Function(FunctionType::Min) => a.min(b), - _ => unreachable!(), + _ => return Err(CalculateError::UnexpectedToken(token.clone())), }; stack.push(Token::Number(n)); } @@ -771,29 +785,31 @@ mod tests { #[test] fn test_compute() { - assert_eq!(compute("(6 - 2)"), 4.); - assert_eq!(compute("3E2"), 300.); - assert_eq!(compute("4(3 - 1)^2"), 16.); - assert_eq!(compute("(6 - -2)"), 8.); - assert_eq!(compute("sin(3)"), 3_f64.sin()); - assert_eq!(compute("sin(3 - 1)"), 2_f64.sin()); - assert_eq!(compute("cos(3 - 1)"), 2_f64.cos()); + assert_eq!(compute("(6 - 2)"), Ok(4.)); + assert_eq!(compute("3E2"), Ok(300.)); + assert_eq!(compute("4(3 - 1)^2"), Ok(16.)); + assert_eq!(compute("(6 - -2)"), Ok(8.)); + assert_eq!(compute("sin(3)"), Ok(3_f64.sin())); + assert_eq!(compute("sin(3 - 1)"), Ok(2_f64.sin())); + assert_eq!(compute("cos(3 - 1)"), Ok(2_f64.cos())); assert_eq!( compute("3 + 4 × 2 ÷ ( 1 − 5 ) ^ 2 ^ 3"), - 3. + (4. * 2.) / (1_f64 - 5.).powf((2_f64).powf(3.)) + Ok(3. + (4. * 2.) / (1_f64 - 5.).powf((2_f64).powf(3.))) ); assert_eq!( compute("sin ( max ( 2, 3 ) ÷ 3 × π )"), - std::f64::consts::PI.sin() + Ok(consts::PI.sin()) ); - assert_eq!(compute("min(2, 3)"), 2.); - assert_eq!(compute("max(2, 3)"), 3.); - assert_eq!(compute("3 +- 2"), 1.); - assert_eq!(compute("3 -+ 2"), 1.); - assert_eq!(compute("1E-3"), 0.001); - assert_eq!(compute("3sin(3)"), 3. * 3_f64.sin()); - assert_eq!(compute("(1 + 2) - (1)"), 2.); - assert_eq!(compute("(1 + 2) - 1"), 2.); - assert_eq!(compute("sqrt(4)"), 2.); + assert_eq!(compute("min(2, 3)"), Ok(2.)); + assert_eq!(compute("max(2, 3)"), Ok(3.)); + assert_eq!(compute("3 +- 2"), Ok(1.)); + assert_eq!(compute("3 -+ 2"), Ok(1.)); + assert_eq!(compute("1E-3"), Ok(0.001)); + assert_eq!(compute("3sin(3)"), Ok(3. * 3_f64.sin())); + assert_eq!(compute("(1 + 2) - (1)"), Ok(2.)); + assert_eq!(compute("(1 + 2) - 1"), Ok(2.)); + assert_eq!(compute("sqrt(4)"), Ok(2.)); + assert_eq!(compute("sin(-1)"), Ok((-1f64).sin())); + assert_eq!(compute("sin(1) sin(1)"), Ok((1f64).sin().powf(2.))); } }