

#######################################################
# INTERNAL
# functions based on descriptions in Auer-Gervini, 2008

fhat <- function(d, Lambda) {
  m <- length(Lambda)
  use <- Lambda[(d+1):m]
  geo <- exp(mean(log(use)))
  ari <- mean(use)
  (m-d)*log(geo/ari)
}

dhat <- function(theta, Lambda) {
  ds <- seq(0, length(Lambda)-1)
  vals <- sapply(ds, function(d, theta, Lambda) {
    fhat(d, Lambda) + theta*(length(Lambda)-1-d)
  }, theta=theta, Lambda=Lambda)
  max(ds[which(vals==max(vals))])
}

thetaLower <- function(d, Lambda) {
  m <- length(Lambda)
  if(d == m - 1) return(0)
  fd <- fhat(d, Lambda)
  kset <- (d+1):(m-1)
  temp <- sapply(kset, function(k) {
    (fhat(k, Lambda) - fd) / (k - d)
  })
  max(temp)
}

thetaUpper <- function(d, Lambda) {
  if (d == 0) return(Inf)
  m <- length(Lambda)
  fd <- fhat(d, Lambda)
  kset <- 0:(d-1)
  temp <- sapply(kset, function(k) {
    (fhat(k, Lambda) - fd) / (k - d)
  })
  min(temp)
}

#######################################################
# EXTERNAL

# broken-stick model
brokenStick <- function(k, n) {
    if (any(n < k)) stop('bad values')
    x <- 1/(1:n)
    sapply(k, function(k0) sum(x[k0:n])/n)
}

# broken stick to get dimension
bsDimension <- function(lambda, FUZZ=0.005) {
  if(inherits(lambda, "SamplePCA")) {
    lambda <- lambda@variances
  }
  N <- length(lambda)
  bs  <- brokenStick(1:N, N)
  fracVar <- lambda/sum(lambda)
  which(fracVar - bs < FUZZ)[1] - 1
}

#######################################################
# S4 interface

setClass("AuerGervini",
         representation=list(
           Lambda="numeric",
           dLevels="numeric",
           changePoints="numeric",
           lowerBounds="numeric",
           upperBounds="numeric"
           ))

AuerGervini <- function(Lambda, epsilon=2e-16) {
  if (inherits(Lambda, "SamplePCA")) {
    Lambda <- Lambda@variances    
  }
  if (epsilon > 0) {
    Lambda <- Lambda[Lambda > epsilon]
  }
  Lambda <- rev(sort(Lambda))
  rg <- (1:length(Lambda)) - 1
  lowerBounds <- sapply(rg, thetaLower, Lambda=Lambda)
  upperBounds <- sapply(rg, thetaUpper, Lambda=Lambda)
  dLevels <- rev(rg[lowerBounds <= upperBounds])
  changePoints <- rev(lowerBounds[lowerBounds <= upperBounds])[-1]
  new("AuerGervini", Lambda=Lambda,
      dLevels=dLevels, changePoints=changePoints,
      lowerBounds=lowerBounds, upperBounds=upperBounds)
}

estimateTop <- function(object, n, m) {
  delta <- 1
  if (n >= m) {
    esttop <- max(delta*(18.8402+1.9523*m-0.0005*m^2)/n, 
                  1.03*object@changePoints)
  } else {
    esttop <- max(delta*(18.8402+1.9523*m-0.0005*m^2)*n/m^2, 
                  1.03*object@changePoints)
  }
  esttop
}

########################################################################
##    Different criteria to choose significantly long step length     ##
########################################################################

#----------------------------------------------------------------------
# twice mean criterion
agDimension1 <- function(object, n, m) {
  stepLength <- diff(c(object@changePoints, estimateTop(object, n, m)))
  if (length(stepLength) > 3) {
    magic <- (stepLength > 2*mean(stepLength))
  } else {
    magic <- (stepLength == max(stepLength))
  }
  ifelse(any(magic), object@dLevels[1+which(magic)[1]], 0)
}

#---------------------------------------------------------------------
# k-means criterion (4 different versions)
# version 1: centers are max. and min. (select highest in large group)
agDimension2 <- function(object, logtransform, n, m) {
  stepLength <- diff(c(object@changePoints, estimateTop(object, n, m)))
  if (logtransform) {
    stepLength <- log(stepLength)
  }
  if (length(stepLength) > 3) { 
    kmeanfit <- kmeans(stepLength, centers=c(min(stepLength), 
                max(stepLength)))
    magic <- (kmeanfit$cluster==2)
  } else {
    magic <- (stepLength == max(stepLength))
  }
  ifelse(any(magic), object@dLevels[1+which(magic)[1]], 0)
}

# version 2: centers are second max. and second min. (select highest large)
agDimension2 <- function(object, logtransform, n, m) {
  stepLength <- diff(c(object@changePoints, estimateTop(object, n, m)))
  if (logtransform) {
    stepLength <- log(stepLength)
  }
  if (length(stepLength) > 3) { 
    sortsl <- sort(stepLength, decreasing=FALSE)
    kmeanfit <- kmeans(stepLength, centers=c(sortsl[2], 
                sortsl[length(sortsl)-1]))
    magic <- (kmeanfit$cluster==2)
  } else {
    magic <- (stepLength == max(stepLength))
  }
  ifelse(any(magic), object@dLevels[1+which(magic)[1]], 0)
}

# version 3: centers are max. and min. (select highest one with step
# length greater than mean step length of the large group)
agDimension2 <- function(object, logtransform, n, m) {
  stepLength <- diff(c(object@changePoints, estimateTop(object, n, m)))
  if (logtransform) {
    stepLength <- log(stepLength)
  }
  if (length(stepLength) > 3) { 
    kmeanfit <- kmeans(stepLength, centers=c(min(stepLength), 
                max(stepLength)))
    magic <- (stepLength>=max(kmeanfit$centers))
  } else {
    magic <- (stepLength == max(stepLength))
  }
  ifelse(any(magic), object@dLevels[1+which(magic)[1]], 0)
}

# version 4: choose k=3 if more features (select highest largest)
agDimension2 <- function(object, logtransform, n, m) {
  stepLength <- diff(c(object@changePoints, estimateTop(object, n, m)))
  if (logtransform) {
    stepLength <- log(stepLength)
  }
  if (length(stepLength) > 3) { 
    kmeanfit <- kmeans(stepLength, centers=c(min(stepLength), 
                max(stepLength)))
    magic <- (kmeanfit$cluster==2)
    # choose k = 3 if there are many features (extra center is median)
    if (ceiling(log(length(stepLength))/2)>2) {
      kmeanfit <- kmeans(stepLength, centers=c(min(stepLength), 
                  median(stepLength), max(stepLength)))
      magic <- (kmeanfit$cluster==3)
    }
  } else {
    magic <- (stepLength == max(stepLength))
  }
  ifelse(any(magic), object@dLevels[1+which(magic)[1]], 0)
} 

#-------------------------------------------------------------------------
# spectral clustering criterion
agDimension3 <- function(object, logtransform, n, m) {
  stepLength <- diff(c(object@changePoints, estimateTop(object, n, m)))
  if (logtransform) {
    stepLength <- log(stepLength)
  }
  if (length(stepLength) > 3) { 
    # project 1D step length sequence onto a 2D line (y=x)
    dat <- cbind(X1=stepLength, X2=stepLength)
    scfit <- specc(dat, centers=2)
    scmean1 <- mean(stepLength[scfit==1])
    scmean2 <- mean(stepLength[scfit==2])
    scnum <- ifelse(scmean1>scmean2, 1, 2)
    magic <- (scfit==scnum)
  } else {
    magic <- (stepLength == max(stepLength))
  }
  ifelse(any(magic), object@dLevels[1+which(magic)[1]], 0)
}

#------------------------------------------------------------------------
# naive t-test criterion (2 different versions)
# version 1: naive idea to detect the significant change point in  
# difference of sorted step lengths (t-based method)
# TO DO: include significance level alpha in arguments, currently
# set significance level to be 0.01
agDimension4 <- function(object, logtransform, n, m) {
  stepLength <- diff(c(object@changePoints, estimateTop(object, n, m)))
  if (logtransform) {
    stepLength <- log(stepLength)
  }
  if (length(stepLength) > 4) { 
    sort1 <- sort(stepLength, decreasing=FALSE, method="qu", 
             index.return=TRUE)
    diffsl <- diff(sort1$x)
    meanlist <- cumsum(diffsl)[3:length(diffsl)]/(3:length(diffsl))
    meannum <- length(meanlist)
    iter <- 0
    pvec <- NULL
    repeat {
      if (iter == meannum-1) break
      sdvalue <- sd(diffsl[1:(3+iter)]-meanlist[iter+1])
      pvalue <- 1 - pt((meanlist[iter+2]-meanlist[iter+1])/(sdvalue/sqrt(3+iter)),
                3+iter)
      pvec <- c(pvec, pvalue)
      if (pvalue < 0.01) break
      iter <- iter + 1
    }
    if (iter < meannum-1) {
        slset <- sort1$ix[(iter+4):length(stepLength)]
        magic <- (1:length(stepLength) %in% slset)
    }  else { 
         pt <- which.min(pvec) 
         slset <- sort1$ix[(pt+4):length(stepLength)]
         magic <- (1:length(stepLength) %in% slset)
    }
  } else {
    magic <- (stepLength == max(stepLength))
  }
  ifelse(any(magic), object@dLevels[1+which(magic)[1]], 0)
}

# version 2: naive idea to detect the significant change point in  
# difference of sorted step lengths (variant t-based method)
# TO DO: include significance level alpha in arguments, currently
# set significance level to be 0.01
agDimension4 <- function(object, logtransform, n, m) {
  stepLength <- diff(c(object@changePoints, estimateTop(object, n, m)))
  if (logtransform) {
    stepLength <- log(stepLength)
  }
  if (length(stepLength) > 4) { 
    sort1 <- sort(stepLength, decreasing=FALSE, method="qu", 
             index.return=TRUE)
    diffsl <- diff(sort1$x)
    meanlist <- cumsum(diffsl)[3:length(diffsl)]/(3:length(diffsl))
    meannum <- length(meanlist)
    iter <- 0
    pvec <- NULL
    repeat {
      if (iter == meannum-1) break
      sdvalue <- sd(diffsl[1:(4+iter)]-meanlist[iter+2])
      # include one more step length to make std. larger
      pvalue <- 1 - pt((meanlist[iter+2]-meanlist[iter+1])/(sdvalue/sqrt(4+iter)),
                4+iter)
      pvec <- c(pvec, pvalue)
      if (pvalue < 0.01) break
      iter <- iter + 1
    }
    if (iter < meannum-1) {
        slset <- sort1$ix[(iter+4):length(stepLength)]
        magic <- (1:length(stepLength) %in% slset)
    }  else { 
         pt <- which.min(pvec) 
         slset <- sort1$ix[(pt+4):length(stepLength)]
         magic <- (1:length(stepLength) %in% slset)
    }
  } else {
    magic <- (stepLength == max(stepLength))
  }
  ifelse(any(magic), object@dLevels[1+which(magic)[1]], 0)
}

#-------------------------------------------------------------------------
# changepoint criterion (cpt.mean in changepoint package)
agDimension5 <- function(object, logtransform, n, m) {
  stepLength <- diff(c(object@changePoints, estimateTop(object, n, m)))
  if (logtransform) {
    stepLength <- log(stepLength)
  }
  if (length(stepLength) > 3) { 
    sort1 <- sort(stepLength, decreasing=FALSE, method="qu",
             index.return=TRUE)
    fit <- cpt.mean(sort1$x, Q=2)
    cp <- cpts(fit)
    if (length(cp)!=0) {
        slset <- sort1$ix[(cp+1):length(stepLength)]
        magic <- (1:length(stepLength) %in% slset)
    } else { magic <- (stepLength == max(stepLength)) }
  } else {
      magic <- (stepLength == max(stepLength))
  }
  ifelse(any(magic), object@dLevels[1+which(magic)[1]], 0)
}

#------------------------------------------------------------------------
# cpm criterion (use detectChangePointBatch function in cpm package)
# use detectChangePointBatch function to detect the significant change 
# point in sorted step lengths, cpmethod is cpmType in the function
agDimension6 <- function(object, logtransform, cpmethod, n, m) {
  stepLength <- diff(c(object@changePoints, estimateTop(object, n, m)))
  if (logtransform) {
    stepLength <- log(stepLength)
    if (cpmethod=="Exponential" || cpmethod=="ExponentialAdjusted") {
      stop("CPM type is not appropriate after log transformation. Try other 
      CPM types.")
    }
  }
  if (length(stepLength) > 3) { 
    sort1 <- sort(stepLength, decreasing=FALSE, method="qu", 
             index.return=TRUE)
    fit <- detectChangePointBatch(sort1$x, cpmethod, alpha=0.05,
             lambda=NA)
    cp <- fit$changePoint
    if (fit$changeDetected) {
        slset <- sort1$ix[(cp+1):length(stepLength)]
        magic <- (1:length(stepLength) %in% slset)
    } else { magic <- (stepLength == max(stepLength)) }
  } else {
      magic <- (stepLength == max(stepLength))
  }
  ifelse(any(magic), object@dLevels[1+which(magic)[1]], 0)
}

#----------------------------------------------------------------------
agDimension <- function(data, object, 
                        agmethod=c("meantwice","kmean","specc","ttest","changepoint","cpm"),
                        logtransform=FALSE,
                        cpmethod="Student") {
  # default criterion is "cpm" if agmethod is missing
  if (missing(agmethod)) {
    agmethod <- "cpm"
  }
  agmethod <- match.arg(agmethod)
  n <- nrow(data)
  m <- ncol(data)
  agdim <- switch(agmethod,
                  meantwice=agDimension1(object, n, m),
                  kmean=agDimension2(object, logtransform, n, m),
                  specc=agDimension3(object, logtransform, n, m),
                  ttest=agDimension4(object, logtransform, n, m),
                  changepoint=agDimension5(object, logtransform, n, m),
                  cpm=agDimension6(object, logtransform, cpmethod, n, m))
  agdim
}

# y is a vector where first element is number of samples, second element
# is number of features, and third element is Auer-Gervini estimate
# Not pass the arguments of agDimension, but input the agDimension 
# result in a vector y
setMethod("plot", c("AuerGervini", "numeric"), function(x, y,
        main="Bayesian Sensitivity Analysis",
        ...) {
  datdim <- y[1:2]
  top <- estimateTop(x, datdim[1], datdim[2])
  fun <- stepfun(x@changePoints, x@dLevels)
  plot(fun, xlab="Prior Theta", ylab="Number of Components",
     main=main, xlim=c(0, top), ...)
  abline(h=y[3], lty=2, lwd=2,col='pink') 
  invisible(x)
})

# TO DO: need to pass the arguments of agDimension into the class here.
# setMethod("summary", "AuerGervini", function(object, ...) {
#   cat("An '", class(object), "' object that estimates the number of ",
#       "principal components to be ", agDimension(object), ".\n", sep="")
# })
