/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.coalescent.hmc;

import dr.evolution.coalescent.IntervalType;
import dr.evolution.coalescent.TreeIntervals;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.coalescent.GMRFMultilocusSkyrideLikelihood;
import dr.evomodel.tree.TreeModel;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.HessianWrtParameterProvider;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.xml.Reportable;
import java.util.List;

public class GMRFGradient
implements GradientWrtParameterProvider,
HessianWrtParameterProvider,
Reportable {
    private final GMRFMultilocusSkyrideLikelihood skygridLikelihood;
    private final WrtParameter wrtParameter;
    private final Parameter parameter;
    private static final Double tolerance = 1.0E-4;

    public GMRFGradient(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood, WrtParameter wrtParameter) {
        this.skygridLikelihood = gMRFMultilocusSkyrideLikelihood;
        this.wrtParameter = wrtParameter;
        this.parameter = wrtParameter.getParameter(gMRFMultilocusSkyrideLikelihood);
    }

    @Override
    public Likelihood getLikelihood() {
        return this.skygridLikelihood;
    }

    @Override
    public Parameter getParameter() {
        return this.parameter;
    }

    @Override
    public int getDimension() {
        return this.parameter.getDimension();
    }

    @Override
    public double[] getGradientLogDensity() {
        return this.wrtParameter.getGradientLogDensity(this.skygridLikelihood);
    }

    @Override
    public double[] getDiagonalHessianLogDensity() {
        return this.wrtParameter.getDiagonalHessianLogDensity(this.skygridLikelihood);
    }

    @Override
    public double[][] getHessianLogDensity() {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    public String getReport() {
        String string = this.skygridLikelihood + "." + this.wrtParameter.name + "\n";
        string = string + GradientWrtParameterProvider.getReportAndCheckForError(this, this.wrtParameter.getParameterLowerBound(), Double.POSITIVE_INFINITY, tolerance) + " \n";
        string = string + HessianWrtParameterProvider.getReportAndCheckForError(this, tolerance);
        return string;
    }

    public static enum WrtParameter {
        LOG_POPULATION_SIZES("logPopulationSizes"){

            @Override
            Parameter getParameter(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                return gMRFMultilocusSkyrideLikelihood.getPopSizeParameter();
            }

            @Override
            double[] getGradientLogDensity(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                return gMRFMultilocusSkyrideLikelihood.getGradientWrtLogPopulationSize();
            }

            @Override
            double[] getDiagonalHessianLogDensity(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                return gMRFMultilocusSkyrideLikelihood.getDiagonalHessianWrtLogPopulationSize();
            }

            @Override
            double getParameterLowerBound() {
                return Double.NEGATIVE_INFINITY;
            }

            @Override
            public void getWarning(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
            }
        }
        ,
        PRECISION("precision"){

            @Override
            Parameter getParameter(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                return gMRFMultilocusSkyrideLikelihood.getPrecisionParameter();
            }

            @Override
            double[] getGradientLogDensity(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                return gMRFMultilocusSkyrideLikelihood.getGradientWrtPrecision();
            }

            @Override
            double[] getDiagonalHessianLogDensity(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                return gMRFMultilocusSkyrideLikelihood.getDiagonalHessianWrtPrecision();
            }

            @Override
            double getParameterLowerBound() {
                return 0.0;
            }

            @Override
            public void getWarning(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
            }
        }
        ,
        REGRESSION_COEFFICIENTS("regressionCoefficients"){

            @Override
            Parameter getParameter(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                List<Parameter> list = gMRFMultilocusSkyrideLikelihood.getBetaListParameter();
                if (list.size() > 1) {
                    throw new RuntimeException("This is not the correct way of handling multidimensional parameters");
                }
                return list.get(0);
            }

            @Override
            double[] getGradientLogDensity(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                return gMRFMultilocusSkyrideLikelihood.getGradientWrtRegressionCoefficients();
            }

            @Override
            double[] getDiagonalHessianLogDensity(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                return gMRFMultilocusSkyrideLikelihood.getDiagonalHessianWrtRegressionCoefficients();
            }

            @Override
            double getParameterLowerBound() {
                return Double.NEGATIVE_INFINITY;
            }

            @Override
            public void getWarning(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
            }
        }
        ,
        NODE_HEIGHT("nodeHeight"){
            Parameter parameter;

            @Override
            Parameter getParameter(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                if (this.parameter == null) {
                    TreeModel treeModel = (TreeModel)gMRFMultilocusSkyrideLikelihood.getTree(0);
                    this.parameter = treeModel.createNodeHeightsParameter(true, true, false);
                }
                return this.parameter;
            }

            @Override
            double[] getGradientLogDensity(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                return this.getGradientWrtNodeHeights(gMRFMultilocusSkyrideLikelihood);
            }

            @Override
            double[] getDiagonalHessianLogDensity(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                return new double[gMRFMultilocusSkyrideLikelihood.getTree(0).getInternalNodeCount()];
            }

            @Override
            double getParameterLowerBound() {
                return 0.0;
            }

            @Override
            public void getWarning(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                if (gMRFMultilocusSkyrideLikelihood.nLoci() > 1) {
                    throw new RuntimeException("Not yet implemented for multiple loci.");
                }
            }

            private double[] getGradientWrtNodeHeights(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                int n;
                gMRFMultilocusSkyrideLikelihood.getLogLikelihood();
                Tree tree = gMRFMultilocusSkyrideLikelihood.getTree(0);
                double[] dArray = new double[tree.getInternalNodeCount()];
                double[] dArray2 = gMRFMultilocusSkyrideLikelihood.getPopSizeParameter().getParameterValues();
                double d = 1.0 / gMRFMultilocusSkyrideLikelihood.getPopulationFactor(0);
                TreeIntervals treeIntervals = gMRFMultilocusSkyrideLikelihood.getTreeIntervals(0);
                int[] nArray = this.getGridIndexForInternalNodes(gMRFMultilocusSkyrideLikelihood, 0);
                for (int i = 0; i < treeIntervals.getIntervalCount(); ++i) {
                    if (treeIntervals.getIntervalType(i) != IntervalType.COALESCENT) continue;
                    int n2 = this.getNodeHeightParameterIndex(treeIntervals.getCoalescentNode(i), tree);
                    n = treeIntervals.getLineageCount(i);
                    int n3 = n2;
                    dArray[n3] = dArray[n3] + -Math.exp(-dArray2[nArray[n2]]) * (double)n * (double)(n - 1);
                    if (tree.isRoot(treeIntervals.getCoalescentNode(i))) continue;
                    int n4 = treeIntervals.getLineageCount(i + 1);
                    int n5 = n2;
                    dArray[n5] = dArray[n5] - -Math.exp(-dArray2[nArray[n2] + 1]) * (double)n4 * (double)(n4 - 1);
                }
                double d2 = 0.5 * d;
                n = 0;
                while (n < dArray.length) {
                    int n6 = n++;
                    dArray[n6] = dArray[n6] * d2;
                }
                return dArray;
            }

            private int getNodeHeightParameterIndex(NodeRef nodeRef, Tree tree) {
                return nodeRef.getNumber() - tree.getExternalNodeCount();
            }

            private int[] getGridIndexForInternalNodes(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood, int n) {
                Tree tree = gMRFMultilocusSkyrideLikelihood.getTree(n);
                TreeIntervals treeIntervals = gMRFMultilocusSkyrideLikelihood.getTreeIntervals(n);
                int[] nArray = new int[tree.getInternalNodeCount()];
                int n2 = 0;
                double[] dArray = gMRFMultilocusSkyrideLikelihood.getGridPoints();
                for (int i = 0; i < treeIntervals.getIntervalCount(); ++i) {
                    if (treeIntervals.getIntervalType(i) != IntervalType.COALESCENT) continue;
                    while (dArray[n2] < treeIntervals.getInterval(i)) {
                        ++n2;
                    }
                    nArray[this.getNodeHeightParameterIndex((NodeRef)treeIntervals.getCoalescentNode((int)i), (Tree)tree)] = n2;
                }
                return nArray;
            }
        };

        private final String name;

        private WrtParameter(String string2) {
            this.name = string2;
        }

        abstract Parameter getParameter(GMRFMultilocusSkyrideLikelihood var1);

        abstract double[] getGradientLogDensity(GMRFMultilocusSkyrideLikelihood var1);

        abstract double[] getDiagonalHessianLogDensity(GMRFMultilocusSkyrideLikelihood var1);

        abstract double getParameterLowerBound();

        public abstract void getWarning(GMRFMultilocusSkyrideLikelihood var1);

        public static WrtParameter factory(String string) {
            for (WrtParameter wrtParameter : WrtParameter.values()) {
                if (!string.equalsIgnoreCase(wrtParameter.name)) continue;
                return wrtParameter;
            }
            return null;
        }
    }
}

