/*
 * Decompiled with CFR 0.152.
 */
package moa.tasks;

import com.github.javacliparser.FileOption;
import com.github.javacliparser.IntOption;
import com.github.javacliparser.MultiChoiceOption;
import java.io.File;
import java.io.FileOutputStream;
import java.io.PrintStream;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.Random;
import moa.classifiers.MultiClassClassifier;
import moa.core.Example;
import moa.core.Measurement;
import moa.core.MiscUtils;
import moa.core.ObjectRepository;
import moa.core.TimingUtils;
import moa.evaluation.LearningEvaluation;
import moa.evaluation.LearningPerformanceEvaluator;
import moa.evaluation.preview.LearningCurve;
import moa.learners.Learner;
import moa.options.ClassOption;
import moa.streams.ExampleStream;
import moa.tasks.ClassificationMainTask;
import moa.tasks.TaskMonitor;

public class EvaluatePrequentialDelayedCV
extends ClassificationMainTask {
    private static final long serialVersionUID = 1L;
    public ClassOption learnerOption = new ClassOption("learner", 'l', "Learner to train.", MultiClassClassifier.class, "moa.classifiers.bayes.NaiveBayes");
    public ClassOption streamOption = new ClassOption("stream", 's', "Stream to learn from.", ExampleStream.class, "generators.RandomTreeGenerator");
    public ClassOption evaluatorOption = new ClassOption("evaluator", 'e', "Classification performance evaluation method.", LearningPerformanceEvaluator.class, "WindowClassificationPerformanceEvaluator");
    public IntOption delayLengthOption = new IntOption("delay", 'k', "Number of instances before test instance is used for training", 1000, 1, Integer.MAX_VALUE);
    public IntOption instanceLimitOption = new IntOption("instanceLimit", 'i', "Maximum number of instances to test/train on  (-1 = no limit).", 100000000, -1, Integer.MAX_VALUE);
    public IntOption timeLimitOption = new IntOption("timeLimit", 't', "Maximum number of seconds to test/train for (-1 = no limit).", -1, -1, Integer.MAX_VALUE);
    public IntOption sampleFrequencyOption = new IntOption("sampleFrequency", 'f', "How many instances between samples of the learning performance.", 100000, 0, Integer.MAX_VALUE);
    public IntOption memCheckFrequencyOption = new IntOption("memCheckFrequency", 'q', "How many instances between memory bound checks.", 100000, 0, Integer.MAX_VALUE);
    public FileOption dumpFileOption = new FileOption("dumpFile", 'd', "File to append intermediate csv results to.", null, "csv", true);
    public IntOption numFoldsOption = new IntOption("numFolds", 'w', "The number of folds (e.g. distributed models) to be used.", 10, 1, Integer.MAX_VALUE);
    public MultiChoiceOption validationMethodologyOption = new MultiChoiceOption("validationMethodology", 'a', "Validation methodology to use.", new String[]{"Cross-Validation", "Bootstrap-Validation", "Split-Validation"}, new String[]{"k-fold distributed Cross Validation", "k-fold distributed Bootstrap Validation", "k-fold distributed Split Validation"}, 0);
    public IntOption randomSeedOption = new IntOption("randomSeed", 'r', "Seed for random behaviour of the task.", 1);
    protected LinkedList<LinkedList<Example>> trainInstances;

    @Override
    public String getPurposeString() {
        return "Evaluates a classifier using delayed cross-validation evaluation by testing and only training with the example after the arrival of other k examples (delayed labeling) ";
    }

    @Override
    public Class<?> getTaskResultType() {
        return LearningCurve.class;
    }

    @Override
    protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) {
        long evaluateStartTime;
        Random random = new Random(this.randomSeedOption.getValue());
        ExampleStream stream = (ExampleStream)this.getPreparedClassOption(this.streamOption);
        Learner[] learners = new Learner[this.numFoldsOption.getValue()];
        Learner baseLearner = (Learner)this.getPreparedClassOption(this.learnerOption);
        baseLearner.resetLearning();
        LearningPerformanceEvaluator[] evaluators = new LearningPerformanceEvaluator[this.numFoldsOption.getValue()];
        LearningPerformanceEvaluator baseEvaluator = (LearningPerformanceEvaluator)this.getPreparedClassOption(this.evaluatorOption);
        for (int i = 0; i < learners.length; ++i) {
            learners[i] = (Learner)baseLearner.copy();
            learners[i].setModelContext(stream.getHeader());
            evaluators[i] = (LearningPerformanceEvaluator)baseEvaluator.copy();
        }
        LearningCurve learningCurve = new LearningCurve("learning evaluation instances");
        int maxInstances = this.instanceLimitOption.getValue();
        long instancesProcessed = 0L;
        int maxSeconds = this.timeLimitOption.getValue();
        int secondsElapsed = 0;
        monitor.setCurrentActivity("Evaluating learner...", -1.0);
        this.trainInstances = new LinkedList();
        for (int i = 0; i < learners.length; ++i) {
            this.trainInstances.add(new LinkedList());
        }
        File dumpFile = this.dumpFileOption.getFile();
        PrintStream immediateResultStream = null;
        if (dumpFile != null) {
            try {
                immediateResultStream = dumpFile.exists() ? new PrintStream(new FileOutputStream(dumpFile, true), true) : new PrintStream(new FileOutputStream(dumpFile), true);
            }
            catch (Exception ex) {
                throw new RuntimeException("Unable to open immediate result file: " + dumpFile, ex);
            }
        }
        boolean firstDump = true;
        boolean preciseCPUTiming = TimingUtils.enablePreciseTiming();
        long lastEvaluateStartTime = evaluateStartTime = TimingUtils.getNanoCPUTimeOfCurrentThread();
        double RAMHours = 0.0;
        while (!(!stream.hasMoreInstances() || maxInstances >= 0 && instancesProcessed >= (long)maxInstances || maxSeconds >= 0 && secondsElapsed >= maxSeconds)) {
            Object trainInst;
            Object testInst = trainInst = stream.nextInstance();
            ++instancesProcessed;
            for (int i = 0; i < learners.length; ++i) {
                double[] prediction = learners[i].getVotesForInstance(testInst);
                evaluators[i].addResult(testInst, prediction);
                int k = 1;
                switch (this.validationMethodologyOption.getChosenIndex()) {
                    case 0: {
                        k = instancesProcessed % (long)learners.length == (long)i ? 0 : 1;
                        break;
                    }
                    case 1: {
                        k = MiscUtils.poisson(1.0, random);
                        break;
                    }
                    case 2: {
                        int n = k = instancesProcessed % (long)learners.length == (long)i ? 1 : 0;
                    }
                }
                if (k > 0) {
                    this.trainInstances.get(i).addLast((Example)trainInst);
                }
                if (this.delayLengthOption.getValue() >= this.trainInstances.get(i).size()) continue;
                Example trainInstI = this.trainInstances.get(i).removeFirst();
                learners[i].trainOnInstance(trainInstI);
            }
            if (instancesProcessed % (long)this.sampleFrequencyOption.getValue() == 0L || !stream.hasMoreInstances()) {
                long evaluateTime = TimingUtils.getNanoCPUTimeOfCurrentThread();
                double time = TimingUtils.nanoTimeToSeconds(evaluateTime - evaluateStartTime);
                double timeIncrement = TimingUtils.nanoTimeToSeconds(evaluateTime - lastEvaluateStartTime);
                for (int i = 0; i < learners.length; ++i) {
                    double RAMHoursIncrement = (double)learners[i].measureByteSize() / 1.073741824E9;
                    RAMHours += (RAMHoursIncrement *= timeIncrement / 3600.0);
                }
                lastEvaluateStartTime = evaluateTime;
                learningCurve.insertEntry(new LearningEvaluation(this.getEvaluationMeasurements(new Measurement[]{new Measurement("learning evaluation instances", instancesProcessed), new Measurement("evaluation time (" + (preciseCPUTiming ? "cpu " : "") + "seconds)", time), new Measurement("model cost (RAM-Hours)", RAMHours)}, evaluators)));
                if (immediateResultStream != null) {
                    if (firstDump) {
                        immediateResultStream.println(learningCurve.headerToString());
                        firstDump = false;
                    }
                    immediateResultStream.println(learningCurve.entryToString(learningCurve.numEntries() - 1));
                    immediateResultStream.flush();
                }
            }
            if (instancesProcessed % 10L != 0L) continue;
            if (monitor.taskShouldAbort()) {
                return null;
            }
            long estimatedRemainingInstances = stream.estimatedRemainingInstances();
            if (maxInstances > 0) {
                long maxRemaining = (long)maxInstances - instancesProcessed;
                if (estimatedRemainingInstances < 0L || maxRemaining < estimatedRemainingInstances) {
                    estimatedRemainingInstances = maxRemaining;
                }
            }
            monitor.setCurrentActivityFractionComplete(estimatedRemainingInstances < 0L ? -1.0 : (double)instancesProcessed / (double)(instancesProcessed + estimatedRemainingInstances));
            if (monitor.resultPreviewRequested()) {
                monitor.setLatestResultPreview(learningCurve.copy());
            }
            secondsElapsed = (int)TimingUtils.nanoTimeToSeconds(TimingUtils.getNanoCPUTimeOfCurrentThread() - evaluateStartTime);
        }
        if (immediateResultStream != null) {
            immediateResultStream.close();
        }
        return learningCurve;
    }

    public Measurement[] getEvaluationMeasurements(Measurement[] modelMeasurements, LearningPerformanceEvaluator[] subEvaluators) {
        LinkedList<Measurement> measurementList = new LinkedList<Measurement>();
        if (modelMeasurements != null) {
            measurementList.addAll(Arrays.asList(modelMeasurements));
        }
        if (subEvaluators != null && subEvaluators.length > 0) {
            LinkedList<Measurement[]> subMeasurements = new LinkedList<Measurement[]>();
            for (LearningPerformanceEvaluator subEvaluator : subEvaluators) {
                if (subEvaluator == null) continue;
                subMeasurements.add(subEvaluator.getPerformanceMeasurements());
            }
            Measurement[] avgMeasurements = Measurement.averageMeasurements((Measurement[][])subMeasurements.toArray((T[])new Measurement[subMeasurements.size()][]));
            measurementList.addAll(Arrays.asList(avgMeasurements));
        }
        return measurementList.toArray(new Measurement[measurementList.size()]);
    }
}

