AST - Refactored function declarations as a separate struct

This commit is contained in:
Emmanuel BENOîT 2023-01-07 12:18:18 +01:00
parent 6252bed605
commit 1af25457d5
4 changed files with 158 additions and 154 deletions

View file

@ -8,17 +8,21 @@ use crate::tokens::Token;
#[derive(Default, Debug, Clone)] #[derive(Default, Debug, Clone)]
pub struct ProgramNode(pub Vec<StmtNode>); pub struct ProgramNode(pub Vec<StmtNode>);
/// A function declaration.
#[derive(Debug, Clone)]
pub struct FunDecl {
pub name: Token,
pub params: Vec<Token>,
pub body: Vec<StmtNode>,
}
/// An AST node that represents a statement. /// An AST node that represents a statement.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum StmtNode { pub enum StmtNode {
/// A variable declaration /// A variable declaration
VarDecl(Token, Option<ExprNode>), VarDecl(Token, Option<ExprNode>),
/// A function declaration /// A function declaration
FunDecl { FunDecl(FunDecl),
name: Token,
params: Vec<Token>,
body: Vec<StmtNode>,
},
/// An single expression /// An single expression
Expression(ExprNode), Expression(ExprNode),
/// The print statement /// The print statement
@ -149,15 +153,15 @@ impl AstDumper for StmtNode {
Self::VarDecl(name, Some(expr)) => format!("( var {} {} )", name.lexeme, expr.dump()), Self::VarDecl(name, Some(expr)) => format!("( var {} {} )", name.lexeme, expr.dump()),
Self::VarDecl(name, None) => format!("( var {} nil )", name.lexeme), Self::VarDecl(name, None) => format!("( var {} nil )", name.lexeme),
Self::FunDecl { name, params, body } => format!( Self::FunDecl(fun_decl) => format!(
"( fun {} ({}) {} )", "( fun {} ({}) {} )",
name.lexeme, fun_decl.name.lexeme,
params fun_decl.params
.iter() .iter()
.map(|token| &token.lexeme as &str) .map(|token| &token.lexeme as &str)
.collect::<Vec<&str>>() .collect::<Vec<&str>>()
.join(" "), .join(" "),
body.iter() fun_decl.body.iter()
.map(|stmt| stmt.dump()) .map(|stmt| stmt.dump())
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join(" ") .join(" ")

View file

@ -1,7 +1,7 @@
use std::{cell::RefCell, rc::Rc}; use std::{cell::RefCell, rc::Rc};
use crate::{ use crate::{
ast, ast::{ExprNode, FunDecl, ProgramNode, StmtNode},
errors::{ErrorKind, SloxError, SloxResult}, errors::{ErrorKind, SloxError, SloxResult},
interpreter::{functions::Function, Environment, EnvironmentRef, Value}, interpreter::{functions::Function, Environment, EnvironmentRef, Value},
resolver::ResolvedVariables, resolver::ResolvedVariables,
@ -9,7 +9,7 @@ use crate::{
}; };
/// Evaluate an interpretable, returning its value. /// Evaluate an interpretable, returning its value.
pub fn evaluate(ast: &ast::ProgramNode, vars: ResolvedVariables) -> SloxResult<Value> { pub fn evaluate(ast: &ProgramNode, vars: ResolvedVariables) -> SloxResult<Value> {
let mut state = InterpreterState::new(&vars); let mut state = InterpreterState::new(&vars);
ast.interpret(&mut state).map(|v| v.result()) ast.interpret(&mut state).map(|v| v.result())
} }
@ -127,7 +127,7 @@ fn error<T>(token: &Token, message: &str) -> SloxResult<T> {
* INTERPRETER FOR PROGRAM NODES * * INTERPRETER FOR PROGRAM NODES *
* ----------------------------- */ * ----------------------------- */
impl Interpretable for ast::ProgramNode { impl Interpretable for ProgramNode {
fn interpret(&self, es: &mut InterpreterState) -> InterpreterResult { fn interpret(&self, es: &mut InterpreterState) -> InterpreterResult {
for stmt in self.0.iter() { for stmt in self.0.iter() {
stmt.interpret(es)?; stmt.interpret(es)?;
@ -140,39 +140,37 @@ impl Interpretable for ast::ProgramNode {
* INTERPRETER FOR STATEMENT NODES * * INTERPRETER FOR STATEMENT NODES *
* ------------------------------- */ * ------------------------------- */
impl Interpretable for ast::StmtNode { impl Interpretable for StmtNode {
fn interpret(&self, es: &mut InterpreterState) -> InterpreterResult { fn interpret(&self, es: &mut InterpreterState) -> InterpreterResult {
match self { match self {
ast::StmtNode::VarDecl(name, expr) => self.on_var_decl(es, name, expr), StmtNode::VarDecl(name, expr) => self.on_var_decl(es, name, expr),
ast::StmtNode::FunDecl { name, params, body } => { StmtNode::FunDecl(decl) => self.on_fun_decl(es, decl),
self.on_fun_decl(es, name, params, body) StmtNode::Expression(expr) => expr.interpret(es),
} StmtNode::Print(expr) => self.on_print(es, expr),
ast::StmtNode::Expression(expr) => expr.interpret(es), StmtNode::Block(statements) => self.on_block(es, statements),
ast::StmtNode::Print(expr) => self.on_print(es, expr), StmtNode::If {
ast::StmtNode::Block(statements) => self.on_block(es, statements),
ast::StmtNode::If {
condition, condition,
then_branch, then_branch,
else_branch, else_branch,
} => self.on_if_statement(es, condition, then_branch, else_branch), } => self.on_if_statement(es, condition, then_branch, else_branch),
ast::StmtNode::Loop { StmtNode::Loop {
label, label,
condition, condition,
body, body,
after_body, after_body,
} => self.on_loop_statement(es, label, condition, body, after_body), } => self.on_loop_statement(es, label, condition, body, after_body),
ast::StmtNode::LoopControl { StmtNode::LoopControl {
is_break, is_break,
loop_name, loop_name,
} => self.on_loop_control_statemement(*is_break, loop_name), } => self.on_loop_control_statemement(*is_break, loop_name),
ast::StmtNode::Return { token: _, value } => self.on_return_statement(es, value), StmtNode::Return { token: _, value } => self.on_return_statement(es, value),
} }
} }
} }
impl ast::StmtNode { impl StmtNode {
/// Handle the `print` statement. /// Handle the `print` statement.
fn on_print(&self, es: &mut InterpreterState, expr: &ast::ExprNode) -> InterpreterResult { fn on_print(&self, es: &mut InterpreterState, expr: &ExprNode) -> InterpreterResult {
let value = expr.interpret(es)?.result(); let value = expr.interpret(es)?.result();
let output = match value { let output = match value {
Value::Nil => String::from("nil"), Value::Nil => String::from("nil"),
@ -191,7 +189,7 @@ impl ast::StmtNode {
&self, &self,
es: &mut InterpreterState, es: &mut InterpreterState,
name: &Token, name: &Token,
initializer: &Option<ast::ExprNode>, initializer: &Option<ExprNode>,
) -> InterpreterResult { ) -> InterpreterResult {
let variable = match initializer { let variable = match initializer {
Some(expr) => Some(expr.interpret(es)?.result()), Some(expr) => Some(expr.interpret(es)?.result()),
@ -202,22 +200,21 @@ impl ast::StmtNode {
} }
/// Handle a function declaration. /// Handle a function declaration.
fn on_fun_decl( fn on_fun_decl(&self, es: &mut InterpreterState, decl: &FunDecl) -> InterpreterResult {
&self, let fun = Function::new(
es: &mut InterpreterState, Some(&decl.name),
name: &Token, &decl.params,
params: &[Token], &decl.body,
body: &[ast::StmtNode], es.environment.clone(),
) -> InterpreterResult { );
let fun = Function::new(Some(name), params, body, es.environment.clone());
es.environment es.environment
.borrow_mut() .borrow_mut()
.define(name, Some(Value::Callable(fun)))?; .define(&decl.name, Some(Value::Callable(fun)))?;
Ok(InterpreterFlowControl::default()) Ok(InterpreterFlowControl::default())
} }
/// Execute the contents of a block. /// Execute the contents of a block.
fn on_block(&self, es: &mut InterpreterState, stmts: &[ast::StmtNode]) -> InterpreterResult { fn on_block(&self, es: &mut InterpreterState, stmts: &[StmtNode]) -> InterpreterResult {
let mut child = InterpreterState::create_child(es); let mut child = InterpreterState::create_child(es);
for stmt in stmts.iter() { for stmt in stmts.iter() {
let result = stmt.interpret(&mut child)?; let result = stmt.interpret(&mut child)?;
@ -232,9 +229,9 @@ impl ast::StmtNode {
fn on_if_statement( fn on_if_statement(
&self, &self,
es: &mut InterpreterState, es: &mut InterpreterState,
condition: &ast::ExprNode, condition: &ExprNode,
then_branch: &ast::StmtNode, then_branch: &StmtNode,
else_branch: &Option<Box<ast::StmtNode>>, else_branch: &Option<Box<StmtNode>>,
) -> InterpreterResult { ) -> InterpreterResult {
if condition.interpret(es)?.result().is_truthy() { if condition.interpret(es)?.result().is_truthy() {
then_branch.interpret(es) then_branch.interpret(es)
@ -250,9 +247,9 @@ impl ast::StmtNode {
&self, &self,
es: &mut InterpreterState, es: &mut InterpreterState,
label: &Option<Token>, label: &Option<Token>,
condition: &ast::ExprNode, condition: &ExprNode,
body: &ast::StmtNode, body: &StmtNode,
after_body: &Option<Box<ast::StmtNode>>, after_body: &Option<Box<StmtNode>>,
) -> InterpreterResult { ) -> InterpreterResult {
let ln = label.as_ref().map(|token| token.lexeme.clone()); let ln = label.as_ref().map(|token| token.lexeme.clone());
while condition.interpret(es)?.result().is_truthy() { while condition.interpret(es)?.result().is_truthy() {
@ -294,7 +291,7 @@ impl ast::StmtNode {
fn on_return_statement( fn on_return_statement(
&self, &self,
es: &mut InterpreterState, es: &mut InterpreterState,
value: &Option<ast::ExprNode>, value: &Option<ExprNode>,
) -> InterpreterResult { ) -> InterpreterResult {
let rv = match value { let rv = match value {
None => Value::Nil, None => Value::Nil,
@ -308,48 +305,51 @@ impl ast::StmtNode {
* INTERPRETER FOR EXPRESSION NODES * * INTERPRETER FOR EXPRESSION NODES *
* -------------------------------- */ * -------------------------------- */
impl Interpretable for ast::ExprNode { impl Interpretable for ExprNode {
fn interpret(&self, es: &mut InterpreterState) -> InterpreterResult { fn interpret(&self, es: &mut InterpreterState) -> InterpreterResult {
match self { match self {
ast::ExprNode::Assignment { name, value, id } => { ExprNode::Assignment { name, value, id } => {
let value = value.interpret(es)?.result(); let value = value.interpret(es)?.result();
es.assign_var(name, id, value)?; es.assign_var(name, id, value)?;
Ok(InterpreterFlowControl::default()) Ok(InterpreterFlowControl::default())
} }
ast::ExprNode::Logical { ExprNode::Logical {
left, left,
operator, operator,
right, right,
} => self.on_logic(es, left, operator, right), } => self.on_logic(es, left, operator, right),
ast::ExprNode::Binary { ExprNode::Binary {
left, left,
operator, operator,
right, right,
} => self.on_binary(es, left, operator, right), } => self.on_binary(es, left, operator, right),
ast::ExprNode::Unary { operator, right } => self.on_unary(es, operator, right), ExprNode::Unary { operator, right } => self.on_unary(es, operator, right),
ast::ExprNode::Grouping { expression } => expression.interpret(es), ExprNode::Grouping { expression } => expression.interpret(es),
ast::ExprNode::Litteral { value } => self.on_litteral(value), ExprNode::Litteral { value } => self.on_litteral(value),
ast::ExprNode::Variable { name, id } => Ok(es.lookup_var(name, id)?.into()), ExprNode::Variable { name, id } => Ok(es.lookup_var(name, id)?.into()),
ast::ExprNode::Call { ExprNode::Call {
callee, callee,
right_paren, right_paren,
arguments, arguments,
} => self.on_call(es, callee, right_paren, arguments), } => self.on_call(es, callee, right_paren, arguments),
ast::ExprNode::Lambda { params, body } => { ExprNode::Lambda { params, body } => {
Ok(Value::Callable(Function::new(None, params, body, es.environment.clone())).into()) Ok(
Value::Callable(Function::new(None, params, body, es.environment.clone()))
.into(),
)
} }
} }
} }
} }
impl ast::ExprNode { impl ExprNode {
/// Evaluate a logical operator. /// Evaluate a logical operator.
fn on_logic( fn on_logic(
&self, &self,
es: &mut InterpreterState, es: &mut InterpreterState,
left: &ast::ExprNode, left: &ExprNode,
operator: &Token, operator: &Token,
right: &ast::ExprNode, right: &ExprNode,
) -> InterpreterResult { ) -> InterpreterResult {
let left_value = left.interpret(es)?.result(); let left_value = left.interpret(es)?.result();
if operator.token_type == TokenType::Or && left_value.is_truthy() if operator.token_type == TokenType::Or && left_value.is_truthy()
@ -365,9 +365,9 @@ impl ast::ExprNode {
fn on_binary( fn on_binary(
&self, &self,
es: &mut InterpreterState, es: &mut InterpreterState,
left: &ast::ExprNode, left: &ExprNode,
operator: &Token, operator: &Token,
right: &ast::ExprNode, right: &ExprNode,
) -> InterpreterResult { ) -> InterpreterResult {
let left_value = left.interpret(es)?.result(); let left_value = left.interpret(es)?.result();
let right_value = right.interpret(es)?.result(); let right_value = right.interpret(es)?.result();
@ -437,7 +437,7 @@ impl ast::ExprNode {
&self, &self,
es: &mut InterpreterState, es: &mut InterpreterState,
operator: &Token, operator: &Token,
right: &ast::ExprNode, right: &ExprNode,
) -> InterpreterResult { ) -> InterpreterResult {
let right_value = right.interpret(es)?.result(); let right_value = right.interpret(es)?.result();
match operator.token_type { match operator.token_type {
@ -475,9 +475,9 @@ impl ast::ExprNode {
fn on_call( fn on_call(
&self, &self,
es: &mut InterpreterState, es: &mut InterpreterState,
callee: &ast::ExprNode, callee: &ExprNode,
right_paren: &Token, right_paren: &Token,
arguments: &Vec<ast::ExprNode>, arguments: &Vec<ExprNode>,
) -> InterpreterResult { ) -> InterpreterResult {
let callee = callee.interpret(es)?.result(); let callee = callee.interpret(es)?.result();
let arg_values = { let arg_values = {

View file

@ -1,7 +1,7 @@
use std::collections::HashSet; use std::collections::HashSet;
use crate::{ use crate::{
ast, ast::{ExprNode, FunDecl, ProgramNode, StmtNode},
errors::{ErrorHandler, ErrorKind, SloxError, SloxResult}, errors::{ErrorHandler, ErrorKind, SloxError, SloxResult},
tokens::{Token, TokenType}, tokens::{Token, TokenType},
}; };
@ -76,7 +76,7 @@ impl Parser {
/// Parse the tokens into an AST and return it, or return nothing if a /// Parse the tokens into an AST and return it, or return nothing if a
/// parser error occurs. /// parser error occurs.
pub fn parse(mut self, err_hdl: &mut ErrorHandler) -> SloxResult<ast::ProgramNode> { pub fn parse(mut self, err_hdl: &mut ErrorHandler) -> SloxResult<ProgramNode> {
self.loop_state.push(LoopParsingState::None); self.loop_state.push(LoopParsingState::None);
let result = self.parse_program(err_hdl); let result = self.parse_program(err_hdl);
self.loop_state.pop(); self.loop_state.pop();
@ -113,8 +113,8 @@ impl Parser {
/// ``` /// ```
/// program := statement* /// program := statement*
/// ``` /// ```
fn parse_program(&mut self, err_hdl: &mut ErrorHandler) -> ast::ProgramNode { fn parse_program(&mut self, err_hdl: &mut ErrorHandler) -> ProgramNode {
let mut stmts: Vec<ast::StmtNode> = Vec::new(); let mut stmts: Vec<StmtNode> = Vec::new();
while !self.is_at_end() { while !self.is_at_end() {
match self.parse_statement() { match self.parse_statement() {
Ok(node) => stmts.push(node), Ok(node) => stmts.push(node),
@ -124,7 +124,7 @@ impl Parser {
} }
} }
} }
ast::ProgramNode(stmts) ProgramNode(stmts)
} }
/// Parse the following rule: /// Parse the following rule:
@ -141,7 +141,7 @@ impl Parser {
/// statement := loop_control_statement /// statement := loop_control_statement
/// statement := return_statement /// statement := return_statement
/// ``` /// ```
fn parse_statement(&mut self) -> SloxResult<ast::StmtNode> { fn parse_statement(&mut self) -> SloxResult<StmtNode> {
if self.expect(&[TokenType::Var]).is_some() { if self.expect(&[TokenType::Var]).is_some() {
self.parse_var_declaration() self.parse_var_declaration()
} else if self.expect(&[TokenType::Fun]).is_some() { } else if self.expect(&[TokenType::Fun]).is_some() {
@ -169,7 +169,7 @@ impl Parser {
} else if self.expect(&[TokenType::Print]).is_some() { } else if self.expect(&[TokenType::Print]).is_some() {
let expression = self.parse_expression()?; let expression = self.parse_expression()?;
self.consume(&TokenType::Semicolon, "expected ';' after value")?; self.consume(&TokenType::Semicolon, "expected ';' after value")?;
Ok(ast::StmtNode::Print(expression)) Ok(StmtNode::Print(expression))
} else { } else {
self.parse_expression_stmt() self.parse_expression_stmt()
} }
@ -179,10 +179,10 @@ impl Parser {
/// ``` /// ```
/// expression_stmt := expression ";" /// expression_stmt := expression ";"
/// ``` /// ```
fn parse_expression_stmt(&mut self) -> SloxResult<ast::StmtNode> { fn parse_expression_stmt(&mut self) -> SloxResult<StmtNode> {
let expression = self.parse_expression()?; let expression = self.parse_expression()?;
self.consume(&TokenType::Semicolon, "expected ';' after expression")?; self.consume(&TokenType::Semicolon, "expected ';' after expression")?;
Ok(ast::StmtNode::Expression(expression)) Ok(StmtNode::Expression(expression))
} }
/// Parse the following rule: /// Parse the following rule:
@ -190,12 +190,12 @@ impl Parser {
/// var_declaration := "var" IDENTIFIER ";" /// var_declaration := "var" IDENTIFIER ";"
/// var_declaration := "var" IDENTIFIER "=" expression ";" /// var_declaration := "var" IDENTIFIER "=" expression ";"
/// ``` /// ```
fn parse_var_declaration(&mut self) -> SloxResult<ast::StmtNode> { fn parse_var_declaration(&mut self) -> SloxResult<StmtNode> {
let name = match self.peek().token_type { let name = match self.peek().token_type {
TokenType::Identifier(_) => self.advance().clone(), TokenType::Identifier(_) => self.advance().clone(),
_ => return self.error("expected variable name"), _ => return self.error("expected variable name"),
}; };
let initializer: Option<ast::ExprNode> = match self.expect(&[TokenType::Equal]) { let initializer: Option<ExprNode> = match self.expect(&[TokenType::Equal]) {
Some(_) => Some(self.parse_expression()?), Some(_) => Some(self.parse_expression()?),
None => None, None => None,
}; };
@ -203,7 +203,7 @@ impl Parser {
&TokenType::Semicolon, &TokenType::Semicolon,
"expected ';' after variable declaration", "expected ';' after variable declaration",
)?; )?;
Ok(ast::StmtNode::VarDecl(name, initializer)) Ok(StmtNode::VarDecl(name, initializer))
} }
/// Parse the following rule: /// Parse the following rule:
@ -212,18 +212,18 @@ impl Parser {
/// function := IDENTIFIER function_info /// function := IDENTIFIER function_info
/// ``` /// ```
/// The `kind` parameter is used to generate error messages. /// The `kind` parameter is used to generate error messages.
fn parse_fun_declaration(&mut self, kind: FunctionKind) -> SloxResult<ast::StmtNode> { fn parse_fun_declaration(&mut self, kind: FunctionKind) -> SloxResult<StmtNode> {
// Read the name // Read the name
let name = match self.peek().token_type { let name = match self.peek().token_type {
TokenType::Identifier(_) => self.advance().clone(), TokenType::Identifier(_) => self.advance().clone(),
_ => return self.error_mv(format!("expected {} name", kind.name())), _ => return self.error_mv(format!("expected {} name", kind.name())),
}; };
let (params, block) = self.parse_function_info(kind)?; let (params, block) = self.parse_function_info(kind)?;
Ok(ast::StmtNode::FunDecl { Ok(StmtNode::FunDecl(FunDecl {
name, name,
params, params,
body: block, body: block,
}) }))
} }
/// Parse the following rules: /// Parse the following rules:
@ -234,7 +234,7 @@ impl Parser {
fn parse_function_info( fn parse_function_info(
&mut self, &mut self,
kind: FunctionKind, kind: FunctionKind,
) -> SloxResult<(Vec<Token>, Vec<ast::StmtNode>)> { ) -> SloxResult<(Vec<Token>, Vec<StmtNode>)> {
// Read the list of parameter names // Read the list of parameter names
self.consume( self.consume(
&TokenType::LeftParen, &TokenType::LeftParen,
@ -286,13 +286,13 @@ impl Parser {
/// ``` /// ```
/// block := "{" statement* "}" /// block := "{" statement* "}"
/// ``` /// ```
fn parse_block(&mut self) -> SloxResult<ast::StmtNode> { fn parse_block(&mut self) -> SloxResult<StmtNode> {
let mut stmts: Vec<ast::StmtNode> = Vec::new(); let mut stmts: Vec<StmtNode> = Vec::new();
while !(self.check(&TokenType::RightBrace) || self.is_at_end()) { while !(self.check(&TokenType::RightBrace) || self.is_at_end()) {
stmts.push(self.parse_statement()?); stmts.push(self.parse_statement()?);
} }
self.consume(&TokenType::RightBrace, "expected '}' after block.")?; self.consume(&TokenType::RightBrace, "expected '}' after block.")?;
Ok(ast::StmtNode::Block(stmts)) Ok(StmtNode::Block(stmts))
} }
/// Parse the following rule: /// Parse the following rule:
@ -300,7 +300,7 @@ impl Parser {
/// if_statement := "if" "(" expression ")" statement /// if_statement := "if" "(" expression ")" statement
/// if_statement := "if" "(" expression ")" statement "else" statement /// if_statement := "if" "(" expression ")" statement "else" statement
/// ``` /// ```
fn parse_if_statement(&mut self) -> SloxResult<ast::StmtNode> { fn parse_if_statement(&mut self) -> SloxResult<StmtNode> {
self.consume(&TokenType::LeftParen, "expected '(' after 'if'")?; self.consume(&TokenType::LeftParen, "expected '(' after 'if'")?;
let expression = self.parse_expression()?; let expression = self.parse_expression()?;
self.consume( self.consume(
@ -312,7 +312,7 @@ impl Parser {
Some(_) => Some(Box::new(self.parse_statement()?)), Some(_) => Some(Box::new(self.parse_statement()?)),
None => None, None => None,
}; };
Ok(ast::StmtNode::If { Ok(StmtNode::If {
condition: expression, condition: expression,
then_branch, then_branch,
else_branch, else_branch,
@ -324,7 +324,7 @@ impl Parser {
/// labelled_loop := "@" IDENTIFIER while_statement /// labelled_loop := "@" IDENTIFIER while_statement
/// labelled_loop := "@" IDENTIFIER for_statement /// labelled_loop := "@" IDENTIFIER for_statement
/// ``` /// ```
fn parse_labelled_loop(&mut self) -> SloxResult<ast::StmtNode> { fn parse_labelled_loop(&mut self) -> SloxResult<StmtNode> {
let name_token = match self.peek().token_type { let name_token = match self.peek().token_type {
TokenType::Identifier(_) => self.advance().clone(), TokenType::Identifier(_) => self.advance().clone(),
_ => return self.error("identifier expected after '@'"), _ => return self.error("identifier expected after '@'"),
@ -343,7 +343,7 @@ impl Parser {
/// ``` /// ```
/// while_statement := "while" "(" expression ")" statement /// while_statement := "while" "(" expression ")" statement
/// ``` /// ```
fn parse_while_statement(&mut self, label: Option<Token>) -> SloxResult<ast::StmtNode> { fn parse_while_statement(&mut self, label: Option<Token>) -> SloxResult<StmtNode> {
self.consume(&TokenType::LeftParen, "expected '(' after 'while'")?; self.consume(&TokenType::LeftParen, "expected '(' after 'while'")?;
let condition = self.parse_expression()?; let condition = self.parse_expression()?;
self.consume( self.consume(
@ -356,7 +356,7 @@ impl Parser {
self.loop_state.pop(); self.loop_state.pop();
result? result?
}); });
Ok(ast::StmtNode::Loop { Ok(StmtNode::Loop {
label, label,
condition, condition,
body, body,
@ -371,7 +371,7 @@ impl Parser {
/// for_initializer := expression /// for_initializer := expression
/// for_initializer := /// for_initializer :=
/// ``` /// ```
fn parse_for_statement(&mut self, label: Option<Token>) -> SloxResult<ast::StmtNode> { fn parse_for_statement(&mut self, label: Option<Token>) -> SloxResult<StmtNode> {
self.consume(&TokenType::LeftParen, "expected '(' after 'for'")?; self.consume(&TokenType::LeftParen, "expected '(' after 'for'")?;
let initializer = if self.expect(&[TokenType::Semicolon]).is_some() { let initializer = if self.expect(&[TokenType::Semicolon]).is_some() {
@ -383,7 +383,7 @@ impl Parser {
}; };
let condition = if self.check(&TokenType::Semicolon) { let condition = if self.check(&TokenType::Semicolon) {
ast::ExprNode::Litteral { ExprNode::Litteral {
value: Token { value: Token {
token_type: TokenType::True, token_type: TokenType::True,
lexeme: String::from("true"), lexeme: String::from("true"),
@ -416,14 +416,14 @@ impl Parser {
self.loop_state.pop(); self.loop_state.pop();
result? result?
}; };
let while_stmt = ast::StmtNode::Loop { let while_stmt = StmtNode::Loop {
label, label,
condition, condition,
body: Box::new(body_stmt), body: Box::new(body_stmt),
after_body: increment.map(|incr| Box::new(ast::StmtNode::Expression(incr))), after_body: increment.map(|incr| Box::new(StmtNode::Expression(incr))),
}; };
if let Some(init_stmt) = initializer { if let Some(init_stmt) = initializer {
Ok(ast::StmtNode::Block(vec![init_stmt, while_stmt])) Ok(StmtNode::Block(vec![init_stmt, while_stmt]))
} else { } else {
Ok(while_stmt) Ok(while_stmt)
} }
@ -434,7 +434,7 @@ impl Parser {
/// loop_control_statement := "break" ( IDENTIFIER )? ";" /// loop_control_statement := "break" ( IDENTIFIER )? ";"
/// loop_control_statement := "continue" ( IDENTIFIER )? ";" /// loop_control_statement := "continue" ( IDENTIFIER )? ";"
/// ``` /// ```
fn parse_loop_control_statement(&mut self, stmt_token: &Token) -> SloxResult<ast::StmtNode> { fn parse_loop_control_statement(&mut self, stmt_token: &Token) -> SloxResult<StmtNode> {
if self.loop_state() == &LoopParsingState::None { if self.loop_state() == &LoopParsingState::None {
return Err(SloxError::with_token( return Err(SloxError::with_token(
ErrorKind::Parse, ErrorKind::Parse,
@ -465,7 +465,7 @@ impl Parser {
&TokenType::Semicolon, &TokenType::Semicolon,
"';' expected after loop control statement", "';' expected after loop control statement",
)?; )?;
Ok(ast::StmtNode::LoopControl { Ok(StmtNode::LoopControl {
is_break: stmt_token.token_type == TokenType::Break, is_break: stmt_token.token_type == TokenType::Break,
loop_name, loop_name,
}) })
@ -475,7 +475,7 @@ impl Parser {
/// ``` /// ```
/// return_statement := "return" expression? ";" /// return_statement := "return" expression? ";"
/// ``` /// ```
fn parse_return_statement(&mut self, ret_token: &Token) -> SloxResult<ast::StmtNode> { fn parse_return_statement(&mut self, ret_token: &Token) -> SloxResult<StmtNode> {
if self.can_use_return() { if self.can_use_return() {
let value = if self.check(&TokenType::Semicolon) { let value = if self.check(&TokenType::Semicolon) {
None None
@ -483,7 +483,7 @@ impl Parser {
Some(self.parse_expression()?) Some(self.parse_expression()?)
}; };
self.consume(&TokenType::Semicolon, "';' expected after return statement")?; self.consume(&TokenType::Semicolon, "';' expected after return statement")?;
Ok(ast::StmtNode::Return { Ok(StmtNode::Return {
token: ret_token.clone(), token: ret_token.clone(),
value, value,
}) })
@ -500,7 +500,7 @@ impl Parser {
/// ``` /// ```
/// expression := assignment /// expression := assignment
/// ``` /// ```
fn parse_expression(&mut self) -> SloxResult<ast::ExprNode> { fn parse_expression(&mut self) -> SloxResult<ExprNode> {
self.parse_assignment() self.parse_assignment()
} }
@ -509,12 +509,12 @@ impl Parser {
/// assignment := IDENTIFIER "=" equality /// assignment := IDENTIFIER "=" equality
/// assignment := equality /// assignment := equality
/// ``` /// ```
fn parse_assignment(&mut self) -> SloxResult<ast::ExprNode> { fn parse_assignment(&mut self) -> SloxResult<ExprNode> {
let expr = self.parse_logic_or()?; let expr = self.parse_logic_or()?;
if let Some(equals) = self.expect(&[TokenType::Equal]) { if let Some(equals) = self.expect(&[TokenType::Equal]) {
let value = self.parse_assignment()?; let value = self.parse_assignment()?;
if let ast::ExprNode::Variable { name, id: _ } = expr { if let ExprNode::Variable { name, id: _ } = expr {
Ok(ast::ExprNode::Assignment { Ok(ExprNode::Assignment {
name, name,
value: Box::new(value), value: Box::new(value),
id: self.make_id(), id: self.make_id(),
@ -535,11 +535,11 @@ impl Parser {
/// ``` /// ```
/// logic_or := logic_and ( "or" logic_and )* /// logic_or := logic_and ( "or" logic_and )*
/// ``` /// ```
fn parse_logic_or(&mut self) -> SloxResult<ast::ExprNode> { fn parse_logic_or(&mut self) -> SloxResult<ExprNode> {
let mut expr = self.parse_logic_and()?; let mut expr = self.parse_logic_and()?;
while let Some(operator) = self.expect(&[TokenType::Or]) { while let Some(operator) = self.expect(&[TokenType::Or]) {
let right = self.parse_logic_and()?; let right = self.parse_logic_and()?;
expr = ast::ExprNode::Logical { expr = ExprNode::Logical {
left: Box::new(expr), left: Box::new(expr),
operator: operator.clone(), operator: operator.clone(),
right: Box::new(right), right: Box::new(right),
@ -552,11 +552,11 @@ impl Parser {
/// ``` /// ```
/// logic_and := equality ( "and" equality )* /// logic_and := equality ( "and" equality )*
/// ``` /// ```
fn parse_logic_and(&mut self) -> SloxResult<ast::ExprNode> { fn parse_logic_and(&mut self) -> SloxResult<ExprNode> {
let mut expr = self.parse_equality()?; let mut expr = self.parse_equality()?;
while let Some(operator) = self.expect(&[TokenType::And]) { while let Some(operator) = self.expect(&[TokenType::And]) {
let right = self.parse_equality()?; let right = self.parse_equality()?;
expr = ast::ExprNode::Logical { expr = ExprNode::Logical {
left: Box::new(expr), left: Box::new(expr),
operator: operator.clone(), operator: operator.clone(),
right: Box::new(right), right: Box::new(right),
@ -570,11 +570,11 @@ impl Parser {
/// equality := comparison "==" comparison /// equality := comparison "==" comparison
/// equality := comparison "!=" comparison /// equality := comparison "!=" comparison
/// ``` /// ```
fn parse_equality(&mut self) -> SloxResult<ast::ExprNode> { fn parse_equality(&mut self) -> SloxResult<ExprNode> {
let mut expr = self.parse_comparison()?; let mut expr = self.parse_comparison()?;
while let Some(operator) = self.expect(&[TokenType::BangEqual, TokenType::EqualEqual]) { while let Some(operator) = self.expect(&[TokenType::BangEqual, TokenType::EqualEqual]) {
let right = self.parse_comparison()?; let right = self.parse_comparison()?;
expr = ast::ExprNode::Binary { expr = ExprNode::Binary {
left: Box::new(expr), left: Box::new(expr),
operator: operator.clone(), operator: operator.clone(),
right: Box::new(right), right: Box::new(right),
@ -588,7 +588,7 @@ impl Parser {
/// comparison := term comparison_operator term /// comparison := term comparison_operator term
/// comparison_operator := "<" | "<=" | ">" | ">=" /// comparison_operator := "<" | "<=" | ">" | ">="
/// ``` /// ```
fn parse_comparison(&mut self) -> SloxResult<ast::ExprNode> { fn parse_comparison(&mut self) -> SloxResult<ExprNode> {
let mut expr = self.parse_term()?; let mut expr = self.parse_term()?;
while let Some(operator) = self.expect(&[ while let Some(operator) = self.expect(&[
TokenType::Greater, TokenType::Greater,
@ -597,7 +597,7 @@ impl Parser {
TokenType::LessEqual, TokenType::LessEqual,
]) { ]) {
let right = self.parse_term()?; let right = self.parse_term()?;
expr = ast::ExprNode::Binary { expr = ExprNode::Binary {
left: Box::new(expr), left: Box::new(expr),
operator: operator.clone(), operator: operator.clone(),
right: Box::new(right), right: Box::new(right),
@ -611,11 +611,11 @@ impl Parser {
/// term := factor ( "+" factor )* /// term := factor ( "+" factor )*
/// term := factor ( "-" factor )* /// term := factor ( "-" factor )*
/// ``` /// ```
fn parse_term(&mut self) -> SloxResult<ast::ExprNode> { fn parse_term(&mut self) -> SloxResult<ExprNode> {
let mut expr = self.parse_factor()?; let mut expr = self.parse_factor()?;
while let Some(operator) = self.expect(&[TokenType::Minus, TokenType::Plus]) { while let Some(operator) = self.expect(&[TokenType::Minus, TokenType::Plus]) {
let right = self.parse_factor()?; let right = self.parse_factor()?;
expr = ast::ExprNode::Binary { expr = ExprNode::Binary {
left: Box::new(expr), left: Box::new(expr),
operator: operator.clone(), operator: operator.clone(),
right: Box::new(right), right: Box::new(right),
@ -629,11 +629,11 @@ impl Parser {
/// factor := unary ( "*" unary )* /// factor := unary ( "*" unary )*
/// factor := unary ( "/" unary )* /// factor := unary ( "/" unary )*
/// ``` /// ```
fn parse_factor(&mut self) -> SloxResult<ast::ExprNode> { fn parse_factor(&mut self) -> SloxResult<ExprNode> {
let mut expr = self.parse_unary()?; let mut expr = self.parse_unary()?;
while let Some(operator) = self.expect(&[TokenType::Slash, TokenType::Star]) { while let Some(operator) = self.expect(&[TokenType::Slash, TokenType::Star]) {
let right = self.parse_unary()?; let right = self.parse_unary()?;
expr = ast::ExprNode::Binary { expr = ExprNode::Binary {
left: Box::new(expr), left: Box::new(expr),
operator: operator.clone(), operator: operator.clone(),
right: Box::new(right), right: Box::new(right),
@ -648,9 +648,9 @@ impl Parser {
/// unary := "!" unary /// unary := "!" unary
/// unary := primary call_arguments* /// unary := primary call_arguments*
/// ``` /// ```
fn parse_unary(&mut self) -> SloxResult<ast::ExprNode> { fn parse_unary(&mut self) -> SloxResult<ExprNode> {
if let Some(operator) = self.expect(&[TokenType::Bang, TokenType::Minus]) { if let Some(operator) = self.expect(&[TokenType::Bang, TokenType::Minus]) {
Ok(ast::ExprNode::Unary { Ok(ExprNode::Unary {
operator, operator,
right: Box::new(self.parse_unary()?), right: Box::new(self.parse_unary()?),
}) })
@ -670,26 +670,26 @@ impl Parser {
/// primary := IDENTIFIER /// primary := IDENTIFIER
/// primary := "fun" function_info /// primary := "fun" function_info
/// ``` /// ```
fn parse_primary(&mut self) -> SloxResult<ast::ExprNode> { fn parse_primary(&mut self) -> SloxResult<ExprNode> {
if self.expect(&[TokenType::LeftParen]).is_some() { if self.expect(&[TokenType::LeftParen]).is_some() {
let expr = self.parse_expression()?; let expr = self.parse_expression()?;
self.consume(&TokenType::RightParen, "expected ')' after expression")?; self.consume(&TokenType::RightParen, "expected ')' after expression")?;
Ok(ast::ExprNode::Grouping { Ok(ExprNode::Grouping {
expression: Box::new(expr), expression: Box::new(expr),
}) })
} else if self.expect(&[TokenType::Fun]).is_some() { } else if self.expect(&[TokenType::Fun]).is_some() {
let (params, body) = self.parse_function_info(FunctionKind::Lambda)?; let (params, body) = self.parse_function_info(FunctionKind::Lambda)?;
Ok(ast::ExprNode::Lambda { params, body }) Ok(ExprNode::Lambda { params, body })
} else if let Some(token) = } else if let Some(token) =
self.expect(&[TokenType::False, TokenType::True, TokenType::Nil]) self.expect(&[TokenType::False, TokenType::True, TokenType::Nil])
{ {
Ok(ast::ExprNode::Litteral { value: token }) Ok(ExprNode::Litteral { value: token })
} else { } else {
match &self.peek().token_type { match &self.peek().token_type {
TokenType::Number(_) | &TokenType::String(_) => Ok(ast::ExprNode::Litteral { TokenType::Number(_) | &TokenType::String(_) => Ok(ExprNode::Litteral {
value: self.advance().clone(), value: self.advance().clone(),
}), }),
TokenType::Identifier(_) => Ok(ast::ExprNode::Variable { TokenType::Identifier(_) => Ok(ExprNode::Variable {
name: self.advance().clone(), name: self.advance().clone(),
id: self.make_id(), id: self.make_id(),
}), }),
@ -703,7 +703,7 @@ impl Parser {
/// call := expression "(" arguments? ")" /// call := expression "(" arguments? ")"
/// arguments := expression ( "," expression )* /// arguments := expression ( "," expression )*
/// ``` /// ```
fn parse_call_arguments(&mut self, callee: ast::ExprNode) -> Result<ast::ExprNode, SloxError> { fn parse_call_arguments(&mut self, callee: ExprNode) -> Result<ExprNode, SloxError> {
let mut arguments = Vec::new(); let mut arguments = Vec::new();
if !self.check(&TokenType::RightParen) { if !self.check(&TokenType::RightParen) {
loop { loop {
@ -719,7 +719,7 @@ impl Parser {
let right_paren = self let right_paren = self
.consume(&TokenType::RightParen, "')' expected after arguments")? .consume(&TokenType::RightParen, "')' expected after arguments")?
.clone(); .clone();
Ok(ast::ExprNode::Call { Ok(ExprNode::Call {
callee: Box::new(callee), callee: Box::new(callee),
right_paren, right_paren,
arguments, arguments,

View file

@ -1,7 +1,7 @@
use std::collections::HashMap; use std::collections::HashMap;
use crate::{ use crate::{
ast, ast::{ExprNode, ProgramNode, StmtNode},
errors::{ErrorKind, SloxError, SloxResult}, errors::{ErrorKind, SloxError, SloxResult},
tokens::Token, tokens::Token,
}; };
@ -12,7 +12,7 @@ use crate::{
pub type ResolvedVariables = HashMap<usize, usize>; pub type ResolvedVariables = HashMap<usize, usize>;
/// Resolve all variables in a program's AST. /// Resolve all variables in a program's AST.
pub fn resolve_variables(program: &ast::ProgramNode) -> SloxResult<ResolvedVariables> { pub fn resolve_variables(program: &ProgramNode) -> SloxResult<ResolvedVariables> {
let mut state = ResolverState::default(); let mut state = ResolverState::default();
state state
.with_scope(|rs| program.resolve(rs)) .with_scope(|rs| program.resolve(rs))
@ -193,7 +193,7 @@ impl<'a> ResolverState<'a> {
fn resolve_function<'a, 'b>( fn resolve_function<'a, 'b>(
rs: &mut ResolverState<'a>, rs: &mut ResolverState<'a>,
params: &'b [Token], params: &'b [Token],
body: &'b Vec<ast::StmtNode>, body: &'b Vec<StmtNode>,
) -> ResolverResult ) -> ResolverResult
where where
'b: 'a, 'b: 'a,
@ -215,7 +215,7 @@ trait VarResolver {
'a: 'b; 'a: 'b;
} }
impl VarResolver for ast::ProgramNode { impl VarResolver for ProgramNode {
fn resolve<'a, 'b>(&'a self, rs: &mut ResolverState<'b>) -> ResolverResult fn resolve<'a, 'b>(&'a self, rs: &mut ResolverState<'b>) -> ResolverResult
where where
'a: 'b, 'a: 'b,
@ -224,7 +224,7 @@ impl VarResolver for ast::ProgramNode {
} }
} }
impl VarResolver for Vec<ast::StmtNode> { impl VarResolver for Vec<StmtNode> {
fn resolve<'a, 'b>(&'a self, rs: &mut ResolverState<'b>) -> ResolverResult fn resolve<'a, 'b>(&'a self, rs: &mut ResolverState<'b>) -> ResolverResult
where where
'a: 'b, 'a: 'b,
@ -236,37 +236,37 @@ impl VarResolver for Vec<ast::StmtNode> {
} }
} }
impl VarResolver for ast::StmtNode { impl VarResolver for StmtNode {
fn resolve<'a, 'b>(&'a self, rs: &mut ResolverState<'b>) -> ResolverResult fn resolve<'a, 'b>(&'a self, rs: &mut ResolverState<'b>) -> ResolverResult
where where
'a: 'b, 'a: 'b,
{ {
match self { match self {
ast::StmtNode::Block(stmts) => rs.with_scope(|rs| stmts.resolve(rs)), StmtNode::Block(stmts) => rs.with_scope(|rs| stmts.resolve(rs)),
ast::StmtNode::VarDecl(name, None) => { StmtNode::VarDecl(name, None) => {
rs.declare(name, SymKind::Variable)?; rs.declare(name, SymKind::Variable)?;
Ok(()) Ok(())
} }
ast::StmtNode::VarDecl(name, Some(init)) => { StmtNode::VarDecl(name, Some(init)) => {
rs.declare(name, SymKind::Variable)?; rs.declare(name, SymKind::Variable)?;
init.resolve(rs)?; init.resolve(rs)?;
rs.define(name); rs.define(name);
Ok(()) Ok(())
} }
ast::StmtNode::FunDecl { name, params, body } => { StmtNode::FunDecl(decl) => {
rs.declare(name, SymKind::Function)?; rs.declare(&decl.name, SymKind::Function)?;
rs.define(name); rs.define(&decl.name);
rs.with_scope(|rs| resolve_function(rs, params, body)) rs.with_scope(|rs| resolve_function(rs, &decl.params, &decl.body))
} }
ast::StmtNode::If { StmtNode::If {
condition, condition,
then_branch, then_branch,
else_branch: None, else_branch: None,
} => condition.resolve(rs).and_then(|_| then_branch.resolve(rs)), } => condition.resolve(rs).and_then(|_| then_branch.resolve(rs)),
ast::StmtNode::If { StmtNode::If {
condition, condition,
then_branch, then_branch,
else_branch: Some(else_branch), else_branch: Some(else_branch),
@ -275,7 +275,7 @@ impl VarResolver for ast::StmtNode {
.and_then(|_| then_branch.resolve(rs)) .and_then(|_| then_branch.resolve(rs))
.and_then(|_| else_branch.resolve(rs)), .and_then(|_| else_branch.resolve(rs)),
ast::StmtNode::Loop { StmtNode::Loop {
label: _, label: _,
condition, condition,
body, body,
@ -291,18 +291,18 @@ impl VarResolver for ast::StmtNode {
}) })
.and_then(|_| body.resolve(rs)), .and_then(|_| body.resolve(rs)),
ast::StmtNode::Return { StmtNode::Return {
token: _, token: _,
value: None, value: None,
} => Ok(()), } => Ok(()),
ast::StmtNode::Return { StmtNode::Return {
token: _, token: _,
value: Some(expr), value: Some(expr),
} => expr.resolve(rs), } => expr.resolve(rs),
ast::StmtNode::Expression(expr) => expr.resolve(rs), StmtNode::Expression(expr) => expr.resolve(rs),
ast::StmtNode::Print(expr) => expr.resolve(rs), StmtNode::Print(expr) => expr.resolve(rs),
ast::StmtNode::LoopControl { StmtNode::LoopControl {
is_break: _, is_break: _,
loop_name: _, loop_name: _,
} => Ok(()), } => Ok(()),
@ -310,37 +310,37 @@ impl VarResolver for ast::StmtNode {
} }
} }
impl VarResolver for ast::ExprNode { impl VarResolver for ExprNode {
fn resolve<'a, 'b>(&'a self, rs: &mut ResolverState<'b>) -> ResolverResult fn resolve<'a, 'b>(&'a self, rs: &mut ResolverState<'b>) -> ResolverResult
where where
'a: 'b, 'a: 'b,
{ {
match self { match self {
ast::ExprNode::Variable { name, id } => rs.resolve_use(id, name), ExprNode::Variable { name, id } => rs.resolve_use(id, name),
ast::ExprNode::Assignment { name, value, id } => { ExprNode::Assignment { name, value, id } => {
value.resolve(rs)?; value.resolve(rs)?;
rs.resolve_assignment(id, name) rs.resolve_assignment(id, name)
} }
ast::ExprNode::Lambda { params, body } => { ExprNode::Lambda { params, body } => {
rs.with_scope(|rs| resolve_function(rs, params, body)) rs.with_scope(|rs| resolve_function(rs, params, body))
} }
ast::ExprNode::Logical { ExprNode::Logical {
left, left,
operator: _, operator: _,
right, right,
} => left.resolve(rs).and_then(|_| right.resolve(rs)), } => left.resolve(rs).and_then(|_| right.resolve(rs)),
ast::ExprNode::Binary { ExprNode::Binary {
left, left,
operator: _, operator: _,
right, right,
} => left.resolve(rs).and_then(|_| right.resolve(rs)), } => left.resolve(rs).and_then(|_| right.resolve(rs)),
ast::ExprNode::Unary { operator: _, right } => right.resolve(rs), ExprNode::Unary { operator: _, right } => right.resolve(rs),
ast::ExprNode::Grouping { expression } => expression.resolve(rs), ExprNode::Grouping { expression } => expression.resolve(rs),
ast::ExprNode::Litteral { value: _ } => Ok(()), ExprNode::Litteral { value: _ } => Ok(()),
ast::ExprNode::Call { ExprNode::Call {
callee, callee,
right_paren: _, right_paren: _,
arguments, arguments,
@ -349,7 +349,7 @@ impl VarResolver for ast::ExprNode {
} }
} }
impl VarResolver for Vec<ast::ExprNode> { impl VarResolver for Vec<ExprNode> {
fn resolve<'a, 'b>(&'a self, rs: &mut ResolverState<'b>) -> ResolverResult fn resolve<'a, 'b>(&'a self, rs: &mut ResolverState<'b>) -> ResolverResult
where where
'a: 'b, 'a: 'b,