diff --git a/opast.cc b/opast.cc index 08a95b0..462e48f 100644 --- a/opast.cc +++ b/opast.cc @@ -135,7 +135,7 @@ A_Node* opast::ASTVisitorBrowser( break; } - // Conditional instruction + // Conditional instruction & associated technical nodes case A_Node::OP_COND: { auto& n( (T_CondInstrNode&) node ); @@ -158,6 +158,22 @@ A_Node* opast::ASTVisitorBrowser( break; } + case A_Node::TN_CONDITION: + if ( child == 0 ) { + return &( (T_CondInstrNode::T_Expression&) node ).expression( ); + } + break; + case A_Node::TN_CASE: + if ( child == 0 ) { + return &( (T_CondInstrNode::T_ValuedCase&) node ).instructions( ); + } + break; + case A_Node::TN_DEFAULT: + if ( child == 0 ) { + return &( (T_CondInstrNode::T_DefaultCase&) node ).instructions( ); + } + break; + // Set instruction case A_Node::OP_SET: if ( child == 0 ) { @@ -576,6 +592,7 @@ T_Optional< T_SRDLocation > T_SamplerInstrNode::setLOD( return {}; } + /*= T_LocalsInstrNode ===========================================================*/ T_Optional< T_SRDLocation > T_LocalsInstrNode::addVariable( @@ -589,3 +606,39 @@ T_Optional< T_SRDLocation > T_LocalsInstrNode::addVariable( varLocs_.add( token.location( ) ); return {}; } + + +/*= T_CondInstrNode ============================================================*/ + +void T_CondInstrNode::setExpression( + P_ExpressionNode expression ) noexcept +{ + if ( !expression ) { + expression_.clear( ); + return; + } + expression_ = NewOwned< T_Expression >( + *this , std::move( expression ) ); +} + +void T_CondInstrNode::setCase( + const int64_t value , + P_InstrListNode instrList ) noexcept +{ + cases_.remove( value ); + if ( instrList ) { + cases_.add( NewOwned< T_ValuedCase >( *this , value , + std::move( instrList ) ) ); + } +} + +void T_CondInstrNode::setDefaultCase( + P_InstrListNode defaultCase ) noexcept +{ + if ( !defaultCase ) { + defaultCase_.clear( ); + return; + } + defaultCase_ = NewOwned< T_DefaultCase >( *this , + std::move( defaultCase ) ); +} diff --git a/opast.hh b/opast.hh index 4921f78..4c2320a 100644 --- a/opast.hh +++ b/opast.hh @@ -62,6 +62,10 @@ class A_Node EXPR_ID , // Variable access EXPR_INPUT , // Input value access EXPR_CONST , // Numeric constant + // Technical nodes + TN_CONDITION , // Expression for a conditional instruction + TN_CASE , // Valued case for a conditional instruction + TN_DEFAULT , // Default case for a conditional instruction }; private: @@ -403,46 +407,101 @@ class T_CallInstrNode : public A_InstructionNode // Conditional instruction class T_CondInstrNode : public A_InstructionNode { + public: + class T_Expression : public A_Node + { + private: + P_ExpressionNode expression_; + + public: + T_Expression( T_CondInstrNode& parent , + P_ExpressionNode expr ) + : A_Node( TN_CONDITION , &parent ) , + expression_( std::move( expr ) ) + { } + + A_ExpressionNode& expression( ) const noexcept + { return *expression_; } + }; + + class T_ValuedCase : public A_Node + { + private: + int64_t value_; + P_InstrListNode instructions_; + + public: + T_ValuedCase( T_CondInstrNode& parent , + const int64_t value , + P_InstrListNode il ) + : A_Node( TN_CASE , &parent ) , + value_( value ) , + instructions_( std::move( il ) ) + { } + + int64_t value( ) const noexcept + { return value_; } + T_InstrListNode& instructions( ) const noexcept + { return *instructions_; } + }; + + class T_DefaultCase : public A_Node + { + private: + P_InstrListNode instructions_; + + public: + T_DefaultCase( T_CondInstrNode& parent , + P_InstrListNode il ) + : A_Node( TN_DEFAULT , &parent ) , + instructions_( std::move( il ) ) + { } + + T_InstrListNode& instructions( ) const noexcept + { return *instructions_; } + }; + private: - P_ExpressionNode expression_; - T_KeyValueTable< int64_t , P_InstrListNode > cases_; - P_InstrListNode defaultCase_; + T_OwnPtr< T_Expression > expression_; + T_ObjectTable< int64_t , T_OwnPtr< T_ValuedCase > > cases_{ + []( T_OwnPtr< T_ValuedCase > const& c ) -> int64_t { + return c->value( ); + } + }; + T_OwnPtr< T_DefaultCase > defaultCase_; public: explicit T_CondInstrNode( T_InstrListNode& parent ) noexcept : A_InstructionNode( OP_COND , parent ) { } - void setExpression( P_ExpressionNode expression ) noexcept - { expression_ = std::move( expression ); } + void setExpression( P_ExpressionNode expression ) noexcept; bool hasExpression( ) const noexcept { return bool( expression_ ); } - A_ExpressionNode& expression( ) const noexcept + T_Expression& expression( ) const noexcept { return *expression_; } void setCase( const int64_t value , - P_InstrListNode instrList ) noexcept - { - if ( instrList ) { - cases_.set( value , std::move( instrList ) ); - } - } - + P_InstrListNode instrList ) noexcept; void rmCase( const int64_t value ) noexcept { cases_.remove( value ); } - T_Array< int64_t > const& cases( ) const noexcept + uint32_t nCases( ) const noexcept + { return cases_.size( ); } + T_Array< int64_t > cases( ) const noexcept { return cases_.keys( ); } bool hasCase( const int64_t value ) const noexcept { return cases_.contains( value ); } - T_InstrListNode& getCase( const int64_t value ) const noexcept + T_ValuedCase& getCase( const int64_t value ) const noexcept { return **cases_.get( value ); } + T_ValuedCase& getCaseByIndex( + const uint32_t index ) const noexcept + { return *cases_[ index ]; } - void setDefaultCase( P_InstrListNode defaultCase ) noexcept - { defaultCase_ = std::move( defaultCase ); } + void setDefaultCase( P_InstrListNode defaultCase ) noexcept; bool hasDefaultCase( ) const noexcept { return bool( defaultCase_ ); } - T_InstrListNode& defaultCase( ) const noexcept + T_DefaultCase& defaultCase( ) const noexcept { return *defaultCase_; } }; diff --git a/opparser.cc b/opparser.cc index dd7a8b3..f22c790 100644 --- a/opparser.cc +++ b/opparser.cc @@ -1185,6 +1185,9 @@ M_INSTR_( If ) T_CondInstrNode& cond{ instructions.add< T_CondInstrNode >( ) }; cond.location( ) = input[ 0 ].location( ); cond.setExpression( parseExpression( cond , input[ 1 ] ) ); + if ( cond.hasExpression( ) ) { + cond.expression( ).location( ) = input[ 1 ].location( ); + } if ( input.size( ) == 2 ) { errors.addNew( "'then' block expected" , @@ -1193,8 +1196,15 @@ M_INSTR_( If ) } cond.setCase( 1 , parseBlock( cond , input[ 2 ] ) ); + if ( cond.hasCase( 1 ) ) { + cond.getCase( 1 ).location( ) = input[ 2 ].location( ); + } + if ( input.size( ) > 3 ) { cond.setDefaultCase( parseBlock( cond , input[ 3 ] ) ); + if ( cond.hasDefaultCase( ) ) { + cond.defaultCase( ).location( ) = input[ 3 ].location( ); + } if ( input.size( ) > 4 ) { errors.addNew( "too many arguments" , input[ 4 ].location( ) ); }