The Official PyTorch implementation of Diffusion Rewards Adversarial Imitation Learning (NeurIPS'24).
Chun-Mao Lai*1,
Hsiang-Chun Wang*1,
Ping-Chun Hsieh2,
Yu-Chiang Frank Wang1,3,
Min-Hung Chen3,
Shao-Hua Sun1
1NTU RobotLearning Lab, 2National Yang Ming Chiao Tung University, 3NVIDIA Research Taiwan
(*Equal contribution)
[Paper
] [Website
] [BibTeX
] [ICLRW'24 Poster
]
This work proposes Diffusion-Reward Adversarial Imitation Learning (DRAIL), which integrates a diffusion model into GAIL, aiming to yield more precise and smoother rewards for policy learning. Specifically, we propose a diffusion discriminative classifier to construct an enhanced discriminator; then, we design diffusion rewards based on the classifier's output for policy learning. We conduct extensive experiments in navigation, manipulation, and locomotion, verifying DRAIL's effectiveness compared to prior imitation learning methods. Moreover, additional experimental results demonstrate the generalizability and data efficiency of DRAIL. Visualized learned reward functions of GAIL and DRAIL suggest that DRAIL can produce more precise and smoother rewards.
For business inquiries, please visit our website and submit the form: NVIDIA Research Licensing.
-
This code base requires
Python 3.8
or higher. All package requirements are inrequirements.txt
. To install from scratch using Anaconda, use the following commands.conda create -n [your_env_name] python=3.8 conda activate [your_env_name] ./utils/setup.sh
-
Setup Weights and Biases by first logging in with
wandb login <YOUR_API_KEY>
and then editingconfig.yaml
with your W&B username and project name.
-
Download expert demonstration datasets to
./expert_datasets
. Most of the expert demonstration datasets we used are provided by goal_prox_il. We provide a script for downloading and post-processing the expert datasets../utils/expert_data.sh
-
For the ‘Walker’ task, we provide a script to generate expert demonstrations with fewer trajectories. Please check out Walker/Expert Data to obtain it and place it under
./expert_datasets
. Execute the following command to post-process the expert data:./utils/clip_walker.py
To replicate the experiments conducted in our paper, follow these steps:
- Select Configuration Files: The wandb sweep configuration files for all tasks can be found in the
configs
directory. Inside each lowest-level directory (e.g.,./configs/push/1.50/2000
,./configs/walker/5traj
, etc.), you'll find six common files:drail.yaml
drail-un.yaml
gail.yaml
wail.yaml
bc.yaml
diffusion-policy.yaml
- Run Experiments: After selecting the desired configuration file, execute the following command:
./utils/wandb.sh <Path_to_Configuration_File.yaml>
Below are detailed descriptions of each task:
- Path:
./configs/maze
- Expert Coverages: The
./configs/maze
directory contains subdirectories names after different expert coverages, including100
,75
,50
, and25
. Each folder represents a specific expert coverage configuration.
- Path:
./configs/pick
- Noise Levels: The
./configs/pick
directory contains subdirectories named after different noise levels, including1.00
,1.25
,1.50
,1.75
, and2.00
. Each folder represents a specific noise level configuration.
- Path:
./configs/push
- Noise Levels: The
./configs/push
directory contains subdirectories named after different noise levels, including1.00
,1.25
,1.50
,1.75
, and2.00
. Each folder represents a specific noise level configuration. - Expert Transitions: Within the
./configs/push/1.50
directory, there are four additional subdirectories:2000
,5000
,10000
, and20311
. These numbers denote the number of expert transitions available for this setting. For other noise levels, the default number of expert transitions is 20311.
- Path:
./configs/hand
- Noise Levels: The
./configs/hand
directory contains subdirectories named after different noise levels, including1.00
,1.25
,1.50
,1.75
, and2.00
. Each folder represents a specific noise level configuration.
- Path:
./configs/ant
- Noise Levels: The
./configs/ant
directory contains subdirectories named after different noise levels, including0.00
,0.01
,0.03
, and0.05
. Each folder represents a specific noise level configuration.
- Path:
./configs/walker
- Expert Trajectories: The
./configs/walker
directory contains subdirectories named after different number of expert trajectories, including25traj
,5traj
,3traj
,2traj
, and1traj
. - Expert Data: To obtain the expert data, follow these steps using the provided code and configuration under the
./configs/walker/expert
directory in this section:- Utilize the provided configuration file
ppo.yaml
to train a PPO model (default name:ppo_walker_expert_model.pt
) and put it into theexpert_datasets
directory. - Modify the
--load-file
argument in thecollect_trajs.sh
script to point to the path of your trained model (typically located in./data/trained_models/Walker2d-v3/...
). - Execute the following command:
./configs/walker/expert/collect_trajs.sh
- The expert data will be stored in
./data/traj/Walker2d-v3/...
.
- Utilize the provided configuration file
drail
: Implementation of our main method.drail/drail.py
: Code for our method DRAIL.drail/drail_un.py
: Code for the variant of our method DRAIL-UN.drail/ddpm
: Directory containing the diffusion model.drail/ddpm/ddpm_condition.py
: Code for the diffusion model of DRAIL.drail/ddpm/ddpm_condition.py
: Code for the diffusion model of DRAIL-UN.drail/ddpm/policy_model.py
: Code for the diffusion model of Diffusion Policy.
utils
: Useful scripts.utils/clip_halfcheetah.py
: Generate the Halfcheetah expert data with fewer trajectories.utils/clip_walker.py
: Generate the Walker expert data with fewer trajectories.utils/clip_push.py
: Generate the FetchPush expert data with fewer transitions.utils/wandb.sh
: Script to automatically create and execute wandb command from a configuration file.utils/setup.sh
: Script to install and set up the conda environment.utils/expert_data.sh
: Script to download and postprocess the expert demonstrations.
shape_env
: Customized environment code.shape_env/rollout_sine2000.py
: Code to generate theSine
function expert transitions.
goal_prox
: Customized environment code from goal_prox_il.goal_prox/envs/ant.py
: AntGoal locomotion task.goal_prox/envs/fetch/custom_fetch.py
: FetchPick task.goal_prox/envs/fetch/custom_push.py
: FetchPush task.goal_prox/envs/hand/manipulate.py
: HandRotate task.goal_prox/gym_minigrid
: MiniGrid code for navigation environment from maximecb.
rl-toolkit
: Base RL code and code for imitation learning baselines from rl-toolkit.rl-toolkit/rlf/algos/on_policy/ppo.py
: PPO policy updater code used for RL.rl-toolkit/rlf/algos/il/gail.py
: Baseline Generative Adversarial Imitation Learning (GAIL) code.rl-toolkit/rlf/algos/il/wail.py
: Baseline Wasserstein Adversarial Imitation Learning (WAIL) code.rl-toolkit/rlf/algos/il/bc.py
: Baseline Behavioral Cloning (BC) code.rl-toolkit/rlf/algos/il/dp.py
: Baseline Diffusion Policy code.
d4rl
: Codebase from D4RL: Datasets for Deep Data-Driven Reinforcement Learning for Maze2D.
- The base code was adapted from goal_prox_il.
- The Grid world environment was obtained from maximecb
- The Fetch and Hand Rotate environments were customized based on OpenAI implementations.
- The Ant environment was customized by goal_prox_il and originated from Farama-Foundation.
- The SAC code was obtained from denisyarats
- The Maze2D environment is based on D4RL: Datasets for Deep Data-Driven Reinforcement Learning.
- The expert demonstrations of the Maze, FetchPick, FetchPush, HandRotate, and AntReach tasks were obtained from goal_prox_il.
If you find DRAIL useful, please consider giving a star and citation:
@article{lai2024diffusion,
title={Diffusion-Reward Adversarial Imitation Learning},
author={Lai, Chun-Mao and Wang, Hsiang-Chun and Hsieh, Ping-Chun and Wang, Yu-Chiang Frank and Chen, Min-Hung and Sun, Shao-Hua},
journal={arXiv preprint arXiv:2405.16194},
year={2024}
}
Copyright © 2024, NVIDIA Corporation. All rights reserved.
This work is made available under the NVIDIA Source Code License-NC. Click here to view a copy of this license.