/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.operator.learner.subgroups;

import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.AbstractLearner;
import com.rapidminer.operator.learner.LearnerCapability;
import com.rapidminer.operator.learner.subgroups.RuleSet;
import com.rapidminer.operator.learner.subgroups.hypothesis.Hypothesis;
import com.rapidminer.operator.learner.subgroups.hypothesis.Rule;
import com.rapidminer.operator.learner.subgroups.utility.UtilityFunction;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class SubgroupDiscovery
extends AbstractLearner {
    public static final String PARAMETER_DISCOVERY_MODE = "mode";
    public static final String[] DISCOVERY_MODES = new String[]{"above minimum utility", "k best rules"};
    public static final int DISCOVERY_MODE_ABOVE_MINIMUM_UTILITY = 0;
    public static final int DISCOVERY_MODE_K_BEST_RULES = 1;
    public static final String PARAMETER_UTILITY_FUNCTION = "utility_function";
    public static final String PARAMETER_RULE_GENERATION = "rule_generation";
    public static final String[] RULE_GENERATION_MODES = Hypothesis.RULE_GENERATION_MODES;
    public static final String PARAMETER_MAX_DEPTH = "max_depth";
    public static final String PARAMETER_MIN_UTILITY = "min_utility";
    public static final String PARAMETER_K_BEST_RULES = "k_best_rules";
    public static final String PARAMETER_MIN_COVERAGE = "min_coverage";
    public static final String PARAMETER_MAX_CACHE = "max_cache";

    public SubgroupDiscovery(OperatorDescription description) {
        super(description);
    }

    @Override
    public Model learn(ExampleSet exampleSet) throws OperatorException {
        int mode = this.getParameterAsInt(PARAMETER_DISCOVERY_MODE);
        int maxDepth = this.getParameterAsInt(PARAMETER_MAX_DEPTH);
        double minUtility = this.getParameterAsDouble(PARAMETER_MIN_UTILITY);
        int kBestRules = this.getParameterAsInt(PARAMETER_K_BEST_RULES);
        int ruleGenerationMode = this.getParameterAsInt(PARAMETER_RULE_GENERATION);
        double coverageThreshold = this.getParameterAsDouble(PARAMETER_MIN_COVERAGE);
        int maxCache = this.getParameterAsInt(PARAMETER_MAX_CACHE);
        int numberOfAttributes = exampleSet.getAttributes().size();
        double totalWeight = 0.0;
        double totalPositiveWeight = 0.0;
        for (Example example : exampleSet) {
            double weight = 1.0;
            if (exampleSet.getAttributes().getWeight() != null) {
                weight = example.getWeight();
            }
            totalWeight += weight;
            if (example.getLabel() != (double)example.getAttributes().getLabel().getMapping().getPositiveIndex()) continue;
            totalPositiveWeight += weight;
        }
        UtilityFunction[] utilityFunctions = UtilityFunction.getUtilityFunctions(totalWeight, totalPositiveWeight);
        UtilityFunction mainUtilityFunction = utilityFunctions[this.getParameterAsInt(PARAMETER_UTILITY_FUNCTION)];
        RuleComparator ruleComparator = new RuleComparator(mainUtilityFunction.getClass());
        LinkedList<Rule> acceptedRules = new LinkedList<Rule>();
        ArrayList<Rule> bestRules = new ArrayList<Rule>(kBestRules);
        LinkedList<Hypothesis> hypotheses = new LinkedList<Hypothesis>();
        Hypothesis emptyHypothesis = new Hypothesis();
        hypotheses.addAll(emptyHypothesis.restrictedRefine(exampleSet.getAttributes()));
        int i = 0;
        while (i < (maxDepth > numberOfAttributes ? numberOfAttributes : maxDepth)) {
            if (hypotheses.size() == 0) break;
            this.log("evaluating " + hypotheses.size() + " hypotheses with " + (i + 1) + " literals");
            for (Example example : exampleSet) {
                for (Hypothesis hypothesis : hypotheses) {
                    hypothesis.apply(example);
                }
            }
            int discarded = 0;
            Iterator iterator = hypotheses.iterator();
            while (iterator.hasNext()) {
                Hypothesis hypothesis;
                hypothesis = (Hypothesis)iterator.next();
                if (!(hypothesis.getCoveredWeight() / totalWeight <= coverageThreshold)) continue;
                iterator.remove();
                ++discarded;
            }
            if (discarded > 0) {
                this.log("removed " + discarded + " hypotheses not exceeding min coverage");
            }
            if (maxCache != -1) {
                Collections.sort(hypotheses, new HypothesisComparator());
                int deleteHypotheses = hypotheses.size() - maxCache;
                int j = 0;
                while (j < deleteHypotheses) {
                    hypotheses.removeLast();
                    ++j;
                }
                if (deleteHypotheses > 0) {
                    this.log("removed " + deleteHypotheses + " hypotheses with the lowest coverage");
                }
            }
            this.log("generating rules from " + hypotheses.size() + " hypotheses");
            LinkedList<Hypothesis> nextHypotheses = new LinkedList<Hypothesis>();
            for (Hypothesis hypothesis : hypotheses) {
                LinkedList<Rule> rules = hypothesis.generateRules(ruleGenerationMode, exampleSet.getAttributes().getLabel());
                for (Rule rule : rules) {
                    int j = 0;
                    while (j < utilityFunctions.length) {
                        rule.setUtility(utilityFunctions[j], utilityFunctions[j].utility(rule));
                        ++j;
                    }
                    double utility = mainUtilityFunction.utility(rule);
                    switch (mode) {
                        case 0: {
                            if (!(utility >= minUtility)) break;
                            acceptedRules.add(rule);
                            this.log("scored: " + rule);
                            break;
                        }
                        case 1: {
                            if (bestRules.size() < kBestRules) {
                                bestRules.add(rule);
                                this.log("scored: " + rule + " [q(h)=" + utility + "]");
                                Collections.sort(bestRules, ruleComparator);
                                break;
                            }
                            if (!(utility > ((Rule)bestRules.get(kBestRules - 1)).getUtility(mainUtilityFunction.getClass()))) break;
                            bestRules.set(kBestRules - 1, rule);
                            minUtility = utility;
                            this.log("scored: " + rule + " [q(h)=" + utility + "]");
                            Collections.sort(bestRules, ruleComparator);
                        }
                    }
                }
                double optimisticEstimate = mainUtilityFunction.optimisticEstimate(hypothesis);
                if (!(optimisticEstimate >= minUtility)) continue;
                for (Hypothesis nextHypothesis : hypothesis.restrictedRefine()) {
                    nextHypotheses.add(nextHypothesis);
                }
            }
            hypotheses = nextHypotheses;
            ++i;
        }
        RuleSet model = new RuleSet(exampleSet);
        switch (mode) {
            case 0: {
                Collections.sort(acceptedRules, ruleComparator);
                for (Rule rule : acceptedRules) {
                    model.addRule(rule);
                }
                break;
            }
            case 1: {
                Collections.sort(bestRules, ruleComparator);
                for (Rule rule : bestRules) {
                    model.addRule(rule);
                }
                break;
            }
        }
        return model;
    }

    @Override
    public boolean supportsCapability(LearnerCapability lc) {
        if (lc == LearnerCapability.POLYNOMINAL_ATTRIBUTES) {
            return true;
        }
        if (lc == LearnerCapability.BINOMINAL_ATTRIBUTES) {
            return true;
        }
        if (lc == LearnerCapability.BINOMINAL_CLASS) {
            return true;
        }
        return lc == LearnerCapability.WEIGHTED_EXAMPLES;
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> types = super.getParameterTypes();
        types.add(new ParameterTypeCategory(PARAMETER_DISCOVERY_MODE, "Discovery mode.", DISCOVERY_MODES, 1));
        types.add(new ParameterTypeCategory(PARAMETER_UTILITY_FUNCTION, "Utility function.", UtilityFunction.FUNCTIONS, 6));
        types.add(new ParameterTypeDouble(PARAMETER_MIN_UTILITY, "Minimum quality which has to be reached.", Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 0.0));
        types.add(new ParameterTypeInt(PARAMETER_K_BEST_RULES, "Report the k best rules.", 1, Integer.MAX_VALUE, 10));
        types.add(new ParameterTypeCategory(PARAMETER_RULE_GENERATION, "Determines which rules are generated.", RULE_GENERATION_MODES, 3));
        types.add(new ParameterTypeInt(PARAMETER_MAX_DEPTH, "Maximum depth of BFS.", 0, Integer.MAX_VALUE, 5));
        types.add(new ParameterTypeDouble(PARAMETER_MIN_COVERAGE, "Only consider rules which exceed the given coverage threshold.", 0.0, 1.0, 0.0));
        types.add(new ParameterTypeInt(PARAMETER_MAX_CACHE, "Bounds the number of rules which are evaluated (only the most supported rules are used).", -1, Integer.MAX_VALUE, -1));
        return types;
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    private class HypothesisComparator
    implements Comparator<Hypothesis> {
        private HypothesisComparator() {
        }

        @Override
        public int compare(Hypothesis firstHypothesis, Hypothesis secondHypothesis) {
            return Double.compare(secondHypothesis.getCoveredWeight(), firstHypothesis.getCoveredWeight());
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    private class RuleComparator
    implements Comparator<Rule> {
        Class<? extends UtilityFunction> functionClass;

        public RuleComparator(Class<? extends UtilityFunction> functionClass) {
            this.functionClass = functionClass;
        }

        @Override
        public int compare(Rule firstRule, Rule secondRule) {
            return Double.compare(secondRule.getUtility(this.functionClass), firstRule.getUtility(this.functionClass));
        }
    }
}

