/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.functions.kernel;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.AbstractLearner;
import com.rapidminer.operator.learner.LearnerCapability;
import com.rapidminer.operator.learner.functions.kernel.RVMModel;
import com.rapidminer.operator.learner.functions.kernel.rvm.ClassificationProblem;
import com.rapidminer.operator.learner.functions.kernel.rvm.ConstructiveRegression;
import com.rapidminer.operator.learner.functions.kernel.rvm.Parameter;
import com.rapidminer.operator.learner.functions.kernel.rvm.RVMClassification;
import com.rapidminer.operator.learner.functions.kernel.rvm.RVMRegression;
import com.rapidminer.operator.learner.functions.kernel.rvm.RegressionProblem;
import com.rapidminer.operator.learner.functions.kernel.rvm.kernel.KernelBasisFunction;
import com.rapidminer.operator.learner.functions.kernel.rvm.kernel.KernelCauchy;
import com.rapidminer.operator.learner.functions.kernel.rvm.kernel.KernelEpanechnikov;
import com.rapidminer.operator.learner.functions.kernel.rvm.kernel.KernelGaussianCombination;
import com.rapidminer.operator.learner.functions.kernel.rvm.kernel.KernelLaplace;
import com.rapidminer.operator.learner.functions.kernel.rvm.kernel.KernelMultiquadric;
import com.rapidminer.operator.learner.functions.kernel.rvm.kernel.KernelPoly;
import com.rapidminer.operator.learner.functions.kernel.rvm.kernel.KernelRadial;
import com.rapidminer.operator.learner.functions.kernel.rvm.kernel.KernelSigmoid;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.ParameterTypeSingle;
import java.util.Iterator;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class RVMLearner
extends AbstractLearner {
    public static final String PARAMETER_RVM_TYPE = "rvm_type";
    public static final String PARAMETER_KERNEL_TYPE = "kernel_type";
    public static final String PARAMETER_MAX_ITERATION = "max_iteration";
    public static final String PARAMETER_MIN_DELTA_LOG_ALPHA = "min_delta_log_alpha";
    public static final String PARAMETER_ALPHA_MAX = "alpha_max";
    public static final String PARAMETER_KERNEL_LENGTHSCALE = "kernel_lengthscale";
    public static final String PARAMETER_KERNEL_DEGREE = "kernel_degree";
    public static final String PARAMETER_KERNEL_BIAS = "kernel_bias";
    public static final String PARAMETER_KERNEL_SIGMA1 = "kernel_sigma1";
    public static final String PARAMETER_KERNEL_SIGMA2 = "kernel_sigma2";
    public static final String PARAMETER_KERNEL_SIGMA3 = "kernel_sigma3";
    public static final String PARAMETER_KERNEL_SHIFT = "kernel_shift";
    public static final String PARAMETER_KERNEL_A = "kernel_a";
    public static final String PARAMETER_KERNEL_B = "kernel_b";
    public static final String[] RVM_TYPES = new String[]{"Regression-RVM", "Classification-RVM", "Constructive-Regression-RVM"};
    public static final String[] KERNEL_TYPES = new String[]{"rbf", "cauchy", "laplace", "poly", "sigmoid", "Epanechnikov", "gaussian combination", "multiquadric"};

    public RVMLearner(OperatorDescription description) {
        super(description);
    }

    @Override
    public boolean supportsCapability(LearnerCapability lc) {
        if (lc == LearnerCapability.NUMERICAL_ATTRIBUTES) {
            return true;
        }
        if (lc == LearnerCapability.BINOMINAL_CLASS) {
            return true;
        }
        return lc == LearnerCapability.NUMERICAL_CLASS;
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    @Override
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        this.log("Creating RVM.");
        Parameter parameter = new Parameter();
        int numExamples = exampleSet.size();
        int numBases = numExamples + 1;
        parameter.min_delta_log_alpha = this.getParameterAsDouble(PARAMETER_MIN_DELTA_LOG_ALPHA);
        parameter.alpha_max = this.getParameterAsDouble(PARAMETER_ALPHA_MAX);
        parameter.maxIterations = this.getParameterAsInt(PARAMETER_MAX_ITERATION);
        this.log("=> Creating input / output vectors.");
        double[][] x = new double[numExamples][exampleSet.getAttributes().size()];
        double[][] t = new double[numExamples][1];
        Iterator reader = exampleSet.iterator();
        int k = 0;
        while (reader.hasNext()) {
            double[] targetVector = new double[1];
            Example e = (Example)reader.next();
            targetVector[0] = e.getLabel();
            x[k] = RVMModel.makeInputVector(e);
            t[k] = targetVector;
            ++k;
        }
        Attribute label = exampleSet.getAttributes().getLabel();
        parameter.initAlpha = Math.pow(1.0 / (double)numExamples, 2.0);
        parameter.initSigma = 0.1;
        this.log("Creating kernel basis functions [" + KERNEL_TYPES[this.getParameterAsInt(PARAMETER_KERNEL_TYPE)] + "].");
        KernelBasisFunction[] kernels = this.createKernels(x, numBases);
        String RVMType2 = RVM_TYPES[this.getParameterAsInt(PARAMETER_RVM_TYPE)];
        com.rapidminer.operator.learner.functions.kernel.rvm.Model model = null;
        if (label.isNominal()) {
            if (label.getMapping().size() != 2) {
                throw new UserError((Operator)this, 114, this.getName(), label.getName());
            }
            int[] c = new int[numExamples];
            k = 0;
            while (k < numExamples) {
                c[k] = (int)t[k][0];
                ++k;
            }
            ClassificationProblem problem = new ClassificationProblem(x, c, kernels);
            if (!RVMType2.equals("Classification-RVM")) throw new UserError((Operator)this, 207, RVMType2, PARAMETER_RVM_TYPE, "only Classification-RVM can be used for the given two class classification problem");
            RVMClassification RVM = new RVMClassification(problem, parameter);
            try {
                model = RVM.learn();
                return new RVMModel(exampleSet, model);
            }
            catch (ArrayIndexOutOfBoundsException e) {
                throw new UserError(this, 924);
            }
        } else {
            RegressionProblem problem = new RegressionProblem(x, t, kernels);
            if (RVMType2.equals("Regression-RVM")) {
                RVMRegression RVM = new RVMRegression(problem, parameter);
                model = RVM.learn();
                return new RVMModel(exampleSet, model);
            } else {
                if (!RVMType2.equals("Constructive-Regression-RVM")) throw new UserError((Operator)this, 207, RVMType2, PARAMETER_RVM_TYPE, "only one of the regression types can be used for the given regression problem");
                ConstructiveRegression RVM = new ConstructiveRegression(problem, parameter);
                model = RVM.learn();
            }
        }
        return new RVMModel(exampleSet, model);
    }

    public KernelBasisFunction[] createKernels(double[][] x, int numKernels) throws OperatorException {
        KernelBasisFunction[] kernels = new KernelBasisFunction[numKernels];
        KernelBasisFunction kernel = null;
        double lengthScale = this.getParameterAsDouble(PARAMETER_KERNEL_LENGTHSCALE);
        double bias = this.getParameterAsDouble(PARAMETER_KERNEL_BIAS);
        double degree = this.getParameterAsDouble(PARAMETER_KERNEL_DEGREE);
        double a = this.getParameterAsDouble(PARAMETER_KERNEL_A);
        double b = this.getParameterAsDouble(PARAMETER_KERNEL_B);
        double sigma1 = this.getParameterAsDouble(PARAMETER_KERNEL_SIGMA1);
        double sigma2 = this.getParameterAsDouble(PARAMETER_KERNEL_SIGMA2);
        double sigma3 = this.getParameterAsDouble(PARAMETER_KERNEL_SIGMA3);
        double shift = this.getParameterAsDouble(PARAMETER_KERNEL_SHIFT);
        int j = 0;
        while (j < numKernels - 1) {
            double[] input = x[j];
            switch (this.getParameterAsInt(PARAMETER_KERNEL_TYPE)) {
                case 0: {
                    kernel = new KernelBasisFunction(new KernelRadial(lengthScale), input);
                    break;
                }
                case 1: {
                    kernel = new KernelBasisFunction(new KernelCauchy(lengthScale), input);
                    break;
                }
                case 2: {
                    kernel = new KernelBasisFunction(new KernelLaplace(lengthScale), input);
                    break;
                }
                case 3: {
                    kernel = new KernelBasisFunction(new KernelPoly(lengthScale, bias, degree), input);
                    break;
                }
                case 4: {
                    kernel = new KernelBasisFunction(new KernelSigmoid(a, b), input);
                    break;
                }
                case 5: {
                    kernel = new KernelBasisFunction(new KernelEpanechnikov(sigma1, degree), input);
                    break;
                }
                case 6: {
                    kernel = new KernelBasisFunction(new KernelGaussianCombination(sigma1, sigma2, sigma3), input);
                    break;
                }
                case 7: {
                    kernel = new KernelBasisFunction(new KernelMultiquadric(sigma1, shift), input);
                    break;
                }
                default: {
                    kernel = new KernelBasisFunction(new KernelRadial(lengthScale), input);
                }
            }
            kernels[j + 1] = kernel;
            ++j;
        }
        return kernels;
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        ParameterTypeSingle type = new ParameterTypeCategory(PARAMETER_RVM_TYPE, "Regression RVM", RVM_TYPES, 0);
        type.setExpert(false);
        types.add(type);
        type = new ParameterTypeCategory(PARAMETER_KERNEL_TYPE, "The type of the kernel functions.", KERNEL_TYPES, 0);
        type.setExpert(false);
        types.add(type);
        type = new ParameterTypeInt(PARAMETER_MAX_ITERATION, "The maximum number of iterations used.", 1, Integer.MAX_VALUE, 100);
        types.add(type);
        type = new ParameterTypeDouble(PARAMETER_MIN_DELTA_LOG_ALPHA, "Abort iteration if largest log alpha change is smaller than this", 0.0, Double.POSITIVE_INFINITY, 0.001);
        types.add(type);
        type = new ParameterTypeDouble(PARAMETER_ALPHA_MAX, "Prune basis function if its alpha is bigger than this", 0.0, Double.POSITIVE_INFINITY, 1.0E12);
        types.add(type);
        type = new ParameterTypeDouble(PARAMETER_KERNEL_LENGTHSCALE, "The lengthscale used in all kernels.", 0.0, Double.POSITIVE_INFINITY, 3.0);
        type.setExpert(false);
        types.add(type);
        type = new ParameterTypeDouble(PARAMETER_KERNEL_DEGREE, "The degree used in the poly kernel.", 0.0, Double.POSITIVE_INFINITY, 2.0);
        type.setExpert(false);
        types.add(type);
        type = new ParameterTypeDouble(PARAMETER_KERNEL_BIAS, "The bias used in the poly kernel.", 0.0, Double.POSITIVE_INFINITY, 1.0);
        type.setExpert(false);
        types.add(type);
        type = new ParameterTypeDouble(PARAMETER_KERNEL_SIGMA1, "The SVM kernel parameter sigma1 (Epanechnikov, Gaussian Combination, Multiquadric).", 0.0, Double.POSITIVE_INFINITY, 1.0);
        type.setExpert(false);
        types.add(type);
        type = new ParameterTypeDouble(PARAMETER_KERNEL_SIGMA2, "The SVM kernel parameter sigma2 (Gaussian Combination).", 0.0, Double.POSITIVE_INFINITY, 0.0);
        type.setExpert(false);
        types.add(type);
        type = new ParameterTypeDouble(PARAMETER_KERNEL_SIGMA3, "The SVM kernel parameter sigma3 (Gaussian Combination).", 0.0, Double.POSITIVE_INFINITY, 2.0);
        type.setExpert(false);
        types.add(type);
        type = new ParameterTypeDouble(PARAMETER_KERNEL_SHIFT, "The SVM kernel parameter shift (polynomial, Multiquadric).", Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 1.0);
        type.setExpert(false);
        types.add(type);
        type = new ParameterTypeDouble(PARAMETER_KERNEL_A, "The SVM kernel parameter a (neural).", Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 1.0);
        type.setExpert(false);
        types.add(type);
        type = new ParameterTypeDouble(PARAMETER_KERNEL_B, "The SVM kernel parameter b (neural).", Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 0.0);
        type.setExpert(false);
        types.add(type);
        return types;
    }
}

