/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.bayes;

import Jama.Matrix;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.example.table.NominalMapping;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.AbstractLearner;
import com.rapidminer.operator.learner.LearnerCapability;
import com.rapidminer.operator.learner.bayes.DiscriminantModel;
import com.rapidminer.parameter.UndefinedParameterError;
import com.rapidminer.tools.math.matrix.CovarianceMatrix;

public class LinearDiscriminantAnalysis
extends AbstractLearner {
    public LinearDiscriminantAnalysis(OperatorDescription description) {
        super(description);
    }

    public Model learn(ExampleSet exampleSet) throws OperatorException {
        int numberOfNumericalAttributes = 0;
        for (Attribute attribute : exampleSet.getAttributes()) {
            if (!attribute.isNumerical()) continue;
            ++numberOfNumericalAttributes;
        }
        NominalMapping labelMapping = exampleSet.getAttributes().getLabel().getMapping();
        String[] labelValues = new String[labelMapping.size()];
        int i = 0;
        while (i < labelMapping.size()) {
            labelValues[i] = labelMapping.mapIndex(i);
            ++i;
        }
        Matrix[] meanVectors = this.getMeanVectors(exampleSet, numberOfNumericalAttributes, labelValues);
        Matrix[] inverseCovariance = this.getInverseCovarianceMatrices(exampleSet, labelValues);
        return this.getModel(exampleSet, labelValues, meanVectors, inverseCovariance, this.getAprioriProbabilities(exampleSet, labelValues));
    }

    protected DiscriminantModel getModel(ExampleSet exampleSet, String[] labels, Matrix[] meanVectors, Matrix[] inverseCovariances, double[] aprioriProbabilities) throws UndefinedParameterError {
        return new DiscriminantModel(exampleSet, labels, meanVectors, inverseCovariances, aprioriProbabilities, 0.0);
    }

    private double[] getAprioriProbabilities(ExampleSet exampleSet, String[] labels) {
        double[] aprioriProbabilites = new double[labels.length];
        double totalSize = exampleSet.size();
        Attribute labelAttribute = exampleSet.getAttributes().getLabel();
        SplittedExampleSet labelSet = SplittedExampleSet.splitByAttribute(exampleSet, exampleSet.getAttributes().getLabel());
        int labelIndex = 0;
        String[] stringArray = labels;
        int n = labels.length;
        int n2 = 0;
        while (n2 < n) {
            String label = stringArray[n2];
            int i = 0;
            while (i < labels.length) {
                labelSet.selectSingleSubset(i);
                if (labelSet.getExample(0).getNominalValue(labelAttribute).equals(label)) break;
                ++i;
            }
            aprioriProbabilites[labelIndex] = (double)labelSet.size() / totalSize;
            ++labelIndex;
            ++n2;
        }
        return aprioriProbabilites;
    }

    protected Matrix[] getMeanVectors(ExampleSet exampleSet, int numberOfAttributes, String[] labels) {
        Matrix[] classMeanVectors = new Matrix[labels.length];
        Attribute labelAttribute = exampleSet.getAttributes().getLabel();
        SplittedExampleSet labelSet = SplittedExampleSet.splitByAttribute(exampleSet, exampleSet.getAttributes().getLabel());
        int labelIndex = 0;
        String[] stringArray = labels;
        int n = labels.length;
        int n2 = 0;
        while (n2 < n) {
            String label = stringArray[n2];
            int i = 0;
            while (i < labels.length) {
                labelSet.selectSingleSubset(i);
                if (labelSet.getExample(0).getNominalValue(labelAttribute).equals(label)) break;
                ++i;
            }
            labelSet.recalculateAllAttributeStatistics();
            double[] meanValues = new double[numberOfAttributes];
            int i2 = 0;
            for (Attribute attribute : labelSet.getAttributes()) {
                if (attribute.isNumerical()) {
                    meanValues[i2] = labelSet.getStatistics(attribute, "average");
                }
                ++i2;
            }
            classMeanVectors[labelIndex] = new Matrix(meanValues, 1);
            ++labelIndex;
            ++n2;
        }
        return classMeanVectors;
    }

    protected Matrix[] getInverseCovarianceMatrices(ExampleSet exampleSet, String[] labels) throws UndefinedParameterError {
        Matrix[] classInverseCovariances = new Matrix[labels.length];
        Matrix inverse = CovarianceMatrix.getCovarianceMatrix(exampleSet).inverse();
        int i = 0;
        while (i < labels.length) {
            classInverseCovariances[i] = inverse;
            ++i;
        }
        return classInverseCovariances;
    }

    public boolean supportsCapability(LearnerCapability capability) {
        if (capability.equals(LearnerCapability.NUMERICAL_ATTRIBUTES)) {
            return true;
        }
        if (capability.equals(LearnerCapability.BINOMINAL_CLASS)) {
            return true;
        }
        return capability.equals(LearnerCapability.POLYNOMINAL_CLASS);
    }
}

