package fr.irisa.cairn.model.integerLinearAlgebra.factory;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

import fr.irisa.cairn.model.integerLinearAlgebra.IVariable;
import fr.irisa.cairn.model.integerLinearAlgebra.IntConstraint;
import fr.irisa.cairn.model.integerLinearAlgebra.IntConstraintSystem;
import fr.irisa.cairn.model.integerLinearAlgebra.IntDivExpression;
import fr.irisa.cairn.model.integerLinearAlgebra.IntExpression;
import fr.irisa.cairn.model.integerLinearAlgebra.IntLinearConstraint;
import fr.irisa.cairn.model.integerLinearAlgebra.IntLinearConstraintSystem;
import fr.irisa.cairn.model.integerLinearAlgebra.IntLinearExpression;
import fr.irisa.cairn.model.integerLinearAlgebra.IntMax;
import fr.irisa.cairn.model.integerLinearAlgebra.IntMaxExpression;
import fr.irisa.cairn.model.integerLinearAlgebra.IntMin;
import fr.irisa.cairn.model.integerLinearAlgebra.IntMinExpression;
import fr.irisa.cairn.model.integerLinearAlgebra.IntMulExpression;
import fr.irisa.cairn.model.integerLinearAlgebra.IntProdExpression;
import fr.irisa.cairn.model.integerLinearAlgebra.IntSumExpression;
import fr.irisa.cairn.model.integerLinearAlgebra.IntTermExpression;
import fr.irisa.cairn.model.integerLinearAlgebra.IntegerLinearAlgebraFactory;
import fr.irisa.cairn.model.integerLinearAlgebra.Operator;
import fr.irisa.cairn.model.integerLinearAlgebra.Scope;
import fr.irisa.cairn.model.integerLinearAlgebra.SymbolVariable;
import fr.irisa.cairn.model.integerLinearAlgebra.tom.Isolate;

public class IntegerExpressionUserFactory {

	public static IntegerLinearAlgebraFactory factory = IntegerLinearAlgebraFactory.eINSTANCE;

	public static IntExpression add(IntExpression a, IntExpression b) {
		IntSumExpression res = factory.createIntSumExpression();
		res.getTerms().add(copyWhenContained(a));
		res.getTerms().add(copyWhenContained(b));
		return res;
	}

	private static IntExpression copyWhenContained(IntExpression a) {
		if (a.eContainer() != null)
			return a.copy();
		else
			return a;
	}

	private static IntConstraint copyWhenContained(IntConstraint a) {
		if (a.eContainer()!=null)
			return a.copy();
		else
			return a;
	}

	private static IntLinearConstraint copyWhenContained(IntLinearConstraint a) {
		if (a.eContainer() != null)
			return a.copy();
		else
			return a;
	}

	public static IntExpression add(IntExpression a, int b) {
		return add(a, term(b));
	}

	public static IntExpression add(int b, IntExpression a) {
		return add(a, term(b));
	}

	public static IntLinearExpression linexp(IntTermExpression... a) {
		IntLinearExpression res = factory.createIntLinearExpression();
		for (IntExpression intExpression : a) {
			if (!(intExpression instanceof IntTermExpression)) {
				throw new RuntimeException("Cannot build linear expression ("
						+ a + ") out of non terminals" + intExpression);
			}
			res.getTerms().add(copyWhenContained(intExpression));
		}

		return res;
	}

	public static IntLinearExpression linexp(String string,
			SymbolVariable... vars) {
		IntExpressionParser parser = new IntExpressionParser();
		return linexp(parser.parseIntExpression(string, vars));
	}

	public static IntLinearExpression linexp(String string,
			List<SymbolVariable> vars) {
		SymbolVariable[] arrayOfSymbols = vars.toArray(new SymbolVariable[] {});
		return linexp(string, arrayOfSymbols);
	}

	public static IntExpression tryAsLinexp(IntExpression b) {
		if (b instanceof IntLinearExpression) {
			return (IntLinearExpression) b;
		}
		IntLinearExpression res = factory.createIntLinearExpression();
		if (b instanceof IntTermExpression) {
			res.getTerms().add(b);
			return res;
		} else if (b instanceof IntSumExpression) {
			IntSumExpression sumexp = (IntSumExpression) b;
			for (IntExpression intExpression : sumexp.getTerms()) {
				if (!(intExpression instanceof IntTermExpression)) {
					return b;
				}
				res.getTerms().add(intExpression.copy());
			}
			return res;
		} else {
			return b;
		}
	}

	public static IntLinearExpression linexp(IntExpression a) {
		IntExpression tryAsLinexp = a.simplify();
		if (tryAsLinexp==a && !(a instanceof IntLinearExpression)) {
			throw new RuntimeException("Cannot convert current expresion " + a
					+ "to linearexpression");

		} else {
			return (IntLinearExpression) tryAsLinexp;
		}
	}

	public static IntLinearExpression linexp(List<IntExpression> a) {
		IntLinearExpression res = factory.createIntLinearExpression();
		for (IntExpression intExpression : a) {
			if (!(intExpression instanceof IntTermExpression)) {
				throw new RuntimeException(
						"Cannot build linear expression out of non flattened integer expression "
								+ a);
			}
			res.getTerms().add(copyWhenContained(intExpression));
		}
		return res;
	}

	public static IntLinearExpression linexp() {
		IntLinearExpression res = factory.createIntLinearExpression();
		return res;
	}

	public static IntExpression sum(IntExpression... a) {
		IntSumExpression res = factory.createIntSumExpression();
		boolean linear = false;
		for (IntExpression intExpression : a) {
			if (!(intExpression instanceof IntTermExpression)) {
				linear = false;
			}
			res.getTerms().add(copyWhenContained(intExpression));
		}
		if (linear) {
			res = factory.createIntLinearExpression();
			throw new UnsupportedOperationException("Not yet implemented");

		}
		return res;
	}

	public static IntExpression sum(List<IntExpression> a) {
		IntSumExpression res = factory.createIntSumExpression();
		boolean linear = false;
		for (IntExpression intExpression : a) {
			if (!(intExpression instanceof IntTermExpression)) {
				linear = false;
			}
			res.getTerms().add(copyWhenContained(intExpression));
		}
		if (linear) {
			res = factory.createIntLinearExpression();
			throw new UnsupportedOperationException("Not yet implemented");

		}
		return res;
	}

	public static IntExpression sub(IntExpression a, IntExpression b) {
		return add(a, mul(-1, b));
	}

	public static IntExpression prod(IntExpression... a) {
		return prod(Arrays.asList(a));
	}

	public static IntExpression prod(List<IntExpression> a) {
		IntProdExpression res = factory.createIntProdExpression();
		for (IntExpression intExpression : a) {
			res.getTerms().add(copyWhenContained(intExpression));
		}
		return res;
	}

	public static IntExpression min(IntExpression... a) {
		IntMinExpression res = factory.createIntMinExpression();
		for (IntExpression intExpression : a) {
			res.getTerms().add(copyWhenContained(intExpression));
		}
		return res;
	}

	public static IntExpression max(IntExpression... a) {
		IntMaxExpression res = factory.createIntMaxExpression();
		for (IntExpression intExpression : a) {
			res.getTerms().add(copyWhenContained(intExpression));
		}
		return res;
	}

	public static IntExpression mul(int scalar, IntExpression a) {
		IntMulExpression t = factory.createIntMulExpression();
		t.setExpr(copyWhenContained(a));
		t.setFactor(scalar);
		return t;
	}

	public static IntExpression floord(IntExpression a, int denum) {
		IntDivExpression t = factory.createIntFloorDExpression();
		t.setDenum(denum);
		t.setExpr(copyWhenContained(a));
		return t;
	}

//	public static IntExpression floord(IntExpression a, IntExpression a) {
//		IntDivExpression t = factory.createIntFloorDExpression();
//		t.setDenum(scale);
//		t.setExpr(copyWhenContained(a));
//		return t;
//	}

	public static IntExpression ceild(IntExpression a, int denum) {
		IntDivExpression t = factory.createIntCeilDExpression();
		t.setDenum(denum);
		t.setExpr(copyWhenContained(a));
		return t;
	}

	public static IntExpression mod(IntExpression a, int scale) {
		IntDivExpression t = factory.createIntModExpression();
		t.setDenum(scale);
		t.setExpr(copyWhenContained(a));
		return t;
	}

	public static IntTermExpression term(int scale, IVariable a) {
		IntTermExpression t = factory.createIntTermExpression();
		t.setValue(scale);
		t.setVar(a);
		return t;
	}

	public static IntTermExpression term(IVariable a) {
		IntTermExpression t = factory.createIntTermExpression();
		t.setVar(a);
		t.setValue(1);
		return t;
	}

	public static IntTermExpression term(int scale) {
		IntTermExpression t = factory.createIntTermExpression();
		t.setValue(scale);
		return t;
	}

	public static SymbolVariable var(String name) {
		SymbolVariable v = factory.createSymbolVariable();
		v.setName(name);
		return v;
	}

	public static IntLinearConstraint constraint(String constraint,
			Map<String, IVariable> varMap) {
		//IntLinearConstraint constr = factory.createIntLinearConstraint();
		IntExpressionParser parser = new IntExpressionParser();
		return parser.parseIntLinearConstraint(constraint, varMap);
	}

	public static IntLinearConstraint constraint(String string,
			SymbolVariable... vars) {
		//IntLinearConstraint constr = factory.createIntLinearConstraint();
		IntExpressionParser parser = new IntExpressionParser();
		return parser.parseIntLinearConstraint(string, vars);
	}

	public static IntLinearConstraint constraint(String string,List<SymbolVariable> vars) {
		SymbolVariable[] arrayOfSymbols = vars.toArray(new SymbolVariable[] {});
		return constraint(string, arrayOfSymbols);
	}

	public static IntLinearConstraintSystem linConstraintSystem() {
		return factory.createIntLinearConstraintSystem();
	}

	public static IntConstraintSystem intConstraintSystem() {
		return factory.createIntConstraintSystem();
	}

	@SuppressWarnings({ "unchecked", "rawtypes" })
	public static IntConstraintSystem intConstraintSystem(List<IntConstraint> constraints) {
		boolean linear=true;
		for (IntConstraint intConstraint : constraints) {
			if(!(intConstraint instanceof IntLinearConstraint)) {
				linear = false;
				break;
			}
		}
		if(!linear) {
			IntConstraintSystem createIntConstraintSystem = factory.createIntConstraintSystem();
			createIntConstraintSystem.getConstraints().addAll(constraints);
			return createIntConstraintSystem;
		} else {
			return linConstraintSystem((List)constraints);
		}
	}

	
	
	public static IntLinearConstraintSystem linConstraintSystem(IntLinearConstraint... constraints) {
		IntLinearConstraintSystem res = factory.createIntLinearConstraintSystem();
		for (IntLinearConstraint constr : constraints) {
			res.getLinearConstraints().add(copyWhenContained(constr));
		}
		return res;
	}


	public static IntLinearConstraintSystem linConstraintSystem(List<IntLinearConstraint> constraints) {
		IntLinearConstraintSystem res = factory.createIntLinearConstraintSystem();
		for (IntLinearConstraint constr : constraints) {
			res.getLinearConstraints().add(copyWhenContained(constr));
		}
		return res;
	}

	public static IntConstraintSystem constraintSystem() {
		return factory.createIntConstraintSystem();
	}

	
	public static IntConstraintSystem constraintSystem(IntConstraint... constraints) {
		return constraintSystem(Arrays.asList(constraints));
	}

	public static IntConstraintSystem constraintSystem(List<IntConstraint> constraints) {
		boolean linear=true;
		List<IntConstraint> newList = new ArrayList<IntConstraint>();
		for (IntConstraint intConstraint : constraints) {
			IntConstraint simplify = intConstraint.copy().simplify();
			newList.add(simplify);
			if(!simplify.isAffine()) {
				linear=false;
			}
		}

		IntConstraintSystem res =null;
		if(linear) {
			res = factory.createIntLinearConstraintSystem();
		} else {
			res = factory.createIntConstraintSystem();
		}
		for (IntConstraint constr : newList) {
			res.getConstraints().add(copyWhenContained(constr));
		}

		return res;
	}

	public static IntLinearConstraintSystem linearConstraintSystem(List<IntLinearConstraint> constraints) {
		IntLinearConstraintSystem res = factory.createIntLinearConstraintSystem();
		for (IntLinearConstraint constr : constraints) {
			res.getLinearConstraints().add(copyWhenContained(constr));
		}
		return res;
	}

	public static Scope scope() {
		return factory.createScope();
	}

	public static IntLinearConstraintSystem negate(
			IntLinearConstraintSystem system) {
		IntLinearConstraintSystem res = linConstraintSystem();
		for (IntLinearConstraint constraint : system.getLinearConstraints()) {
			Operator newOp = negate(constraint.getComparisonOperator());
			IntLinearConstraint newConstraint = linConstraint((IntLinearExpression)constraint.getExpr(), newOp);
			res.getLinearConstraints().add(newConstraint);
		}
		return res;
	}
	
	public static List<? extends IntConstraintSystem> negate(List<? extends IntConstraintSystem> subSystems) {
		List<IntLinearConstraintSystem> res = new ArrayList<IntLinearConstraintSystem>();
		negateRec(subSystems,res); //, new int[0]);
		return res;
	}
	
	
	private static void negateRec(List<? extends IntConstraintSystem> subSystems, List<IntLinearConstraintSystem> res, int ... pos) {
		if (pos.length >= subSystems.size()) {
			IntLinearConstraintSystem newSys = linConstraintSystem();
			for (int i = 0; i < pos.length; i++) {
				IntConstraint c = subSystems.get(i).getConstraints().get(pos[i]).copy();
				IntLinearConstraint newC = IntegerExpressionUserFactory.negate(IntegerExpressionUserFactory.linConstraint(c));
				newSys.getLinearConstraints().add(newC);
			}
			res.add(newSys);
		} else {
			int[] newPos = new int[pos.length + 1];
			for (int i = 0; i < pos.length; i++)
				newPos[i] = pos[i];
			
			IntConstraintSystem sys = subSystems.get(pos.length);
			for (int i = 0; i < sys.getConstraints().size(); i++) {
				newPos[pos.length] = i;
				IntConstraint intConstraint = sys.getConstraints().get(i);
				if (intConstraint.getComparisonOperator() == Operator.EQ) {
					intConstraint.setComparisonOperator(Operator.GE);
					negateRec(subSystems, res, newPos);
					intConstraint.setComparisonOperator(Operator.LE);
					negateRec(subSystems, res, newPos);
					intConstraint.setComparisonOperator(Operator.EQ);
				} else {
					negateRec(subSystems, res, newPos);					
				}
			}
		}
	}

	private static Operator negate(Operator comparisonOperator) {
		Operator newOperator;
		switch (comparisonOperator) {
		case GE:
			newOperator = Operator.LT;
			break;
		case LE:
			newOperator = Operator.GT;
			break;
		case GT:
			newOperator = Operator.LE;
			break;
		case LT:
			newOperator = Operator.GE;
			break;
		case EQ:
			newOperator = Operator.NE;
			break;
		case NE:
			newOperator = Operator.EQ;
			break;
		default:
			throw new UnsupportedOperationException("Not yet implemented");
		}
		return newOperator;
	}

	public static IntLinearConstraint negate(IntLinearConstraint constraint) {
		Operator newOp = negate(constraint.getComparisonOperator());
		IntLinearConstraint res = linConstraint((IntLinearExpression)constraint.getExpr().copy(),newOp);
		return res;
	}

	public static IntLinearConstraint mul(IntLinearConstraint constraint,
			int value) {
		IntLinearExpression exp = IntegerExpressionUserFactory.linexp(mul(value, constraint
				.getExpr()));
		if (value < 0) {
			Operator opposite;
			switch (constraint.getComparisonOperator()) {
			case GE:
				opposite = Operator.LE;
				break;
			case LE:
				opposite = Operator.GE;
				break;
			case GT:
				opposite = Operator.LT;
				break;
			case LT:
				opposite = Operator.GT;
				break;
			default:
				opposite = constraint.getComparisonOperator();
			}
			return linConstraint(exp, opposite);
		}
		return linConstraint(exp, constraint.getComparisonOperator());
	}

	/**
	 * Build an expression corresponding to a variable isolation. Variable is
	 * removed and other terms are negated. Input expression isn't modified.
	 * 
	 * @param expression
	 * @param var
	 *            variable to isolate
	 * @return a new expression
	 */
	public static IntExpression isolate(IntExpression exp, IVariable variable) {
		IntExpression i = Isolate.isolate(exp, variable);
		return i;
	} 

	public static IntLinearConstraint linConstraint(IntConstraint c) {
		return linConstraint(linexp(c.getExpr()), c.getComparisonOperator());
	}
	
	public static IntLinearConstraint linConstraint(IntLinearExpression a,	Operator newOperator) {
		IntLinearConstraint res = factory.createIntLinearConstraint();
		res.setExpr((IntLinearExpression) copyWhenContained(a));
		res.setComparisonOperator(newOperator);
		return res;
	}
 
	public static IntConstraint constraint(IntExpression a,	Operator newOperator) {
		IntExpression exp = tryAsLinexp((IntExpression) copyWhenContained(a));
		if(exp instanceof IntLinearExpression) {
			return linConstraint((IntLinearExpression) exp, newOperator);
		} else {
			IntConstraint res = factory.createIntConstraint();
			res.setExpr((IntExpression) copyWhenContained(a));
			res.setComparisonOperator(newOperator);
			return res;
		}
	}

	public static IntConstraint constraint(IntExpression lhs,	Operator newOperator,IntExpression rhs) {
		return constraint(add(lhs,mul(-1,rhs)), newOperator);
	}
	
	public static IntMax intMax() {
		IntMax max = factory.createIntMax();
		max.setWidth(32);
		max.setSigned(true);
		return max;
	}
	
	public static IntMin intMin() {
		IntMin min = factory.createIntMin();
		min.setWidth(32);
		min.setSigned(true);
		return min;
	}
}
