#' @title
#' flassomsm
#'
#' @description
#' Fits a penalized regression model with combined Fusedlasso penalty using hybrid algorithm
#'
#' @details This is the core function of the package.This function fits a penalized Cox-type regression model
#' within the framework of a multi-state model. It is designed to handle transition-specific covariate effects
#' across multiple states by incorporating a regularization approach that combines both the Lasso penalty and
#' the Fused penalty. The penalization is of the L1 type, meaning it applies to the absolute values of
#' the regression coefficients, encouraging sparsity in the model. Additionally, it penalizes the absolute
#' differences between corresponding coefficients across different transitions, promoting similarity or grouping
#' of effects across transitions when appropriate. This dual-penalty structure enables both variable selection and
#' smoothing of covariate effects across related transitions, which is particularly useful in complex multi-state
#' settings where covariate effects may share underlying patterns but still exhibit transition-specific behaviors.
#' The parameters are estimated using a hybrid algorithm techinique combining PIRLS and ADMM together.
#'
#' @param msdata is a multi-state model in extended form having columns Tstart,Tstop,trans (covariates expanded transition wise)
#' @param X expanded covariate matrix of the msdata
#' @param p number of covariates in the dataset before expanding
#' @param lambda_lasso parameter for lasso penalty
#' @param lambda_fused parameter for fused penalty
#' @param tol_outer tolerence limit in the outer loop of PIRLS algorithm to converge
#' @param max_outer round of iterations until tolerence is reached for the outer loop
#' @param rho augmented Lagrangian parameter
#' @param tol_admm tolerence limit in the inner loop of ADMM algorithm to converge
#' @param max_admm round of iterations until tolerence is reached for the inner loop
#' @param trace logical triggering for status information
#' @param MSM_profile logical triggering to return the results
#' @param use_parallel logical flag to indicate whether to use parallel processing
#' @return A list with elements like matrix of estimated beta coefficients along with standard error and p value,number of iterations, aic (Akaike Information Criterion) value, gcv (GCV criterion) value and df (degrees of freedom)
#' @import numDeriv
#'         corpcor
#'         progressr
#'         future.apply
#' @importFrom stats rnorm pnorm
#' @importFrom survival Surv coxph survfit
#'
#' @examples
#' ##
#' set.seed(123)
#' data(msdata_3state)
#' covs1 <- msdata_3state[,9:10]
#' flassomsm(msdata = msdata_3state,X=msdata_3state[,c(11:dim(msdata_3state)[[2]])],
#' p = length(covs1),lambda_lasso = 0.3,lambda_fused = 0.5,tol_outer = 1e-4,
#' max_outer = 50, rho = 1, tol_admm = 1e-4, max_admm = 100,trace = TRUE,
#' MSM_profile = FALSE)
#'
#' # For 2 covariates and 3 number of transitions
#'
#'\donttest{
#' # Simulate msdata_4state instead of loading from disk
#' msdata_4state <- simdata(seed=123,n=1000,dist="weibull",cdist="exponential",
#'                  cparams=list(rate = 0.1),lambdas=c(0.1, 0.2, 0.3, 0.4),
#'                  gammas=c(1.5, 2, 2.5, 2.6),beta_list=list(c(-0.05, 0.01, 0.5, 0.6),
#'                  c(-0.03, 0.02, 0.07, 0.08),c(-0.04, 0.03, 0.04, -0.03),
#'                  c(-0.05,0.05,0.6,0.8)),cov_means=c(0,10,2,3),cov_sds=c(1,20,5,1.05),
#'                  trans_list=list(c(2, 3, 4, 5),c(3, 4, 5),c(4, 5), c(5), c()),
#'                  state_names=c("Tx", "Rec", "Death", "Reldeath", "srv"))
#'
#' set.seed(123)
#' sub_msdata_4state <- msdata_4state[msdata_4state$id %in% sample(unique(msdata_4state$id), 10), ]
#' covs1 <- sub_msdata_4state[,9:10]
#' flassomsm(msdata = sub_msdata_4state,X=sub_msdata_4state[,c(13:32)],
#'          p = length(covs1),lambda_lasso = 0.5,lambda_fused = 0.6,tol_outer = 1e-4,
#'          max_outer = 50, rho = 1, tol_admm = 1e-4, max_admm = 100,trace = TRUE,MSM_profile = FALSE)
#'}
#' # For 2 covariates and 10 number of transitions
#' ##
#' @export
#' @author Atanu Bhattacharjee,Gajendra Kumar Vishwakarma,Abhipsa Tripathy


flassomsm <- function(msdata, X, p, lambda_lasso, lambda_fused,
                           tol_outer = 1e-4, max_outer = 50,
                           rho = 1, tol_admm = 1e-4, max_admm = 100,
                           trace = TRUE, MSM_profile = FALSE,
                           use_parallel = TRUE) {

  if (!requireNamespace("progressr", quietly = TRUE)) stop("progressr package is required")
  if (use_parallel) {
    if (!requireNamespace("future.apply", quietly = TRUE)) stop("future.apply package is required")
    future::plan(future::multisession, workers = future::availableCores() - 1)
  }

  event <- msdata$status
  transition <- msdata$trans
  T <- length(unique(transition))
  n <- nrow(X)
  beta <- rep(1e-2, T * p)

  risksets_msm <- function(msdata) {
    transitions <- unique(msdata$trans)
    riskset_list <- list()

    for (t in transitions) {
      fit <- survfit(Surv(Tstart, Tstop, status) ~ 1, data = subset(msdata, trans == t))
      riskset_list[[as.character(t)]] <- fit$n.risk
    }

    return(riskset_list)
  }

  generate_penalty_structure <- function(p, T) {
    n <- p * T
    num_rows <- n + choose(T, 2) * p
    P <- matrix(0, num_rows, n)
    diag(P[1:n, ]) <- 1
    row_idx <- n + 1

    for (i in 1:p) {
      for (j in 1:(T - 1)) {
        for (k in (j + 1):T) {
          P[row_idx, (i - 1) * T + j] <- 1
          P[row_idx, (i - 1) * T + k] <- -1
          row_idx <- row_idx + 1
        }
      }
    }

    return(P)
  }

  fisherinfor_msm <- function(beta, X, riskset_list, event, transition) {

    logLik_function <- function(beta) {
      unique_transitions <- unique(transition)
      log_partial_likelihood <- 0

      X <- as.matrix(X)
      beta <- as.numeric(beta)
      f <- as.vector(X %*% beta)
      ef <- exp(f)

      for (h in unique_transitions) {
        indices <- which(transition == h & event == 1)

        for (i in indices) {
          risk_set_size <- riskset_list[[as.character(h)]][i]
          if (length(risk_set_size) == 0 || is.na(risk_set_size)) {
            next
          }

          if (risk_set_size > 0) {
            risk_sum <- sum(ef[1:min(risk_set_size, length(ef))])
            log_partial_likelihood <- log_partial_likelihood + (f[i] - log(risk_sum))
          }
        }
      }

      return(log_partial_likelihood)
    }

    hessian_matrix <- hessian(func = logLik_function, x = beta)
    fisher_information <- -hessian_matrix

    return(fisher_information)
  }

  score_msm <- function(beta, X, riskset_list, event, transition) {
    unique_transitions <- unique(transition)
    score <- rep(0, length(beta))

    X <- as.matrix(X)
    beta <- as.numeric(beta)
    f <- as.vector(X %*% beta)
    ef <- exp(f)

    for (h in unique_transitions) {
      indices <- which(transition == h & event == 1)

      for (i in indices) {
        risk_set_size <- riskset_list[[as.character(h)]][i]

        if (length(risk_set_size) == 0 || is.na(risk_set_size)) {
          next
        }

        if (risk_set_size > 0) {
          risk_indices <- 1:min(risk_set_size, length(ef))
          risk_sum <- sum(ef[risk_indices])
          weighted_X <- colSums(X[risk_indices, , drop = FALSE] * ef[risk_indices]) / risk_sum

          score <- score + (X[i, ] - weighted_X)
        }
      }
    }

    return(score)
  }

  likelihood_msm <- function(beta, X, riskset_list, event, transition) {
    unique_transitions <- unique(transition)
    log_partial_likelihood <- 0

    X <- as.matrix(X)
    beta <- as.numeric(beta)
    f <- as.vector(X %*% beta)
    ef <- exp(f)

    for (h in unique_transitions) {
      indices <- which(transition == h & event == 1)

      for (i in indices) {
        risk_set_size <- riskset_list[[as.character(h)]][i]

        if (length(risk_set_size) == 0 || is.na(risk_set_size)) {
          next
        }

        if (risk_set_size > 0) {
          risk_sum <- sum(ef[1:min(risk_set_size, length(ef))])
          log_partial_likelihood <- log_partial_likelihood + (f[i] - log(risk_sum))
        }
      }
    }

    return(log_partial_likelihood)
  }

  riskset_list <- risksets_msm(msdata)
  D_fused <- generate_penalty_structure(p, T)
  C <- rbind(diag(T * p), D_fused)
  lambda_vec <- c(rep(lambda_lasso, T * p), rep(lambda_fused, nrow(D_fused)))
  covariate_names <- colnames(X)

  if (MSM_profile) {
    F_list <- list()
    S_list <- list()
  }

  soft_threshold <- function(x, lambda) {
    ifelse(abs(x) > lambda, x - sign(x) * lambda, x / (1 + lambda))
  }

  beta_final <- beta

  progressr::with_progress({
    pbar <- progressr::progressor(steps = max_outer)

    for (outer_iter in seq_len(max_outer)) {
      pbar(sprintf("PIRLS iteration %d/%d", outer_iter, max_outer))

      score <- score_msm(beta, X, riskset_list, event, transition)
      info <- fisherinfor_msm(beta, X, riskset_list, event, transition)

      if (MSM_profile) {
        F_list[[outer_iter]] <- info
        S_list[[outer_iter]] <- score
      }

      epsilon <- 1e-6
      info <- info + diag(epsilon, nrow(info))
      z_working <- beta + solve(info, score)

      beta_admm <- rep(0, T * p)
      gamma <- rep(0, nrow(C))
      u <- rep(0, nrow(C))

      inv_mat <- tryCatch(solve(info + rho * crossprod(C)),
                          error = function(e) pseudoinverse(info + rho * crossprod(C)))

      for (admm_iter in seq_len(max_admm)) {
        beta_admm_new <- as.numeric(inv_mat %*% (info %*% z_working + rho * crossprod(C, gamma - u)))

        suppressWarnings({
          if (any(is.na(beta_admm_new)) || any(is.infinite(beta_admm_new))) {
            warning("NA or Inf detected in ADMM update, skipping iteration")
            next
          }
        })


        beta_admm <- beta_admm_new
        Cbeta <- C %*% beta_admm
        gamma <- soft_threshold(Cbeta + u, lambda_vec / rho)
        u <- u + Cbeta - gamma

        if (max(abs(beta_admm - beta)) < tol_admm) break
      }

      beta_new <- beta_admm

      if (trace && outer_iter %% 20 == 0) {
        beta_df <- data.frame(Covariate = covariate_names, Estimate = round(beta_new, 4))
        message("\nIntermediate beta estimates:\n", paste(capture.output(print(beta_df)), collapse = "\n"))
      }
      

      if (max(abs(beta_new - beta)) < tol_outer) break
      beta <- beta_new
    }
  })

  beta_final <- beta
  se_beta <- sqrt(abs(diag(solve(info))))
  z_scores <- ifelse(se_beta > 0, beta_final / se_beta, 0)
  p_values <- 2 * (1 - pnorm(abs(z_scores)))

  beta_final_df <- data.frame(Covariate = covariate_names,
                              Estimate = beta_final,
                              Std_Error = se_beta,
                              Z_score = z_scores,
                              P_value = p_values)

  df <- tryCatch({
    sum(diag(info %*% solve(info)))
  }, error = function(e) NA)

  nlpl <- tryCatch({
    -likelihood_msm(beta = beta_final, X = X, riskset_list = riskset_list, event = event, transition = transition)
  }, error = function(e) NA)

  N <- nrow(msdata)
  aic <- if (!is.na(nlpl) && !is.na(df)) 2 * (nlpl + df) else NA
  gcv <- if (!is.na(nlpl) && !is.na(df)) (1 / N) * nlpl / (N * ((1 - df / N)^2)) else NA

  result <- list(beta = beta_final_df,
                 iterations = outer_iter,
                 aic = aic,
                 gcv = gcv,
                 df = df)

  if (MSM_profile) {
    result$F <- F_list
    result$S <- S_list
    result$D_fused <- D_fused
  }

  return(result)
}

utils::globalVariables(c("id","from","to","trans","Tstart", "Tstop","time","status"))
