# GRPO With Replay Buffer

This experimental trainer, trains a model with GRPO but replaces groups (and corresponding completions) that have 0 standard deviation with groups with high rewards and standard deviation that've been used to train a model in prior batches.

## Usage

```python
import torch
from trl.experimental.grpo_with_replay_buffer import GRPOWithReplayBufferConfig, GRPOWithReplayBufferTrainer
from datasets import load_dataset

dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

# Guarantee that some rewards have 0 std
def custom_reward_func(completions, **kwargs):
    if torch.rand(1).item() < 0.25:
        return [0] * len(completions)  # simulate some None rewards
    else:
        return torch.rand(len(completions)).tolist()

training_args = GRPOWithReplayBufferConfig(
    output_dir="./tmp",
    learning_rate=1e-4,
    per_device_train_batch_size=4,
    num_generations=4,
    max_completion_length=8,
    replay_buffer_size=8,
    report_to="none",
)

trainer = GRPOWithReplayBufferTrainer(
    model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
    reward_funcs=[custom_reward_func],
    args=training_args,
    train_dataset=dataset,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()
```

## GRPOWithReplayBufferTrainer[[trl.experimental.grpo_with_replay_buffer.GRPOWithReplayBufferTrainer]]

#### trl.experimental.grpo_with_replay_buffer.GRPOWithReplayBufferTrainer[[trl.experimental.grpo_with_replay_buffer.GRPOWithReplayBufferTrainer]]

[Source](https://github.com/huggingface/trl/blob/main/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py#L60)

traintrl.experimental.grpo_with_replay_buffer.GRPOWithReplayBufferTrainer.trainhttps://github.com/huggingface/trl/blob/main/transformers/trainer.py#L1325[{"name": "resume_from_checkpoint", "val": ": str | bool | None = None"}, {"name": "trial", "val": ": optuna.Trial | dict[str, Any] | None = None"}, {"name": "ignore_keys_for_eval", "val": ": list[str] | None = None"}]- **resume_from_checkpoint** (`str` or `bool`, *optional*) --
  If a `str`, local path to a saved checkpoint as saved by a previous instance of `Trainer`. If a
  `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
  of `Trainer`. If present, training will resume from the model/optimizer/scheduler states loaded here.
- **trial** (`optuna.Trial` or `dict[str, Any]`, *optional*) --
  The trial run or the hyperparameter dictionary for hyperparameter search.
- **ignore_keys_for_eval** (`list[str]`, *optional*) --
  A list of keys in the output of your model (if it is a dictionary) that should be ignored when
  gathering predictions for evaluation during the training.0`~trainer_utils.TrainOutput`Object containing the global step count, training loss, and metrics.

Main training entry point.

**Parameters:**

resume_from_checkpoint (`str` or `bool`, *optional*) : If a `str`, local path to a saved checkpoint as saved by a previous instance of `Trainer`. If a `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance of `Trainer`. If present, training will resume from the model/optimizer/scheduler states loaded here.

trial (`optuna.Trial` or `dict[str, Any]`, *optional*) : The trial run or the hyperparameter dictionary for hyperparameter search.

ignore_keys_for_eval (`list[str]`, *optional*) : A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training.

**Returns:**

``~trainer_utils.TrainOutput``

Object containing the global step count, training loss, and metrics.
#### save_model[[trl.experimental.grpo_with_replay_buffer.GRPOWithReplayBufferTrainer.save_model]]

[Source](https://github.com/huggingface/trl/blob/main/transformers/trainer.py#L3752)

Will save the model, so you can reload it using `from_pretrained()`.

Will only save from the main process.
#### push_to_hub[[trl.experimental.grpo_with_replay_buffer.GRPOWithReplayBufferTrainer.push_to_hub]]

[Source](https://github.com/huggingface/trl/blob/main/transformers/trainer.py#L3999)

Upload `self.model` and `self.processing_class` to the 🤗 model hub on the repo `self.args.hub_model_id`.

**Parameters:**

commit_message (`str`, *optional*, defaults to `"End of training"`) : Message to commit while pushing.

blocking (`bool`, *optional*, defaults to `True`) : Whether the function should return only when the `git push` has finished.

token (`str`, *optional*, defaults to `None`) : Token with write permission to overwrite Trainer's original args.

revision (`str`, *optional*) : The git revision to commit from. Defaults to the head of the "main" branch.

kwargs (`dict[str, Any]`, *optional*) : Additional keyword arguments passed along to `~Trainer.create_model_card`.

**Returns:**

The URL of the repository where the model was pushed if `blocking=False`, or a `Future` object tracking the
progress of the commit if `blocking=True`.

## GRPOWithReplayBufferConfig[[trl.experimental.grpo_with_replay_buffer.GRPOWithReplayBufferConfig]]

#### trl.experimental.grpo_with_replay_buffer.GRPOWithReplayBufferConfig[[trl.experimental.grpo_with_replay_buffer.GRPOWithReplayBufferConfig]]

[Source](https://github.com/huggingface/trl/blob/main/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_config.py#L21)

New Parameters:
replay_buffer_size (`int`, *optional*, defaults to `0`):
A cache that stores the rollouts with the highest advantage scores and variance per group. If a new
group has 0 variance, it is replaced with a group sampled from the replay buffer.

## ReplayBuffer[[trl.experimental.grpo_with_replay_buffer.ReplayBuffer]]

#### trl.experimental.grpo_with_replay_buffer.ReplayBuffer[[trl.experimental.grpo_with_replay_buffer.ReplayBuffer]]

[Source](https://github.com/huggingface/trl/blob/main/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py#L28)

A simple replay buffer to store and sample previously seen rollouts.

