package um.matrix.AC3BAC;

import choco.*;
import choco.integer.*;
import choco.integer.constraints.*;
import choco.util.*;
import java.util.*;

public class AC3BACMatrix extends AbstractLargeIntConstraint {
    private boolean debug;
    //mapping from variable to matrix index where it came from
    private Map<IntDomainVar, int[]> varToIndex;
    //matrix to allow access to constraints by pair of indices
    private IntDomainVar[][] mat;
    //size of matrix
    private int n;
    //var aliases to allow cut and paste code from BACUM.java
    private IntDomainVar v0;
    private IntDomainVar v1;
    private IntDomainVar v2;

    //BEGIN support code copied from BACUM.java
    private class Triple {
	IntDomainVar s, m, l;
	public Triple(IntDomainVar s, IntDomainVar m, IntDomainVar l) { 
	    this.s = s; this.m = m; this.l = l; 
	}
    }
    //return true if and only if v and w's domains have null intersection
    private static boolean nullIntersection(IntDomainVar v, IntDomainVar w) {
	return v.getInf() > w.getSup() || v.getSup() < w.getInf();
    }
    //return triple of variables ordered by infimum. Uses sorting algorithm for
    //3 items using at most 3 comparisons, see page 173 of CLR Introduction to
    //Algorithms.
    private Triple sortOnInf() {
	int v0Inf = v0.getInf(); 
	int v1Inf = v1.getInf(); 
	int v2Inf = v2.getInf();
	if(v0Inf <= v1Inf)
	    if(v1Inf <= v2Inf)
		return new Triple(v0, v1, v2);
	    else
		if(v0Inf <= v2Inf)
		    return new Triple(v0, v2, v1);
		else
		    return new Triple(v2, v0, v1);
	else
	    if(v0Inf <= v2Inf)
		return new Triple(v1, v0, v2);
	    else
		if(v1Inf <= v2Inf)
		    return new Triple(v1, v2, v0);
		else
		    return new Triple(v2, v1, v0);
    }
    private Triple sortOnSup() {
	int v0Sup = v0.getSup(); 
	int v1Sup = v1.getSup(); 
	int v2Sup = v2.getSup();
	if(v0Sup <= v1Sup)
	    if(v1Sup <= v2Sup)
		return new Triple(v0, v1, v2);
	    else
		if(v0Sup <= v2Sup)
		    return new Triple(v0, v2, v1);
		else
		    return new Triple(v2, v0, v1);
	else
	    if(v0Sup <= v2Sup)
		return new Triple(v1, v0, v2);
	    else
		if(v1Sup <= v2Sup)
		    return new Triple(v1, v2, v0);
		else
		    return new Triple(v2, v1, v0);
    }
    public void fixBounds() throws ContradictionException {
	if(debug) {
	    System.out.println("fixBounds()");
	    System.out.println("vars: " + v0.pretty() + " " + 
			       v1.pretty() + " " + v2.pretty());
	}
	//FIX UP THE INFS
	Triple si = sortOnInf();
	int sInf = si.s.getInf();
	int mInf = si.m.getInf();
	int lInf = si.l.getInf();
	//first case, each inf is different
	if(sInf != mInf && mInf != lInf) {
	    if(debug) System.out.println("case 1: infs all different:");
	    si.s.setInf(mInf);
	//4th case, action is identical to first case but temporarily
	//separated for clarity
 	} else if(mInf == lInf && sInf != mInf) {
	    if(debug) System.out.println("case 4: smallest inf is distinct:");
	    si.s.setInf(mInf);
	}
	//FIX UP THE SUPS
	Triple ss = sortOnSup();
	int sSup = ss.s.getSup();
	int mSup = ss.m.getSup();
	int lSup = ss.l.getSup();
	//first case, each sup is different
	if(sSup != mSup && mSup != lSup) {
	    if(debug) System.out.println("case 1: sups all different");
	    if(nullIntersection(ss.l, ss.s)) {
		if(debug) System.out.println("null intersection of s and l");
		ss.m.setSup(sSup);
	    } else if(nullIntersection(ss.m, ss.s)) {
		if(debug) System.out.println("null intersection of m and l");
		ss.l.setSup(sSup);
	    }
	//third case, largest are equal but smallest is different
	} else if (sSup != mSup && mSup == lSup) {
	    if(debug) System.out.println("case 3: 2 largest sups are equal");
	    if(nullIntersection(ss.m, ss.s)) {
		if(debug) System.out.println("null intersection of s and m");
		ss.l.setSup(sSup);
	    } else if(nullIntersection(ss.l, ss.s)) {
		if(debug) System.out.println("null intersection of s and l");
		ss.m.setSup(sSup);
	    }		
	}	
	if(debug) {
	    System.out.println("end vars: " + v0.pretty() + " " + 
			       v1.pretty() + " " + v2.pretty());
	    System.out.println("fixBounds() ends");
	}
    }
    //END support code copied from BACUM.java
    
    public static IntDomainVar[] getUseful(IntDomainVar[][] mat) {
	int n = mat.length;
	int numUseful = n * (n - 1) / 2;
	IntDomainVar[] vars = new IntDomainVar[numUseful];
	int count = 0;
	//extract useful variables
	for(int i = 0; i < n - 1; i++)
	    for(int j = i + 1; j < n; j++)
		vars[count++] = mat[i][j];
	return vars;
    }
    public AC3BACMatrix(IntDomainVar[][] mat, boolean debug) {
	//the superclass constructor must come first so we must flatten the
	//matrix inside a function call
	super(getUseful(mat));
	n = mat.length;
	int numUseful = n * (n - 1) / 2;
	varToIndex = new HashMap<IntDomainVar, int[]>(3 * numUseful);
	//add all useful variables to mapping
	for(int i = 0; i < n - 1; i++) 
	    for(int j = i + 1; j < n; j++) 
		varToIndex.put(mat[i][j], new int[] { i, j });
	this.mat = mat;
	this.debug = debug;    
    }

    private void doPropagate(int idx) throws ContradictionException {
	if(debug) System.out.println("doPropagate(" + idx + ")");
	v0 = getIntVar(idx);
	int[] index = varToIndex.get(v0);
	int i = index[0]; int j = index[1];
	if(debug) System.out.println("idx=[" + i + "," + j + "]");
	for(int k = 0; k < n; k++) {
	    if(i != k && j != k) {
		v1 = mat[i][k];
		v2 = mat[j][k];
		fixBounds();
	    }
	}
	if(debug) System.out.println("doPropagate() ends");
    }

    public void awakeOnInf(int idx) throws ContradictionException {
	if(debug) System.out.println("awakeOnInf(" + idx + ")");
	doPropagate(idx);
	if(debug) System.out.println("awakOnInf() ends");
    }

    public void awakeOnSup(int idx) throws ContradictionException {
	if(debug) System.out.println("awakeOnSup(" + idx + ")");
	doPropagate(idx);
	if(debug) System.out.println("awakeOnSup() ends");
    }

    public void awake() throws ContradictionException {
	if(debug) System.out.println("awake()");
	for(int i = 0; i < getNbVars(); i++)
	    doPropagate(i);
	if(debug) System.out.println("awake() ends");
    }

    public void awakeOnRem(int idx, int x) throws ContradictionException {
	if(debug) System.out.println("awakeOnRem(" + idx + "," + x + ")");
    }

    public void awakeOnRemovals(int idx, IntIterator deltaDom) 
	throws ContradictionException {
	if(debug) System.out.println("awakeOnRemovals(" + idx + ")");
    }

    public void propagate() throws ContradictionException {
	if(debug) System.out.println("propagate()");
    }

    public void awakeOnInst(int idx) throws ContradictionException {
	if(debug) {
	    System.out.println("awakeOnInst(" + idx + ")");
	    System.out.println(" setting to " + getVar(idx));
	}
	doPropagate(idx);
	if(debug) System.out.println("awakeOnInst() ends");
    }
    	
    public boolean isSatisfied() {
	return true;
    }
}
