/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.types;

import cc.mallet.classify.Classification;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelVector;
import cc.mallet.types.Labeling;
import cc.mallet.types.RankedFeatureVector;
import cc.mallet.util.MalletLogger;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.logging.Logger;

public class ExpGain
extends RankedFeatureVector {
    private static Logger logger = MalletLogger.getLogger(ExpGain.class.getName());
    boolean usingHyperbolicPrior = false;
    double hyperbolicSlope = 0.2;
    double hyperbolicSharpness = 10.0;

    private static double[] calcExpGains(InstanceList ilist, LabelVector[] classifications, double gaussianPriorVariance) {
        int li;
        FeatureVector fv;
        int fli;
        int numInstances = ilist.size();
        int numClasses = ilist.getTargetAlphabet().size();
        int numFeatures = ilist.getDataAlphabet().size();
        assert (ilist.size() > 0);
        double[][] p = new double[numClasses][numFeatures];
        double[][] q = new double[numClasses][numFeatures];
        double[][] alphas = new double[numClasses][numFeatures];
        logger.info("Starting klgains, #instances=" + numInstances);
        double trueLabelWeightSum = 0.0;
        double modelLabelWeightSum = 0.0;
        for (int i = 0; i < numInstances; ++i) {
            assert (classifications[i].getLabelAlphabet() == ilist.getTargetAlphabet());
            Instance inst = (Instance)ilist.get(i);
            Labeling labeling = inst.getLabeling();
            FeatureVector fv2 = (FeatureVector)inst.getData();
            double perInstanceModelLabelWeight = 0.0;
            for (int li2 = 0; li2 < numClasses; ++li2) {
                double trueLabelWeight = labeling.value(li2);
                double modelLabelWeight = classifications[i].value(li2);
                trueLabelWeightSum += trueLabelWeight;
                modelLabelWeightSum += modelLabelWeight;
                perInstanceModelLabelWeight += modelLabelWeight;
                if (trueLabelWeight == 0.0 && modelLabelWeight == 0.0) continue;
                for (int fl = 0; fl < fv2.numLocations(); ++fl) {
                    fli = fv2.indexAtLocation(fl);
                    assert (fv2.valueAtLocation(fl) == 1.0);
                    double[] dArray = p[li2];
                    int n = fli;
                    dArray[n] = dArray[n] + trueLabelWeight;
                    double[] dArray2 = q[li2];
                    int n2 = fli;
                    dArray2[n2] = dArray2[n2] + modelLabelWeight;
                }
            }
            assert (Math.abs(perInstanceModelLabelWeight - 1.0) < 0.001);
        }
        assert (Math.abs(trueLabelWeightSum / (double)numInstances - 1.0) < 0.001) : "trueLabelWeightSum should be 1.0, it was " + trueLabelWeightSum;
        assert (Math.abs(modelLabelWeightSum / (double)numInstances - 1.0) < 0.001) : "modelLabelWeightSum should be 1.0, it was " + modelLabelWeightSum;
        double[][] dalphas = new double[numClasses][numFeatures];
        double[][] alphaChangeOld = new double[numClasses][numFeatures];
        double[][] alphaMax = new double[numClasses][numFeatures];
        double[][] alphaMin = new double[numClasses][numFeatures];
        double[][] ddalphas = new double[numClasses][numFeatures];
        for (int i = 0; i < numClasses; ++i) {
            for (int j = 0; j < numFeatures; ++j) {
                alphaMax[i][j] = Double.POSITIVE_INFINITY;
                alphaMin[i][j] = Double.NEGATIVE_INFINITY;
            }
        }
        double maxAlphachange = 0.0;
        double maxDalpha = 99.0;
        int maxNewtonSteps = 50;
        for (int newton = 0; maxDalpha > 1.0E-8 && newton < maxNewtonSteps; ++newton) {
            int i;
            for (i = 0; i < numClasses; ++i) {
                for (int j = 0; j < numFeatures; ++j) {
                    dalphas[i][j] = p[i][j] - alphas[i][j] / gaussianPriorVariance;
                    ddalphas[i][j] = -1.0 / gaussianPriorVariance;
                }
            }
            for (i = 0; i < ilist.size(); ++i) {
                assert (classifications[i].getLabelAlphabet() == ilist.getTargetAlphabet());
                Instance inst = (Instance)ilist.get(i);
                Labeling labeling = inst.getLabeling();
                fv = (FeatureVector)inst.getData();
                for (int fl = 0; fl < fv.numLocations(); ++fl) {
                    fli = fv.indexAtLocation(fl);
                    for (li = 0; li < numClasses; ++li) {
                        double modelLabelWeight = classifications[i].value(li);
                        double expalpha = Math.exp(alphas[li][fli]);
                        double numerator = modelLabelWeight * expalpha;
                        double denominator = numerator + (1.0 - modelLabelWeight);
                        double[] dArray = dalphas[li];
                        int n = fli;
                        dArray[n] = dArray[n] - numerator / denominator;
                        double[] dArray3 = ddalphas[li];
                        int n3 = fli;
                        dArray3[n3] = dArray3[n3] + (numerator * numerator / (denominator * denominator) - numerator / denominator);
                    }
                }
            }
            maxDalpha = 0.0;
            maxAlphachange = 0.0;
            for (int i2 = 0; i2 < numClasses; ++i2) {
                for (int j = 0; j < numFeatures; ++j) {
                    double alphachange = -(dalphas[i2][j] / ddalphas[i2][j]);
                    if (p[i2][j] == 0.0 && q[i2][j] == 0.0) continue;
                    if (Double.isNaN(alphas[i2][j]) || Double.isNaN(alphachange)) {
                        logger.info("alpha[" + i2 + "][" + j + "]=" + alphas[i2][j] + " p=" + p[i2][j] + " q=" + q[i2][j] + " dalpha=" + dalphas[i2][j] + " ddalpha=" + ddalphas[i2][j] + " alphachange=" + alphachange + " min=" + alphaMin[i2][j] + " max=" + alphaMax[i2][j]);
                    }
                    if (Double.isNaN(alphas[i2][j]) || Double.isNaN(dalphas[i2][j]) || Double.isNaN(ddalphas[i2][j]) || Double.isInfinite(alphas[i2][j]) || Double.isInfinite(dalphas[i2][j]) || Double.isInfinite(ddalphas[i2][j])) {
                        alphachange = 0.0;
                    }
                    double oldalpha = alphas[i2][j];
                    double newalpha = Math.abs(alphachange + alphaChangeOld[i2][j]) / Math.abs(alphachange) < 0.01 ? alphas[i2][j] + alphachange / 2.0 : alphas[i2][j] + alphachange;
                    if (alphachange < 0.0 && alphaMax[i2][j] > alphas[i2][j]) {
                        alphaMax[i2][j] = alphas[i2][j];
                    }
                    if (alphachange > 0.0 && alphaMin[i2][j] < alphas[i2][j]) {
                        alphaMin[i2][j] = alphas[i2][j];
                    }
                    if (newalpha <= alphaMax[i2][j] && newalpha >= alphaMin[i2][j]) {
                        alphas[i2][j] = newalpha;
                    } else {
                        assert (alphaMax[i2][j] != Double.POSITIVE_INFINITY);
                        assert (alphaMin[i2][j] != Double.NEGATIVE_INFINITY);
                        alphas[i2][j] = alphaMin[i2][j] + (alphaMax[i2][j] - alphaMin[i2][j]) / 2.0;
                    }
                    alphachange = alphas[i2][j] - oldalpha;
                    if (Math.abs(alphachange) > maxAlphachange) {
                        maxAlphachange = Math.abs(alphachange);
                    }
                    if (Math.abs(dalphas[i2][j]) > maxDalpha) {
                        maxDalpha = Math.abs(dalphas[i2][j]);
                    }
                    alphaChangeOld[i2][j] = alphachange;
                }
            }
            logger.info("After " + newton + " Newton iterations, maximum alphachange=" + maxAlphachange + " dalpha=" + maxDalpha);
        }
        alphaMin = alphaMax = (double[][])null;
        alphaChangeOld = alphaMax;
        dalphas = alphaMax;
        ddalphas = alphaMax;
        double[][] qeag = new double[numClasses][numFeatures];
        for (int i = 0; i < ilist.size(); ++i) {
            assert (classifications[i].getLabelAlphabet() == ilist.getTargetAlphabet());
            Instance inst = (Instance)ilist.get(i);
            Labeling labeling = inst.getLabeling();
            fv = (FeatureVector)inst.getData();
            int fvMaxLocation = fv.numLocations() - 1;
            for (li = 0; li < numClasses; ++li) {
                double modelLabelWeight = classifications[i].value(li);
                for (int fl = 0; fl < fv.numLocations(); ++fl) {
                    fli = fv.indexAtLocation(fl);
                    double[] dArray = qeag[li];
                    int n = fli;
                    dArray[n] = dArray[n] + Math.log(modelLabelWeight * Math.exp(alphas[li][fli]) + (1.0 - modelLabelWeight));
                }
            }
        }
        double[] klgains = new double[numFeatures];
        for (int i = 0; i < numClasses; ++i) {
            for (int j = 0; j < numFeatures; ++j) {
                double klgainIncr;
                assert (!Double.isInfinite(alphas[i][j]));
                double alpha = alphas[i][j];
                if (alpha == 0.0 || (klgainIncr = alpha * p[i][j] - qeag[i][j] - alpha * alpha / (2.0 * gaussianPriorVariance)) < 0.0) continue;
                int n = j;
                klgains[n] = klgains[n] + klgainIncr;
            }
        }
        return klgains;
    }

    public ExpGain(InstanceList ilist, LabelVector[] classifications, double gaussianPriorVariance) {
        super(ilist.getDataAlphabet(), ExpGain.calcExpGains(ilist, classifications, gaussianPriorVariance));
    }

    private static LabelVector[] getLabelVectorsFromClassifications(Classification[] c) {
        LabelVector[] ret = new LabelVector[c.length];
        for (int i = 0; i < c.length; ++i) {
            ret[i] = c[i].getLabelVector();
        }
        return ret;
    }

    public ExpGain(InstanceList ilist, Classification[] classifications, double gaussianPriorVariance) {
        super(ilist.getDataAlphabet(), ExpGain.calcExpGains(ilist, ExpGain.getLabelVectorsFromClassifications(classifications), gaussianPriorVariance));
    }

    public static class Factory
    implements RankedFeatureVector.Factory {
        LabelVector[] classifications;
        double gaussianPriorVariance = 10.0;
        private static final long serialVersionUID = 1L;
        private static final int CURRENT_SERIAL_VERSION = 0;

        public Factory(LabelVector[] classifications) {
            this.classifications = classifications;
        }

        public Factory(LabelVector[] classifications, double gaussianPriorVariance) {
            this.classifications = classifications;
            this.gaussianPriorVariance = gaussianPriorVariance;
        }

        public RankedFeatureVector newRankedFeatureVector(InstanceList ilist) {
            assert (ilist.getTargetAlphabet() == this.classifications[0].getAlphabet());
            return new ExpGain(ilist, this.classifications, this.gaussianPriorVariance);
        }

        private void writeObject(ObjectOutputStream out) throws IOException {
            out.writeInt(0);
            out.writeInt(this.classifications.length);
            for (int i = 0; i < this.classifications.length; ++i) {
                out.writeObject(this.classifications[i]);
            }
        }

        private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
            int version = in.readInt();
            int n = in.readInt();
            this.classifications = new LabelVector[n];
            for (int i = 0; i < n; ++i) {
                this.classifications[i] = (LabelVector)in.readObject();
            }
        }
    }
}

