/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.lightgbm;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Field;
import org.dmg.pmml.Interval;
import org.dmg.pmml.InvalidValueTreatmentMethod;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.Value;
import org.dmg.pmml.Visitable;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.BooleanFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Decorator;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldUtil;
import org.jpmml.converter.InvalidValueDecorator;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelEncoder;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.TypeUtil;
import org.jpmml.converter.WildcardFeature;
import org.jpmml.lightgbm.BinaryCategoricalFeature;
import org.jpmml.lightgbm.BinomialLogisticRegression;
import org.jpmml.lightgbm.DirectCategoricalFeature;
import org.jpmml.lightgbm.Lambdarank;
import org.jpmml.lightgbm.LightGBMEncoder;
import org.jpmml.lightgbm.LightGBMUtil;
import org.jpmml.lightgbm.MultinomialLogisticRegression;
import org.jpmml.lightgbm.NullFeature;
import org.jpmml.lightgbm.ObjectiveFunction;
import org.jpmml.lightgbm.PandasUtil;
import org.jpmml.lightgbm.PoissonRegression;
import org.jpmml.lightgbm.Regression;
import org.jpmml.lightgbm.Section;
import org.jpmml.lightgbm.Tree;
import org.jpmml.lightgbm.visitors.TreeModelCompactor;
import org.jpmml.model.visitors.VisitorBattery;

public class GBDT {
    private String version;
    private int max_feature_idx_;
    private int label_idx_;
    private String[] feature_names_;
    private String[] feature_infos_;
    private ObjectiveFunction object_function_;
    private Tree[] models_;
    private Map<String, String> feature_importances = Collections.emptyMap();
    private String linear_tree;
    private List<List<?>> pandas_categorical = Collections.emptyList();
    private static final Integer CATEGORY_MISSING = -1;

    public void load(List<Section> sections) {
        String treeId;
        Section section;
        int index = 0;
        Section section2 = sections.get(index);
        if (!section2.checkId("tree")) {
            throw new IllegalArgumentException();
        }
        this.version = section2.getString("version");
        if (this.version != null) {
            switch (this.version) {
                case "v2": 
                case "v3": 
                case "v4": {
                    break;
                }
                default: {
                    throw new IllegalArgumentException("Version " + this.version + " is not supported");
                }
            }
        }
        this.max_feature_idx_ = section2.getInt("max_feature_idx");
        this.label_idx_ = section2.getInt("label_index");
        this.feature_names_ = section2.getStringArray("feature_names", this.max_feature_idx_ + 1);
        this.feature_infos_ = section2.getStringArray("feature_infos", this.max_feature_idx_ + 1);
        this.object_function_ = GBDT.loadObjectiveFunction(section2);
        ++index;
        ArrayList<Tree> trees = new ArrayList<Tree>();
        while (index < sections.size() && (section = sections.get(index)).checkId(treeId = "Tree=" + String.valueOf(index - 1))) {
            Tree tree = new Tree();
            tree.load(section);
            trees.add(tree);
            ++index;
        }
        this.models_ = trees.toArray(new Tree[trees.size()]);
        if ((index = GBDT.skipEndSection("end of trees", sections, index)) < sections.size() && ((section = sections.get(index)).checkId("feature importances:") || section.checkId("feature_importances:"))) {
            this.feature_importances = this.loadFeatureSection(section);
            ++index;
        }
        if (index < sections.size() && (section = sections.get(index)).checkId("parameters:")) {
            this.linear_tree = section.get("linear_tree", false);
            ++index;
            index = GBDT.skipEndSection("end of parameters", sections, index);
        }
        if (index < sections.size() && (section = sections.get(index)).checkId(id -> id.startsWith("pandas_categorical:"))) {
            this.pandas_categorical = this.loadPandasCategorical(section);
            ++index;
        }
    }

    public Schema encodeSchema(String targetName, List<String> targetCategories, ModelEncoder encoder) {
        ObjectiveFunction object_function_ = this.getObjectiveFunction();
        if (object_function_ == null) {
            throw new IllegalStateException();
        }
        if (targetName == null) {
            targetName = "_target";
        }
        Label label = object_function_.encodeLabel(targetName, targetCategories, encoder);
        ArrayList<NullFeature> features = new ArrayList<NullFeature>();
        boolean hasPandasCategories = this.pandas_categorical.size() > 0;
        int pandasCategoryIndex = 0;
        String[] featureNames = this.feature_names_;
        String[] featureInfos = this.feature_infos_;
        if (featureNames.length != featureInfos.length) {
            throw new IllegalArgumentException();
        }
        for (int i = 0; i < featureNames.length; ++i) {
            Object feature;
            DataField dataField;
            Boolean categorical;
            String featureName = featureNames[i];
            String featureInfo = featureInfos[i];
            if (LightGBMUtil.isNone(featureInfo)) {
                List<?> pandasCategoryValues;
                features.add(new NullFeature((PMMLEncoder)encoder, featureName, DataType.DOUBLE));
                if (!hasPandasCategories || pandasCategoryIndex >= this.pandas_categorical.size() || (pandasCategoryValues = this.pandas_categorical.get(pandasCategoryIndex)).size() != 1) continue;
                ++pandasCategoryIndex;
                continue;
            }
            Boolean binary = this.isBinary(i);
            if (binary == null) {
                binary = Boolean.FALSE;
            }
            if ((categorical = this.isCategorical(i)) == null) {
                categorical = LightGBMUtil.isValues(featureInfo);
            }
            if (categorical.booleanValue()) {
                if (binary.booleanValue()) {
                    throw new IllegalArgumentException();
                }
                List<Object> values = LightGBMUtil.parseValues(featureInfo).stream().filter(value -> value != CATEGORY_MISSING).sorted().collect(Collectors.toList());
                DataType dataType = DataType.INTEGER;
                boolean direct = true;
                if (hasPandasCategories) {
                    if (pandasCategoryIndex >= this.pandas_categorical.size()) {
                        throw new IllegalArgumentException("Conflicting categorical feature information between the header and \"pandas_categorical\" sections");
                    }
                    List<?> pandasCategoryValues = this.pandas_categorical.get(pandasCategoryIndex);
                    if (pandasCategoryValues.size() < values.size()) {
                        throw new IllegalArgumentException("Expected at least " + values.size() + " category levels, got " + pandasCategoryValues.size() + " category levels");
                    }
                    values = pandasCategoryValues;
                    dataType = TypeUtil.getDataType(pandasCategoryValues);
                    direct = false;
                    ++pandasCategoryIndex;
                }
                dataField = encoder.createDataField(featureName, OpType.CATEGORICAL, dataType, values);
                feature = dataType == DataType.BOOLEAN && BooleanFeature.VALUES.equals(values) ? new BooleanFeature((PMMLEncoder)encoder, (Field)dataField) : (direct ? new DirectCategoricalFeature((PMMLEncoder)encoder, dataField) : new CategoricalFeature((PMMLEncoder)encoder, (Field)dataField));
                encoder.addDecorator((Field)dataField, (Decorator)new InvalidValueDecorator(InvalidValueTreatmentMethod.AS_MISSING, null));
            } else {
                if (binary.booleanValue()) {
                    dataField = encoder.createDataField(featureName, OpType.CATEGORICAL, DataType.INTEGER, Arrays.asList(0, 1));
                    feature = new BinaryFeature((PMMLEncoder)encoder, (Field)dataField, (Object)1);
                } else {
                    Interval interval = LightGBMUtil.parseInterval(featureInfo);
                    dataField = encoder.createDataField(featureName, OpType.CONTINUOUS, DataType.DOUBLE);
                    if (interval != null) {
                        dataField.addIntervals(new Interval[]{interval});
                    }
                    feature = new ContinuousFeature((PMMLEncoder)encoder, (Field)dataField);
                }
                encoder.addDecorator((Field)dataField, (Decorator)new InvalidValueDecorator(InvalidValueTreatmentMethod.AS_IS, null));
            }
            features.add((NullFeature)((Object)feature));
            Double importance = this.getFeatureImportance(featureName);
            if (importance == null) continue;
            encoder.addFeatureImportance((Feature)feature, (Number)importance);
        }
        return new Schema(encoder, label, features);
    }

    public Schema toLightGBMSchema(final Schema schema) {
        final String[] featureNames = this.feature_names_;
        final String[] featureInfos = this.feature_infos_;
        Function<Feature, Feature> function = new Function<Feature, Feature>(){
            private ModelEncoder encoder;
            private List<? extends Feature> features;
            {
                this.encoder = schema.getEncoder();
                this.features = schema.getFeatures();
                SchemaUtil.checkSize((int)featureNames.length, this.features);
                SchemaUtil.checkSize((int)featureInfos.length, this.features);
            }

            @Override
            public Feature apply(Feature feature) {
                int index = this.features.indexOf(feature);
                if (index < 0) {
                    throw new IllegalArgumentException();
                }
                String featureName = featureNames[index];
                String featureInfo = featureInfos[index];
                Double importance = GBDT.this.getFeatureImportance(featureName);
                if (importance != null) {
                    this.encoder.addFeatureImportance(feature, (Number)importance);
                }
                if (feature instanceof BinaryFeature) {
                    BinaryFeature binaryFeature = (BinaryFeature)feature;
                    Boolean binary = GBDT.this.isBinary(index);
                    if (binary == null || binary.booleanValue()) {
                        return binaryFeature;
                    }
                    Boolean categorical = GBDT.this.isCategorical(index);
                    if (categorical != null && categorical.booleanValue()) {
                        BinaryCategoricalFeature categoricalFeature = new BinaryCategoricalFeature((PMMLEncoder)this.encoder, binaryFeature);
                        return categoricalFeature;
                    }
                } else if (feature instanceof CategoricalFeature) {
                    CategoricalFeature categoricalFeature = (CategoricalFeature)feature;
                    Boolean categorical = GBDT.this.isCategorical(index);
                    if (categorical == null || categorical.booleanValue()) {
                        return categoricalFeature;
                    }
                } else {
                    if (feature instanceof NullFeature) {
                        NullFeature nullFeature = (NullFeature)feature;
                        return nullFeature;
                    }
                    if (feature instanceof WildcardFeature) {
                        WildcardFeature wildcardFeature = (WildcardFeature)feature;
                        Boolean binary = GBDT.this.isBinary(index);
                        if (binary != null && binary.booleanValue()) {
                            wildcardFeature.toCategoricalFeature(Arrays.asList(0, 1));
                            BinaryFeature binaryFeature = new BinaryFeature((PMMLEncoder)this.encoder, (Feature)wildcardFeature, (Object)1);
                            return binaryFeature;
                        }
                    }
                }
                return feature.toContinuousFeature();
            }
        };
        return schema.toTransformedSchema((Function)function);
    }

    public PMML encodePMML(Map<String, ?> options, String targetName, List<String> targetCategories) {
        LightGBMEncoder encoder = new LightGBMEncoder();
        Schema schema = this.encodeSchema(targetName, targetCategories, encoder);
        MiningModel miningModel = this.encodeModel(options, schema);
        PMML pmml = encoder.encodePMML((Model)miningModel);
        return pmml;
    }

    public MiningModel encodeModel(Map<String, ?> options, Schema schema) {
        Integer numIterations = (Integer)options.get("num_iteration");
        schema = this.configureSchema(options, schema);
        MiningModel miningModel = this.encodeModel(numIterations, schema);
        miningModel = this.configureModel(options, miningModel);
        return miningModel;
    }

    public MiningModel encodeModel(Integer numIterations, Schema schema) {
        ObjectiveFunction object_function_ = this.getObjectiveFunction();
        if (object_function_ == null) {
            throw new IllegalStateException();
        }
        MiningModel miningModel = object_function_.encodeModel(Arrays.asList(this.models_), numIterations, schema).setAlgorithmName("LightGBM");
        return miningModel;
    }

    public Schema configureSchema(Map<String, ?> options, Schema schema) {
        final Boolean nanAsMissing = (Boolean)options.get("nan_as_missing");
        Function<Feature, Feature> function = new Function<Feature, Feature>(){

            @Override
            public Feature apply(Feature feature) {
                if (feature instanceof NullFeature) {
                    NullFeature nullFeature = (NullFeature)feature;
                    return nullFeature;
                }
                if (nanAsMissing != null && nanAsMissing.booleanValue()) {
                    DataType dataType = feature.getDataType();
                    switch (dataType) {
                        case INTEGER: {
                            break;
                        }
                        case FLOAT: 
                        case DOUBLE: {
                            Field field = feature.getField();
                            if (!(field instanceof DataField)) break;
                            DataField dataField = (DataField)field;
                            FieldUtil.addValues((Field)dataField, (Value.Property)Value.Property.MISSING, Collections.singletonList("NaN"));
                            break;
                        }
                    }
                }
                return feature;
            }
        };
        return schema.toTransformedSchema((Function)function);
    }

    public MiningModel configureModel(Map<String, ?> options, MiningModel miningModel) {
        Boolean compact = (Boolean)options.get("compact");
        VisitorBattery visitors = new VisitorBattery();
        if (Boolean.TRUE.equals(compact)) {
            visitors.add(TreeModelCompactor.class);
        }
        visitors.applyTo((Visitable)miningModel);
        return miningModel;
    }

    public String[] getFeatureNames() {
        return this.feature_names_;
    }

    public String[] getFeatureInfos() {
        return this.feature_infos_;
    }

    public ObjectiveFunction getObjectiveFunction() {
        return this.object_function_;
    }

    public void setObjectiveFunction(ObjectiveFunction object_function_) {
        this.object_function_ = object_function_;
    }

    public boolean hasLinearTree() {
        return Objects.equals("1", this.linear_tree);
    }

    private Boolean isBinary(int feature) {
        Tree[] trees;
        String featureInfo = this.feature_infos_[feature];
        if (!LightGBMUtil.isBinaryInterval(featureInfo)) {
            return Boolean.FALSE;
        }
        Boolean result = null;
        for (Tree tree : trees = this.models_) {
            Boolean binary = tree.isBinary(feature);
            if (binary == null) continue;
            if (!binary.booleanValue()) {
                return Boolean.FALSE;
            }
            result = Boolean.TRUE;
        }
        return result;
    }

    private Boolean isCategorical(int feature) {
        Tree[] trees;
        String featureInfo = this.feature_infos_[feature];
        if (!LightGBMUtil.isValues(featureInfo)) {
            return Boolean.FALSE;
        }
        Boolean result = null;
        for (Tree tree : trees = this.models_) {
            Boolean categorical = tree.isCategorical(feature);
            if (categorical == null) continue;
            if (!categorical.booleanValue()) {
                return Boolean.FALSE;
            }
            result = Boolean.TRUE;
        }
        return result;
    }

    private Double getFeatureImportance(String featureName) {
        String value = this.feature_importances.get(featureName);
        return value != null ? Double.valueOf(value) : null;
    }

    private static ObjectiveFunction loadObjectiveFunction(Section section) {
        String standardizedName;
        if (!section.containsKey("objective")) {
            return null;
        }
        String[] tokens = section.getStringArray("objective", -1);
        if (tokens.length == 0) {
            throw new IllegalArgumentException();
        }
        String name = tokens[0];
        Section config = new Section();
        config.put("name", name);
        boolean average_output = section.containsKey("average_output");
        if (average_output) {
            config.put("average_output", null);
        }
        for (int i = 1; i < tokens.length; ++i) {
            config.put(tokens[i], ':');
        }
        switch (standardizedName = GBDT.standardizeObjectiveFunctionName(name.toLowerCase())) {
            case "regression": 
            case "regression_l1": 
            case "huber": 
            case "fair": 
            case "quantile": {
                return new Regression(config);
            }
            case "poisson": 
            case "gamma": 
            case "tweedie": {
                return new PoissonRegression(config);
            }
            case "lambdarank": {
                return new Lambdarank(config);
            }
            case "binary": {
                config.put("num_class", "2");
                return new BinomialLogisticRegression(config);
            }
            case "cross_entropy": {
                config.put("num_class", "2");
                config.put("sigmoid", "1.0");
                return new BinomialLogisticRegression(config);
            }
            case "multiclass": {
                return new MultinomialLogisticRegression(config);
            }
            case "custom": {
                return null;
            }
        }
        throw new IllegalArgumentException(standardizedName);
    }

    private static String standardizeObjectiveFunctionName(String name) {
        switch (name) {
            case "regression": 
            case "regression_l2": 
            case "mean_squared_error": 
            case "mse": 
            case "l2": 
            case "l2_root": 
            case "root_mean_squared_error": 
            case "rmse": {
                return "regression";
            }
            case "regression_l1": 
            case "mean_absolute_error": 
            case "l1": 
            case "mae": {
                return "regression_l1";
            }
            case "multiclass": 
            case "softmax": {
                return "multiclass";
            }
            case "multiclassova": 
            case "multiclass_ova": 
            case "ova": 
            case "ovr": {
                return "multiclassova";
            }
            case "xentropy": 
            case "cross_entropy": {
                return "cross_entropy";
            }
            case "xentlambda": 
            case "cross_entropy_lambda": {
                return "cross_entropy_lambda";
            }
            case "mean_absolute_percentage_error": 
            case "mape": {
                return "mape";
            }
            case "none": 
            case "null": 
            case "custom": 
            case "na": {
                return "custom";
            }
        }
        return name;
    }

    private Map<String, String> loadFeatureSection(Section section) {
        LinkedHashMap<String, String> result = new LinkedHashMap<String, String>(section);
        result.keySet().retainAll(Arrays.asList(this.feature_names_));
        return result;
    }

    private List<List<?>> loadPandasCategorical(Section section) {
        String id = section.id();
        try {
            List<List<Object>> result = PandasUtil.parsePandasCategorical(id);
            if (result == null) {
                result = Collections.emptyList();
            }
            return result;
        }
        catch (Exception e) {
            throw new IllegalArgumentException(id, e);
        }
    }

    private static int skipEndSection(String id, List<Section> sections, int index) {
        Section section;
        if (index < sections.size() && (section = sections.get(index)).checkId(id)) {
            return index + 1;
        }
        return index;
    }
}

