/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst;

import cc.mallet.fst.Transducer;
import cc.mallet.types.Alphabet;
import cc.mallet.types.Multinomial;
import cc.mallet.types.Sequence;
import cc.mallet.util.MalletLogger;
import com.carrotsearch.hppc.IntObjectHashMap;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.logging.Logger;

public class FeatureTransducer
extends Transducer {
    private static Logger logger = MalletLogger.getLogger(FeatureTransducer.class.getName());
    Alphabet inputAlphabet;
    Alphabet outputAlphabet;
    ArrayList<State> states = new ArrayList();
    ArrayList<State> initialStates = new ArrayList();
    HashMap<String, State> name2state = new HashMap();
    Multinomial.Estimator initialStateCounts;
    Multinomial.Estimator finalStateCounts;
    boolean trainable = false;
    private static final long serialVersionUID = 1L;

    public FeatureTransducer(Alphabet inputAlphabet, Alphabet outputAlphabet) {
        this.inputAlphabet = inputAlphabet;
        this.outputAlphabet = outputAlphabet;
    }

    public FeatureTransducer(Alphabet dictionary) {
        this(dictionary, dictionary);
    }

    public FeatureTransducer() {
        this(new Alphabet());
    }

    public Alphabet getInputAlphabet() {
        return this.inputAlphabet;
    }

    public Alphabet getOutputAlphabet() {
        return this.outputAlphabet;
    }

    public void addState(String name, double initialWeight, double finalWeight, int[] inputs, int[] outputs, double[] weights, String[] destinationNames) {
        if (this.name2state.get(name) != null) {
            throw new IllegalArgumentException("State with name `" + name + "' already exists.");
        }
        State s = new State(name, this.states.size(), initialWeight, finalWeight, inputs, outputs, weights, destinationNames, this);
        this.states.add(s);
        if (initialWeight < Double.NEGATIVE_INFINITY) {
            this.initialStates.add(s);
        }
        this.name2state.put(name, s);
        this.setTrainable(false);
    }

    public void addState(String name, double initialWeight, double finalWeight, Object[] inputs, Object[] outputs, double[] weights, String[] destinationNames) {
        this.addState(name, initialWeight, finalWeight, this.inputAlphabet.lookupIndices(inputs, true), this.outputAlphabet.lookupIndices(outputs, true), weights, destinationNames);
    }

    @Override
    public int numStates() {
        return this.states.size();
    }

    @Override
    public Transducer.State getState(int index) {
        return this.states.get(index);
    }

    @Override
    public Iterator<State> initialStateIterator() {
        return this.initialStates.iterator();
    }

    public boolean isTrainable() {
        return this.trainable;
    }

    public void setTrainable(boolean f) {
        this.trainable = f;
        if (f) {
            this.initialStateCounts = new Multinomial.LaplaceEstimator(this.states.size());
            this.finalStateCounts = new Multinomial.LaplaceEstimator(this.states.size());
        } else {
            this.initialStateCounts = null;
            this.finalStateCounts = null;
        }
        for (int i = 0; i < this.numStates(); ++i) {
            ((State)this.getState(i)).setTrainable(f);
        }
    }

    public void reset() {
        if (this.trainable) {
            this.initialStateCounts.reset();
            this.finalStateCounts.reset();
            for (int i = 0; i < this.numStates(); ++i) {
                ((State)this.getState(i)).reset();
            }
        }
    }

    public void estimate() {
        if (this.initialStateCounts == null || this.finalStateCounts == null) {
            throw new IllegalStateException("This transducer not currently trainable.");
        }
        Multinomial initialStateDistribution = this.initialStateCounts.estimate();
        Multinomial finalStateDistribution = this.finalStateCounts.estimate();
        for (int i = 0; i < this.states.size(); ++i) {
            State s = this.states.get(i);
            s.initialWeight = initialStateDistribution.logProbability(i);
            s.finalWeight = finalStateDistribution.logProbability(i);
            s.estimate();
        }
    }

    protected class Transition {
        int input;
        int output;
        double weight;
        int index;
        String destinationName;
        State destination = null;
        Transition nextWithSameInput;

        public Transition(int input, int output, double weight, State sourceState, String destinationName) {
            this.input = input;
            this.output = output;
            this.weight = weight;
            this.nextWithSameInput = (Transition)sourceState.input2transitions.get(input);
            sourceState.input2transitions.put(input, this);
            this.destinationName = destinationName;
        }

        public State getDestinationState() {
            if (this.destination == null) {
                this.destination = FeatureTransducer.this.name2state.get(this.destinationName);
                assert (this.destination != null);
            }
            return this.destination;
        }
    }

    protected class TransitionIterator
    extends Transducer.TransitionIterator {
        int index;
        Transition transition;
        State source;
        int input;

        public TransitionIterator(State source) {
            this.source = source;
            this.input = -1;
            this.index = -1;
            this.transition = null;
        }

        public TransitionIterator(State source, int input) {
            this.source = source;
            this.input = input;
            this.index = -2;
            this.transition = (Transition)source.input2transitions.get(input);
        }

        @Override
        public boolean hasNext() {
            if (this.index >= -1) {
                return this.index < this.source.transitions.length - 1;
            }
            return this.index == -2 ? this.transition != null : this.transition.nextWithSameInput != null;
        }

        @Override
        public Transducer.State nextState() {
            if (this.index >= -1) {
                this.transition = this.source.transitions[++this.index];
            } else if (this.index == -2) {
                this.index = -3;
            } else {
                this.transition = this.transition.nextWithSameInput;
            }
            return this.transition.getDestinationState();
        }

        @Override
        public int getIndex() {
            return this.index;
        }

        @Override
        public Object getInput() {
            return FeatureTransducer.this.inputAlphabet.lookupObject(this.transition.input);
        }

        @Override
        public Object getOutput() {
            return FeatureTransducer.this.outputAlphabet.lookupObject(this.transition.output);
        }

        @Override
        public double getWeight() {
            return this.transition.weight;
        }

        @Override
        public Transducer.State getSourceState() {
            return this.source;
        }

        @Override
        public Transducer.State getDestinationState() {
            return this.transition.getDestinationState();
        }

        public void incrementCount(double count) {
            logger.info("FeatureTransducer incrementCount " + count);
            this.source.transitionCounts.increment(this.transition.index, count);
        }
    }

    public class State
    extends Transducer.State {
        String name;
        int index;
        double initialWeight;
        double finalWeight;
        Transition[] transitions;
        IntObjectHashMap input2transitions;
        Multinomial.Estimator transitionCounts;
        FeatureTransducer transducer;
        private static final long serialVersionUID = 1L;

        protected State(String name, int index, double initialWeight, double finalWeight, int[] inputs, int[] outputs, double[] weights, String[] destinationNames, FeatureTransducer transducer) {
            assert (inputs.length == outputs.length && inputs.length == weights.length && inputs.length == destinationNames.length);
            this.transducer = transducer;
            this.name = name;
            this.index = index;
            this.initialWeight = initialWeight;
            this.finalWeight = finalWeight;
            this.transitions = new Transition[inputs.length];
            this.input2transitions = new IntObjectHashMap();
            this.transitionCounts = null;
            for (int i = 0; i < inputs.length; ++i) {
                this.transitions[i] = new Transition(inputs[i], outputs[i], weights[i], this, destinationNames[i]);
                this.transitions[i].index = i;
            }
        }

        @Override
        public Transducer getTransducer() {
            return this.transducer;
        }

        @Override
        public double getInitialWeight() {
            return this.initialWeight;
        }

        @Override
        public double getFinalWeight() {
            return this.finalWeight;
        }

        @Override
        public void setInitialWeight(double v) {
            this.initialWeight = v;
        }

        @Override
        public void setFinalWeight(double v) {
            this.finalWeight = v;
        }

        private void setTrainable(boolean f) {
            this.transitionCounts = f ? new Multinomial.LaplaceEstimator(this.transitions.length) : null;
        }

        public Multinomial.Estimator getTransitionEstimator() {
            return this.transitionCounts;
        }

        private void reset() {
            if (this.transitionCounts != null) {
                this.transitionCounts.reset();
            }
        }

        @Override
        public int getIndex() {
            return this.index;
        }

        @Override
        public Transducer.TransitionIterator transitionIterator(Sequence input, int inputPosition, Sequence output, int outputPosition) {
            if (inputPosition < 0 || outputPosition < 0 || output != null) {
                throw new UnsupportedOperationException("Not yet implemented.");
            }
            if (input == null) {
                return this.transitionIterator();
            }
            return this.transitionIterator(input, inputPosition);
        }

        @Override
        public Transducer.TransitionIterator transitionIterator(Sequence inputSequence, int inputPosition) {
            int inputIndex = FeatureTransducer.this.inputAlphabet.lookupIndex(inputSequence.get(inputPosition), false);
            if (inputIndex == -1) {
                throw new IllegalArgumentException("Input not in dictionary.");
            }
            return this.transitionIterator(inputIndex);
        }

        public Transducer.TransitionIterator transitionIterator(Object o) {
            int inputIndex = FeatureTransducer.this.inputAlphabet.lookupIndex(o, false);
            if (inputIndex == -1) {
                throw new IllegalArgumentException("Input not in dictionary.");
            }
            return this.transitionIterator(inputIndex);
        }

        public Transducer.TransitionIterator transitionIterator(int input) {
            return new TransitionIterator(this, input);
        }

        @Override
        public Transducer.TransitionIterator transitionIterator() {
            return new TransitionIterator(this);
        }

        @Override
        public String getName() {
            return this.name;
        }

        public void incrementInitialCount(double count) {
            if (FeatureTransducer.this.initialStateCounts == null) {
                throw new IllegalStateException("Transducer is not currently trainable.");
            }
            FeatureTransducer.this.initialStateCounts.increment(this.index, count);
        }

        public void incrementFinalCount(double count) {
            if (FeatureTransducer.this.finalStateCounts == null) {
                throw new IllegalStateException("Transducer is not currently trainable.");
            }
            FeatureTransducer.this.finalStateCounts.increment(this.index, count);
        }

        private void estimate() {
            if (this.transitionCounts == null) {
                throw new IllegalStateException("Transducer is not currently trainable.");
            }
            Multinomial transitionDistribution = this.transitionCounts.estimate();
            for (int i = 0; i < this.transitions.length; ++i) {
                this.transitions[i].weight = transitionDistribution.logProbability(i);
            }
        }
    }
}

