Optimizer - Errors when folding bad operations

Errors will be added to an array if the optimizer finds invalid
operations involving constants (e.g. (inv 0))
This commit is contained in:
Emmanuel BENOîT 2017-12-02 10:07:14 +01:00
parent e5a7d2e722
commit f1b51f564d
2 changed files with 120 additions and 98 deletions

View file

@ -1,6 +1,8 @@
#include "externals.hh"
#include "c-opcomp.hh"
#include "c-opopt.hh"
#include "c-ops.hh"
#include "c-sync.hh"
using namespace ebcl;
@ -8,26 +10,15 @@ using namespace opast;
using namespace opopt;
// Macro to use either the specified visitor or a visitor created just for the
// occation
#define M_VISITOR_( E ) \
T_OwnPtr< T_Visitor< A_Node > > localVisitor_; \
if ( !(E) ) { \
localVisitor_ = NewOwned< T_Visitor< A_Node > >( ASTVisitorBrowser ); \
} \
T_Visitor< A_Node >& visitor{ (E) ? *(E) : *localVisitor_ }
/*= CONSTANT FOLDING =========================================================*/
namespace {
struct T_ConstantFolder_
{
// Input
T_Optional< std::pair< uint32_t , uint32_t > > fixedSize;
T_SyncCurves const* curves;
T_ConstantFolder_( T_OptData& data ) noexcept
: oData{ data }
{}
// Result
bool didFold{ false };
@ -35,6 +26,8 @@ struct T_ConstantFolder_
bool operator()( A_Node& node , bool exit ) noexcept;
private:
T_OptData& oData;
template<
typename T
> void handleParentNode(
@ -52,14 +45,12 @@ struct T_ConstantFolder_
// Transform an unary operator applied to a constant into a constant.
P_ExpressionNode doUnaryOp(
A_Node& parent ,
T_UnaryOperatorNode::E_Operator op ,
T_UnaryOperatorNode& node ,
double value ) const noexcept;
// Transform a binary operator applied to a constant into a constant.
P_ExpressionNode doBinaryOp(
A_Node& parent ,
T_BinaryOperatorNode::E_Operator op ,
T_BinaryOperatorNode& node ,
double left ,
double right ) const noexcept;
};
@ -135,12 +126,12 @@ P_ExpressionNode T_ConstantFolder_::checkExpression(
// Replace inputs with value if no curve/constant curve
if ( node.type( ) == A_Node::EXPR_INPUT ) {
if ( !curves ) {
if ( !oData.curves ) {
return {};
}
auto& n{ (T_InputExprNode&) node };
auto const* const curve{ curves->curves.get( n.id( ) ) };
auto const* const curve{ oData.curves->curves.get( n.id( ) ) };
if ( curve ) {
// Curve present, check if it's constant
const auto cval{ curve->isConstant( ) };
@ -163,8 +154,7 @@ P_ExpressionNode T_ConstantFolder_::checkExpression(
[]( auto& n , P_ExpressionNode e ) { n.setArgument( std::move( e ) ); } );
if ( asUnary->argument( ).type( ) == A_Node::EXPR_CONST ) {
auto const& cn{ (T_ConstantExprNode const&) asUnary->argument( ) };
return doUnaryOp( asUnary->parent( ) , asUnary->op( ) ,
cn.floatValue( ) );
return doUnaryOp( *asUnary , cn.floatValue( ) );
}
return {};
}
@ -183,8 +173,7 @@ P_ExpressionNode T_ConstantFolder_::checkExpression(
&& asBinary->right( ).type( ) == A_Node::EXPR_CONST ) {
auto const& l{ (T_ConstantExprNode const&) asBinary->left( ) };
auto const& r{ (T_ConstantExprNode const&) asBinary->right( ) };
return doBinaryOp( asBinary->parent( ) , asBinary->op( ),
l.floatValue( ) , r.floatValue( ) );
return doBinaryOp( *asBinary , l.floatValue( ) , r.floatValue( ) );
}
return {};
}
@ -194,98 +183,132 @@ P_ExpressionNode T_ConstantFolder_::checkExpression(
P_ExpressionNode T_ConstantFolder_::doIdExpr(
T_IdentifierExprNode& node ) noexcept
{
if ( !fixedSize ) {
if ( !oData.fixedSize ) {
return {};
}
if ( node.id( ) == "width" ) {
return NewOwned< T_ConstantExprNode >( node.parent( ) ,
double( fixedSize->first ) );
double( oData.fixedSize->first ) );
}
if ( node.id( ) == "height" ) {
return NewOwned< T_ConstantExprNode >( node.parent( ) ,
float( fixedSize->second ) );
float( oData.fixedSize->second ) );
}
return {};
}
P_ExpressionNode T_ConstantFolder_::doUnaryOp(
A_Node& parent ,
const T_UnaryOperatorNode::E_Operator op ,
T_UnaryOperatorNode& node ,
const double value ) const noexcept
{
const double rVal{ []( const auto op , const auto value ) {
switch ( op ) {
const double rVal{ [this]( auto& node , const auto value ) {
switch ( node.op( ) ) {
case T_UnaryOperatorNode::NEG:
return -value;
case T_UnaryOperatorNode::NOT:
return value ? 0. : 1.;
case T_UnaryOperatorNode::INV:
// TODO check if 0
if ( value == 0 ) {
oData.errors.addNew( "math - 1/x, x=0" , node.location( ) );
return 0.;
}
return 1. / value;
case T_UnaryOperatorNode::COS:
return cos( value );
case T_UnaryOperatorNode::SIN:
return sin( value );
case T_UnaryOperatorNode::TAN:
// TODO check if valid
if ( fabs( value - M_PI / 2 ) <= 1e-6 ) {
oData.errors.addNew( "math - tan(x), x=~PI/2" ,
node.location( ) , E_SRDErrorType::WARNING );
}
return tan( value );
case T_UnaryOperatorNode::SQRT:
// TODO check if >= 0
if ( value < 0 ) {
oData.errors.addNew( "math - sqrt(x), x<0" , node.location( ) );
return 0.;
}
return sqrt( value );
case T_UnaryOperatorNode::LN:
// TODO check if > 0
if ( value <= 0 ) {
oData.errors.addNew( "math - ln(x), x<=0" , node.location( ) );
return 0.;
}
return log( value );
case T_UnaryOperatorNode::EXP:
return exp( value );
}
fprintf( stderr , "invalid operator %d\n" , int( op ) );
std::abort( );
}( op , value ) };
return NewOwned< T_ConstantExprNode >( parent , rVal );
fprintf( stderr , "invalid operator %d\n" , int( node.op( ) ) );
std::abort( );
}( node , value ) };
return NewOwned< T_ConstantExprNode >( node.parent( ) , rVal );
}
P_ExpressionNode T_ConstantFolder_::doBinaryOp(
A_Node& parent ,
const T_BinaryOperatorNode::E_Operator op ,
T_BinaryOperatorNode& node ,
const double left ,
const double right ) const noexcept
{
const double rVal{ []( const auto op , const auto l , const auto r ) {
switch ( op ) {
const double rVal{ [this]( auto& node , const auto l , const auto r ) {
switch ( node.op( ) ) {
case T_BinaryOperatorNode::ADD:
return l + r;
case T_BinaryOperatorNode::SUB:
return l - r;
case T_BinaryOperatorNode::MUL:
return l * r;
case T_BinaryOperatorNode::DIV:
// TODO: check r != 0
return l / r;
case T_BinaryOperatorNode::POW:
// TODO check operands
return pow( l , r );
case T_BinaryOperatorNode::CMP_EQ:
return ( l == r ) ? 1. : 0.;
case T_BinaryOperatorNode::CMP_NE:
return ( l != r ) ? 1. : 0.;
case T_BinaryOperatorNode::CMP_GT:
return ( l > r ) ? 1. : 0.;
case T_BinaryOperatorNode::CMP_GE:
return ( l >= r ) ? 1. : 0.;
case T_BinaryOperatorNode::CMP_LT:
return ( l < r ) ? 1. : 0.;
case T_BinaryOperatorNode::CMP_LE:
return ( l <= r ) ? 1. : 0.;
}
fprintf( stderr , "invalid operator %d\n" , int( op ) );
std::abort( );
}( op , left , right ) };
return NewOwned< T_ConstantExprNode >( parent , rVal );
case T_BinaryOperatorNode::DIV:
if ( r == 0 ) {
oData.errors.addNew( "math - l/r, r=0" , node.location( ) );
return 0.;
}
return l / r;
case T_BinaryOperatorNode::POW:
if ( l == 0 && r == 0 ) {
oData.errors.addNew( "math - l^r, l=r=0" , node.location( ) );
return 0.;
}
if ( l == 0 && r < 0 ) {
oData.errors.addNew( "math - l^r, l=0, r<0" , node.location( ) );
return 0.;
}
if ( l < 0 && fmod( r , 1. ) != 0. ) {
oData.errors.addNew( "math - l^r, l<0, r not integer" , node.location( ) );
return 0.;
}
return pow( l , r );
case T_BinaryOperatorNode::CMP_EQ: return ( l == r ) ? 1. : 0.;
case T_BinaryOperatorNode::CMP_NE: return ( l != r ) ? 1. : 0.;
case T_BinaryOperatorNode::CMP_GT: return ( l > r ) ? 1. : 0.;
case T_BinaryOperatorNode::CMP_GE: return ( l >= r ) ? 1. : 0.;
case T_BinaryOperatorNode::CMP_LT: return ( l < r ) ? 1. : 0.;
case T_BinaryOperatorNode::CMP_LE: return ( l <= r ) ? 1. : 0.;
}
fprintf( stderr , "invalid operator %d\n" , int( node.op( ) ) );
std::abort( );
}( node , left , right ) };
return NewOwned< T_ConstantExprNode >( node.parent( ) , rVal );
}
} // namespace <anon>
@ -294,27 +317,20 @@ P_ExpressionNode T_ConstantFolder_::doBinaryOp(
bool opopt::FoldConstants(
T_RootNode& root ,
const T_Optional< std::pair< uint32_t , uint32_t > > fixedSize ,
T_SyncCurves const* curves ,
T_Visitor< A_Node >* const extVisitor ) noexcept
T_OpsParserOutput& program ,
T_OptData& oData ) noexcept
{
M_VISITOR_( extVisitor );
T_ConstantFolder_ folder;
folder.fixedSize = fixedSize;
folder.curves = curves;
visitor.visit( root , folder );
T_ConstantFolder_ folder{ oData };
oData.visitor.visit( program.root , folder );
return folder.didFold;
}
/*= DEAD CODE REMOVAL ========================================================*/
bool opopt::RemoveDeadCode( opast::T_RootNode& root ,
ebcl::T_Visitor< opast::A_Node >* extVisitor ) noexcept
bool opopt::RemoveDeadCode(
T_OpsParserOutput& program ,
T_OptData& oData ) noexcept
{
M_VISITOR_( extVisitor );
#warning blargh
return false;
}

View file

@ -3,35 +3,41 @@
#include <ebcl/Algorithms.hh>
struct T_OpsParserOutput;
struct T_SyncCurves;
namespace opopt {
// Persistent data for the various stages of the optimizer.
struct T_OptData
{
// List of errors generated by the optimizer
T_Array< ebcl::T_SRDError > errors;
// If the size of the ouput is fixed, this field contains it as a
// <width,height> pair.
T_Optional< std::pair< uint32_t , uint32_t > > fixedSize;
// The curves that will be bound to the inputs.
T_SyncCurves const* curves;
// A visitor to be used for the tree
ebcl::T_Visitor< opast::A_Node > visitor{ opast::ASTVisitorBrowser };
};
/*----------------------------------------------------------------------------*/
// Attempts to fold constant expressions into single constants. Returns true if
// transformations were made, false if not.
//
// Parameters:
// root the root node
// fixedSize the size of the output, if it is fixed
// curves the curves that will be bound to the inputs
// extVisitor a node visitor instance to be used instead of creating
// one
//
bool FoldConstants( opast::T_RootNode& root ,
T_Optional< std::pair< uint32_t , uint32_t > > fixedSize = { } ,
T_SyncCurves const* curves = nullptr ,
ebcl::T_Visitor< opast::A_Node >* extVisitor = nullptr ) noexcept;
bool FoldConstants( T_OpsParserOutput& program ,
T_OptData& optData ) noexcept;
// Attempt to remove blocks of code that will not be executed because of
// constant conditions. Returns true if transformations were made, false if not.
//
// Parameters:
// root the root node
// extVisitor a node visitor instance to be used instead of creating
// one
//
bool RemoveDeadCode( opast::T_RootNode& root ,
ebcl::T_Visitor< opast::A_Node >* extVisitor = nullptr ) noexcept;
bool RemoveDeadCode( T_OpsParserOutput& program ,
T_OptData& optData ) noexcept;
} // namespace opopt