Skip to content

Commit

Permalink
Add configs for DPO (#379)
Browse files Browse the repository at this point in the history
* init

* nit

* style
  • Loading branch information
natolambert authored Oct 8, 2024
1 parent fafd1ec commit f0b3def
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 2 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ rejection_sampling/shards1
token_length.png
*.tfevents.*

oe-eval-internal/

results
models
wandb
Expand Down
31 changes: 31 additions & 0 deletions configs/train_configs/dpo/tulu3_preview_pref_v3.1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
model_name_or_path: /model
model_revision: main
use_flash_attn: true
gradient_checkpointing: true
dataset_mixer:
# Same UltraFeedback data from Tulu 2
allenai/ultrafeedback_binarized_cleaned_train: 1.0
# Custom conversion of daring anteater synthetic data into preferences
ai2-adapt-dev/DaringAnteater-prefs-RM-filter: 1.0
# Modifications of WildChat data to preferences with
ai2-adapt-dev/WildChat-prefs-280824: 1.0
tokenizer_name: /model
use_slow_tokenizer: true
max_seq_length: 2048
preprocessing_num_workers: 16
per_device_train_batch_size: 1
gradient_accumulation_steps: 16 # designed for 8 GPUs, so batch size 128
learning_rate: 5.0e-7
lr_scheduler_type: linear
warmup_ratio: 0.1
weight_decay: 0.0
num_train_epochs: 1
output_dir: /output
with_tracking: true
report_to:
- wandb
logging_steps: 1
use_lora: false
dpo_loss_type: dpo_norm
dpo_beta: 5
checkpointing_steps: 1000
33 changes: 33 additions & 0 deletions configs/train_configs/dpo/tulu3_preview_pref_v3.2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
model_name_or_path: /model
model_revision: main
use_flash_attn: true
gradient_checkpointing: true
dataset_mixer:
# Same UltraFeedback data from Tulu 2
allenai/ultrafeedback_binarized_cleaned_train: 1.0
# Custom conversion of daring anteater synthetic data into preferences
ai2-adapt-dev/DaringAnteater-prefs-RM-filter: 1.0
# Modifications of WildChat data to preferences with
ai2-adapt-dev/WildChat-prefs-280824: 1.0
# Nectar binarized with anthropic helpful and harmless prompts
ai2-adapt-dev/nectar_binarized-anthropic-hh: 1.0
tokenizer_name: /model
use_slow_tokenizer: true
max_seq_length: 2048
preprocessing_num_workers: 16
per_device_train_batch_size: 1
gradient_accumulation_steps: 16 # designed for 8 GPUs, so batch size 128
learning_rate: 5.0e-7
lr_scheduler_type: linear
warmup_ratio: 0.1
weight_decay: 0.0
num_train_epochs: 1
output_dir: /output
with_tracking: true
report_to:
- wandb
logging_steps: 1
use_lora: false
dpo_loss_type: dpo_norm
dpo_beta: 5
checkpointing_steps: 1000
33 changes: 33 additions & 0 deletions configs/train_configs/dpo/tulu3_preview_pref_v3.3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
model_name_or_path: /model
model_revision: main
use_flash_attn: true
gradient_checkpointing: true
dataset_mixer:
# Same UltraFeedback data from Tulu 2
allenai/ultrafeedback_binarized_cleaned_train: 1.0
# Custom conversion of daring anteater synthetic data into preferences
ai2-adapt-dev/DaringAnteater-prefs-RM-filter: 1.0
# Modifications of WildChat data to preferences with
ai2-adapt-dev/WildChat-prefs-280824: 1.0
# Custom IF Eval data with Llama 3.1 405B for chosen and Tulu 2 as rejected
ai2-adapt-dev/Llama-3.1-if_taxonomy_tulu: 1.0
tokenizer_name: /model
use_slow_tokenizer: true
max_seq_length: 2048
preprocessing_num_workers: 16
per_device_train_batch_size: 1
gradient_accumulation_steps: 16 # designed for 8 GPUs, so batch size 128
learning_rate: 5.0e-7
lr_scheduler_type: linear
warmup_ratio: 0.1
weight_decay: 0.0
num_train_epochs: 1
output_dir: /output
with_tracking: true
report_to:
- wandb
logging_steps: 1
use_lora: false
dpo_loss_type: dpo_norm
dpo_beta: 5
checkpointing_steps: 1000
7 changes: 5 additions & 2 deletions open_instruct/mix_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from open_instruct.finetune import FlatArguments

# script for mixing and saving data
from .utils import ArgumentParserPlus, FlatArguments, get_datasets
from open_instruct.utils import ArgumentParserPlus, get_datasets

# Run as module for local imports, e.g.:
# python -m open_instruct.mix_data configs/train_configs/sft/default.yaml --dataset_mix_dir=output/tmp/
# python open_instruct/mix_data.py configs/train_configs/sft/tulu3_8b_preview_mix_v3.4.yaml --dataset_mix_dir=output/tmp/
# can pass --save_to_hub=allenai/tulu-v3.1-mix-preview-4096-OLMoE
# note that = is needed with our argparser


def main():
Expand Down
54 changes: 54 additions & 0 deletions open_instruct/mix_data_preferences.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# !/usr/bin/env python
# coding=utf-8
# Copyright 2024 AllenAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from open_instruct.dpo_tune import FlatArguments

# script for mixing and saving data
from open_instruct.utils import ArgumentParserPlus, get_datasets

# Run as module for local imports, e.g.:
# python open_instruct/mix_data_preferences.py configs/train_configs/sft/tulu3_8b_preview_mix_v3.4.yaml --dataset_mix_dir=output/tmp/
# can pass --save_to_hub=allenai/tulu-v3.1-mix-preview-4096-OLMoE
# note that = is needed with our argparser


def main():
parser = ArgumentParserPlus((FlatArguments))
args = parser.parse()

# assert that data_mixer is not none in config
assert args.dataset_mixer is not None, "data_mixer is required in config"

raw_datasets = get_datasets(
args.dataset_mixer,
configs=args.dataset_config_name,
splits=["train"],
save_data_dir=args.dataset_mix_dir, # location where dataset is saved as json
columns_to_keep=["chosen", "rejected"],
keep_ids=True,
)

# print first 5 samples of dataset
for i in range(5):
print(raw_datasets["train"][i])

# if args.save_to_hub is not none, push dataset to hub
if args.save_to_hub:
raw_datasets["train"].push_to_hub(args.save_to_hub, private=True)


if __name__ == "__main__":
main()

0 comments on commit f0b3def

Please sign in to comment.