package org.polymodel.verifier;

import static fr.irisa.cairn.model.integerLinearAlgebra.factory.IntegerExpressionUserFactory.linConstraint;
import static fr.irisa.cairn.model.integerLinearAlgebra.factory.IntegerExpressionUserFactory.linexp;
import static fr.irisa.cairn.model.integerLinearAlgebra.factory.IntegerExpressionUserFactory.sub;
import static fr.irisa.cairn.model.integerLinearAlgebra.factory.IntegerExpressionUserFactory.term;

import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

import org.eclipse.emf.ecore.util.EcoreUtil;
import org.polymodel.verifier.factory.VerifierUserFactory;
import org.polymodel.verifier.message.CAUSALITY_VIOLATION_TYPE;
import org.polymodel.verifier.message.MISMATCHED_DIMENSION_TYPE;
import org.polymodel.verifier.message.factory.VerifierMessageUserFactory;

import fr.irisa.cairn.model.integerLinearAlgebra.IVariable;
import fr.irisa.cairn.model.integerLinearAlgebra.IntLinearConstraint;
import fr.irisa.cairn.model.integerLinearAlgebra.IntLinearExpression;
import fr.irisa.cairn.model.integerLinearAlgebra.IntTermExpression;
import fr.irisa.cairn.model.integerLinearAlgebra.Operator;
import fr.irisa.cairn.model.integerLinearAlgebra.factory.IntegerExpressionUserFactory;
import fr.irisa.cairn.model.polymodel.AffineMapping;
import fr.irisa.cairn.model.polymodel.PolyhedralDomain;
import fr.irisa.cairn.model.polymodel.factory.PolyModelDefaultFactory;
import fr.irisa.cairn.model.polymodel.prdg.PRDG;
import fr.irisa.cairn.model.polymodel.prdg.PRDGNode;
import fr.irisa.cairn.model.polymodel.util.DomainOperations;
import fr.irisa.cairn.model.polymodel.util.DomainOperations.Constraint;

/**
 * NOTE: PRDG given should have edges from consumer to producer.
 * 
 * @author yuki
 *
 */
public class Verifier {
	
	protected VerifierOutput output;
	protected VERBOSITY verbosity_;
	protected static enum STATUS {SATISFIED, VIOLATED};
	protected PolyModelDefaultFactory PMfactory;
	protected DomainOperations domainOps;
	
	private static final boolean DEBUG = false;
	
	 protected final Map<Object, MemorySpace> memorySpaces;
	 protected final Map<MemoryMap, VerifierNode> memoryMapToNode;
	 
	 protected final VerifierInput verifierInput;
	 
	 protected Verifier(PolyModelDefaultFactory PMfactory, VerifierInput input) {
			this.verifierInput = input;
			this.PMfactory = PMfactory;
			domainOps = new DomainOperations(PMfactory);
			verbosity_ = VERBOSITY.MAX;
			
			//Construct the map from memory spaces to memory maps
			memorySpaces = new HashMap<Object, MemorySpace>();
			memoryMapToNode = new HashMap<MemoryMap, VerifierNode>();
			
			for (VerifierNode node : verifierInput.getNodes()) {
				MemoryMap map = node.getMemoryMap();
				if (!memorySpaces.containsKey(map.getTarget())) {
					memorySpaces.put(map.getTarget(), new MemorySpace(map.getTarget()));
				}

				memorySpaces.get(map.getTarget()).maps.add(map);
				memorySpaces.get(map.getTarget()).nodes.add(node);
				memoryMapToNode.put(node.getMemoryMap(), node);
			}
	 }

	public static VerifierOutput verify(PolyModelDefaultFactory PMfactory, VerifierInput input) {
		Verifier verifier = new Verifier(PMfactory, input);
				
		return verifier.run();
	}

	public static VerifierOutput verify(PolyModelDefaultFactory PMfactory, PRDG prdg,  Map<PRDGNode, AffineMapping> spaceTimeMaps, Map<PRDGNode, MemoryMap> memoryMaps, Map<PRDGNode, List<DIM_TYPE>> dimTypes) {
		VerifierInput input = VerifierUserFactory.createVerifierInput(prdg, spaceTimeMaps, memoryMaps, dimTypes);
		Verifier verifier = new Verifier(PMfactory, input);
				
		return verifier.run();
	}
	
	private VerifierOutput run() {
		output = VerifierUserFactory.initializeOutput();
		
		
		//verify specifications given
		boolean specValid = verifySpecs(); //FIXME
		//when spec is not valid, then no further check is possible
		if (!specValid) {
			output.setValid(false);
			return output;
		}
		
		//verify space time
		for (VerifierEdge edge : verifierInput.getEdges()) {
			verifySpaceTimeMapping(edge);
		}
		
		//FIXME same as below, messages may not necessarily be errors in the future
		if (output.getMessages().size() > 0) {
			output.setValid(false);
			return output;
		}

		//verify memory map
		verifyMemoryMaps();//FIXME

		
		//FIXME : you may want to have messages that are not errors but warning, 
		//and in such case the valid flag must be handled in a more explicit manner
		output.setValid(output.getMessages().size() == 0);
		
		return output;
	}
	
	/**
	 * returns true when valid
	 * 
	 * @return
	 */
	protected boolean verifySpecs() {
		int dim = -1;
		for (VerifierNode node : verifierInput.getNodes()) {
			List<DIM_TYPE> dims = node.getDimensionTypes();
			if (dim == -1) {
				dim = dims.size();
				continue;
			}
			if (dim != dims.size()) {
				output.getMessages().add(VerifierMessageUserFactory.createMismatchedDimensionalityError(MISMATCHED_DIMENSION_TYPE.DIM_TYPES));
				return false;
			}
		}
		
		return true;
	}
	
	
	/*
	 * returns true for valid or invalid edges, for unsatisfied edges, its false
	 * */
	protected STATUS verifySpaceTimeMapping(VerifierEdge edge) {
		
		if(DEBUG) {
			System.out.println("====  Verifying : " + edge + " ====");
		}
		
		AffineMapping srcSTmap = edge.getSource().getSpaceTimeMap();
		AffineMapping dstSTmap = edge.getDestination().getSpaceTimeMap();
		
		//FIXME only works for isl based PRDG + sch
		if (dstSTmap == null) {
			//dst is input
			return STATUS.SATISFIED;
		}
		
		List<DIM_TYPE> srcDims = edge.getSource().getDimensionTypes();
		List<DIM_TYPE> dstDims = edge.getDestination().getDimensionTypes();
		
		//for each dimension
		for (int dim = 0; dim < srcDims.size(); dim++) {
			//make sure the dimensions match
			if (srcDims.get(dim) != dstDims.get(dim)) {
				output.getMessages().add(VerifierMessageUserFactory.createMismatchedDimensionsError(edge, dim));
				return STATUS.VIOLATED;
			}
			
			//checks performed at each dimension is different based on the dimension type
			// z = points in the source (consumer)
			// z' = points in the destination (producer) such that z' = dep(z)
			// dep = dependence function
			// stmap = space-time mapping for each variable (which variable relies on context in the comments below)
			
			//create a mapping that only uses one dimension of the RHS
			AffineMapping srcSTmapDim = projectToADimension(srcSTmap, dim);
			AffineMapping dstSTmapDim = projectToADimension(dstSTmap, dim);
			
			if (DEBUG) {
				System.out.println("===   "+dim+"th dimension : " + srcDims.get(dim)+ " ===");
				System.out.println("srcSTmapDim : " + srcSTmapDim);
				System.out.println("dstSTmapDim : " + dstSTmapDim);
			}
			
			//parallel
			if (srcDims.get(dim) == DIM_TYPE.PARALLEL) {
				//make sure that dependencies do not cross processor dimension;
				// => make sure stmap(z) and stmap(z') in the dim_th dimension is equal in all points in the domain

				PolyhedralDomain LTdomain = EcoreUtil.copy(edge.getDomain());
				PolyhedralDomain GTdomain = EcoreUtil.copy(edge.getDomain());
				domainOps.addConstraintsRelatingTwoSetsOfIndices(LTdomain, new Constraint(srcSTmapDim, dstSTmapDim, Operator.LT));
				domainOps.addConstraintsRelatingTwoSetsOfIndices(GTdomain, new Constraint(srcSTmapDim, dstSTmapDim, Operator.GT));
				
				if (DEBUG) {
					System.out.println("LTdomain:"+LTdomain);
					System.out.println("GTdomain:"+GTdomain);
				}

				//check < || > instead of !=
				if (!LTdomain.isEmpty() || !GTdomain.isEmpty()) {
					if (DEBUG) {
						System.out.println("VIOLATED : " + LTdomain.union(GTdomain));
					}
					output.getMessages().add(VerifierMessageUserFactory.createCausalityError(edge, CAUSALITY_VIOLATION_TYPE.PROCESSOR, dim, LTdomain.union(GTdomain)));
					return STATUS.VIOLATED;
				}
				
				
			//sequential
			} else if (srcDims.get(dim) == DIM_TYPE.SEQUENTIAL || srcDims.get(dim) == DIM_TYPE.ORDERING) {
				//make sure that the producer is scheduled before the consumer
				// => make sure that stmap(z) > stmap(z') in the dim_th dimension for all points in the domain
				
				//the check is the same for ordering, but because the ordering dimensions is some constant integer, no need for polyhedral ops.
				//FIXME using the same check for simplicity but this is inefficient
				
				PolyhedralDomain LEdomain = EcoreUtil.copy(edge.getDomain());
				PolyhedralDomain LTdomain = EcoreUtil.copy(edge.getDomain());
			
				domainOps.addConstraintsRelatingTwoSetsOfIndices(LEdomain, new Constraint(srcSTmapDim, dstSTmapDim, Operator.LE));
				domainOps.addConstraintsRelatingTwoSetsOfIndices(LTdomain, new Constraint(srcSTmapDim, dstSTmapDim, Operator.LT));
				
				if (DEBUG) {
					System.out.println("LTdomain:"+LTdomain);
					System.out.println("LEdomain:"+LEdomain);
				}
				
				//if stmap(z) > stmap(z') its satisfied, stmap(z) >= stmap(z') its unsatisfied and if stmap(z) < stmap(z') its violated
				//check <= first and if its not empty check < and even that is not empty its violated

				//check <=
				if (LEdomain.isEmpty()) {
					if (DEBUG) {
						System.out.println("SATISFIED");
					}
					return STATUS.SATISFIED;
				}
				//check <
				//if its empty then the dependence is not satisfied but not violated either, proceed to next dimension
				if (LTdomain.isEmpty()) {
					continue;
				}


				//violated edge
				output.getMessages().add(VerifierMessageUserFactory.createCausalityError(edge, CAUSALITY_VIOLATION_TYPE.TIME, dim, LTdomain));
				if (DEBUG) {
					System.out.println("VIOLATED : " + LTdomain);
				}
				return STATUS.VIOLATED;
				
			} else {
				throw new RuntimeException("Unexpected dimension type : " + srcDims.get(dim));
			}
		}
		
		return STATUS.VIOLATED;
	}

	
	protected void verifyMemoryMaps() {

		for(MemorySpace space : memorySpaces.values()) {

			if (isValidMemorySpace(space)) {
				checkForWriteConflicts(space);
			}
		}		
	}

	/*
	 * All memory maps in a memory space should be same
	 */
	private boolean isValidMemorySpace(MemorySpace space) {
		
		if(space.maps.size() > 0) {
			MemoryMap memoryMap = space.maps.get(0);
			
			for(MemoryMap nextMemoryMap : space.maps.subList(1, space.maps.size())) {
				if (memoryMap.getMapping().getDimRHS() != nextMemoryMap.getMapping().getDimRHS()) {
					return false;
				}
			}	
		}
		return true;
	}
	
	
	protected void checkForWriteConflicts(MemorySpace space) {
		
		//FIXME test for bijection
		
		//FIXME node^2 check is necessary if the PRDG is from alphabets programs
		//this fix is unimplemented since its not necessary for ompVerify 
		
//		List<PRDGNode> sublist = space.nodes;
		
		for (VerifierNode nodeA : space.nodes) {
			if (hasWriteConflict(nodeA, nodeA)) {
				//errors_.add(new WriteConflictsError(TYPE.MEMORY, nodeA, memoryMaps.get(nodeA)));
				//TODO

			}
		}
	}
	
	/**
	 * This method constructs a domain that restrict the set of points to those write the same memory location.
	 * In the call at the following depth, the final check if the points are written at the same time (in-between synchronization) is done.
	 * 
	 * @param nodeA
	 * @param nodeB
	 * @return
	 */
	protected boolean hasWriteConflict(VerifierNode nodeA, VerifierNode nodeB) {
		//merge two domains to related the two
		PolyhedralDomain probDomain = domainOps.mergeDomains(nodeA.getDomain(), nodeB.getDomain());
		domainOps.addConstraintsRelatingTwoSetsOfIndices(probDomain, new Constraint(nodeA.getMemoryMap().getMapping(), nodeB.getMemoryMap().getMapping(), Operator.EQ));

		if (DEBUG) {
			System.out.println("nodeA("+nodeA.getName()+"):"+nodeA.getDomain());
			System.out.println("nodeB("+nodeB.getName()+"):"+nodeB.getDomain());
			System.out.println("nodeAB:"+probDomain);
			System.out.println("mapA:"+nodeA.getMemoryMap().getMapping());
			System.out.println("mapB:"+nodeB.getMemoryMap().getMapping());
			System.out.println("probDomain:"+probDomain);
		}
		//When it is the same node, then make sure the same point is avoided
		if (nodeA.equals(nodeB)) {
			//need a loop to lexicographically specify the constraints that avoids same iteration point 
			for (int dim=0; dim < nodeA.getDomain().getNIndices(); dim++) {
				//z > z'
				IntTermExpression termZ = term(probDomain.getIndices().get(dim));
				IntTermExpression termZp = term(probDomain.getIndices().get(dim+nodeA.getDomain().getNIndices()));
				IntLinearConstraint gtConstraint = linConstraint(linexp(sub(termZ, termZp)), Operator.GT);
				IntLinearConstraint eqConstraint = linConstraint(linexp(sub(termZ, termZp)), Operator.EQ);

				{
					PolyhedralDomain problem = EcoreUtil.copy(probDomain);
					problem.addConstraint(gtConstraint);
					if (DEBUG) {
						System.out.println("lexProb:"+problem);
						System.out.flush();
					}
				
					int[] dims = findConflictingDimensions(problem, nodeA, nodeB);
					for (int i = 0; i < dims.length; i++) {
						switch (nodeA.getDimensionTypes().get(dims[i])) {
							case PARALLEL:
								output.getMessages().add(VerifierMessageUserFactory.createWriteConflictError(nodeA, nodeA, i));
								return true;
							case SEQUENTIAL:
							case ORDERING:
								return false;
						}
					}
				}
				
				probDomain.addConstraint(eqConstraint);
			}
		//else its just one call
		} else {
			throw new RuntimeException("Write conflict across different variabels are unsupported : not sure if its needed");
		}

		System.err.println();
		
		return false;
	}


	/**
	 * Find a dimension where different iteration points write to the same memory location.
	 * 
	 * @param probDomain
	 * @param nodeA
	 * @param nodeB
	 * @return
	 */
	protected int[] findConflictingDimensions (PolyhedralDomain probDomain, VerifierNode nodeA, VerifierNode nodeB) {
		AffineMapping STmapA = nodeA.getSpaceTimeMap();
		AffineMapping STmapB = nodeB.getSpaceTimeMap();
		
		//Find some dimension where 
		for (int dim = 0; dim < STmapA.getDimRHS(); dim++) {
			AffineMapping STmapAdim = projectToADimension(STmapA, dim);
			AffineMapping STmapBdim = projectToADimension(STmapB, dim);
			//perform emptiness check twice for LT and GT to test for !=
			PolyhedralDomain LTdomain = EcoreUtil.copy(probDomain);
			PolyhedralDomain GTdomain = EcoreUtil.copy(probDomain);
			domainOps.addConstraintsRelatingTwoSetsOfIndices(LTdomain, new Constraint(STmapAdim, STmapBdim, Operator.LT));
			domainOps.addConstraintsRelatingTwoSetsOfIndices(GTdomain, new Constraint(STmapAdim, STmapBdim, Operator.GT));
			if (DEBUG) {
				System.err.println("LT:"+LTdomain);
				System.err.println("GT:"+GTdomain);
			}
			if (!LTdomain.isEmpty() || !GTdomain.isEmpty()) {
				if (DEBUG) {
					System.out.println("CONFLICT at " + dim);
				}
				return new int[]{dim};
			}
		}
		
		return new int[]{};
	}
	
	protected AffineMapping projectToADimension(AffineMapping map, int dim) {
		if (map.getDimRHS() <= dim || dim < -1) {
			throw new RuntimeException("Dimension out of range.");
		}
		
		List<String> indices = new LinkedList<String>();
		for (IVariable iv : map.getParams()) {
			indices.add(iv.toString());
		}
		for (int d = 0; d < map.getDimRHS(); d++) {
			indices.add("proj"+d);
		}
		List<String> expr = new LinkedList<String>();
		expr.add("proj"+dim);
		
		return PMfactory.affineMappingFromString(indices, map.getParams().size(), expr).compose(map);
	}

	protected boolean hasWriteConflicts(VerifierNode node) {
		boolean has_write_conflicts = false;
		
		return has_write_conflicts;
	}

	protected boolean isEmpty(PolyhedralDomain PolyhedralDomain, IntLinearConstraint affineForm) {

		PolyhedralDomain tempPolyhedralDomain = EcoreUtil.copy(PolyhedralDomain);
		tempPolyhedralDomain.addConstraint(affineForm);
		
		return tempPolyhedralDomain.isEmpty();		
	}
	
	protected boolean satisfies(PolyhedralDomain PolyhedralDomain, IntLinearConstraint affineForm) {

		PolyhedralDomain tempPolyhedralDomain = EcoreUtil.copy(PolyhedralDomain);
		PolyhedralDomain legalDomain = PMfactory.polyhedralDomain(IntegerExpressionUserFactory.linConstraintSystem(affineForm));
		
		tempPolyhedralDomain = tempPolyhedralDomain.intersection(legalDomain.complement());
		
		return tempPolyhedralDomain.isEmpty();		
	}
	
	

	public List<IntLinearExpression> getIntLinearExpressionList(PolyhedralDomain inputDom, AffineMapping f1, AffineMapping f2) {
		//check on inputs
		if (f1.getDimRHS() != f2.getDimRHS()) {
			throw new RuntimeException("Dimensionality of RHS of the two functions given must be equal.");
		}
		if (f1.getDimLHS() + f2.getDimLHS() != inputDom.getNIndices()) {
			throw new RuntimeException("Number of indices used in two functions given must match the number of indices in input PolyhedralDomain.");
		}

		//create two functions to relate indices on input PolyhedralDomain with the two functions given
		//Input PolyhedralDomain has dimensions for P + z + z', so we need :
		//  fa(z+z') = z and fb(z+z') = z', so that it can be composed with the two functions given, which is :
		//  f1(z) = ? and f2(z') = ?

		//First function, sublist of 0 to numIndices(f1) -- sublist is inclusive to exclusive
		List<String> indexNames = new LinkedList<String>();
		for (IVariable iv : inputDom.getIndices()) {
			indexNames.add(iv.toString());
		}
		List<String> sublistA = indexNames.subList(0, f1.getDimLHS());

		AffineMapping fa = PMfactory.affineMapping(inputDom.getParams(), inputDom.getIndices(), sublistA);
		
		//Second function, sublist of numIndices(f1) to numIndices(f2)
		List<String> sublistB = indexNames.subList(f1.getDimLHS(), f1.getDimLHS() + f2.getDimLHS());
		AffineMapping fb = PMfactory.affineMapping(inputDom.getParams(), inputDom.getIndices(), sublistB);

		AffineMapping domToZ = f1.compose(fa);
		AffineMapping domToZp = f2.compose(fb);
		
		//Create IntLinearConstraint for each dimension on the RHS
		//relate the result of the compose functions to each other using the given operator
		List<IntLinearExpression> exp_list = new LinkedList<IntLinearExpression>();
		for (int d = 0; d < f1.getDimRHS(); d++) {
			exp_list.add(IntegerExpressionUserFactory.linexp(
									IntegerExpressionUserFactory.sub(
											domToZ.getFunctions().get(d), domToZp.getFunctions().get(d))));
		}
		
		return exp_list;
	}
	
	private class MemorySpace {
		public final Object target;
		public final List<VerifierNode> nodes = new LinkedList<VerifierNode>();
		public final List<MemoryMap> maps = new LinkedList<MemoryMap>();
		
		protected MemorySpace(Object target) {
			this.target = target;
		}
	}
	
}