/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.bayes;

import weka.classifiers.Classifier;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;

public class HNB
extends Classifier
implements TechnicalInformationHandler {
    static final long serialVersionUID = -4503874444306113214L;
    private double[] m_ClassCounts;
    private double[][][] m_ClassAttAttCounts;
    private int[] m_NumAttValues;
    private int m_TotalAttValues;
    private int m_NumClasses;
    private int m_NumAttributes;
    private int m_NumInstances;
    private int m_ClassIndex;
    private int[] m_StartAttIndex;
    private double[][] m_condiMutualInfo;

    public String globalInfo() {
        return "Contructs Hidden Naive Bayes classification model with high classification accuracy and AUC.\n\nFor more information refer to:\n\n" + this.getTechnicalInformation().toString();
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        result.setValue(TechnicalInformation.Field.AUTHOR, "H. Zhang and L. Jiang and J. Su");
        result.setValue(TechnicalInformation.Field.TITLE, "Hidden Naive Bayes");
        result.setValue(TechnicalInformation.Field.BOOKTITLE, "Twentieth National Conference on Artificial Intelligence");
        result.setValue(TechnicalInformation.Field.YEAR, "2005");
        result.setValue(TechnicalInformation.Field.PAGES, "919-924");
        result.setValue(TechnicalInformation.Field.PUBLISHER, "AAAI Press");
        return result;
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.NOMINAL_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        return result;
    }

    @Override
    public void buildClassifier(Instances instances) throws Exception {
        this.getCapabilities().testWithFail(instances);
        instances = new Instances(instances);
        instances.deleteWithMissingClass();
        this.m_NumClasses = instances.numClasses();
        this.m_ClassIndex = instances.classIndex();
        this.m_NumAttributes = instances.numAttributes();
        this.m_NumInstances = instances.numInstances();
        this.m_TotalAttValues = 0;
        this.m_StartAttIndex = new int[this.m_NumAttributes];
        this.m_NumAttValues = new int[this.m_NumAttributes];
        int i = 0;
        while (i < this.m_NumAttributes) {
            if (i != this.m_ClassIndex) {
                this.m_StartAttIndex[i] = this.m_TotalAttValues;
                this.m_NumAttValues[i] = instances.attribute(i).numValues();
                this.m_TotalAttValues += this.m_NumAttValues[i];
            } else {
                this.m_StartAttIndex[i] = -1;
                this.m_NumAttValues[i] = this.m_NumClasses;
            }
            ++i;
        }
        this.m_ClassCounts = new double[this.m_NumClasses];
        this.m_ClassAttAttCounts = new double[this.m_NumClasses][this.m_TotalAttValues][this.m_TotalAttValues];
        int k = 0;
        while (k < this.m_NumInstances) {
            int classVal;
            int n = classVal = (int)instances.instance(k).classValue();
            this.m_ClassCounts[n] = this.m_ClassCounts[n] + 1.0;
            int[] attIndex = new int[this.m_NumAttributes];
            int i2 = 0;
            while (i2 < this.m_NumAttributes) {
                attIndex[i2] = i2 == this.m_ClassIndex ? -1 : this.m_StartAttIndex[i2] + (int)instances.instance(k).value(i2);
                ++i2;
            }
            int Att1 = 0;
            while (Att1 < this.m_NumAttributes) {
                if (attIndex[Att1] != -1) {
                    int Att2 = 0;
                    while (Att2 < this.m_NumAttributes) {
                        if (attIndex[Att2] != -1) {
                            double[] dArray = this.m_ClassAttAttCounts[classVal][attIndex[Att1]];
                            int n2 = attIndex[Att2];
                            dArray[n2] = dArray[n2] + 1.0;
                        }
                        ++Att2;
                    }
                }
                ++Att1;
            }
            ++k;
        }
        this.m_condiMutualInfo = new double[this.m_NumAttributes][this.m_NumAttributes];
        int son = 0;
        while (son < this.m_NumAttributes) {
            if (son != this.m_ClassIndex) {
                int parent = 0;
                while (parent < this.m_NumAttributes) {
                    if (parent != this.m_ClassIndex && son != parent) {
                        this.m_condiMutualInfo[son][parent] = this.conditionalMutualInfo(son, parent);
                    }
                    ++parent;
                }
            }
            ++son;
        }
    }

    private double conditionalMutualInfo(int son, int parent) throws Exception {
        int k;
        int j;
        double CondiMutualInfo = 0.0;
        int sIndex = this.m_StartAttIndex[son];
        int pIndex = this.m_StartAttIndex[parent];
        double[] PriorsClass = new double[this.m_NumClasses];
        double[][] PriorsClassSon = new double[this.m_NumClasses][this.m_NumAttValues[son]];
        double[][] PriorsClassParent = new double[this.m_NumClasses][this.m_NumAttValues[parent]];
        double[][][] PriorsClassParentSon = new double[this.m_NumClasses][this.m_NumAttValues[parent]][this.m_NumAttValues[son]];
        int i = 0;
        while (i < this.m_NumClasses) {
            PriorsClass[i] = this.m_ClassCounts[i] / (double)this.m_NumInstances;
            ++i;
        }
        i = 0;
        while (i < this.m_NumClasses) {
            j = 0;
            while (j < this.m_NumAttValues[son]) {
                PriorsClassSon[i][j] = this.m_ClassAttAttCounts[i][sIndex + j][sIndex + j] / (double)this.m_NumInstances;
                ++j;
            }
            ++i;
        }
        i = 0;
        while (i < this.m_NumClasses) {
            j = 0;
            while (j < this.m_NumAttValues[parent]) {
                PriorsClassParent[i][j] = this.m_ClassAttAttCounts[i][pIndex + j][pIndex + j] / (double)this.m_NumInstances;
                ++j;
            }
            ++i;
        }
        i = 0;
        while (i < this.m_NumClasses) {
            j = 0;
            while (j < this.m_NumAttValues[parent]) {
                k = 0;
                while (k < this.m_NumAttValues[son]) {
                    PriorsClassParentSon[i][j][k] = this.m_ClassAttAttCounts[i][pIndex + j][sIndex + k] / (double)this.m_NumInstances;
                    ++k;
                }
                ++j;
            }
            ++i;
        }
        i = 0;
        while (i < this.m_NumClasses) {
            j = 0;
            while (j < this.m_NumAttValues[parent]) {
                k = 0;
                while (k < this.m_NumAttValues[son]) {
                    CondiMutualInfo += PriorsClassParentSon[i][j][k] * this.log2(PriorsClassParentSon[i][j][k] * PriorsClass[i], PriorsClassParent[i][j] * PriorsClassSon[i][k]);
                    ++k;
                }
                ++j;
            }
            ++i;
        }
        return CondiMutualInfo;
    }

    private double log2(double x, double y) {
        if (x < 1.0E-6 || y < 1.0E-6) {
            return 0.0;
        }
        return Math.log(x / y) / Math.log(2.0);
    }

    @Override
    public double[] distributionForInstance(Instance instance) throws Exception {
        double[] probs = new double[this.m_NumClasses];
        int[] attIndex = new int[this.m_NumAttributes];
        int att = 0;
        while (att < this.m_NumAttributes) {
            attIndex[att] = att == this.m_ClassIndex ? -1 : this.m_StartAttIndex[att] + (int)instance.value(att);
            ++att;
        }
        int classVal = 0;
        while (classVal < this.m_NumClasses) {
            probs[classVal] = (this.m_ClassCounts[classVal] + 1.0 / (double)this.m_NumClasses) / ((double)this.m_NumInstances + 1.0);
            int son = 0;
            while (son < this.m_NumAttributes) {
                if (attIndex[son] != -1) {
                    int sIndex = attIndex[son];
                    attIndex[son] = -1;
                    double prob = 0.0;
                    double condiMutualInfoSum = 0.0;
                    int parent = 0;
                    while (parent < this.m_NumAttributes) {
                        if (attIndex[parent] != -1) {
                            condiMutualInfoSum += this.m_condiMutualInfo[son][parent];
                            prob += this.m_condiMutualInfo[son][parent] * (this.m_ClassAttAttCounts[classVal][attIndex[parent]][sIndex] + 1.0 / (double)this.m_NumAttValues[son]) / (this.m_ClassAttAttCounts[classVal][attIndex[parent]][attIndex[parent]] + 1.0);
                        }
                        ++parent;
                    }
                    if (condiMutualInfoSum > 0.0) {
                        int n = classVal;
                        probs[n] = probs[n] * (prob /= condiMutualInfoSum);
                    } else {
                        prob = (this.m_ClassAttAttCounts[classVal][sIndex][sIndex] + 1.0 / (double)this.m_NumAttValues[son]) / (this.m_ClassCounts[classVal] + 1.0);
                        int n = classVal;
                        probs[n] = probs[n] * prob;
                    }
                    attIndex[son] = sIndex;
                }
                ++son;
            }
            ++classVal;
        }
        Utils.normalize(probs);
        return probs;
    }

    public String toString() {
        return "HNB (Hidden Naive Bayes)";
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 5516 $");
    }

    public static void main(String[] args) {
        HNB.runClassifier(new HNB(), args);
    }
}

