{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# rlberry PPO on Farm1 from farm-gym-games\n\nThis example use farm-gym-games and rlberry, please install these libraries before using.\n\npip install git+https://github.com/rlberry-py/rlberry.git\npip install git+https://github.com/farm-gym/farm-gym-games\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import farmgym_games\nimport numpy as np\nfrom rlberry.envs import gym_make\n\nfrom rlberry.agents.torch import PPOAgent\nfrom rlberry.manager import AgentManager, evaluate_agents, plot_writer_data\nfrom rlberry.agents.torch.utils.training import model_factory_from_env\n\nimport gymnasium as gym\n\nimport pandas as pd\nimport seaborn as sns\nimport time\nimport matplotlib.pyplot as plt\n\npolicy_configs = {\n    \"type\": \"MultiLayerPerceptron\",  # A network architecture\n    \"layer_sizes\": (256, 256),  # Network dimensions\n    \"reshape\": False,\n    \"is_policy\": True,\n}\n\nvalue_configs = {\n    \"type\": \"MultiLayerPerceptron\",\n    \"layer_sizes\": (256, 256),\n    \"reshape\": False,\n    \"out_size\": 1,\n}\n\nactions_txt = [\n    \"doing nothing\",\n    \"1L of water\",\n    \"5L of water\",\n    \"harvesting\",\n    \"sow some seeds\",\n    \"scatter fertilizer\",\n    \"scatter herbicide\",\n    \"scatter pesticide\",\n    \"remove weeds by hand\",\n]\n\n\nenv_ctor, env_kwargs = gym_make, {\"id\": \"GymV26Environment-v0\", \"env_id\": \"Farm1-v0\"}\n\nif __name__ == \"__main__\":\n    manager = AgentManager(\n        PPOAgent,\n        (env_ctor, env_kwargs),\n        agent_name=\"PPOAgent\",\n        init_kwargs=dict(\n            policy_net_fn=model_factory_from_env,\n            policy_net_kwargs=policy_configs,\n            value_net_fn=model_factory_from_env,\n            value_net_kwargs=value_configs,\n            learning_rate=9e-5,\n            n_steps=5 * 365,\n            batch_size=365,\n            eps_clip=0.1,\n        ),\n        fit_budget=5e5,\n        eval_kwargs=dict(eval_horizon=365),\n        n_fit=1,\n        parallelization=\"process\",\n        mp_context=\"spawn\",\n        enable_tensorboard=True,\n        seed=42,\n    )\n\n    init_time = time.time()\n    manager.fit()\n    print(\"training time in s is \", time.time() - init_time)\n    fig, ax = plt.subplots(figsize=(12, 6))\n    data = plot_writer_data(manager, tag=\"episode_rewards\", smooth_weight=0.8, ax=ax)  # smoothing tensorboard-style\n\n    fig.savefig(\"ppo_regret.pdf\")\n\n    agent = manager.agent_handlers[0]  # select the agent from the manager\n    env = gym_make(\"Farm1-v0\")\n\n    rew = 0\n    while rew < 1:\n        rew = 0\n        obs = env.reset()\n        episode = pd.DataFrame()\n        for day in range(365):\n            action = agent.policy(obs)\n            obs, reward, is_done, _ = env.step(action)\n            episode = pd.concat(\n                [episode, pd.DataFrame({\"action\": [actions_txt[action]], \"reward\": [reward]})], ignore_index=True\n            )\n            rew = rew + reward\n            if is_done:\n                print(\"Plant is Dead\")\n                break\n\n        print(rew)\n\n    fig, ax = plt.subplots(figsize=(12, 6))\n    sns.countplot(data=episode, x=\"action\", order=episode[\"action\"].value_counts().index, ax=ax)\n    fig.savefig(\"ppo_barplot.pdf\")"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.15"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}