{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# rlberry DQN on Farm1 from farm-gym-games\n\nThis example use farm-gym-games and rlberry, please install these libraries before using.\n\npip install git+https://github.com/rlberry-py/rlberry.git\npip install git+https://github.com/farm-gym/farm-gym-games\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import farmgym_games\nimport numpy as np\nfrom rlberry.envs import gym_make\n\nfrom rlberry.agents.torch import DQNAgent\nfrom rlberry.manager import AgentManager, evaluate_agents, plot_writer_data\nfrom rlberry.agents.torch.utils.training import model_factory_from_env\n\nimport pandas as pd\nimport seaborn as sns\nimport time\nimport matplotlib.pyplot as plt\n\n\nactions_txt = [\n    \"doing nothing\",\n    \"1L of water\",\n    \"5L of water\",\n    \"harvesting\",\n    \"sow some seeds\",\n    \"scatter fertilizer\",\n    \"scatter herbicide\",\n    \"scatter pesticide\",\n    \"remove weeds by hand\",\n]\n\n\nenv_ctor, env_kwargs = gym_make, {\"id\": \"Farm1-v0\"}\n\nif __name__ == \"__main__\":\n    manager = AgentManager(\n        DQNAgent,\n        (env_ctor, env_kwargs),\n        agent_name=\"DQNAgent\",\n        init_kwargs=dict(\n            learning_rate=9e-5,\n        ),\n        fit_budget=5e5,\n        eval_kwargs=dict(eval_horizon=365),\n        n_fit=1,\n        parallelization=\"process\",\n        mp_context=\"spawn\",\n        enable_tensorboard=True,\n        seed=42,\n    )\n\n    init_time = time.time()\n    manager.fit()\n    print(\"training time in s is \", time.time() - init_time)\n    fig, ax = plt.subplots(figsize=(12, 6))\n    data = plot_writer_data(manager, tag=\"episode_rewards\", smooth_weight=0.8, ax=ax)  # smoothing tensorboard-style\n\n    fig.savefig(\"dqn_regret.pdf\")\n\n    agent = manager.agent_handlers[0]  # select the agent from the manager\n    env = gym_make(\"Farm1-v0\")\n\n    rew = 0\n    while rew < 1:\n        rew = 0\n        obs = env.reset()\n        episode = pd.DataFrame()\n        for day in range(365):\n            action = agent.policy(obs)\n            obs, reward, is_done, _ = env.step(action)\n            episode = pd.concat(\n                [episode, pd.DataFrame({\"action\": [actions_txt[action]], \"reward\": [reward]})], ignore_index=True\n            )\n            rew = rew + reward\n            if is_done:\n                print(\"Plant is Dead\")\n                break\n\n        print(rew)\n\n    fig, ax = plt.subplots(figsize=(12, 6))\n    sns.countplot(data=episode, x=\"action\", order=episode[\"action\"].value_counts().index, ax=ax)\n    fig.savefig(\"dqn_barplot.pdf\")"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.15"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}