Skip to content

Commit

Permalink
Fix simpo metrics (#374)
Browse files Browse the repository at this point in the history
* Fix simpo metrics

* style and qualtiy
  • Loading branch information
vwxyzjn authored Oct 3, 2024
1 parent 0d3e95a commit 6624eba
Showing 1 changed file with 21 additions and 15 deletions.
36 changes: 21 additions & 15 deletions open_instruct/dpo_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,17 +1005,18 @@ def load_model():

# We keep track of the loss at each logged step
with torch.no_grad():
chosen_rewards = (args.dpo_beta * (policy_chosen_logps - reference_chosen_logps)).mean()
rejected_rewards = (args.dpo_beta * (policy_rejected_logps - reference_rejected_logps)).mean()
average_rewards = (chosen_rewards + rejected_rewards) / 2
accuracy = (chosen_rewards > rejected_rewards).float().mean()
margin = (chosen_rewards - rejected_rewards).mean()
local_metrics[0] += loss
local_metrics[1] += chosen_rewards
local_metrics[2] += rejected_rewards
local_metrics[3] += average_rewards
local_metrics[4] += accuracy
local_metrics[5] += margin
if args.dpo_loss_type == "dpo" or args.dpo_loss_type == "dpo_norm":
chosen_rewards = (args.dpo_beta * (policy_chosen_logps - reference_chosen_logps)).mean()
rejected_rewards = (args.dpo_beta * (policy_rejected_logps - reference_rejected_logps)).mean()
average_rewards = (chosen_rewards + rejected_rewards) / 2
accuracy = (chosen_rewards > rejected_rewards).float().mean()
margin = (chosen_rewards - rejected_rewards).mean()
local_metrics[1] += chosen_rewards
local_metrics[2] += rejected_rewards
local_metrics[3] += average_rewards
local_metrics[4] += accuracy
local_metrics[5] += margin
local_metrics[6] += policy_chosen_logps.mean()
local_metrics[7] += policy_rejected_logps.mean()
if args.load_balancing_loss:
Expand All @@ -1035,14 +1036,19 @@ def load_model():
"learning_rate": lr_scheduler.get_last_lr()[0],
"epoch": episode / len(train_dataset),
"train_loss": global_metrics[0],
"rewards/chosen": global_metrics[1],
"rewards/rejected": global_metrics[2],
"rewards/average": global_metrics[3],
"rewards/accuracy": global_metrics[4],
"rewards/margin": global_metrics[5],
"logps/chosen": global_metrics[6],
"logps/rejected": global_metrics[7],
}
if args.dpo_loss_type == "dpo" or args.dpo_loss_type == "dpo_norm":
metrics_to_log.update(
{
"rewards/chosen": global_metrics[1],
"rewards/rejected": global_metrics[2],
"rewards/average": global_metrics[3],
"rewards/accuracy": global_metrics[4],
"rewards/margin": global_metrics[5],
}
)
logger_str = (
f" Step: {completed_steps}, LR: {lr_scheduler.get_last_lr()[0]}, Loss: {global_metrics[0]}"
)
Expand Down

0 comments on commit 6624eba

Please sign in to comment.