# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Budowanie sieci CIFAR-10.

Podsumowanie dostępnych funkcji:

 # Przetwarzanie obrazów wejściowych i etykiet na potrzeby szkolenia. Jeśli
 # chcesz przeprowadzić ewaluację, użyj zamiast tego inputs().
 inputs, labels = distorted_inputs()

 # Obliczanie wnioskowania na podstawie danych wejściowych modelu w celu
 # wykonania predykcji.
 predictions = inference(inputs)

 # Obliczenie całkowitej straty predykcji w odniesieniu do etykiet.
 loss = loss(predictions, labels)

 # Utworzenie grafu do uruchomienia jednego kroku szkolenia w odniesieniu do straty.
 train_op = train(loss, global_step)
"""
# pylint: disable=missing-docstring
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import re
import sys
import tarfile

from six.moves import urllib, xrange # pylint: disable=redefined-builtin
import tensorflow as tf

import cifar10_input

FLAGS = tf.app.flags.FLAGS

# Podstawowe parametry modelu.
tf.app.flags.DEFINE_integer('batch_size', 128,
                            """Liczba obrazów do przetworzenia w grupie.""")
tf.app.flags.DEFINE_string('data_dir', '/tmp/cifar10_data',
                           """Ścieżka do katalogu danych CIFAR-10.""")

# Przetwarzaj obrazy w tym rozmiarze. Zauważ, że różni się on od oryginalnego
# rozmiaru obrazów CIFAR 32 x 32. Jeśli ktoś zmieni tę liczbę, wówczas zmieni się
# cała architektura modelu i każdy model będzie musiał zostać ponownie przeszkolony.
IMAGE_SIZE = 24
NUM_CLASSES = 10

# Stałe globalne opisujące zbiór danych CIFAR-10.
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000
NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000


# Stałe opisujące proces szkolenia.
MOVING_AVERAGE_DECAY = 0.9999     # Spadek do wykorzystania dla średniej kroczącej.
NUM_EPOCHS_PER_DECAY = 350.0      # Liczba epok, po których zmniejsza się współczynnik uczenia.
LEARNING_RATE_DECAY_FACTOR = 0.1  # Współczynnik spadku tempa uczenia.
INITIAL_LEARNING_RATE = 0.1       # Początkowy współczynnik uczenia.

# Jeśli model jest szkolony z wykorzystaniem wielu układów GPU, nazwy wszystkich
# operacji zostaną dla odróżnienia poprzedzone prefiksem określonym w tower_name.
# Zauważ, że prefiks ten jest usuwany z nazw podsumowań podczas wizualizacji modelu.
TOWER_NAME = 'filar'

DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'


def distorted_inputs():
  """Konstruowanie zniekształconych danych wejściowych dla szkolenia CIFAR przy użyciu operacji Reader.

  Zwracane wartości:
    images: Obrazy. Tensor 4D o rozmiarach [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3].
    labels: Etykiety. Tensor 1D o rozmiarze [batch_size].

  Komunikaty o błędach:
    ValueError: Jeśli brak data_dir
  """
  if not FLAGS.data_dir:
    raise ValueError('Proszę podać data_dir')
  data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
  images, labels = _distorted_inputs(data_dir=data_dir,
                                     batch_size=FLAGS.batch_size)
  return images, labels

def read_cifar10(filename_queue):
  """Odczyt i przetwarzanie przykładów z plików danych CIFAR10.

  Zalecenie: jeśli chcesz N równoległych odczytów, wywołaj tę funkcję N razy.
  To da Ci N niezależnych funkcji wczytujących różne pliki i pozycje w obrębie
  tych plików, co zapewni lepsze wymieszanie przykładów.

  Argumenty:
    filename_queue: Kolejka łańcuchów z nazwami plików do odczytu.

  Zwracane wartości:
    Obiekt reprezentujący pojedynczy przykład, z następującymi polami:
      height: liczba wierszy w wyniku (32)
      width: liczba kolumn w wyniku (32)
      depth: liczba kanałów kolorów w wyniku (3)
      key: skalarny tensor znakowy opisujący nazwę pliku i numer rekordu
        dla tego przykładu.
      label: tensor typu int32 z etykietą w zakresie 0..9.
      uint8image: tensor [height, width, depth] typu uint8 z danymi obrazu
  """

  class CIFAR10Record(object):
    pass
  result = CIFAR10Record()

  # Wymiary obrazów w zbiorze danych CIFAR-10.
  # Opis formatu wejściowego znajduje się na stronie
  # http://www.cs.toronto.edu/~kriz/cifar.html.
  label_bytes = 1  # 2 dla CIFAR-100
  result.height = 32
  result.width = 32
  result.depth = 3
  image_bytes = result.height * result.width * result.depth
  # Każdy rekord składa się z etykiety, po której występuje obraz,
  # przy stałej liczbie bajtów dla każdego rekordu.
  record_bytes = label_bytes + image_bytes

  # Odczyt rekordu przez pobranie nazw plików z filename_queue.
  # W formacie CIFAR-10 nie ma nagłówka ani stopki, więc header_bytes
  # i footer_bytes pozostawiamy przy domyślnej wartości 0.
  reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
  result.key, value = reader.read(filename_queue)

  # Konwersja łańcucha na wektor typu uint8 o długości record_bytes.
  record_bytes = tf.decode_raw(value, tf.uint8)

  # Pierwsze bajty reprezentują etykietę, którą konwertujemy z uint8->int32.
  result.label = tf.cast(
      tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)

  # Pozostałe bajty za etykietą reprezentują obraz, który przekształcamy
  # z [depth * height * width] na [depth, height, width].
  depth_major = tf.reshape(
      tf.strided_slice(record_bytes, [label_bytes],
                       [label_bytes + image_bytes]),
      [result.depth, result.height, result.width])
  # Konwersja z [depth, height, width] na [height, width, depth].
  result.uint8image = tf.transpose(depth_major, [1, 2, 0])

  return result

def _distorted_inputs(data_dir, batch_size):
  """Konstruowanie zniekształconych danych wejściowych dla szkolenia CIFAR
     przy użyciu operacji Reader.

  Argumenty:
    data_dir: Ścieżka do katalogu danych CIFAR-10.
    batch_size: Liczba obrazów na grupę.

  Zwracane wartości:
    images: Obrazy. Tensor 4D o rozmiarach [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3].
    labels: Etykiety. Tensor 1D o rozmiarze [batch_size].
  """
  filenames = [os.path.join(data_dir, 'grupa_danych_%d.bin' % i)
               for i in xrange(1, 6)]
  for f in filenames:
    if not tf.gfile.Exists(f):
      raise ValueError('Nie znaleziono pliku: ' + f)

  # Tworzenie kolejki, która tworzy nazwy plików do odczytu.
  filename_queue = tf.train.string_input_producer(filenames)

  # Wczytywanie przykładów z plików w kolejce nazw plików.
  read_input = read_cifar10(filename_queue)
  reshaped_image = tf.cast(read_input.uint8image, tf.float32)

  height = IMAGE_SIZE
  width = IMAGE_SIZE

  # Przetwarzanie obrazu na potrzeby szkolenia sieci. Zwróć uwagę na wiele
  # losowych zniekształceń zastosowanych do obrazu.
  
  # Losowe przycięcie części obrazu [wysokość, szerokość].
  distorted_image = tf.random_crop(reshaped_image, [height, width, 3])

  # Losowe odwrócenie obrazu w poziomie.
  distorted_image = tf.image.random_flip_left_right(distorted_image)

  # Ponieważ te operacje nie są przemienne, rozważ randomizację kolejności
  # ich wykonywania.
  # UWAGA: ponieważ per_image_standardization zeruje średnią i tworzy jednostkę
  # stddev, prawdopodobnie nie da to żadnego efektu; patrz tensorflow#1458.
  distorted_image = tf.image.random_brightness(distorted_image,
                                               max_delta=63)
  distorted_image = tf.image.random_contrast(distorted_image,
                                             lower=0.2, upper=1.8)

  # Odjęcie średniej i podzielenie przez wariancję pikseli.
  float_image = tf.image.per_image_standardization(distorted_image)

  # Ustawienie kształtu tensorów.
  float_image.set_shape([height, width, 3])
  read_input.label.set_shape([1])

  # Zapewnienie dobrych właściwości miksujących dla losowego przetasowania.
  min_fraction_of_examples_in_queue = 0.4
  min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN *
                           min_fraction_of_examples_in_queue)
  print ('Wypełnianie kolejki %d obrazami CIFAR przed rozpoczęciem szkolenia. '
         'To zajmie kilka minut.' % min_queue_examples)

  # Generowanie grupy obrazów i etykiet poprzez utworzenie kolejki przykładów.
  return _generate_image_and_label_batch(float_image, read_input.label,
                                         min_queue_examples, batch_size,
                                         shuffle=True)

def inputs(eval_data):
  """Konstruowanie danych wejściowych do ewaluacji CIFAR przy użyciu operacji Reader.

  Argumenty:
    eval_data: wartość logiczna, wskazująca, czy należy korzystać ze zbioru treningowego,
	  czy ewaluacyjnego.

  Zwracane wartości:
    images: Obrazy. Tensor 4D o rozmiarach [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3].
    labels: Etykiety. Tensor 1D o rozmiarze [batch_size].

  Komunikaty o błędach:
    ValueError: Jeśli brak data_dir
  """
  if not FLAGS.data_dir:
    raise ValueError('Proszę podać data_dir')
  data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
  images, labels = _inputs(eval_data=eval_data,
                           data_dir=data_dir,
                           batch_size=FLAGS.batch_size)
  return images, labels

def _inputs(eval_data, data_dir, batch_size):
  """Konstruowanie danych wejściowych do ewaluacji CIFAR przy użyciu operacji Reader.

  Argumenty:
    eval_data: wartość logiczna, wskazująca, czy należy korzystać ze zbioru treningowego,
	  czy ewaluacyjnego.
    data_dir: Ścieżka do katalogu danych CIFAR-10.
    batch_size: Liczba obrazów na grupę.

  Zwracane wartości:
    images: Obrazy. Tensor 4D o rozmiarach [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3].
    labels: Etykiety. Tensor 1D o rozmiarze [batch_size].
  """
  if not eval_data:
    filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
                 for i in xrange(1, 6)]
    num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
  else:
    filenames = [os.path.join(data_dir, 'test_batch.bin')]
    num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL

  for f in filenames:
    if not tf.gfile.Exists(f):
      raise ValueError('Nie znaleziono pliku: ' + f)

  # Tworzenie kolejki, która tworzy nazwy plików do odczytu.
  filename_queue = tf.train.string_input_producer(filenames)

  # Wczytywanie przykładów z plików w kolejce nazw plików.
  read_input = read_cifar10(filename_queue)
  reshaped_image = tf.cast(read_input.uint8image, tf.float32)

  height = IMAGE_SIZE
  width = IMAGE_SIZE

  # Przetwarzanie obrazu na potrzeby ewaluacji.
  # Przycięcie środkowej [wysokość, szerokość] części obrazu.
  resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image,
                                                         height, width)

  # Odjęcie średniej i podzielenie przez wariancję pikseli.
  float_image = tf.image.per_image_standardization(resized_image)

  # Ustawienie kształtu tensorów.
  float_image.set_shape([height, width, 3])
  read_input.label.set_shape([1])

  # Zapewnienie dobrych właściwości miksujących dla losowego przetasowania.
  min_fraction_of_examples_in_queue = 0.4
  min_queue_examples = int(num_examples_per_epoch *
                           min_fraction_of_examples_in_queue)

  # Generowanie grupy obrazów i etykiet poprzez utworzenie kolejki przykładów.
  return _generate_image_and_label_batch(float_image, read_input.label,
                                         min_queue_examples, batch_size,
                                         shuffle=False)

def maybe_download_and_extract():
  """Pobranie i wypakowanie archiwum tar ze strony internetowej Aleksa."""
  dest_directory = FLAGS.data_dir
  if not os.path.exists(dest_directory):
    os.makedirs(dest_directory)
  filename = DATA_URL.split('/')[-1]
  filepath = os.path.join(dest_directory, filename)
  if not os.path.exists(filepath):
    def _progress(count, block_size, total_size):
      sys.stdout.write('\r>> Pobieranie %s %.1f%%' % (filename,
          float(count * block_size) / float(total_size) * 100.0))
      sys.stdout.flush()
    filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
    print()
    statinfo = os.stat(filepath)
    print('Pomyślnie pobrano', filename, statinfo.st_size, 'bajtów.')
  extracted_dir_path = os.path.join(dest_directory, 'cifar-10-batches-bin')
  if not os.path.exists(extracted_dir_path):
    tarfile.open(filepath, 'r:gz').extractall(dest_directory)

def _generate_image_and_label_batch(image, label, min_queue_examples,
                                    batch_size, shuffle):
  """Konstruowanie skolejkowanej grupy obrazów i etykiet.

  Argumenty:
    image: Tensor 3-D Tensor [wysokość, szerokość, 3] typu float32.
    label: Tensor 1-D typu int32
    min_queue_examples: int32, minimalna liczba próbek do przechowywania
	  w kolejce dostarczającej grupy przykładów.
    batch_size: Liczba obrazów na grupę.
    shuffle: wartość logiczna określająca, czy użyta będzie kolejka tasująca.

  Zwracane wartości:
    images: Obrazy. Tensor 4D o rozmiarze [batch_size, wysokość, szerokość, 3].
    labels: Etykiety. Tensor 1D o rozmiarze [batch_size].
  """
  # Utworzenie kolejki, która przetasuje przykłady, a następnie
  # wczyta obrazy 'batch_size' + etykiety z kolejki przykładów.
  num_preprocess_threads = 16
  if shuffle:
    images, label_batch = tf.train.shuffle_batch(
        [image, label],
        batch_size=batch_size,
        num_threads=num_preprocess_threads,
        capacity=min_queue_examples + 3 * batch_size,
        min_after_dequeue=min_queue_examples)
  else:
    images, label_batch = tf.train.batch(
        [image, label],
        batch_size=batch_size,
        num_threads=num_preprocess_threads,
        capacity=min_queue_examples + 3 * batch_size)

  # Wyświetlanie obrazów szkoleniowych w wizualizerze.
  tf.summary.image('obrazy', images)

  return images, tf.reshape(label_batch, [batch_size])
