Skip to content

Latest commit

 

History

History
168 lines (118 loc) · 8.94 KB

README.md

File metadata and controls

168 lines (118 loc) · 8.94 KB

⚡ SheepRLHF 🐑

📖 What

An easy-to-use framework for training large language models with reinforcement learning in PyTorch, accelerated with Lightning Fabric. The framework is designed to be modular and extensible. It supports different RL algorithms, different tasks, and different datasets.

Quick start

git clone https://github.com/Eclectic-Sheep/sheeprlhf.git && cd sheeprlhf
pip install -e "."
# Launch SFT training for summarization task using OPT model using Lightning Fabric on GPU
python sheeprlhf.py train data=summarization model=opt task=sft fabric=auto_cuda
# Optionally
pip install -e ".[eval]"
python sheeprlhf.py eval task=perplexity experiment_dir=<path_to_sft_experiment>

This will train a model on the summarization dataset using the OPT model. The training will first download the dataset and the model, and then start training. The training will be accelerated with Lightning Fabric, and all metrics will be logged locally using TensorBoard.

Configurations

Here is the available configurations out-of-the-box for the framework:

Dataset Name Config Name
CarperAI/openai_summarize_comparisons summarization
Dahoas/full-hh-rlhf helpful_harmless
Model Name Config Name
OPT opt
GPT2 gpt2
Phi phi
Train Task Name Config Name
Supervised Fine-Tuning sft
Reward Modeling rm
Proximal Policy Optimization ppo
Direct Policy Optimization dpo
Evaluation Task Name Config Name
Perplexity perplexity
ROUGE rouge

❓ Why

We want to have a framework for RL algorithms for LLMs starting from common RLHF algorithms that is at the same time simple and scalable thanks to Lightning Fabric framework. Single framework for different type of tasks and algorithms, should allow developers to easily experiment with different configurations.

📝 How

Reinforcement Learning with Human Feedback (RLHF) is a technique that combines traditional reinforcement learning (RL) with human decisions to train more effective and safe policies. Instead of solely relying on reward signals obtained from the environment, RLHF integrates feedback from humans to guide the learning process. With RLHF, we can have approximated reward signals that are not crafted manually, but rather learned from human judgments. Moreover, we have implemented Direct Policy Optimization for aligning models to human preferences without training a reward model.

Usage

SheepRLHF is designed to be modular and extensible. The framework provides two entry points: train and eval. The train entry point is used to train a model, while the eval entry point is used to evaluate a model. After selecting the entry point, the user can select the task, the model, and the data to use. All other configurations can be changed by passing them as command line arguments.

Extending the framework

The repository is structured as follows:

  • agent: Contains the implementation of the agents for RL algorithms.
  • config: Contains the default configurations for entry points or experiments.
  • data: Contains the implementation of the data processors that can be extended to support new datasets. It also includes dataset and data collator implementations.
  • loss: Contains the implementation of the loss functions for available tasks.
  • model: Contains the implementation of wrapper model classes for LLMs.
  • structure: This folder has all configurations for the framework, including the default configurations. The user can add new settings to the framework by adding new configurations to this folder.
    • data.py: Contains the configuration for each dataset available.
    • fabric.py: Configurations for Lightning Fabric instance.
    • generation.py: Contains parameters for generation configuration for text generation.
    • model.py: Contains the configuration for each model available.
    • optim.py: Optimization configuration.
    • run.py: Entry point configurations for training and evaluation.
    • task.py: Contains the configuration for each task available such as SFT, DPO, and PPO etc.
  • task: In this folder, we have implementations for each task that the framework supports.
    • train: Contains the implementation of the training algorithms such as SFT, DPO, and PPO.
    • eval: Contains the implementation of the evaluation algorithms such as perplexity and and ROUGE.
  • utils: Contains utilities and helper functions.
  • cli.py: Contains the entry points for the framework.

Adding new models

All models are defined as configuration dataclasses under sheeprlhf/structure/model.py file.To add a new model available on Huggingface, one can add a new configuration to the file. For example, to add OPT 350M model, one can add the following code:

@dataclass
class OPTConfig(ModelConfig):
    """Configurations for OPT based models."""
    config_name: str = "opt"
    repo_name: str = "facebook/opt-350m"
    embedding_dim_name: Optional[str] = "word_embed_proj_dim"
    lora_cfg: Optional[LORAConfig] = LORAConfig(targets="('q_proj','v_proj')")

Enabling LoRA

SheepRLHF supports LoRA out of the box, which helps reducing memory requirements while only updating the subset of parameters. To enable LoRA, one can add the following code to the configuration of the algorithm:

python sheeprlhf.py train task=sft model=opt data=summarization model.finetune_mode=LORA model.lora_cfg.rank=16

🙇 Contributing

The best way to contribute is by opening an issue to discuss a new feature or a bug, or by opening a PR to fix a bug or to add a new feature. For development, it is required to install the pre-commit hooks and have development dependencies installed. To do so, run the following commands:

pip install ".[dev]"
pre-commit install

ℹ️ Acknowledgements

This work and the code developed for the task is a long educational and experimental journey. Please ask us about anything you need or not clear on GitHub. It will be even more then welcomed if you like to contribute. We would like to thank the following works for their contributions to the field and inspiring us to develop this work.

Libraries

Blog Posts

Research Articles

📭 Who

You can contact us for any further questions or discussions:

📄 License

This project is licensed under the terms of the Apache License 2.0. Please see the LICENSE file for details. Be aware that the project also may use other third-party libraries or models available online, which may be licensed under different licenses.