From 8d65c288a3957c770930d5be6059f40b9b1d5ebc Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Emmanuel=20Beno=C3=AEt?= <tseeker@nocternity.net>
Date: Sat, 7 Jan 2023 10:51:34 +0100
Subject: [PATCH] Resolver - Refactored entering/exiting scopes

---
 src/resolver.rs | 38 +++++++++++++++-----------------------
 1 file changed, 15 insertions(+), 23 deletions(-)

diff --git a/src/resolver.rs b/src/resolver.rs
index 5d122ce..b6b2844 100644
--- a/src/resolver.rs
+++ b/src/resolver.rs
@@ -55,14 +55,16 @@ struct ResolverState {
 }
 
 impl ResolverState {
-    /// Enter a new scope.
-    fn begin_scope(&mut self) {
+    /// Execute some function with a new scope. The scope will be disposed
+    /// of after the function has been executed.
+    fn with_scope<F>(&mut self, f: F) -> ResolverResult
+    where
+        F: FnOnce(&mut Self) -> ResolverResult,
+    {
         self.scopes.push(HashMap::new());
-    }
-
-    /// End the current scope.
-    fn end_scope(&mut self) {
+        let result = f(self);
         self.scopes.pop();
+        result
     }
 
     /// Try to declare a symbol. If the scope already contains a declaration
@@ -156,18 +158,13 @@ fn resolve_function(
     params: &[Token],
     body: &Vec<ast::StmtNode>,
 ) -> ResolverResult {
-    rs.begin_scope();
     for param in params {
         rs.declare(param, SymKind::Variable)?;
         rs.define(param);
     }
     // Unlike the original Lox, function arguments and function bodies do
     // not use the same environment.
-    rs.begin_scope();
-    let result = body.resolve(rs);
-    rs.end_scope();
-    rs.end_scope();
-    result
+    rs.with_scope(|rs| body.resolve(rs))
 }
 
 /// Helper trait used to visit the various AST nodes with the resolver.
@@ -194,12 +191,7 @@ impl VarResolver for Vec<ast::StmtNode> {
 impl VarResolver for ast::StmtNode {
     fn resolve(&self, rs: &mut ResolverState) -> ResolverResult {
         match self {
-            ast::StmtNode::Block(stmts) => {
-                rs.begin_scope();
-                let result = stmts.resolve(rs);
-                rs.end_scope();
-                result
-            }
+            ast::StmtNode::Block(stmts) => rs.with_scope(|rs| stmts.resolve(rs)),
 
             ast::StmtNode::VarDecl(name, None) => {
                 rs.declare(name, SymKind::Variable)?;
@@ -215,7 +207,7 @@ impl VarResolver for ast::StmtNode {
             ast::StmtNode::FunDecl { name, params, body } => {
                 rs.declare(name, SymKind::Function)?;
                 rs.define(name);
-                resolve_function(rs, params, body)
+                rs.with_scope(|rs| resolve_function(rs, params, body))
             }
 
             ast::StmtNode::If {
@@ -270,16 +262,16 @@ impl VarResolver for ast::StmtNode {
 impl VarResolver for ast::ExprNode {
     fn resolve(&self, rs: &mut ResolverState) -> ResolverResult {
         match self {
-            ast::ExprNode::Variable { name, id } => {
-                rs.resolve_local(id, name, false)
-            }
+            ast::ExprNode::Variable { name, id } => rs.resolve_local(id, name, false),
 
             ast::ExprNode::Assignment { name, value, id } => {
                 value.resolve(rs)?;
                 rs.resolve_local(id, name, true)
             }
 
-            ast::ExprNode::Lambda { params, body } => resolve_function(rs, params, body),
+            ast::ExprNode::Lambda { params, body } => {
+                rs.with_scope(|rs| resolve_function(rs, params, body))
+            }
 
             ast::ExprNode::Logical {
                 left,