package org.deeplearning4j.examples.convolution;

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Utworzony przez agibsonccc on 16.09.2015.
 */
public class LenetMnistExample {
    private static final Logger log = LoggerFactory.getLogger(LenetMnistExample.class);

    public static void main(String[] args) throws Exception {
        int nChannels = 1;  // Liczba kanałów wejściowych.
        int outputNum = 10; // Liczba możliwych wyników.
        int batchSize = 64; // Wielkość paczki testowej.
        int nEpochs = 1;    // Liczba epok treningowych.
        int iterations = 1; // Liczba iteracji treningowych.
        int seed = 123;

        /*
            Utworzenie iteratora i zdefiniowanie wielkości paczki dla każdej iteracji.
        */
        log.info("Load data....");
        DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize,true,12345);
        DataSetIterator mnistTest = new MnistDataSetIterator(batchSize,false,12345);

        /*
            Utworzenie sieci neuronowej.
         */
        log.info("Ładowanie danych...");
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .iterations(iterations) // Iteracje treningowe.
                .regularization(true).l2(0.0005)
                /*
                    Odkomentuj poniższe wiersze, aby zmniejszyć szybkość uczenia i zmienić obciążenie.
                 */
                .learningRate(.01)//.biasLearningRate(0.02)
                //.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(0.001).lrPolicyPower(0.75)
                .weightInit(WeightInit.XAVIER)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .updater(Updater.NESTEROVS).momentum(0.9)
                .list()
                .layer(0, new ConvolutionLayer.Builder(5, 5)
                        // Metody nIn i nOut definiują głębokość; argumentem metody nIn jest liczba kanałów, a nOut liczba filtrów, które będą zastosowane.
                        .nIn(nChannels)
                        .stride(1, 1)
                        .nOut(20)
                        .activation(Activation.IDENTITY)
                        .build())
                .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                        .kernelSize(2,2)
                        .stride(2,2)
                        .build())
                .layer(2, new ConvolutionLayer.Builder(5, 5)
                        // Zwróć uwagę, że w następnych warstwach nie trzeba stosować metody nIn.
                        .stride(1, 1)
                        .nOut(50)
                        .activation(Activation.IDENTITY)
                        .build())
                .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                        .kernelSize(2,2)
                        .stride(2,2)
                        .build())
                .layer(4, new DenseLayer.Builder().activation(Activation.RELU)
                        .nOut(500).build())
                .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                        .nOut(outputNum)
                        .activation(Activation.SOFTMAX)
                        .build())
                .setInputType(InputType.convolutionalFlat(28,28,1)) //See note below
                .backprop(true).pretrain(false).build();

        /*
          Kilka uwag do wiersza .setInputType(InputType.convolutionalFlat(28,28,1)):
          (a) Dodaje preprocesory wykonujące takie operacje jak przejścia pomiędzy
              warstwą konwolucyjną a wstępnie próbkującą lub gęstą.
          (b) Dodatkowo weryfikuje konfigurację.
          (c) W razie potrzeby wywołuje metodę nIn (określającą liczbę neuronów wejściowych,
              czyli głębokość wejścia w przypadku sieci CNN) dla każdej warstwy z argumentem
              równym wielkości poprzedniej warstwy (ale nie nadpisuje wartości ręcznie
              określonych przez użytkownika).
              Klasy InputType  można użyć również dla innych rodzajów warstw (RNN, MLP itp.),
              nie tylko dla CNN.
          W przypadku zwykłych obrazów (jeżeli użyta jest klasa ImageRecordReader) należy
          użyć klasy InputType.convolutional(height,width,depth).
          Czytnik rekordów MNIST jest specjalnym przypadkiem zwracającym obrazy o wymiarach
          28x28 pikseli z odcieniami szarości (nChannels=1) w „spłaszczonym” formacie
          wektorowym (tj. jako wektory 1x784); dlatego został tu użyty typ convolutionalFlat.
        */

        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();


        log.info("Trening modelu...");
        model.setListeners(new ScoreIterationListener(1));
        for( int i=0; i<nEpochs; i++ ) {
            model.fit(mnistTrain);
            log.info("*** Zakończona epoka {} ***", i);

            log.info("Ocena modelu...");
            Evaluation eval = new Evaluation(outputNum);
            while(mnistTest.hasNext()){
                DataSet ds = mnistTest.next();
                INDArray output = model.output(ds.getFeatureMatrix(), false);
                eval.eval(ds.getLabels(), output);

            }
            log.info(eval.stats());
            mnistTest.reset();
        }
        log.info("**************** Koniec przykładu ********************");
    }
}
