/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers.rules.multilabel.attributeclassobservers;

import com.github.javacliparser.IntOption;
import java.io.Serializable;
import moa.classifiers.rules.core.NumericRulePredicate;
import moa.classifiers.rules.core.Utils;
import moa.classifiers.rules.multilabel.attributeclassobservers.NumericStatisticsObserver;
import moa.classifiers.rules.multilabel.core.AttributeExpansionSuggestion;
import moa.classifiers.rules.multilabel.core.splitcriteria.MultiLabelSplitCriterion;
import moa.core.DoubleVector;
import moa.core.ObjectRepository;
import moa.options.AbstractOptionHandler;
import moa.tasks.TaskMonitor;

public class MultiLabelBSTree
extends AbstractOptionHandler
implements NumericStatisticsObserver {
    public IntOption maxNodesOption = new IntOption("maxNodesOption", 'z', "Maximum number of nodes", 50, 0, Integer.MAX_VALUE);
    protected int maxNodes;
    protected int numNodes;
    private static final long serialVersionUID = 1L;
    protected Node root = null;
    protected DoubleVector[] leftStatistics;
    protected DoubleVector[] rightStatistics;

    @Override
    public void observeAttribute(double inputAttributeValue, DoubleVector[] statistics) {
        if (!Double.isNaN(inputAttributeValue)) {
            if (this.root == null) {
                this.root = new Node(inputAttributeValue, statistics);
                this.maxNodes = this.maxNodesOption.getValue();
            } else {
                this.root.observeAttribute(inputAttributeValue, statistics);
            }
        }
    }

    @Override
    public AttributeExpansionSuggestion getBestEvaluatedSplitSuggestion(MultiLabelSplitCriterion criterion, DoubleVector[] preSplitStatistics, int inputAttributeIndex) {
        int numOutputs = preSplitStatistics.length;
        this.leftStatistics = new DoubleVector[numOutputs];
        this.rightStatistics = new DoubleVector[numOutputs];
        for (int i = 0; i < numOutputs; ++i) {
            this.leftStatistics[i] = new DoubleVector(new double[preSplitStatistics[i].numValues()]);
            this.rightStatistics[i] = new DoubleVector(preSplitStatistics[i]);
        }
        AttributeExpansionSuggestion ret = this.searchForBestSplitOption(this.root, null, criterion, preSplitStatistics, inputAttributeIndex);
        this.leftStatistics = null;
        this.rightStatistics = null;
        return ret;
    }

    protected AttributeExpansionSuggestion searchForBestSplitOption(Node currentNode, AttributeExpansionSuggestion currentBestOption, MultiLabelSplitCriterion criterion, DoubleVector[] preSplitStatistics, int inputAttributeIndex) {
        if (currentNode == null) {
            return currentBestOption;
        }
        if (currentNode.left != null) {
            currentBestOption = this.searchForBestSplitOption(currentNode.left, currentBestOption, criterion, preSplitStatistics, inputAttributeIndex);
        }
        for (int i = 0; i < this.leftStatistics.length; ++i) {
            this.leftStatistics[i].addValues(currentNode.leftStatistics[i]);
            this.rightStatistics[i].subtractValues(currentNode.leftStatistics[i]);
        }
        DoubleVector[][] postSplitDists = new DoubleVector[this.leftStatistics.length][2];
        for (int i = 0; i < this.leftStatistics.length; ++i) {
            postSplitDists[i] = new DoubleVector[2];
            postSplitDists[i][0] = this.leftStatistics[i];
            postSplitDists[i][1] = this.rightStatistics[i];
        }
        double merit = criterion.getMeritOfSplit(preSplitStatistics, postSplitDists);
        if (currentBestOption == null || merit > currentBestOption.merit) {
            currentBestOption = new AttributeExpansionSuggestion(new NumericRulePredicate(inputAttributeIndex, currentNode.cutPoint, true), Utils.copy(postSplitDists), merit);
        }
        if (currentNode.right != null) {
            currentBestOption = this.searchForBestSplitOption(currentNode.right, currentBestOption, criterion, preSplitStatistics, inputAttributeIndex);
        }
        for (int i = 0; i < this.leftStatistics.length; ++i) {
            this.leftStatistics[i].subtractValues(currentNode.leftStatistics[i]);
            this.rightStatistics[i].addValues(currentNode.leftStatistics[i]);
        }
        return currentBestOption;
    }

    @Override
    public String getPurposeString() {
        return "Stores statistics for all output attributes for a giver input attribute.";
    }

    @Override
    public void getDescription(StringBuilder sb, int indent) {
    }

    @Override
    protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
    }

    protected class Node
    implements Serializable {
        private static final long serialVersionUID = 1L;
        private double cutPoint;
        private DoubleVector[] leftStatistics;
        private DoubleVector[] rightStatistics;
        private Node left;
        private Node right;

        public Node(double inputAttributeValue, DoubleVector[] statistics) {
            int i;
            this.cutPoint = inputAttributeValue;
            int numOutputAttributes = statistics.length;
            this.leftStatistics = new DoubleVector[numOutputAttributes];
            for (i = 0; i < numOutputAttributes; ++i) {
                this.leftStatistics[i] = new DoubleVector(statistics[i]);
            }
            this.rightStatistics = new DoubleVector[numOutputAttributes];
            for (i = 0; i < numOutputAttributes; ++i) {
                this.rightStatistics[i] = new DoubleVector();
            }
        }

        public void observeAttribute(double inputAttributeValue, DoubleVector[] statistics) {
            if (inputAttributeValue == this.cutPoint) {
                for (int i = 0; i < statistics.length; ++i) {
                    this.leftStatistics[i].addValues(statistics[i]);
                }
            } else if (inputAttributeValue < this.cutPoint) {
                for (int i = 0; i < statistics.length; ++i) {
                    this.leftStatistics[i].addValues(statistics[i]);
                }
                if (this.left == null) {
                    if (MultiLabelBSTree.this.numNodes < MultiLabelBSTree.this.maxNodes) {
                        this.left = new Node(inputAttributeValue, statistics);
                        ++MultiLabelBSTree.this.numNodes;
                    }
                } else {
                    this.left.observeAttribute(inputAttributeValue, statistics);
                }
            } else {
                for (int i = 0; i < statistics.length; ++i) {
                    this.rightStatistics[i].addValues(statistics[i]);
                }
                if (this.right == null) {
                    if (MultiLabelBSTree.this.numNodes < MultiLabelBSTree.this.maxNodes) {
                        this.right = new Node(inputAttributeValue, statistics);
                        ++MultiLabelBSTree.this.numNodes;
                    }
                } else {
                    this.right.observeAttribute(inputAttributeValue, statistics);
                }
            }
        }
    }
}

