/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.tree;

import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.ml.tree.DecisionTreeModel;
import org.apache.spark.ml.tree.InternalNode;
import org.apache.spark.ml.tree.LeafNode;
import org.apache.spark.ml.tree.Node;
import org.apache.spark.ml.tree.TreeEnsembleModel$;
import org.apache.spark.util.collection.OpenHashMap;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Iterable$;
import scala.collection.Seq;
import scala.collection.TraversableOnce;
import scala.collection.generic.GenericTraversableTemplate;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.StringBuilder;
import scala.math.Numeric;
import scala.math.Ordering;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

public final class TreeEnsembleModel$ {
    public static final TreeEnsembleModel$ MODULE$;

    static {
        new TreeEnsembleModel$();
    }

    public <M extends DecisionTreeModel> Vector featureImportances(M[] trees, int numFeatures) {
        Tuple2 tuple2;
        int d;
        int n;
        OpenHashMap.mcD.sp totalImportances = new OpenHashMap.mcD.sp(ClassTag$.MODULE$.Int(), ClassTag$.MODULE$.Double());
        Predef$.MODULE$.refArrayOps((Object[])trees).foreach((Function1)new Serializable((OpenHashMap)totalImportances){
            public static final long serialVersionUID = 0L;
            public final OpenHashMap totalImportances$1;

            public final void apply(M tree) {
                OpenHashMap.mcD.sp importances = new OpenHashMap.mcD.sp(ClassTag$.MODULE$.Int(), ClassTag$.MODULE$.Double());
                TreeEnsembleModel$.MODULE$.computeFeatureImportance(tree.rootNode(), (OpenHashMap<Object, Object>)importances);
                double treeNorm = BoxesRunTime.unboxToDouble((Object)((TraversableOnce)importances.map((Function1)new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final double apply(Tuple2<Object, Object> x$4) {
                        return x$4._2$mcD$sp();
                    }
                }, Iterable$.MODULE$.canBuildFrom())).sum((Numeric)Numeric.DoubleIsFractional$.MODULE$));
                if (treeNorm != 0.0) {
                    importances.foreach((Function1)new Serializable(this, treeNorm){
                        public static final long serialVersionUID = 0L;
                        private final /* synthetic */ anonfun.featureImportances.1 $outer;
                        private final double treeNorm$1;

                        public final double apply(Tuple2<Object, Object> x0$2) {
                            Tuple2<Object, Object> tuple2 = x0$2;
                            if (tuple2 != null) {
                                int idx = tuple2._1$mcI$sp();
                                double impt = tuple2._2$mcD$sp();
                                double normImpt = impt / this.treeNorm$1;
                                double d = this.$outer.totalImportances$1.changeValue$mcD$sp((Object)BoxesRunTime.boxToInteger((int)idx), (Function0)new Serializable(this, normImpt){
                                    public static final long serialVersionUID = 0L;
                                    private final double normImpt$1;

                                    public final double apply() {
                                        return this.apply$mcD$sp();
                                    }

                                    public double apply$mcD$sp() {
                                        return this.normImpt$1;
                                    }
                                    {
                                        this.normImpt$1 = normImpt$1;
                                    }
                                }, (Function1)new Serializable(this, normImpt){
                                    public static final long serialVersionUID = 0L;
                                    private final double normImpt$1;

                                    public final double apply(double x$5) {
                                        return this.apply$mcDD$sp(x$5);
                                    }

                                    public double apply$mcDD$sp(double x$5) {
                                        return x$5 + this.normImpt$1;
                                    }
                                    {
                                        this.normImpt$1 = normImpt$1;
                                    }
                                });
                                return d;
                            }
                            throw new MatchError(tuple2);
                        }
                        {
                            if ($outer == null) {
                                throw null;
                            }
                            this.$outer = $outer;
                            this.treeNorm$1 = treeNorm$1;
                        }
                    });
                }
            }
            {
                this.totalImportances$1 = totalImportances$1;
            }
        });
        this.normalizeMapValues((OpenHashMap<Object, Object>)totalImportances);
        if (numFeatures != -1) {
            n = numFeatures;
        } else {
            int maxFeatureIndex = BoxesRunTime.unboxToInt((Object)Predef$.MODULE$.intArrayOps((int[])Predef$.MODULE$.refArrayOps((Object[])trees).map((Function1)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final int apply(M x$6) {
                    return x$6.maxSplitFeatureIndex();
                }
            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int()))).max((Ordering)Ordering.Int$.MODULE$));
            n = d = maxFeatureIndex + 1;
        }
        if (d == 0) {
            Predef$.MODULE$.assert(totalImportances.size() == 0, (Function0)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final String apply() {
                    return new StringBuilder().append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Unknown error in computing feature"})).s((Seq)Nil$.MODULE$)).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{" importance: No splits found, but some non-zero importances."})).s((Seq)Nil$.MODULE$)).toString();
                }
            });
        }
        if ((tuple2 = ((GenericTraversableTemplate)totalImportances.iterator().toSeq().sortBy((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final int apply(Tuple2<Object, Object> x$7) {
                return x$7._1$mcI$sp();
            }
        }, (Ordering)Ordering.Int$.MODULE$)).unzip((Function1)Predef$.MODULE$.$conforms())) != null) {
            Tuple2 tuple22;
            Seq indices = (Seq)tuple2._1();
            Seq values = (Seq)tuple2._2();
            Tuple2 tuple23 = tuple22 = new Tuple2((Object)indices, (Object)values);
            Seq indices2 = (Seq)tuple23._1();
            Seq values2 = (Seq)tuple23._2();
            return Vectors$.MODULE$.sparse(d, (int[])indices2.toArray(ClassTag$.MODULE$.Int()), (double[])values2.toArray(ClassTag$.MODULE$.Double()));
        }
        throw new MatchError((Object)tuple2);
    }

    public <M extends DecisionTreeModel> Vector featureImportances(M tree, int numFeatures, ClassTag<M> evidence$1) {
        return this.featureImportances((DecisionTreeModel[])((Object[])new DecisionTreeModel[]{tree}), numFeatures);
    }

    public void computeFeatureImportance(Node node, OpenHashMap<Object, Object> importances) {
        Node node2;
        while ((node2 = node) instanceof InternalNode) {
            InternalNode internalNode = (InternalNode)node2;
            int feature = internalNode.split().featureIndex();
            double scaledGain = internalNode.gain() * (double)internalNode.impurityStats().count();
            importances.changeValue$mcD$sp((Object)BoxesRunTime.boxToInteger((int)feature), (Function0)new Serializable(scaledGain){
                public static final long serialVersionUID = 0L;
                private final double scaledGain$1;

                public final double apply() {
                    return this.apply$mcD$sp();
                }

                public double apply$mcD$sp() {
                    return this.scaledGain$1;
                }
                {
                    this.scaledGain$1 = scaledGain$1;
                }
            }, (Function1)new Serializable(scaledGain){
                public static final long serialVersionUID = 0L;
                private final double scaledGain$1;

                public final double apply(double x$9) {
                    return this.apply$mcDD$sp(x$9);
                }

                public double apply$mcDD$sp(double x$9) {
                    return x$9 + this.scaledGain$1;
                }
                {
                    this.scaledGain$1 = scaledGain$1;
                }
            });
            this.computeFeatureImportance(internalNode.leftChild(), importances);
            node = internalNode.rightChild();
        }
        if (node2 instanceof LeafNode) {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
            return;
        }
        throw new MatchError((Object)node2);
    }

    public void normalizeMapValues(OpenHashMap<Object, Object> map) {
        double total = BoxesRunTime.unboxToDouble((Object)((TraversableOnce)map.map((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final double apply(Tuple2<Object, Object> x$10) {
                return x$10._2$mcD$sp();
            }
        }, Iterable$.MODULE$.canBuildFrom())).sum((Numeric)Numeric.DoubleIsFractional$.MODULE$));
        if (total != 0.0) {
            int[] keys = (int[])map.iterator().map((Function1)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final int apply(Tuple2<Object, Object> x$11) {
                    return x$11._1$mcI$sp();
                }
            }).toArray(ClassTag$.MODULE$.Int());
            Predef$.MODULE$.intArrayOps(keys).foreach((Function1)new Serializable(map, total){
                public static final long serialVersionUID = 0L;
                private final OpenHashMap map$1;
                public final double total$1;

                public final double apply(int key) {
                    return this.apply$mcDI$sp(key);
                }

                public double apply$mcDI$sp(int key) {
                    return this.map$1.changeValue$mcD$sp((Object)BoxesRunTime.boxToInteger((int)key), (Function0)new Serializable(this){
                        public static final long serialVersionUID = 0L;

                        public final double apply() {
                            return this.apply$mcD$sp();
                        }

                        public double apply$mcD$sp() {
                            return 0.0;
                        }
                    }, (Function1)new Serializable(this){
                        public static final long serialVersionUID = 0L;
                        private final /* synthetic */ anonfun.normalizeMapValues.1 $outer;

                        public final double apply(double x$12) {
                            return this.apply$mcDD$sp(x$12);
                        }

                        public double apply$mcDD$sp(double x$12) {
                            return x$12 / this.$outer.total$1;
                        }
                        {
                            if ($outer == null) {
                                throw null;
                            }
                            this.$outer = $outer;
                        }
                    });
                }
                {
                    this.map$1 = map$1;
                    this.total$1 = total$1;
                }
            });
        }
    }

    private TreeEnsembleModel$() {
        MODULE$ = this;
    }
}

