/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.operators.hmc;

import dr.evomodel.treedatalikelihood.discrete.MaskProvider;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.model.Parameter;
import dr.inference.operators.AdaptationMode;
import dr.inference.operators.hmc.HamiltonianMonteCarloOperator;
import dr.inference.operators.hmc.MassPreconditionScheduler;
import dr.inference.operators.hmc.MassPreconditioner;
import dr.util.Transform;
import dr.xml.Reportable;

public class TransformedMultivariateHamiltonianMonteCarloOperator
extends HamiltonianMonteCarloOperator
implements Reportable {
    private final MaskProvider maskProvider;

    public TransformedMultivariateHamiltonianMonteCarloOperator(AdaptationMode adaptationMode, double d, GradientWrtParameterProvider gradientWrtParameterProvider, Parameter parameter, Transform transform, MaskProvider maskProvider, HamiltonianMonteCarloOperator.Options options, MassPreconditioner massPreconditioner, MassPreconditionScheduler.Type type) {
        super(adaptationMode, d, gradientWrtParameterProvider, parameter, transform, maskProvider.getMask(), options, massPreconditioner, type);
        this.maskProvider = maskProvider;
        this.leapFrogEngine = this.constructLeapFrogEngine(transform);
    }

    @Override
    protected double[] buildMask(Parameter parameter) {
        double[] dArray = new double[this.gradientProvider.getDimension()];
        for (int i = 0; i < dArray.length; ++i) {
            dArray[i] = 1.0;
        }
        return dArray;
    }

    @Override
    protected HamiltonianMonteCarloOperator.LeapFrogEngine constructLeapFrogEngine(Transform transform) {
        return new MaskedMultivariateTransform(this.parameter, this.gradientProvider, this.getDefaultInstabilityHandler(), this.preconditioning, this.maskProvider, transform);
    }

    @Override
    public String getReport() {
        return null;
    }

    class MaskedMultivariateTransform
    extends HamiltonianMonteCarloOperator.LeapFrogEngine.WithTransform {
        private final MaskProvider maskProvider;

        MaskedMultivariateTransform(Parameter parameter, GradientWrtParameterProvider gradientWrtParameterProvider, HamiltonianMonteCarloOperator.InstabilityHandler instabilityHandler, MassPreconditioner massPreconditioner, MaskProvider maskProvider, Transform transform) {
            super(parameter, transform, instabilityHandler, massPreconditioner, new double[gradientWrtParameterProvider.getDimension()]);
            this.maskProvider = maskProvider;
            this.setMaskUntransformedSpace();
        }

        private void setMaskUntransformedSpace() {
            for (int i = 0; i < this.mask.length; ++i) {
                this.mask[i] = 1.0;
            }
        }

        @Override
        public void updateMask() {
            this.maskProvider.updateMask();
        }

        @Override
        public void updateMomentum(double[] dArray, double[] dArray2, double[] dArray3, double d) throws HamiltonianMonteCarloOperator.NumericInstabilityException {
            dArray3 = this.transform.updateGradientLogDensity(dArray3, this.unTransformedPosition, 0, this.unTransformedPosition.length);
            this.mask(dArray3);
            super.updateMomentum(dArray, dArray2, dArray3, d);
        }

        private void mask(double[] dArray) {
            assert (this.maskProvider.getMask().getDimension() == dArray.length);
            for (int i = 0; i < dArray.length; ++i) {
                int n = i;
                dArray[n] = dArray[n] * this.maskProvider.getMask().getParameterValue(i);
            }
        }
    }
}

