Skip to content

Commit

Permalink
Support multi-node with online DPO / PPO (#370)
Browse files Browse the repository at this point in the history
* Support multi-node with online DPO / PPO

* support all weka clusters
  • Loading branch information
vwxyzjn authored Oct 3, 2024
1 parent 8cb01df commit 0d3e95a
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 16 deletions.
81 changes: 69 additions & 12 deletions mason.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import re
from typing import List
import beaker
import os
Expand All @@ -12,6 +13,12 @@ def parse_beaker_dataset(dataset_str):

return {"mount_path": splt[0], "beaker": splt[1]}

WEKA_CLUSTERS = [
"ai2/jupiter-cirrascale-2",
"ai2/saturn-cirrascale",
"ai2/neptune-cirrascale",
]


def get_args():
parser = argparse.ArgumentParser()
Expand All @@ -24,6 +31,7 @@ def get_args():
)
parser.add_argument("--budget", type=str, help="Budget to use.", required=True)
parser.add_argument("--gpus", type=int, help="Number of gpus", default=0)
parser.add_argument("--num_nodes", type=int, help="Number of nodes", default=1)
parser.add_argument(
"--image",
type=str,
Expand Down Expand Up @@ -118,7 +126,7 @@ def parse_commands(command_args: List[str]) -> List[List[str]]:
return commands


def get_env_vars(pure_docker_mode, cluster: List[str], beaker_secrets, whoami, resumable):
def get_env_vars(pure_docker_mode: bool, cluster: List[str], beaker_secrets: List[str], whoami: str, resumable: bool, num_nodes: int):
env_vars = []
useful_secrets = [
"HF_TOKEN",
Expand Down Expand Up @@ -151,8 +159,8 @@ def get_env_vars(pure_docker_mode, cluster: List[str], beaker_secrets, whoami, r
),
])

# if we are not running on jupiter2, we try to mount the NFS
if "ai2/jupiter-cirrascale-2" not in cluster:
# if none of the cluster is in weka, we mount the NFS
if all(c not in WEKA_CLUSTERS for c in cluster):
env_vars.extend([
beaker.EnvVar(
name="HF_DATASETS_CACHE",
Expand All @@ -171,8 +179,19 @@ def get_env_vars(pure_docker_mode, cluster: List[str], beaker_secrets, whoami, r
value=f"/net/nfs.cirrascale/allennlp/deletable_checkpoint_states/{global_wandb_id}",
),
])
# if we only run on jupiter 2, we try to mount weka
elif len(cluster) == 1 and "ai2/jupiter-cirrascale-2" in cluster:
if len(cluster) == 1 and "ai2/pluto-cirrascale" in cluster:
env_vars.extend([
beaker.EnvVar(
name="NCCL_IB_HCA",
value="^=mlx5_1,mlx5_2",
),
beaker.EnvVar(
name="NCCL_DEBUG",
value="INFO",
),
])
# if all cluster is in weka, we mount the weka
elif all(c in WEKA_CLUSTERS for c in cluster):
env_vars.extend([
beaker.EnvVar(
name="HF_DATASETS_CACHE",
Expand All @@ -187,6 +206,21 @@ def get_env_vars(pure_docker_mode, cluster: List[str], beaker_secrets, whoami, r
value=f"/weka/allennlp/deletable_checkpoint_states/{global_wandb_id}",
),
])
if num_nodes > 1:
env_vars.extend([
beaker.EnvVar(
name="NCCL_SOCKET_IFNAME",
value="ib",
),
beaker.EnvVar(
name="NCCL_IB_HCA",
value="^=mlx5_bond_0",
),
beaker.EnvVar(
name="NCCL_DEBUG",
value="INFO",
),
])
# don't mount anything; assume no cache
else:
pass
Expand All @@ -209,16 +243,16 @@ def get_env_vars(pure_docker_mode, cluster: List[str], beaker_secrets, whoami, r
def get_datasets(beaker_datasets, cluster: List[str]):
"""if pure docker mode we don't mount the NFS; so we can run it on jupiter2"""
res = []
# if we are not running on jupiter2, we try to mount the NFS
if "ai2/jupiter-cirrascale-2" not in cluster:
# if none of the cluster is in weka, we mount the NFS
if all(c not in WEKA_CLUSTERS for c in cluster):
res = [
beaker.DataMount(
source=beaker.DataSource(host_path="/net/nfs.cirrascale"),
mount_path="/net/nfs.cirrascale",
),
]
# if we only run on jupiter 2, we try to mount weka
elif len(cluster) == 1 and "ai2/jupiter-cirrascale-2" in cluster:
# if all cluster is in weka, we mount the weka
elif all(c in WEKA_CLUSTERS for c in cluster):
res = [
beaker.DataMount(
source=beaker.DataSource(weka="oe-adapt-default"),
Expand All @@ -245,28 +279,51 @@ def make_task_spec(args, command, i, beaker_secrets, whoami, resumable: bool):
full_command = command
command = ['/bin/bash', '-c']
setup_commands = (
"echo 'Running on host: $BEAKER_REPLICA_RANK' && "
"echo 'Running on host: $BEAKER_LEADER_REPLICA_HOSTNAME' && "
"git config --global safe.directory '*' && " # fix the permission issue with git
"umask 000 && " # fix the permission issue with the cache folder
)
if not args.pure_docker_mode:
setup_commands += f"cd {os.getcwd()} && "
fully_command = setup_commands + " ".join(full_command)

join_full_command = " ".join(full_command)
# override accelerate call
if args.num_nodes > 1:
join_full_command = re.sub(
r'--num_processes (\d+)',
lambda m: (
f'--num_processes {int(m.group(1)) * args.num_nodes} '
f'--num_machines {args.num_nodes} '
'--machine_rank $BEAKER_REPLICA_RANK '
'--main_process_ip $BEAKER_LEADER_REPLICA_HOSTNAME '
'--main_process_port 29400 '
),
join_full_command
)
full_command = setup_commands + join_full_command
print(f"{full_command=}")


spec = beaker.TaskSpec(
name=f"{args.task_name}__{i}",
image=beaker.ImageSource(beaker=args.image),
command=command,
arguments=[fully_command],
arguments=[full_command],
result=beaker.ResultSpec(path="/output"),
datasets=get_datasets(args.beaker_datasets, args.cluster),
context=beaker.TaskContext(priority=beaker.Priority(args.priority),
preemptible=args.preemptible),
constraints=beaker.Constraints(cluster=args.cluster),
env_vars=get_env_vars(args.pure_docker_mode, args.cluster, beaker_secrets, whoami, resumable),
env_vars=get_env_vars(args.pure_docker_mode, args.cluster, beaker_secrets, whoami, resumable, args.num_nodes),
resources=beaker.TaskResources(gpu_count=args.gpus),
replicas=args.num_nodes,
)
if args.num_nodes > 1:
spec.leader_selection = True
spec.host_networking = True
spec.propagate_failure = True
spec.propagate_preemption = True

return spec

Expand Down
4 changes: 2 additions & 2 deletions open_instruct/online_dpo_vllm_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,8 +678,8 @@ def repeat_generator():
g_vllm_responses[:] = g_padded_response_ids
broadcast(g_vllm_responses, 0)
local_vllm_responses = g_vllm_responses[
accelerator.local_process_index
* queries.shape[0] : (accelerator.local_process_index + 1)
accelerator.process_index
* queries.shape[0] : (accelerator.process_index + 1)
* queries.shape[0]
]
query_responses = torch.cat((queries, local_vllm_responses), 1)
Expand Down
4 changes: 2 additions & 2 deletions open_instruct/ppo_vllm_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,8 +753,8 @@ def repeat_generator():
g_vllm_responses[:] = g_padded_response_ids
broadcast(g_vllm_responses, 0)
local_vllm_responses = g_vllm_responses[
accelerator.local_process_index
* queries.shape[0] : (accelerator.local_process_index + 1)
accelerator.process_index
* queries.shape[0] : (accelerator.process_index + 1)
* queries.shape[0]
]
query_responses = torch.cat((queries, local_vllm_responses), 1)
Expand Down

0 comments on commit 0d3e95a

Please sign in to comment.