#' Estimate Average Treatment Effect (ATE) via Semi-Supervised Learning
#'
#' @param Y Numeric vector. Observed outcomes for labeled data (with missing values for unlabelled).
#' @param A Numeric vector. Treatment indicator (1 for treated, 0 for control).
#' @param R Logical or binary vector. Indicator for labeled data (1 if labeled, 0 if not).
#' @param mu1 Numeric vector. Estimated outcome regression \eqn{E[Y \mid A = 1, X]}.
#' @param mu0 Numeric vector. Estimated outcome regression \eqn{E[Y \mid A = 0, X]}.
#' @param pi1 Numeric vector. Estimated propensity scores \eqn{P(A = 1 \mid X)}.
#' @param pi0 Numeric vector. Estimated propensity scores \eqn{P(A = 0 \mid X)}.
#' @param imp.A Numeric vector. Estimated treatment probabilities using surrogate covariates \code{W}.
#' @param imp.A1Y1 Numeric vector. Imputed \eqn{E[Y \mid A = 1, W]} using surrogate variables.
#' @param imp.A0Y1 Numeric vector. Imputed \eqn{E[Y \mid A = 0, W]} using surrogate variables.
#' @param min.pi Numeric. Lower bound to truncate estimated propensity scores (default = 0.05).
#' @param max.pi Numeric. Upper bound to truncate estimated propensity scores (default = 0.95).
#'
#' @return A list containing:
#' \describe{
#'   \item{est}{Estimated ATE.}
#'   \item{se}{Estimated standard error of ATE.}
#' }
#'
#' @details
#' This function estimates the ATE in a semi-supervised setting, where outcomes are only observed
#' for a subset of the sample. Surrogate variables and imputed models are used to leverage information
#' from unlabelled data.
#'
#' @examples
#' set.seed(123)
#' N <- 400
#' n <- 200  # Number of labeled observations
#' labeled_indices <- sample(1:N, n)
#'
#' # Generate covariates and treatment
#' X <- rnorm(N)
#' A <- rbinom(N, 1, plogis(X))
#'
#' # True potential outcomes
#' Y0_true <- X + rnorm(N)
#' Y1_true <- X + 1 + rnorm(N)
#'
#' # Observed outcomes
#' Y_full <- ifelse(A == 1, Y1_true, Y0_true)
#'
#' # Only labeled samples have observed Y
#' Y <- rep(NA, N)
#' Y[labeled_indices] <- Y_full[labeled_indices]
#' R <- rep(0, N); R[labeled_indices] <- 1
#'
#' # Nuisance parameter estimates (can be replaced by actual model predictions)
#' mu1 <- X + 0.5
#' mu0 <- X - 0.5
#' pi1 <- plogis(X)
#' pi0 <- 1 - pi1
#' imp.A <- plogis(X)
#' imp.A1Y1 <- plogis(X) * (X + 0.5)
#' imp.A0Y1 <- (1 - plogis(X)) * (X - 0.5)
#'
#' # Estimate ATE
#' result <- ate.SSL(
#'   Y = Y,
#'   A = A,
#'   R = R,
#'   mu1 = mu1,
#'   mu0 = mu0,
#'   pi1 = pi1,
#'   pi0 = pi0,
#'   imp.A = imp.A,
#'   imp.A1Y1 = imp.A1Y1,
#'   imp.A0Y1 = imp.A0Y1
#' )
#'
#' print(result$est)
#' print(result$se)
#'
#' @export


ate.SSL = function(Y, A, R, mu1, mu0, pi1,pi0, imp.A,
                   imp.A1Y1, imp.A0Y1, min.pi = 0.05, max.pi = 0.95)
{
  labeled_indices <- which(R == 1)
  n = sum(!is.na(Y))
  N = length(pi1)
  rho.inv = N/n

  pi1 = pmin(max.pi, pmax(min.pi, pi1))
  pi1 = pi1 * mean(A[labeled_indices])/mean(pi1)
  pi0 = pmin(max.pi, pmax(min.pi, pi0))
  pi0 = pi0 * mean(A[labeled_indices])/mean(pi0)

  infl = (
    (mu1 + imp.A1Y1/pi1 - imp.A*mu1/pi1) -
      (mu0 + imp.A0Y1/pi0 - (1-imp.A)*mu0/pi0)
  )
  infl[labeled_indices] = infl[labeled_indices] + rho.inv*(
    (A[labeled_indices]*Y[labeled_indices] - imp.A1Y1[labeled_indices])/pi1[labeled_indices]
    -((1-A[labeled_indices])*Y[labeled_indices] - imp.A0Y1[labeled_indices])/pi0[labeled_indices]
    - (A[labeled_indices] - imp.A[labeled_indices])*(mu1[labeled_indices]/pi1[labeled_indices] + mu0[labeled_indices]/pi0[labeled_indices])
  )

  return(list(est = mean(infl),
              se = sd(infl)/sqrt(N)))
}
