/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.meta;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.Tools;
import com.rapidminer.gui.tools.ExtendedJTabbedPane;
import com.rapidminer.operator.IOContainer;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.meta.BayBoostBaseModelInfo;
import com.rapidminer.operator.learner.meta.ContingencyMatrix;
import com.rapidminer.tools.LogService;
import java.awt.Component;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class BayBoostModel
extends PredictionModel {
    private static final long serialVersionUID = 5821921049035718838L;
    private final List<BayBoostBaseModelInfo> modelInfo;
    private final double[] priors;
    private int maxModelNumber = -1;
    private static final String MAX_MODEL_NUMBER = "iteration";
    private static final String CONV_TO_CRISP = "crisp";
    private double threshold = 0.5;

    public BayBoostModel(ExampleSet exampleSet, List<BayBoostBaseModelInfo> modelInfos, double[] priors) {
        super(exampleSet);
        this.modelInfo = modelInfos;
        this.priors = priors;
    }

    public BayBoostBaseModelInfo getBayBoostBaseModelInfo(int index) {
        return this.modelInfo.get(index);
    }

    public void setParameter(String name, String value) throws OperatorException {
        if (name.equalsIgnoreCase(MAX_MODEL_NUMBER)) {
            try {
                this.maxModelNumber = Integer.parseInt(value);
                return;
            }
            catch (NumberFormatException numberFormatException) {}
        } else if (name.equalsIgnoreCase(CONV_TO_CRISP)) {
            this.threshold = Double.parseDouble(value.trim());
            return;
        }
        super.setParameter(name, value);
    }

    public void setMaxModelNumber(int numModels) {
        this.maxModelNumber = numModels;
    }

    @Override
    public Component getVisualizationComponent(IOContainer container) {
        ExtendedJTabbedPane tabPane = new ExtendedJTabbedPane();
        int i = 0;
        while (i < this.getNumberOfModels()) {
            Model model = this.getModel(i);
            tabPane.add("Model " + (i + 1), model.getVisualizationComponent(container));
            ++i;
        }
        return tabPane;
    }

    @Override
    public String toString() {
        StringBuffer result = new StringBuffer(String.valueOf(super.toString()) + com.rapidminer.tools.Tools.getLineSeparator() + "Number of inner models: " + this.getNumberOfModels() + com.rapidminer.tools.Tools.getLineSeparators(2));
        int i = 0;
        while (i < this.getNumberOfModels()) {
            Model model = this.getModel(i);
            result.append(String.valueOf(i > 0 ? com.rapidminer.tools.Tools.getLineSeparator() : "") + "Embedded model #" + i + ":" + com.rapidminer.tools.Tools.getLineSeparator() + model.toResultString());
            ++i;
        }
        return result.toString();
    }

    public int getNumberOfModels() {
        if (this.maxModelNumber >= 0) {
            return Math.min(this.maxModelNumber, this.modelInfo.size());
        }
        return this.modelInfo.size();
    }

    private double[] getFactorsForModel(int modelNr, int predicted) {
        ContingencyMatrix cm = this.modelInfo.get(modelNr).getContingencyMatrix();
        return cm.getLiftRatiosForPrediction(predicted);
    }

    private double getPriorOfClass(int classIndex) {
        return this.priors[classIndex];
    }

    public double[] getPriors() {
        double[] result = new double[this.priors.length];
        System.arraycopy(this.priors, 0, result, 0, result.length);
        return result;
    }

    public Model getModel(int index) {
        return this.modelInfo.get(index).getModel();
    }

    public ContingencyMatrix getContingencyMatrix(int index) {
        return this.modelInfo.get(index).getContingencyMatrix();
    }

    @Override
    public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabel) throws OperatorException {
        Attribute[] specialAttributes = this.createSpecialAttributes(exampleSet);
        this.initIntermediateResultAttributes(exampleSet, specialAttributes);
        int i = 0;
        while (i < this.getNumberOfModels()) {
            Model model = this.getModel(i);
            ExampleSet clonedExampleSet = (ExampleSet)exampleSet.clone();
            clonedExampleSet = model.apply(clonedExampleSet);
            this.updateEstimates(clonedExampleSet, this.getContingencyMatrix(i), specialAttributes);
            PredictionModel.removePredictedLabel(clonedExampleSet);
            ++i;
        }
        for (Example example : exampleSet) {
            this.translateOddsIntoPredictions(example, specialAttributes, this.getTrainingHeader().getAttributes().getLabel());
        }
        this.cleanUpSpecialAttributes(exampleSet, specialAttributes);
        return exampleSet;
    }

    private Attribute[] createSpecialAttributes(ExampleSet exampleSet) throws OperatorException {
        String attributePrefix = "BayBoostModelPrediction";
        Attribute[] specialAttributes = new Attribute[this.getLabel().getMapping().size()];
        int i = 0;
        while (i < specialAttributes.length) {
            specialAttributes[i] = Tools.createSpecialAttribute(exampleSet, "BayBoostModelPrediction" + i, 2);
            ++i;
        }
        return specialAttributes;
    }

    private void cleanUpSpecialAttributes(ExampleSet exampleSet, Attribute[] specialAttributes) throws OperatorException {
        int i = 0;
        while (i < specialAttributes.length) {
            exampleSet.getAttributes().remove(specialAttributes[i]);
            exampleSet.getExampleTable().removeAttribute(specialAttributes[i]);
            ++i;
        }
    }

    private void initIntermediateResultAttributes(ExampleSet exampleSet, Attribute[] specAttrib) {
        double[] priorOdds = new double[this.priors.length];
        int i = 0;
        while (i < priorOdds.length) {
            priorOdds[i] = this.priors[i] == 1.0 ? Double.POSITIVE_INFINITY : this.priors[i] / (1.0 - this.priors[i]);
            ++i;
        }
        for (Example example : exampleSet) {
            int i2 = 0;
            while (i2 < specAttrib.length) {
                example.setValue(specAttrib[i2], priorOdds[i2]);
                ++i2;
            }
        }
    }

    private void translateOddsIntoPredictions(Example example, Attribute[] specAttrib, Attribute trainingSetLabel) {
        String bestLabel;
        double probSum = 0.0;
        double[] classProb = new double[specAttrib.length];
        int bestIndex = 0;
        int n = 0;
        while (n < classProb.length) {
            double odds = example.getValue(specAttrib[n]);
            if (Double.isNaN(odds)) {
                this.logWarning("Found NaN odd ratio estimate.");
                classProb[n] = 1.0;
            } else {
                classProb[n] = Double.isInfinite(odds) ? 1.0 : odds / (1.0 + odds);
            }
            probSum += classProb[n];
            if (classProb[n] > classProb[bestIndex]) {
                bestIndex = n;
            }
            ++n;
        }
        if (probSum != 1.0) {
            int k = 0;
            while (k < classProb.length) {
                int n2 = k++;
                classProb[n2] = classProb[n2] / probSum;
            }
        }
        if (this.getLabel().isNominal() && this.getLabel().getMapping().size() == 2 && this.threshold != 0.5) {
            int posIndex = this.getLabel().getMapping().getPositiveIndex();
            int negIndex = this.getLabel().getMapping().getNegativeIndex();
            this.threshold = this.threshold >= 0.0 && this.threshold <= 1.0 ? this.threshold : 0.5;
            bestLabel = this.getLabel().getMapping().mapIndex(classProb[posIndex] >= this.threshold ? posIndex : negIndex);
        } else {
            bestLabel = this.getLabel().getMapping().mapIndex(bestIndex);
        }
        example.setValue(example.getAttributes().getPredictedLabel(), trainingSetLabel.getMapping().mapString(bestLabel));
        int k = 0;
        while (k < classProb.length) {
            if (Double.isNaN(classProb[k]) || classProb[k] < 0.0 || classProb[k] > 1.0) {
                this.logWarning("Found illegal confidence value: " + classProb[k]);
            }
            example.setConfidence(this.getLabel().getMapping().mapIndex(k), classProb[k]);
            ++k;
        }
    }

    private void updateEstimates(ExampleSet exampleSet, ContingencyMatrix cm, Attribute[] specialAttributes) {
        for (Example example : exampleSet) {
            int predicted = (int)example.getPredictedLabel();
            int j = 0;
            while (j < cm.getNumberOfClasses()) {
                double liftRatioCurrent = cm.getLiftRatio(j, predicted);
                if (Double.isNaN(liftRatioCurrent)) {
                    this.logWarning("Ignoring non-applicable model.");
                } else if (Double.isInfinite(liftRatioCurrent)) {
                    if (example.getValue(specialAttributes[j]) != 0.0) {
                        int k = 0;
                        while (k < specialAttributes.length) {
                            example.setValue(specialAttributes[k], 0.0);
                            ++k;
                        }
                        example.setValue(specialAttributes[j], liftRatioCurrent);
                    }
                } else {
                    double oldValue = example.getValue(specialAttributes[j]);
                    if (Double.isNaN(oldValue)) {
                        this.logWarning("Found NaN value in intermediate odds ratio estimates!");
                    }
                    if (!Double.isInfinite(oldValue)) {
                        example.setValue(specialAttributes[j], oldValue * liftRatioCurrent);
                    }
                }
                ++j;
            }
        }
    }

    public static boolean adjustIntermediateProducts(double[] products, double[] liftFactors) {
        int i = 0;
        while (i < liftFactors.length) {
            if (Double.isNaN(liftFactors[i])) {
                LogService.getGlobal().log("Ignoring non-applicable model.", 5);
            } else if (Double.isInfinite(liftFactors[i])) {
                if (products[i] != 0.0) {
                    int j = 0;
                    while (j < products.length) {
                        products[j] = 0.0;
                        ++j;
                    }
                    products[i] = liftFactors[i];
                    return true;
                }
            } else {
                int n = i;
                products[n] = products[n] * liftFactors[i];
                if (Double.isNaN(products[i])) {
                    LogService.getGlobal().log("Found NaN value in intermediate odds ratio estimates!", 5);
                }
            }
            ++i;
        }
        return false;
    }

    public double[] getModelWeights() throws OperatorException {
        if (this.getLabel().getMapping().size() != 2) {
            throw new UserError(null, 114, "BayBoostModel", this.getLabel());
        }
        int maxWeight = 10;
        int pos = this.getLabel().getMapping().getPositiveIndex();
        int neg = this.getLabel().getMapping().getNegativeIndex();
        double[] weights = new double[this.getNumberOfModels() + 1];
        double odds = this.getPriorOfClass(pos) / this.getPriorOfClass(neg);
        weights[0] = Math.log(odds);
        int i = 1;
        while (i < weights.length) {
            double[] liftRatiosPos = this.getFactorsForModel(i - 1, pos);
            double logPosRatio = Math.log(liftRatiosPos[pos]);
            logPosRatio = Math.min((double)maxWeight, Math.max((double)(-maxWeight), logPosRatio));
            double[] liftRatiosNeg = this.getFactorsForModel(i - 1, neg);
            double logNegRatio = Math.log(liftRatiosNeg[pos]);
            double indep = (logPosRatio + (logNegRatio = Math.min((double)maxWeight, Math.max((double)(-maxWeight), logNegRatio)))) / 2.0;
            if (com.rapidminer.tools.Tools.isEqual(indep, maxWeight) || com.rapidminer.tools.Tools.isEqual(indep, -maxWeight)) {
                logPosRatio = 10.0 * indep;
                indep = 0.0;
            }
            weights[0] = weights[0] + indep;
            weights[i] = logPosRatio -= indep;
            ++i;
        }
        return weights;
    }
}

