package fr.irisa.cairn.jnimap.isl.jni.extra;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

import fr.irisa.cairn.jnimap.isl.jni.ISLFactory;
import fr.irisa.cairn.jnimap.isl.jni.ISLPrettyPrinter.ISL_FORMAT;
import fr.irisa.cairn.jnimap.isl.jni.JNIISLAffine;
import fr.irisa.cairn.jnimap.isl.jni.JNIISLAffineList;
import fr.irisa.cairn.jnimap.isl.jni.JNIISLBasicSet;
import fr.irisa.cairn.jnimap.isl.jni.JNIISLDim;
import fr.irisa.cairn.jnimap.isl.jni.JNIISLDimType;
import fr.irisa.cairn.jnimap.isl.jni.JNIISLMap;
import fr.irisa.cairn.jnimap.isl.jni.JNIISLSet;
import fr.irisa.cairn.jnimap.isl.jni.JNIISLSpace;
import fr.irisa.cairn.jnimap.isl.jni.JNIISLUnionMap;
import fr.irisa.cairn.jnimap.isl.jni.JNIISLUnionSet;

public class JNIISLTools {
	
	public static JNIISLMap buildRenameMapString(List<String> newIndicesNames, String tupleName, List<String> parametersNames,
			List<String> indicesName) {
		int size = indicesName.size();
		JNIISLMap islMap;
		if(size!=newIndicesNames.size()) {
			throw new UnsupportedOperationException("Not yet implemented");
		} else {
			String map_next_string = parametersNames+" -> { "+tupleName+indicesName+" -> "+tupleName+newIndicesNames+" : ";
			for(int j=0; j<size;j++) {
				if(j>0) {
					map_next_string+=" & ";
				}
				map_next_string+=newIndicesNames.get(j)+" = "+indicesName.get(j);
			}
			islMap = ISLFactory.islMap(map_next_string +" }");
		}
		return islMap;
	}


	public static JNIISLMap renameDimensions(JNIISLMap map, JNIISLDimType type, List<String> names) {
		List<String> parametersNames = map.getParametersNames();
		switch (type.getValue()) {
			case JNIISLDimType.ISL_DIM_IN:
				String tupleName = map.getTupleName(JNIISLDimType.isl_dim_in);

				List<String> indicesName = map.getDomainNames();
				JNIISLMap buildRenameMapString = buildRenameMapString(names, tupleName, parametersNames,indicesName);

				return JNIISLMap.applyDomain(map.copy(), buildRenameMapString);
	
			case JNIISLDimType.ISL_DIM_OUT:
				tupleName = map.getTupleName(JNIISLDimType.isl_dim_out);

				indicesName = map.getRangeNames();
				buildRenameMapString = buildRenameMapString(names, tupleName, parametersNames,indicesName);
				return JNIISLMap.applyRange(map.copy(), buildRenameMapString);
			default:
				throw new UnsupportedOperationException("Not yet implemented");
			
		}
	}

	public static JNIISLSet renameDimensions(JNIISLSet set, List<String> names) {
		List<String> parametersNames = set.getParametersNames();
		String tupleName = set.getTupleName();


		List<String> indicesName = set.getIndicesNames();
		JNIISLMap buildRenameMapString = buildRenameMapString(names, tupleName, parametersNames,indicesName);
		return set.copy().apply(buildRenameMapString);
	}

	public static JNIISLMap identity(JNIISLMap map) {
	
		List<String> domainNames = map.getDomainNames();
		int size = domainNames.size();
		JNIISLMap islMap;
		List<String> rangeNames = map.getRangeNames();
		if(size!=rangeNames.size()) {
			throw new UnsupportedOperationException("Not yet implemented");
		} else {
			List<String> parametersNames = map.getParametersNames();
			String inTupleName = map.getTupleName(JNIISLDimType.isl_dim_in);
			String outTupleName = map.getTupleName(JNIISLDimType.isl_dim_out);
			String map_next_string = parametersNames+" -> { "+inTupleName+domainNames+" -> "+outTupleName+rangeNames+" : ";
			for(int j=0; j<size;j++) {
				if(j>0) {
					map_next_string+=" & ";
				}
				map_next_string+=rangeNames.get(j)+" = "+domainNames.get(j);
			}
			islMap = ISLFactory.islMap(map_next_string +" }");
		}
		return islMap;
	}

	public static JNIISLMap shift(JNIISLSet set, int dimId, int offset) {
		
		JNIISLMap empty = JNIISLMap.fromDomainAndRange(set.copy(), set.copy());
//		dim.getNameList(JNIISLDimType.isl_dim_in);
//		for(String index : dim.getNameList(JNIISLDimType.isl_dim_in)) {
//			
//		}
		return empty;
	}

	/**
	 * Simply add dimensions, without changing current ones.
	 * @param set
	 * @param dims
	 * @return
	 */
	public static JNIISLSet extendDimsWith(JNIISLSet set, List<String> dims) {
		if (dims.isEmpty()) return set.copy();
		set = set.copy().makeDisjoint();
		JNIISLSpace dimsOrig = set.getSpace();
		JNIISLSet res = null;
		int nbBset = set.getNumberOfBasicSet();
		for (int i = 0; i < nbBset; i++) {
			JNIISLBasicSet bs = set.getBasicSetAt(i);
			String base = "[";
			//params
			List<String> parametersNames = bs.getParametersNames();
			for (int k = 0; k < parametersNames.size(); k++) {
				if (k > 0) base += ",";
				base += parametersNames.get(k);
			}
			String tupleName = set.getTupleName();
			if (tupleName == null) tupleName = "";
			base += "] -> { "+tupleName+"[";
			//domain's current dimensions
			List<String> current = bs.getIndicesNames();
			boolean first = true;
			
			for (int j = 0; j < dimsOrig.getSize(JNIISLDimType.isl_dim_set); j++) {
				if (first) first=false;
				else base += ",";
				base += (current.get(j) == null)?"i"+j:current.get(j);
			}
			
			for (String s : dims) {
				if (first) first=false;
				else base += ",";
				base += s;
			}
			base += "]  : ";
			String pa = bs.toString(ISL_FORMAT.ISL);
			pa = pa.substring(pa.indexOf(":")+1).replace("}", "").trim();
			String bstring = base + pa + "}";
			
			
			JNIISLSet s = ISLFactory.islSet(bstring);
			if (res == null) res = s;
			else res = JNIISLSet.union(res, s);
		}
		return res;
	}


	/**
	 * Simply add dimensions, without changing current ones.
	 * @param set
	 * @param dims
	 * @return
	 */
	public static JNIISLSet renameIndices(JNIISLSet set, List<String> dims) {
		set = set.copy().makeDisjoint();
		JNIISLSpace dimsOrig = set.getSpace();
		
		List<String> origNames = dimsOrig.getNameList(JNIISLDimType.isl_dim_set);
		
		if (origNames.size() != dims.size()) {
			throw new RuntimeException();
		}
		
		String base = "";
		//params
		base += set.getParametersNames().toString();
		
		String tupleName = set.getTupleName();
		if (tupleName == null) tupleName = "";
		base += " -> { "+tupleName+"[";
		
		//domain's current dimensions
		List<String> current = set.getIndicesNames();
		boolean first = true;
		List<String> dimIn = new ArrayList<String>();
		for (int j = 0; j < dimsOrig.getSize(JNIISLDimType.isl_dim_set); j++) {
			if (first) first=false;
			else base += ",";
			String varName = (current.get(j) == null)?"i"+j:current.get(j);
			dimIn.add(varName);
			base += varName;
		}
		base += "] -> "+tupleName+"[";
		//new dimensions
		first = true;
		for (String s : dims) {
			if (first) first=false;
			else base += ",";
			base += s;
		}
		base += "]  : ";
		first=true;
		for (int i = 0; i < dimIn.size(); i++) {
			if (first) first=false;
			else base += " and ";
			base += dimIn.get(i)+" = "+dims.get(i);
		}
		String bstring = base + "}";
		
		System.out.println(bstring);
		JNIISLMap s = ISLFactory.islMap(bstring);
		
		JNIISLSet res = set.copy().apply(s);
		
		res = res.coalesce();
		return res;
	}
	
	/**
	 * This method infer the parameter context domain of a set, that is the
	 * parameter values that ensure that the set will contain at least one
	 * point.
	 * 
	 * @param set
	 * @return
	 */
	public static JNIISLSet inferParameterContextDomain(JNIISLSet set) {
		JNIISLSet contextSet = set.copy();
		
		long nDim = set.getNDim();
		if (nDim > 0) 
			contextSet = contextSet.projectOut(JNIISLDimType.isl_dim_set, 0, nDim);
		else {
			JNIISLMap identity = JNIISLMap.identity(contextSet.copy().getSpace()).reverse();
			contextSet = contextSet.apply(identity);
		}
			
		
//		JNIISLSet lexMin = set.copy().lexMin();
//		long nbDims = set.getDimensions().getSize(JNIISLDimType.isl_dim_set);
//		JNIISLSet contextSet = lexMin.copy().projectOut(JNIISLDimType.isl_dim_set, 0, nbDims).coalesce();
//		contextSet = contextSet.setTupleName(set.getTupleName());
//		contextSet = JNIISLTools.extendDimsWith(contextSet, set.getIndicesNames());
//		contextSet = contextSet.projectOut(JNIISLDimType.isl_dim_set, 0, contextSet.getNDim());
		return contextSet;
	}
	
	public static JNIISLSet inferParameterContextDomain(JNIISLUnionSet uset) {
		JNIISLSet res = null;
		
		for (JNIISLSet set : uset.getSets()) {
			JNIISLSet paramCtx = inferParameterContextDomain(set);
			if (res == null) res = paramCtx;
			else res = JNIISLSet.union(res, paramCtx);
		}
		res = res.coalesce();
		return res;
	}

	/**
	 * This methods expands the dimension of input domain to match those of target domain, 
	 * and then does return the intersection of the two "aligned domains".  
	 * @param input
	 * @param target
	 * @return
	 */
	public static JNIISLSet expandTo(JNIISLSet input, JNIISLSet target) {
		long nExtraDims = target.getNDim() - input.getNDim();
		JNIISLSet expandedSet = input.copy().insertDim(JNIISLDimType.isl_dim_set, input.getNDim(), nExtraDims);
		JNIISLMap map = JNIISLMap.fromDomainAndRange(expandedSet.copy(), target.copy());
		map = JNIISLMap.identity(map.getSpace().copy()); 
		expandedSet = expandedSet.apply(map.copy());
		expandedSet = JNIISLSet.intersect(expandedSet, target.copy());
		return expandedSet;
	}
	public static JNIISLBasicSet expandTo(JNIISLBasicSet input, JNIISLBasicSet target) {
		JNIISLSet inputSet = JNIISLSet.fromBasicSet(input);
		JNIISLSet targetSet = JNIISLSet.fromBasicSet(target);
		JNIISLSet expandedSet = expandTo(inputSet,targetSet);
		if (expandedSet.getNumberOfBasicSet() != 1) throw new RuntimeException();
		JNIISLBasicSet expanded = expandedSet.getBasicSetAt(0);
		return expanded;
	}

	public static JNIISLSet renameTo(JNIISLSet input, JNIISLSet target) {
		long nExtraDims = target.copy().getNDim() - input.copy().getNDim();
		if(nExtraDims!=0) {
			throw new UnsupportedOperationException("Not yet implemented");
		}
		JNIISLSet renamed = input.copy();
		JNIISLMap map = JNIISLMap.fromDomainAndRange(renamed.copy(), target.copy());
		map = JNIISLMap.identity(map.getSpace().copy()); 
		renamed = renamed.apply(map.copy());
		return renamed;
	}


	public static JNIISLMap renameRangeTo(JNIISLMap input, JNIISLSpace targetDim) {
		JNIISLSet target = JNIISLSet.buildEmpty(targetDim);
		JNIISLSet range = input.copy().getRange();
		long nExtraDims = target.copy().getNDim() - range.copy().getNDim();
		if(nExtraDims!=0) {
			throw new UnsupportedOperationException("Not yet implemented");
		}
		JNIISLMap map = JNIISLMap.fromDomainAndRange(range.copy(), target.copy());
		map = JNIISLMap.identity(map.getSpace().copy()); 
		map = JNIISLMap.applyRange(input.copy(),map.copy());
		return map;
	}


	
	public static boolean allDisjointBasic(Collection<JNIISLBasicSet> sets) {
		Iterator<JNIISLBasicSet> iterator = sets.iterator();
		Collection<JNIISLSet> tmp = new ArrayList<JNIISLSet>(sets.size());
		while(iterator.hasNext()) {
			tmp.add(JNIISLSet.fromBasicSet(iterator.next().copy()));
		}
		return allDisjoint(tmp);
	}

	public static boolean allDisjoint(Collection<JNIISLSet> sets) {
		Iterator<JNIISLSet> iterator = sets.iterator();
		if (!iterator.hasNext()) return true;
		JNIISLSet init = iterator.next();
		JNIISLSet set = init.copy();
		boolean result = true;
		while (iterator.hasNext()) {
			JNIISLSet bset = iterator.next();
			if (bset.getSpace().isCompatibleWith(set.getSpace()) == 0) {
				result = false;
				break;
			}
			JNIISLSet set2 = bset;
			JNIISLSet tmp = JNIISLSet.intersect(set.copy(), set2.copy());
			if (tmp.isEmpty() == 0) {
				result = false;
				break;
			}
			set = JNIISLSet.union(set, set2);
		}
		return result;
	}


	/**
	 * this method returns this list of dimension's indices of the range 
	 * for which the dimension is always a scalar, what ever the set (of
	 * the range).
	 * @param umap
	 * @return
	 */
	public static List<Integer> scalarDims(JNIISLUnionMap umap) {
		if (umap.isEmpty()==0) {
			JNIISLUnionSet scheduleRange = umap.copy().getRange().coalesce();
			//all the statements should be scheduled in the same space
			if (scheduleRange.getNbSet() != 1) throw new RuntimeException();
			JNIISLSet scheduledDomain = scheduleRange.getSetAt(0);
			int nbDims = (int) scheduledDomain.getNDim();
			//the number of potential scalar dimensions is not greater 
			//than the number of dimensions of the scheduled domain.
			List<Integer> res = new ArrayList<Integer>(nbDims);
			for (int i = 0; i < nbDims; i++) {
				res.add(i);
			}
			
			for (JNIISLMap map : umap.getMaps()) {
				//the number of schedule per statement should be
				//exactly one.
				if (map.getNumberOfBasicMap() != 1) throw new RuntimeException();
				
				Map<JNIISLBasicSet, JNIISLAffineList> m = map.getBasicMapAt(0).lexmaxV2();
				//schedule for one given statement should be bijective.
				if (m.size() != 1) throw new RuntimeException();
				
				JNIISLAffineList l = m.entrySet().iterator().next().getValue();
				//schedule function should have the same number of dimensions 
				//as it's scheduled destination domain.
				if (l.getSize() != nbDims) throw new RuntimeException();
				
				for (int i = 0; i < l.getSize(); i++) {
					//if the current dimension is already known as non-
					//scalar, skip it.
					if (res.indexOf(i) == -1) continue;
					
					JNIISLAffine aff = l.getAffineAt(i);
					if (!isConstant(aff)) res.remove(res.indexOf(i));
				}
			}
			
			return res;
		} else {
			return new ArrayList<Integer>(0);
		}
	}
	
	/**
	 * This method returns true if the expression is a constant.
	 * @param aff
	 * @return
	 */
	public static boolean isConstant(JNIISLAffine aff) {
		JNIISLDimType dimType;
		
		dimType = JNIISLDimType.isl_dim_out;
		for (int i = 0; i < aff.getNbDim(dimType); i++) {
			if (aff.getCoefficientAt(dimType, i) != 0) return false;
		}
		
		dimType = JNIISLDimType.isl_dim_in;
		for (int i = 0; i < aff.getNbDim(dimType); i++) {
			if (aff.getCoefficientAt(dimType, i) != 0) return false;
		}
		
		dimType = JNIISLDimType.isl_dim_param;
		for (int i = 0; i < aff.getNbDim(dimType); i++) {
			if (aff.getCoefficientAt(dimType, i) != 0) return false;
		}
		
		//we assume that if a div is a constant, then it has been
		//simplified.
		//XXX : should we assume that?
		dimType = JNIISLDimType.isl_dim_div;
		for (int i = 0; i < aff.getNbDim(dimType); i++) {
			if (aff.getCoefficientAt(dimType, i) != 0) return false;
		}
		
		return true;
	}
}
