/*
 * Decompiled with CFR 0.152.
 */
package se.lth.cs.srl.pipeline;

import java.io.File;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.zip.ZipEntry;
import java.util.zip.ZipException;
import java.util.zip.ZipFile;
import java.util.zip.ZipOutputStream;
import se.lth.cs.srl.Learn;
import se.lth.cs.srl.SemanticRoleLabeler;
import se.lth.cs.srl.corpus.ArgMap;
import se.lth.cs.srl.corpus.Predicate;
import se.lth.cs.srl.corpus.Sentence;
import se.lth.cs.srl.corpus.Word;
import se.lth.cs.srl.features.Feature;
import se.lth.cs.srl.features.FeatureGenerator;
import se.lth.cs.srl.features.FeatureSet;
import se.lth.cs.srl.io.AllCoNLL09Reader;
import se.lth.cs.srl.languages.Language;
import se.lth.cs.srl.ml.Model;
import se.lth.cs.srl.ml.liblinear.Label;
import se.lth.cs.srl.ml.liblinear.LibLinearLearningProblem;
import se.lth.cs.srl.options.LearnOptions;
import se.lth.cs.srl.options.ParseOptions;
import se.lth.cs.srl.pipeline.AbstractStep;
import se.lth.cs.srl.pipeline.ArgumentClassifier;
import se.lth.cs.srl.pipeline.ArgumentIdentifier;
import se.lth.cs.srl.pipeline.Pipeline;
import se.lth.cs.srl.pipeline.Step;
import se.lth.cs.srl.util.BrownCluster;

public class Reranker
extends SemanticRoleLabeler {
    public static final String FILENAME = "global";
    private final double alfa;
    private final boolean noPI;
    private final int aiBeam;
    private final int acBeam;
    private Model model;
    private List<String> argLabels;
    private List<Feature> aiFeatures;
    private List<Feature> acFeatures;
    private int sizeAIFeatures;
    private int sizeACFeatures;
    private int sizePipelineFeatures;
    private Map<String, Integer> calsMap;
    private int calsCounter = 1;
    private Pipeline pipeline;
    private ArgumentIdentifier aiModule;
    private ArgumentClassifier acModule;
    private int[] rankCount;
    private int zeroArgMapCount = 0;

    public Reranker(ParseOptions parseOptions) throws ZipException, IOException, ClassNotFoundException {
        this(parseOptions.global_alfa, parseOptions.skipPI, parseOptions.global_aiBeam, parseOptions.global_acBeam);
        ZipFile zipFile = new ZipFile(parseOptions.modelFile);
        this.pipeline = this.noPI ? Pipeline.fromZipFile(zipFile, new Step[]{Step.pd, Step.ai, Step.ac}) : Pipeline.fromZipFile(zipFile);
        System.out.println("Loading reranker from " + zipFile.getName());
        if (this.noPI) {
            System.out.println("Skipping predicate identification. Input is assumed to have predicates identified.");
        }
        this.argLabels = this.pipeline.getArgLabels();
        this.populateRerankerFeatureSets(this.pipeline.getFeatureSets(), this.pipeline.getFg());
        ObjectInputStream ois = new ObjectInputStream(zipFile.getInputStream(zipFile.getEntry(FILENAME)));
        this.model = (Model)ois.readObject();
        this.calsMap = (Map)ois.readObject();
        ois.close();
        int i = this.noPI ? 1 : 2;
        this.aiModule = (ArgumentIdentifier)this.pipeline.steps.get(i);
        this.acModule = (ArgumentClassifier)this.pipeline.steps.get(i + 1);
        zipFile.close();
    }

    private Reranker(double alfa, boolean noPI, int aiBeam, int acBeam) {
        this.alfa = alfa;
        this.noPI = noPI;
        this.aiBeam = aiBeam;
        this.acBeam = acBeam;
        this.rankCount = new int[aiBeam * acBeam];
    }

    public Reranker(LearnOptions learnOptions, ZipOutputStream zos) throws IOException {
        this(1.0, false, learnOptions.global_aiBeam, learnOptions.global_acBeam);
        List<Sentence> trainCorpus = new AllCoNLL09Reader(learnOptions.inputCorpus).readAll();
        BrownCluster bc = Learn.learnOptions.brownClusterFile == null ? null : new BrownCluster(Learn.learnOptions.brownClusterFile);
        Pipeline fullPipeline = Pipeline.trainNewPipeline(trainCorpus, learnOptions.getFeatureFiles(), zos, bc);
        FeatureGenerator fg = fullPipeline.getFg();
        this.argLabels = fullPipeline.getArgLabels();
        HashMap<Step, FeatureSet> featureSets = new HashMap<Step, FeatureSet>(fullPipeline.getFeatureSets());
        featureSets.remove((Object)Step.pi);
        fullPipeline = null;
        this.populateRerankerFeatureSets(featureSets, fg);
        LibLinearLearningProblem lp = new LibLinearLearningProblem(new File(learnOptions.tempDir, FILENAME), true);
        List<List<Sentence>> subCorpora = Reranker.partitionCorpus(trainCorpus, learnOptions.global_numberOfCrossTrain);
        this.calsMap = new HashMap<String, Integer>();
        for (int i = 0; i < subCorpora.size(); ++i) {
            List<Sentence> testCorpus = subCorpora.get(i);
            trainCorpus.clear();
            for (int j = 0; j < subCorpora.size(); ++j) {
                if (j == i) continue;
                trainCorpus.addAll((Collection<Sentence>)subCorpora.get(j));
            }
            Pipeline pipeline = Pipeline.trainNewPipeline(trainCorpus, fg, null, featureSets);
            ArgumentIdentifier aiModule = (ArgumentIdentifier)pipeline.steps.get(1);
            ArgumentClassifier acModule = (ArgumentClassifier)pipeline.steps.get(2);
            for (Sentence sen : testCorpus) {
                for (Predicate pred : sen.getPredicates()) {
                    Collection<Integer> indices;
                    ++this.predCount;
                    List<ArgMap> negatives = acModule.beamSearch(pred, aiModule.beamSearch(pred, learnOptions.global_aiBeam), learnOptions.global_acBeam);
                    HashSet<ArgMap> positives = new HashSet<ArgMap>();
                    double score = Reranker.partitionBestArgMaps(negatives, pred.getArgMap(), positives);
                    if (learnOptions.global_insertGoldMapForTrain && score != 1.0) {
                        positives.add(new ArgMap(pred.getArgMap()));
                    }
                    for (ArgMap am : positives) {
                        indices = this.collectPipelineFeatureIndices(pred, am, new ArrayList<Integer>());
                        this.addAndCollectGlobalFeatures(pred, am, indices);
                        Collections.sort((List)indices);
                        lp.addInstance(AbstractStep.POSITIVE, indices);
                    }
                    for (ArgMap am : negatives) {
                        indices = this.collectPipelineFeatureIndices(pred, am, new ArrayList<Integer>());
                        this.addAndCollectGlobalFeatures(pred, am, indices);
                        Collections.sort((List)indices);
                        lp.addInstance(AbstractStep.NEGATIVE, indices);
                    }
                }
            }
        }
        lp.done();
        this.model = lp.train();
        zos.putNextEntry(new ZipEntry(FILENAME));
        ObjectOutputStream oos = new ObjectOutputStream(zos);
        oos.writeObject(this.model);
        oos.writeObject(this.calsMap);
        oos.flush();
    }

    @Override
    protected void parse(Sentence sen) {
        this.pipeline.steps.get(0).parse(sen);
        if (!this.noPI) {
            this.pipeline.steps.get(1).parse(sen);
        }
        for (Predicate pred : sen.getPredicates()) {
            int bestCandidateIndex;
            List<ArgMap> candidates = this.acModule.beamSearch(pred, this.aiModule.beamSearch(pred, this.aiBeam), this.acBeam);
            for (ArgMap argMap : candidates) {
                ArrayList<Integer> indices = new ArrayList<Integer>();
                this.collectPipelineFeatureIndices(pred, argMap, indices);
                this.collectGlobalFeatures(pred, argMap, indices);
                List<Label> labels = this.model.classifyProb(indices);
                for (Label label : labels) {
                    if (label.getLabel().equals(AbstractStep.NEGATIVE)) continue;
                    argMap.setRerankProb(label.getProb());
                    argMap.resetProb();
                }
            }
            int n = bestCandidateIndex = this.softMax(candidates);
            this.rankCount[n] = this.rankCount[n] + 1;
            ArgMap bestCandidate = candidates.get(bestCandidateIndex);
            if (bestCandidate.size() == 0) {
                ++this.zeroArgMapCount;
            }
            pred.setArgMap(bestCandidate);
        }
    }

    private int softMax(List<ArgMap> argmaps) {
        for (ArgMap am : argmaps) {
            double prob = am.getIdProb();
            if (am.size() != 0) {
                prob *= Math.pow(am.getLblProb(), 1.0 / (double)am.size());
            }
            am.setProb(prob);
        }
        double bestScore = 0.0;
        int bestIndex = -1;
        int size = argmaps.size();
        for (int i = 0; i < size; ++i) {
            double weightedRerankProb;
            ArgMap am = argmaps.get(i);
            double localProb = am.getProb();
            double score = localProb * (weightedRerankProb = Math.pow(am.getRerankProb(), this.alfa));
            if (score > bestScore) {
                bestIndex = i;
                bestScore = score;
                continue;
            }
            if (score != bestScore) continue;
            System.out.println("!same score..");
        }
        return bestIndex;
    }

    private Collection<Integer> collectPipelineFeatureIndices(Predicate pred, ArgMap argMap, Collection<Integer> indices) {
        for (Word arg : argMap.keySet()) {
            Integer aiOffset = 0;
            HashSet<Integer> currentInstance = new HashSet<Integer>();
            for (Feature f : this.aiFeatures) {
                f.addFeatures(currentInstance, pred, arg, aiOffset, false);
                aiOffset = aiOffset + f.size(false);
            }
            Integer acOffset = this.sizeAIFeatures + this.sizeACFeatures * this.argLabels.indexOf(argMap.get(arg));
            for (Feature f : this.acFeatures) {
                f.addFeatures(currentInstance, pred, arg, acOffset, false);
                acOffset = acOffset + f.size(false);
            }
            indices.addAll(currentInstance);
        }
        return indices;
    }

    private void addAndCollectGlobalFeatures(Predicate pred, ArgMap argMap, Collection<Integer> indices) {
        String cals = Language.getLanguage().getCoreArgumentLabelSequence(pred, argMap);
        Integer index = this.calsMap.get(cals);
        if (index == null) {
            this.calsMap.put(cals, this.calsCounter);
            index = this.calsCounter++;
        }
        indices.add(this.sizePipelineFeatures + index);
    }

    private void collectGlobalFeatures(Predicate pred, ArgMap argMap, Collection<Integer> indices) {
        String cals = Language.getLanguage().getCoreArgumentLabelSequence(pred, argMap);
        Integer index = this.calsMap.get(cals);
        if (index != null) {
            indices.add(this.sizePipelineFeatures + index);
        }
    }

    private void populateRerankerFeatureSets(Map<Step, FeatureSet> featureSets, FeatureGenerator fg) {
        this.aiFeatures = new ArrayList<Feature>();
        this.acFeatures = new ArrayList<Feature>();
        for (Map.Entry entry : featureSets.get((Object)Step.ai).entrySet()) {
            this.aiFeatures.addAll((Collection)entry.getValue());
        }
        for (Map.Entry entry : featureSets.get((Object)Step.ac).entrySet()) {
            this.acFeatures.addAll((Collection)entry.getValue());
        }
        this.sizeAIFeatures = 0;
        this.sizeACFeatures = 0;
        for (Feature f : this.aiFeatures) {
            this.sizeAIFeatures += f.size(false);
        }
        for (Feature f : this.acFeatures) {
            this.sizeACFeatures += f.size(false);
        }
        this.sizePipelineFeatures = this.sizeAIFeatures + this.argLabels.size() * this.sizeACFeatures;
    }

    private static double partitionBestArgMaps(List<ArgMap> candidates, Map<Word, String> goldStandard, Set<ArgMap> bestArgMaps) {
        double bestScore = 0.0;
        for (ArgMap candidate : candidates) {
            double curScore = candidate.computeScore(goldStandard);
            if (curScore > bestScore) {
                bestScore = curScore;
                bestArgMaps.clear();
                bestArgMaps.add(candidate);
                continue;
            }
            if (curScore != bestScore) continue;
            bestArgMaps.add(candidate);
        }
        candidates.removeAll(bestArgMaps);
        return bestScore;
    }

    private static List<List<Sentence>> partitionCorpus(Iterable<Sentence> sentences, int numberOfPartitions) {
        ArrayList<List<Sentence>> subCorpora = new ArrayList<List<Sentence>>();
        for (int i = 0; i < numberOfPartitions; ++i) {
            subCorpora.add(new ArrayList());
        }
        if (Learn.learnOptions.deterministicReranker) {
            int senCount = 0;
            for (Sentence s : sentences) {
                ((List)subCorpora.get(senCount % numberOfPartitions)).add(s);
                ++senCount;
            }
        } else {
            for (Sentence s : sentences) {
                int index = (int)Math.floor(Math.random() * (double)numberOfPartitions);
                ((List)subCorpora.get(index)).add(s);
            }
        }
        return subCorpora;
    }

    @Override
    protected String getSubStatus() {
        StringBuilder ret = new StringBuilder("Reranker status:\n");
        ret.append("AI beam:\t\t" + this.aiBeam + "\n");
        ret.append("AC beam:\t\t" + this.acBeam + "\n");
        ret.append("Alfa:\t\t\t" + this.alfa + "\n");
        ret.append("\n");
        ret.append("Reranker choices:\n");
        ret.append("Rank\tFrequency\n");
        for (int i = 0; i < this.rankCount.length; ++i) {
            ret.append(i + 1 + "\t" + this.rankCount[i] + "\n");
        }
        ret.append("\n");
        ret.append("Number of zero size argmaps:\t" + this.zeroArgMapCount + "\n");
        return ret.toString();
    }
}

