/*
 * Decompiled with CFR 0.152.
 */
package org.extratrees;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Set;
import org.extratrees.AbstractTrees;
import org.extratrees.FactorBinaryTree;
import org.extratrees.Matrix;
import org.extratrees.TaskCutResult;

public class FactorExtraTrees
extends AbstractTrees<FactorBinaryTree> {
    int[] output;
    int nFactors;

    public FactorExtraTrees(Matrix matrix, int[] nArray) {
        this(matrix, nArray, null);
    }

    public FactorExtraTrees(Matrix matrix, int[] nArray, int[] nArray2) {
        int n;
        if (matrix.nrows != nArray.length) {
            throw new IllegalArgumentException("Input and output do not have same length.");
        }
        if (nArray2 != null && matrix.nrows != nArray2.length) {
            throw new IllegalArgumentException("Input and tasks do not have the same number of data points.");
        }
        this.input = matrix;
        this.output = nArray;
        this.nFactors = 1;
        for (n = 0; n < nArray.length; ++n) {
            if (nArray[n] < 0) {
                throw new RuntimeException("Bug: negative output (factor) values.");
            }
            if (this.nFactors > nArray[n]) continue;
            this.nFactors = nArray[n] + 1;
        }
        this.setTasks(nArray2);
        this.cols = new ArrayList(matrix.ncols);
        for (n = 0; n < matrix.ncols; ++n) {
            this.cols.add(n);
        }
    }

    public int getnFactors() {
        return this.nFactors;
    }

    public void setnFactors(int n) {
        this.nFactors = n;
    }

    public FactorExtraTrees selectTrees(boolean[] blArray) {
        FactorExtraTrees factorExtraTrees = new FactorExtraTrees(this.input, this.output, this.tasks);
        factorExtraTrees.trees = new ArrayList();
        for (int i = 0; i < blArray.length; ++i) {
            if (!blArray[i]) continue;
            factorExtraTrees.trees.add(this.trees.get(i));
        }
        return factorExtraTrees;
    }

    public static int getValue(ArrayList<FactorBinaryTree> arrayList, double[] dArray, int n) {
        int[] nArray = new int[n];
        for (FactorBinaryTree factorBinaryTree : arrayList) {
            int n2 = factorBinaryTree.getValue(dArray);
            nArray[n2] = nArray[n2] + 1;
        }
        return FactorExtraTrees.getMaxIndex(nArray);
    }

    public static int getMaxIndex(int[] nArray) {
        int n = -1;
        int n2 = Integer.MIN_VALUE;
        for (int i = 0; i < nArray.length; ++i) {
            if (nArray[i] <= n2) continue;
            n2 = nArray[i];
            n = i;
        }
        return n;
    }

    public Matrix getAllValues(Matrix matrix) {
        Matrix matrix2 = new Matrix(matrix.nrows, this.trees.size());
        double[] dArray = new double[matrix.ncols];
        for (int i = 0; i < matrix.nrows; ++i) {
            matrix.copyRow(i, dArray);
            for (int j = 0; j < this.trees.size(); ++j) {
                matrix2.set(i, j, ((FactorBinaryTree)this.trees.get(j)).getValue(dArray));
            }
        }
        return matrix2;
    }

    public Matrix getAllValuesMT(Matrix matrix, int[] nArray) {
        if (matrix.nrows != nArray.length) {
            throw new IllegalArgumentException("Inputs and tasks do not have the same length.");
        }
        Matrix matrix2 = new Matrix(matrix.nrows, this.trees.size());
        double[] dArray = new double[matrix.ncols];
        for (int i = 0; i < matrix.nrows; ++i) {
            matrix.copyRow(i, dArray);
            for (int j = 0; j < this.trees.size(); ++j) {
                matrix2.set(i, j, ((FactorBinaryTree)this.trees.get(j)).getValueMT(dArray, nArray[i]));
            }
        }
        return matrix2;
    }

    public int[] getValues(Matrix matrix) {
        return FactorExtraTrees.getValues(this.trees, matrix, this.nFactors);
    }

    public static int[] getValues(ArrayList<FactorBinaryTree> arrayList, Matrix matrix, int n) {
        int[] nArray = new int[matrix.nrows];
        double[] dArray = new double[matrix.ncols];
        for (int i = 0; i < matrix.nrows; ++i) {
            for (int j = 0; j < matrix.ncols; ++j) {
                dArray[j] = matrix.get(i, j);
            }
            nArray[i] = FactorExtraTrees.getValue(arrayList, dArray, n);
        }
        return nArray;
    }

    public int[] getValuesMT(Matrix matrix, int[] nArray) {
        int[] nArray2 = new int[matrix.nrows];
        double[] dArray = new double[matrix.ncols];
        for (int i = 0; i < matrix.nrows; ++i) {
            for (int j = 0; j < matrix.ncols; ++j) {
                dArray[j] = matrix.get(i, j);
            }
            nArray2[i] = this.getValueMT(dArray, nArray[i]);
        }
        return nArray2;
    }

    public int getValueMT(double[] dArray, int n) {
        int[] nArray = new int[this.nFactors];
        for (FactorBinaryTree factorBinaryTree : this.trees) {
            int n2 = factorBinaryTree.getValueMT(dArray, n);
            nArray[n2] = nArray[n2] + 1;
        }
        return FactorExtraTrees.getMaxIndex(nArray);
    }

    public static double getGiniIndex(int[] nArray) {
        int n = 0;
        int n2 = 0;
        for (int i = 0; i < nArray.length; ++i) {
            n += nArray[i] * nArray[i];
            n2 += nArray[i];
        }
        return 1.0 - (double)n / (double)(n2 * n2);
    }

    @Override
    protected FactorBinaryTree makeFilledTree(FactorBinaryTree factorBinaryTree, FactorBinaryTree factorBinaryTree2, int n, double d, int n2) {
        FactorBinaryTree factorBinaryTree3 = new FactorBinaryTree();
        factorBinaryTree3.column = n;
        factorBinaryTree3.threshold = d;
        factorBinaryTree3.nSuccessors = n2;
        factorBinaryTree3.left = factorBinaryTree;
        factorBinaryTree3.right = factorBinaryTree2;
        return factorBinaryTree3;
    }

    @Override
    protected TaskCutResult getTaskCut(int[] nArray, Set<Integer> set, double d) {
        if (this.nFactors > 2) {
            throw new RuntimeException("Multitask learning is not implemented 3 or more factors (classes).");
        }
        if (set.size() <= 1) {
            return null;
        }
        int[][] nArray2 = this.getFactorTaskTable(nArray);
        double[] dArray = this.getTaskScores(nArray2);
        if (!this.hasAtLeast2Tasks(set, nArray2)) {
            return null;
        }
        double[] dArray2 = this.getRange(dArray);
        TaskCutResult taskCutResult = null;
        for (int i = 0; i < this.numRandomTaskCuts; ++i) {
            double d2 = this.getRandom(dArray2[0], dArray2[1]);
            TaskCutResult taskCutResult2 = new TaskCutResult();
            this.calculateTaskCutScore(dArray, nArray2, d2, taskCutResult2);
            if (!(taskCutResult2.score < d)) continue;
            taskCutResult = taskCutResult2;
            d = taskCutResult2.score;
        }
        return taskCutResult;
    }

    protected boolean hasAtLeast2Tasks(Set<Integer> set, int[][] nArray) {
        boolean bl = false;
        for (int n : set) {
            if (nArray[0][n] <= 0 && nArray[1][n] <= 0) continue;
            if (bl) {
                return true;
            }
            bl = true;
        }
        return false;
    }

    private int[][] getFactorTaskTable(int[] nArray) {
        int n;
        int[][] nArray2 = new int[this.nFactors][this.nTasks];
        for (n = 0; n < this.nFactors; ++n) {
            nArray2[n] = new int[this.nTasks];
        }
        for (n = 0; n < nArray.length; ++n) {
            int n2 = nArray[n];
            int[] nArray3 = nArray2[this.output[n2]];
            int n3 = this.tasks[n2];
            nArray3[n3] = nArray3[n3] + 1;
        }
        return nArray2;
    }

    private double[] getTaskScores(int[][] nArray) {
        int[][] nArray2 = nArray;
        double d = 1.0;
        double[] dArray = FactorExtraTrees.sumAlong2nd(nArray2);
        double d2 = (dArray[0] + 1.0) / (dArray[0] + dArray[1] + 2.0) * d;
        double[] dArray2 = new double[this.nTasks];
        for (int i = 0; i < this.nTasks; ++i) {
            dArray2[i] = ((double)nArray2[0][i] + d2) / ((double)(nArray2[0][i] + nArray2[1][i]) + d);
        }
        return dArray2;
    }

    public static double[] sumAlong2nd(int[][] nArray) {
        double[] dArray = new double[2];
        for (int i = 0; i < nArray[0].length; ++i) {
            dArray[0] = dArray[0] + (double)nArray[0][i];
            dArray[1] = dArray[1] + (double)nArray[1][i];
        }
        return dArray;
    }

    @Override
    protected void calculateCutScore(int[] nArray, int n, double d, AbstractTrees.CutResult cutResult) {
        int[] nArray2 = new int[this.nFactors];
        int[] nArray3 = new int[this.nFactors];
        for (int i = 0; i < nArray.length; ++i) {
            if (this.input.get(nArray[i], n) < d) {
                int n2 = this.output[nArray[i]];
                nArray2[n2] = nArray2[n2] + 1;
                continue;
            }
            int n3 = this.output[nArray[i]];
            nArray3[n3] = nArray3[n3] + 1;
        }
        this.cutResultFromCounts(cutResult, nArray2, nArray3);
    }

    private void calculateTaskCutScore(double[] dArray, int[][] nArray, double d, TaskCutResult taskCutResult) {
        int[] nArray2 = new int[this.nFactors];
        int[] nArray3 = new int[this.nFactors];
        taskCutResult.leftTasks = new HashSet<Integer>();
        taskCutResult.rightTasks = new HashSet<Integer>();
        for (int i = 0; i < nArray[0].length; ++i) {
            int n;
            if (dArray[i] < d) {
                for (n = 0; n < this.nFactors; ++n) {
                    int n2 = n;
                    nArray2[n2] = nArray2[n2] + nArray[n][i];
                }
                taskCutResult.leftTasks.add(i);
                continue;
            }
            for (n = 0; n < this.nFactors; ++n) {
                int n3 = n;
                nArray3[n3] = nArray3[n3] + nArray[n][i];
            }
            taskCutResult.rightTasks.add(i);
        }
        this.cutResultFromCounts(taskCutResult, nArray2, nArray3);
    }

    private void cutResultFromCounts(AbstractTrees.CutResult cutResult, int[] nArray, int[] nArray2) {
        double d = FactorExtraTrees.getGiniIndex(nArray);
        double d2 = FactorExtraTrees.getGiniIndex(nArray2);
        cutResult.countLeft = FactorExtraTrees.sum(nArray);
        cutResult.countRight = FactorExtraTrees.sum(nArray2);
        cutResult.score = (d * (double)cutResult.countLeft + d2 * (double)cutResult.countRight) / (double)(cutResult.countLeft + cutResult.countRight);
        cutResult.leftConst = d < 9.999999999999998E-15;
        cutResult.rightConst = d2 < 9.999999999999998E-15;
    }

    @Override
    public FactorBinaryTree makeLeaf(int[] nArray, Set<Integer> set) {
        FactorBinaryTree factorBinaryTree = new FactorBinaryTree();
        factorBinaryTree.value = 0;
        factorBinaryTree.nSuccessors = nArray.length;
        factorBinaryTree.tasks = set;
        int[] nArray2 = new int[this.nFactors];
        for (int i = 0; i < nArray.length; ++i) {
            int n = this.output[nArray[i]];
            nArray2[n] = nArray2[n] + 1;
        }
        factorBinaryTree.value = FactorExtraTrees.getMaxIndex(nArray2);
        return factorBinaryTree;
    }
}

