/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.functions.neuralnet;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.Tools;
import com.rapidminer.gui.tools.ExtendedJScrollPane;
import com.rapidminer.gui.tools.JRadioSelectionPanel;
import com.rapidminer.operator.IOContainer;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.functions.neuralnet.SimpleNeuralNetVisualizer;
import java.awt.Component;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.encog.matrix.Matrix;
import org.encog.neural.data.NeuralData;
import org.encog.neural.data.basic.BasicNeuralData;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.Layer;

public class SimpleNeuralNetModel
extends PredictionModel {
    private static final long serialVersionUID = 332041465701627316L;
    private BasicNetwork network;
    private String[] attributeNames;
    private double[] attributeMin;
    private double[] attributeMax;
    private double labelMin;
    private double labelMax;

    protected SimpleNeuralNetModel(ExampleSet trainingExampleSet, BasicNetwork network, double[] attributeMin, double[] attributeMax, double labelMin, double labelMax) {
        super(trainingExampleSet);
        this.network = network;
        this.attributeNames = Tools.getRegularAttributeNames(trainingExampleSet);
        this.attributeMin = attributeMin;
        this.attributeMax = attributeMax;
        this.labelMin = labelMin;
        this.labelMax = labelMax;
    }

    public BasicNetwork getNeuralNet() {
        return this.network;
    }

    public String[] getAttributeNames() {
        return this.attributeNames;
    }

    public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabel) throws OperatorException {
        for (Example example : exampleSet) {
            double[] data = new double[this.attributeNames.length];
            int i = 0;
            while (i < this.attributeNames.length) {
                data[i] = this.attributeMin[i] != this.attributeMax[i] ? (example.getValue(exampleSet.getAttributes().get(this.attributeNames[i])) - this.attributeMin[i]) / (this.attributeMax[i] - this.attributeMin[i]) : example.getValue(exampleSet.getAttributes().get(this.attributeNames[i])) - this.attributeMin[i];
                ++i;
            }
            BasicNeuralData neuralData = new BasicNeuralData(data);
            double prediction = this.network.compute((NeuralData)neuralData).getData(0);
            if (predictedLabel.isNominal()) {
                double scaled = (prediction - 0.5) * 2.0;
                int index = scaled > 0.0 ? predictedLabel.getMapping().getPositiveIndex() : predictedLabel.getMapping().getNegativeIndex();
                example.setValue(predictedLabel, index);
                example.setConfidence(predictedLabel.getMapping().getPositiveString(), 1.0 / (1.0 + Math.exp(-scaled)));
                example.setConfidence(predictedLabel.getMapping().getNegativeString(), 1.0 / (1.0 + Math.exp(scaled)));
                continue;
            }
            example.setValue(predictedLabel, prediction * (this.labelMax - this.labelMin) + this.labelMin);
        }
        return exampleSet;
    }

    public Component getVisualizationComponent(IOContainer ioContainer) {
        JRadioSelectionPanel mainPanel = new JRadioSelectionPanel();
        ExtendedJScrollPane graphView = new ExtendedJScrollPane(new SimpleNeuralNetVisualizer(this.network, this.attributeNames));
        Component textView = super.getVisualizationComponent(ioContainer);
        mainPanel.addComponent("Graph View", graphView, "Changes to a graphical view of this model.");
        mainPanel.addComponent("Text View", textView, "Changes to a textual description of this model.");
        return mainPanel;
    }

    public String toString() {
        StringBuffer result = new StringBuffer();
        List layers = this.network.getLayers();
        Iterator i = layers.iterator();
        int layerIndex = 0;
        while (i.hasNext()) {
            Layer layer = (Layer)i.next();
            String nodeString = layer.getNeuronCount() == 1 ? "1 node" : String.valueOf(layer.getNeuronCount()) + " nodes";
            String titleString = "Layer " + (layerIndex + 1) + " (" + nodeString + ")";
            result.append(String.valueOf(titleString) + com.rapidminer.tools.Tools.getLineSeparator());
            int t = 0;
            while (t < titleString.length()) {
                result.append("-");
                ++t;
            }
            result.append(com.rapidminer.tools.Tools.getLineSeparator());
            if (layerIndex == 0) {
                result.append(String.valueOf(Arrays.asList(this.attributeNames).toString()) + com.rapidminer.tools.Tools.getLineSeparators(2));
                if (layer.hasMatrix()) {
                    this.layerWeightsToString(result, layer.getMatrix(), layerIndex);
                }
            } else if (layer.hasMatrix()) {
                this.layerWeightsToString(result, layer.getMatrix(), layerIndex);
            }
            result.append(com.rapidminer.tools.Tools.getLineSeparator());
            ++layerIndex;
        }
        return result.toString();
    }

    private void layerWeightsToString(StringBuffer result, Matrix matrix, int currentLayerIndex) {
        result.append("Output Weights:" + com.rapidminer.tools.Tools.getLineSeparator());
        int rows = matrix.getRows();
        int cols = matrix.getCols();
        int c = 0;
        while (c < cols) {
            result.append(String.valueOf(com.rapidminer.tools.Tools.getLineSeparator()) + "* To Layer " + (currentLayerIndex + 2) + " - Node " + (c + 1) + ":" + com.rapidminer.tools.Tools.getLineSeparator());
            int r = 0;
            while (r < rows - 1) {
                result.append("From Node " + (r + 1) + ": ");
                result.append(matrix.get(r, c));
                result.append(com.rapidminer.tools.Tools.getLineSeparator());
                ++r;
            }
            result.append("From Threshold Node: ");
            result.append(matrix.get(rows - 1, c));
            result.append(com.rapidminer.tools.Tools.getLineSeparator());
            ++c;
        }
    }
}

