diff --git a/src/resolver.rs b/src/resolver.rs index de8d466..644a6a2 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use crate::{ - ast::{ExprNode, FunDecl, ProgramNode, StmtNode, VariableExpr}, + ast::{ClassMemberDecl, ExprNode, FunDecl, ProgramNode, StmtNode, VariableExpr}, errors::{ErrorKind, SloxError, SloxResult}, tokens::Token, }; @@ -266,20 +266,23 @@ where } /// Process all method definitions in a class. -fn resolve_class<'a, 'b>(rs: &mut ResolverState<'a>, methods: &'b [FunDecl]) -> ResolverResult +fn resolve_class<'a, 'b>( + rs: &mut ResolverState<'a>, + methods: &'b [ClassMemberDecl], +) -> ResolverResult where 'b: 'a, { rs.define_this(); - methods.iter().try_for_each(|method| { - rs.with_scope( + methods.iter().try_for_each(|member| match member { + ClassMemberDecl::Method(method) | ClassMemberDecl::StaticMethod(method) => rs.with_scope( |rs| resolve_function(rs, &method.params, &method.body), if method.name.lexeme == "init" { ScopeType::Initializer } else { ScopeType::Method }, - ) + ), }) } @@ -343,7 +346,7 @@ impl VarResolver for StmtNode { StmtNode::ClassDecl(decl) => { rs.declare(&decl.name, SymKind::Class)?; rs.define(&decl.name); - rs.with_scope(|rs| resolve_class(rs, &decl.methods), rs.current_type()) + rs.with_scope(|rs| resolve_class(rs, &decl.members), rs.current_type()) } StmtNode::If {