# methods.R
# S3 methods for class "random_gaussian_nb"
#
# NOTE: This file assumes you have predict.random_gaussian_nb implemented.

# so check() knows `i` is intended (kept for backward compatibility)
utils::globalVariables(c("i"))

# ---------- internal helpers (not exported) ----------

#' @keywords internal
.all_model_feats <- function(models, field) {
  if (length(models) == 0) return(character(0))
  unique(unlist(lapply(models, function(m) m[[field]]), use.names = FALSE))
}

#' @keywords internal
.avg_feats_per_model <- function(models) {
  if (length(models) == 0) return(0)
  mean(vapply(models, function(m) length(m$feats), numeric(1)))
}

# ---------- print ----------

#' @rdname random_gaussian_nb
#' @param x A \code{random_gaussian_nb} object.
#' @param ... Additional arguments (ignored).
#' @export
print.random_gaussian_nb <- function(x, ...) {
  cat("<Random Naive Bayes model (posterior averaging)>\n")
  cat("  Bootstrap iterations: ", x$n_iter, "\n", sep = "")
  cat("  Feature fraction:     ", x$feature_fraction, "\n", sep = "")
  cat("  Parallel cores:       ", x$cores, "\n", sep = "")
  cat("  Classes:              ", paste(x$.classes, collapse = ", "), "\n", sep = "")
  invisible(x)
}

# ---------- summary ----------

#' @rdname random_gaussian_nb
#' @param object A \code{random_gaussian_nb} object.
#' @param ... Additional arguments (ignored).
#' @export
summary.random_gaussian_nb <- function(object, ...) {
  models <- object$.models

  num_feats <- .all_model_feats(models, "num_feats")
  cat_feats <- .all_model_feats(models, "cat_feats")
  avg_feats <- .avg_feats_per_model(models)

  cat("## Random Naive Bayes model summary\n\n")
  cat("Model:\n")
  cat("  - Bootstrap iterations: ", object$n_iter, "\n", sep = "")
  cat("  - Feature fraction:     ", object$feature_fraction, "\n", sep = "")
  cat("  - Parallel cores:       ", object$cores, "\n", sep = "")
  cat("  - Classes:              ", paste(object$.classes, collapse = ", "), "\n", sep = "")
  cat("\n")

  cat("Predictors (across all bootstrap models):\n")
  cat("  - Unique numeric predictors:     ", length(num_feats), "\n", sep = "")
  cat("  - Unique categorical predictors: ", length(cat_feats), "\n", sep = "")
  cat("  - Avg. features per model:       ", sprintf("%.2f", avg_feats), "\n", sep = "")
  cat("\n")

  cat("Likelihoods:\n")
  cat("  - Numeric: Gaussian\n")
  cat("  - Categorical: Multinomial (Laplace-smoothed)\n\n")

  invisible(object)
}

# ---------- str ----------

#' @rdname random_gaussian_nb
#' @param object A \code{random_gaussian_nb} object.
#' @param ... Additional arguments (ignored).
#' @export
str.random_gaussian_nb <- function(object, ...) {
  models <- object$.models
  num_feats <- .all_model_feats(models, "num_feats")
  cat_feats <- .all_model_feats(models, "cat_feats")

  cat("List of ", length(object), " (random_gaussian_nb)\n", sep = "")
  cat("$ n_iter          : int ", object$n_iter, "\n", sep = "")
  cat("$ feature_fraction: num ", object$feature_fraction, "\n", sep = "")
  cat("$ cores           : int ", object$cores, "\n", sep = "")
  cat("$ .classes        : ", paste(object$.classes, collapse = ", "), "\n", sep = "")
  cat("$ X_train         : data.frame [", nrow(object$X_train), " x ", ncol(object$X_train), "]\n", sep = "")
  cat("$ y_train         : factor [", length(object$y_train), "]\n", sep = "")
  cat("$ .models         : list(", length(models), ")\n", sep = "")
  cat("    - unique numeric feats     : ", length(num_feats), "\n", sep = "")
  cat("    - unique categorical feats : ", length(cat_feats), "\n", sep = "")
  invisible(object)
}

# ---------- nobs ----------

#' @importFrom stats nobs
#' @rdname random_gaussian_nb
#' @param object A \code{random_gaussian_nb} object.
#' @param ... Additional arguments (ignored).
#' @export
nobs.random_gaussian_nb <- function(object, ...) {
  nrow(object$X_train)
}

# ---------- fitted ----------

#' @rdname random_gaussian_nb
#' @param object A \code{random_gaussian_nb} object.
#' @param ... Passed to \code{\link[stats:predict]{predict()}}.
#' @export
fitted.random_gaussian_nb <- function(object, ...) {
  stats::predict(object, newdata = NULL, type = "class", ...)
}

# ---------- plot ----------

#' @rdname random_gaussian_nb
#' @param x A \code{random_gaussian_nb} object.
#' @param which Diagnostic to plot: \code{"feature_frequency"}, \code{"prior_variability"}, or \code{"prob_entropy"}.
#' @param newdata Optional new data for \code{"prob_entropy"}. If \code{NULL}, uses the training data.
#' @param top Number of top features to show for \code{"feature_frequency"}.
#' @param ... Passed to the underlying plotting function (e.g., \code{\link[graphics:barplot]{barplot()}},
#'   \code{\link[graphics:boxplot]{boxplot()}}, \code{\link[graphics:hist]{hist()}}).
#' @export
plot.random_gaussian_nb <- function(x,
                                    which = c("feature_frequency", "prior_variability", "prob_entropy"),
                                    newdata = NULL,
                                    top = 20,
                                    ...) {
  which <- match.arg(which)
  models  <- x$.models
  classes <- x$.classes

  if (length(models) == 0) stop("No bootstrap models found in `x$.models`.")

  if (which == "feature_frequency") {
    feats_all <- unlist(lapply(models, `[[`, "feats"), use.names = FALSE)
    tab <- sort(table(feats_all), decreasing = TRUE)

    if (length(tab) == 0) stop("No features recorded in bootstrap models.")

    top <- as.integer(top)
    if (is.na(top) || top < 1) stop("`top` must be a positive integer.")
    top <- min(top, length(tab))

    tab <- tab[seq_len(top)]
    freq <- as.numeric(tab) / length(models)

    op <- graphics::par(no.readonly = TRUE)
    on.exit(graphics::par(op), add = TRUE)

    graphics::barplot(
      freq,
      names.arg = names(tab),
      las = 2,
      ylab = "Selection frequency",
      main = sprintf("Top-%d feature selection frequencies", top),
      ...
    )
    return(invisible(x))
  }

  if (which == "prior_variability") {
    prior_mat <- vapply(
      models,
      function(m) {
        pr <- m$prior
        pr <- pr[classes]      # align to global class order
        pr[is.na(pr)] <- 0
        as.numeric(pr)
      },
      FUN.VALUE = numeric(length(classes))
    )
    prior_mat <- t(prior_mat)

    op <- graphics::par(no.readonly = TRUE)
    on.exit(graphics::par(op), add = TRUE)

    graphics::boxplot(
      prior_mat,
      names = classes,
      ylab = "Class prior",
      main = "Bootstrap prior variability",
      ...
    )
    return(invisible(x))
  }

  if (which == "prob_entropy") {
    probs <- stats::predict(x, newdata = newdata, type = "prob")
    P <- as.matrix(probs)

    eps <- .Machine$double.eps
    P <- pmax(P, eps)
    ent <- -rowSums(P * log(P))

    op <- graphics::par(no.readonly = TRUE)
    on.exit(graphics::par(op), add = TRUE)

    graphics::hist(
      ent,
      breaks = "FD",
      main = "Predictive entropy (averaged posterior)",
      xlab = "Entropy",
      ...
    )
    return(invisible(x))
  }

  invisible(x)
}

