 Copyright Â© Sorbonne University.

 This source code is licensed under the MIT license found in the LICENSE file
 in the root directory of this source tree.

# Outlook
In this notebook, using BBRL, we code a simple version of the DQN algorithm
without a replay buffer nor a target network so as to better understand the
inner mechanisms.

To understand this code, you need to know more about [the BBRL interaction
model](https://github.com/osigaud/bbrl/blob/master/docs/overview.md) Then you
should run [a didactical
example](https://github.com/osigaud/bbrl/blob/master/docs/notebooks/02-multi_env_noautoreset.student.ipynb)
to see how agents interact in BBRL when autoreset=False.

The DQN algorithm is explained in [this
video](https://www.youtube.com/watch?v=CXwvOMJujZk) and you can also read [the
corresponding slides](http://pages.isir.upmc.fr/~sigaud/teach/dqn.pdf).

In [None]:
# Prepare the environment
try:
    from easypip import easyimport
except ModuleNotFoundError:
    from subprocess import run
    assert run(["pip", "install", "easypip"]).returncode == 0, "Could not install easypip"
    from easypip import easyimport

easyimport("swig")
easyimport("bbrl_utils").setup(maze_mdp=True)

import os

import bbrl_gymnasium  # noqa: F401
import torch
import torch.nn as nn
from bbrl.agents import Agent, Agents
from bbrl_utils.algorithms import EpisodicAlgo
from bbrl_utils.nn import build_mlp, setup_optimizer
from bbrl_utils.notebook import setup_tensorboard
from omegaconf import OmegaConf

# Learning environment

## Configuration

The learning environment is controlled by a configuration that define a few
important things as described in the example below. This configuration can
hold as many extra information as you need, the example below is the minimal
one.

```python
params = {
    # This defines the a path for logs and saved models
    "base_dir": "${gym_env.env_name}/myalgo_${current_time:}",

    # The Gymnasium environment
    "gym_env": {
        "env_name": "CartPoleContinuous-v1",
    },

    # Algorithm
    "algorithm": {
        # Seed used for the random number generator
        "seed": 1023,

        # Number of parallel training environments
        "n_envs": 8,
                
        # Minimum number of steps between two evaluations
        "eval_interval": 500,
        
        # Number of parallel evaluation environments
        "nb_evals": 10,

        # Number of epochs (loops)
        "max_epochs": 40000,

    },
}

# Creates the configuration object, i.e. cfg.algorithm.nb_evals is 10
cfg = OmegaConf.create(params)
```

## The RL algorithm

In this notebook, the RL algorithm is based on `EpisodicAlgo`, that defines
the algorithm environment when using episodes. To use such environment, we
just need to subclass `EpisodicAlgo` and to define two things, namely the
`train_policy` and the `eval_policy`. Both are BBRL agents that, given the
environment state, select the action to perform.

```py
  class MyAlgo(EpisodicAlgo):
      def __init__(self, cfg):
          super().__init__(cfg)

          # Define the train and evaluation policies
          # (the agents compute the workspace `action` variable)
          self.train_policy = MyPolicyAgent(...)
          self.eval_policy = MyEvalAgent(...)

algo = MyAlgo(cfg)
```

The `EpisodicAlgo` defines useful objects:

- `algo.cfg` is the configuration
- `algo.nb_steps` (integer) is the number of steps since the training began
- `algo.logger` is a logger that can be used to collect statistics during training:
    - `algo.logger.add_log("critic_loss", critic_loss, algo.nb_steps)` registers the `critic_loss` value on tensorboard
- `algo.evaluate()` evaluates the current `eval_policy` if needed, and keeps the
agent if it was the best so far (average cumulated reward);
- `algo.visualize_best()` runs the best agent on one episode, and displays the video



Besides, it also defines an `iter_episodes` is simple:

```py
  # With episodes
  for workspace in rl_algo.iter_episodes():
      # workspace is a workspace containing transitions
      # Episodes shorter than the longer one contain duplicated
      # transitions (with `env/done` set to true)
      ...
```

## Definition of agents

The [DQN](https://daiwk.github.io/assets/dqn.pdf) algorithm is a critic only
algorithm. Thus we just need a Critic agent (which is also used to output
actions) and an Environment agent.

### The critic agent

The critic agent is an instance of the `DiscreteQAgent` class. We first build
a deterministic neural network that takes the state as input (so it has one
input neuron per state variable) and that outputs the Q-value of each action
in that state (so it has one output neuron per action).

As any BBRL agent, the DiscreteQAgent has a `forward()` function that takes a
time state as input. This `forward()` function outputs the Q-values of all
actions at the corresponding time step. Additionally, if the critic is used to
choose an action, it also outputs the chosen action at the same time step.

In [None]:
class DiscreteQAgent(Agent):
    """BBRL agent (discrete actions) based on a MLP"""

    def __init__(self, state_dim, hidden_layers, action_dim):
        super().__init__()
        self.model = build_mlp(
            [state_dim] + list(hidden_layers) + [action_dim], activation=nn.ReLU()
        )

    def forward(self, t: int, **kwargs):
        """An Agent can use self.workspace"""

        # Retrieves the observation from the environment at time t
        obs = self.get(("env/env_obs", t))

        # Computes the critic (Q) values for the observation
        q_values = self.model(obs)

        # ... and sets the q-values (one for each possible action)
        self.set(("q_values", t), q_values)

#### Greedily choosing the action

The ArgmaxActionSelector is in charge of choosing the action whose Q-value is
the highest given the Q-values of all actions. We may use it when we do not
want to explore.

In [None]:
class ArgmaxActionSelector(Agent):
    """BBRL agent that selects the best action based on Q(s,a)"""

    def forward(self, t: int, **kwargs):
        q_values = self.get(("q_values", t))
        action = q_values.argmax(1)
        self.set(("action", t), action)

### Creating an Exploration method

As Q-learning, DQN needs some exploration to prevent too early convergence.
Here we use the simple $\epsilon$-greedy exploration method.
It is implemented as an agent which chooses an action based on the Q-values.

In [None]:
class EGreedyActionSelector(Agent):
    def __init__(self, epsilon):
        super().__init__()
        self.epsilon = epsilon

    def forward(self, t: int, **kwargs):
        # Retrieves the q values
        # (matrix nb. of episodes x nb. of actions)
        q_values: torch.Tensor = self.get(("q_values", t))
        size, nb_actions = q_values.shape

        # Flag
        is_random = torch.rand(size) > self.epsilon
        
        # Actions (random / argmax)
        random_action = torch.randint(nb_actions, size=(size,))
        max_action = q_values.argmax(-1)

        # Choose the action based on the is_random flag
        action = torch.where(is_random, random_action, max_action)

        # Sets the action at time t
        self.set(("action", t), action)

## Heart of the algorithm

### Computing the critic loss

The role of the `compute_critic_loss` function is to implement the Bellman
backup rule. In Q-learning, this rule was written:

$$Q(s_t,a_t) \leftarrow Q(s_t,a_t) + \alpha [ r(s_t,a_t) + \gamma \max_a
Q(s_{t+1},a) - Q(s_t,a_t)]$$

In DQN, the update rule $Q \leftarrow Q + \alpha [\delta] $ is replaced by a
gradient descent step over the Q-network.

We first compute a target value: $ target = r(s_t,a_t) + \gamma \max_a
Q(s_{t+1},a)$ from a set of samples.

Then we get a TD error $\delta$ by substracting $Q(s_t,a_t)$ for these
samples, and we use the squared TD error as a loss function: $ loss = (target
- Q(s_t,a_t))^2$.

To implement the above calculation in BBRL, the difficulty is to properly deal
with time indexes.

The `compute_critic_loss` function receives rewards, q_values and actions as
tensors that have been computed over a complete episode.

We need to take `reward[1:]`, which means all the rewards except the first one
because the reward from $(s_t, a_t)$ is $r_{t+1}$. Similarly, to get $\max_a
Q(s_{t+1}, a)$, we need to ignore the first of the max_q values, using
`max_q[1:]`.

Do not forget to apply .detach() when computing the values of $\max_a
Q(s_{t+1}, a)$, as **we do not want to apply gradient descent on this $\max_a
Q(s_{t+1}, a)$**, we only apply gradient descent to $Q(s_t, a_t)$ according to
this target value. In practice, `x.detach()` detaches a computation graph from
a tensor, so it avoids computing a gradient over this tensor.

The `must_bootstrap` tensor is used as a trick to deal with terminal states,
as explained
[here](https://github.com/osigaud/bbrl/blob/master/docs/time_limits.md) In
practice, `must_bootstrap` is the logical negation of `terminated`. In the
autoreset=False version we use full episodes, thus `must_bootstrap` is always
True for all steps but the last one.

To compute $Q(s_t,a_t)$ we use the [`torch.gather()`](https://pytorch.org/docs/stable/generated/torch.gather.html) function. This function is
a little tricky to use, see [this
page](https://github.com/osigaud/bbrl/blob/master/docs/using_gather.md) for
useful explanations.

In particular, the q_vals output that we get is not properly conditioned,
hence the need for the `qval[:-1]` (we ignore the last dimension). Finally we
just need to compute the difference target - qvals, square it, take the mean
and send it back as the loss.

In [None]:
def compute_critic_loss(
    cfg,
    reward: torch.Tensor,
    must_bootstrap: torch.Tensor,
    done: torch.Tensor,
    q_values: torch.Tensor,
    action: torch.LongTensor,
) -> torch.Tensor:
    """Compute the temporal difference loss from a dataset to
    update a critic

    For the tensor dimensions:

    - T = maximum number of time steps
    - B = number of episodes run in parallel
    - A = action space dimension

    :param cfg: The configuration
    :param reward: A (T x B) tensor containing the rewards
    :param must_bootstrap: a (T x B) tensor containing 0 at (t, b) if the
        episode b was terminated at time $t$ (or before)
    :param done: a (T x B) tensor containing 0 at (t, b) if the
        episode b is truncated or terminated at time $t$ (or before)
    :param q_values: a (T x B x A) tensor containing the Q-values at each time
        step, and for each action
    :param action: a (T x B) long tensor containing the chosen action

    :return: The DQN loss
    """
    # We compute the max of Q-values over all actions and detach (so that this
    # part of the computation graph is not included in the gradient
    # backpropagation)

    # Compute the loss

    assert False, 'Not implemented yet'


    return critic_loss

## Main training loop

Note that everything about the shared workspace between all the agents is
completely hidden under the hood. This results in a gain of productivity, at
the expense of having to dig into the BBRL code if you want to understand the
details, change the multiprocessing model, etc.

The next cells defines a `EpisodicDQN` that deals with various part of the
training loop.

In [None]:
class EpisodicDQN(EpisodicAlgo):
    def __init__(self, cfg):
        super().__init__(cfg)

        # Get the observation / action state space dimensions
        obs_size, act_size = self.train_env.get_obs_and_actions_sizes()

        # Our discrete Q-Agent
        self.q_agent = DiscreteQAgent(
            obs_size, cfg.algorithm.architecture.hidden_size, act_size
        )

        # The e-greedy strategy (when training)
        explorer = EGreedyActionSelector(cfg.algorithm.epsilon)

        # The training agent combines the Q agent
        self.train_policy = Agents(self.q_agent, explorer)

        # The optimizer for the Q-Agent parameters
        self.optimizer = setup_optimizer(self.cfg.optimizer, self.q_agent)

        # ...and the evaluation policy (select the most likely action)
        self.eval_policy = Agents(self.q_agent, ArgmaxActionSelector())

    def run(self):
        for train_workspace in self.iter_episodes():
            q_values, terminated, done, reward, action = train_workspace[
                "q_values", "env/terminated", "env/done", "env/reward", "action"
            ]

            # Determines whether values of the critic should be propagated
            # True if the episode reached a time limit or if the task was not done
            # See https://github.com/osigaud/bbrl/blob/master/docs/time_limits.md
            must_bootstrap = ~terminated

            # Compute critic loss
            critic_loss = compute_critic_loss(
                self.cfg, reward, must_bootstrap, done, q_values, action
            )

            # Store the loss for tensorboard display
            self.logger.add_log("critic_loss", critic_loss, self.nb_steps)
            dqn.logger.add_log("q_values/min", q_values.max(-1).values.min(), dqn.nb_steps)
            dqn.logger.add_log("q_values/max", q_values.max(-1).values.max(), dqn.nb_steps)
            dqn.logger.add_log("q_values/mean", q_values.max(-1).values.mean(), dqn.nb_steps)

            # Gradient step
            self.optimizer.zero_grad()
            critic_loss.backward()
            torch.nn.utils.clip_grad_norm_(
                self.q_agent.parameters(), self.cfg.algorithm.max_grad_norm
            )
            self.optimizer.step()

            # Evaluate the current policy (if needed)
            self.evaluate()

In [None]:
# We setup tensorboard before running DQN
setup_tensorboard("./outputs/tblogs")

In [None]:
params = {
    "save_best": False,
    "base_dir": "${gym_env.env_name}/dqn-simple-S${algorithm.seed}_${current_time:}",
    "collect_stats": True,
    "algorithm": {
        "seed": 3,
        "max_grad_norm": 0.5,
        "epsilon": 0.1,
        "n_envs": 8,
        "eval_interval": 5_000,
        "max_epochs": 500,
        "nb_evals": 10,
        "discount_factor": 0.99,
        "architecture": {"hidden_size": [256, 256]},
    },
    "gym_env": {
        "env_name": "CartPole-v1",
    },
    "optimizer": {
        "classname": "torch.optim.Adam",
        "lr": 2e-3,
    },
}

dqn = EpisodicDQN(OmegaConf.create(params))

In [None]:
# Run and visualize the best agent
dqn.run()
dqn.visualize_best()

## What's next?

To get a full DQN, we need to do the following:
- Add a replay buffer. We can add a replay buffer independently from the
  target network. The version with a replay buffer and no target network
  corresponds to [the NQF
  algorithm](https://link.springer.com/content/pdf/10.1007/11564096_32.pdf).
  This will be the aim of the next notebook.
- Before adding the replay buffer, we will first move to a version of DQN
  which uses the AutoResetGymAgent. This will be the aim of the next notebook
  too.
- We should also add a few extra-mechanisms which are present in the full DQN
  version: starting to learn once the replay buffer is full enough, decreasing
  the exploration rate epsilon...
<!-- - We could also add visualization tools to visualize the learned Q network, by using the `plot_critic` function available in [`bbrl.visu.visu_critics`](https://github.com/osigaud/bbrl/blob/master/src/bbrl/visu/visu_critics.py#L13) -->