//Example 21-1. Creating and training a decision tree

#include <opencv2/opencv.hpp>
#include <stdio.h>
#include <iostream>

using namespace std;
using namespace cv;

void help(char **argv) {
  cout << "\n\n"
       << "Using binary decision trees to learn to recognize poisonous\n"
       << "    from edible mushrooms based on visible attributes.\n" 
       << "    This program demonstrates how to create and a train a \n"
       << "    decision tree using ml library in OpenCV.\n"
       << "Call:\n" << argv[0] << " <csv-file-path>\n\n"
       << "\nIf you don't enter a file, it defaults to ../mushroom/agaricus-lepiota.data\n"
       << endl;
}

int main(int argc, char *argv[]) {
  // jeli wywoujcy poda nazw pliku, to doskonale; w przeciwnym razie program uywa domylnej nazwy
  //
  const char *csv_file_name = argc >= 2 ? argv[1] : "../mushroom/agaricus-lepiota.data";
  cout << "Wersja OpenCV: " << CV_VERSION << endl;
  help(argv);

  // wczytanie podanego pliku CSV
  //
  cv::Ptr<cv::ml::TrainData> data_set =
      csv_file_name, // nazwa pliku wejciowego
    0,             // linie nagwkowe (tyle ignoruje)
    0,             // w tej kolumnie zaczynaj si odpowiedzi
    1,             // w tej kolumnie zaczynaj si dane wejciowe
    "cat[0-22]"    // wszystkie 23 kolumny to kategorie
  );               // uywa domylnych znakw rozdzielajcego (,) i braku danych (?)

  
  // sprawdzenie, czy wczytujemy waciwe dane
  //
  int n_samples = data_set->getNSamples();
  if (n_samples == 0) {
    cerr << "Nie udao si wczyta pliku: " << csv_file_name << endl;
    exit(-1);
  } else {
    cout << "Wczytano " << n_samples << " prbek z " << csv_file_name << endl;
  }

  // podzia danych w taki sposb, e 90% to dane szkoleniowe
  //
  data_set->setTrainTestSplitRatio(0.90, false);
  int n_train_samples = data_set->getNTrainSamples();
  int n_test_samples = data_set->getNTestSamples();
  cout << "Znaleziono " << n_train_samples << " prbek szkoleniowych i "
       << n_test_samples << " prbek testowych." << endl;

  // tworzenie klasyfikatora DTrees
  //
  cv::Ptr<cv::ml::RTrees> dtree = cv::ml::RTrees::create();
  // ustawienia parametrw
  //
  // to s parametry ze starego pliku mushrooms.cpp

  // ustawienie prawdopodobiestw a priori tak, aby kara grzyby trujce 10 razy bardziej ni jadalne
  //
  float _priors[] = {1.0, 10.0};
  cv::Mat priors(1, 2, CV_32F, _priors);
  dtree->setMaxDepth(8);
  dtree->setMinSampleCount(10);
  dtree->setRegressionAccuracy(0.01f);
  dtree->setUseSurrogates(false /* true */);
  dtree->setMaxCategories(15);
  dtree->setCVFolds(0 /*10*/); // warto rna od zera powoduje zrzut rdzenia
  dtree->setUse1SERule(true);
  dtree->setTruncatePrunedTree(true);
  // dtree->setPriors( priors );
  dtree->setPriors(cv::Mat()); // warto rna od zera powoduje zrzut rdzenia
  // szkolenie modelu
  // uwaga: wykorzystujemy tylko szkoleniow cz zbioru danych
  //

  dtree->train(data_set);

  // po zakoczeniu szkolenia powinnimy by w stanie obliczy skal bdw zarwno w odniesieniu do danych szkoleniowych,
  // jak i do odoonych danych testowych
  //
  cv::Mat results;
  float train_performance = dtree->calcError(data_set,
                                             false, // uyj danych szkoleniowych
                                             results // cv::noArray()
                                             );
  std::vector<cv::String> names;
  data_set->getNames(names);
  Mat flags = data_set->getVarSymbolFlags();

  // wasne obliczenia statystyczne:
  //
  {
    cv::Mat expected_responses = data_set->getResponses();
    int good = 0, bad = 0, total = 0;
    for (int i = 0; i < data_set->getNTrainSamples(); ++i) {
      float received = results.at<float>(i, 0);
      float expected = expected_responses.at<float>(i, 0);
      cv::String r_str = names[(int)received];
      cv::String e_str = names[(int)expected];
      cout << "Oczekiwano: " << e_str << ", otrzymano: " << r_str << endl;
      if (received == expected)
        good++;
      else
        bad++;
      total++;
    }
    cout << "Poprawne odpowiedzi: " <<(float(good)/total) <<" % " << endl;
                cout << "Niepoprawne odpowiedzi: " << (float(bad) / total) << "%"
         << endl;
  }
  float test_performance = dtree->calcError(data_set,
                                            true, // uyj danych testowych
                                            results // cv::noArray()
                                            );
  cout << "Skuteczno na danych szkoleniowych: " << train_performance << "%" << endl;
  cout << "Skuteczno na danych testowych: " <<test_performance <<" % " <<endl;
  return 0;
}
