/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.xgboost;

import com.google.common.primitives.Floats;
import com.google.gson.JsonArray;
import com.google.gson.JsonObject;
import java.io.IOException;
import java.util.Arrays;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.Schema;
import org.jpmml.xgboost.GradientBooster;
import org.jpmml.xgboost.JSONUtil;
import org.jpmml.xgboost.ObjFunction;
import org.jpmml.xgboost.RegTree;
import org.jpmml.xgboost.XGBoostDataInput;

public class GBTree
extends GradientBooster {
    private int num_trees;
    private int num_roots;
    private int num_feature;
    private int num_output_group;
    private int size_leaf_vector;
    private RegTree[] trees;
    private int[] tree_info;

    @Override
    public String getAlgorithmName() {
        return "GBTree";
    }

    @Override
    public void loadBinary(XGBoostDataInput input) throws IOException {
        this.num_trees = input.readInt();
        this.num_roots = input.readInt();
        this.num_feature = input.readInt();
        input.readReserved(3);
        this.num_output_group = input.readInt();
        this.size_leaf_vector = input.readInt();
        input.readReserved(32);
        this.trees = (RegTree[])input.readObjectArray(RegTree.class, this.num_trees);
        this.tree_info = input.readIntArray(this.num_trees);
    }

    @Override
    public void loadJSON(JsonObject gradientBooster) {
        JsonObject model = gradientBooster.getAsJsonObject("model");
        JsonObject gbtreeModelParam = model.getAsJsonObject("gbtree_model_param");
        this.num_trees = gbtreeModelParam.getAsJsonPrimitive("num_trees").getAsInt();
        this.size_leaf_vector = gbtreeModelParam.getAsJsonPrimitive("size_leaf_vector").getAsInt();
        JsonArray trees = model.getAsJsonArray("trees");
        this.trees = new RegTree[this.num_trees];
        for (int i = 0; i < this.num_trees; ++i) {
            JsonObject tree = trees.get(i).getAsJsonObject();
            this.trees[i] = new RegTree();
            this.trees[i].loadJSON(tree);
        }
        this.tree_info = JSONUtil.toIntArray(model.getAsJsonArray("tree_info"));
    }

    public MiningModel encodeMiningModel(ObjFunction obj, float base_score, Integer ntreeLimit, Schema schema) {
        RegTree[] trees = this.trees();
        float[] weights = this.tree_weights();
        return obj.encodeMiningModel(Arrays.asList(trees), weights != null ? Floats.asList((float[])weights) : null, base_score, ntreeLimit, schema);
    }

    public int num_trees() {
        return this.num_trees;
    }

    public RegTree[] trees() {
        return this.trees;
    }

    public float[] tree_weights() {
        return null;
    }
}

