#' Partial Dependence and other Profiles
#'
#' Calculates different types of profiles across covariable values. By default, partial dependence profiles [1] are calculated. Other options are profiles of ALE (accumulated local effects, see [2]), response, predicted values ("M plots" or "marginal plots", see [2]), residuals, and shap. The results are aggregated either by (weighted) means or by (weighted) quartiles. Note that ALE profiles are calibrated by (weighted) average predictions. In contrast to the suggestions in [2], we calculate ALE profiles of factors in the same order as the factor levels. They are not being reordered based on similiarity of other variables.
#'
#' For numeric covariables \code{v} with more than \code{n_bins} disjoint values, its values are binned. Alternatively, \code{breaks} can be provided to specify the binning. For partial dependence profiles (and partly also ALE profiles), this behaviour can be overritten either by providing a vector of evaluation points (\code{pd_evaluate_at}) or an evaluation \code{pd_grid}. By the latter we mean a data frame with column name(s) with a (multi-)variate evaluation grid. For partial dependence, ALE, and prediction profiles, "model", "predict_function", linkinv" and "data" are required. For response profiles its "y", "linkinv" and "data" and for shap profiles it is just "shap". "data" can be passed on the fly.
#'
#' @param x An object of class \code{flashlight} or \code{multiflashlight}.
#' @param v The variable to be profiled.
#' @param data An optional \code{data.frame}. Not used for \code{type = "shap"}.
#' @param by An optional vector of column names used to additionally group the results.
#' @param type Type of the profile: Either "partial dependence", "ale", "predicted", "response", "residual", or "shap".
#' @param stats Statistic to calculate: "mean" or "quartiles". For ALE profiles, only "mean" makes sense.
#' @param breaks Cut breaks for a numeric \code{v}.
#' @param n_bins Maxmium number of unique values to evaluate for numeric \code{v}. Only used if neither \code{grid} nor \code{pd_evaluate_at} is specified.
#' @param cut_type For the default "equal", bins of equal width are created for \code{v} by \code{pretty}. Choose "quantile" to create quantile bins.
#' @param use_linkinv Should retransformation function be applied? Default is TRUE. Not used for type "shap".
#' @param value_name Column name in resulting \code{data} containing the profile value. Defaults to "value".
#' @param q1_name Name of the resulting column with first quartile values. Only relevant for \code{stats} "quartiles".
#' @param q3_name Name of the resulting column with third quartile values. Only relevant for \code{stats} "quartiles".
#' @param label_name Column name in resulting \code{data} containing the label of the flashlight. Defaults to "label".
#' @param type_name Column name in the resulting \code{data} with the plot \code{type}.
#' @param counts_name Name of the column containing counts if \code{counts} is TRUE.
#' @param counts Should counts be added?
#' @param counts_weighted If \code{counts} is TRUE: Should counts be weighted by the case weights? If TRUE, the sum of \code{w} is returned by group.
#' @param v_labels If FALSE, return group centers of \code{v} instead of labels. Only relevant for types "response", "predicted" or "residual" and if \code{v} is being binned. In that case useful if e.g. different flashlights use different data sets and bin labels would not match.
#' @param pred Optional vector with predictions (after application of inverse link). Can be used to avoid recalculation of predictions over and over if the functions is to be repeatedly called for different \code{v} and predictions are computationally expensive to make. Only relevant for \code{type = "predicted"} and \code{type = "ale"}.
#' @param pd_evaluate_at Vector with values of \code{v} used to evaluate the profile. Only relevant for type = "partial dependence" and "ale".
#' @param pd_grid A \code{data.frame} with grid values, e.g. generated by \code{expand.grid}. Only used for type = "partial dependence".
#' @param pd_indices A vector of row numbers to consider in calculating partial dependence profiles. Only used for type = "partial dependence" and "ale".
#' @param pd_n_max Maximum number of ICE profiles to calculate (will be randomly picked from \code{data}). Only used for type = "partial dependence" and "ale".
#' @param pd_seed Integer random seed used to select ICE profiles. Only used for type = "partial dependence" and "ale".
#' @param pd_center How should ICE curves be centered? Default is "no". Choose "first", "middle", or "last" to 0-center at specific evaluation points. Choose "mean" to center all profiles at the within-group means. Choose "0" to mean-center curves at 0. Only relevant for partial dependence.
#' @param ale_two_sided If \code{TRUE}, \code{v} is continuous and \code{breaks} are passed or being calculated, then two-sided derivatives are calculated for ALE instead of left derivatives. More specifically: Usually, local effects at value x are calculated using points between x-e and x. Set \code{ale_two_sided = TRUE} to use points between x-e/2 and x+e/2.
#' @param ... Further arguments passed to \code{cut3} resp. \code{formatC} in forming the cut breaks of the \code{v} variable. Not relevant for partial dependence and ALE profiles.
#' @return An object of classes \code{light_profile}, \code{light} (and a list) with the following elements.
#' \itemize{
#'   \item \code{data} A tibble containing results. Can be used to build fully customized visualizations. Its column names are specified by all other items in this list.
#'   \item \code{by} Names of group by variable.
#'   \item \code{v} The variable(s) evaluated.
#'   \item \code{type} Same as input \code{type}. For information only.
#'   \item \code{stats} Same as input \code{stats}.
#'   \item \code{value_name} Same as input \code{value_name}.
#'   \item \code{q1_name} Same as input \code{q1_name}.
#'   \item \code{q3_name} Same as input \code{q3_name}.
#'   \item \code{label_name} Same as input \code{label_name}.
#'   \item \code{type_name} Same as input \code{type_name}.
#'   \item \code{counts_name} Same as input \code{counts_name}.
#' }
#' @export
#' @references
#' [1] Friedman J. H. (2001). Greedy function approximation: A gradient boosting machine. The Annals of Statistics, 29:1189–1232.
#'
#' [2] Apley D. W. (2016). Visualizing the effects of predictor variables in black box supervised learning models. ArXiv <arXiv:1612.08468>.
#'
#' @examples
#' fit_full <- lm(Sepal.Length ~ ., data = iris)
#' mod_full <- flashlight(model = fit_full, label = "full", data = iris, y = "Sepal.Length")
#' light_profile(mod_full, v = "Species")
#' light_profile(mod_full, v = "Species", type = "response")
#' light_profile(mod_full, v = "Species", stats = "quartiles")
#' light_profile(mod_full, v = "Petal.Width", type = "residual")
#' light_profile(mod_full, v = "Petal.Width", pd_evaluate_at = 2:4)
#' @seealso \code{\link{light_effects}}, \code{\link{plot.light_profile}}.
light_profile <- function(x, ...) {
  UseMethod("light_profile")
}

#' @describeIn light_profile Default method not implemented yet.
#' @export
light_profile.default <- function(x, ...) {
  stop("No default method available yet.")
}

#' @describeIn light_profile Profiles for flashlight.
#' @export
light_profile.flashlight <- function(x, v = NULL, data = NULL, by = x$by,
                                     type = c("partial dependence", "ale", "predicted",
                                              "response", "residual", "shap"),
                                     stats = c("mean", "quartiles"),
                                     breaks = NULL, n_bins = 11,
                                     cut_type = c("equal", "quantile"), use_linkinv = TRUE,
                                     value_name = "value", q1_name = "q1", q3_name = "q3",
                                     label_name = "label", type_name = "type",
                                     counts_name = "counts", counts = TRUE,
                                     counts_weighted = FALSE, v_labels = TRUE,
                                     pred = NULL, pd_evaluate_at = NULL, pd_grid = NULL,
                                     pd_indices = NULL, pd_n_max = 1000, pd_seed = NULL,
                                     pd_center = c("no", "first", "middle", "last", "mean", "0"),
                                     ale_two_sided = FALSE, ...) {
  type <- match.arg(type)
  stats <- match.arg(stats)
  cut_type <- match.arg(cut_type)
  pd_center <- match.arg(pd_center)

  # If SHAP, extract data
  if (type == "shap") {
    if (!is.shap(x$shap)) {
      stop("No shap values calculated. Run 'add_shap' for the flashlight first.")
    }
    data <- x$shap$data[x$shap$data[[x$shap$variable_name]] == v, ]
  } else if (is.null(data)) {
    data <- x$data
  }

  # Checks (more will be done below or in the called functions)
  stopifnot(!anyDuplicated(c(by, union(v, names(pd_grid)), if (counts) counts_name,
                             if (stats == "quartiles") c(q1_name, q3_name),
                             value_name, label_name, type_name)))
  if (!is.null(pred) && type == "predicted" && length(pred) != nrow(data)) {
    stop("Wrong number of predicted values passed.")
  }
  if (type == "ale" && stats == "quartiles") {
    stop("The cumsum step of ALE does not make sense for quartiles.")
  }

  # Update flashlight
  if (type != "shap") {
    x <- flashlight(x, data = data, by = by,
                    linkinv = if (use_linkinv) x$linkinv else function(z) z)
  }

  # Calculate profiles
  arg_list <- list(x = x, v = v, evaluate_at = pd_evaluate_at, breaks = breaks,
                   n_bins = n_bins, cut_type = cut_type, indices = pd_indices,
                   n_max = pd_n_max, seed = pd_seed, value_name = value_name)
  if (type == "partial dependence") {
    arg_list <- c(arg_list, list(grid = pd_grid, center = pd_center,
                                 label_name = label_name, id_name = "id_xxx"))
    cp_profiles <- do.call(light_ice, arg_list)
    v <- cp_profiles$v
    data <- cp_profiles$data
  } else if (type == "ale") {
    arg_list <- c(arg_list, list(counts_name = counts_name,
                                 counts = counts, counts_weighted = counts_weighted,
                                 pred = pred, two_sided = ale_two_sided))
    agg <- do.call(ale_profile, arg_list)
  } else {
    stopifnot(!is.null(v),
              v %in% colnames(data),
              nrow(data) >= 1L)

    # Add predictions/response to data
    data[[value_name]] <- switch(type,
      response = response(x),
      predicted = if (is.null(pred)) predict(x) else pred,
      residual = residuals(x),
      shap = data[["shap_"]])

    # Replace v values by binned ones
    cuts <- auto_cut(data[[v]], breaks = breaks,
                     n_bins = n_bins, cut_type = cut_type, ...)
    data[[v]] <- cuts$data[[if (v_labels) "level" else "value"]]
  }

  # Aggregate predicted values
  if (type != "ale") {
    agg <- grouped_stats(data = data, x = value_name, w = x$w, by = c(by, v),
                         stats = stats, counts = counts,
                         counts_weighted = counts_weighted,
                         counts_name = counts_name, q1_name = q1_name,
                         q3_name = q3_name, na.rm = TRUE)
  }

  # Finalize results
  agg[[label_name]] <- x$label

  # Code type as factor (relevant for light_effects)
  agg[[type_name]] <- factor(type, c("response", "predicted", "partial dependence",
                                     "ale", "residual", "shap"))

  # Collect results
  out <- list(data = agg, by = by, v = v, type = type, stats = stats,
              value_name = value_name, q1_name = q1_name, q3_name = q3_name,
              label_name = label_name, type_name = type_name, counts_name = counts_name)
  class(out) <- c("light_profile", "light", "list")
  out
}

#' @describeIn light_profile Profiles for multiflashlight.
#' @export
light_profile.multiflashlight <- function(x, v = NULL, data = NULL,
                                          breaks = NULL, n_bins = 11,
                                          cut_type = c("equal", "quantile"),
                                          pd_evaluate_at = NULL, pd_grid = NULL, ...) {
  cut_type <- match.arg(cut_type)

  if (is.null(breaks) && is.null(pd_evaluate_at) && is.null(pd_grid)) {
    breaks <- common_breaks(x = x, v = v, data = data, breaks = breaks,
                            n_bins = n_bins, cut_type = cut_type)
  }
  all_profiles <- lapply(x, light_profile, v = v, data = data,
                         breaks = breaks, n_bins = n_bins, cut_type = cut_type,
                         pd_evaluate_at = pd_evaluate_at, pd_grid = pd_grid, ...)
  light_combine(all_profiles, new_class = "light_profile_multi")
}
