AST - Refactored function declarations as a separate struct

This commit is contained in:
Emmanuel BENOîT 2023-01-07 12:18:18 +01:00
parent 6252bed605
commit 1af25457d5
4 changed files with 158 additions and 154 deletions

View file

@ -8,17 +8,21 @@ use crate::tokens::Token;
#[derive(Default, Debug, Clone)]
pub struct ProgramNode(pub Vec<StmtNode>);
/// A function declaration.
#[derive(Debug, Clone)]
pub struct FunDecl {
pub name: Token,
pub params: Vec<Token>,
pub body: Vec<StmtNode>,
}
/// An AST node that represents a statement.
#[derive(Debug, Clone)]
pub enum StmtNode {
/// A variable declaration
VarDecl(Token, Option<ExprNode>),
/// A function declaration
FunDecl {
name: Token,
params: Vec<Token>,
body: Vec<StmtNode>,
},
FunDecl(FunDecl),
/// An single expression
Expression(ExprNode),
/// The print statement
@ -149,15 +153,15 @@ impl AstDumper for StmtNode {
Self::VarDecl(name, Some(expr)) => format!("( var {} {} )", name.lexeme, expr.dump()),
Self::VarDecl(name, None) => format!("( var {} nil )", name.lexeme),
Self::FunDecl { name, params, body } => format!(
Self::FunDecl(fun_decl) => format!(
"( fun {} ({}) {} )",
name.lexeme,
params
fun_decl.name.lexeme,
fun_decl.params
.iter()
.map(|token| &token.lexeme as &str)
.collect::<Vec<&str>>()
.join(" "),
body.iter()
fun_decl.body.iter()
.map(|stmt| stmt.dump())
.collect::<Vec<String>>()
.join(" ")

View file

@ -1,7 +1,7 @@
use std::{cell::RefCell, rc::Rc};
use crate::{
ast,
ast::{ExprNode, FunDecl, ProgramNode, StmtNode},
errors::{ErrorKind, SloxError, SloxResult},
interpreter::{functions::Function, Environment, EnvironmentRef, Value},
resolver::ResolvedVariables,
@ -9,7 +9,7 @@ use crate::{
};
/// Evaluate an interpretable, returning its value.
pub fn evaluate(ast: &ast::ProgramNode, vars: ResolvedVariables) -> SloxResult<Value> {
pub fn evaluate(ast: &ProgramNode, vars: ResolvedVariables) -> SloxResult<Value> {
let mut state = InterpreterState::new(&vars);
ast.interpret(&mut state).map(|v| v.result())
}
@ -127,7 +127,7 @@ fn error<T>(token: &Token, message: &str) -> SloxResult<T> {
* INTERPRETER FOR PROGRAM NODES *
* ----------------------------- */
impl Interpretable for ast::ProgramNode {
impl Interpretable for ProgramNode {
fn interpret(&self, es: &mut InterpreterState) -> InterpreterResult {
for stmt in self.0.iter() {
stmt.interpret(es)?;
@ -140,39 +140,37 @@ impl Interpretable for ast::ProgramNode {
* INTERPRETER FOR STATEMENT NODES *
* ------------------------------- */
impl Interpretable for ast::StmtNode {
impl Interpretable for StmtNode {
fn interpret(&self, es: &mut InterpreterState) -> InterpreterResult {
match self {
ast::StmtNode::VarDecl(name, expr) => self.on_var_decl(es, name, expr),
ast::StmtNode::FunDecl { name, params, body } => {
self.on_fun_decl(es, name, params, body)
}
ast::StmtNode::Expression(expr) => expr.interpret(es),
ast::StmtNode::Print(expr) => self.on_print(es, expr),
ast::StmtNode::Block(statements) => self.on_block(es, statements),
ast::StmtNode::If {
StmtNode::VarDecl(name, expr) => self.on_var_decl(es, name, expr),
StmtNode::FunDecl(decl) => self.on_fun_decl(es, decl),
StmtNode::Expression(expr) => expr.interpret(es),
StmtNode::Print(expr) => self.on_print(es, expr),
StmtNode::Block(statements) => self.on_block(es, statements),
StmtNode::If {
condition,
then_branch,
else_branch,
} => self.on_if_statement(es, condition, then_branch, else_branch),
ast::StmtNode::Loop {
StmtNode::Loop {
label,
condition,
body,
after_body,
} => self.on_loop_statement(es, label, condition, body, after_body),
ast::StmtNode::LoopControl {
StmtNode::LoopControl {
is_break,
loop_name,
} => self.on_loop_control_statemement(*is_break, loop_name),
ast::StmtNode::Return { token: _, value } => self.on_return_statement(es, value),
StmtNode::Return { token: _, value } => self.on_return_statement(es, value),
}
}
}
impl ast::StmtNode {
impl StmtNode {
/// Handle the `print` statement.
fn on_print(&self, es: &mut InterpreterState, expr: &ast::ExprNode) -> InterpreterResult {
fn on_print(&self, es: &mut InterpreterState, expr: &ExprNode) -> InterpreterResult {
let value = expr.interpret(es)?.result();
let output = match value {
Value::Nil => String::from("nil"),
@ -191,7 +189,7 @@ impl ast::StmtNode {
&self,
es: &mut InterpreterState,
name: &Token,
initializer: &Option<ast::ExprNode>,
initializer: &Option<ExprNode>,
) -> InterpreterResult {
let variable = match initializer {
Some(expr) => Some(expr.interpret(es)?.result()),
@ -202,22 +200,21 @@ impl ast::StmtNode {
}
/// Handle a function declaration.
fn on_fun_decl(
&self,
es: &mut InterpreterState,
name: &Token,
params: &[Token],
body: &[ast::StmtNode],
) -> InterpreterResult {
let fun = Function::new(Some(name), params, body, es.environment.clone());
fn on_fun_decl(&self, es: &mut InterpreterState, decl: &FunDecl) -> InterpreterResult {
let fun = Function::new(
Some(&decl.name),
&decl.params,
&decl.body,
es.environment.clone(),
);
es.environment
.borrow_mut()
.define(name, Some(Value::Callable(fun)))?;
.define(&decl.name, Some(Value::Callable(fun)))?;
Ok(InterpreterFlowControl::default())
}
/// Execute the contents of a block.
fn on_block(&self, es: &mut InterpreterState, stmts: &[ast::StmtNode]) -> InterpreterResult {
fn on_block(&self, es: &mut InterpreterState, stmts: &[StmtNode]) -> InterpreterResult {
let mut child = InterpreterState::create_child(es);
for stmt in stmts.iter() {
let result = stmt.interpret(&mut child)?;
@ -232,9 +229,9 @@ impl ast::StmtNode {
fn on_if_statement(
&self,
es: &mut InterpreterState,
condition: &ast::ExprNode,
then_branch: &ast::StmtNode,
else_branch: &Option<Box<ast::StmtNode>>,
condition: &ExprNode,
then_branch: &StmtNode,
else_branch: &Option<Box<StmtNode>>,
) -> InterpreterResult {
if condition.interpret(es)?.result().is_truthy() {
then_branch.interpret(es)
@ -250,9 +247,9 @@ impl ast::StmtNode {
&self,
es: &mut InterpreterState,
label: &Option<Token>,
condition: &ast::ExprNode,
body: &ast::StmtNode,
after_body: &Option<Box<ast::StmtNode>>,
condition: &ExprNode,
body: &StmtNode,
after_body: &Option<Box<StmtNode>>,
) -> InterpreterResult {
let ln = label.as_ref().map(|token| token.lexeme.clone());
while condition.interpret(es)?.result().is_truthy() {
@ -294,7 +291,7 @@ impl ast::StmtNode {
fn on_return_statement(
&self,
es: &mut InterpreterState,
value: &Option<ast::ExprNode>,
value: &Option<ExprNode>,
) -> InterpreterResult {
let rv = match value {
None => Value::Nil,
@ -308,48 +305,51 @@ impl ast::StmtNode {
* INTERPRETER FOR EXPRESSION NODES *
* -------------------------------- */
impl Interpretable for ast::ExprNode {
impl Interpretable for ExprNode {
fn interpret(&self, es: &mut InterpreterState) -> InterpreterResult {
match self {
ast::ExprNode::Assignment { name, value, id } => {
ExprNode::Assignment { name, value, id } => {
let value = value.interpret(es)?.result();
es.assign_var(name, id, value)?;
Ok(InterpreterFlowControl::default())
}
ast::ExprNode::Logical {
ExprNode::Logical {
left,
operator,
right,
} => self.on_logic(es, left, operator, right),
ast::ExprNode::Binary {
ExprNode::Binary {
left,
operator,
right,
} => self.on_binary(es, left, operator, right),
ast::ExprNode::Unary { operator, right } => self.on_unary(es, operator, right),
ast::ExprNode::Grouping { expression } => expression.interpret(es),
ast::ExprNode::Litteral { value } => self.on_litteral(value),
ast::ExprNode::Variable { name, id } => Ok(es.lookup_var(name, id)?.into()),
ast::ExprNode::Call {
ExprNode::Unary { operator, right } => self.on_unary(es, operator, right),
ExprNode::Grouping { expression } => expression.interpret(es),
ExprNode::Litteral { value } => self.on_litteral(value),
ExprNode::Variable { name, id } => Ok(es.lookup_var(name, id)?.into()),
ExprNode::Call {
callee,
right_paren,
arguments,
} => self.on_call(es, callee, right_paren, arguments),
ast::ExprNode::Lambda { params, body } => {
Ok(Value::Callable(Function::new(None, params, body, es.environment.clone())).into())
ExprNode::Lambda { params, body } => {
Ok(
Value::Callable(Function::new(None, params, body, es.environment.clone()))
.into(),
)
}
}
}
}
impl ast::ExprNode {
impl ExprNode {
/// Evaluate a logical operator.
fn on_logic(
&self,
es: &mut InterpreterState,
left: &ast::ExprNode,
left: &ExprNode,
operator: &Token,
right: &ast::ExprNode,
right: &ExprNode,
) -> InterpreterResult {
let left_value = left.interpret(es)?.result();
if operator.token_type == TokenType::Or && left_value.is_truthy()
@ -365,9 +365,9 @@ impl ast::ExprNode {
fn on_binary(
&self,
es: &mut InterpreterState,
left: &ast::ExprNode,
left: &ExprNode,
operator: &Token,
right: &ast::ExprNode,
right: &ExprNode,
) -> InterpreterResult {
let left_value = left.interpret(es)?.result();
let right_value = right.interpret(es)?.result();
@ -437,7 +437,7 @@ impl ast::ExprNode {
&self,
es: &mut InterpreterState,
operator: &Token,
right: &ast::ExprNode,
right: &ExprNode,
) -> InterpreterResult {
let right_value = right.interpret(es)?.result();
match operator.token_type {
@ -475,9 +475,9 @@ impl ast::ExprNode {
fn on_call(
&self,
es: &mut InterpreterState,
callee: &ast::ExprNode,
callee: &ExprNode,
right_paren: &Token,
arguments: &Vec<ast::ExprNode>,
arguments: &Vec<ExprNode>,
) -> InterpreterResult {
let callee = callee.interpret(es)?.result();
let arg_values = {

View file

@ -1,7 +1,7 @@
use std::collections::HashSet;
use crate::{
ast,
ast::{ExprNode, FunDecl, ProgramNode, StmtNode},
errors::{ErrorHandler, ErrorKind, SloxError, SloxResult},
tokens::{Token, TokenType},
};
@ -76,7 +76,7 @@ impl Parser {
/// Parse the tokens into an AST and return it, or return nothing if a
/// parser error occurs.
pub fn parse(mut self, err_hdl: &mut ErrorHandler) -> SloxResult<ast::ProgramNode> {
pub fn parse(mut self, err_hdl: &mut ErrorHandler) -> SloxResult<ProgramNode> {
self.loop_state.push(LoopParsingState::None);
let result = self.parse_program(err_hdl);
self.loop_state.pop();
@ -113,8 +113,8 @@ impl Parser {
/// ```
/// program := statement*
/// ```
fn parse_program(&mut self, err_hdl: &mut ErrorHandler) -> ast::ProgramNode {
let mut stmts: Vec<ast::StmtNode> = Vec::new();
fn parse_program(&mut self, err_hdl: &mut ErrorHandler) -> ProgramNode {
let mut stmts: Vec<StmtNode> = Vec::new();
while !self.is_at_end() {
match self.parse_statement() {
Ok(node) => stmts.push(node),
@ -124,7 +124,7 @@ impl Parser {
}
}
}
ast::ProgramNode(stmts)
ProgramNode(stmts)
}
/// Parse the following rule:
@ -141,7 +141,7 @@ impl Parser {
/// statement := loop_control_statement
/// statement := return_statement
/// ```
fn parse_statement(&mut self) -> SloxResult<ast::StmtNode> {
fn parse_statement(&mut self) -> SloxResult<StmtNode> {
if self.expect(&[TokenType::Var]).is_some() {
self.parse_var_declaration()
} else if self.expect(&[TokenType::Fun]).is_some() {
@ -169,7 +169,7 @@ impl Parser {
} else if self.expect(&[TokenType::Print]).is_some() {
let expression = self.parse_expression()?;
self.consume(&TokenType::Semicolon, "expected ';' after value")?;
Ok(ast::StmtNode::Print(expression))
Ok(StmtNode::Print(expression))
} else {
self.parse_expression_stmt()
}
@ -179,10 +179,10 @@ impl Parser {
/// ```
/// expression_stmt := expression ";"
/// ```
fn parse_expression_stmt(&mut self) -> SloxResult<ast::StmtNode> {
fn parse_expression_stmt(&mut self) -> SloxResult<StmtNode> {
let expression = self.parse_expression()?;
self.consume(&TokenType::Semicolon, "expected ';' after expression")?;
Ok(ast::StmtNode::Expression(expression))
Ok(StmtNode::Expression(expression))
}
/// Parse the following rule:
@ -190,12 +190,12 @@ impl Parser {
/// var_declaration := "var" IDENTIFIER ";"
/// var_declaration := "var" IDENTIFIER "=" expression ";"
/// ```
fn parse_var_declaration(&mut self) -> SloxResult<ast::StmtNode> {
fn parse_var_declaration(&mut self) -> SloxResult<StmtNode> {
let name = match self.peek().token_type {
TokenType::Identifier(_) => self.advance().clone(),
_ => return self.error("expected variable name"),
};
let initializer: Option<ast::ExprNode> = match self.expect(&[TokenType::Equal]) {
let initializer: Option<ExprNode> = match self.expect(&[TokenType::Equal]) {
Some(_) => Some(self.parse_expression()?),
None => None,
};
@ -203,7 +203,7 @@ impl Parser {
&TokenType::Semicolon,
"expected ';' after variable declaration",
)?;
Ok(ast::StmtNode::VarDecl(name, initializer))
Ok(StmtNode::VarDecl(name, initializer))
}
/// Parse the following rule:
@ -212,18 +212,18 @@ impl Parser {
/// function := IDENTIFIER function_info
/// ```
/// The `kind` parameter is used to generate error messages.
fn parse_fun_declaration(&mut self, kind: FunctionKind) -> SloxResult<ast::StmtNode> {
fn parse_fun_declaration(&mut self, kind: FunctionKind) -> SloxResult<StmtNode> {
// Read the name
let name = match self.peek().token_type {
TokenType::Identifier(_) => self.advance().clone(),
_ => return self.error_mv(format!("expected {} name", kind.name())),
};
let (params, block) = self.parse_function_info(kind)?;
Ok(ast::StmtNode::FunDecl {
Ok(StmtNode::FunDecl(FunDecl {
name,
params,
body: block,
})
}))
}
/// Parse the following rules:
@ -234,7 +234,7 @@ impl Parser {
fn parse_function_info(
&mut self,
kind: FunctionKind,
) -> SloxResult<(Vec<Token>, Vec<ast::StmtNode>)> {
) -> SloxResult<(Vec<Token>, Vec<StmtNode>)> {
// Read the list of parameter names
self.consume(
&TokenType::LeftParen,
@ -286,13 +286,13 @@ impl Parser {
/// ```
/// block := "{" statement* "}"
/// ```
fn parse_block(&mut self) -> SloxResult<ast::StmtNode> {
let mut stmts: Vec<ast::StmtNode> = Vec::new();
fn parse_block(&mut self) -> SloxResult<StmtNode> {
let mut stmts: Vec<StmtNode> = Vec::new();
while !(self.check(&TokenType::RightBrace) || self.is_at_end()) {
stmts.push(self.parse_statement()?);
}
self.consume(&TokenType::RightBrace, "expected '}' after block.")?;
Ok(ast::StmtNode::Block(stmts))
Ok(StmtNode::Block(stmts))
}
/// Parse the following rule:
@ -300,7 +300,7 @@ impl Parser {
/// if_statement := "if" "(" expression ")" statement
/// if_statement := "if" "(" expression ")" statement "else" statement
/// ```
fn parse_if_statement(&mut self) -> SloxResult<ast::StmtNode> {
fn parse_if_statement(&mut self) -> SloxResult<StmtNode> {
self.consume(&TokenType::LeftParen, "expected '(' after 'if'")?;
let expression = self.parse_expression()?;
self.consume(
@ -312,7 +312,7 @@ impl Parser {
Some(_) => Some(Box::new(self.parse_statement()?)),
None => None,
};
Ok(ast::StmtNode::If {
Ok(StmtNode::If {
condition: expression,
then_branch,
else_branch,
@ -324,7 +324,7 @@ impl Parser {
/// labelled_loop := "@" IDENTIFIER while_statement
/// labelled_loop := "@" IDENTIFIER for_statement
/// ```
fn parse_labelled_loop(&mut self) -> SloxResult<ast::StmtNode> {
fn parse_labelled_loop(&mut self) -> SloxResult<StmtNode> {
let name_token = match self.peek().token_type {
TokenType::Identifier(_) => self.advance().clone(),
_ => return self.error("identifier expected after '@'"),
@ -343,7 +343,7 @@ impl Parser {
/// ```
/// while_statement := "while" "(" expression ")" statement
/// ```
fn parse_while_statement(&mut self, label: Option<Token>) -> SloxResult<ast::StmtNode> {
fn parse_while_statement(&mut self, label: Option<Token>) -> SloxResult<StmtNode> {
self.consume(&TokenType::LeftParen, "expected '(' after 'while'")?;
let condition = self.parse_expression()?;
self.consume(
@ -356,7 +356,7 @@ impl Parser {
self.loop_state.pop();
result?
});
Ok(ast::StmtNode::Loop {
Ok(StmtNode::Loop {
label,
condition,
body,
@ -371,7 +371,7 @@ impl Parser {
/// for_initializer := expression
/// for_initializer :=
/// ```
fn parse_for_statement(&mut self, label: Option<Token>) -> SloxResult<ast::StmtNode> {
fn parse_for_statement(&mut self, label: Option<Token>) -> SloxResult<StmtNode> {
self.consume(&TokenType::LeftParen, "expected '(' after 'for'")?;
let initializer = if self.expect(&[TokenType::Semicolon]).is_some() {
@ -383,7 +383,7 @@ impl Parser {
};
let condition = if self.check(&TokenType::Semicolon) {
ast::ExprNode::Litteral {
ExprNode::Litteral {
value: Token {
token_type: TokenType::True,
lexeme: String::from("true"),
@ -416,14 +416,14 @@ impl Parser {
self.loop_state.pop();
result?
};
let while_stmt = ast::StmtNode::Loop {
let while_stmt = StmtNode::Loop {
label,
condition,
body: Box::new(body_stmt),
after_body: increment.map(|incr| Box::new(ast::StmtNode::Expression(incr))),
after_body: increment.map(|incr| Box::new(StmtNode::Expression(incr))),
};
if let Some(init_stmt) = initializer {
Ok(ast::StmtNode::Block(vec![init_stmt, while_stmt]))
Ok(StmtNode::Block(vec![init_stmt, while_stmt]))
} else {
Ok(while_stmt)
}
@ -434,7 +434,7 @@ impl Parser {
/// loop_control_statement := "break" ( IDENTIFIER )? ";"
/// loop_control_statement := "continue" ( IDENTIFIER )? ";"
/// ```
fn parse_loop_control_statement(&mut self, stmt_token: &Token) -> SloxResult<ast::StmtNode> {
fn parse_loop_control_statement(&mut self, stmt_token: &Token) -> SloxResult<StmtNode> {
if self.loop_state() == &LoopParsingState::None {
return Err(SloxError::with_token(
ErrorKind::Parse,
@ -465,7 +465,7 @@ impl Parser {
&TokenType::Semicolon,
"';' expected after loop control statement",
)?;
Ok(ast::StmtNode::LoopControl {
Ok(StmtNode::LoopControl {
is_break: stmt_token.token_type == TokenType::Break,
loop_name,
})
@ -475,7 +475,7 @@ impl Parser {
/// ```
/// return_statement := "return" expression? ";"
/// ```
fn parse_return_statement(&mut self, ret_token: &Token) -> SloxResult<ast::StmtNode> {
fn parse_return_statement(&mut self, ret_token: &Token) -> SloxResult<StmtNode> {
if self.can_use_return() {
let value = if self.check(&TokenType::Semicolon) {
None
@ -483,7 +483,7 @@ impl Parser {
Some(self.parse_expression()?)
};
self.consume(&TokenType::Semicolon, "';' expected after return statement")?;
Ok(ast::StmtNode::Return {
Ok(StmtNode::Return {
token: ret_token.clone(),
value,
})
@ -500,7 +500,7 @@ impl Parser {
/// ```
/// expression := assignment
/// ```
fn parse_expression(&mut self) -> SloxResult<ast::ExprNode> {
fn parse_expression(&mut self) -> SloxResult<ExprNode> {
self.parse_assignment()
}
@ -509,12 +509,12 @@ impl Parser {
/// assignment := IDENTIFIER "=" equality
/// assignment := equality
/// ```
fn parse_assignment(&mut self) -> SloxResult<ast::ExprNode> {
fn parse_assignment(&mut self) -> SloxResult<ExprNode> {
let expr = self.parse_logic_or()?;
if let Some(equals) = self.expect(&[TokenType::Equal]) {
let value = self.parse_assignment()?;
if let ast::ExprNode::Variable { name, id: _ } = expr {
Ok(ast::ExprNode::Assignment {
if let ExprNode::Variable { name, id: _ } = expr {
Ok(ExprNode::Assignment {
name,
value: Box::new(value),
id: self.make_id(),
@ -535,11 +535,11 @@ impl Parser {
/// ```
/// logic_or := logic_and ( "or" logic_and )*
/// ```
fn parse_logic_or(&mut self) -> SloxResult<ast::ExprNode> {
fn parse_logic_or(&mut self) -> SloxResult<ExprNode> {
let mut expr = self.parse_logic_and()?;
while let Some(operator) = self.expect(&[TokenType::Or]) {
let right = self.parse_logic_and()?;
expr = ast::ExprNode::Logical {
expr = ExprNode::Logical {
left: Box::new(expr),
operator: operator.clone(),
right: Box::new(right),
@ -552,11 +552,11 @@ impl Parser {
/// ```
/// logic_and := equality ( "and" equality )*
/// ```
fn parse_logic_and(&mut self) -> SloxResult<ast::ExprNode> {
fn parse_logic_and(&mut self) -> SloxResult<ExprNode> {
let mut expr = self.parse_equality()?;
while let Some(operator) = self.expect(&[TokenType::And]) {
let right = self.parse_equality()?;
expr = ast::ExprNode::Logical {
expr = ExprNode::Logical {
left: Box::new(expr),
operator: operator.clone(),
right: Box::new(right),
@ -570,11 +570,11 @@ impl Parser {
/// equality := comparison "==" comparison
/// equality := comparison "!=" comparison
/// ```
fn parse_equality(&mut self) -> SloxResult<ast::ExprNode> {
fn parse_equality(&mut self) -> SloxResult<ExprNode> {
let mut expr = self.parse_comparison()?;
while let Some(operator) = self.expect(&[TokenType::BangEqual, TokenType::EqualEqual]) {
let right = self.parse_comparison()?;
expr = ast::ExprNode::Binary {
expr = ExprNode::Binary {
left: Box::new(expr),
operator: operator.clone(),
right: Box::new(right),
@ -588,7 +588,7 @@ impl Parser {
/// comparison := term comparison_operator term
/// comparison_operator := "<" | "<=" | ">" | ">="
/// ```
fn parse_comparison(&mut self) -> SloxResult<ast::ExprNode> {
fn parse_comparison(&mut self) -> SloxResult<ExprNode> {
let mut expr = self.parse_term()?;
while let Some(operator) = self.expect(&[
TokenType::Greater,
@ -597,7 +597,7 @@ impl Parser {
TokenType::LessEqual,
]) {
let right = self.parse_term()?;
expr = ast::ExprNode::Binary {
expr = ExprNode::Binary {
left: Box::new(expr),
operator: operator.clone(),
right: Box::new(right),
@ -611,11 +611,11 @@ impl Parser {
/// term := factor ( "+" factor )*
/// term := factor ( "-" factor )*
/// ```
fn parse_term(&mut self) -> SloxResult<ast::ExprNode> {
fn parse_term(&mut self) -> SloxResult<ExprNode> {
let mut expr = self.parse_factor()?;
while let Some(operator) = self.expect(&[TokenType::Minus, TokenType::Plus]) {
let right = self.parse_factor()?;
expr = ast::ExprNode::Binary {
expr = ExprNode::Binary {
left: Box::new(expr),
operator: operator.clone(),
right: Box::new(right),
@ -629,11 +629,11 @@ impl Parser {
/// factor := unary ( "*" unary )*
/// factor := unary ( "/" unary )*
/// ```
fn parse_factor(&mut self) -> SloxResult<ast::ExprNode> {
fn parse_factor(&mut self) -> SloxResult<ExprNode> {
let mut expr = self.parse_unary()?;
while let Some(operator) = self.expect(&[TokenType::Slash, TokenType::Star]) {
let right = self.parse_unary()?;
expr = ast::ExprNode::Binary {
expr = ExprNode::Binary {
left: Box::new(expr),
operator: operator.clone(),
right: Box::new(right),
@ -648,9 +648,9 @@ impl Parser {
/// unary := "!" unary
/// unary := primary call_arguments*
/// ```
fn parse_unary(&mut self) -> SloxResult<ast::ExprNode> {
fn parse_unary(&mut self) -> SloxResult<ExprNode> {
if let Some(operator) = self.expect(&[TokenType::Bang, TokenType::Minus]) {
Ok(ast::ExprNode::Unary {
Ok(ExprNode::Unary {
operator,
right: Box::new(self.parse_unary()?),
})
@ -670,26 +670,26 @@ impl Parser {
/// primary := IDENTIFIER
/// primary := "fun" function_info
/// ```
fn parse_primary(&mut self) -> SloxResult<ast::ExprNode> {
fn parse_primary(&mut self) -> SloxResult<ExprNode> {
if self.expect(&[TokenType::LeftParen]).is_some() {
let expr = self.parse_expression()?;
self.consume(&TokenType::RightParen, "expected ')' after expression")?;
Ok(ast::ExprNode::Grouping {
Ok(ExprNode::Grouping {
expression: Box::new(expr),
})
} else if self.expect(&[TokenType::Fun]).is_some() {
let (params, body) = self.parse_function_info(FunctionKind::Lambda)?;
Ok(ast::ExprNode::Lambda { params, body })
Ok(ExprNode::Lambda { params, body })
} else if let Some(token) =
self.expect(&[TokenType::False, TokenType::True, TokenType::Nil])
{
Ok(ast::ExprNode::Litteral { value: token })
Ok(ExprNode::Litteral { value: token })
} else {
match &self.peek().token_type {
TokenType::Number(_) | &TokenType::String(_) => Ok(ast::ExprNode::Litteral {
TokenType::Number(_) | &TokenType::String(_) => Ok(ExprNode::Litteral {
value: self.advance().clone(),
}),
TokenType::Identifier(_) => Ok(ast::ExprNode::Variable {
TokenType::Identifier(_) => Ok(ExprNode::Variable {
name: self.advance().clone(),
id: self.make_id(),
}),
@ -703,7 +703,7 @@ impl Parser {
/// call := expression "(" arguments? ")"
/// arguments := expression ( "," expression )*
/// ```
fn parse_call_arguments(&mut self, callee: ast::ExprNode) -> Result<ast::ExprNode, SloxError> {
fn parse_call_arguments(&mut self, callee: ExprNode) -> Result<ExprNode, SloxError> {
let mut arguments = Vec::new();
if !self.check(&TokenType::RightParen) {
loop {
@ -719,7 +719,7 @@ impl Parser {
let right_paren = self
.consume(&TokenType::RightParen, "')' expected after arguments")?
.clone();
Ok(ast::ExprNode::Call {
Ok(ExprNode::Call {
callee: Box::new(callee),
right_paren,
arguments,

View file

@ -1,7 +1,7 @@
use std::collections::HashMap;
use crate::{
ast,
ast::{ExprNode, ProgramNode, StmtNode},
errors::{ErrorKind, SloxError, SloxResult},
tokens::Token,
};
@ -12,7 +12,7 @@ use crate::{
pub type ResolvedVariables = HashMap<usize, usize>;
/// Resolve all variables in a program's AST.
pub fn resolve_variables(program: &ast::ProgramNode) -> SloxResult<ResolvedVariables> {
pub fn resolve_variables(program: &ProgramNode) -> SloxResult<ResolvedVariables> {
let mut state = ResolverState::default();
state
.with_scope(|rs| program.resolve(rs))
@ -193,7 +193,7 @@ impl<'a> ResolverState<'a> {
fn resolve_function<'a, 'b>(
rs: &mut ResolverState<'a>,
params: &'b [Token],
body: &'b Vec<ast::StmtNode>,
body: &'b Vec<StmtNode>,
) -> ResolverResult
where
'b: 'a,
@ -215,7 +215,7 @@ trait VarResolver {
'a: 'b;
}
impl VarResolver for ast::ProgramNode {
impl VarResolver for ProgramNode {
fn resolve<'a, 'b>(&'a self, rs: &mut ResolverState<'b>) -> ResolverResult
where
'a: 'b,
@ -224,7 +224,7 @@ impl VarResolver for ast::ProgramNode {
}
}
impl VarResolver for Vec<ast::StmtNode> {
impl VarResolver for Vec<StmtNode> {
fn resolve<'a, 'b>(&'a self, rs: &mut ResolverState<'b>) -> ResolverResult
where
'a: 'b,
@ -236,37 +236,37 @@ impl VarResolver for Vec<ast::StmtNode> {
}
}
impl VarResolver for ast::StmtNode {
impl VarResolver for StmtNode {
fn resolve<'a, 'b>(&'a self, rs: &mut ResolverState<'b>) -> ResolverResult
where
'a: 'b,
{
match self {
ast::StmtNode::Block(stmts) => rs.with_scope(|rs| stmts.resolve(rs)),
StmtNode::Block(stmts) => rs.with_scope(|rs| stmts.resolve(rs)),
ast::StmtNode::VarDecl(name, None) => {
StmtNode::VarDecl(name, None) => {
rs.declare(name, SymKind::Variable)?;
Ok(())
}
ast::StmtNode::VarDecl(name, Some(init)) => {
StmtNode::VarDecl(name, Some(init)) => {
rs.declare(name, SymKind::Variable)?;
init.resolve(rs)?;
rs.define(name);
Ok(())
}
ast::StmtNode::FunDecl { name, params, body } => {
rs.declare(name, SymKind::Function)?;
rs.define(name);
rs.with_scope(|rs| resolve_function(rs, params, body))
StmtNode::FunDecl(decl) => {
rs.declare(&decl.name, SymKind::Function)?;
rs.define(&decl.name);
rs.with_scope(|rs| resolve_function(rs, &decl.params, &decl.body))
}
ast::StmtNode::If {
StmtNode::If {
condition,
then_branch,
else_branch: None,
} => condition.resolve(rs).and_then(|_| then_branch.resolve(rs)),
ast::StmtNode::If {
StmtNode::If {
condition,
then_branch,
else_branch: Some(else_branch),
@ -275,7 +275,7 @@ impl VarResolver for ast::StmtNode {
.and_then(|_| then_branch.resolve(rs))
.and_then(|_| else_branch.resolve(rs)),
ast::StmtNode::Loop {
StmtNode::Loop {
label: _,
condition,
body,
@ -291,18 +291,18 @@ impl VarResolver for ast::StmtNode {
})
.and_then(|_| body.resolve(rs)),
ast::StmtNode::Return {
StmtNode::Return {
token: _,
value: None,
} => Ok(()),
ast::StmtNode::Return {
StmtNode::Return {
token: _,
value: Some(expr),
} => expr.resolve(rs),
ast::StmtNode::Expression(expr) => expr.resolve(rs),
ast::StmtNode::Print(expr) => expr.resolve(rs),
ast::StmtNode::LoopControl {
StmtNode::Expression(expr) => expr.resolve(rs),
StmtNode::Print(expr) => expr.resolve(rs),
StmtNode::LoopControl {
is_break: _,
loop_name: _,
} => Ok(()),
@ -310,37 +310,37 @@ impl VarResolver for ast::StmtNode {
}
}
impl VarResolver for ast::ExprNode {
impl VarResolver for ExprNode {
fn resolve<'a, 'b>(&'a self, rs: &mut ResolverState<'b>) -> ResolverResult
where
'a: 'b,
{
match self {
ast::ExprNode::Variable { name, id } => rs.resolve_use(id, name),
ExprNode::Variable { name, id } => rs.resolve_use(id, name),
ast::ExprNode::Assignment { name, value, id } => {
ExprNode::Assignment { name, value, id } => {
value.resolve(rs)?;
rs.resolve_assignment(id, name)
}
ast::ExprNode::Lambda { params, body } => {
ExprNode::Lambda { params, body } => {
rs.with_scope(|rs| resolve_function(rs, params, body))
}
ast::ExprNode::Logical {
ExprNode::Logical {
left,
operator: _,
right,
} => left.resolve(rs).and_then(|_| right.resolve(rs)),
ast::ExprNode::Binary {
ExprNode::Binary {
left,
operator: _,
right,
} => left.resolve(rs).and_then(|_| right.resolve(rs)),
ast::ExprNode::Unary { operator: _, right } => right.resolve(rs),
ast::ExprNode::Grouping { expression } => expression.resolve(rs),
ast::ExprNode::Litteral { value: _ } => Ok(()),
ast::ExprNode::Call {
ExprNode::Unary { operator: _, right } => right.resolve(rs),
ExprNode::Grouping { expression } => expression.resolve(rs),
ExprNode::Litteral { value: _ } => Ok(()),
ExprNode::Call {
callee,
right_paren: _,
arguments,
@ -349,7 +349,7 @@ impl VarResolver for ast::ExprNode {
}
}
impl VarResolver for Vec<ast::ExprNode> {
impl VarResolver for Vec<ExprNode> {
fn resolve<'a, 'b>(&'a self, rs: &mut ResolverState<'b>) -> ResolverResult
where
'a: 'b,