package eu.dnetlib.data.transform;

import com.wcohen.ss.JaroWinkler;
import eu.dnetlib.data.bulktag.Pair;
import eu.dnetlib.data.proto.FieldTypeProtos.Author;
import eu.dnetlib.data.proto.FieldTypeProtos.KeyValue;
import eu.dnetlib.pace.model.Person;
import org.apache.commons.lang3.StringUtils;

import java.text.Normalizer;
import java.util.*;
import java.util.function.Function;

import static java.util.stream.Collectors.*;

public class AuthorMerger {

    private static final Double THRESHOLD = 0.95;
    private static final String ORCID = "orcid";
    private static final int MAX_AUTHORS = 200;

    public static List<Author> merge(final Collection<List<Author>> authors, final double threshold) {
        return merge(authors, THRESHOLD);
    }

    public static List<Author> merge(final Collection<List<Author>> authors) {
        return doMerge(
                authors.stream()
                        .map(group -> group.stream()
                                .map(AuthorMerger::fixORCID)
                                .collect(toList()))
                        .collect(toList()));
    }

    private static List<Author> doMerge(final Collection<List<Author>> authors) {
        final List<Author> res = new ArrayList<>();

        if (authors.isEmpty()) {
            return res;
        }

        if (authors.size() == 1) {
            return authors.iterator().next();
        }

        final TreeMap<Integer, List<List<Author>>> byOrcidCount = new TreeMap<>(
                authors.stream()
                        .collect(groupingBy(AuthorMerger::countOrcid))
                        .entrySet().stream()
                        .filter(e -> e.getKey() > 0)
                        .collect(toMap(
                                Map.Entry::getKey,
                                Map.Entry::getValue
                        )));

        if (byOrcidCount == null || byOrcidCount.isEmpty()) {
            return authors.iterator().next();
        }
        final Map.Entry<Integer, List<List<Author>>> mostOrcid = byOrcidCount.lastEntry();

        if (mostOrcid.getKey() > 0) {

            final List<Author> pivots = mostOrcid.getValue().iterator().next();

            res.addAll(mostOrcid.getValue().iterator().next().stream()
                    .filter(a -> hasOrcid(a))
                    .collect(toList()));

            if (pivots.size() == res.size()) {
                return res;
            }

            final Collection<Author> authorList = authors.stream()
                    .filter(g -> !g.equals(pivots))
                    .flatMap(List::stream)
                    .filter(a -> hasOrcid(a))
                    .limit(MAX_AUTHORS)
                    .map(a -> {
                        final String orcid = a.getPidList().stream()
                                .filter(p -> p.getKey().equalsIgnoreCase(ORCID))
                                .findFirst()
                                .get().getValue();
                        return new Pair<String, Author>(orcid, a);
                    })
                    .collect(toMap(
                            p -> p.getFst(),
                            p -> p.getSnd(),
                            (p1, p2) -> p2))
                    .values();

            pivots.stream().filter(a -> !hasOrcid(a)).forEach(pivot -> {
                final Author.Builder b = Author.newBuilder(pivot);
                authorList.parallelStream()
                        .map(a -> {
                            return new Pair<Double, Author>(sim(a, pivot), a);
                        })
                        .filter(p -> p.getFst() >= THRESHOLD)
                        .forEach(p -> {
                            b.mergeFrom(p.getSnd());
                        });

                Collection<KeyValue> pids = b.getPidList().stream()
                        .collect(toMap(
                                kv -> kv.getKey(),
                                Function.identity(),
                                (kv1, kv2) -> kv2
                        )).values();
                b.clearPid();
                b.addAllPid(pids);

                res.add(b.build());
            });
        }

        return res;
    }

    private static Author fixORCID(final Author author) {
        final Author.Builder b = Author.newBuilder(author);
        for(KeyValue.Builder pid : b.getPidBuilderList()) {
            if (pid.getKey().toLowerCase().contains(ORCID)) {
                pid.setKey("ORCID");
                if (pid.getValue().contains("orcid.org")) {
                    pid.setValue(StringUtils.substringAfterLast(pid.getValue(), "/"));

                }
            }
        }
        return b.build();
    }

    private static int countOrcid(final List<Author> authors) {
        return authors.stream()
                .map(a -> {
                    return hasOrcid(a) ? 1 : 0;
                })
                .mapToInt(Integer::intValue)
                .sum();
    }

    private static boolean hasOrcid(Author a) {
        return a.getPidList().stream().anyMatch(p -> p.getKey().equalsIgnoreCase(ORCID));
    }

    private static Double sim(Author a, Author b) {

        final Person pa = parse(a);
        final Person pb = parse(b);

        if (pa.isAccurate() & pb.isAccurate()) {
            return new JaroWinkler().score(
                    normalize(pa.getSurnameString()),
                    normalize(pb.getSurnameString()));
        } else {
            return new JaroWinkler().score(
                        normalize(pa.getNormalisedFullname()),
                        normalize(pb.getNormalisedFullname()));
        }
    }

    private static Person parse(Author author) {
        if (author.hasSurname()) {
            return new Person(author.getSurname() + ", " + author.getName(), false);
        } else {
            return new Person(author.getFullname(), false);
        }
    }

    private static String normalize(final String s) {
        return nfd(s).toLowerCase()
                // do not compact the regexes in a single expression, would cause StackOverflowError in case of large input strings
                .replaceAll("(\\W)+", " ")
                .replaceAll("(\\p{InCombiningDiacriticalMarks})+", " ")
                .replaceAll("(\\p{Punct})+", " ")
                .replaceAll("(\\d)+", " ")
                .replaceAll("(\\n)+", " ")
                .trim();
    }

    private static String nfd(final String s) {
        return Normalizer.normalize(s, Normalizer.Form.NFD);
    }


}
