/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.meta;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.table.AttributeFactory;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorCreationException;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.Learner;
import com.rapidminer.operator.learner.LearnerCapability;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.lazy.DefaultLearner;
import com.rapidminer.operator.learner.meta.AbstractMetaLearner;
import com.rapidminer.operator.learner.meta.AdditiveRegressionModel;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.tools.OperatorService;
import java.util.Iterator;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class AdditiveRegression
extends AbstractMetaLearner {
    public static final String PARAMETER_ITERATIONS = "iterations";
    public static final String PARAMETER_SHRINKAGE = "shrinkage";

    public AdditiveRegression(OperatorDescription description) {
        super(description);
    }

    @Override
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        ExampleSet workingExampleSet = (ExampleSet)exampleSet.clone();
        Attribute originalLabel = workingExampleSet.getAttributes().getLabel();
        Attribute workingLabel = AttributeFactory.createAttribute(originalLabel, "working_label");
        workingExampleSet.getExampleTable().addAttribute(workingLabel);
        workingExampleSet.getAttributes().addRegular(workingLabel);
        for (Example example : workingExampleSet) {
            example.setValue(workingLabel, example.getValue(originalLabel));
        }
        workingExampleSet.getAttributes().remove(workingLabel);
        workingExampleSet.getAttributes().setLabel(workingLabel);
        Learner defaultLearner = null;
        try {
            defaultLearner = OperatorService.createOperator(DefaultLearner.class);
        }
        catch (OperatorCreationException e) {
            throw new OperatorException(String.valueOf(this.getName()) + ": not able to create default classifier!", e);
        }
        Model defaultModel = defaultLearner.learn(workingExampleSet);
        this.residualReplace(workingExampleSet, defaultModel, false);
        Model[] residualModels = new Model[this.getParameterAsInt(PARAMETER_ITERATIONS)];
        int iteration = 0;
        while (iteration < residualModels.length) {
            residualModels[iteration] = this.applyInnerLearner(workingExampleSet);
            this.residualReplace(workingExampleSet, residualModels[iteration], true);
            ++iteration;
        }
        workingExampleSet.getAttributes().remove(workingLabel);
        workingExampleSet.getExampleTable().removeAttribute(workingLabel);
        return new AdditiveRegressionModel(exampleSet, defaultModel, residualModels, this.getParameterAsDouble(PARAMETER_SHRINKAGE));
    }

    private void residualReplace(ExampleSet exampleSet, Model model, boolean shrinkage) throws OperatorException {
        ExampleSet resultSet = model.apply(exampleSet);
        Attribute label = exampleSet.getAttributes().getLabel();
        Iterator originalReader = exampleSet.iterator();
        Iterator predictionReader = resultSet.iterator();
        while (originalReader.hasNext() && predictionReader.hasNext()) {
            Example originalExample = (Example)originalReader.next();
            Example predictionExample = (Example)predictionReader.next();
            double prediction = predictionExample.getPredictedLabel();
            if (shrinkage) {
                prediction *= this.getParameterAsDouble(PARAMETER_SHRINKAGE);
            }
            double residual = originalExample.getLabel() - prediction;
            originalExample.setValue(label, residual);
        }
        PredictionModel.removePredictedLabel(resultSet);
    }

    @Override
    public int getMinNumberOfInnerOperators() {
        return 1;
    }

    @Override
    public int getMaxNumberOfInnerOperators() {
        return 1;
    }

    @Override
    public boolean supportsCapability(LearnerCapability capability) {
        if (capability.equals(LearnerCapability.BINOMINAL_CLASS)) {
            return false;
        }
        if (capability.equals(LearnerCapability.POLYNOMINAL_CLASS)) {
            return false;
        }
        return super.supportsCapability(capability);
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        types.add(new ParameterTypeInt(PARAMETER_ITERATIONS, "The number of iterations.", 1, Integer.MAX_VALUE, 10));
        types.add(new ParameterTypeDouble(PARAMETER_SHRINKAGE, "Reducing this learning rate prevent overfitting but increases the learning time.", 0.0, 1.0, 1.0));
        return types;
    }
}

