/*-------------------------------------------------------------------------------
 This file is part of unityForest.

 Copyright (c) [2014-2018] [Marvin N. Wright]
 Modifications and extensions by Roman Hornung

 This software may be modified and distributed under the terms of the MIT license.

 Please note that the C++ core of divfor is distributed under MIT license and the
 R package "unityForest" under GPL3 license.
 #-------------------------------------------------------------------------------*/

#include <thread>
#include <chrono>
#include <iterator>
#include <cmath>

#include <Rcpp.h>

#include "Tree.h"
#include "utility.h"

//#include "debug_cp.h"

namespace unityForest
{

  Tree::Tree() : dependent_varID(0), mtry(0), prop_var_root(0), num_samples(0), num_samples_oob(0), min_node_size(0), min_node_size_root(0), deterministic_varIDs(0), split_select_varIDs(
                                                                                                                                                                          0),
                 split_select_weights(0), case_weights(0), manual_inbag(0), oob_sampleIDs(0), holdout(false), keep_inbag(
                                                                                                                  false),
                 data(0), variable_importance(0), importance_mode(DEFAULT_IMPORTANCE_MODE), sample_with_replacement(
                                                                                                true),
                 sample_fraction(0), memory_saving_splitting(false), splitrule(DEFAULT_SPLITRULE), alpha(DEFAULT_ALPHA), minprop(
                                                                                                                             DEFAULT_MINPROP),
                 num_random_splits(DEFAULT_NUM_RANDOM_SPLITS), max_depth(DEFAULT_MAXDEPTH), max_depth_root(DEFAULT_MAXDEPTHROOT), num_cand_trees(DEFAULT_NUMCANDTREES), depth(0), last_left_nodeID(0), last_left_nodeID_loop(0)
  {
  }

  Tree::Tree(std::vector<std::vector<size_t>> &child_nodeIDs, std::vector<size_t> &split_varIDs,
             std::vector<double> &split_values) : dependent_varID(0), mtry(0), prop_var_root(0), num_samples(0), num_samples_oob(0), min_node_size(0), min_node_size_root(0), deterministic_varIDs(0), split_select_varIDs(0), split_select_weights(0), case_weights(0), manual_inbag(0), split_varIDs(split_varIDs), split_values(split_values), child_nodeIDs(child_nodeIDs), oob_sampleIDs(0), holdout(false), keep_inbag(false), data(0), variable_importance(0), importance_mode(DEFAULT_IMPORTANCE_MODE), sample_with_replacement(true), sample_fraction(0), memory_saving_splitting(false), splitrule(DEFAULT_SPLITRULE), alpha(DEFAULT_ALPHA), minprop(DEFAULT_MINPROP), num_random_splits(DEFAULT_NUM_RANDOM_SPLITS), max_depth(DEFAULT_MAXDEPTH), max_depth_root(DEFAULT_MAXDEPTHROOT), num_cand_trees(DEFAULT_NUMCANDTREES), depth(0), last_left_nodeID(0), last_left_nodeID_loop(0)
  {
  }

  // Constructor for repr_tree_mode:
  Tree::Tree(std::vector<std::vector<size_t>> &child_nodeIDs, std::vector<size_t> &split_varIDs,
             std::vector<double> &split_values,
             const Data *data_ptr) : dependent_varID(0), mtry(0), prop_var_root(0), num_samples(0), num_samples_oob(0), min_node_size(0), min_node_size_root(0), deterministic_varIDs(0), split_select_varIDs(0), split_select_weights(0), case_weights(0), manual_inbag(0), split_varIDs(split_varIDs), split_values(split_values), child_nodeIDs(child_nodeIDs), oob_sampleIDs(0), holdout(false), keep_inbag(false), data(data_ptr), variable_importance(0), importance_mode(DEFAULT_IMPORTANCE_MODE), sample_with_replacement(true), sample_fraction(0), memory_saving_splitting(false), splitrule(DEFAULT_SPLITRULE), alpha(DEFAULT_ALPHA), minprop(DEFAULT_MINPROP), num_random_splits(DEFAULT_NUM_RANDOM_SPLITS), max_depth(DEFAULT_MAXDEPTH), max_depth_root(DEFAULT_MAXDEPTHROOT), num_cand_trees(DEFAULT_NUMCANDTREES), depth(0), last_left_nodeID(0), last_left_nodeID_loop(0)
  {
  }

  void Tree::init(const Data *data, uint mtry, double prop_var_root, size_t dependent_varID, size_t num_samples, uint seed,
                  std::vector<size_t> *deterministic_varIDs, std::vector<size_t> *split_select_varIDs,
                  std::vector<double> *split_select_weights, ImportanceMode importance_mode, uint min_node_size, uint min_node_size_root,
                  bool sample_with_replacement, bool memory_saving_splitting, SplitRule splitrule, std::vector<double> *case_weights,
                  std::vector<size_t> *manual_inbag, bool keep_inbag, std::vector<double> *sample_fraction, double alpha,
                  double minprop, bool holdout, uint num_random_splits, uint max_depth, uint max_depth_root, uint num_cand_trees, std::vector<size_t> repr_vars)
  {

    this->data = data;
    this->mtry = mtry;
    this->dependent_varID = dependent_varID;
    this->num_samples = num_samples;
    this->memory_saving_splitting = memory_saving_splitting;
    this->prop_var_root = prop_var_root;
    this->repr_vars = repr_vars;

    // Create root node, assign bootstrap sample and oob samples
    child_nodeIDs.push_back(std::vector<size_t>());
    child_nodeIDs.push_back(std::vector<size_t>());

    createEmptyNodeFullTree();

    // Initialize random number generator and set seed
    random_number_generator.seed(seed);

    this->deterministic_varIDs = deterministic_varIDs;
    this->split_select_varIDs = split_select_varIDs;
    this->split_select_weights = split_select_weights;
    this->importance_mode = importance_mode;
    this->min_node_size = min_node_size;
    this->min_node_size_root = min_node_size_root;
    this->sample_with_replacement = sample_with_replacement;
    this->splitrule = splitrule;
    this->case_weights = case_weights;
    this->manual_inbag = manual_inbag;
    this->keep_inbag = keep_inbag;
    this->sample_fraction = sample_fraction;
    this->holdout = holdout;
    this->alpha = alpha;
    this->minprop = minprop;
    this->num_random_splits = num_random_splits;
    this->max_depth = max_depth;
    this->max_depth_root = max_depth_root;
    this->num_cand_trees = num_cand_trees;
  }

  // Construct a tree.
  void Tree::grow(std::vector<double> *variable_importance)
  {

    // CP();

    // Allocate memory for tree growing
    allocateMemory();

    this->variable_importance = variable_importance;

    // CP();

    // Bootstrap, dependent if weighted or not and with or without replacement
    if (!case_weights->empty())
    {
      if (sample_with_replacement)
      {
        bootstrapWeighted();
      }
      else
      {
        bootstrapWithoutReplacementWeighted();
      }
    }
    else if (sample_fraction->size() > 1)
    {
      if (sample_with_replacement)
      {
        bootstrapClassWise();
      }
      else
      {
        bootstrapWithoutReplacementClassWise();
      }
    }
    else if (!manual_inbag->empty())
    {
      setManualInbag();
    }
    else
    {
      if (sample_with_replacement)
      {
        bootstrap();
      }
      else
      {
        bootstrapWithoutReplacement();
      }
    }

    // CP();

    // Randomly draw a proportion prop_var_root of the variables (default 0.7), which will be used for the tree root:
    std::vector<size_t> varIDs_root;
    // Determine the number of variables to be used for the tree root as prop_var_root times the number of all available variables:
    size_t num_vars_root = round((data->getNumCols() - 1) * prop_var_root);
    if (num_vars_root == 0)
    {
      num_vars_root = 1;
    }

    // CP();

    // Draw num_vars_root variables without replacement from all available variables:
    // Get the vector with all available variables, while excluding the variables in data->getNoSplitVariables():
    const std::vector<size_t> &all_vars = *allowedVarIDs_;
    // Draw num_vars_root variables without replacement from all_vars:
    drawWithoutReplacementFromVector(varIDs_root, all_vars, random_number_generator, num_vars_root);

    // CP();

    // Generate 'num_cand_trees' trees (default: 1000) with random splits, where each tree is grown to a maximum depth of three:

    child_nodeIDs_loop.push_back(std::vector<size_t>());
    child_nodeIDs_loop.push_back(std::vector<size_t>());

    double best_decrease = -1;

    // Distribution to generate double between 0.0 and 1.0
    std::uniform_real_distribution<double> distr(0.0, 1.0);

    const size_t MAX_NODES = static_cast<size_t>(pow(2, max_depth_root + 1)) - 1;
    split_varIDs_loop.reserve(MAX_NODES);
    split_values_loop.reserve(MAX_NODES);
    child_nodeIDs_loop[0].reserve(MAX_NODES);
    child_nodeIDs_loop[1].reserve(MAX_NODES);
    start_pos_loop.reserve(MAX_NODES);
    end_pos_loop.reserve(MAX_NODES);

    // CP();

    for (size_t j = 0; j < num_cand_trees; ++j)
    {

      // Clear vectors
      clearRandomTree();

      // CP();

      // Make empty node:
      createEmptyNodeRandomTree();

      // CP();

      // Init start and end positions
      start_pos_loop[0] = 0;
      end_pos_loop[0] = sampleIDs.size();

      // While not all nodes terminal, split next node
      size_t num_open_nodes = 1;
      last_left_nodeID_loop = 0;
      size_t i = 0;
      depth = 0;
      std::vector<size_t> terminal_nodes; // Vector to store indices of terminal nodes
      while (num_open_nodes > 0)
      {

        // Split node at random
        bool is_terminal_node = splitNodeRandom(i, varIDs_root);

        if (is_terminal_node)
        {
          terminal_nodes.push_back(i); // Add index of terminal node to vector
          --num_open_nodes;
        }
        else
        {
          ++num_open_nodes;
          if (i >= last_left_nodeID_loop)
          {
            // If new level, increase depth
            // (left_node saves left-most node in current level, new level reached if that node is splitted)
            last_left_nodeID_loop = split_varIDs_loop.size() - 2;
            ++depth;
          }
        }
        ++i;
      }

      // CP();

      // Evaluate the tree:
      double decrease = evaluateRandomTree(terminal_nodes);

      // CP();

      // Save the current tree if it is better than the best tree so far:
      if (decrease >= best_decrease)
      {
        // If decrease is equal to best_decrease, the current tree is saved with a probability of 0.5:
        if (decrease == best_decrease)
        {
          if (distr(random_number_generator) < 0.5)
          {
            split_varIDs_best = split_varIDs_loop;
            split_values_best = split_values_loop;
            child_nodeIDs_best = child_nodeIDs_loop;
            best_decrease = decrease;
          }
        }
        else
        {
          split_varIDs_best = split_varIDs_loop;
          split_values_best = split_values_loop;
          child_nodeIDs_best = child_nodeIDs_loop;
          best_decrease = decrease;
        }
      }

      // CP();
    }

    // CP();

    // Extend the best tree to the maximum depth using conventional splitting:

    // Init start and end positions
    start_pos[0] = 0;
    end_pos[0] = sampleIDs.size();

    // The first root in the (full) tree is always from the tree root
    nodeID_in_root[0] = 0;

    // While not all nodes terminal, split next node
    size_t num_open_nodes = 1;
    size_t i = 0;
    depth = 0;
    while (num_open_nodes > 0)
    {
      // Split node
      bool is_terminal_node = splitNodeFullTree(i);
      if (is_terminal_node)
      {
        --num_open_nodes;
      }
      else
      {
        ++num_open_nodes;
        if (i >= last_left_nodeID)
        {
          // If new level, increase depth
          // (left_node saves left-most node in current level, new level reached if that node is splitted)
          last_left_nodeID = split_varIDs.size() - 2;
          ++depth;
        }
      }
      ++i;
    }

    // CP();

    // Delete sampleID vector to save memory
    /// sampleIDs.clear();
    /// sampleIDs.shrink_to_fit();
    cleanUpInternal();

    // CP();
  }

  // Predict using a tree.
  void Tree::predict(const Data *prediction_data, bool oob_prediction)
  {

    size_t num_samples_predict;
    if (oob_prediction)
    {
      num_samples_predict = num_samples_oob;
    }
    else
    {
      num_samples_predict = prediction_data->getNumRows();
    }

    prediction_terminal_nodeIDs.resize(num_samples_predict, 0);

    // For each sample start in root, drop down the tree and return final value
    for (size_t i = 0; i < num_samples_predict; ++i)
    {
      size_t sample_idx;
      if (oob_prediction)
      {
        sample_idx = oob_sampleIDs[i];
      }
      else
      {
        sample_idx = i;
      }
      size_t nodeID = 0;
      while (1)
      {

        // Break if terminal node
        if (child_nodeIDs[0][nodeID] == 0 && child_nodeIDs[1][nodeID] == 0)
        {
          break;
        }

        // Move to child
        size_t split_varID = split_varIDs[nodeID];

        double value = prediction_data->get(sample_idx, split_varID);
        if (prediction_data->isOrderedVariable(split_varID))
        {
          if (value <= split_values[nodeID])
          {
            // Move to left child
            nodeID = child_nodeIDs[0][nodeID];
          }
          else
          {
            // Move to right child
            nodeID = child_nodeIDs[1][nodeID];
          }
        }
        else
        {
          size_t factorID = floor(value) - 1;
          size_t splitID = floor(split_values[nodeID]);

          // Left if 0 found at position factorID
          if (!(splitID & (1 << factorID)))
          {
            // Move to left child
            nodeID = child_nodeIDs[0][nodeID];
          }
          else
          {
            // Move to right child
            nodeID = child_nodeIDs[1][nodeID];
          }
        }
      }

      prediction_terminal_nodeIDs[i] = nodeID;
    }
  }

  // Compute the unity VIM.
  void Tree::computeUFImportance(std::vector<double> &forest_importance)
  {

    // If at least one element of is_in_best is 1, compute the importance of the variables:
    if (std::find(is_in_best.begin(), is_in_best.end(), 1) != is_in_best.end())
    {

      // Determine the node IDs in the tree for which is_in_best is 1:
      std::vector<size_t> best_nodeIDs;
      for (size_t i = 0; i < is_in_best.size(); ++i)
      {
        if (is_in_best[i] == 1)
        {
          best_nodeIDs.push_back(i);
        }
      }

      // Drop the OOB observations down the tree and for each node in best_nodeIDs, determine the
      // OOB observations that pass through the node:
      std::vector<std::vector<size_t>> oob_sampleIDs_nodeID(best_nodeIDs.size());
      for (size_t sampleID : oob_sampleIDs)
      {
        size_t nodeID = 0;
        while (true)
        {
          if (std::find(best_nodeIDs.begin(), best_nodeIDs.end(), nodeID) != best_nodeIDs.end())
          {
            oob_sampleIDs_nodeID[std::find(best_nodeIDs.begin(), best_nodeIDs.end(), nodeID) - best_nodeIDs.begin()].push_back(sampleID);
          }
          if (child_nodeIDs[0][nodeID] == 0 && child_nodeIDs[1][nodeID] == 0)
          {
            break;
          }
          size_t split_varID = split_varIDs[nodeID];
          double value = data->get(sampleID, split_varID);
          if (value <= split_values[nodeID])
          {
            nodeID = child_nodeIDs[0][nodeID];
          }
          else
          {
            nodeID = child_nodeIDs[1][nodeID];
          }
        }
      }

      // Remove empty elements from oob_sampleIDs_nodeID and the corresponding elements from best_nodeIDs:
      std::vector<size_t> best_nodeIDs_temp;
      std::vector<std::vector<size_t>> oob_sampleIDs_nodeID_temp;
      for (size_t i = 0; i < best_nodeIDs.size(); ++i)
      {
        if (!oob_sampleIDs_nodeID[i].empty())
        {
          best_nodeIDs_temp.push_back(best_nodeIDs[i]);
          oob_sampleIDs_nodeID_temp.push_back(oob_sampleIDs_nodeID[i]);
        }
      }
      best_nodeIDs = best_nodeIDs_temp;
      oob_sampleIDs_nodeID = oob_sampleIDs_nodeID_temp;

      // Loop through best_nodeIDs and compute the importance of the variables:
      for (size_t i = 0; i < best_nodeIDs.size(); ++i)
      {

        // Calculate the importance of the variable for the node using OOB observations:
        forest_importance[split_varIDs[best_nodeIDs[i]]] += computeUFNodeImportance(best_nodeIDs[i], oob_sampleIDs_nodeID[i]);
      }
    }
  }

  // Compute the unity VIM contribution at a specific node by subtracting the OOB-based split criterion value after
  // from that before permuting the covariate values in the node.
  double Tree::computeUFNodeImportance(size_t nodeID, std::vector<size_t> oob_sampleIDs_nodeID)
  {

    // Compute the OOB split criterion value for the node:
    double split_crit_node = computeOOBSplitCriterionValue(nodeID, oob_sampleIDs_nodeID);

    // Permute the OOB observations in oob_sampleIDs_nodeID:
    std::vector<size_t> permutations(oob_sampleIDs_nodeID);
    std::shuffle(permutations.begin(), permutations.end(), random_number_generator);

    // Compute the OOB split criterion value for the node after permuting the OOB observations:
    double split_crit_node_permuted = computeOOBSplitCriterionValuePermuted(nodeID, oob_sampleIDs_nodeID, permutations);

    // Compute the difference between the OOB split criterion value for the node and the OOB split
    // criterion value for the node after permuting the OOB observations and weight it by the number of
    // in-bag observations that pass through the node:
    double importance_node = (split_crit_node - split_crit_node_permuted) * (end_pos[nodeID] - start_pos[nodeID]);

    // Return the importance of the variable for the node:
    return importance_node;
  }

  // Compute the split scores for the unity VIM computation. 
  void Tree::computeSplitCriterionValues()
  {

    // Initialize the split criterion for each node in the tree that is in the tree root and is not a terminal node:
    split_criterion.resize(split_varIDs.size(), -1);

    // Drop all in-bag observations down the tree and store the node IDs
    std::vector<std::vector<size_t>> inbag_sampleIDs_nodeID(split_varIDs.size());
    for (size_t sampleID : sampleIDs)
    {
      size_t nodeID = 0;
      while (true)
      {
        inbag_sampleIDs_nodeID[nodeID].push_back(sampleID);
        if (child_nodeIDs[0][nodeID] == 0 && child_nodeIDs[1][nodeID] == 0)
        {
          break;
        }
        size_t split_varID = split_varIDs[nodeID];
        double value = data->get(sampleID, split_varID);
        if (value <= split_values[nodeID])
        {
          nodeID = child_nodeIDs[0][nodeID];
        }
        else
        {
          nodeID = child_nodeIDs[1][nodeID];
        }
      }
    }

    // Calculate the split criterion for each node in the tree that is in the tree root and is not a terminal node:
    std::vector<size_t> inbag_sampleIDs_left_child;
    std::vector<size_t> inbag_sampleIDs_right_child;
    for (size_t i = 0; i < split_varIDs.size(); ++i)
    {
      // If the node is the first node or is in the tree root and has children that are in the tree root, calculate the split criterion:
      if ((i == 0 || nodeID_in_root[i] != 0) && (child_nodeIDs[0][i] != 0 && nodeID_in_root[child_nodeIDs[0][i]] != 0)) // && inbag_sampleIDs_nodeID[i].size() >= 20)
      {

        // Determine the in-bag sample IDs of the two child nodes of the current node (hint: the in-bag sample IDs of the left child node are stored in inbag_sampleIDs_nodeID[child_nodeIDs[0][i]]):
        inbag_sampleIDs_left_child = inbag_sampleIDs_nodeID[child_nodeIDs[0][i]];
        inbag_sampleIDs_right_child = inbag_sampleIDs_nodeID[child_nodeIDs[1][i]];

        // Calculate the split criterion for the node using in-bag observations and multiply it by the number of in-bag observations that pass through the node:
        split_criterion[i] = computeSplitCriterion(inbag_sampleIDs_left_child, inbag_sampleIDs_right_child) * static_cast<double>(inbag_sampleIDs_nodeID[i].size());
      }
    }
  }

  // Compute the split scores for the CRTR analysis. 
  void Tree::computeOOBSplitCriterionValues()
  {

    // CP();

    // Initialize the split criterion for each node in the tree that is in the tree root and is not a terminal node:
    split_criterion.resize(split_varIDs.size(), -1);

    // Make a new vector to store the out of bag sample IDs.
    std::vector<size_t> oob_sampleIDs;
    // The out of bag sample IDs are the indices of inbag_counts that are equal to 0.
    for (size_t i = 0; i < inbag_counts.size(); ++i)
    {
      if (inbag_counts[i] == 0)
      {
        oob_sampleIDs.push_back(i);
      }
    }
    // CP();
    //  Drop all out-of-bag observations down the tree and store the node IDs
    std::vector<std::vector<size_t>> oob_sampleIDs_nodeID(split_varIDs.size());
    for (size_t sampleID : oob_sampleIDs)
    {
      size_t nodeID = 0;
      while (true)
      {
        oob_sampleIDs_nodeID[nodeID].push_back(sampleID);
        if (child_nodeIDs[0][nodeID] == 0 && child_nodeIDs[1][nodeID] == 0)
        {
          break;
        }
        size_t split_varID = split_varIDs[nodeID];
        double value = data->get(sampleID, split_varID);
        if (value <= split_values[nodeID])
        {
          nodeID = child_nodeIDs[0][nodeID];
        }
        else
        {
          nodeID = child_nodeIDs[1][nodeID];
        }
      }
    }
    // CP();

    // Calculate the split criterion for each node in the tree that is in the tree root and is not a terminal node:
    std::vector<size_t> oob_sampleIDs_left_child;
    std::vector<size_t> oob_sampleIDs_right_child;
    for (size_t i = 0; i < split_varIDs.size(); ++i)
    {
      // CP();
      //  If the node is the first node or is in the tree root and has children that are in the tree root and the split variable is in repr_vars, calculate the split criterion:
      if ((i == 0 || nodeID_in_root[i] != 0) && (child_nodeIDs[0][i] != 0 && nodeID_in_root[child_nodeIDs[0][i]] != 0)) // && inbag_sampleIDs_nodeID[i].size() >= 20)
      {
        // CP();
        if (std::find(repr_vars.begin(), repr_vars.end(), split_varIDs[i]) != repr_vars.end())
        {

          // Determine the out-of-bag sample IDs of the two child nodes of the current node (hint: the in-bag sample IDs of the left child node are stored in oob_sampleIDs_nodeID[child_nodeIDs[0][i]]):
          oob_sampleIDs_left_child = oob_sampleIDs_nodeID[child_nodeIDs[0][i]];
          oob_sampleIDs_right_child = oob_sampleIDs_nodeID[child_nodeIDs[1][i]];
          // CP();

          // If neither child node is empty, calculate the split criterion:
          if (!oob_sampleIDs_left_child.empty() && !oob_sampleIDs_right_child.empty())
          {
            // CP();
            //  Calculate the split criterion for the node using in-bag observations and multiply it by the number of in-bag observations that pass through the node:
            split_criterion[i] = computeSplitCriterion(oob_sampleIDs_left_child, oob_sampleIDs_right_child) * static_cast<double>(oob_sampleIDs_nodeID[i].size());
            // CP();
          }
        }
      }
    }
    // CP();
  }

  // Compute the split score value for a split in the unity VIM computation.
  double Tree::computeSplitCriterion(std::vector<size_t> sampleIDs_left_child, std::vector<size_t> sampleIDs_right_child)
  {
    // Default implementation: Throw an exception (or do nothing)
    throw std::runtime_error("computeSplitCriterion not implemented for this subclass.");
  }

  // Compute the split score value for a split in the CRTR analysis.
  double Tree::computeOOBSplitCriterionValue(size_t nodeID, std::vector<size_t> oob_sampleIDs_nodeID)
  {
    // Default implementation: Throw an exception (or do nothing)
    throw std::runtime_error("computeOOBSplitCriterionValue not implemented for this subclass.");
  }

  // Compute the OOB split criterion value for the node after permuting the OOB observations (unity VIM).
  double Tree::computeOOBSplitCriterionValuePermuted(size_t nodeID, std::vector<size_t> oob_sampleIDs_nodeID, std::vector<size_t> permutations)
  {

    // Default implementation: Throw an exception (or do nothing)
    throw std::runtime_error("computeOOBSplitCriterionValuePermuted not implemented for this subclass.");
  }

  // Collect information on all splits in each variable (needed for determining the top-scoring splits in the unity VIM computation). 
  void Tree::collectSplits(size_t tree_idx, std::vector<std::vector<SplitData>> &all_splits_per_variable)
  {

    // Loop over all nodes in the tree and collect the split data:
    for (size_t i = 0; i < split_varIDs.size(); ++i)
    {

      // Determine tree index:
      size_t tree_index = tree_idx;
      // Determine node ID:
      size_t nodeID = i;
      // Determine the variable ID of the split:
      size_t varID = split_varIDs[i];
      // Determine the split value:
      double split_value = split_criterion[i];

      // Create a SplitData object and add it to all_splits_per_variable:
      SplitData split_data(tree_index, nodeID, split_value);

      // Add the SplitData object to all_splits_per_variable:
      all_splits_per_variable[varID].push_back(split_data);
    }
  }

  // Collect information on all splits in each variable (needed for determining the top-scoring splits in the CRTR analysis). 
  void Tree::collectOOBSplits(size_t tree_idx, std::vector<std::vector<SplitData>> &all_splits_per_variable)
  {

    // Loop over all nodes in the tree and collect the split data:
    for (size_t i = 0; i < split_varIDs.size(); ++i)
    {

      // If split_varIDs[i] is in repr_vars, collect the split data:
      if (std::find(repr_vars.begin(), repr_vars.end(), split_varIDs[i]) != repr_vars.end())
      {

        // Determine tree index:
        size_t tree_index = tree_idx;
        // Determine node ID:
        size_t nodeID = i;
        // Determine the index of split_varIDs[i] in repr_vars:
        size_t varID_index = std::find(repr_vars.begin(), repr_vars.end(), split_varIDs[i]) - repr_vars.begin();
        // Determine the split value:
        double split_value = split_criterion[i];

        // Create a SplitData object and add it to all_splits_per_variable:
        SplitData split_data(tree_index, nodeID, split_value);

        // Add the SplitData object to all_splits_per_variable:
        all_splits_per_variable[varID_index].push_back(split_data);
      }
    }
  }

  // Count the number of times each covariate is used for splitting in the tree (needed for calculating the covariate scores in the CRTR analysis). 
  void Tree::countVariables(std::vector<size_t> &var_counts)
  {

    // Loop over all nodes in the tree and count the variables:
    for (size_t i = 0; i < split_varIDs.size(); ++i)
    {

      // If node is in the tree root and has children that are in the tree root, count the variable:
      if ((i == 0 || nodeID_in_root[i] != 0) && (child_nodeIDs[0][i] != 0 && nodeID_in_root[child_nodeIDs[0][i]] != 0))
      {
        // Increment the count for the variable:
        var_counts[split_varIDs[i]]++;
      }
    }
  }

  // Compute the Uv vectors needed in the determination of the representative trees in the CRTR analysis. 
  void Tree::computeUv(size_t tree_ind, std::vector<std::vector<double>> &Uv)
  {

    size_t depth_temp = 1;
    std::vector<size_t> curr_child_nodeIDs;

    for (size_t i = 0; i < split_varIDs.size(); ++i)
    {
      if ((i == 0 || nodeID_in_root[i] != 0) && (child_nodeIDs[0][i] != 0 && nodeID_in_root[child_nodeIDs[0][i]] != 0))
      {
        // If nodeID_in_root[i] is in curr_child_nodeIDs, increment depth_temp:
        if (std::find(curr_child_nodeIDs.begin(), curr_child_nodeIDs.end(), nodeID_in_root[i]) != curr_child_nodeIDs.end())
        {
          depth_temp++;
          // Clear the curr_child_nodeIDs vector:
          curr_child_nodeIDs.clear();
        }

        // Add 1/(2^(depth_temp-1)) to the Uv for the variable:
        Uv[tree_ind][split_varIDs[i]] += 1.0 / std::pow(2.0, depth_temp - 1);

        // Add the child node IDs to the curr_child_nodeIDs vector:
        curr_child_nodeIDs.push_back(nodeID_in_root[child_nodeIDs[0][i]]);
        curr_child_nodeIDs.push_back(nodeID_in_root[child_nodeIDs[1][i]]);
      }
    }

    // Divide all elements of Uv[tree_ind] by depth_temp:
    for (size_t i = 0; i < Uv[tree_ind].size(); ++i)
    {
      Uv[tree_ind][i] /= depth_temp;
    }
  }

  // Set the score values for the variables in the tree based on the scores_tree vector (CRTR analysis).
  void Tree::setScoreVector(std::vector<double> scores_tree)
  {

    // Initialize the vector score_values with the value -99:
    score_values.resize(split_varIDs.size(), -99.0);

    // Loop over all nodes in the tree and set the score values:
    for (size_t i = 0; i < split_varIDs.size(); ++i)
    {
      // If nodeID_in_root[i] is in curr_child_nodeIDs, set the score value:
      if ((i == 0 || nodeID_in_root[i] != 0) && (child_nodeIDs[0][i] != 0 && nodeID_in_root[child_nodeIDs[0][i]] != 0))
      {
        // Set the score value (Note: the i-th value of scores_tree is the score value for the i-th variable):
        score_values[i] = scores_tree[split_varIDs[i]];
      }
    }
  }

  // Mark the top-scoring splits (unity VIM). 
  void Tree::markBestSplits(size_t tree_idex, const std::vector<std::set<std::pair<size_t, size_t>>> &bestSplits)
  {

    // Reserve space for the vector is_in_best (note: the vector is_in_best is part of the Tree object):
    is_in_best.resize(split_varIDs.size());

    // Loop over all nodes in the tree and mark the best splits:
    for (size_t i = 0; i < split_varIDs.size(); ++i)
    {
      if (bestSplits[split_varIDs[i]].count(std::make_pair(tree_idex, i)) > 0)
      {
        // Mark the split as one of the best splits:
        is_in_best[i] = 1;
      }
      else
      {
        // Mark the split as not one of the best splits:
        is_in_best[i] = 0;
      }
    }
  }

  // Mark the top-scoring splits (CRTR analyis).
  void Tree::markBestOOBSplits(size_t tree_idex, const std::vector<std::set<std::pair<size_t, size_t>>> &bestSplits)
  {

    // The vector is_in_best is part of the Tree object.
    // Make is_in_best of length split_varIDs.size() and initialize it to 0:
    is_in_best.resize(split_varIDs.size(), 0);

    // Loop over all nodes in the tree and mark the best splits:
    for (size_t i = 0; i < split_varIDs.size(); ++i)
    {
      // If split_varIDs[i] is in repr_vars, mark the best splits:
      if (std::find(repr_vars.begin(), repr_vars.end(), split_varIDs[i]) != repr_vars.end())
      {
        size_t varID_index = std::find(repr_vars.begin(), repr_vars.end(), split_varIDs[i]) - repr_vars.begin();
        if (bestSplits[varID_index].count(std::make_pair(tree_idex, i)) > 0)
        {
          // Mark the split as one of the best splits:
          is_in_best[i] = 1;
        }
      }
    }
  }

  // For a node, determine the random set of split covariate candidates (only used in the tree sprouts). 
  void Tree::createPossibleSplitVarSubset(std::vector<size_t> &result)
  {

    size_t num_vars = data->getNumCols();

    // For corrected Gini importance add dummy variables
    if (importance_mode == IMP_GINI_CORRECTED)
    {
      num_vars += data->getNumCols() - data->getNoSplitVariables().size();
    }

    // Randomly add non-deterministic variables (according to weights if needed)
    if (split_select_weights->empty())
    {
      if (deterministic_varIDs->empty())
      {
        drawWithoutReplacementSkip(result, random_number_generator, num_vars, data->getNoSplitVariables(), mtry);
      }
      else
      {
        std::vector<size_t> skip;
        std::copy(data->getNoSplitVariables().begin(), data->getNoSplitVariables().end(),
                  std::inserter(skip, skip.end()));
        std::copy(deterministic_varIDs->begin(), deterministic_varIDs->end(), std::inserter(skip, skip.end()));
        std::sort(skip.begin(), skip.end());
        drawWithoutReplacementSkip(result, random_number_generator, num_vars, skip, mtry);
      }
    }
    else
    {
      drawWithoutReplacementWeighted(result, random_number_generator, *split_select_varIDs, mtry, *split_select_weights);
    }

    // Always use deterministic variables
    std::copy(deterministic_varIDs->begin(), deterministic_varIDs->end(), std::inserter(result, result.end()));
  }

  // Split node in tree sprout.
  bool Tree::splitNodeFullTree(size_t nodeID)
  {

    bool split_in_root;

    if ((nodeID_in_root[nodeID] != 0 || nodeID == 0) && !(child_nodeIDs_best[0][nodeID_in_root[nodeID]] == 0 && child_nodeIDs_best[1][nodeID_in_root[nodeID]] == 0))
    {
      // If node is in tree root and not terminal, use split from tree root
      split_varIDs[nodeID] = split_varIDs_best[nodeID_in_root[nodeID]];
      split_values[nodeID] = split_values_best[nodeID_in_root[nodeID]];

      // The split is in the tree root
      split_in_root = true;
    }
    else
    {

      // Select random subset of variables to possibly split at
      std::vector<size_t> possible_split_varIDs;
      createPossibleSplitVarSubset(possible_split_varIDs);

      // Call subclass method, sets split_varIDs and split_values
      bool stop = splitNodeInternal(nodeID, possible_split_varIDs);
      if (stop)
      {
        // Terminal node
        return true;
      }

      // The split is not in the tree root
      split_in_root = false;
    }

    size_t split_varID = split_varIDs[nodeID];
    double split_value = split_values[nodeID];

    // Save non-permuted variable for prediction
    split_varIDs[nodeID] = data->getUnpermutedVarID(split_varID);

    // Create child nodes
    size_t left_child_nodeID = split_varIDs.size();
    child_nodeIDs[0][nodeID] = left_child_nodeID;
    createEmptyNodeFullTree();
    if (split_in_root)
    {
      nodeID_in_root[left_child_nodeID] = child_nodeIDs_best[0][nodeID_in_root[nodeID]];
    }
    start_pos[left_child_nodeID] = start_pos[nodeID];

    size_t right_child_nodeID = split_varIDs.size();
    child_nodeIDs[1][nodeID] = right_child_nodeID;
    createEmptyNodeFullTree();
    if (split_in_root)
    {
      nodeID_in_root[right_child_nodeID] = child_nodeIDs_best[1][nodeID_in_root[nodeID]];
    }
    start_pos[right_child_nodeID] = end_pos[nodeID];

    // For each sample in node, assign to left or right child

    // Ordered: left is <= splitval and right is > splitval
    size_t pos = start_pos[nodeID];
    while (pos < start_pos[right_child_nodeID])
    {
      size_t sampleID = sampleIDs[pos];
      if (data->get(sampleID, split_varID) <= split_value)
      {
        // If going to left, do nothing
        ++pos;
      }
      else
      {
        // If going to right, move to right end
        --start_pos[right_child_nodeID];
        std::swap(sampleIDs[pos], sampleIDs[start_pos[right_child_nodeID]]);
      }
    }

    // End position of left child is start position of right child
    end_pos[left_child_nodeID] = start_pos[right_child_nodeID];
    end_pos[right_child_nodeID] = end_pos[nodeID];

    // No terminal node
    return false;
  }

  // Split node in random candidate tree root.
  bool Tree::splitNodeRandom(size_t nodeID, const std::vector<size_t> &varIDs_root)
  {

    // CP();

    bool stop = checkWhetherFinalRandom(nodeID);

    // CP();

    if (stop)
    {
      // Terminal node
      return true;
    }

    // Randomly draw one variable out of varIDs_root that has at least two unique values in the current node and is not a no split variable

    std::uniform_int_distribution<size_t> uni(0, varIDs_root.size() - 1);

    size_t varID;

    bool found_variable = false;
    size_t drawnvarID;

    // CP();

    // for (size_t varID : sampled_varIDs)
    for (size_t t = 0; t < 500; ++t)
    {

      varID = varIDs_root[uni(random_number_generator)];

      if (twoDifferentValues(nodeID, varID))
      {
        drawnvarID = varID;
        found_variable = true;
        break;
      }
    }

    // CP();

    if (!found_variable)
    {
      return true; // No suitable variable found; make node terminal
    }

    // Collect values into a vector
    values_buffer.clear();
    values_buffer.reserve(end_pos_loop[nodeID] - start_pos_loop[nodeID]);

    for (size_t pos = start_pos_loop[nodeID]; pos < end_pos_loop[nodeID]; ++pos)
    {
      values_buffer.push_back(data->get(sampleIDs[pos], drawnvarID));
    }

    // CP();

    // Sort and remove duplicates
    std::sort(values_buffer.begin(), values_buffer.end());
    auto last = std::unique(values_buffer.begin(), values_buffer.end());
    values_buffer.erase(last, values_buffer.end());

    // Select a split value between two neighboring unique values
    std::uniform_int_distribution<size_t> unif_dist(0, values_buffer.size() - 2);
    size_t index = unif_dist(random_number_generator);
    double split_value = (values_buffer[index] + values_buffer[index + 1]) / 2;

    // Set the split variable and split value
    split_varIDs_loop[nodeID] = data->getUnpermutedVarID(drawnvarID);
    split_values_loop[nodeID] = split_value;

    // CP();

    // Create child nodes
    size_t left_child_nodeID = split_varIDs_loop.size();
    child_nodeIDs_loop[0][nodeID] = left_child_nodeID;
    createEmptyNodeRandomTree();
    start_pos_loop[left_child_nodeID] = start_pos_loop[nodeID];

    size_t right_child_nodeID = split_varIDs_loop.size();
    child_nodeIDs_loop[1][nodeID] = right_child_nodeID;
    createEmptyNodeRandomTree();
    start_pos_loop[right_child_nodeID] = end_pos_loop[nodeID];

    // For each sample in node, assign to left or right child

    // CP();

    // Ordered: left is <= splitval and right is > splitval
    size_t pos = start_pos_loop[nodeID];
    while (pos < start_pos_loop[right_child_nodeID])
    {
      size_t sampleID = sampleIDs[pos];
      if (data->get(sampleID, drawnvarID) <= split_value)
      {
        // If going to left, do nothing
        ++pos;
      }
      else
      {
        // If going to right, move to right end
        --start_pos_loop[right_child_nodeID];
        std::swap(sampleIDs[pos], sampleIDs[start_pos_loop[right_child_nodeID]]);
      }
    }

    // CP();

    // End position of left child is start position of right child
    end_pos_loop[left_child_nodeID] = start_pos_loop[right_child_nodeID];
    end_pos_loop[right_child_nodeID] = end_pos_loop[nodeID];

    // CP();

    // No terminal node
    return false;
  }

  // Check whether a variable has two different values in a node (needed for determining the split variables in the random candidate tree roots).
  bool Tree::twoDifferentValues(size_t nodeID, size_t varID)
  {
    const double first =
        data->get(sampleIDs[start_pos_loop[nodeID]], varID);

    for (size_t pos = start_pos_loop[nodeID] + 1; pos < end_pos_loop[nodeID]; ++pos)
    {
      if (data->get(sampleIDs[pos], varID) != first)
        return true;
    }

    return false; // all equal
  }

  // Create an empty node in a random candidate tree root.
  void Tree::createEmptyNodeRandomTree()
  {
    split_varIDs_loop.push_back(0);
    split_values_loop.push_back(0);
    child_nodeIDs_loop[0].push_back(0);
    child_nodeIDs_loop[1].push_back(0);
    start_pos_loop.push_back(0);
    end_pos_loop.push_back(0);

    createEmptyNodeRandomTreeInternal();
  }

  // Create an empty node in a tree sprout.
  void Tree::createEmptyNodeFullTree()
  {
    split_varIDs.push_back(0);
    split_values.push_back(0);
    child_nodeIDs[0].push_back(0);
    child_nodeIDs[1].push_back(0);
    start_pos.push_back(0);
    end_pos.push_back(0);
    nodeID_in_root.push_back(0);

    createEmptyNodeFullTreeInternal();
  }

  // Function used to clear some objects from the random candidate tree root.
  void Tree::clearRandomTree()
  {
    split_varIDs_loop.clear();
    split_values_loop.clear();
    start_pos_loop.clear();
    end_pos_loop.clear();
    child_nodeIDs_loop[0].clear();
    child_nodeIDs_loop[1].clear();

    clearRandomTreeInternal();
  }

  void Tree::bootstrap()
  {

    // Use fraction (default 63.21%) of the samples
    size_t num_samples_inbag = (size_t)num_samples * (*sample_fraction)[0];

    // Reserve space, reserve a little more to be save)
    sampleIDs.reserve(num_samples_inbag);
    oob_sampleIDs.reserve(num_samples * (exp(-(*sample_fraction)[0]) + 0.1));

    std::uniform_int_distribution<size_t> unif_dist(0, num_samples - 1);

    // Start with all samples OOB
    inbag_counts.resize(num_samples, 0);

    // Draw num_samples samples with replacement (num_samples_inbag out of n) as inbag and mark as not OOB
    for (size_t s = 0; s < num_samples_inbag; ++s)
    {
      size_t draw = unif_dist(random_number_generator);
      sampleIDs.push_back(draw);
      ++inbag_counts[draw];
    }

    // Save OOB samples
    for (size_t s = 0; s < inbag_counts.size(); ++s)
    {
      if (inbag_counts[s] == 0)
      {
        oob_sampleIDs.push_back(s);
      }
    }
    num_samples_oob = oob_sampleIDs.size();

    if (!keep_inbag)
    {
      inbag_counts.clear();
      inbag_counts.shrink_to_fit();
    }
  }

  void Tree::bootstrapWeighted()
  {

    // Use fraction (default 63.21%) of the samples
    size_t num_samples_inbag = (size_t)num_samples * (*sample_fraction)[0];

    // Reserve space, reserve a little more to be save)
    sampleIDs.reserve(num_samples_inbag);
    oob_sampleIDs.reserve(num_samples * (exp(-(*sample_fraction)[0]) + 0.1));

    std::discrete_distribution<> weighted_dist(case_weights->begin(), case_weights->end());

    // Start with all samples OOB
    inbag_counts.resize(num_samples, 0);

    // Draw num_samples samples with replacement (n out of n) as inbag and mark as not OOB
    for (size_t s = 0; s < num_samples_inbag; ++s)
    {
      size_t draw = weighted_dist(random_number_generator);
      sampleIDs.push_back(draw);
      ++inbag_counts[draw];
    }

    // Save OOB samples. In holdout mode these are the cases with 0 weight.
    if (holdout)
    {
      for (size_t s = 0; s < (*case_weights).size(); ++s)
      {
        if ((*case_weights)[s] == 0)
        {
          oob_sampleIDs.push_back(s);
        }
      }
    }
    else
    {
      for (size_t s = 0; s < inbag_counts.size(); ++s)
      {
        if (inbag_counts[s] == 0)
        {
          oob_sampleIDs.push_back(s);
        }
      }
    }
    num_samples_oob = oob_sampleIDs.size();

    if (!keep_inbag)
    {
      inbag_counts.clear();
      inbag_counts.shrink_to_fit();
    }
  }

  void Tree::bootstrapWithoutReplacement()
  {

    // Use fraction (default 63.21%) of the samples
    size_t num_samples_inbag = (size_t)num_samples * (*sample_fraction)[0];
    shuffleAndSplit(sampleIDs, oob_sampleIDs, num_samples, num_samples_inbag, random_number_generator);
    num_samples_oob = oob_sampleIDs.size();

    if (keep_inbag)
    {
      // All observation are 0 or 1 times inbag
      inbag_counts.resize(num_samples, 1);
      for (size_t i = 0; i < oob_sampleIDs.size(); i++)
      {
        inbag_counts[oob_sampleIDs[i]] = 0;
      }
    }
  }

  void Tree::bootstrapWithoutReplacementWeighted()
  {

    // Use fraction (default 63.21%) of the samples
    size_t num_samples_inbag = (size_t)num_samples * (*sample_fraction)[0];
    drawWithoutReplacementWeighted(sampleIDs, random_number_generator, num_samples - 1, num_samples_inbag, *case_weights);

    // All observation are 0 or 1 times inbag
    inbag_counts.resize(num_samples, 0);
    for (auto &sampleID : sampleIDs)
    {
      inbag_counts[sampleID] = 1;
    }

    // Save OOB samples. In holdout mode these are the cases with 0 weight.
    if (holdout)
    {
      for (size_t s = 0; s < (*case_weights).size(); ++s)
      {
        if ((*case_weights)[s] == 0)
        {
          oob_sampleIDs.push_back(s);
        }
      }
    }
    else
    {
      for (size_t s = 0; s < inbag_counts.size(); ++s)
      {
        if (inbag_counts[s] == 0)
        {
          oob_sampleIDs.push_back(s);
        }
      }
    }
    num_samples_oob = oob_sampleIDs.size();

    if (!keep_inbag)
    {
      inbag_counts.clear();
      inbag_counts.shrink_to_fit();
    }
  }

  void Tree::bootstrapClassWise()
  {
    // Empty on purpose (virtual function only implemented in classification and probability)
  }

  void Tree::bootstrapWithoutReplacementClassWise()
  {
    // Empty on purpose (virtual function only implemented in classification and probability)
  }

  void Tree::setManualInbag()
  {
    // Select observation as specified in manual_inbag vector
    sampleIDs.reserve(manual_inbag->size());
    inbag_counts.resize(num_samples, 0);
    for (size_t i = 0; i < manual_inbag->size(); ++i)
    {
      size_t inbag_count = (*manual_inbag)[i];
      if ((*manual_inbag)[i] > 0)
      {
        for (size_t j = 0; j < inbag_count; ++j)
        {
          sampleIDs.push_back(i);
        }
        inbag_counts[i] = inbag_count;
      }
      else
      {
        oob_sampleIDs.push_back(i);
      }
    }
    num_samples_oob = oob_sampleIDs.size();

    // Shuffle samples
    std::shuffle(sampleIDs.begin(), sampleIDs.end(), random_number_generator);

    if (!keep_inbag)
    {
      inbag_counts.clear();
      inbag_counts.shrink_to_fit();
    }
  }

} // namespace unityForest
