/*
 * Decompiled with CFR 0.152.
 */
package eu.dnetlib.dhp.oa.dedup;

import eu.dnetlib.dhp.application.ArgumentApplicationParser;
import eu.dnetlib.dhp.oa.dedup.AbstractSparkAction;
import eu.dnetlib.dhp.oa.dedup.DedupUtility;
import eu.dnetlib.dhp.oa.dedup.IdGenerator;
import eu.dnetlib.dhp.oa.dedup.OpenorgsUtility;
import eu.dnetlib.dhp.oa.dedup.SparkCreateMergeRels;
import eu.dnetlib.dhp.schema.common.ModelSupport;
import eu.dnetlib.dhp.schema.oaf.DataInfo;
import eu.dnetlib.dhp.schema.oaf.Qualifier;
import eu.dnetlib.dhp.schema.oaf.Relation;
import eu.dnetlib.dhp.utils.ISLookupClientFactory;
import eu.dnetlib.enabling.is.lookup.rmi.ISLookUpException;
import eu.dnetlib.enabling.is.lookup.rmi.ISLookUpService;
import eu.dnetlib.pace.config.Config;
import eu.dnetlib.pace.config.DedupConfig;
import eu.dnetlib.pace.model.SparkDeduper;
import eu.dnetlib.pace.tree.support.TreeProcessor;
import java.io.IOException;
import java.io.InputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.Spliterators;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import org.apache.commons.io.IOUtils;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.FlatMapGroupsFunction;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SaveMode;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.encoders.RowEncoder;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.dom4j.DocumentException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.xml.sax.SAXException;

public class SparkRefineMergeRels
extends AbstractSparkAction {
    private static final Logger log = LoggerFactory.getLogger(SparkRefineMergeRels.class);
    private static final StructType rowSchema = new StructType(new StructField[]{new StructField("source", DataTypes.StringType, false, Metadata.empty()), new StructField("target", DataTypes.StringType, false, Metadata.empty())});
    private static double THRESHOLD = 0.7;

    public SparkRefineMergeRels(ArgumentApplicationParser parser, SparkSession spark) {
        super(parser, spark);
    }

    public static void main(String[] args) throws Exception {
        ArgumentApplicationParser parser = new ArgumentApplicationParser(IOUtils.toString((InputStream)SparkCreateMergeRels.class.getResourceAsStream("/eu/dnetlib/dhp/oa/dedup/refineMergeRels_parameters.json")));
        parser.parseArgument(args);
        String isLookUpUrl = parser.get("isLookUpUrl");
        log.info("isLookupUrl {}", (Object)isLookUpUrl);
        SparkConf conf = new SparkConf();
        conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
        conf.registerKryoClasses(ModelSupport.getOafModelClasses());
        new SparkRefineMergeRels(parser, SparkRefineMergeRels.getSparkWithHiveSession(conf)).run(ISLookupClientFactory.getLookUpService((String)isLookUpUrl));
    }

    @Override
    void run(ISLookUpService isLookUpService) throws DocumentException, IOException, ISLookUpException, SAXException {
        String graphBasePath = this.parser.get("graphBasePath");
        String isLookUpUrl = this.parser.get("isLookUpUrl");
        String actionSetId = this.parser.get("actionSetId");
        String workingPath = this.parser.get("workingPath");
        int numPartitions = Optional.ofNullable(this.parser.get("numPartitions")).map(Integer::valueOf).orElse(1000);
        log.info("numPartitions: '{}'", (Object)numPartitions);
        log.info("graphBasePath: '{}'", (Object)graphBasePath);
        log.info("isLookUpUrl:   '{}'", (Object)isLookUpUrl);
        log.info("actionSetId:   '{}'", (Object)actionSetId);
        log.info("workingPath:   '{}'", (Object)workingPath);
        for (DedupConfig dedupConf : this.getConfigurations(isLookUpService, actionSetId)) {
            String subEntity = dedupConf.getWf().getSubEntityValue();
            log.info("Processing mergerels for: '{}'", (Object)subEntity);
            String mergeRelPath = DedupUtility.createMergeRelPath(workingPath, actionSetId, subEntity);
            SparkDeduper deduper = new SparkDeduper(dedupConf);
            Dataset<Row> entities = this.spark.read().textFile(DedupUtility.createEntityPath(graphBasePath, subEntity)).transform(deduper.model().parseJsonDataset());
            entities = SparkRefineMergeRels.appendNegativeConstraints(this.spark, entities, graphBasePath, subEntity);
            Dataset rawMergeRels = this.spark.read().load(mergeRelPath).as(Encoders.bean(Relation.class)).where("relClass == 'merges'").select("source", new String[]{"target"}).join(entities, functions.col((String)"target").equalTo((Object)entities.col("identifier"))).withColumnRenamed("source", "groupId");
            Dataset<Row> conflictualIds = SparkRefineMergeRels.getConflictualIds((Dataset<Row>)rawMergeRels);
            Dataset conflictualMergeRels = rawMergeRels.join(conflictualIds, rawMergeRels.col("groupId").equalTo((Object)functions.col((String)"conflictualGroupId")), "left_semi");
            Dataset cleanMergeRels = this.spark.read().load(mergeRelPath).as(Encoders.bean(Relation.class)).join(conflictualIds, functions.expr((String)"source = conflictualGroupId OR target = conflictualGroupId"), "left_anti");
            Dataset splitMergeRels = conflictualMergeRels.groupByKey((MapFunction & Serializable)t -> (String)t.getAs("groupId"), Encoders.STRING()).flatMapGroups((FlatMapGroupsFunction & Serializable)(key, values) -> SparkRefineMergeRels.splitGroup(values, dedupConf), (Encoder)RowEncoder.apply((StructType)rowSchema));
            Dataset output = splitMergeRels.flatMap((FlatMapFunction & Serializable)r -> {
                String dedupId = r.getString(0);
                String id = r.getString(1);
                ArrayList<Relation> res = new ArrayList<Relation>();
                res.add(SparkRefineMergeRels.rel(dedupId, id, "merges", dedupConf));
                res.add(SparkRefineMergeRels.rel(id, dedupId, "isMergedIn", dedupConf));
                return res.iterator();
            }, Encoders.bean(Relation.class)).union(cleanMergeRels.as(Encoders.bean(Relation.class)));
            SparkRefineMergeRels.saveParquet(output, mergeRelPath + "_refined", SaveMode.Overwrite);
            SparkRefineMergeRels.renameParquet(this.spark, mergeRelPath + "_refined", mergeRelPath);
        }
    }

    public static Dataset<Row> appendNegativeConstraints(SparkSession spark, Dataset<Row> entities, String graphBasePath, String subEntity) {
        switch (subEntity) {
            case "organization": {
                Dataset entitiesWithNwo = entities.withColumn("negativeConstraints", functions.when((Column)functions.lower((Column)functions.col((String)"identifier")).contains((Object)"nwo"), (Object)functions.array((Column[])new Column[]{functions.lit((Object)"nwo")})).otherwise((Object)functions.array((Column[])new Column[0])));
                Dataset<Row> families = OpenorgsUtility.createFamilies(spark, graphBasePath + "/relation", "IsParentOf");
                return entitiesWithNwo.join(families, entitiesWithNwo.col("identifier").equalTo((Object)families.col("id")), "left").withColumn("negativeConstraints", functions.array_union((Column)functions.col((String)"negativeConstraints"), (Column)functions.when((Column)families.col("groupId").isNotNull(), (Object)functions.array((Column[])new Column[]{families.col("groupId").cast("string")})).otherwise((Object)functions.array((Column[])new Column[0])))).drop(new String[]{"id", "groupId", "nwoConstraint", "familyConstraint"});
            }
        }
        return spark.emptyDataFrame();
    }

    private static Dataset<Row> getConflictualIds(Dataset<Row> rawMergeRels) {
        return rawMergeRels.select("groupId", new String[]{"negativeConstraints"}).groupBy("groupId", new String[0]).agg(functions.flatten((Column)functions.collect_list((Column)functions.col((String)"negativeConstraints"))).alias("allConstraints"), new Column[0]).where(functions.size((Column)functions.col((String)"allConstraints")).notEqual((Object)functions.size((Column)functions.array_distinct((Column)functions.col((String)"allConstraints"))))).select(new Column[]{functions.col((String)"groupId").as("conflictualGroupId")});
    }

    private static Relation rel(String source, String target, String relClass, DedupConfig dedupConf) {
        String entityType = dedupConf.getWf().getEntityType();
        Relation r = new Relation();
        r.setSource(source);
        r.setTarget(target);
        r.setRelClass(relClass);
        r.setRelType(entityType + entityType.substring(0, 1).toUpperCase() + entityType.substring(1));
        r.setSubRelType("dedup");
        DataInfo info = new DataInfo();
        info.setDeletedbyinference(Boolean.valueOf(false));
        info.setInferred(Boolean.valueOf(true));
        info.setInvisible(Boolean.valueOf(false));
        info.setInferenceprovenance(dedupConf.getWf().getConfigurationId());
        Qualifier provenanceAction = new Qualifier();
        provenanceAction.setClassid("sysimport:dedup");
        provenanceAction.setClassname("sysimport:dedup");
        provenanceAction.setSchemeid("dnet:provenanceActions");
        provenanceAction.setSchemename("dnet:provenanceActions");
        info.setProvenanceaction(provenanceAction);
        r.setDataInfo(info);
        return r;
    }

    public static Iterator<Row> splitGroup(Iterator<Row> values, DedupConfig dedupConf) {
        ArrayList<Row> mergeRels = new ArrayList<Row>();
        List<Row> entities = StreamSupport.stream(Spliterators.spliteratorUnknownSize(values, 0), false).collect(Collectors.toList());
        double[][] simMatrix = SparkRefineMergeRels.getSimMatrix(entities, dedupConf);
        List<Set<Integer>> cliques = SparkRefineMergeRels.allMaxCliquesGreedyFinder(simMatrix);
        for (Set<Integer> clique : cliques) {
            if (clique.size() <= 1) continue;
            String newDedupId = IdGenerator.generate((String)entities.get((Integer)clique.stream().findFirst().orElseThrow(RuntimeException::new)).getAs("identifier"));
            for (Integer i : clique) {
                Row r = entities.get(i);
                mergeRels.add(RowFactory.create((Object[])new Object[]{newDedupId, r.getAs("identifier")}));
            }
        }
        return mergeRels.iterator();
    }

    public static List<Set<Integer>> allMaxCliquesGreedyFinder(double[][] simMatrix) {
        int n = simMatrix.length;
        boolean[] used = new boolean[n];
        ArrayList<Set<Integer>> cliques = new ArrayList<Set<Integer>>();
        for (int i = 0; i < n; ++i) {
            if (used[i]) continue;
            HashSet<Integer> clique = new HashSet<Integer>();
            clique.add(i);
            for (int j = 0; j < n; ++j) {
                if (i == j || used[j]) continue;
                boolean valid = true;
                double totalSim = 0.0;
                int count = 0;
                Iterator iterator = clique.iterator();
                while (iterator.hasNext()) {
                    int member = (Integer)iterator.next();
                    double sim = simMatrix[j][member];
                    if (sim == -1.0 || simMatrix[member][j] == -1.0) {
                        valid = false;
                        break;
                    }
                    totalSim += sim;
                    ++count;
                }
                if (!valid || count <= 0 || !(totalSim / (double)count >= THRESHOLD)) continue;
                clique.add(j);
            }
            Iterator iterator = clique.iterator();
            while (iterator.hasNext()) {
                int node = (Integer)iterator.next();
                used[node] = true;
            }
            cliques.add(clique);
        }
        return cliques;
    }

    public static double[][] getSimMatrix(List<Row> entities, DedupConfig dedupConf) {
        TreeProcessor treeProcessor = new TreeProcessor((Config)dedupConf);
        int n = entities.size();
        double[][] simMatrix = new double[n][n];
        for (int i = 0; i < n; ++i) {
            Row a = entities.get(i);
            for (int j = i; j < n; ++j) {
                double sim;
                Row b = entities.get(j);
                if (j == i) {
                    simMatrix[i][j] = 1.0;
                    continue;
                }
                simMatrix[i][j] = sim = SparkRefineMergeRels.weigh(a, b, treeProcessor);
                simMatrix[j][i] = sim;
            }
        }
        return simMatrix;
    }

    public static double weigh(Row a, Row b, TreeProcessor treeProcessor) {
        List negativeConstraintsB;
        List negativeConstraintsA = Arrays.asList(a.schema().fieldNames()).contains("negativeConstraints") ? a.getList(a.fieldIndex("negativeConstraints")) : Collections.emptyList();
        List list = negativeConstraintsB = Arrays.asList(b.schema().fieldNames()).contains("negativeConstraints") ? b.getList(b.fieldIndex("negativeConstraints")) : Collections.emptyList();
        if (!(negativeConstraintsA.isEmpty() || negativeConstraintsB.isEmpty() || Collections.disjoint(negativeConstraintsA, negativeConstraintsB))) {
            return -1.0;
        }
        return treeProcessor.computeScore(a, b);
    }
}

