package org.deeplearning4j.mlp;

import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter;
import com.beust.jcommander.ParameterException;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.spark.api.TrainingMaster;
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.List;

/**
 * @author Alex Black
 */
public class MnistMLPExample {
    private static final Logger log = LoggerFactory.getLogger(MnistMLPExample.class);

    @Parameter(names = "-useSparkLocal", description = "Lokalne środowisko Spark (do testów, bez skryptu spark-submit)", arity = 1)
    private boolean useSparkLocal = true;

    @Parameter(names = "-batchSizePerWorker", description = "Liczba rekordów do przetworzenia przez każdy węzeł wykonawczy")
    private int batchSizePerWorker = 16;

    @Parameter(names = "-numEpochs", description = "Liczba epok treningowych")
    private int numEpochs = 15;

    public static void main(String[] args) throws Exception {
        new MnistMLPExample().entryPoint(args);
    }

    protected void entryPoint(String[] args) throws Exception {
        // Analiza argumentów z wiersza poleceń.
        JCommander jcmdr = new JCommander(this);
        try {
            jcmdr.parse(args);
        } catch (ParameterException e) {
            // Błędne argumenty -> wyświetlenie informacji.
            jcmdr.usage();
            try { Thread.sleep(500); } catch (Exception e2) { }
            throw e;
        }

        SparkConf sparkConf = new SparkConf();
        if (useSparkLocal) {
            sparkConf.setMaster("local[*]");
        }
        sparkConf.setAppName("Wielowarstwowy perceptron dla środowiska Spark");
        JavaSparkContext sc = new JavaSparkContext(sparkConf);

        // Załadowanie danych do pamięci i zrównoleglenie.
        // Z reguły nie jest to dobra praktyka, ale prosta w zastosowaniu w tym przykładzie.
        DataSetIterator iterTrain = new MnistDataSetIterator(batchSizePerWorker, true, 12345);
        DataSetIterator iterTest = new MnistDataSetIterator(batchSizePerWorker, true, 12345);
        List<DataSet> trainDataList = new ArrayList<>();
        List<DataSet> testDataList = new ArrayList<>();
        while (iterTrain.hasNext()) {
            trainDataList.add(iterTrain.next());
        }
        while (iterTest.hasNext()) {
            testDataList.add(iterTest.next());
        }

        JavaRDD<DataSet> trainData = sc.parallelize(trainDataList);
        JavaRDD<DataSet> testData = sc.parallelize(testDataList);


        //----------------------------------
        // Przygotowanie konfiguracji i przetrenowanie sieci.
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(12345)
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1)
            .activation(Activation.LEAKYRELU)
            .weightInit(WeightInit.XAVIER)
            .learningRate(0.02)
            .updater(Updater.NESTEROVS).momentum(0.9)
            .regularization(true).l2(1e-4)
            .list()
            .layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(500).build())
            .layer(1, new DenseLayer.Builder().nIn(500).nOut(100).build())
            .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                .activation(Activation.SOFTMAX).nIn(100).nOut(10).build())
            .pretrain(false).backprop(true)
            .build();

        // Konfiguracja środowiska Spark na potrzeby treningu.
        // Opis opcji konfiguracyjnych dostępny jest na stronie http://deeplearning4j.org/spark.
        TrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(batchSizePerWorker)    // Każdy obiekt DataSet zawiera domyślnie 32 rekordy.
            .averagingFrequency(5)
            .workerPrefetchNumBatches(2)    // Asynchroniczne odczytywanie 2 dodatkowych rekordów przez każdy węzeł wykonawczy.

            .batchSizePerWorker(batchSizePerWorker)
            .build();

        // Utworzenie sieci w środowisku Spark.
        SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, conf, tm);

        // Przeprowadzenie treningu.
        for (int i = 0; i < numEpochs; i++) {
            sparkNet.fit(trainData);
            log.info("Zakończona epoka {}", i);
        }

        // Ocena modelu (rozproszonego).
        Evaluation evaluation = sparkNet.evaluate(testData);
        log.info("***** Ocena *****");
        log.info(evaluation.stats());

        // Usunięcie tymczasowych plików treningowych, ponieważ nie będą już potrzebne.
        tm.deleteTempFiles(sc);

        log.info("***** Koniec przykładu *****");
    }
}
