Skip to content

Commit

Permalink
Synthetic preference dataset (#332)
Browse files Browse the repository at this point in the history
* push changes

* Support hf revision when generating

* style an quality

* add synthetic preference dataset scripts

* deal with reference generation as part of rejection sampling as well

* add litellm dependency

* suppport dataset_mixer_list

* Add openai api key to useful secretes

* style and quality

* quick push

* add reference

* quick update
  • Loading branch information
vwxyzjn authored Oct 2, 2024
1 parent 1012ff7 commit d2c1838
Show file tree
Hide file tree
Showing 10 changed files with 542 additions and 90 deletions.
31 changes: 23 additions & 8 deletions docs/algorithms/rejection_sampling.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,37 +11,52 @@ This code supports HF models, local models and also API-based models (e.g., `gpt

```bash
# 1. first sample a bunch of completions given prompts
# Here is an example created dataset: https://huggingface.co/datasets/vwxyzjn/generation_1724272894
# Here is an example created dataset: https://huggingface.co/datasets/vwxyzjn/generation_1727879425
python open_instruct/rejection_sampling/generation.py \
--dataset_name allenai/tulu-v2-sft-mixture \
--dataset_mixer_list allenai/tulu-v2-sft-mixture 100 \
--dataset_splits train \
--model_name_or_path allenai/llama-3-tulu-2-8b \
--num_completions 3 \
--save_filename output/completions.jsonl \
--sanity_check \
--push_to_hub
--push_to_hub
```

### Scoring completions
You can use either a single RM to score responses or a list of RMs. In the latter case, we will take the majority vote to compute the final score. The RMs can be models explicitly trained as RMs, HF LMs, or API-based models.

Note that by default we include the reference completion in the list of completions to perform rejection sampling. This can be disabled by setting `--no_include_reference_completion_for_rejection_sampling`

```bash
# 2.1 tokenize them and run a reward model to filter them
# Here is an example created dataset: https://huggingface.co/datasets/vwxyzjn/rejection_sampling_1724273165
# Here is an example created dataset for raw scores: https://huggingface.co/datasets/vwxyzjn/rejection_sampling_scores_1724273165
# Here is an example created dataset: https://huggingface.co/datasets/vwxyzjn/rejection_sampling_1727887719
# Here is an example created dataset for raw scores: https://huggingface.co/datasets/vwxyzjn/rejection_sampling_scores_1727887719/
python open_instruct/rejection_sampling/rejection_sampling.py \
--input_filename output/completions.jsonl \
--model_names_or_paths allenai/llama-3-tulu-2-8b-uf-mean-rm \
--save_filename_scores output/completions_scores.jsonl \
--save_filename output/rejection_sampled_completions.jsonl \
--num_completions 3 \
--push_to_hub \
--num_gpus 1 \

# 2.1.2 without reference completion in rejection sampling
# Here is an example created dataset: https://huggingface.co/datasets/vwxyzjn/rejection_sampling_1727887719
# Here is an example created dataset for raw scores: https://huggingface.co/datasets/vwxyzjn/rejection_sampling_scores_1727887719/
python open_instruct/rejection_sampling/rejection_sampling.py \
--input_filename output/completions.jsonl \
--model_names_or_paths allenai/llama-3-tulu-2-8b-uf-mean-rm \
--save_filename_scores output/completions_scores.jsonl \
--save_filename output/rejection_sampled_completions.jsonl \
--no_include_reference_completion_for_rejection_sampling \
--num_completions 3 \
--push_to_hub \
--num_gpus 1 \

# 2.2 tokenize them and run llm as a judge
# Note then when using LLM as a judge, it's possible that llm api failed to produce a score in our expected
# format, so score extraction failed and we simply mark the score -1.
# Here is an example created dataset: https://huggingface.co/datasets/vwxyzjn/rejection_sampling_1724273303
# Here is an example created dataset for raw scores: https://huggingface.co/datasets/vwxyzjn/rejection_sampling_scores_1724273303
# Here is an example created dataset: https://huggingface.co/datasets/vwxyzjn/rejection_sampling_1727889563
# Here is an example created dataset for raw scores: https://huggingface.co/datasets/vwxyzjn/rejection_sampling_scores_1727889563
python open_instruct/rejection_sampling/rejection_sampling.py \
--input_filename output/completions.jsonl \
--model_names_or_paths gpt-4o-mini \
Expand Down
81 changes: 81 additions & 0 deletions docs/algorithms/synthetic_preference_dataset.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Synthetic preference dataset

This section focuses explicitly on creating synthetic preference datasets.

# Debug run (use an interactive session)

This code supports HF models, local models and also API-based models (e.g., `gpt-4`). For generating completions, the code now accepts one model at a time, but we're working on adding an ensemble of models. Stay tuned.

```bash
# 1. first sample a bunch of completions given prompts
# Here is an example created dataset: https://huggingface.co/datasets/vwxyzjn/generation_1725567768
python open_instruct/rejection_sampling/generation.py \
--dataset_mixer_list HuggingFaceH4/no_robots 100 \
--dataset_splits train \
--model_name_or_path allenai/llama-3-tulu-2-8b \
--num_completions 3 \
--save_filename output/completions.jsonl \
--sanity_check \
--push_to_hub
```

### Create preference pairs

```bash
# 2.1 do LLM as a judge to create synthetic preference dataset
# Here is an example created dataset: https://huggingface.co/datasets/vwxyzjn/synthetic_preference_dataset_1725567862
python open_instruct/rejection_sampling/synthetic_preference_dataset.py \
--input_filename output/completions.jsonl \
--model gpt-4o-2024-08-06 \
--save_filename output/synthetic_preferences.jsonl \
--num_completions 3 \
--push_to_hub \
```


You can visualize the dataset via

```bash
python -m costa_utils.hf_viz \
--sft vwxyzjn/synthetic_preference_dataset_1725567862 \
--split train \
--sft_messages_column_name whole_conversation

python -m costa_utils.hf_viz \
--preference vwxyzjn/synthetic_preference_dataset_1725567862 \
--split train \
--preference_chosen_column_name chosen \
--preference_rejected_column_name rejected
```

![synthetic_preference_dataset](synthetic_preference_dataset.png)

# Run through the entire dataset run

To run through the entire dataset you would need a lot more GPUs to finish the generation more quickly.


```bash
# NOTE: the scripts below only generate 400 prompts, so it's for demonstration purposes only. The scripts are highly scalable, and you could modify its `num_prompts=400` to something else like 300000 for the tulu dataset.

# you need to make sure your default beaker workspace has WANDB_API_KEY and HF_TOKEN secrets in them
beaker secret write HF_TOKEN xxxxxxxxxxxx
beaker secret write WANDB_API_KEY xxxxxxxxxxx

# Docker mode: using caches from WEKA
deploy_mode="docker_weka" bash scripts/synthetic_preference_dataset.bash

# Docker mode: using caches from NFS
deploy_mode="docker_nfs" bash scripts/synthetic_preference_dataset.bash

# Docker mode: do not use caches
deploy_mode="docker" bash scripts/synthetic_preference_dataset.bash

# If you have environment setup with NFS and want to launch debug mode:
deploy_mode="nfs" bash scripts/synthetic_preference_dataset.bash
```

You can see a demo [here](https://drive.google.com/file/d/1dq3KG15ajpOv8tFYEZGS4tlW7G55oOYP/view?usp=sharing)

<img width="1327" alt="image" src="https://github.com/user-attachments/assets/71a15671-e054-4eab-a571-715881958e74">

Binary file added docs/algorithms/synthetic_preference_dataset.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions mason.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def get_env_vars(pure_docker_mode, cluster: List[str], beaker_secrets, whoami, r
"HF_TOKEN",
"WANDB_API_KEY",
"BEAKER_TOKEN",
"OPENAI_API_KEY",
]
for useful_secret in useful_secrets:
if f"{whoami}_{useful_secret}" in beaker_secrets:
Expand Down
73 changes: 34 additions & 39 deletions open_instruct/rejection_sampling/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,22 @@
from pprint import pformat
from typing import Dict, List, Optional

from datasets import load_dataset
from huggingface_hub import HfApi
from huggingface_hub.repocard import RepoCard
from rich.pretty import pprint
from transformers import AutoTokenizer, HfArgumentParser
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

from open_instruct.dataset_processor import (
INPUT_IDS_PROMPT_KEY,
DatasetConfig,
SFTDatasetProcessor,
)
from open_instruct.rejection_sampling.api_generate import ( # Import your classes
LLMGenerationConfig,
LLMProcessor,
)
from open_instruct.utils import ArgumentParserPlus, combine_dataset

api = HfApi()
# we don't use `multiprocessing.cpu_count()` because typically we only have 12 CPUs
Expand All @@ -44,6 +49,11 @@

@dataclass
class Args:
dataset_mixer_list: List[str]
dataset_splits: List[str] = None
dataset_start_idx: int = 0
dataset_end_idx: Optional[int] = None

model_name_or_path: str = "cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr"
revision: str = "main"
save_filename: str = "completions.jsonl"
Expand All @@ -66,18 +76,6 @@ class GenerationArgs:
tensor_parallel_size: int = 1


@dataclass
class DatasetArgs:
dataset_name: str = None
dataset_text_field: str = "prompt"
dataset_train_split: str = "train"
dataset_test_split: str = "validation"
dataset_start_idx: int = 0
dataset_end_idx: Optional[int] = 100
sanity_check: bool = False
sanity_check_size: int = 100


def save_jsonl(save_filename: str, table: Dict[str, List]):
first_key = list(table.keys())[0]
os.makedirs(os.path.dirname(save_filename), exist_ok=True)
Expand All @@ -100,6 +98,7 @@ def generate_with_vllm(model_name_or_path: str, revision: str, prompt_token_ids:
revision=revision,
tokenizer_revision=revision,
tensor_parallel_size=gen_args.tensor_parallel_size,
max_model_len=gen_args.response_length,
)

# filter out prompts which are beyond the model's max token length
Expand Down Expand Up @@ -144,35 +143,32 @@ def format_conversation(messages: list) -> str:
return "\n".join(formatted_conversation)


def main(args: Args, dataset_args: DatasetArgs, gen_args: GenerationArgs):

ds = load_dataset(dataset_args.dataset_name)
if dataset_args.sanity_check:
for key in ds:
ds[key] = ds[key].select(range(min(dataset_args.sanity_check_size, len(ds[key]))))
if dataset_args.dataset_end_idx is None:
dataset_args.dataset_end_idx = len(ds[dataset_args.dataset_train_split])
for key in ds:
ds[key] = ds[key].select(range(dataset_args.dataset_start_idx, dataset_args.dataset_end_idx))
pprint([dataset_args, args, gen_args])
def main(args: Args, dataset_config: DatasetConfig, gen_args: GenerationArgs):
dataset = combine_dataset(
args.dataset_mixer_list,
splits=args.dataset_splits,
columns_to_keep=[dataset_config.sft_messages_key],
)
if args.dataset_end_idx is None:
args.dataset_end_idx = len(dataset)
dataset = dataset.select(range(args.dataset_start_idx, args.dataset_end_idx))
pprint([dataset_config, args, gen_args])

if "gpt-3.5" in args.model_name_or_path or "gpt-4" in args.model_name_or_path:
ds = ds.map(
dataset = dataset.map(
lambda x: {"prompt": format_conversation(x["messages"][:-1])},
num_proc=NUM_CPUS_FOR_DATASET_MAP,
)
messages = ds[dataset_args.dataset_train_split]["prompt"]
messages = dataset["prompt"]
responses = asyncio.run(generate_with_openai(args.model_name_or_path, messages, args, gen_args))
outputs = [{"outputs": [{"text": response} for response in responses]}]

else:
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, revision=args.revision)

ds = ds.map(
lambda x: {"prompt_token_ids": tokenizer.apply_chat_template(x["messages"][:-1])},
num_proc=NUM_CPUS_FOR_DATASET_MAP,
)
prompt_token_ids = ds[dataset_args.dataset_train_split]["prompt_token_ids"]
dataset_processor = SFTDatasetProcessor(tokenizer=tokenizer, config=dataset_config)
dataset = dataset_processor.tokenize(dataset)
dataset = dataset_processor.filter(dataset)
prompt_token_ids = dataset[INPUT_IDS_PROMPT_KEY]
outputs = generate_with_vllm(args.model_name_or_path, args.revision, prompt_token_ids, gen_args)

# Assuming we generate n=3 completions per prompt; the outputs will look like:
Expand All @@ -185,7 +181,7 @@ def main(args: Args, dataset_args: DatasetArgs, gen_args: GenerationArgs):
# ...
table = defaultdict(list)
num_prompt_with_identical_completions = 0
for output, messages in zip(outputs, ds[dataset_args.dataset_train_split]["messages"]):
for output, messages in zip(outputs, dataset["messages"]):
# if the model completions are exactly the same across all completions per prompt, we can skip this
if len(set(tuple(item["text"]) for item in output["outputs"])) == 1:
num_prompt_with_identical_completions += 1
Expand Down Expand Up @@ -231,8 +227,8 @@ def main(args: Args, dataset_args: DatasetArgs, gen_args: GenerationArgs):
args:
{pformat(vars(args))}
dataset_args:
{pformat(vars(dataset_args))}
dataset_config:
{pformat(vars(dataset_config))}
gen_args:
{pformat(vars(gen_args))}
Expand All @@ -251,6 +247,5 @@ def main(args: Args, dataset_args: DatasetArgs, gen_args: GenerationArgs):


if __name__ == "__main__":
parser = HfArgumentParser((Args, DatasetArgs, GenerationArgs))
args, dataset_args, gen_args = parser.parse_args_into_dataclasses()
main(args, dataset_args, gen_args)
parser = ArgumentParserPlus((Args, DatasetConfig, GenerationArgs))
main(*parser.parse())
Loading

0 comments on commit d2c1838

Please sign in to comment.