/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.bayes.net.estimate;

import weka.classifiers.bayes.net.search.local.Scoreable;
import weka.core.RevisionUtils;
import weka.core.Statistics;
import weka.core.Utils;
import weka.estimators.DiscreteEstimator;
import weka.estimators.Estimator;

public class DiscreteEstimatorBayes
extends Estimator
implements Scoreable {
    static final long serialVersionUID = 4215400230843212684L;
    protected double[] m_Counts;
    protected double m_SumOfCounts;
    protected int m_nSymbols = 0;
    protected double m_fPrior = 0.0;

    public DiscreteEstimatorBayes(int nSymbols, double fPrior) {
        this.m_fPrior = fPrior;
        this.m_nSymbols = nSymbols;
        this.m_Counts = new double[this.m_nSymbols];
        int iSymbol = 0;
        while (iSymbol < this.m_nSymbols) {
            this.m_Counts[iSymbol] = this.m_fPrior;
            ++iSymbol;
        }
        this.m_SumOfCounts = this.m_fPrior * (double)this.m_nSymbols;
    }

    @Override
    public void addValue(double data, double weight) {
        int n = (int)data;
        this.m_Counts[n] = this.m_Counts[n] + weight;
        this.m_SumOfCounts += weight;
    }

    @Override
    public double getProbability(double data) {
        if (this.m_SumOfCounts == 0.0) {
            return 0.0;
        }
        return this.m_Counts[(int)data] / this.m_SumOfCounts;
    }

    public double getCount(double data) {
        if (this.m_SumOfCounts == 0.0) {
            return 0.0;
        }
        return this.m_Counts[(int)data];
    }

    public int getNumSymbols() {
        return this.m_Counts == null ? 0 : this.m_Counts.length;
    }

    @Override
    public double logScore(int nType, int nCardinality) {
        double fScore = 0.0;
        switch (nType) {
            case 0: {
                int iSymbol = 0;
                while (iSymbol < this.m_nSymbols) {
                    fScore += Statistics.lnGamma(this.m_Counts[iSymbol]);
                    ++iSymbol;
                }
                fScore -= Statistics.lnGamma(this.m_SumOfCounts);
                if (this.m_fPrior == 0.0) break;
                fScore -= (double)this.m_nSymbols * Statistics.lnGamma(this.m_fPrior);
                fScore += Statistics.lnGamma((double)this.m_nSymbols * this.m_fPrior);
                break;
            }
            case 1: {
                int iSymbol = 0;
                while (iSymbol < this.m_nSymbols) {
                    fScore += Statistics.lnGamma(this.m_Counts[iSymbol]);
                    ++iSymbol;
                }
                fScore -= Statistics.lnGamma(this.m_SumOfCounts);
                fScore -= (double)this.m_nSymbols * Statistics.lnGamma(1.0 / (double)(this.m_nSymbols * nCardinality));
                fScore += Statistics.lnGamma(1.0 / (double)nCardinality);
                break;
            }
            case 2: 
            case 3: 
            case 4: {
                int iSymbol = 0;
                while (iSymbol < this.m_nSymbols) {
                    double fP = this.getProbability(iSymbol);
                    fScore += this.m_Counts[iSymbol] * Math.log(fP);
                    ++iSymbol;
                }
                break;
            }
        }
        return fScore;
    }

    public String toString() {
        String result = "Discrete Estimator. Counts = ";
        if (this.m_SumOfCounts > 1.0) {
            int i = 0;
            while (i < this.m_Counts.length) {
                result = String.valueOf(result) + " " + Utils.doubleToString(this.m_Counts[i], 2);
                ++i;
            }
            result = String.valueOf(result) + "  (Total = " + Utils.doubleToString(this.m_SumOfCounts, 2) + ")\n";
        } else {
            int i = 0;
            while (i < this.m_Counts.length) {
                result = String.valueOf(result) + " " + this.m_Counts[i];
                ++i;
            }
            result = String.valueOf(result) + "  (Total = " + this.m_SumOfCounts + ")\n";
        }
        return result;
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 1.7 $");
    }

    public static void main(String[] argv) {
        try {
            int current;
            if (argv.length == 0) {
                System.out.println("Please specify a set of instances.");
                return;
            }
            int max = current = Integer.parseInt(argv[0]);
            int i = 1;
            while (i < argv.length) {
                current = Integer.parseInt(argv[i]);
                if (current > max) {
                    max = current;
                }
                ++i;
            }
            DiscreteEstimator newEst = new DiscreteEstimator(max + 1, true);
            int i2 = 0;
            while (i2 < argv.length) {
                current = Integer.parseInt(argv[i2]);
                System.out.println(newEst);
                System.out.println("Prediction for " + current + " = " + newEst.getProbability(current));
                newEst.addValue(current, 1.0);
                ++i2;
            }
        }
        catch (Exception e) {
            System.out.println(e.getMessage());
        }
    }
}

