/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.rexp.xgboost;

import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.dmg.pmml.VerificationField;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelEncoder;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.rexp.DecorationUtil;
import org.jpmml.rexp.ModelConverter;
import org.jpmml.rexp.RExp;
import org.jpmml.rexp.RExpEncoder;
import org.jpmml.rexp.RFactorVector;
import org.jpmml.rexp.RGenericVector;
import org.jpmml.rexp.RIntegerVector;
import org.jpmml.rexp.RNumberVector;
import org.jpmml.rexp.RRaw;
import org.jpmml.rexp.RStringVector;
import org.jpmml.rexp.RVector;
import org.jpmml.xgboost.FeatureMap;
import org.jpmml.xgboost.Learner;
import org.jpmml.xgboost.ObjFunction;
import org.jpmml.xgboost.XGBoostUtil;

public class XGBoostConverter
extends ModelConverter<RGenericVector> {
    private Learner learner = null;
    private FeatureMap featureMap = null;
    private boolean compact = this.getOption("compact", Boolean.TRUE);

    public XGBoostConverter(RGenericVector booster) {
        super((RExp)booster);
    }

    public void encodeSchema(RExpEncoder encoder) {
        RGenericVector booster = (RGenericVector)this.getObject();
        RStringVector featureNames = booster.getStringElement("feature_names", false);
        RGenericVector schema = booster.getGenericElement("schema", false);
        FeatureMap featureMap = this.ensureFeatureMap();
        if (featureNames != null) {
            XGBoostConverter.checkFeatureMap(featureMap, featureNames);
        }
        Learner learner = this.ensureLearner();
        ObjFunction obj = learner.obj();
        String targetField = "_target";
        List targetCategories = null;
        if (schema != null) {
            RStringVector responseName = schema.getStringElement("response_name", false);
            RStringVector responseLevels = schema.getStringElement("response_levels", false);
            if (responseName != null) {
                targetField = (String)responseName.asScalar();
            }
            if (responseLevels != null) {
                targetCategories = responseLevels.getValues();
            }
        }
        Label label = obj.encodeLabel(targetField, targetCategories, (ModelEncoder)encoder);
        encoder.setLabel(label);
        List features = featureMap.encodeFeatures(learner, (PMMLEncoder)encoder);
        for (Feature feature : features) {
            encoder.addFeature(feature);
        }
    }

    public MiningModel encodeModel(Schema schema) {
        RGenericVector booster = (RGenericVector)this.getObject();
        RNumberVector ntreeLimit = booster.getNumericElement("ntreelimit", false);
        RGenericVector boosterSchema = booster.getGenericElement("schema", false);
        RNumberVector missing = boosterSchema.getNumericElement("missing", false);
        Learner learner = this.ensureLearner();
        LinkedHashMap<String, Object> options = new LinkedHashMap<String, Object>();
        options.put("missing", missing != null ? missing.asScalar() : null);
        options.put("compact", this.compact);
        options.put("numeric", true);
        options.put("ntree_limit", ntreeLimit != null ? ValueUtil.asInteger((Number)((Number)ntreeLimit.asScalar())) : null);
        Schema xgbSchema = learner.toXGBoostSchema(schema);
        return learner.encodeModel(options, xgbSchema);
    }

    protected Map<VerificationField, List<?>> encodeActiveValues(RGenericVector dataFrame) {
        FeatureMap featureMap = this.ensureFeatureMap();
        XGBoostConverter.checkFeatureMap(featureMap, dataFrame);
        List entries = featureMap.getEntries();
        LinkedHashMap<String, Object> data = new LinkedHashMap<String, Object>();
        block4: for (int i = 0; i < dataFrame.size(); ++i) {
            FeatureMap.Entry entry = (FeatureMap.Entry)entries.get(i);
            final RVector column = dataFrame.getVectorValue(i);
            String name = entry.getName();
            FeatureMap.Entry.Type type = entry.getType();
            switch (type) {
                case INDICATOR: {
                    FeatureMap.IndicatorEntry indicatorEntry = (FeatureMap.IndicatorEntry)entry;
                    RFactorVector factorColumn = (RFactorVector)data.get(name);
                    if (factorColumn == null) {
                        factorColumn = new RFactorVector(null, null){
                            private List<String> factorValues;
                            {
                                super(x0, x1);
                                this.factorValues = new ArrayList<String>();
                                for (int i = 0; i < column.size(); ++i) {
                                    this.factorValues.add(null);
                                }
                            }

                            public List<String> getFactorValues() {
                                return this.factorValues;
                            }
                        };
                        data.put(name, factorColumn);
                    }
                    List factorValues = factorColumn.getFactorValues();
                    List mask = column.getValues();
                    for (int row = 0; row < mask.size(); ++row) {
                        Number rowMask = (Number)mask.get(row);
                        if (rowMask == null || rowMask.doubleValue() != 1.0) continue;
                        String value = indicatorEntry.getValue();
                        if (value == null) {
                            value = "true";
                        }
                        factorValues.set(row, value);
                    }
                    continue block4;
                }
                case QUANTITIVE: 
                case INTEGER: 
                case FLOAT: {
                    data.put(name, column);
                    continue block4;
                }
                default: {
                    throw new IllegalArgumentException(String.valueOf(type));
                }
            }
        }
        ArrayList columns = new ArrayList(data.values());
        ArrayList names = new ArrayList(data.keySet());
        return XGBoostConverter.encodeVerificationData(columns, names);
    }

    private FeatureMap ensureFeatureMap() {
        if (this.featureMap == null) {
            this.featureMap = this.loadFeatureMap();
        }
        return this.featureMap;
    }

    private Learner ensureLearner() {
        if (this.learner == null) {
            this.learner = this.loadLearner();
        }
        return this.learner;
    }

    private FeatureMap loadFeatureMap() {
        RGenericVector booster = (RGenericVector)this.getObject();
        RVector fmap = DecorationUtil.getVectorElement((RGenericVector)booster, (String)"fmap");
        try {
            return XGBoostConverter.loadFeatureMap(fmap);
        }
        catch (IOException ioe) {
            throw new IllegalArgumentException(ioe);
        }
    }

    private Learner loadLearner() {
        RGenericVector booster = (RGenericVector)this.getObject();
        RRaw raw = (RRaw)booster.getElement("raw");
        try {
            return XGBoostConverter.loadLearner(raw);
        }
        catch (IOException ioe) {
            throw new IllegalArgumentException(ioe);
        }
    }

    private static void checkFeatureMap(FeatureMap featureMap, RVector<?> vector) {
        List entries = featureMap.getEntries();
        if (vector.size() != entries.size()) {
            throw new IllegalArgumentException("Invalid 'fmap' element. Expected " + vector.size() + " features, got " + entries.size() + " features");
        }
    }

    private static FeatureMap loadFeatureMap(RVector<?> fmap) throws IOException {
        if (fmap instanceof RStringVector) {
            return XGBoostConverter.loadFeatureMap((RStringVector)fmap);
        }
        if (fmap instanceof RGenericVector) {
            return XGBoostConverter.loadFeatureMap((RGenericVector)fmap);
        }
        throw new IllegalArgumentException();
    }

    private static FeatureMap loadFeatureMap(RStringVector fmap) throws IOException {
        File file = new File((String)fmap.asScalar());
        try (FileInputStream is = new FileInputStream(file);){
            FeatureMap featureMap = XGBoostUtil.loadFeatureMap((InputStream)is);
            return featureMap;
        }
    }

    private static FeatureMap loadFeatureMap(RGenericVector fmap) {
        RIntegerVector id = fmap.getIntegerValue(0);
        RFactorVector name = fmap.getFactorValue(1);
        RFactorVector type = fmap.getFactorValue(2);
        FeatureMap featureMap = new FeatureMap();
        for (int i = 0; i < id.size(); ++i) {
            if (i != id.getValue(i)) {
                throw new IllegalArgumentException();
            }
            featureMap.addEntry(name.getFactorValue(i), type.getFactorValue(i));
        }
        return featureMap;
    }

    private static Learner loadLearner(RRaw raw) throws IOException {
        byte[] value = raw.getValue();
        try (ByteArrayInputStream is = new ByteArrayInputStream(value);){
            Learner learner = XGBoostUtil.loadLearner((InputStream)is, (ByteOrder)ByteOrder.nativeOrder(), null, (String)"$.Model");
            return learner;
        }
    }
}

