
#' Plot MCMC traceplots and density plots
#'
#' @param samples Array of MCMC samples
#' @param var Parameter names to plot
#' @param ind Chain indices to plot
#' @param burnin Nunber of initial MCMC samples to discard
#' @param width Width of the plot
#' @param height Height of the plot
#' @param legend Logical, whether to include a legend of parameter names
#' @param legend.location Location of legend
#' @param traceplot Logical, whether to include traceplots
#' @param densityplot Logaical, whether to include density plots
#' @param file Optional filename to save figure as a file
#'
#' @examples
#' samples <- cbind(rnorm(1000), rgamma(1000, 1))
#' colnames(samples) <- c('alpha', 'beta')
#' samplesPlot(samples)
#' 
#' @export
samplesPlot <- function(samples, var=colnames(samples), ind=NULL, burnin=NULL, width=7, height=4, legend=TRUE, legend.location='topright', traceplot=TRUE, densityplot=TRUE, file=NULL) {
    if(!is.null(file)) pdf(file, width=width, height=height) else
    ## orig: if(inherits(try(knitr::opts_chunk$get('dev'), silent=TRUE), 'try-error') || is.null(knitr::opts_chunk$get('dev')))   ## if called from Rmarkdown/knitr
    if(inherits(try(eval(parse(text='knitr::opts_chunk$get(\'dev\')')[[1]]), silent=TRUE), 'try-error') || is.null(eval(parse(text='knitr::opts_chunk$get(\'dev\')')[[1]])))
        dev.new(height=height, width=width)
    par.save <- par(no.readonly = TRUE)
    par(mfrow=c(1,traceplot+densityplot), cex=0.7, cex.main=1.5, cex.axis=0.9, lab=c(3,3,7), mgp=c(0,0.4,0), mar=c(1.6,1.6,2,0.6), oma=c(0,0,0,0), tcl=-0.3, bty='l')
    ## process samples
    var <- gsub('\\[', '\\\\\\[', gsub('\\]', '\\\\\\]', var))   ## add \\ before any '[' or ']' appearing in var
    var <- unlist(lapply(var, function(n) grep(paste0('^', n,'(\\[.+\\])?$'), colnames(samples), value=TRUE)))  ## expanded any indexing
    samples <- samples[, var, drop=FALSE]
    if(!is.null(ind) && !is.null(burnin)) stop('only specify either ind or burnin')
    if(!is.null(ind))     samples <- samples[ind, , drop=FALSE]
    if(!is.null(burnin))  samples <- samples[(burnin+1):dim(samples)[1], , drop=FALSE]
    nparam <- ncol(samples)
    rng <- range(samples)
    if(!traceplot & !densityplot) stop('both traceplot and densityplot are false')
    if(traceplot) {  ## traceplot
        plot(1:nrow(samples), ylim=rng, type='n', main='Traceplots', xlab='', ylab='')
        for(i in 1:nparam)
            lines(samples[,i], col=rainbow(nparam, alpha=0.75)[i])
        if(legend & !densityplot & !is.null(dimnames(samples)) & is.character(dimnames(samples)[[2]]))
            legend(legend=dimnames(samples)[[2]], fill=rainbow(nparam, alpha=0.5), bty='n', x=legend.location)
    }  ## finish traceplot
    if(densityplot) {  ## denstyplot
        xMin <- xMax <- yMax <- NULL
        for(i in 1:nparam) {
            d <- density(samples[,i])
            xMin <- min(xMin,d$x); xMax <- max(xMax,d$x); yMax <- max(yMax, d$y) }
        plot(1, xlim=c(xMin,xMax), ylim=c(0,yMax), type='n', main='Posterior Densities', xlab='', ylab='', yaxt='n')
        for(i in 1:nparam)
            polygon(density(samples[,i]), col=rainbow(nparam, alpha=0.2)[i], border=rainbow(nparam, alpha=0.2)[i])
        if(legend & !is.null(dimnames(samples)) & is.character(dimnames(samples)[[2]]))
            legend(legend=dimnames(samples)[[2]], fill=rainbow(nparam, alpha=0.5), bty='n', x=legend.location)
    }  ## finish densityplot
    if(!is.null(file)) dev.off()
    invisible(par(par.save))
}
 


#' Compare summary statistics from multiple MCMC chains
#'
#' Parameter plots from each chain show median and 95% credible intervals
#'
#' @param samplesList List of arrays of MCMC samples from different chains
#' @param var Parameter names to plot
#' @param nrows Number of rows in the resulting plot
#' @param width Width of figure
#' @param height Height of figure
#' @param legend Logical, whether to include a legend of chain names
#' @param legend.location Legend location
#' @param jitter Scale factor for spreading out lines from each chain
#' @param buffer.right Additional buffer on left side of plot
#' @param buffer.left Additional buffer on right side of plot
#' @param cex Expansion coefficient for text
#' @param file Filename for saving figure to a file
#'
#' @examples
#' samples1 <- cbind(rnorm(1000, 1), rgamma(1000, 1), rpois(1000, 1))
#' colnames(samples1) <- c('alpha', 'beta', 'gamma')
#' samples2 <- cbind(rnorm(1000, 2), rgamma(1000, 2), rpois(1000, 2))
#' colnames(samples2) <- c('alpha', 'beta', 'gamma')
#' samplesList <- list(chain1 = samples1, chain2 = samples2)
#' chainsPlot(samplesList, nrow = 1, jitter = .3, buffer.left = .5, buffer.right = .5)
#'
#' @export
chainsPlot <- function(samplesList, var=NULL, nrows=3, width=7, height=min(1+3*nrows,7), legend=!is.null(names(samplesList)), legend.location='topright', jitter=1, buffer.right=0, buffer.left=0, cex=1, file=NULL) {
    if(!is.null(file)) pdf(file, width=width, height=height) else
    ## orig: if(inherits(try(knitr::opts_chunk$get('dev'), silent=TRUE), 'try-error') || is.null(knitr::opts_chunk$get('dev')))   ## if called from Rmarkdown/knitr
    if(inherits(try(eval(parse(text='knitr::opts_chunk$get(\'dev\')')[[1]]), silent=TRUE), 'try-error') || is.null(eval(parse(text='knitr::opts_chunk$get(\'dev\')')[[1]])))
        dev.new(height=height, width=width)
    par.save <- par(no.readonly = TRUE)
    par(mfrow=c(nrows,1), oma=c(3,1,1,1), mar=c(4,1,0,1), mgp=c(3,0.5,0))
    if(!(class(samplesList) %in% c('list', 'mcmc.list'))) samplesList <- list(samplesList)
    if(!is.null(var)) samplesList <- lapply(samplesList, function(samples) {
        var <- gsub('\\[', '\\\\\\[', gsub('\\]', '\\\\\\]', var))   ## add \\ before any '[' or ']' appearing in var
        theseVar <- unlist(lapply(var, function(n) grep(paste0('^', n,'(\\[.+\\])?$'), colnames(samples), value=TRUE)))  ## expanded any indexing
        samples[, theseVar, drop=FALSE]
    })
    chainParamNamesList <- lapply(samplesList, function(s) colnames(s))
    nChains <- length(samplesList)
    paramNamesAll <- unique(unlist(lapply(samplesList, function(s) colnames(s))))
    nParamsAll <- length(paramNamesAll)
    cols <- rainbow(nChains)
    ## construct 3D summary array:
    summary <- array(as.numeric(NA), dim = c(nChains, 3, nParamsAll))
    if(!is.null(names(samplesList))) dimnames(summary)[[1]] <- names(samplesList)
    dimnames(summary)[[2]] <- c('mean','low','upp')
    dimnames(summary)[[3]] <- paramNamesAll
    for(iChain in 1:nChains) {
        theseSamples <- samplesList[[iChain]]
        thisSummary <- rbind(mean = apply(theseSamples, 2, mean),
                             low  = apply(theseSamples, 2, function(x) quantile(x, 0.025)),
                             upp  = apply(theseSamples, 2, function(x) quantile(x, 0.975)))
        summary[iChain,c('mean','low','upp'),colnames(thisSummary)] <- thisSummary
    }
    nParamsPerRow <- ceiling(nParamsAll/nrows)
    sq <- if(nChains==1) 0 else seq(-1,1,length=nChains)
    scale <- width/nParamsPerRow * jitter * 0.1  ## adjust jitter scale factor at end
    for(iRow in 1:nrows) {
        rowParamInd <- (1+(iRow-1)*nParamsPerRow) : ifelse(iRow==nrows,nParamsAll,iRow*nParamsPerRow)
        nRowParams <- length(rowParamInd)
        rowParamNames <- paramNamesAll[rowParamInd]
        xs <- 1:nRowParams
        names(xs) <- rowParamNames
        ylim <- range(summary[,c('low','upp'),rowParamNames], na.rm=TRUE)
        plot(x=-100, y=0, xlim=c(1-buffer.left,nParamsPerRow+buffer.right), ylim=ylim, xaxt='n', ylab='', xlab='', tcl=-0.3, cex.axis=cex)
        axis(1, at=1:nRowParams, labels=FALSE, tcl=-0.3)
        text(x=1:nRowParams, y=par()$usr[3]-0.1*(par()$usr[4]-par()$usr[3]), labels=rowParamNames, srt=45, adj=1, xpd=TRUE, cex=0.9*cex)
        for(iChain in 1:nChains) {
            ps <- intersect(rowParamNames, chainParamNamesList[[iChain]])
            xsJittered <- xs + sq[iChain]*scale
            points(x=xsJittered[ps], y=summary[iChain,'mean',ps], pch=16, col=cols[iChain])
            segments(x0=xsJittered[ps], y0=summary[iChain,'low',ps], y1=summary[iChain,'upp',ps], lwd=1, col=cols[iChain])
        }
        if(legend) legend(legend.location, legend=names(samplesList), pch=16, col=cols, cex=cex)
    }
    if(!is.null(file)) dev.off()
    invisible(par(par.save))
}



##### IN MCMC_utils.R
#####
#####samplesSummary <- function(samples) {
#####    cbind(
#####        `Mean`      = apply(samples, 2, mean),
#####        `Median`    = apply(samples, 2, median),
#####        `St.Dev.`   = apply(samples, 2, sd),
#####        `95%CI_low` = apply(samples, 2, function(x) quantile(x, 0.025)),
#####        `95%CI_upp` = apply(samples, 2, function(x) quantile(x, 0.975)))
#####}


## ## utility for plotting MCMC samples from multiple chains (one parameter only)
## samplesPlot2 <- function(samplesList, ind=1, burnin=NULL, legend=TRUE, legend.location='topright') {
## #  dev.new(height=height, width=width)
##   nChains <- length(samplesList)
##   par(mfrow=c(1,2), cex=0.7, cex.main=1.5, lab=c(3,3,7), mgp=c(0,0.6,0), mar=c(2,1,2,1), oma=c(0,0,0,0), tcl=-0.3, yaxt='n', bty='l')
##   samples <- samplesList[[1]][, ind, drop=FALSE]
##   if (nChains > 1)
##     for (chain in 2:nChains)
##       samples <- cbind(samples, samplesList[[chain]][, ind, drop=FALSE])
##   if(!is.null(burnin))
##     samples <- samples[(burnin+1):dim(samples)[1], , drop=FALSE]
##   colnames(samples) <- paste("Chain", 1:nChains)
##   nparam <- ncol(samples)
##   rng <- range(samples)
##   plot(1:nrow(samples), ylim=rng, type='n', main=paste('Traceplots', ind), xlab='', ylab='')
##   for(i in 1:nparam)
##     lines(samples[,i], col=rainbow(nparam, alpha=0.75)[i])
##   xMin <- xMax <- yMax <- NULL
##   for(i in 1:nparam) {
##     d <- density(samples[,i])
##     xMin <- min(xMin,d$x); xMax <- max(xMax,d$x); yMax <- max(yMax, d$y) }
##   plot(1, xlim=c(xMin,xMax), ylim=c(0,yMax), type='n', main=paste('Posterior Densities', ind), xlab='', ylab='')
##   alpha_density <- 0.2
##   for(i in 1:nparam)
##     polygon(density(samples[,i]), col=rainbow(nparam, alpha=alpha_density)[i], border=rainbow(nparam, alpha=alpha_density)[i])
##   if(legend & !is.null(dimnames(samples)) & is.character(dimnames(samples)[[2]]))
##     legend(legend=dimnames(samples)[[2]], fill=rainbow(nparam, alpha=0.5), bty='n', x=legend.location)
## }

## ## utility for plotting MCMC samples from multiple chains (also allows multiple parameters)
## samplesPlot3 <- function(samplesList, ind=1, burnin=NULL, legend=TRUE, legend.location='topright', 
##                          common.scale = TRUE, nCols = 2) {
##   #  dev.new(height=height, width=width)
##   nChains <- length(samplesList)
##   nPar <- length(ind)
##   if (common.scale && nPar > 1) {
##     samplesAll <- samplesList[[1]][, ind, drop=FALSE]
##     if (nChains > 1)
##       for (chain in 2:nChains)
##         samplesAll <- cbind(samplesAll, samplesList[[chain]][, ind, drop=FALSE])
##     if(!is.null(burnin))
##       samplesAll <- samplesAll[(burnin+1):dim(samplesAll)[1], , drop=FALSE]
##     rng <- range(samplesAll)
##     xMin <- xMax <- yMax <- NULL
##     for(i in 1:ncol(samplesAll)) {
##       d <- density(samplesAll[,i])
##       xMin <- min(xMin,d$x); xMax <- max(xMax,d$x); yMax <- max(yMax, d$y) 
##     }
##   }
##   nRows <- ceiling(nPar / nCols * 2)
##   par(mfrow=c(nRows, nCols), cex=0.7, cex.main=1, lab=c(3,3,7), mgp=c(0,0.6,0), mar=c(2,1,2,1), oma=c(0,0,0,0), tcl=-0.3, yaxt='n', bty='l')
##   for (par in ind) {
##     samples <- samplesList[[1]][, par, drop=FALSE]
##     if (nChains > 1)
##       for (chain in 2:nChains)
##         samples <- cbind(samples, samplesList[[chain]][, par, drop=FALSE])
##     if(!is.null(burnin))
##       samples <- samples[(burnin+1):dim(samples)[1], , drop=FALSE]
##     colnames(samples) <- paste("Chain", 1:nChains)
##     if (!common.scale || nPar == 1) {
##       rng <- range(samples)
##       xMin <- xMax <- yMax <- NULL
##       for(i in 1:nChains) {
##         d <- density(samples[,i])
##         xMin <- min(xMin,d$x); xMax <- max(xMax,d$x); yMax <- max(yMax, d$y) }
##     }
##     plot(1:nrow(samples), ylim=rng, type='n', main=paste('Traceplots', par), xlab='', ylab='')
##     for(i in 1:nChains)
##       lines(samples[,i], col=rainbow(nChains, alpha=0.75)[i])
##     plot(1, xlim=c(xMin,xMax), ylim=c(0,yMax), type='n', main=paste('Posterior Densities', par), xlab='', ylab='')
##     alpha_density <- 0.2
##     for(i in 1:nChains)
##       polygon(density(samples[,i]), col=rainbow(nChains, alpha=alpha_density)[i], border=rainbow(nChains, alpha=alpha_density)[i])
##     if(legend & !is.null(dimnames(samples)) & is.character(dimnames(samples)[[2]]))
##       legend(legend=dimnames(samples)[[2]], fill=rainbow(nChains, alpha=0.5), bty='n', x=legend.location)
##   }
## }


