diff --git a/src/env.rs b/src/env.rs index 507660a..5fd2f75 100644 --- a/src/env.rs +++ b/src/env.rs @@ -57,7 +57,7 @@ pub fn env_get(env: &Env, sym: &str) -> MalRet { pub fn env_binds(outer: Env, binds: &MalType, exprs: &[MalType]) -> Result { let env = env_new(Some(outer)); - let binds = binds.if_vec()?; + let binds = binds.if_list()?; let binl = binds.len(); let expl = exprs.len(); if binl < expl { @@ -150,13 +150,25 @@ pub fn call_func(func: &MalType, args: &[MalType]) -> CallRet { } pub fn any_zero(list: &[MalType]) -> Result<&[MalType], MalErr> { - if list - .iter() - .any(|x| matches!(x, M::Num(v) if v.exact_zero())) - { - return Err(MalErr::unrecoverable("Attempting division by 0")); + match list.len() { + 1 => { + if list[0].if_number()?.get_num() == 0 { + Err(MalErr::unrecoverable("Attempting division by 0")) + } else { + Ok(list) + } + } + _ => { + if list[1..list.len()] + .iter() + .any(|x| matches!(x, M::Num(v) if v.exact_zero())) + { + Err(MalErr::unrecoverable("Attempting division by 0")) + } else { + Ok(list) + } + } } - Ok(list) } pub fn arithmetic_op(set: isize, f: fn(Frac, Frac) -> Frac, args: &[MalType]) -> MalRet { diff --git a/src/eval.rs b/src/eval.rs index ce3beb0..032c8ea 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -75,7 +75,7 @@ fn let_star_form(list: &[MalType], env: Env) -> Result<(MalType, Env), MalErr> { let inner_env = env_new(Some(env.clone())); // change the inner environment let (car, cdr) = car_cdr(list)?; - let list = car.if_vec()?; + let list = car.if_list()?; if list.len() % 2 != 0 { return Err(MalErr::unrecoverable( "let* form, number of arguments must be even", @@ -114,7 +114,7 @@ fn if_form(list: &[MalType], env: Env) -> MalRet { fn fn_star_form(list: &[MalType], env: Env) -> MalRet { let (binds, exprs) = car_cdr(list)?; - binds.if_vec()?; + binds.if_list()?; Ok(M::MalFun { // eval: eval_ast, params: Rc::new(binds.clone()), diff --git a/src/types.rs b/src/types.rs index 3530588..dc9aab4 100644 --- a/src/types.rs +++ b/src/types.rs @@ -192,9 +192,9 @@ impl MalType { pub fn if_list(&self) -> Result<&[MalType], MalErr> { match self { - Self::List(list) => Ok(list), + Self::List(list) | Self::Vector(list) => Ok(list), _ => Err(MalErr::unrecoverable( - format!("{:?} is not a list", prt(self)).as_str(), + format!("{:?} is not an iterable", prt(self)).as_str(), )), } }