#!/usr/bin/env python3
import argparse
import gym
import pybullet_envs

from lib import model

import numpy as np
import torch


ENV_ID = "MinitaurBulletEnv-v0"

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model", required=True, help="Nazwa pliku z modelem")
    parser.add_argument("-e", "--env", default=ENV_ID, help="Nazwa rodowiska, domylnie=" + ENV_ID)
    parser.add_argument("-r", "--record", help="Jeli podano, definiuje katalog sucy do zapisywania danych, domylnie=brak katalogu")
    args = parser.parse_args()

    spec = gym.envs.registry.spec(args.env)
    spec._kwargs['render'] = False
    env = gym.make(args.env)
    if args.record:
        env = gym.wrappers.Monitor(env, args.record)

    net = model.ModelA2C(env.observation_space.shape[0], env.action_space.shape[0])
    net.load_state_dict(torch.load(args.model))

    obs = env.reset()
    total_reward = 0.0
    total_steps = 0
    while True:
        obs_v = torch.FloatTensor([obs])
        mu_v, var_v, val_v = net(obs_v)
        action = mu_v.squeeze(dim=0).data.numpy()
        action = np.clip(action, -1, 1)
        obs, reward, done, _ = env.step(action)
        total_reward += reward
        total_steps += 1
        if done:
            break
    print("Po %d krokach uzyskano nagrod %.3f" % (total_steps, total_reward))
