#!/usr/bin/perl

use Modern::Perl '2018';
use Fatal qw(open);

die "Must be run from top level" if !-e './tools/genStan';

my $dest = 'src/stan_files';

my $itemModel = 'functions {
#include /functions/cmp_prob2.stan
}
';

my $commonData =
'  // dimensions
  int<lower=1> NPA;             // number of players or objects or things
  int<lower=1> NCMP;            // number of unique comparisons
  int<lower=1> N;               // number of observations';

my $responseData =
'  // response data
  int<lower=1, upper=NPA> pa1[NCMP];        // PA1 for observation N
  int<lower=1, upper=NPA> pa2[NCMP];        // PA2 for observation N
  int weight[NCMP];
  int pick[NCMP];
  int refresh[NCMP];';

my $unidim_ll =
'  vector[NTHRESH*2 + 1] prob;
  vector[N] log_lik;
  int cur = 1;

  for (cmp in 1:NCMP) {
    real ll;
    if (refresh[cmp]) {
      prob = cmp_probs(scale, alpha, theta[pa1[cmp]], theta[pa2[cmp]], cumTh);
    }
    ll = categorical_lpmf(rcat[cmp] | prob);
    for (wx in 1:weight[cmp]) {
      log_lik[cur] = ll;
      cur = cur + 1;
    }
  }
';

my $threshold_prior = 'threshold ~ normal(0, 2.0);';
my $multivariateThresholdPrior = "$threshold_prior";
my $alpha_prior = 'alpha ~ exponential(0.1);';

sub mkUnidim {
    my ($adapt, $ll) = @_;
    my $data = $adapt? "varCorrection" : "scale";
    my $par = $adapt? "sigma" : "alpha";
    my $scaleDef = $adapt? "
  real scale = (sd(theta) ^ varCorrection)/1.749;" : "";
    my $theta_sd = $adapt? "sigma" : "1.0";
    my $prior = $adapt? 'sigma ~ lognormal(1, 1);' : $alpha_prior;
    my $cmp_alpha = $adapt? '' : ', alpha';
    my $llBody = $ll? $unidim_ll : '';

    my $unidim =
qq[data {
$commonData
  int<lower=1> NTHRESH;         // number of thresholds
  real $data;
$responseData
}
transformed data {
  int rcat[NCMP];

  for (cmp in 1:NCMP) {
    rcat[cmp] = pick[cmp] + NTHRESH + 1;
  }
}
parameters {
  vector[NPA] theta;
  vector[NTHRESH] threshold;
  real<lower=0> $par;
}
transformed parameters {
  vector[NTHRESH] cumTh = cumulative_sum(threshold);$scaleDef
}
model {
  vector[NTHRESH*2 + 1] prob;
  $prior
  theta ~ normal(0, $theta_sd);
  $threshold_prior
  for (cmp in 1:NCMP) {
    if (refresh[cmp]) {
      prob = cmp_probs(scale$cmp_alpha, theta[pa1[cmp]], theta[pa2[cmp]], cumTh);
    }
    if (weight[cmp] == 1) {
      target += categorical_lpmf(rcat[cmp] | prob);
    } else {
      target += weight[cmp] * categorical_lpmf(rcat[cmp] | prob);
    }
  }
}
generated quantities {
  real thetaVar = variance(theta);
$llBody}
];
}

sub mvtDataCommon {
    my ($data, $tdataDecl, $tdata) = @_;
    $data = '' if !defined $data;
    $tdataDecl = '' if !defined $tdataDecl;
    $tdata = '' if !defined $tdata;
qq[data {
$commonData
  int<lower=1> NITEMS;
  int<lower=1> NTHRESH[NITEMS];         // number of thresholds
  int<lower=1> TOFFSET[NITEMS];
  vector[NITEMS] scale;$data
$responseData
  int item[NCMP];
}
transformed data {
  int totalThresholds = sum(NTHRESH);
  int rcat[NCMP];$tdataDecl
  for (cmp in 1:NCMP) {
    rcat[cmp] = pick[cmp] + NTHRESH[item[cmp]] + 1;
  }$tdata
}];
}

my $multivariateThresholdParam =
  'vector[totalThresholds] threshold;';

my $multivariateThresholdTParamDecl =
  'vector[totalThresholds] cumTh;';

my $multivariateThresholdTParam =
'for (ix in 1:NITEMS) {
    int from = TOFFSET[ix];
    int to = TOFFSET[ix] + NTHRESH[ix] - 1;
    cumTh[from:to] = cumulative_sum(threshold[from:to]);
  }';

my $multivariateQuickLikelihoodDecl =
  'vector[max(NTHRESH)*2 + 1] prob;
  int probSize;';

my $multivariateQuickLikelihood =
'for (cmp in 1:NCMP) {
    if (refresh[cmp]) {
      int ix = item[cmp];
      int from = TOFFSET[ix];
      int to = TOFFSET[ix] + NTHRESH[ix] - 1;
      probSize = (2*NTHRESH[ix]+1);
      prob[:probSize] = cmp_probs(scale[ix], alpha[ix],
               theta[pa1[cmp], ix],
               theta[pa2[cmp], ix], cumTh[from:to]);
    }
    if (weight[cmp] == 1) {
      target += categorical_lpmf(rcat[cmp] | prob[:probSize]);
    } else {
      target += weight[cmp] * categorical_lpmf(rcat[cmp] | prob[:probSize]);
    }
  }';

my $multivariateLikelihoodDecl =
  '
  vector[max(NTHRESH)*2 + 1] prob;
  int probSize;
  vector[N] log_lik;
  int cur = 1;
';

my $multivariateLikelihood = '

  for (cmp in 1:NCMP) {
    real ll;
    if (refresh[cmp]) {
      int ix = item[cmp];
      int from = TOFFSET[ix];
      int to = TOFFSET[ix] + NTHRESH[ix] - 1;
      probSize = (2*NTHRESH[ix]+1);
      prob[:probSize] = cmp_probs(scale[ix], alpha[ix],
               theta[pa1[cmp], ix],
               theta[pa2[cmp], ix], cumTh[from:to]);
    }
    ll = categorical_lpmf(rcat[cmp] | prob[:probSize]);
    for (wx in 1:weight[cmp]) {
      log_lik[cur] = ll;
      cur = cur + 1;
    }
  }';

sub mkCorModel {
    my ($ll) = @_;
    my $llDecl = $ll? $multivariateLikelihoodDecl : '';
    my $llBody = $ll? $multivariateLikelihood : '';
    mvtDataCommon().
qq[
parameters {
  $multivariateThresholdParam
  vector<lower=0>[NITEMS] alpha;
  matrix[NPA,NITEMS]      rawTheta;
  cholesky_factor_corr[NITEMS] rawThetaCorChol;
}
transformed parameters {
  $multivariateThresholdTParamDecl
  matrix[NPA,NITEMS]      theta;

  // non-centered parameterization due to thin data
  for (pa in 1:NPA) {
    theta[pa,] = (rawThetaCorChol * rawTheta[pa,]')';
  }
  $multivariateThresholdTParam
}
model {
  $multivariateQuickLikelihoodDecl

  rawThetaCorChol ~ lkj_corr_cholesky(2);
  for (pa in 1:NPA) {
    rawTheta[pa,] ~ normal(0,1);
  }
  $multivariateThresholdPrior
  $alpha_prior
  $multivariateQuickLikelihood
}
generated quantities {$llDecl
  corr_matrix[NITEMS] thetaCor;
  thetaCor = multiply_lower_tri_self_transpose(rawThetaCorChol);$llBody
}
];
}

sub mkFactorModel {
    my ($ll, $psi) = @_;
    my $llDecl = $ll? $multivariateLikelihoodDecl : '';
    my $llBody = $ll? $multivariateLikelihood : '';

    my $psiData = '';
    my $psiParam = '';
    my $psiTparam = '';
    my $latentPrior = '
  rawFactor[,1] ~ normal(0, 1);';
    if ($psi) {
	$psiData = '
  int<lower=1> NPSI;  // = NFACTORS * (NFACTORS-1) / 2;
  real psiScalePrior[NPSI];';
	$psiParam = '
  corr_matrix[NFACTORS] Psi;';
	$psiTparam = '
  cholesky_factor_corr[NFACTORS] CholPsi = cholesky_decompose(Psi);';
	$latentPrior = '
  for (cx in 1:(NFACTORS-1)) {
    for (rx in (cx+1):NFACTORS) {
      target += normal_lpdf(logit(0.5 + Psi[rx,cx]/2.0) | 0, psiScalePrior[px]);
      px += 1;
    }
  }
  for (xx in 1:NPA) {
    rawFactor[xx,] ~ multi_normal_cholesky_lpdf(rep_vector(0, NFACTORS), CholPsi);
  }';
    }

    mvtDataCommon("
  vector[NITEMS] alpha;
  int<lower=1> NFACTORS;
  real factorScalePrior[NFACTORS];$psiData
  int<lower=1> NPATHS;
  int factorItemPath[2,NPATHS];  // 1 is factor index, 2 is item index",
		 '
  vector[NPATHS] pathScalePrior;',
		 '
  for (px in 1:NPATHS) {
    int fx = factorItemPath[1,px];
    int ix = factorItemPath[2,px];
    if (fx < 1 || fx > NFACTORS) {
      reject("factorItemPath[1,","px","] names factor ", fx, " (NFACTORS=",NFACTORS,")");
    }
    if (ix < 1 || ix > NITEMS) {
      reject("factorItemPath[2,","px","] names item ", ix, " (NITEMS=",NITEMS,")");
    }
    pathScalePrior[px] = factorScalePrior[fx];
  }').
    qq[
parameters {
  $multivariateThresholdParam$psiParam
  matrix[NPA,NFACTORS] rawFactor;      // do not interpret, see factor
  vector[NPATHS] rawLoadings; // do not interpret, see factorLoadings
  matrix[NPA,NITEMS] rawUniqueTheta; // do not interpret, see uniqueTheta
  vector[NITEMS] rawUnique;      // do not interpret, see unique
}
transformed parameters {
  $multivariateThresholdTParamDecl$psiTparam
  matrix[NPA,NITEMS] theta;
  vector[NPATHS] rawPathProp;  // always positive
  real rawPerComponentVar[NITEMS,1+NFACTORS];
  $multivariateThresholdTParam
  for (ix in 1:NITEMS) {
    theta[,ix] = rawUniqueTheta[,ix] * rawUnique[ix];
    rawPerComponentVar[ix, 1] = variance(theta[,ix]);
  }
  for (fx in 1:NFACTORS) {
    for (ix in 1:NITEMS) rawPerComponentVar[ix,1+fx] = 0;
  }
  for (px in 1:NPATHS) {
    int fx = factorItemPath[1,px];
    int ix = factorItemPath[2,px];
    vector[NPA] theta1 = rawLoadings[px] * rawFactor[,fx];
    rawPerComponentVar[ix,1+fx] = variance(theta1);
    theta[,ix] += theta1;
  }
  for (px in 1:NPATHS) {
    int fx = factorItemPath[1,px];
    int ix = factorItemPath[2,px];
    real resid = 0;
    real pred;
    for (cx in 1:(1+NFACTORS)) {
      if (cx == fx+1) {
        pred = rawPerComponentVar[ix,cx];
      } else {
        resid += rawPerComponentVar[ix,cx];
      }
    }
    rawPathProp[px] = pred / (pred + resid);
  }
}
model {
  $multivariateQuickLikelihoodDecl
  int px=1;

  $multivariateThresholdPrior$latentPrior
  rawLoadings ~ normal(0, 5.0);
  rawUnique ~ normal(0, 5.0);
  for (ix in 1:NITEMS) {
    rawUniqueTheta[,ix] ~ normal(0, 1.0);
  }
  $multivariateQuickLikelihood
  target += normal_lpdf(logit(0.5 + rawPathProp/2.0) | 0, pathScalePrior);
}
generated quantities {$llDecl
  vector[NPATHS] pathProp = rawPathProp;
  vector[NITEMS] sigma;
  vector[NPATHS] pathLoadings = rawLoadings;
  matrix[NPA,NFACTORS] factor = rawFactor;
  matrix[NPA,NITEMS] residual;
  matrix[NITEMS,NITEMS] residualItemCor;
  int rawSeenFactor[NFACTORS];
  int rawNegateFactor[NFACTORS];

  for (ix in 1:NITEMS) {
    residual[,ix] = rawUniqueTheta[,ix] * rawUnique[ix];
    residual[,ix] -= mean(residual[,ix]);
  }
  residualItemCor = crossprod(residual);
  residualItemCor = quad_form_diag(residualItemCor, 1.0 ./ sqrt(diagonal(residualItemCor)));

  for (fx in 1:NITEMS) {
    sigma[fx] = sd(theta[,fx]);
  }
  for (fx in 1:NFACTORS) rawSeenFactor[fx] = 0;
  for (px in 1:NPATHS) {
    int fx = factorItemPath[1,px];
    int ix = factorItemPath[2,px];
    if (rawSeenFactor[fx] == 0) {
      rawSeenFactor[fx] = 1;
      rawNegateFactor[fx] = rawLoadings[px] < 0;
    }
    if (rawNegateFactor[fx]) {
      pathLoadings[px] = -pathLoadings[px];
    }
  }
  for (fx in 1:NFACTORS) {
    if (!rawNegateFactor[fx]) continue;
    factor[,fx] = -factor[,fx];
  }
  for (fx in 1:NPATHS) {
    if (pathLoadings[fx] < 0) pathProp[fx] = -pathProp[fx];
  }$llBody
}
];
}

{
    open my $fh, ">$dest/unidim.stan";
    print $fh "#include /pre/license.stan\n";
    print $fh $itemModel;
    print $fh mkUnidim(0,0);
}
{
    open my $fh, ">$dest/unidim_adapt.stan";
    print $fh "#include /pre/license.stan\n";
    print $fh 'functions {
#include /functions/cmp_prob1.stan
}
';
    print $fh mkUnidim(1,0);
}
{
    open my $fh, ">$dest/unidim_ll.stan";
    print $fh "#include /pre/license.stan\n";
    print $fh $itemModel;
    print $fh mkUnidim(0,1);
}
{
    open my $fh, ">$dest/correlation.stan";
    print $fh "#include /pre/license.stan\n";
    print $fh $itemModel;
    print $fh mkCorModel(0);
}
{
    open my $fh, ">$dest/correlation_ll.stan";
    print $fh "#include /pre/license.stan\n";
    print $fh $itemModel;
    print $fh mkCorModel(1);
}
{
    open my $fh, ">$dest/factor1.stan";
    print $fh "#include /pre/license.stan\n";
    print $fh $itemModel;
    print $fh mkFactorModel(0,0);
}
{
    open my $fh, ">$dest/factor1_ll.stan";
    print $fh "#include /pre/license.stan\n";
    print $fh $itemModel;
    print $fh mkFactorModel(1,0);
}
{
    open my $fh, ">$dest/factor.stan";
    print $fh "#include /pre/license.stan\n";
    print $fh $itemModel;
    print $fh mkFactorModel(0,1);
}
{
    open my $fh, ">$dest/factor_ll.stan";
    print $fh "#include /pre/license.stan\n";
    print $fh $itemModel;
    print $fh mkFactorModel(1,1);
}
