/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.functions.pace;

import java.util.Random;
import weka.classifiers.functions.pace.MixtureDistribution;
import weka.classifiers.functions.pace.PaceMatrix;
import weka.core.matrix.DoubleVector;
import weka.core.matrix.Maths;

public class NormalMixture
extends MixtureDistribution {
    protected double separatingThreshold = 0.05;
    protected double trimingThreshold = 0.7;
    protected double fittingIntervalLength = 3.0;

    public double getSeparatingThreshold() {
        return this.separatingThreshold;
    }

    public void setSeparatingThreshold(double d) {
        this.separatingThreshold = d;
    }

    public double getTrimingThreshold() {
        return this.trimingThreshold;
    }

    public void setTrimingThreshold(double d) {
        this.trimingThreshold = d;
    }

    public boolean separable(DoubleVector doubleVector, int n, int n2, double d) {
        double d2 = 0.0;
        for (int i = n; i <= n2; ++i) {
            d2 += Maths.pnorm(-Math.abs(d - doubleVector.get(i)));
        }
        return d2 < this.separatingThreshold;
    }

    public DoubleVector supportPoints(DoubleVector doubleVector, int n) {
        if (doubleVector.size() < 2) {
            throw new IllegalArgumentException("data size < 2");
        }
        return doubleVector.copy();
    }

    public PaceMatrix fittingIntervals(DoubleVector doubleVector) {
        DoubleVector doubleVector2 = doubleVector.cat(doubleVector.minus(this.fittingIntervalLength));
        DoubleVector doubleVector3 = doubleVector.plus(this.fittingIntervalLength).cat(doubleVector);
        PaceMatrix paceMatrix = new PaceMatrix(doubleVector2.size(), 2);
        paceMatrix.setMatrix(0, doubleVector2.size() - 1, 0, doubleVector2);
        paceMatrix.setMatrix(0, doubleVector3.size() - 1, 1, doubleVector3);
        return paceMatrix;
    }

    public PaceMatrix probabilityMatrix(DoubleVector doubleVector, PaceMatrix paceMatrix) {
        int n = doubleVector.size();
        int n2 = paceMatrix.getRowDimension();
        PaceMatrix paceMatrix2 = new PaceMatrix(n2, n);
        for (int i = 0; i < n2; ++i) {
            for (int j = 0; j < n; ++j) {
                paceMatrix2.set(i, j, Maths.pnorm(paceMatrix.get(i, 1), doubleVector.get(j), 1.0) - Maths.pnorm(paceMatrix.get(i, 0), doubleVector.get(j), 1.0));
            }
        }
        return paceMatrix2;
    }

    public double empiricalBayesEstimate(double d) {
        if (Math.abs(d) > 10.0) {
            return d;
        }
        DoubleVector doubleVector = Maths.dnormLog(d, this.mixingDistribution.getPointValues(), 1.0);
        doubleVector.minusEquals(doubleVector.max());
        doubleVector = doubleVector.map("java.lang.Math", "exp");
        doubleVector.timesEquals(this.mixingDistribution.getFunctionValues());
        return this.mixingDistribution.getPointValues().innerProduct(doubleVector) / doubleVector.sum();
    }

    public DoubleVector empiricalBayesEstimate(DoubleVector doubleVector) {
        DoubleVector doubleVector2 = new DoubleVector(doubleVector.size());
        for (int i = 0; i < doubleVector.size(); ++i) {
            doubleVector2.set(i, this.empiricalBayesEstimate(doubleVector.get(i)));
        }
        this.trim(doubleVector2);
        return doubleVector2;
    }

    public DoubleVector nestedEstimate(DoubleVector doubleVector) {
        int n;
        DoubleVector doubleVector2 = new DoubleVector(doubleVector.size());
        for (n = 0; n < doubleVector.size(); ++n) {
            doubleVector2.set(n, this.hf(doubleVector.get(n)));
        }
        doubleVector2.cumulateInPlace();
        n = doubleVector2.indexOfMax();
        DoubleVector doubleVector3 = doubleVector.copy();
        if (n < doubleVector.size() - 1) {
            doubleVector3.set(n + 1, doubleVector.size() - 1, 0.0);
        }
        this.trim(doubleVector3);
        return doubleVector3;
    }

    public DoubleVector subsetEstimate(DoubleVector doubleVector) {
        DoubleVector doubleVector2 = this.h(doubleVector);
        DoubleVector doubleVector3 = doubleVector.copy();
        for (int i = 0; i < doubleVector.size(); ++i) {
            if (!(doubleVector2.get(i) <= 0.0)) continue;
            doubleVector3.set(i, 0.0);
        }
        this.trim(doubleVector3);
        return doubleVector3;
    }

    public void trim(DoubleVector doubleVector) {
        for (int i = 0; i < doubleVector.size(); ++i) {
            if (!(Math.abs(doubleVector.get(i)) <= this.trimingThreshold)) continue;
            doubleVector.set(i, 0.0);
        }
    }

    public double hf(double d) {
        DoubleVector doubleVector = this.mixingDistribution.getPointValues();
        DoubleVector doubleVector2 = this.mixingDistribution.getFunctionValues();
        DoubleVector doubleVector3 = Maths.dnormLog(d, doubleVector, 1.0);
        doubleVector3.minusEquals(doubleVector3.max());
        doubleVector3 = doubleVector3.map("java.lang.Math", "exp");
        doubleVector3.timesEquals(doubleVector2);
        return doubleVector.times(2.0 * d).minusEquals(d * d).innerProduct(doubleVector3) / doubleVector3.sum();
    }

    public double h(double d) {
        DoubleVector doubleVector = this.mixingDistribution.getPointValues();
        DoubleVector doubleVector2 = this.mixingDistribution.getFunctionValues();
        DoubleVector doubleVector3 = Maths.dnorm(d, doubleVector, 1.0).timesEquals(doubleVector2);
        return doubleVector.times(2.0 * d).minusEquals(d * d).innerProduct(doubleVector3);
    }

    public DoubleVector h(DoubleVector doubleVector) {
        DoubleVector doubleVector2 = new DoubleVector(doubleVector.size());
        for (int i = 0; i < doubleVector.size(); ++i) {
            doubleVector2.set(i, this.h(doubleVector.get(i)));
        }
        return doubleVector2;
    }

    public double f(double d) {
        DoubleVector doubleVector = this.mixingDistribution.getPointValues();
        DoubleVector doubleVector2 = this.mixingDistribution.getFunctionValues();
        return Maths.dchisq(d, doubleVector).timesEquals(doubleVector2).sum();
    }

    public DoubleVector f(DoubleVector doubleVector) {
        DoubleVector doubleVector2 = new DoubleVector(doubleVector.size());
        for (int i = 0; i < doubleVector.size(); ++i) {
            doubleVector2.set(i, this.h(doubleVector2.get(i)));
        }
        return doubleVector2;
    }

    public String toString() {
        return this.mixingDistribution.toString();
    }

    public static void main(String[] stringArray) {
        int n = 50;
        int n2 = 50;
        double d = 0.0;
        double d2 = 5.0;
        DoubleVector doubleVector = Maths.rnorm(n, d, 1.0, new Random());
        doubleVector = doubleVector.cat(Maths.rnorm(n2, d2, 1.0, new Random()));
        DoubleVector doubleVector2 = new DoubleVector(n, d).cat(new DoubleVector(n2, d2));
        System.out.println("==========================================================");
        System.out.println("This is to test the estimation of the mixing\ndistribution of the mixture of unit variance normal\ndistributions. The example mixture used is of the form: \n\n   0.5 * N(mu1, 1) + 0.5 * N(mu2, 1)\n");
        System.out.println("It also tests three estimators: the subset\nselector, the nested model selector, and the empirical Bayes\nestimator. Quadratic losses of the estimators are given, \nand are taken as the measure of their performance.");
        System.out.println("==========================================================");
        System.out.println("mu1 = " + d + " mu2 = " + d2 + "\n");
        System.out.println(doubleVector.size() + " observations are: \n\n" + doubleVector);
        System.out.println("\nQuadratic loss of the raw data (i.e., the MLE) = " + doubleVector.sum2(doubleVector2));
        System.out.println("==========================================================");
        NormalMixture normalMixture = new NormalMixture();
        normalMixture.fit(doubleVector, 1);
        System.out.println("The estimated mixing distribution is:\n" + normalMixture);
        DoubleVector doubleVector3 = normalMixture.nestedEstimate(doubleVector.rev()).rev();
        System.out.println("\nThe Nested Estimate = \n" + doubleVector3);
        System.out.println("Quadratic loss = " + doubleVector3.sum2(doubleVector2));
        doubleVector3 = normalMixture.subsetEstimate(doubleVector);
        System.out.println("\nThe Subset Estimate = \n" + doubleVector3);
        System.out.println("Quadratic loss = " + doubleVector3.sum2(doubleVector2));
        doubleVector3 = normalMixture.empiricalBayesEstimate(doubleVector);
        System.out.println("\nThe Empirical Bayes Estimate = \n" + doubleVector3);
        System.out.println("Quadratic loss = " + doubleVector3.sum2(doubleVector2));
    }
}

