/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.mllib.optimization;

import breeze.linalg.DenseVector;
import breeze.linalg.DenseVector$;
import breeze.linalg.NumericOps;
import breeze.linalg.norm$;
import breeze.math.Field;
import breeze.storage.Zero;
import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.internal.Logging;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors$;
import org.apache.spark.mllib.optimization.Gradient;
import org.apache.spark.mllib.optimization.Updater;
import org.apache.spark.rdd.RDD;
import org.slf4j.Logger;
import scala.Function0;
import scala.Function2;
import scala.MatchError;
import scala.None$;
import scala.Predef$;
import scala.Serializable;
import scala.Some;
import scala.StringContext;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.Seq;
import scala.collection.TraversableOnce;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.StringBuilder;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;

@DeveloperApi
public final class GradientDescent$
implements Logging,
Serializable {
    public static final GradientDescent$ MODULE$;
    private transient Logger org$apache$spark$internal$Logging$$log_;

    static {
        new GradientDescent$();
    }

    public Logger org$apache$spark$internal$Logging$$log_() {
        return this.org$apache$spark$internal$Logging$$log_;
    }

    public void org$apache$spark$internal$Logging$$log__$eq(Logger x$1) {
        this.org$apache$spark$internal$Logging$$log_ = x$1;
    }

    public String logName() {
        return Logging.class.logName((Logging)this);
    }

    public Logger log() {
        return Logging.class.log((Logging)this);
    }

    public void logInfo(Function0<String> msg) {
        Logging.class.logInfo((Logging)this, msg);
    }

    public void logDebug(Function0<String> msg) {
        Logging.class.logDebug((Logging)this, msg);
    }

    public void logTrace(Function0<String> msg) {
        Logging.class.logTrace((Logging)this, msg);
    }

    public void logWarning(Function0<String> msg) {
        Logging.class.logWarning((Logging)this, msg);
    }

    public void logError(Function0<String> msg) {
        Logging.class.logError((Logging)this, msg);
    }

    public void logInfo(Function0<String> msg, Throwable throwable) {
        Logging.class.logInfo((Logging)this, msg, (Throwable)throwable);
    }

    public void logDebug(Function0<String> msg, Throwable throwable) {
        Logging.class.logDebug((Logging)this, msg, (Throwable)throwable);
    }

    public void logTrace(Function0<String> msg, Throwable throwable) {
        Logging.class.logTrace((Logging)this, msg, (Throwable)throwable);
    }

    public void logWarning(Function0<String> msg, Throwable throwable) {
        Logging.class.logWarning((Logging)this, msg, (Throwable)throwable);
    }

    public void logError(Function0<String> msg, Throwable throwable) {
        Logging.class.logError((Logging)this, msg, (Throwable)throwable);
    }

    public boolean isTraceEnabled() {
        return Logging.class.isTraceEnabled((Logging)this);
    }

    public void initializeLogIfNecessary(boolean isInterpreter) {
        Logging.class.initializeLogIfNecessary((Logging)this, (boolean)isInterpreter);
    }

    public boolean initializeLogIfNecessary(boolean isInterpreter, boolean silent) {
        return Logging.class.initializeLogIfNecessary((Logging)this, (boolean)isInterpreter, (boolean)silent);
    }

    public boolean initializeLogIfNecessary$default$2() {
        return Logging.class.initializeLogIfNecessary$default$2((Logging)this);
    }

    public Tuple2<Vector, double[]> runMiniBatchSGD(RDD<Tuple2<Object, Vector>> data, Gradient gradient2, Updater updater, double stepSize, int numIterations, double regParam, double miniBatchFraction, Vector initialWeights, double convergenceTol) {
        if (miniBatchFraction < 1.0 && convergenceTol > 0.0) {
            this.logWarning((Function0<String>)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final String apply() {
                    return "Testing against a convergenceTol when using miniBatchFraction < 1.0 can be unstable because of the stochasticity in sampling.";
                }
            });
        }
        if ((double)numIterations * miniBatchFraction < 1.0) {
            this.logWarning((Function0<String>)new Serializable(numIterations, miniBatchFraction){
                public static final long serialVersionUID = 0L;
                private final int numIterations$1;
                private final double miniBatchFraction$1;

                public final String apply() {
                    return new StringBuilder().append((Object)"Not all examples will be used if numIterations * miniBatchFraction < 1.0: ").append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"numIterations=", " and miniBatchFraction=", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)this.numIterations$1), BoxesRunTime.boxToDouble((double)this.miniBatchFraction$1)}))).toString();
                }
                {
                    this.numIterations$1 = numIterations$1;
                    this.miniBatchFraction$1 = miniBatchFraction$1;
                }
            });
        }
        ArrayBuffer stochasticLossHistory = new ArrayBuffer(numIterations);
        None$ previousWeights = None$.MODULE$;
        None$ currentWeights = None$.MODULE$;
        long numExamples = data.count();
        if (numExamples == 0L) {
            this.logWarning((Function0<String>)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final String apply() {
                    return "GradientDescent.runMiniBatchSGD returning initial weights, no data found";
                }
            });
            return new Tuple2((Object)initialWeights, stochasticLossHistory.toArray(ClassTag$.MODULE$.Double()));
        }
        if ((double)numExamples * miniBatchFraction < 1.0) {
            this.logWarning((Function0<String>)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final String apply() {
                    return "The miniBatchFraction is too small";
                }
            });
        }
        Vector weights = Vectors$.MODULE$.dense(initialWeights.toArray());
        int n = weights.size();
        double regVal = updater.compute(weights, Vectors$.MODULE$.zeros(weights.size()), 0.0, 1, regParam)._2$mcD$sp();
        boolean converged = false;
        IntRef i = IntRef.create((int)1);
        while (!converged && i.elem <= numIterations) {
            int x$5;
            Serializable x$4;
            Serializable x$3;
            Tuple3 x$2;
            Broadcast bcWeights = data.context().broadcast((Object)weights, ClassTag$.MODULE$.apply(Vector.class));
            RDD qual$1 = data.sample(false, miniBatchFraction, (long)(42 + i.elem));
            Tuple3 tuple3 = (Tuple3)qual$1.treeAggregate((Object)(x$2 = new Tuple3((Object)DenseVector$.MODULE$.zeros$mDc$sp(n, ClassTag$.MODULE$.Double(), (Zero)Zero.DoubleZero$.MODULE$), (Object)BoxesRunTime.boxToDouble((double)0.0), (Object)BoxesRunTime.boxToLong((long)0L))), (Function2)(x$3 = new Serializable(gradient2, bcWeights){
                public static final long serialVersionUID = 0L;
                private final Gradient gradient$1;
                private final Broadcast bcWeights$1;

                public final Tuple3<DenseVector<Object>, Object, Object> apply(Tuple3<DenseVector<Object>, Object, Object> c, Tuple2<Object, Vector> v) {
                    double l = this.gradient$1.compute((Vector)v._2(), v._1$mcD$sp(), (Vector)this.bcWeights$1.value(), Vectors$.MODULE$.fromBreeze((breeze.linalg.Vector<Object>)((breeze.linalg.Vector)c._1())));
                    return new Tuple3(c._1(), (Object)BoxesRunTime.boxToDouble((double)(BoxesRunTime.unboxToDouble((Object)c._2()) + l)), (Object)BoxesRunTime.boxToLong((long)(BoxesRunTime.unboxToLong((Object)c._3()) + 1L)));
                }
                {
                    this.gradient$1 = gradient$1;
                    this.bcWeights$1 = bcWeights$1;
                }
            }), (Function2)(x$4 = new Serializable(){
                public static final long serialVersionUID = 0L;

                public final Tuple3<DenseVector<Object>, Object, Object> apply(Tuple3<DenseVector<Object>, Object, Object> c1, Tuple3<DenseVector<Object>, Object, Object> c2) {
                    return new Tuple3(((NumericOps)c1._1()).$plus$eq(c2._1(), DenseVector$.MODULE$.canAddIntoD()), (Object)BoxesRunTime.boxToDouble((double)(BoxesRunTime.unboxToDouble((Object)c1._2()) + BoxesRunTime.unboxToDouble((Object)c2._2()))), (Object)BoxesRunTime.boxToLong((long)(BoxesRunTime.unboxToLong((Object)c1._3()) + BoxesRunTime.unboxToLong((Object)c2._3()))));
                }
            }), x$5 = qual$1.treeAggregate$default$4((Object)x$2), ClassTag$.MODULE$.apply(Tuple3.class));
            if (tuple3 != null) {
                Tuple3 tuple32;
                DenseVector gradientSum = (DenseVector)tuple3._1();
                double lossSum = BoxesRunTime.unboxToDouble((Object)tuple3._2());
                long miniBatchSize = BoxesRunTime.unboxToLong((Object)tuple3._3());
                Tuple3 tuple33 = tuple32 = new Tuple3((Object)gradientSum, (Object)BoxesRunTime.boxToDouble((double)lossSum), (Object)BoxesRunTime.boxToLong((long)miniBatchSize));
                DenseVector gradientSum2 = (DenseVector)tuple33._1();
                double lossSum2 = BoxesRunTime.unboxToDouble((Object)tuple33._2());
                long miniBatchSize2 = BoxesRunTime.unboxToLong((Object)tuple33._3());
                bcWeights.destroy(false);
                if (miniBatchSize2 > 0L) {
                    stochasticLossHistory.$plus$eq((Object)BoxesRunTime.boxToDouble((double)(lossSum2 / (double)miniBatchSize2 + regVal)));
                    Tuple2<Vector, Object> update2 = updater.compute(weights, Vectors$.MODULE$.fromBreeze((breeze.linalg.Vector<Object>)((breeze.linalg.Vector)gradientSum2.$div((Object)BoxesRunTime.boxToDouble((double)miniBatchSize2), DenseVector$.MODULE$.dv_s_Op_Double_OpDiv()))), stepSize, i.elem, regParam);
                    weights = (Vector)update2._1();
                    regVal = update2._2$mcD$sp();
                    previousWeights = currentWeights;
                    currentWeights = new Some((Object)weights);
                    None$ none$ = previousWeights;
                    None$ none$2 = None$.MODULE$;
                    if (none$ == null ? none$2 != null : !none$.equals(none$2)) {
                        None$ none$3 = currentWeights;
                        None$ none$4 = None$.MODULE$;
                        if (none$3 == null ? none$4 != null : !none$3.equals(none$4)) {
                            converged = this.isConverged((Vector)previousWeights.get(), (Vector)currentWeights.get(), convergenceTol);
                        }
                    }
                } else {
                    this.logWarning((Function0<String>)new Serializable(numIterations, i){
                        public static final long serialVersionUID = 0L;
                        private final int numIterations$1;
                        private final IntRef i$1;

                        public final String apply() {
                            return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Iteration (", "/", "). The size of sampled batch is zero"})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)this.i$1.elem), BoxesRunTime.boxToInteger((int)this.numIterations$1)}));
                        }
                        {
                            this.numIterations$1 = numIterations$1;
                            this.i$1 = i$1;
                        }
                    });
                }
                ++i.elem;
                continue;
            }
            throw new MatchError((Object)tuple3);
        }
        this.logInfo((Function0<String>)new Serializable(stochasticLossHistory){
            public static final long serialVersionUID = 0L;
            private final ArrayBuffer stochasticLossHistory$1;

            public final String apply() {
                return new StringOps(Predef$.MODULE$.augmentString("GradientDescent.runMiniBatchSGD finished. Last 10 stochastic losses %s")).format((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{((TraversableOnce)this.stochasticLossHistory$1.takeRight(10)).mkString(", ")}));
            }
            {
                this.stochasticLossHistory$1 = stochasticLossHistory$1;
            }
        });
        return new Tuple2((Object)weights, stochasticLossHistory.toArray(ClassTag$.MODULE$.Double()));
    }

    public Tuple2<Vector, double[]> runMiniBatchSGD(RDD<Tuple2<Object, Vector>> data, Gradient gradient2, Updater updater, double stepSize, int numIterations, double regParam, double miniBatchFraction, Vector initialWeights) {
        return this.runMiniBatchSGD(data, gradient2, updater, stepSize, numIterations, regParam, miniBatchFraction, initialWeights, 0.001);
    }

    private boolean isConverged(Vector previousWeights, Vector currentWeights, double convergenceTol) {
        DenseVector currentBDV;
        DenseVector previousBDV = previousWeights.asBreeze().toDenseVector$mcD$sp(ClassTag$.MODULE$.Double());
        double solutionVecDiff = BoxesRunTime.unboxToDouble((Object)norm$.MODULE$.apply(previousBDV.$minus((Object)(currentBDV = currentWeights.asBreeze().toDenseVector$mcD$sp(ClassTag$.MODULE$.Double())), DenseVector$.MODULE$.canSubD()), DenseVector$.MODULE$.canNorm((Field)Field.fieldDouble$.MODULE$)));
        return solutionVecDiff < convergenceTol * Math.max(BoxesRunTime.unboxToDouble((Object)norm$.MODULE$.apply((Object)currentBDV, DenseVector$.MODULE$.canNorm((Field)Field.fieldDouble$.MODULE$))), 1.0);
    }

    private Object readResolve() {
        return MODULE$;
    }

    private GradientDescent$() {
        MODULE$ = this;
        Logging.class.$init$((Logging)this);
    }
}

