diff --git a/c-opopt.cc b/c-opopt.cc index b1c9bb8..ef8198c 100644 --- a/c-opopt.cc +++ b/c-opopt.cc @@ -35,16 +35,33 @@ struct T_ConstantFolder_ bool operator()( A_Node& node , bool exit ) noexcept; private: - using F_ExprGet_ = std::function< A_ExpressionNode&( A_Node& ) >; - using F_ExprSet_ = std::function< void( A_Node& , P_ExpressionNode ) >; - - void handleParentNode( + template< + typename T + > void handleParentNode( A_Node& node , - F_ExprGet_ get , - F_ExprSet_ set ) noexcept; + std::function< A_ExpressionNode&( T& ) > get , + std::function< void( T& , P_ExpressionNode ) > set ) noexcept; P_ExpressionNode checkExpression( - A_ExpressionNode const& node ) noexcept; + A_ExpressionNode& node ) noexcept; + + // Handle identifiers. If the size is fixed and the identifier is + // either width or height, replace it with the appropriate value. + P_ExpressionNode doIdExpr( + T_IdentifierExprNode& node ) noexcept; + + // Transform an unary operator applied to a constant into a constant. + P_ExpressionNode doUnaryOp( + A_Node& parent , + T_UnaryOperatorNode::E_Operator op , + double value ) const noexcept; + + // Transform a binary operator applied to a constant into a constant. + P_ExpressionNode doBinaryOp( + A_Node& parent , + T_BinaryOperatorNode::E_Operator op , + double left , + double right ) const noexcept; }; /*----------------------------------------------------------------------------*/ @@ -60,33 +77,24 @@ bool T_ConstantFolder_::operator()( switch ( node.type( ) ) { case A_Node::TN_ARG: - handleParentNode( node , - []( A_Node& n ) -> A_ExpressionNode& { - return ((T_ArgumentNode&)n).expression( ); - } , - []( A_Node& n , P_ExpressionNode e ) { - ((T_ArgumentNode&)n).expression( std::move( e ) ); - } ); + handleParentNode< T_ArgumentNode >( + node , + []( auto& n ) -> A_ExpressionNode& { return n.expression( ); } , + []( auto& n , P_ExpressionNode e ) { n.expression( std::move( e ) ); } + ); return false; case A_Node::TN_CONDITION: - handleParentNode( node , - []( A_Node& n ) -> A_ExpressionNode& { - return ((T_CondInstrNode::T_Expression&)n).expression( ); - } , - []( A_Node& n , P_ExpressionNode e ) { - ((T_CondInstrNode::T_Expression&)n).expression( std::move( e ) ); - } ); + handleParentNode< T_CondInstrNode::T_Expression >( node , + []( auto& n ) -> A_ExpressionNode& { return n.expression( ); } , + []( auto& n , P_ExpressionNode e ) { n.expression( std::move( e ) ); } + ); return false; case A_Node::OP_SET: - handleParentNode( node , - []( A_Node& n ) -> A_ExpressionNode& { - return ((T_SetInstrNode&)n).expression( ); - } , - []( A_Node& n , P_ExpressionNode e ) { - ((T_SetInstrNode&)n).setExpression( std::move( e ) ); - } ); + handleParentNode< T_SetInstrNode >( node , + []( auto& n ) -> A_ExpressionNode& { return n.expression( ); } , + []( auto& n , P_ExpressionNode e ) { n.setExpression( std::move( e ) ); } ); return false; default: @@ -94,56 +102,185 @@ bool T_ConstantFolder_::operator()( } } -void T_ConstantFolder_::handleParentNode( - A_Node& node , - F_ExprGet_ get , - F_ExprSet_ set ) noexcept +/*----------------------------------------------------------------------------*/ + +template< + typename T +> void T_ConstantFolder_::handleParentNode( + A_Node& n , + std::function< A_ExpressionNode&( T& ) > get , + std::function< void( T& , P_ExpressionNode ) > set ) noexcept { + auto& node{ (T&) n }; auto r{ checkExpression( get( node ) ) }; if ( r ) { + r->location( ) = node.location( ); set( node , std::move( r ) ); didFold = true; } } -/*----------------------------------------------------------------------------*/ - P_ExpressionNode T_ConstantFolder_::checkExpression( - A_ExpressionNode const& node ) noexcept + A_ExpressionNode& node ) noexcept { -#warning TODO optimize the fuck - // 1/ Replace inputs with value if no curve/constant curve - // Replace $width/$height with value if fixedSize - // 2/ Replace UnOp( Cnst ) with result - // Replace BinOp( Cnst , Cnst ) with result - // 3/ Try to find other optimisations, e.g. for Add( Cnst , Add( Cnst , Expr ) ) + // Already a constant + if ( node.type( ) == A_Node::EXPR_CONST ) { + return {}; + } + // Replace $width/$height with value if fixedSize if ( node.type( ) == A_Node::EXPR_ID ) { - if ( !fixedSize ) { - return {}; - } + return doIdExpr( (T_IdentifierExprNode&) node ); + } - T_IdentifierExprNode& n{ (T_IdentifierExprNode&) node }; - if ( n.id( ) == "width" ) { - auto c{ NewOwned< T_ConstantExprNode >( n.parent( ) , - double( fixedSize->first ) ) }; - c->location( ) = n.location( ); - return c; - } - if ( n.id( ) == "height" ) { - auto c{ NewOwned< T_ConstantExprNode >( n.parent( ) , - float( fixedSize->second ) ) }; - c->location( ) = n.location( ); - return c; + // Replace inputs with value if no curve/constant curve + if ( node.type( ) == A_Node::EXPR_INPUT ) { + // TODO: may be replaced with either a constant or a variable. + // * If the curve exists and describes a constant, it's a + // constant + // * If there is no curve and only one default value in the + // whole program then it's also a constant + // * No curve, multiple defaults -> variable + return {}; + } + + // Replace UnOp( Cnst ) with result + auto* const asUnary{ dynamic_cast< T_UnaryOperatorNode* >( &node ) }; + if ( asUnary ) { + handleParentNode< T_UnaryOperatorNode >( *asUnary , + []( auto& n ) -> A_ExpressionNode& { return n.argument( ); } , + []( auto& n , P_ExpressionNode e ) { n.setArgument( std::move( e ) ); } ); + if ( asUnary->argument( ).type( ) == A_Node::EXPR_CONST ) { + auto const& cn{ (T_ConstantExprNode const&) asUnary->argument( ) }; + return doUnaryOp( asUnary->parent( ) , asUnary->op( ) , + cn.floatValue( ) ); } return {}; } + // Replace BinOp( Cnst , Cnst ) with result + auto* const asBinary{ dynamic_cast< T_BinaryOperatorNode* >( &node ) }; + assert( asBinary && "Missing support for some expr subtype" ); + handleParentNode< T_BinaryOperatorNode >( *asBinary , + []( auto& n ) -> A_ExpressionNode& { return n.left( ); } , + []( auto& n , P_ExpressionNode e ) { n.setLeft( std::move( e ) ); } ); + handleParentNode< T_BinaryOperatorNode >( *asBinary , + []( auto& n ) -> A_ExpressionNode& { return n.right( ); } , + []( auto& n , P_ExpressionNode e ) { n.setRight( std::move( e ) ); } ); + + if ( asBinary->left( ).type( ) == A_Node::EXPR_CONST + && asBinary->right( ).type( ) == A_Node::EXPR_CONST ) { + auto const& l{ (T_ConstantExprNode const&) asBinary->left( ) }; + auto const& r{ (T_ConstantExprNode const&) asBinary->right( ) }; + return doBinaryOp( asBinary->parent( ) , asBinary->op( ), + l.floatValue( ) , r.floatValue( ) ); + } + return {}; +} + +/*----------------------------------------------------------------------------*/ + +P_ExpressionNode T_ConstantFolder_::doIdExpr( + T_IdentifierExprNode& node ) noexcept +{ + if ( !fixedSize ) { + return {}; + } + + if ( node.id( ) == "width" ) { + return NewOwned< T_ConstantExprNode >( node.parent( ) , + double( fixedSize->first ) ); + } + + if ( node.id( ) == "height" ) { + return NewOwned< T_ConstantExprNode >( node.parent( ) , + float( fixedSize->second ) ); + } + return {}; } +P_ExpressionNode T_ConstantFolder_::doUnaryOp( + A_Node& parent , + const T_UnaryOperatorNode::E_Operator op , + const double value ) const noexcept +{ + const double rVal{ []( const auto op , const auto value ) { + switch ( op ) { + case T_UnaryOperatorNode::NEG: + return -value; + case T_UnaryOperatorNode::NOT: + return value ? 0. : 1.; + case T_UnaryOperatorNode::INV: + // TODO check if 0 + return 1. / value; + case T_UnaryOperatorNode::COS: + return cos( value ); + case T_UnaryOperatorNode::SIN: + return sin( value ); + case T_UnaryOperatorNode::TAN: + // TODO check if valid + return tan( value ); + case T_UnaryOperatorNode::SQRT: + // TODO check if >= 0 + return sqrt( value ); + case T_UnaryOperatorNode::LN: + // TODO check if > 0 + return log( value ); + case T_UnaryOperatorNode::EXP: + return exp( value ); + } + fprintf( stderr , "invalid operator %d\n" , int( op ) ); + std::abort( ); + }( op , value ) }; + + return NewOwned< T_ConstantExprNode >( parent , rVal ); +} + +P_ExpressionNode T_ConstantFolder_::doBinaryOp( + A_Node& parent , + const T_BinaryOperatorNode::E_Operator op , + const double left , + const double right ) const noexcept +{ + const double rVal{ []( const auto op , const auto l , const auto r ) { + switch ( op ) { + case T_BinaryOperatorNode::ADD: + return l + r; + case T_BinaryOperatorNode::SUB: + return l - r; + case T_BinaryOperatorNode::MUL: + return l * r; + case T_BinaryOperatorNode::DIV: + // TODO: check r != 0 + return l / r; + case T_BinaryOperatorNode::POW: + // TODO check operands + return pow( l , r ); + case T_BinaryOperatorNode::CMP_EQ: + return ( l == r ) ? 1. : 0.; + case T_BinaryOperatorNode::CMP_NE: + return ( l != r ) ? 1. : 0.; + case T_BinaryOperatorNode::CMP_GT: + return ( l > r ) ? 1. : 0.; + case T_BinaryOperatorNode::CMP_GE: + return ( l >= r ) ? 1. : 0.; + case T_BinaryOperatorNode::CMP_LT: + return ( l < r ) ? 1. : 0.; + case T_BinaryOperatorNode::CMP_LE: + return ( l <= r ) ? 1. : 0.; + } + fprintf( stderr , "invalid operator %d\n" , int( op ) ); + std::abort( ); + }( op , left , right ) }; + + return NewOwned< T_ConstantExprNode >( parent , rVal ); +} + } // namespace +/*----------------------------------------------------------------------------*/ + bool opopt::FoldConstants( T_RootNode& root ,