package fr.irisa.cairn.model.integerLinearAlgebra.tom;


import tom.library.sl.*;
import java.util.List;

  

import fr.irisa.cairn.model.integerLinearAlgebra.*;
import fr.irisa.cairn.model.integerLinearAlgebra.adapter.*;
import static fr.irisa.cairn.model.integerLinearAlgebra.factory.IntegerExpressionUserFactory.*;
import fr.irisa.cairn.model.integerLinearAlgebra.factory.IntegerExpressionUserFactory;
import java.util.List;
import tom.library.sl.VisitFailure;
 
   
public class SimplifyIntExpression  {
   
	%include { sl.tom }
	%include { integerlinearalgebra.tom}
  
	public static int Math_max(int a, int b) {
		return Math.max(a,b);
	}
   
	public static int Math_min(int a, int b) {
		return Math.min(a,b);
	} 

	public static int Math_mod(int a, int b) {
		return a%b;
	} 
	
	public static int Math_mul(int a, int b) {
		return a*b;
	} 

	public static int Math_floord(int a, int b) {
		return a/b;
	} 

	public static int Math_ceild(int a, int b) {
		return (a+b-1)/b;
	} 
 
	public static IntExpression simplify(IntExpression expression) {
		try {
			IntExpressionAdapter adapter = (IntExpressionAdapter)
					`InnermostId(Simplify()).visitLight(new IntExpressionAdapter(expression));
			return tryAsLinexp(adapter.expr());
		} catch(Exception e) { 
			e.printStackTrace(); 
			throw new RuntimeException("Visitor faliure on "+expression+ ":"+e.getMessage());
		}
	}
    
	%strategy Simplify() extends Identity() {
		visit IntExpression {


				
				sum(list(x1*,term(a,v),x2*,term(b,v),x3*)) 	-> { return `sum(list(term(a+b,v),x1,x2,x3)); }
				sum(list(x1*,cst(a),x2*,cst(b),x3*)) 		-> { return `sum(list(x1,x2,x3,cst(a+b))); }
				sum(list(y))								-> { return `y;	}
				sum(list(x1*,sum(list(x2*)),x3*)) 			-> { return `sum(list(x1,x2,x3)); }
				sum(list(x1*,cst(0),x2*)) 					-> { return `sum(list(x1,x2)); }
				
				prod(list(y))								-> { return `y;	}
				prod(list(x1*,cst(0),x2*)) 					-> { return `cst(0); }
				prod(list(x1*,cst(1),x2*)) 					-> { return `prod(list(x1,x2)); }
				prod(list(cst(x), term(a,b))) 				-> { return `term((x)*(a), b); }
				prod(list(term(a,b), cst(x))) 				-> { return `term((x)*(a), b); }

				min(list(x1*,term(a,v),x2*,term(b,v),x3*)) 	-> { return `min(list(term(Math_min(a,b),v),x1,x2,x3)); }
				min(list(x1*,cst(a),x2*,cst(b),x3)) 		-> { return `min(list(cst(Math_min(a,b)),x1,x2,x3)); }
				min(list(a))								-> { return `a; }
				min(list(x1*,min(list(x2*)),x3*)) 			-> { return `min(list(x1,x2,x3)); }
				min(list(cst(a),cst(a))) 			-> { return `cst(a); }
				min(list(term(a,b),term(a,b))) 			-> { return `term(a,b); }

				max(list(x1*,term(a,v),x2*,term(b,v),x3*)) 	-> { return `max(list(term(Math_max(a,b),v),x1,x2,x3)); }
				max(list(x1*,cst(a),x2*,cst(b),x3)) 		-> { return `max(list(cst(Math_max(a,b)),x1,x2,x3)); }
				max(list(term(a,v)))						-> { return `term(a,v); }
				max(list(x1*,min(list(x2*)),x3*)) 			-> { return `max(list(x1,x2,x3)); }
				max(list(cst(a),cst(a))) 			-> { return `cst(a); }
				max(list(term(a,b),term(a,b))) 			-> { return `term(a,b); }

				mul(1,a)									-> { return `a;}
				mul(0,a)									-> { return `cst(0);}
				x@mul(c,mul(a,v))							-> { int r=(`a)*(`c); return `mul(r,v);}
				mul(c,term(a,v))							-> { int r=(`a)*(`c); return `term(r,v);}

				mul(c,cst(a))				-> { 
					int r=(`a)*(`c);
					return `cst(r); 
				}
			
				mul(c, sum(list(sum(list(x*, term(a,b), y*))))) -> {
						int ca = (`c)*(`a);
						return `sum(list(term(ca, b), mul(c, sum(list(x, y)))));
				}
				
				mul(c,sum(list(Y*,term(a,v),X*)))				-> { 
					int r=(`a)*(`c);
					return `sum(list(term((r),v),mul(c,sum(list(Y,X))))); 
				}

 
				mul(c,sum(list(Y*,cst(a),X*)))				-> { 
					int r=(`a)*(`c);
					return `sum(list(cst(r),mul(c,sum(list(Y,X))))); 
				}
				

				floord(sum(list(X1*,term(a,v),X2*)),c)		-> { 
					if(((`a)%(`c))==0) {
						int r=(`a)/(`c);
						return `sum(list(term((r),v),floord(sum(list(X1,X2)),c)));
					} 
				}

				x@ceild(term(v,a),b)		-> { 
					int r=(`v+(`b)-1)/(`b);
					//System.err.println("x@ceild(term(v,a),b) ->x="+`x.expr()+",r="+r);
					return `cst(r);
				}

				x@ceild(cst(v),b)		-> { 
					int r=(`v+(`b)-1)/(`b);
					//System.err.println("x@ceild(cst(v,a),b) ->x="+`x.expr()+",r="+r);
					return `cst(r);
				}

/*
				x@ceild(a@sum(s),b)		-> { 
					IntExpressionAdapter r = new IntExpressionAdapter(`a.expr().simplify());
					//System.err.println("x@ceild(a@sum(s),b)->x="+`x.expr()+",r="+r.expr());
					return `ceild(r,b);
				}

				x@ceild(a@mul(s,f),b)		-> { 
					IntExpressionAdapter r = new IntExpressionAdapter(`a.expr().simplify());
					//System.err.println("x@ceild(a@mul(s,f),b)->x="+`x.expr()+",r="+r.expr());
					return `ceild(r,b);
				}

				x@floord(a,b)		-> { 
					IntExpressionAdapter r = new IntExpressionAdapter(`a.expr().simplify());
					return `floord(r,b);
				}

				x@mod(a,b)		-> { 
					IntExpressionAdapter r = new IntExpressionAdapter(`a.expr().simplify());
					return `mod(r,b);
				}

				x@mul(a,b)		-> { 
					IntExpressionAdapter r = new IntExpressionAdapter(`b.expr().simplify());
					IntExpressionAdapter nr = `mul(a,r);
					//System.err.println("x@mul(a,b)->x="+`x.expr()+",nr="+nr.expr());
					return nr;
				}
*/
				floord(cst(a),b)		-> { 
					int r=(`a)/(`b);
					return `cst(r);
				}

				mod(sum(list(X1*,term(a,v),X2*)),c)		-> { 
					if(((`a)%(`c))==0) {
						int r=(`a)/(`c);
						return `sum(list(mod(sum(list(X1,X2)),c)));
					} 
				}

				//mod(sum(list(mul(a,X),mul(b,Y))),d) -> {	return `mod(sum(list(mul(Math_mod(a,d),X),mul(Math_mod(b,d),Y))),d);	}


				ceild(sum(list()),c)						-> { return `cst(0); }

				sum(list(cst(a)))							-> { return `cst(a); }
				sum(list())									-> { return `cst(0); }
				mod(sum(list()),c)							-> { return `cst(0); }
				floord(sum(list()),c)						-> { return `cst(0); }
				ceild(sum(list()),c)						-> { return `cst(0); }
				mul(c,sum(list()))							-> { return `cst(0); }
				
		} 
	}
 
}