/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treelikelihood;

import dr.evolution.alignment.PatternList;
import dr.evolution.datatype.DataType;
import dr.evolution.tree.MutableTreeModel;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evomodel.branchmodel.BranchModel;
import dr.evomodel.branchmodel.EpochBranchModel;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.siteratemodel.SiteRateModel;
import dr.evomodel.substmodel.MarkovJumpsSubstitutionModel;
import dr.evomodel.substmodel.SubstitutionModel;
import dr.evomodel.substmodel.UniformizedSubstitutionModel;
import dr.evomodel.tipstatesmodel.TipStatesModel;
import dr.evomodel.tree.TreeModel;
import dr.evomodel.treelikelihood.AncestralStateBeagleTreeLikelihood;
import dr.evomodel.treelikelihood.MarkovJumpsTraitProvider;
import dr.evomodel.treelikelihood.PartialsRescalingScheme;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.NumberColumn;
import dr.inference.markovjumps.MarkovJumpsRegisterAcceptor;
import dr.inference.markovjumps.MarkovJumpsType;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.util.Citable;
import dr.util.Citation;
import dr.util.CommonCitations;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class MarkovJumpsBeagleTreeLikelihood
extends AncestralStateBeagleTreeLikelihood
implements MarkovJumpsRegisterAcceptor,
MarkovJumpsTraitProvider,
Citable {
    public static final String ALL_HISTORY = "history_all";
    public static final String HISTORY = "history";
    public static final String TOTAL_COUNTS = "allTransitions";
    private List<MarkovJumpsSubstitutionModel> markovjumps;
    private List<Integer> branchModelNumber;
    private List<Parameter> registerParameter;
    private List<String> jumpTag;
    private List<double[][]> expectedJumps;
    private boolean logHistory = false;
    private boolean useCompactHistory = false;
    private String[][] histories = null;
    private boolean[] scaleByTime;
    private double[] tmpProbabilities;
    private double[][] condJumps;
    private int numRegisters;
    private int historyRegisterNumber = -1;
    private final boolean useUniformization;
    private final int nSimulants;
    private final boolean reportUnconditionedColumns;

    public MarkovJumpsBeagleTreeLikelihood(PatternList patternList, MutableTreeModel mutableTreeModel, BranchModel branchModel, SiteRateModel siteRateModel, BranchRateModel branchRateModel, TipStatesModel tipStatesModel, boolean bl, PartialsRescalingScheme partialsRescalingScheme, boolean bl2, Map<Set<String>, Parameter> map, DataType dataType, String string, boolean bl3, boolean bl4, boolean bl5, boolean bl6, int n, boolean bl7) {
        super(patternList, mutableTreeModel, branchModel, siteRateModel, branchRateModel, tipStatesModel, bl, partialsRescalingScheme, bl2, map, dataType, string, bl3, bl4, bl7);
        this.useUniformization = bl5;
        this.reportUnconditionedColumns = bl6;
        this.nSimulants = n;
        this.markovjumps = new ArrayList<MarkovJumpsSubstitutionModel>();
        this.branchModelNumber = new ArrayList<Integer>();
        this.registerParameter = new ArrayList<Parameter>();
        this.jumpTag = new ArrayList<String>();
        this.expectedJumps = new ArrayList<double[][]>();
        this.tmpProbabilities = new double[this.stateCount * this.stateCount * this.categoryCount];
        this.condJumps = new double[this.categoryCount][this.stateCount * this.stateCount];
    }

    public MarkovJumpsBeagleTreeLikelihood(PatternList patternList, TreeModel treeModel, BranchModel branchModel, SiteRateModel siteRateModel, BranchRateModel branchRateModel, TipStatesModel tipStatesModel, boolean bl, PartialsRescalingScheme partialsRescalingScheme, boolean bl2, Map<Set<String>, Parameter> map, DataType dataType, String string, boolean bl3, boolean bl4, boolean bl5, boolean bl6, int n) {
        this(patternList, treeModel, branchModel, siteRateModel, branchRateModel, tipStatesModel, bl, partialsRescalingScheme, bl2, map, dataType, string, bl3, bl4, bl5, bl6, n, false);
    }

    @Override
    public void addRegister(Parameter parameter, MarkovJumpsType markovJumpsType, boolean bl) {
        if (markovJumpsType == MarkovJumpsType.COUNTS && parameter.getDimension() != this.stateCount * this.stateCount || markovJumpsType == MarkovJumpsType.REWARDS && parameter.getDimension() != this.stateCount) {
            throw new RuntimeException("Register parameter of wrong dimension");
        }
        this.addVariable(parameter);
        final String string = parameter.getId();
        for (int i = 0; i < this.substitutionModelDelegate.getSubstitutionModelCount(); ++i) {
            MarkovJumpsSubstitutionModel markovJumpsSubstitutionModel;
            boolean bl2 = this.branchModel instanceof EpochBranchModel;
            this.registerParameter.add(parameter);
            SubstitutionModel substitutionModel = this.substitutionModelDelegate.getSubstitutionModel(i);
            if (this.useUniformization) {
                markovJumpsSubstitutionModel = new UniformizedSubstitutionModel(substitutionModel, markovJumpsType, this.nSimulants);
            } else {
                if (markovJumpsType == MarkovJumpsType.HISTORY) {
                    throw new RuntimeException("Can only report complete history using uniformization");
                }
                markovJumpsSubstitutionModel = new MarkovJumpsSubstitutionModel(substitutionModel, markovJumpsType);
            }
            this.markovjumps.add(markovJumpsSubstitutionModel);
            this.branchModelNumber.add(i);
            this.addModel(markovJumpsSubstitutionModel);
            this.setupRegistration(this.numRegisters);
            final String string2 = this.substitutionModelDelegate.getSubstitutionModelCount() == 1 ? string : string + i;
            this.jumpTag.add(string2);
            this.expectedJumps.add(new double[this.treeModel.getNodeCount()][this.patternCount]);
            boolean[] blArray = this.scaleByTime;
            int n = blArray == null ? 0 : blArray.length;
            this.scaleByTime = new boolean[n + 1];
            if (n > 0) {
                System.arraycopy(blArray, 0, this.scaleByTime, 0, n);
            }
            this.scaleByTime[n] = bl;
            if (markovJumpsType != MarkovJumpsType.HISTORY) {
                TreeTrait.DA dA = new TreeTrait.DA(){
                    final int registerNumber;
                    {
                        this.registerNumber = MarkovJumpsBeagleTreeLikelihood.this.numRegisters;
                    }

                    @Override
                    public String getTraitName() {
                        return string2;
                    }

                    @Override
                    public TreeTrait.Intent getIntent() {
                        return TreeTrait.Intent.BRANCH;
                    }

                    @Override
                    public double[] getTrait(Tree tree, NodeRef nodeRef) {
                        return MarkovJumpsBeagleTreeLikelihood.this.getMarkovJumpsForNodeAndRegister(tree, nodeRef, this.registerNumber);
                    }
                };
                this.treeTraits.addTrait(string2 + "_base", dA);
                String string3 = parameter.getId();
                if (this.substitutionModelDelegate.getSubstitutionModelCount() > 1) {
                    string3 = string3 + i;
                }
                this.treeTraits.addTrait(string3, new TreeTrait.SumAcrossArrayD(new TreeTrait.SumOverTreeDA(dA)));
                this.treeTraits.addTrait(string2 + "_sum", new TreeTrait.SumAcrossArrayD(dA));
            } else {
                if (i == 0 || !bl2) {
                    if (this.histories != null) {
                        throw new RuntimeException("Only one complete history per markovJumpTreeLikelihood is allowed");
                    }
                    this.histories = new String[this.treeModel.getNodeCount()][this.patternCount];
                    if (this.nSimulants > 1) {
                        throw new RuntimeException("Only one simulant allowed when saving complete history");
                    }
                    TreeTrait.DA dA = new TreeTrait.DA(){
                        final int registerNumber;
                        {
                            this.registerNumber = MarkovJumpsBeagleTreeLikelihood.this.numRegisters;
                        }

                        @Override
                        public String getTraitName() {
                            return string;
                        }

                        @Override
                        public TreeTrait.Intent getIntent() {
                            return TreeTrait.Intent.BRANCH;
                        }

                        @Override
                        public double[] getTrait(Tree tree, NodeRef nodeRef) {
                            return MarkovJumpsBeagleTreeLikelihood.this.getMarkovJumpsForNodeAndRegister(tree, nodeRef, this.registerNumber);
                        }
                    };
                    this.treeTraits.addTrait(parameter.getId(), new TreeTrait.SumOverTreeDA(dA));
                    this.historyRegisterNumber = this.numRegisters;
                    ((UniformizedSubstitutionModel)markovJumpsSubstitutionModel).setSaveCompleteHistory(true);
                    if (this.useCompactHistory && this.logHistory) {
                        this.treeTraits.addTrait(ALL_HISTORY, new TreeTrait.SA(){

                            @Override
                            public String getTraitName() {
                                return MarkovJumpsBeagleTreeLikelihood.ALL_HISTORY;
                            }

                            @Override
                            public TreeTrait.Intent getIntent() {
                                return TreeTrait.Intent.BRANCH;
                            }

                            @Override
                            public boolean getFormatAsArray() {
                                return true;
                            }

                            @Override
                            public String[] getTrait(Tree tree, NodeRef nodeRef) {
                                ArrayList<String> arrayList = new ArrayList<String>();
                                for (int i = 0; i < MarkovJumpsBeagleTreeLikelihood.this.patternCount; ++i) {
                                    String string = MarkovJumpsBeagleTreeLikelihood.this.getHistoryForNode(tree, nodeRef, i);
                                    if (string == null || string.compareTo("{}") == 0) continue;
                                    if ((string = string.substring(1, string.length() - 1)).contains("},{")) {
                                        String[] stringArray;
                                        for (String string2 : stringArray = string.split("(?<=\\}),(?=\\{)")) {
                                            arrayList.add(string2);
                                        }
                                        continue;
                                    }
                                    arrayList.add(string);
                                }
                                String[] stringArray = new String[arrayList.size()];
                                arrayList.toArray(stringArray);
                                return stringArray;
                            }

                            @Override
                            public boolean getLoggable() {
                                return true;
                            }
                        });
                    }
                    int n2 = 0;
                    while (n2 < this.patternCount) {
                        final String string4 = this.patternCount == 1 ? HISTORY : "history_" + (n2 + 1);
                        final int n3 = n2++;
                        this.treeTraits.addTrait(string4, new TreeTrait.S(){

                            @Override
                            public String getTraitName() {
                                return string4;
                            }

                            @Override
                            public TreeTrait.Intent getIntent() {
                                return TreeTrait.Intent.BRANCH;
                            }

                            @Override
                            public String getTrait(Tree tree, NodeRef nodeRef) {
                                String string = MarkovJumpsBeagleTreeLikelihood.this.getHistoryForNode(tree, nodeRef, n3);
                                return string.compareTo("{}") != 0 ? string : null;
                            }

                            @Override
                            public boolean getLoggable() {
                                return MarkovJumpsBeagleTreeLikelihood.this.logHistory && !MarkovJumpsBeagleTreeLikelihood.this.useCompactHistory;
                            }
                        });
                    }
                }
                if (bl2) {
                    for (int j = 0; j < this.markovjumps.size(); ++j) {
                        ((UniformizedSubstitutionModel)this.markovjumps.get(j)).setSaveCompleteHistory(true);
                    }
                }
            }
            ++this.numRegisters;
        }
    }

    public void setLogHistories(boolean bl) {
        this.logHistory = bl;
    }

    public void setUseCompactHistory(boolean bl) {
        this.useCompactHistory = bl;
    }

    public double[] getMarkovJumpsForNodeAndRegister(Tree tree, NodeRef nodeRef, int n) {
        return this.getMarkovJumpsForRegister(tree, n)[nodeRef.getNumber()];
    }

    private void refresh(Tree tree) {
        if (tree != this.treeModel) {
            throw new RuntimeException("Must call with internal tree");
        }
        if (!this.likelihoodKnown) {
            this.calculateLogLikelihood();
            this.likelihoodKnown = true;
        }
        if (!this.areStatesRedrawn) {
            this.redrawAncestralStates();
        }
    }

    public double[][] getMarkovJumpsForRegister(Tree tree, int n) {
        this.refresh(tree);
        return this.expectedJumps.get(n);
    }

    public String getHistoryForNode(Tree tree, NodeRef nodeRef, int n) {
        return this.getHistory(tree)[nodeRef.getNumber()][n];
    }

    public String[][] getHistory(Tree tree) {
        this.refresh(tree);
        return this.histories;
    }

    private void setupRegistration(int n) {
        double[] dArray = this.registerParameter.get(n).getParameterValues();
        this.markovjumps.get(n).setRegistration(dArray);
        this.areStatesRedrawn = false;
    }

    @Override
    protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        for (int i = 0; i < this.numRegisters; ++i) {
            if (variable != this.registerParameter.get(i)) continue;
            this.setupRegistration(i);
            return;
        }
        super.handleVariableChangedEvent(variable, n, changeType);
    }

    @Override
    protected void hookCalculation(Tree tree, NodeRef nodeRef, NodeRef nodeRef2, int[] nArray, int[] nArray2, double[] dArray, int[] nArray3) {
        int n = nodeRef2.getNumber();
        double[] dArray2 = dArray;
        if (dArray2 == null) {
            this.getMatrix(n, this.tmpProbabilities);
            dArray2 = this.tmpProbabilities;
        }
        double d = this.branchRateModel.getBranchRate(tree, nodeRef2);
        double d2 = tree.getNodeHeight(nodeRef);
        double d3 = tree.getNodeHeight(nodeRef2);
        double d4 = d2 - d3;
        for (int i = 0; i < this.markovjumps.size(); ++i) {
            BranchModel.Mapping mapping;
            MarkovJumpsSubstitutionModel markovJumpsSubstitutionModel = this.markovjumps.get(i);
            int n2 = this.branchModelNumber.get(i);
            if (n2 == (mapping = this.branchModel.getBranchModelMapping(nodeRef2)).getOrder()[0]) {
                if (this.useUniformization) {
                    this.computeSampledMarkovJumpsForBranch((UniformizedSubstitutionModel)markovJumpsSubstitutionModel, d4, d, n, nArray, nArray2, d2, d3, dArray2, this.scaleByTime[i], this.expectedJumps.get(i), nArray3, this.branchModel instanceof EpochBranchModel || i == this.historyRegisterNumber);
                    continue;
                }
                this.computeIntegratedMarkovJumpsForBranch(markovJumpsSubstitutionModel, d4, d, n, nArray, nArray2, dArray2, this.condJumps, this.scaleByTime[i], this.expectedJumps.get(i), nArray3);
                continue;
            }
            double[] dArray3 = this.expectedJumps.get(i)[n];
            Arrays.fill(dArray3, 0.0);
        }
    }

    private void computeSampledMarkovJumpsForBranch(UniformizedSubstitutionModel uniformizedSubstitutionModel, double d, double d2, int n, int[] nArray, int[] nArray2, double d3, double d4, double[] dArray, boolean bl, double[][] dArray2, int[] nArray3, boolean bl2) {
        for (int i = 0; i < this.patternCount; ++i) {
            int n2 = nArray3 == null ? 0 : nArray3[i];
            double d5 = this.siteRateModel.getRateForCategory(n2);
            int n3 = n2 * this.stateCount * this.stateCount;
            double d6 = uniformizedSubstitutionModel.computeCondStatMarkovJumps(nArray[i], nArray2[i], d * d2 * d5, dArray[n3 + nArray[i] * this.stateCount + nArray2[i]]);
            if (bl) {
                d6 /= d2 * d5;
            }
            dArray2[n][i] = d6;
            if (!bl2) continue;
            int n4 = this.useCompactHistory ? i + 1 : -1;
            this.histories[n][i] = uniformizedSubstitutionModel.getCompleteHistory(n4, d3, d4);
        }
    }

    private void computeIntegratedMarkovJumpsForBranch(MarkovJumpsSubstitutionModel markovJumpsSubstitutionModel, double d, double d2, int n, int[] nArray, int[] nArray2, double[] dArray, double[][] dArray2, boolean bl, double[][] dArray3, int[] nArray3) {
        int n2;
        for (n2 = 0; n2 < this.categoryCount; ++n2) {
            double d3 = this.siteRateModel.getRateForCategory(n2);
            if (d3 > 0.0) {
                if (this.categoryCount == 1) {
                    markovJumpsSubstitutionModel.computeCondStatMarkovJumps(d * d2 * d3, dArray, dArray2[n2]);
                } else {
                    System.arraycopy(dArray, n2 * this.stateCount * this.stateCount, this.tmpProbabilities, 0, this.stateCount * this.stateCount);
                    markovJumpsSubstitutionModel.computeCondStatMarkovJumps(d * d2 * d3, this.tmpProbabilities, dArray2[n2]);
                }
                if (!bl) continue;
                double d4 = d2 * d3;
                int n3 = 0;
                while (n3 < dArray2[n2].length) {
                    double[] dArray4 = dArray2[n2];
                    int n4 = n3++;
                    dArray4[n4] = dArray4[n4] / d4;
                }
                continue;
            }
            Arrays.fill(dArray2[n2], 0.0);
            if (markovJumpsSubstitutionModel.getType() != MarkovJumpsType.REWARDS || !bl) continue;
            for (int i = 0; i < this.stateCount; ++i) {
                dArray2[n2][i * this.stateCount + i] = d;
            }
        }
        for (n2 = 0; n2 < this.patternCount; ++n2) {
            int n5 = nArray3 == null ? 0 : nArray3[n2];
            dArray3[n][n2] = dArray2[n5][nArray[n2] * this.stateCount + nArray2[n2]];
        }
    }

    @Override
    public LogColumn[] getColumns() {
        int n = this.patternCount * this.numRegisters;
        if (this.reportUnconditionedColumns) {
            n = this.categoryCount == 1 ? (n += this.numRegisters) : (n *= 2);
        }
        int n2 = 0;
        LogColumn[] logColumnArray = new LogColumn[n];
        for (int i = 0; i < this.numRegisters; ++i) {
            for (int j = 0; j < this.patternCount; ++j) {
                logColumnArray[n2++] = new ConditionedCountColumn(this.jumpTag.get(i), i, j);
                if (!this.reportUnconditionedColumns || this.categoryCount <= 1) continue;
                logColumnArray[n2++] = new UnconditionedCountColumn(this.jumpTag.get(i), i, j, this.rateCategory);
            }
            if (!this.reportUnconditionedColumns || this.categoryCount != 1) continue;
            logColumnArray[n2++] = new UnconditionedCountColumn(this.jumpTag.get(i), i);
        }
        return logColumnArray;
    }

    @Override
    public Citation.Category getCategory() {
        return Citation.Category.COUNTING_PROCESSES;
    }

    @Override
    public String getDescription() {
        return "MarkovJumps inference techniques";
    }

    @Override
    public List<Citation> getCitations() {
        return Collections.singletonList(CommonCitations.MININ_2008_COUNTING);
    }

    protected class ConditionedCountColumn
    extends CountColumn {
        public ConditionedCountColumn(String string, int n, int n2) {
            super("c_" + string, n, n2);
        }

        @Override
        public double getDoubleValue() {
            double d = 0.0;
            double[][] dArray = MarkovJumpsBeagleTreeLikelihood.this.getMarkovJumpsForRegister(MarkovJumpsBeagleTreeLikelihood.this.treeModel, this.indexRegistration);
            for (int i = 0; i < MarkovJumpsBeagleTreeLikelihood.this.treeModel.getNodeCount(); ++i) {
                d += dArray[i][this.indexSite];
            }
            return d;
        }
    }

    protected class UnconditionedCountColumn
    extends CountColumn {
        int[] rateCategory;

        public UnconditionedCountColumn(String string, int n, int n2, int[] nArray) {
            super("u_" + string, n, n2);
            this.rateCategory = nArray;
        }

        public UnconditionedCountColumn(String string, int n) {
            this(string, n, -1, null);
        }

        @Override
        public double getDoubleValue() {
            double d = ((MarkovJumpsSubstitutionModel)MarkovJumpsBeagleTreeLikelihood.this.markovjumps.get(this.indexRegistration)).getMarginalRate() * this.getExpectedTreeLength();
            if (this.rateCategory != null) {
                d *= MarkovJumpsBeagleTreeLikelihood.this.siteRateModel.getRateForCategory(this.rateCategory[this.indexSite]);
            }
            return d;
        }

        private double getExpectedTreeLength() {
            double d = 0.0;
            for (int i = 0; i < MarkovJumpsBeagleTreeLikelihood.this.treeModel.getNodeCount(); ++i) {
                NodeRef nodeRef = MarkovJumpsBeagleTreeLikelihood.this.treeModel.getNode(i);
                if (MarkovJumpsBeagleTreeLikelihood.this.treeModel.isRoot(nodeRef)) continue;
                d += MarkovJumpsBeagleTreeLikelihood.this.branchRateModel.getBranchRate(MarkovJumpsBeagleTreeLikelihood.this.treeModel, nodeRef) * MarkovJumpsBeagleTreeLikelihood.this.treeModel.getBranchLength(nodeRef);
            }
            return d;
        }
    }

    protected abstract class CountColumn
    extends NumberColumn {
        protected int indexRegistration;
        protected int indexSite;

        public CountColumn(String string, int n, int n2) {
            super(string + (n2 >= 0 ? "[" + (n2 + 1) + "]" : ""));
            this.indexRegistration = n;
            this.indexSite = n2;
        }

        @Override
        public abstract double getDoubleValue();
    }
}

