# -----------------------------------------------------------------------------
# File: checks.R
# Purpose: Internal functions for checking model specifications and parameters
# Author: Steffen Maletz
# Last modified: 2025-12-28
# -----------------------------------------------------------------------------

# Function to check wheter data fits to family
# Input:
# - ts: numeric matrix of time series data
# - family: list generated by one of the family functions of this package
# Output: invisible TRUE if data fits to family, FALSE if not,
#     stops with error if family is unknown or missing data is present
data_family_check <- function(ts, family) {
  if(any(is.na(ts))) {
    stop("Missing data in time series is not supported.")
  }
  if(any(!is.finite(ts))) {
    stop("Non-finite values in time series are not supported.")
  }
  count_dist <- c("poisson", "quasipoisson", "negative_binomial", "binomial", "quasibinomial")
  positive_dist <- c("inverse_gaussian", "gamma")
  upper_bounded_dist <- c("binomial", "quasibinomial")
  if(family$distribution %in% count_dist) {
    if(any(ts != floor(ts)) || any(ts < 0)) {
      return(FALSE)
    }
  }
  if(family$distribution %in% positive_dist) {
    if(any(ts <= 0)) {
      return(FALSE)
    }
  }
  if(family$distribution %in% upper_bounded_dist) {
    n <- family$size
    stopifnot("'size' must be length 1 or equal to number of series" = (length(n) == 1 || length(n) == nrow(ts)))
    if(length(n) == 1) {
      n <- rep(n, nrow(ts))
    }
    res <- sapply(seq(nrow(ts)), function(i) {
      all(ts[i, ] <= n[i])
    })
    if(!all(res)) {
      return(FALSE)
    }
  }
  return(TRUE)
}


# Function for testing if covariates are in correct format
# Input:
# - covariates: list of covariate matrices
# - nobs: number of time points in the time series
# - dim: number of locations in the time series
# - family
# Output: invisible TRUE if all covariates are valid, stops with error if not
covariate_check <- function(covariates, nobs, dim, family) {
  stopifnot("covariates must be submitted in a list" = is.null(covariates) || is.list(covariates))
  if(!is.null(covariates)){
    for(cov in covariates){
      if(inherits(cov, "time_constant")){
        if(length(cov) != dim){
          stop("Time constant covariates must have length equal to number of locations.")
        }
        if(family$non_negative_parameters && any(cov < 0)){
          stop("Covariate values must be non-negative for the link-function chosen within this family.")
        }
      } else if(inherits(cov, "spatial_constant")){
        if(length(cov) < nobs){
          stop("Spatial constant covariates must have at least the same number of time-points as ts.")
        } else if(length(cov) > nobs){
          warning("Spatial constant covariates have more time-points than ts. Extra values will be ignored.")
        }
        if(family$non_negative_parameters && any(cov < 0)){
          stop("Covariate values must be non-negative for the link-function chosen within this family.")
        }
      } else if(is.matrix(cov)){
        if(!is.numeric(cov)){
          stop("Covariate matrices must be numeric.")
        }
        if(nrow(cov) != dim){
          stop("Covariate matrices must have the same number of locations as ts.")
        }
        if(ncol(cov) < nobs){
          stop("Covariate matrices must have at least the same number of time-points as ts.")
        } else if(ncol(cov) > nobs){
          warning("Spatial constant covariates have more time-points than ts. Extra values will be ignored.")
        }
        if(family$non_negative_parameters && any(cov < 0)){
          stop("Covariate values must be non-negative for the link-function chosen within this family.")
        }
      } else {
        stop("Covariates must be either matrices or created using TimeConstant or SpatialConstant functions.")
      }
    }
  }
  invisible(TRUE)
}



# Function to check structure of the model orders and time lags
# Input:
# - orders: numeric vector/matrix or logical vector/matrix specifying model orders
# - time_lags: numeric vector specifying time lags corresponding to orders
# Output: list with processed orders matrix and time lags vector
check_orders <- function(orders, time_lags) {
  
  # --- Input Checks ---
  if (!is.numeric(orders) && !is.logical(orders)) {
    stop("Model orders must be numeric or logical.")
  }
  
  if (!is.vector(orders) && !is.matrix(orders)) {
    stop("Model orders must be a vector or matrix.")
  }
  
  if (!is.null(time_lags)) {
    if (!is.numeric(time_lags)) stop("Time lags must be numeric.")
    if (any(time_lags < 0L)) stop("Time lags must be non-negative.")
    if (any(duplicated(time_lags))) stop("Time lags must be unique.")
    if (any(is.infinite(time_lags))) stop("Time lags must be finite.")
    time_lags <- floor(time_lags)
  }
  
  # --- Vector case ---
  if (is.vector(orders)) {
    
    if (is.null(time_lags)) {
      time_lags <- seq_len(length(orders))
    } else if (length(time_lags) != length(orders)) {
      stop("Length of time_lags does not match the number of orders.")
    }
    
    if (any(!is.na(orders) & orders < -1)) {
      stop("Model orders must be NA or >= -1.")
    }
    
    orders <- floor(orders)
    orders[is.na(orders)] <- -1L
    
    valid <- orders >= 0L
    orders <- orders[valid]
    time_lags <- time_lags[valid]
    
    max_order <- ifelse(length(orders) > 0, max(orders), 0L)
    
    order_matrix <- outer(0:max_order, orders, FUN = function(i, j) as.integer(i <= j))
    
  } else { 
    # --- Matrix case ---
    
    if (is.null(time_lags)) {
      time_lags <- seq_len(ncol(orders))
    } else if (length(time_lags) != ncol(orders)) {
      stop("Length of time_lags does not match the number of columns in orders.")
    }
    
    if (!all(orders == 0 | orders == 1)) {
      stop("Orders must be binary (0/1) or logical.")
    }
    
    orders <- 1L * orders # Coerce to integer if logical
    
    # Remove zero columns
    keep_cols <- colSums(orders) > 0
    orders <- orders[, keep_cols, drop = FALSE]
    time_lags <- time_lags[keep_cols]
    
    # Remove trailing zero rows
    last_nonzero <- max(row(orders)[orders == 1], na.rm = TRUE)
    order_matrix <- orders[seq_len(last_nonzero), , drop = FALSE]
  }
  
  # --- Naming ---
  rownames(order_matrix) <- paste0("s_", seq_len(nrow(order_matrix)) - 1L)
  colnames(order_matrix) <- paste0("t_", time_lags)
  
  # --- Return ---
  list(orders = order_matrix, time_lags = time_lags)
}


# Function to check and align parameters with model orders
# Input:
# - parameters: numeric vector/matrix of model parameters
# - orders: numeric vector/matrix or logical vector/matrix specifying model orders
# Output: matrix with parameters aligned to orders
param_check <- function(parameters, orders) {
  
  if (!is.numeric(parameters)) {
    stop("Parameters must be numeric (vector or matrix).")
  }
  
  # --- Matrix-Fall ---
  if (is.matrix(parameters)) {
    
    if (!is.matrix(orders)) {
      stop("If parameters are a matrix, orders must also be a matrix.")
    }
    
    if (!all(dim(parameters) == dim(orders))) {
      stop("Dimensions of parameters and orders do not match.")
    }
    
    invalid_entries <- (parameters != 0) & (orders == 0)
    
    if (any(invalid_entries)) {
      warning("Some parameters are non-zero but ignored due to zero orders. These parameters are set to zero.")
      parameters[invalid_entries] <- 0
    }
    
    return(parameters)
  }
  
  # --- Vektor-Fall ---
  if (length(parameters) != sum(orders)) {
    stop("The number of parameters does not match the number of active (non-zero) orders.")
  }

  result <- matrix(0, nrow = nrow(orders), ncol = ncol(orders))
  result[orders != 0] <- parameters
  
  return(result)
}


# Function to check model specification.
# Adjusts and standardizes inputs for model fitting etc.  
# Input:
# - model: list containing model specification
# Output: adjusted model list
model_check <- function(model) {
  
  if (!is.list(model)) {
    stop("'model' must be a list.")
  }
  
  # --- Intercept Check ---
  model$intercept <- match.arg(model$intercept, c("homogeneous", "inhomogeneous"))
  
  # --- Past Observations (AR Part) ---
  if (is.null(model$past_obs) && !is.null(model$past_mean)) {
    stop("Models without regression on past observations are not identifiable.")
  }
  
  if(!is.null(model$past_obs)) {
    ar_check <- check_orders(model$past_obs, model$past_obs_time_lags)
    model$past_obs <- ar_check$orders
    model$past_obs_time_lags <- ar_check$time_lags
  }

  # --- Past Means (MA Part) ---
  if (!is.null(model$past_mean)) {
    ma_check <- check_orders(model$past_mean, model$past_mean_time_lags)
    model$past_mean <- ma_check$orders
    model$past_mean_time_lags <- ma_check$time_lags
  }
  
  # --- Covariates ---
  if (!is.null(model$covariates)) {
    cov_check <- check_orders(model$covariates, NULL)
    model$covariates <- cov_check$orders
  }

  return(model)
}


# Function to check model and parameters together
# Input:
# - model: list containing model specification
# - parameters: list containing model parameters
# - dim: integer, dimension of the time series
# Output: list with adjusted model and parameters
model_and_parameter_check <- function(model, parameters, dim) {
  
  # --- Grund-Checks ---
  if (!is.list(model)) stop("'model' must be a list.")
  if (is.null(model$intercept)) stop("Intercept specification is missing in the model.")
  if (is.null(model$past_obs) && !is.null(model$past_mean)) stop("Autoregressive part (past_obs) is required in the model.")
  
  model$intercept <- match.arg(model$intercept, c("homogeneous", "inhomogeneous"))
  
  # --- Intercept-Parameter prüfen ---
  n_intercept <- if (model$intercept == "homogeneous") 1 else dim
  if (length(parameters$intercept) != n_intercept) {
    stop("Intercept parameter length does not match the model specification.")
  }
  
  # --- Helper Funktion für Orders + Parameter ---
  process_orders_and_parameters <- function(order, time_lags, param, name) {
    check <- check_orders(order, time_lags)
    new_order <- check$orders
    new_time_lags <- check$time_lags
    new_param <- param_check(param, new_order)
    
    # Optionales Entfernen, wenn alles 0
    if (all(new_order == 0) && all(new_param == 0)) {
      return(list(order = NULL, time_lags = NULL, param = NULL))
    }
    
    rownames(new_param) <- rownames(new_order)
    colnames(new_param) <- colnames(new_order)
    
    return(list(order = new_order, time_lags = new_time_lags, param = new_param))
  }
  
  # --- AR-Teil ---
  if(!is.null(model$past_obs)) {
    ar <- process_orders_and_parameters(model$past_obs, model$past_obs_time_lags, parameters$past_obs, "past_obs")
    model$past_obs <- ar$order
    model$past_obs_time_lags <- ar$time_lags
    parameters$past_obs <- ar$param
  }
  
  # --- MA-Teil ---
  if (!is.null(model$past_mean)) {
    ma <- process_orders_and_parameters(model$past_mean, model$past_mean_time_lags, parameters$past_mean, "past_mean")
    model$past_mean <- ma$order
    model$past_mean_time_lags <- ma$time_lags
    parameters$past_mean <- ma$param
  }
  
  # --- Kovariaten ---
  if (!is.null(model$covariates)) {
    cov <- process_orders_and_parameters(model$covariates, NULL, parameters$covariates, "covariates")
    model$covariates <- cov$order
    parameters$covariates <- cov$param
  }
  
  return(list(model = model, parameters = parameters))
}


# Function to check spatial weight matrices list
# Input:
# - wlist: list of spatial weight matrices
# Output: dimension of time series (invisible)
wlist_check <- function(wlist) {
  res <- NA
  if(!is.null(wlist)) {
    if (!is.list(wlist)) {
      stop("wlist must be a list of matrices.")
    }
    if (!all(sapply(wlist, is.matrix) | (sapply(wlist, class) == "dgCMatrix"))) {
        stop("All elements of 'wlist' must be matrices.")
    }
    lM <- FALSE
    if(any(sapply(wlist, class) == "dgCMatrix")){
      lM <- requireNamespace("Matrix", quietly = TRUE)
      if(!lM){
        stop("The 'Matrix' package must be installed in order to use the 'dgCMatrix' class.")
      }
    }
    dims <- sapply(wlist, dim)
    if (!all(apply(dims, 1, function(x) all(x == x[1])))) {
        stop("All matrices in 'wlist' must have the same dimensions.")
    }
  
    if (!all(sapply(wlist, function(x) all(x >= 0)))) {
        stop("All matrices in 'wlist' must contain non-negative values.")
    }
    if(lM){
      normalized <- all(sapply(wlist, function(W) all(abs(Matrix::rowSums(W) - 1) < sqrt(.Machine$double.eps))))
    } else {
      normalized <- all(sapply(wlist, function(W) all(abs(rowSums(W) - 1) < sqrt(.Machine$double.eps))))
    }
    
    if (!normalized) {
        warning("The matrices in 'wlist' are not normalized. Consider normalizing them.")
    }
    res <- (nrow(wlist[[1]]))
  }
  invisible(res)
}

# Function to check length of spatial weight matrices list
# Input:
# - required: required number of matrices
# - specific_list: user-provided list of matrices
# - default_list: default list of matrices (wlist)
# - arg_name: argument name for error message
check_length_of_wlist <- function(required, specific_list_length, default_list_length, arg_name) {
  if(specific_list_length == 0) {
    stopifnot("Too few matrices in wlist" = (default_list_length >= required))
  } else {
    msg <- paste("Too few matrices in", arg_name)
    stopifnot(structure(specific_list_length >= required, msg = msg))
  }
}