/*
 * Decompiled with CFR 0.152.
 */
package se.lth.cs.srl.ml.liblinear;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.Serializable;
import java.util.Collection;
import java.util.HashMap;

public abstract class WeightVector
implements Serializable {
    private static final long serialVersionUID = 1L;
    protected double bias;
    protected int features;
    protected int classes;

    public WeightVector(double bias, int features, int classes) {
        this.bias = bias;
        this.classes = classes;
        this.features = features;
    }

    public static WeightVector parseWeights(BufferedReader in, int features, int classes, double bias, boolean sparse) throws IOException {
        if (sparse) {
            if (classes == 2) {
                return new BinarySparseVector(in, features, bias);
            }
            return new MultipleSparseVector(in, features, classes, bias);
        }
        if (classes == 2) {
            return new BinaryLibLinearVector(in, features, bias);
        }
        return new MultipleLibLinearVector(in, features, classes, bias);
    }

    public abstract double[] computeAllProbs(Collection<Integer> var1);

    public abstract short computeBestClass(Collection<Integer> var1);

    public static class MultipleSparseVector
    extends MultipleVector {
        private static final long serialVersionUID = 1L;
        private HashMap<Integer, WeightArray> weightMap = new HashMap();

        public MultipleSparseVector(MultipleLibLinearVector vec) {
            super(vec.bias, vec.features, vec.classes);
            for (int i = 0; i < vec.features; ++i) {
                WeightArray wa = new WeightArray(this.classes);
                boolean notNull = false;
                for (int j = 0; j < vec.classes; ++j) {
                    wa.weights[j] = vec.weights[j][i];
                    notNull = notNull || wa.weights[j] != 0.0f;
                }
                if (!notNull) continue;
                this.weightMap.put(i, wa);
            }
            if (this.bias > 0.0) {
                WeightArray wa = new WeightArray(this.classes);
                for (int j = 0; j < this.classes; ++j) {
                    wa.weights[j] = vec.weights[j][this.features];
                }
                this.weightMap.put(this.features, wa);
            }
        }

        public MultipleSparseVector(BufferedReader in, int features, int classes, double bias) throws IOException {
            super(bias, features, classes);
            String str;
            int i = 0;
            while ((str = in.readLine()) != null) {
                WeightArray weights = new WeightArray(classes);
                int j = 0;
                boolean nonZero = false;
                for (String w : str.split(" ")) {
                    float f = Float.parseFloat(w);
                    weights.weights[j++] = f;
                    nonZero = nonZero || f != 0.0f;
                }
                if (nonZero) {
                    this.weightMap.put(i, weights);
                }
                ++i;
            }
        }

        @Override
        protected double[] computeScores(Collection<Integer> ints) {
            double[] ret = new double[this.classes];
            for (int i = 0; i < this.classes; ++i) {
                double curvalue = this.bias > 0.0 ? (this.weightMap.containsKey(this.features) ? this.bias * (double)this.weightMap.get((Object)Integer.valueOf((int)this.features)).weights[i] : 0.0) : 0.0;
                for (Integer in : ints) {
                    if (!this.weightMap.containsKey(in - 1) || in - 1 >= this.features) continue;
                    curvalue += (double)this.weightMap.get((Object)Integer.valueOf((int)(in.intValue() - 1))).weights[i];
                }
                ret[i] = curvalue;
            }
            return ret;
        }

        private static class WeightArray
        implements Serializable {
            private static final long serialVersionUID = 1L;
            float[] weights;

            public WeightArray(int size) {
                this.weights = new float[size];
            }
        }
    }

    public static class MultipleLibLinearVector
    extends MultipleVector {
        private static final long serialVersionUID = 1L;
        private float[][] weights;

        public MultipleLibLinearVector(BufferedReader in, int features, int classes, double bias) throws IOException {
            super(bias, features, classes);
            String str;
            this.weights = new float[classes][features + 1];
            int i = 0;
            while ((str = in.readLine()) != null) {
                String[] values = str.split(" ");
                for (int j = 0; j < classes; ++j) {
                    this.weights[j][i] = Float.parseFloat(values[j]);
                }
                ++i;
            }
        }

        @Override
        protected double[] computeScores(Collection<Integer> ints) {
            double[] ret = new double[this.classes];
            for (int i = 0; i < this.classes; ++i) {
                double curvalue = this.bias > 0.0 ? this.bias * (double)this.weights[i][this.features] : 0.0;
                for (Integer in : ints) {
                    if (in - 1 >= this.features) continue;
                    curvalue += (double)this.weights[i][in - 1];
                }
                ret[i] = curvalue;
            }
            return ret;
        }
    }

    public static abstract class MultipleVector
    extends WeightVector {
        private static final long serialVersionUID = 1L;

        public MultipleVector(double bias, int features, int classes) {
            super(bias, features, classes);
        }

        protected abstract double[] computeScores(Collection<Integer> var1);

        @Override
        public double[] computeAllProbs(Collection<Integer> ints) {
            int i;
            double[] ret = new double[this.classes];
            double[] scores = this.computeScores(ints);
            double sum = 0.0;
            for (i = 0; i < this.classes; i = (int)((short)(i + 1))) {
                ret[i] = 1.0 / (1.0 + Math.exp(-scores[i]));
                sum += ret[i];
            }
            for (i = 0; i < this.classes; i = (int)((short)(i + 1))) {
                int n = i;
                ret[n] = ret[n] / sum;
            }
            return ret;
        }

        @Override
        public short computeBestClass(Collection<Integer> ints) {
            int ret = 0;
            double[] scores = this.computeScores(ints);
            for (int i = 0; i < this.classes; i = (int)((short)(i + 1))) {
                if (!(scores[i] > scores[ret])) continue;
                ret = i;
            }
            return (short)ret;
        }
    }

    public static class BinarySparseVector
    extends BinaryVector {
        private static final long serialVersionUID = 1L;
        private HashMap<Integer, Float> weightMap = new HashMap();

        public BinarySparseVector(BinaryLibLinearVector vec) {
            super(vec.bias, vec.features, 2);
            for (int i = 0; i < vec.features; ++i) {
                if (vec.weights[i] == 0.0f) continue;
                this.weightMap.put(i, Float.valueOf(vec.weights[i]));
            }
            if (this.bias > 0.0) {
                this.weightMap.put(this.features, Float.valueOf(vec.weights[this.features]));
            }
        }

        public BinarySparseVector(BufferedReader in, int features, double bias) throws IOException {
            super(bias, features, 2);
            String str;
            int i = 0;
            while ((str = in.readLine()) != null) {
                Float f = Float.valueOf(Float.parseFloat(str));
                if (f.floatValue() != 0.0f) {
                    this.weightMap.put(i, f);
                }
                ++i;
            }
        }

        @Override
        protected double computeScore(Collection<Integer> ints) {
            double sum = this.bias > 0.0 ? (this.weightMap.containsKey(this.features) ? this.bias * (double)this.weightMap.get(this.features).floatValue() : 0.0) : 0.0;
            for (Integer i : ints) {
                if (i - 1 >= this.features || !this.weightMap.containsKey(i - 1)) continue;
                sum += (double)this.weightMap.get(i - 1).floatValue();
            }
            return sum;
        }
    }

    public static class BinaryLibLinearVector
    extends BinaryVector {
        private static final long serialVersionUID = 1L;
        private float[] weights;

        public BinaryLibLinearVector(BufferedReader in, int features, double bias) throws IOException {
            super(bias, features, 2);
            String str;
            this.weights = new float[features + 1];
            int i = 0;
            while ((str = in.readLine()) != null) {
                this.weights[i] = Float.parseFloat(str);
                ++i;
            }
        }

        @Override
        protected double computeScore(Collection<Integer> ints) {
            double sum = this.bias > 0.0 ? this.bias * (double)this.weights[this.features] : 0.0;
            for (Integer i : ints) {
                if (i - 1 >= this.features) continue;
                sum += (double)this.weights[i - 1];
            }
            return sum;
        }
    }

    public static abstract class BinaryVector
    extends WeightVector {
        private static final long serialVersionUID = 1L;

        public BinaryVector(double bias, int features, int classes) {
            super(bias, features, classes);
        }

        protected abstract double computeScore(Collection<Integer> var1);

        @Override
        public double[] computeAllProbs(Collection<Integer> ints) {
            double prob = 1.0 / (1.0 + Math.exp(-this.computeScore(ints)));
            double[] ret = new double[]{prob, 1.0 - prob};
            return ret;
        }

        @Override
        public short computeBestClass(Collection<Integer> ints) {
            if (this.computeScore(ints) > 0.0) {
                return 0;
            }
            return 1;
        }
    }
}

