Skip to content

Commit

Permalink
[0.6.7.4] [feat]: update Soft Q learning
Browse files Browse the repository at this point in the history
  • Loading branch information
johnjim0816 committed Jul 31, 2024
1 parent 5db764e commit 297118d
Show file tree
Hide file tree
Showing 11 changed files with 247 additions and 219 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ More tutorials and API documentation are hosted on [JoyRL docs](https://datawhal
| TD3 | [TD3 Paper](https://arxiv.org/pdf/1802.09477) | [johnjim0816](https://github.com/johnjim0816) | |
| A2C/A3C | [A3C Paper](https://arxiv.org/abs/1602.01783) | [johnjim0816](https://github.com/johnjim0816) | |
| PPO | [PPO Paper](https://arxiv.org/abs/1707.06347) | [johnjim0816](https://github.com/johnjim0816) | |
| SoftQ | [SoftQ Paper](https://arxiv.org/abs/1702.08165) | [johnjim0816](https://github.com/johnjim0816) | |

## Why JoyRL?

Expand All @@ -82,7 +83,7 @@ More tutorials and API documentation are hosted on [JoyRL docs](https://datawhal
| [rlpyt](https://github.com/astooke/rlpyt) | [![GitHub stars](https://img.shields.io/github/stars/astooke/rlpyt)](https://github.com/astooke/rlpyt/stargazers) | 11 | :x: | :x: | :heavy_check_mark: | :heavy_check_mark: | PyTorch |
| [ChainerRL](https://github.com/chainer/chainerrl) | [![GitHub stars](https://img.shields.io/github/stars/chainer/chainerrl)](https://github.com/chainer/chainerrl/stargazers) | 18 | :heavy_check_mark: (gym) | :x: | :heavy_check_mark: | :x: | Chainer |
| [Tianshou](https://github.com/thu-ml/tianshou) | [![GitHub stars](https://img.shields.io/github/stars/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/stargazers) | 20 | :heavy_check_mark: (Gymnasium) | :x: | :heavy_check_mark: | :heavy_check_mark: | PyTorch |
| [JoyRL](https://github.com/datawhalechina/joyrl) | ![GitHub stars](https://img.shields.io/github/stars/datawhalechina/joyrl) | 10 | :heavy_check_mark: (Gymnasium) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch |
| [JoyRL](https://github.com/datawhalechina/joyrl) | ![GitHub stars](https://img.shields.io/github/stars/datawhalechina/joyrl) | 11 | :heavy_check_mark: (Gymnasium) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch |

Here are some other highlghts of JoyRL:

Expand Down
1 change: 1 addition & 0 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ More tutorials and API documentation are hosted on [JoyRL docs](https://datawhal
| TD3 | [TD3 Paper](https://arxiv.org/pdf/1802.09477) | [johnjim0816](https://github.com/johnjim0816) | |
| A2C/A3C | [A3C Paper](https://arxiv.org/abs/1602.01783) | [johnjim0816](https://github.com/johnjim0816) | |
| PPO | [PPO Paper](https://arxiv.org/abs/1707.06347) | [johnjim0816](https://github.com/johnjim0816) | |
| SoftQ | [SoftQ Paper](https://arxiv.org/abs/1702.08165) | [johnjim0816](https://github.com/johnjim0816) | |

## Why JoyRL?

Expand Down
4 changes: 2 additions & 2 deletions joyrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
Email: [email protected]
Date: 2023-01-01 16:20:49
LastEditor: JiangJi
LastEditTime: 2024-07-28 13:10:04
LastEditTime: 2024-07-31 14:26:18
Discription:
'''
from joyrl import algos, framework, envs
from joyrl.run import run

__version__ = "0.6.7.3"
__version__ = "0.6.7.4"

__all__ = [
"algos",
Expand Down
4 changes: 1 addition & 3 deletions joyrl/algos/DQN/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Email: [email protected]
Date: 2024-01-25 09:58:33
LastEditor: JiangJi
LastEditTime: 2024-07-21 15:17:11
LastEditTime: 2024-07-31 10:15:12
Discription:
'''
import torch
Expand All @@ -27,7 +27,6 @@ def __init__(self, cfg, **kwargs):
self.target_update = cfg.target_update
self.sample_count = 0
self.update_step = 0
self.ou_noise = OUNoise(self.action_size_list)

def create_model(self):
self.model = QNetwork(self.cfg, self.state_size_list).to(self.device)
Expand All @@ -52,7 +51,6 @@ def sample_action(self, state, **kwargs):
action = self.predict_action(state)
else:
action = get_model_actions(self.model, mode = 'random', actor_outputs = [{}] * len(self.action_size_list))
action = self.ou_noise.get_action(action, self.sample_count)
return action

@torch.no_grad()
Expand Down
10 changes: 10 additions & 0 deletions joyrl/algos/SoftQ/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/usr/bin/env python
# coding=utf-8
'''
Author: JiangJi
Email: [email protected]
Date: 2024-07-30 13:40:26
LastEditor: JiangJi
LastEditTime: 2024-07-30 13:40:27
Discription:
'''
49 changes: 49 additions & 0 deletions joyrl/algos/SoftQ/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#!/usr/bin/env python
# coding=utf-8
'''
Author: JiangJi
Email: [email protected]
Date: 2023-12-20 23:39:18
LastEditor: JiangJi
LastEditTime: 2024-07-30 13:40:49
Discription:
'''
class AlgoConfig():
''' algorithm parameters
'''
def __init__(self) -> None:
# set epsilon_start=epsilon_end to get fixed epsilon, i.e. epsilon=epsilon_end
self.epsilon_start = 0.95 # epsilon start value
self.epsilon_end = 0.01 # epsilon end value
self.epsilon_decay = 500 # epsilon decay
self.gamma = 0.95 # reward discount factor
self.alpha = 0.4 # temperature parameter of softmax
self.lr = 0.0001 # learning rate
self.buffer_type = 'REPLAY_QUE' # replay buffer type
self.max_buffer_size = 100000 # replay buffer size
self.batch_size = 64 # batch size
self.target_update = 4 # target network update frequency
# value network layers config
# [{'name': 'feature_1', 'layers': [{'layer_type': 'linear', 'layer_size': [256], 'activation': 'relu'}, {'layer_type': 'linear', 'layer_size': [256], 'activation': 'relu'}]}]
self.branch_layers = [
# {
# 'name': 'feature_1',
# 'layers':
# [
# {'layer_type': 'linear', 'layer_size': [64], 'activation': 'ReLU'},
# {'layer_type': 'linear', 'layer_size': [64], 'activation': 'ReLU'},
# ]
# },
# {
# 'name': 'feature_2',
# 'layers':
# [
# {'layer_type': 'linear', 'layer_size': [64], 'activation': 'ReLU'},
# {'layer_type': 'linear', 'layer_size': [64], 'activation': 'ReLU'},
# ]
# }
]
self.merge_layers = [
# {'layer_type': 'linear', 'layer_size': [256], 'activation': 'ReLU'},
# {'layer_type': 'linear', 'layer_size': [256], 'activation': 'ReLU'},
]
16 changes: 16 additions & 0 deletions joyrl/algos/SoftQ/data_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/usr/bin/env python
# coding=utf-8
'''
Author: JiangJi
Email: [email protected]
Date: 2024-07-30 13:40:11
LastEditor: JiangJi
LastEditTime: 2024-07-30 13:40:12
Discription:
'''
from joyrl.algos.base.data_handler import BaseDataHandler

class DataHandler(BaseDataHandler):
def __init__(self, cfg):
super().__init__(cfg)

102 changes: 102 additions & 0 deletions joyrl/algos/SoftQ/policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#!/usr/bin/env python
# coding=utf-8
'''
Author: JiangJi
Email: [email protected]
Date: 2024-01-25 09:58:33
LastEditor: JiangJi
LastEditTime: 2024-07-31 10:19:50
Discription:
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import math,random
from joyrl.algos.base.policy import BasePolicy
from joyrl.algos.base.noise import OUNoise
from joyrl.algos.base.network import *

class Policy(BasePolicy):
def __init__(self, cfg, **kwargs):
super(Policy, self).__init__(cfg, **kwargs)
self.gamma = cfg.gamma
self.alpha = cfg.alpha
# e-greedy parameters
self.epsilon_start = cfg.epsilon_start
self.epsilon_end = cfg.epsilon_end
self.epsilon_decay = cfg.epsilon_decay
self.target_update = cfg.target_update
self.sample_count = 0
self.update_step = 0

def create_model(self):
self.model = QNetwork(self.cfg, self.state_size_list).to(self.device)
self.target_model = QNetwork(self.cfg, self.state_size_list).to(self.device)
self.target_model.load_state_dict(self.model.state_dict()) # or use this to copy parameters

def load_model_meta(self, model_meta):
super().load_model_meta(model_meta)
if model_meta.get('sample_count') is not None:
self.sample_count = model_meta['sample_count']

def sample_action(self, state, **kwargs):
''' sample action
'''
# epsilon must decay(linear,exponential and etc.) for balancing exploration and exploitation
self.sample_count += 1
self.epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
math.exp(-1. * self.sample_count / self.epsilon_decay)
self.update_model_meta({'sample_count': self.sample_count})
if random.random() > self.epsilon:
# before update, the network inference time may be longer
action = self.predict_action(state)
else:
action = get_model_actions(self.model, mode = 'random', actor_outputs = [{}] * len(self.action_size_list))
return action

@torch.no_grad()
def predict_action(self,state, **kwargs):
''' predict action
'''
state = self.process_sample_state(state)
model_outputs = self.model(state)
actor_outputs = model_outputs['actor_outputs']
q_values = [actor_outputs[i]['q_value'] for i in range(len(self.action_size_list))]
values_soft = [ self.alpha * torch.logsumexp(q_value / self.alpha, dim=1, keepdim=True) for q_value in q_values]
probs = [F.softmax((q_value - value_soft), dim=1) for q_value, value_soft in zip(q_values, values_soft)]
dists = [torch.distributions.Categorical(probs = prob) for prob in probs]
actions = [dist.sample().cpu().numpy().item() for dist in dists]
return actions

def learn(self, **kwargs):
''' learn policy
'''
super().learn(**kwargs)
# compute current Q values
self.summary_loss = []
tot_loss = 0
actor_outputs = self.model(self.states)['actor_outputs']
target_actor_outputs = self.target_model(self.next_states)['actor_outputs']
for i in range(len(self.action_size_list)):
actual_q_value = actor_outputs[i]['q_value'].gather(1, self.actions[i].long())
# compute next max q value
next_q_value = target_actor_outputs[i]['q_value']
next_v_soft = self.alpha * torch.logsumexp(next_q_value / self.alpha, dim=1, keepdim=True)
# compute target Q values
target_q_value = self.rewards + (1 - self.dones) * self.gamma * next_v_soft
# compute loss
loss_i = nn.MSELoss()(actual_q_value, target_q_value)
tot_loss += loss_i
self.summary_loss.append(loss_i.item())
self.optimizer.zero_grad()
tot_loss.backward()
# clip to avoid gradient explosion
for param in self.model.parameters():
param.grad.data.clamp_(-1, 1)
self.optimizer.step()
# update target net every C steps
if self.update_step % self.target_update == 0:
self.target_model.load_state_dict(self.model.state_dict())
self.update_step += 1
self.update_summary() # update summary

71 changes: 0 additions & 71 deletions joyrl/algos/SoftQ/softq.py

This file was deleted.

Loading

0 comments on commit 297118d

Please sign in to comment.