/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.features.weighting;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.AttributeWeights;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.table.AttributeFactory;
import com.rapidminer.operator.IOContainer;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorCreationException;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.features.weighting.AbstractWeighting;
import com.rapidminer.operator.learner.functions.kernel.JMySVMLearner;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.tools.OperatorService;
import java.util.LinkedList;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class SVMWeighting
extends AbstractWeighting {
    public SVMWeighting(OperatorDescription description) {
        super(description);
    }

    @Override
    public AttributeWeights calculateWeights(ExampleSet exampleSet) throws OperatorException {
        Attribute label = exampleSet.getAttributes().getLabel();
        if (label == null) {
            throw new UserError(this, 105);
        }
        JMySVMLearner svmOperator = null;
        try {
            svmOperator = OperatorService.createOperator(JMySVMLearner.class);
        }
        catch (OperatorCreationException e) {
            throw new UserError((Operator)this, 904, "inner SVM operator", e.getMessage());
        }
        svmOperator.setParameter("kernel_type", "0");
        svmOperator.setParameter("C", String.valueOf(this.getParameterAsDouble("C")));
        svmOperator.setParameter("calculate_weights", "true");
        AttributeWeights result = null;
        if (label.isNumerical() || label.isNominal() && label.getMapping().size() == 2) {
            result = this.calculateAttributeWeights(svmOperator, exampleSet);
        } else if (label.isNominal()) {
            exampleSet.recalculateAttributeStatistics(label);
            int totalClassSizeSum = 0;
            int[] classFrequencies = new int[label.getMapping().size()];
            int counter = 0;
            LinkedList<AttributeWeights> allWeights = new LinkedList<AttributeWeights>();
            for (String value : label.getMapping().getValues()) {
                int frequency = (int)exampleSet.getStatistics(label, "count", value);
                classFrequencies[counter++] = frequency;
                totalClassSizeSum += frequency;
                Attribute tempLabel = AttributeFactory.createAttribute("temp_label", 6);
                int positiveIndex = tempLabel.getMapping().mapString("positive");
                int negativeIndex = tempLabel.getMapping().mapString("negative");
                exampleSet.getExampleTable().addAttribute(tempLabel);
                exampleSet.getAttributes().addRegular(tempLabel);
                int currentLabelIndex = label.getMapping().mapString(value);
                for (Example e : exampleSet) {
                    int oldLabelValue = (int)e.getValue(label);
                    if (oldLabelValue == currentLabelIndex) {
                        e.setValue(tempLabel, positiveIndex);
                        continue;
                    }
                    e.setValue(tempLabel, negativeIndex);
                }
                exampleSet.getAttributes().remove(tempLabel);
                exampleSet.getAttributes().setLabel(tempLabel);
                AttributeWeights currentWeights = this.calculateAttributeWeights(svmOperator, exampleSet);
                allWeights.add(currentWeights);
                exampleSet.getAttributes().setLabel(label);
                exampleSet.getExampleTable().removeAttribute(tempLabel);
            }
            result = new AttributeWeights();
            for (String attributeName : ((AttributeWeights)allWeights.get(0)).getAttributeNames()) {
                double currentWeightSum = 0.0;
                counter = 0;
                for (AttributeWeights weights : allWeights) {
                    double weight = weights.getWeight(attributeName);
                    currentWeightSum += Math.abs(weight) * (double)classFrequencies[counter++];
                }
                result.setWeight(attributeName, currentWeightSum / (double)totalClassSizeSum);
            }
        } else {
            this.logError("Calculation of SVM weights only possible for numerical or nominal labels.");
        }
        result.setSource(this.getName());
        return result;
    }

    private AttributeWeights calculateAttributeWeights(Operator svmOperator, ExampleSet exampleSet) throws OperatorException {
        IOContainer ioContainer = new IOContainer(exampleSet);
        ioContainer = svmOperator.apply(ioContainer);
        AttributeWeights result = ioContainer.remove(AttributeWeights.class);
        return result;
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        types.add(new ParameterTypeDouble("C", "The SVM complexity weighting factor.", Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 0.0));
        return types;
    }
}

