# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Szkolenie CIFAR-10 z wykorzystaniem wielu GPU z synchronicznymi aktualizacjami.

Dokładność:
cifar10_multi_gpu_train.py osiąga dokładność ~86% po 100 tys. krokach (256 epokach
danych) zgodnie z oceną dokonaną poprzez cifar10_eval.py.

Prędkość: Przy rozmiarze grupy 128.

System        | Czas kroku (sek./grupę) |     Accuracy
--------------------------------------------------------------------
1 Tesla K20m  | 0,35-0,60               | ~86% przy 60 tys. kroków  (5 godzin)
1 Tesla K40m  | 0,25-0,35               | ~86% przy 100 tys. kroków (4 godziny)
2 Tesla K20m  | 0,13-0,20               | ~84% przy 30 tys. kroków  (2,5 godziny)
3 Tesla K20m  | 0,13-0,18               | ~84% przy 30 tys. kroków
4 Tesla K20m  | ~0,10                   | ~84% przy 30 tys. kroków

Sposób użycia:
Zapoznaj się z samouczkiem i stroną internetową, aby dowiedzieć się, jak
pobrać zestaw danych CIFAR-10, skompilować program i przeszkolić model.

http://tensorflow.org/tutorials/deep_cnn/
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from datetime import datetime
import os.path
import shutil
import re
import time

import numpy as np
from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf
import cifar10

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train',
                           """Katalog zapisu dzienników zdarzeń """
                           """i punktów kontrolnych.""")
tf.app.flags.DEFINE_integer('max_steps', 1000000,
                            """Liczba grup do uruchomienia.""")
tf.app.flags.DEFINE_integer('num_gpus', 1,
                            """Liczba używanych układów GPU.""")
tf.app.flags.DEFINE_boolean('log_device_placement', False,
                            """Czy rejestrować rozmieszczenie urządzeń.""")

def _activation_summary(x):
  """Funkcja pomocnicza do tworzenia podsumowań dla aktywacji.

  Tworzy podsumowanie zapewniające histogram aktywacji.
  Tworzy podsumowanie, które mierzy rzadkość aktywacji.

  Argumenty:
    x: Tensor
  Zwracane wartości:
    brak
  """
  # Usuwa 'filar_[0-9]/' z nazwy, jeśli jest to sesja szkolenia na wielu GPU.
  # Pomaga to zapewnić przejrzystość prezentacji w tensorboard.
  tensor_name = re.sub('%s_[0-9]*/' % cifar10.TOWER_NAME, '', x.op.name)
  tf.summary.histogram(tensor_name + '/aktywacje', x)
  tf.summary.scalar(tensor_name + '/rzadkość',
                                       tf.nn.zero_fraction(x))


def _variable_on_cpu(name, shape, initializer):
  """Funkcja pomocnicza do tworzenia zmiennej przechowywanej w pamięci CPU.

  Argumenty:
    name: nazwa zmiennej
    shape: lista liczb całkowitych
    initializer: inicjalizator zmiennej

  Zwracane wartości:
    Tensor zmiennej
  """
  with tf.device('/cpu:0'):
    var = tf.get_variable(name, shape, initializer=initializer, dtype=tf.float32)
  return var


def _variable_with_weight_decay(name, shape, stddev, wd):
  """Funkcja pomocnicza do tworzenia zainicjowanej zmiennej z rozkładem wag.

  Zauważ, że zmienna jest inicjalizowana z obciętym rozkładem normalnym.
  Rozkład wag jest dodawany tylko wtedy, gdy został określony.

  Argumenty:
    name: nazwa zmiennej
    shape: lista liczb całkowitych
    stddev: odchylenie standardowe obciętego rozkładu Gaussa
    wd: dodanie rozkładu wag L2Loss pomnożone przez tę liczbę zmiennoprzecinkową.
	    Przy wartości None rozkład wag nie zostanie dodany do tej zmiennej.

  Zwracane wartości:
    Tensor zmiennej
  """
  var = _variable_on_cpu(
      name,
      shape,
      tf.truncated_normal_initializer(stddev=stddev, dtype=tf.float32))
  if wd is not None:
    weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='strata_wag')
    tf.add_to_collection('straty', weight_decay)
  return var

def inference(images):
  """Budowanie modelu CIFAR-10.

  Argumenty:
    images: Obrazy zwracane przez distorted_inputs() lub inputs().

  Zwracane wartości:
    Logity.
  """
  # Tworzymy instancje wszystkich zmiennych przy użyciu tf.get_variable() zamiast
  # tf.Variable() w celu współdzielenia zmiennych w wielu przebiegach szkolenia na
  # GPU. Gdybyśmy uruchamiali ten model na pojedynczym GPU, moglibyśmy uprościć tę
  # funkcję, zastępując wszystkie wystąpienia tf.get_variable() przez tf.Variable().
  #
  # splot1
  with tf.variable_scope('splot1') as scope:
    kernel = _variable_with_weight_decay('wagi',
                                         shape=[5, 5, 3, 64],
                                         stddev=5e-2,
                                         wd=0.0)
    conv = tf.nn.conv2d(images, kernel, [1, 1, 1, 1], padding='SAME')
    biases = _variable_on_cpu('obciazenia', [64], tf.constant_initializer(0.0))
    pre_activation = tf.nn.bias_add(conv, biases)
    conv1 = tf.nn.relu(pre_activation, name=scope.name)
    _activation_summary(conv1)

  # łącz1
  pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
                         padding='SAME', name='łącz1')
  # norm1
  norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,
                    name='norm1')

  # splot2
  with tf.variable_scope('splot2') as scope:
    kernel = _variable_with_weight_decay('wagi',
                                         shape=[5, 5, 64, 64],
                                         stddev=5e-2,
                                         wd=0.0)
    conv = tf.nn.conv2d(norm1, kernel, [1, 1, 1, 1], padding='SAME')
    biases = _variable_on_cpu('obciazenia', [64], tf.constant_initializer(0.1))
    pre_activation = tf.nn.bias_add(conv, biases)
    conv2 = tf.nn.relu(pre_activation, name=scope.name)
    _activation_summary(conv2)

  # norm2
  norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,
                    name='norm2')
  # łącz2
  pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1],
                         strides=[1, 2, 2, 1], padding='SAME', name='łącz2')

  # lokal3
  with tf.variable_scope('lokal3') as scope:
    # Nadanie głębi, aby umożliwić przeprowadzenie operacji mnożenia macierzy.
    reshape = tf.reshape(pool2, [FLAGS.batch_size, -1])
    dim = reshape.get_shape()[1].value
    weights = _variable_with_weight_decay('wagi', shape=[dim, 384],
                                          stddev=0.04, wd=0.004)
    biases = _variable_on_cpu('obciazenia', [384], tf.constant_initializer(0.1))
    local3 = tf.nn.relu(tf.matmul(reshape, weights) + biases, name=scope.name)
    _activation_summary(local3)

  # lokal4
  with tf.variable_scope('lokal4') as scope:
    weights = _variable_with_weight_decay('wagi', shape=[384, 192],
                                          stddev=0.04, wd=0.004)
    biases = _variable_on_cpu('obciazenia', [192], tf.constant_initializer(0.1))
    local4 = tf.nn.relu(tf.matmul(local3, weights) + biases, name=scope.name)
    _activation_summary(local4)

  # warstwa liniowa (WX + b),
  # Nie stosujemy tutaj funkcji softmax, ponieważ
  # tf.nn.sparse_softmax_cross_entropy_with_logits akceptuje nieskalowane logity
  # i wykonuje softmax wewnętrznie dla zwiększenia wydajności.
  with tf.variable_scope('liniowa_softmax') as scope:
    weights = _variable_with_weight_decay('wagi', [192, cifar10.NUM_CLASSES],
                                          stddev=1/192.0, wd=0.0)
    biases = _variable_on_cpu('obciazenia', [cifar10.NUM_CLASSES],
                              tf.constant_initializer(0.0))
    softmax_linear = tf.add(tf.matmul(local4, weights), biases, name=scope.name)
    _activation_summary(softmax_linear)

  return softmax_linear

def loss(logits, labels):
  """Dodaje L2Loss do wszystkich możliwych do wyuczenia zmiennych.

  Dodaje podsumowania dla "Loss" i "Loss/avg".
  Argumenty:
    logits: Logity z inference().
    labels: Etykiety z distorted_inputs lub inputs(). Tensor 1-D
            o kształcie [batch_size]

  Zwracane wartości:
    Tensor straty typu float.
  """
  # Obliczanie średniej straty entropii krzyżowej w całej grupie.
  labels = tf.cast(labels, tf.int64)
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
      labels=labels, logits=logits, name='entropia_krzyzowa_na_przyklad')
  cross_entropy_mean = tf.reduce_mean(cross_entropy, name='entropia_krzyzowa')
  tf.add_to_collection('straty', cross_entropy_mean)

  # Całkowita strata jest definiowana jako strata entropii krzyżowej plus
  # wszystkie warunki rozkładu wag (strata L2).
  return tf.add_n(tf.get_collection('straty'), name='strata_razem')

def tower_loss(scope, images, labels):
  """Obliczenie całkowitej straty na pojedynczym filarze z uruchomionym modelem CIFAR.

  Argumenty:
    scope: unikatowy przedrostek identyfikujący filar CIFAR, np. 'filar_0'
    images: Obrazy. Tensor 4D o kształcie [batch_size, height, width, 3].
    labels: Etykiety. Tensor 1D tensor of shape [batch_size].

  Zwracane wartości:
    Tensor o kształcie [] zawierający całkowitą stratę dla grupy danych
  """

  # Budowanie grafu wnioskowania.
  logits = inference(images)

  # Budowanie części grafu obliczającego straty. Zauważ, że total_loss
  # będziemy wyliczać przy użyciu poniższej własnej funkcji.
  _ = loss(logits, labels)

  # Zebranie wszystkich strat tylko dla obecnego filaru.
  losses = tf.get_collection('straty', scope)

  # Obliczenie całkowitej straty dla bieżącego filaru.
  total_loss = tf.add_n(losses, name='strata_razem')

  # Dołączenie podsumowania skalarnego dla wszystkich indywidualnych strat
  # i straty całkowitej; to samo dotyczy uśrednionej wersji strat.
  for l in losses + [total_loss]:
    # Usuwa 'filar_[0-9]/' z nazwy, jeśli jest to sesja szkolenia na wielu GPU.
    # Pomaga to zapewnić przejrzystość prezentacji w tensorboard.
    loss_name = re.sub('%s_[0-9]*/' % cifar10.TOWER_NAME, '', l.op.name)
    tf.summary.scalar(loss_name, l)

  return total_loss


def average_gradients(tower_grads):
  """Obliczanie średniego gradientu dla każdej zmiennej współdzielonej
     pomiędzy wszystkie filary.

  Zauważ, że funkcja ta zapewnia punkt synchronizacji pomiędzy wszystkimi filarami.

  Argumenty:
    tower_grads: Lista list krotek (gradient, zmienna). Zewnętrzna
    lista dotyczy poszczególnych gradientów. Wewnętrzna lista dotyczy
    obliczeń gradientów dla każdego filaru.
  Zwracane wartości:
     Lista par (gradient, zmienna) gdzie gradient został uśredniony
     na podstawie wszystkich filarów.
  """
  average_grads = []
  for grad_and_vars in zip(*tower_grads):
    # Zauważ, że każdy element grad_and_vars wygląda następująco:
    #   ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
    grads = []
    for g, _ in grad_and_vars:
      # Dodanie do gradientów wymiaru 0 w celu przedstawienia filaru.
      expanded_g = tf.expand_dims(g, 0)

      # Dodanie wymiaru 'filar', który będziemy poniżej uśredniać.
      grads.append(expanded_g)

    # Usrednianie wymiaru 'filar'.
    grad = tf.concat(axis=0, values=grads)
    grad = tf.reduce_mean(grad, 0)

    # Należy pamiętać, że zmienne są nadmiarowe, ponieważ są współdzielone między
	# filarami. Więc... po prostu zwrócimy wskaźnik pierwszego filaru do zmiennej.
    v = grad_and_vars[0][1]
    grad_and_var = (grad, v)
    average_grads.append(grad_and_var)
  return average_grads


def train():
  """Szkolenie CIFAR-10 dla określonej liczby kroków."""
  with tf.Graph().as_default(), tf.device('/cpu:0'):
    # Utworzenie zmiennej zliczającej liczbę wywołań train(). Jest to równe
    # liczbie przetworzonych grup * FLAGS.num_gpus.
    global_step = tf.get_variable(
        'krok_globalny', [],
        initializer=tf.constant_initializer(0), trainable=False)

    # Obliczanie harmonogramu współczynnika uczenia.
    num_batches_per_epoch = (cifar10.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN /
                             FLAGS.batch_size)
    decay_steps = int(num_batches_per_epoch * cifar10.NUM_EPOCHS_PER_DECAY)

    # Zmniejszanie współczynnika uczenia wykładniczo w oparciu o liczbę kroków.
    lr = tf.train.exponential_decay(cifar10.INITIAL_LEARNING_RATE,
                                    global_step,
                                    decay_steps,
                                    cifar10.LEARNING_RATE_DECAY_FACTOR,
                                    staircase=True)

    # Utworzenie optymalizatora, który przeprowadza gradient prosty.
    opt = tf.train.GradientDescentOptimizer(lr)

    # Pobieranie obrazów i etykiet dla CIFAR-10.
    images, labels = cifar10.distorted_inputs()
    batch_queue = tf.contrib.slim.prefetch_queue.prefetch_queue(
          [images, labels], capacity=2 * FLAGS.num_gpus)
    # Obliczanie gradientów dla każdego filaru modelu.
    tower_grads = []
    with tf.variable_scope(tf.get_variable_scope()):
      for i in xrange(FLAGS.num_gpus):
        with tf.device('/gpu:%d' % i):
          with tf.name_scope('%s_%d' % (cifar10.TOWER_NAME, i)) as scope:
            # Usuwanie z kolejki jednej grupy dla GPU
            image_batch, label_batch = batch_queue.dequeue()
            # Obliczanie straty dla jednego filaru modelu CIFAR. Ta funkcja tworzy
            # cały model CIFAR, ale współdzieli zmienne pomiędzy wszystkie filary.
            loss = tower_loss(scope, image_batch, label_batch)

            # Ponowne użycie zmiennych dla następnego filaru.
            tf.get_variable_scope().reuse_variables()

            # Zachowanie podsumowań z ostatniego filaru.
            summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)

            # Obliczanie gradientów dla grupy danych dla tego filaru CIFAR.
            grads = opt.compute_gradients(loss)

            # Śledzenie gradientów na wszystkich filarach.
            tower_grads.append(grads)

    # Musimy obliczyć średnią każdego gradientu. Zauważ, że jest to punkt
    # synchronizacji na wszystkich filarach.
    grads = average_gradients(tower_grads)

    # Dodanie podsumowania w celu śledzenia współczynnika uczenia.
    summaries.append(tf.summary.scalar('wsp_uczenia', lr))

    # Dodanie histogramów dla gradientów.
    for grad, var in grads:
      if grad is not None:
        summaries.append(tf.summary.histogram(var.op.name + '/gradienty', grad))

    # Zastosowanie gradientów w celu dostosowania współdzielonych zmiennych.
    apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)

    # Dodanie histogramów dla możliwych do wyuczenia zmiennych.
    for var in tf.trainable_variables():
      summaries.append(tf.summary.histogram(var.op.name, var))

    # Śledzenie średnich kroczących wszystkich możliwych do wyuczenia zmiennych.
    variable_averages = tf.train.ExponentialMovingAverage(
        cifar10.MOVING_AVERAGE_DECAY, global_step)
    variables_averages_op = variable_averages.apply(tf.trainable_variables())

    # Grupowanie wszystkich aktualizacji w pojedynczą operację szkolenia.
    train_op = tf.group(apply_gradient_op, variables_averages_op)

    # Utworzenie zapisu.
    saver = tf.train.Saver(tf.global_variables())

    # Budowanie operacji podsumowującej na podstawie ostatnich podsumowań filaru.
    summary_op = tf.summary.merge(summaries)

    # Budowanie operacji inicjalizacji do uruchomienia poniżej.
    init = tf.global_variables_initializer()

    # Rozpoczęcie uruchamiania operacji na grafie. allow_soft_placement musi być
    # ustawione na True, aby budować filary na GPU, ponieważ niektóre operacje
    # nie mają implementacji GPU.
    sess = tf.Session(config=tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=FLAGS.log_device_placement))
    sess.run(init)

    # Uruchomienie operacji w kolejce.
    tf.train.start_queue_runners(sess=sess)

    summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)

    for step in xrange(FLAGS.max_steps):
      start_time = time.time()
      _, loss_value = sess.run([train_op, loss])
      duration = time.time() - start_time

      assert not np.isnan(loss_value), 'Model rozbieżny ze stratą = NaN'

      if step % 10 == 0:
        num_examples_per_step = FLAGS.batch_size * FLAGS.num_gpus
        examples_per_sec = num_examples_per_step / duration
        sec_per_batch = duration / FLAGS.num_gpus

        format_str = ('%s: krok %d, strata = %.2f (%.1f przykładów/sek.; %.3f '
                      'sek./grupę)')
        print (format_str % (datetime.now(), step, loss_value,
                             examples_per_sec, sec_per_batch))

      if step % 100 == 0:
        summary_str = sess.run(summary_op)
        summary_writer.add_summary(summary_str, step)

      # Okresowy zapis punktu kontrolnego modelu.
      if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
        checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
        saver.save(sess, checkpoint_path, global_step=step)


def main(argv=None):  # pylint: disable=unused-argument
  cifar10.maybe_download_and_extract()
  if os.path.exists(FLAGS.train_dir):
    shutil.rmtree(FLAGS.train_dir)
  os.makedirs(FLAGS.train_dir)
  train()


if __name__ == '__main__':
  main()
