/*
 * Decompiled with CFR 0.152.
 */
package se.lth.cs.srl.pipeline;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.TreeSet;
import se.lth.cs.srl.Learn;
import se.lth.cs.srl.corpus.ArgMap;
import se.lth.cs.srl.corpus.Predicate;
import se.lth.cs.srl.corpus.Sentence;
import se.lth.cs.srl.corpus.Word;
import se.lth.cs.srl.features.FeatureSet;
import se.lth.cs.srl.ml.LearningProblem;
import se.lth.cs.srl.ml.Model;
import se.lth.cs.srl.ml.liblinear.Label;
import se.lth.cs.srl.pipeline.ArgumentStep;
import se.lth.cs.srl.util.Constant;

public class ArgumentIdentifier
extends ArgumentStep {
    private static final String FILEPREFIX = "ai_";

    public ArgumentIdentifier(FeatureSet fs) {
        super(fs);
    }

    @Override
    public void extractInstances(Sentence s) {
        for (Predicate pred : s.getPredicates()) {
            int size = s.size();
            for (int i = 1; i < size; ++i) {
                this.addInstance(pred, (Word)s.get(i));
            }
        }
    }

    @Override
    public void parse(Sentence s) {
        for (Predicate pred : s.getPredicates()) {
            int size = s.size();
            for (int i = 1; i < size; ++i) {
                Word arg = (Word)s.get(i);
                Integer label = super.classifyInstance(pred, arg);
                if (!label.equals(POSITIVE)) continue;
                pred.addArgMap(arg, "ARG");
            }
        }
    }

    @Override
    protected Integer getLabel(Predicate pred, Word arg) {
        return pred.getArgMap().containsKey(arg) ? POSITIVE : NEGATIVE;
    }

    @Override
    public void prepareLearning() {
        super.prepareLearning(FILEPREFIX);
    }

    @Override
    protected String getModelFileName() {
        return "ai_.models";
    }

    List<ArgMap> beamSearch(Predicate pred, int beamSize) {
        ArrayList<ArgMap> candidates = new ArrayList<ArgMap>();
        candidates.add(new ArgMap());
        Sentence s = pred.getMySentence();
        TreeSet<ArgMap> newCandidates = new TreeSet<ArgMap>(ArgMap.REVERSE_PROB_COMPARATOR);
        String POSPrefix = super.getPOSPrefix(pred.getPOS());
        if (POSPrefix == null) {
            POSPrefix = this.featureSet.POSPrefixes[0];
        }
        Model model = (Model)this.models.get(POSPrefix);
        int size = s.size();
        for (int i = 1; i < size; ++i) {
            newCandidates.clear();
            Word arg = (Word)s.get(i);
            Collection<Integer> indices = super.collectIndices(pred, arg, POSPrefix, new TreeSet<Integer>());
            List<Label> probs = model.classifyProb(indices);
            for (ArgMap argmap : candidates) {
                for (Label label : probs) {
                    ArgMap branch = new ArgMap(argmap);
                    if (label.getLabel().equals(POSITIVE)) {
                        branch.put(arg, "ARG", label.getProb());
                    } else {
                        branch.multiplyProb(label.getProb());
                    }
                    newCandidates.add(branch);
                }
            }
            candidates.clear();
            Iterator it = newCandidates.iterator();
            for (int j = 0; j < beamSize && it.hasNext(); ++j) {
                candidates.add((ArgMap)it.next());
            }
        }
        for (ArgMap argmap : candidates) {
            argmap.setIdProb(argmap.getProb());
            argmap.resetProb();
        }
        return candidates;
    }

    protected void addInstance(Predicate pred, Word arg) {
        if (arg.getIgnoredArgument()) {
            return;
        }
        String POSPrefix = this.getPOSPrefix(pred.getPOS());
        if (POSPrefix == null) {
            if (Learn.learnOptions.skipNonMatchingPredicates) {
                return;
            }
            POSPrefix = this.featureSet.POSPrefixes[0];
        }
        LearningProblem lp = (LearningProblem)this.learningProblems.get(POSPrefix);
        Collection<Integer> indices = this.collectIndices(pred, arg, POSPrefix, new TreeSet<Integer>());
        if (indices != null) {
            if (Constant.PI_useStartExtention != -1) {
                lp.addInstance(this.getLabel(pred, arg), indices, arg.getSparseEntensions(), Constant.AI_useStartExtention, Constant.AI_useEndExtention, this.offset);
            } else {
                lp.addInstance(this.getLabel(pred, arg), indices);
            }
        }
    }
}

