#include <RcppArmadillo.h>
#include "ieTest.h"

using namespace Rcpp;

// Function that performs the original S-test by Berger. 

//' S test for Indirect Effect for a single mediator
//'
//' This function takes 
//' the estimate of the effect of the independent variable on the mediator  
//' and the effect of the mediator on the effect as well as their variances and 
//' performs the S test. Alternative null hypothesis can be specified as well. 
//' Additionally, covariances of the parameters can be specified for cases 
//' involving missing data where the estimates may be correlated.
//'
//' @param alpha Significance level for the test of significance
//' @param x1 Numeric value of the estimated first effect of interest
//' @param s11 Numeric value of the estimated first effect variance
//' @param df1 Degrees of freedom for estimate x1 
//' @param x2 Numeric value of estimated second effect of interest 
//' @param s22 Numeric value of the estimated second effect variance 
//' @param df2 Degrees of freedom for estimate x2. Often the same as x1 
//' @param x10 Optional numeric value of alternative null hypothesis value for the first effect 
//' @param x20 Optional numeric value of alternative null hypothesis value for the second effect 
//' @param s12 Specification of covariance between x1 and x2. Typically 0, but may be non-zero in the prescence of missing data 
//' @returns Boolean True/False value of whether the test rejects the Null hypothesis
//' @references Berger, Roger L. Likelihood Ratio Tests and Intersection-Union Tests. Advances in Statistical Decision Theory and Applications, 2011.
//' @note The function for the S-test does not incorporate interactions between the independent and mediating variables. 
//' The user must first calculate the mean and variance of the second product term to be used in the function call. 
//' @export
//' @examples
//' sTest_one(0.05, .5, 1, 100, -.25, .1, 100)
// [[Rcpp::export]]
bool sTest_one(double alpha, double x1, double s11, int df1, double x2, double s22, int df2, double x10 = 0, double x20 = 0, double s12 = 0){
 // Test and correction if variables are not independent.
 // This occurs with missing data and EM algorithm implimented in the mmed package;
  if(s12 != 0){
    arma::mat tempCM = arma::zeros(2, 2);
    tempCM(0, 0) = s11;
    tempCM(1, 1) = s22;
    tempCM(0, 1) = s12; tempCM(1, 0) = s12;
    arma::mat cholM = chol(tempCM, "lower");
    arma::mat decomp = arma::inv(cholM);

    arma::vec tempMeans = arma::zeros(2);
    arma::vec tempNull = arma::zeros(2);
    tempMeans(0) = x1; tempMeans(1) = x2;
    tempNull(0) = x10; tempNull(1) = x20;

    arma::vec newMeans = decomp * tempMeans;
    arma::vec newNull = decomp * tempNull;

    x1 = newMeans(0); x2 = newMeans(1);
    x10 = newNull(0); x20 = newNull(1);
    s11 = 1; s22 = 1;
  }

  double z1 = (x1 - x10) / std::sqrt(s11);
  double z2 = (x2 - x20) / std::sqrt(s22);
  double U1 = R::pt(fabs(z1), df1, 1, 0 );
  double U2 = R::pt(fabs(z2), df2, 1, 0 );

  bool sq1 = ((U1 > (1.0-alpha/2.0)) |  (U1 < (alpha/2.0))) &
    ((U2 > (1.0-alpha/2.0)) |  (U2 < (alpha/2.0)));

  bool sq2 = (U1 >= (alpha/2.0)) & (U1 <= (1.0 - alpha/2.0)) & (U2 >= (alpha/2.0)) & (U2 <= (1.0 - alpha/2.0)) &
    (
        ((U2 >= (U1 - alpha/4.0)) & (U2 <= (U1 + alpha/4.0))) |
          ((U2 >= (1 - U1 - alpha/4.0)) & (U2 <= (1 - U1 + alpha/4.0)))
    );

  bool sq3 = (U1 >= (alpha/2.0)) & (U1 <= (1.0 - alpha/2.0)) & (U2 >= (alpha/2.0)) & (U2 <= (1.0 - alpha/2.0)) &
    (
        (U2 > (fabs(U1 - 0.5) + 1.0 - 3.0*alpha/4.0)) |
          (U2 < (-fabs(U1 - 0.5) + 3.0*alpha/4.0)) |
          (U1 > (fabs(U2 - 0.5) + 1.0 - 3.0*alpha/4.0)) |
          (U1 < (-fabs(U2 - 0.5) + 3.0*alpha/4.0))
    );

  return(sq1 | sq2 | sq3);
}

