/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.lazy;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Attributes;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.tools.container.Tupel;
import com.rapidminer.tools.math.container.GeometricDataCollection;
import java.util.ArrayList;
import java.util.Collection;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class KNNRegressionModel
extends PredictionModel {
    private static final long serialVersionUID = -6292869962412072573L;
    private int k;
    private GeometricDataCollection<Double> samples;
    private ArrayList<String> sampleAttributeNames;
    private boolean weightByDistance;

    public KNNRegressionModel(ExampleSet trainingSet, GeometricDataCollection<Double> samples, int k, boolean weightByDistance) {
        super(trainingSet);
        this.k = k;
        this.samples = samples;
        this.weightByDistance = weightByDistance;
        Attributes attributes = trainingSet.getAttributes();
        this.sampleAttributeNames = new ArrayList(attributes.size());
        for (Attribute attribute : attributes) {
            this.sampleAttributeNames.add(attribute.getName());
        }
    }

    @Override
    public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabel) throws OperatorException {
        ArrayList<Attribute> sampleAttributes = new ArrayList<Attribute>(this.sampleAttributeNames.size());
        Attributes attributes = exampleSet.getAttributes();
        for (String attributeName : this.sampleAttributeNames) {
            sampleAttributes.add(attributes.get(attributeName));
        }
        double[] values = new double[sampleAttributes.size()];
        for (Example example : exampleSet) {
            int i = 0;
            for (Attribute attribute : sampleAttributes) {
                values[i] = example.getValue(attribute);
                ++i;
            }
            double result = 0.0;
            if (!this.weightByDistance) {
                Collection<Double> neighbourLabels = this.samples.getNearestValues(this.k, values);
                for (double label : neighbourLabels) {
                    result += label;
                }
                result /= (double)this.k;
            } else {
                Collection<Tupel<Double, Double>> neighbourTupels = this.samples.getNearestValueDistances(this.k, values);
                double totalDistance = 0.0;
                for (Tupel<Double, Double> tupel : neighbourTupels) {
                    totalDistance += tupel.getFirst().doubleValue();
                }
                double totalSimilarity = 0.0;
                if (totalDistance == 0.0) {
                    totalDistance = 1.0;
                    totalSimilarity = this.k;
                } else {
                    totalSimilarity = Math.max(this.k - 1, 1);
                }
                for (Tupel<Double, Double> tupel : neighbourTupels) {
                    result += tupel.getSecond() * (1.0 - tupel.getFirst() / totalDistance) / totalSimilarity;
                }
            }
            example.setValue(predictedLabel, result);
        }
        return exampleSet;
    }
}

