-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[0.6.7.4] [feat]: update Soft Q learning
- Loading branch information
1 parent
5db764e
commit 297118d
Showing
11 changed files
with
247 additions
and
219 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,13 +5,13 @@ | |
Email: [email protected] | ||
Date: 2023-01-01 16:20:49 | ||
LastEditor: JiangJi | ||
LastEditTime: 2024-07-28 13:10:04 | ||
LastEditTime: 2024-07-31 14:26:18 | ||
Discription: | ||
''' | ||
from joyrl import algos, framework, envs | ||
from joyrl.run import run | ||
|
||
__version__ = "0.6.7.3" | ||
__version__ = "0.6.7.4" | ||
|
||
__all__ = [ | ||
"algos", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
Email: [email protected] | ||
Date: 2024-01-25 09:58:33 | ||
LastEditor: JiangJi | ||
LastEditTime: 2024-07-21 15:17:11 | ||
LastEditTime: 2024-07-31 10:15:12 | ||
Discription: | ||
''' | ||
import torch | ||
|
@@ -27,7 +27,6 @@ def __init__(self, cfg, **kwargs): | |
self.target_update = cfg.target_update | ||
self.sample_count = 0 | ||
self.update_step = 0 | ||
self.ou_noise = OUNoise(self.action_size_list) | ||
|
||
def create_model(self): | ||
self.model = QNetwork(self.cfg, self.state_size_list).to(self.device) | ||
|
@@ -52,7 +51,6 @@ def sample_action(self, state, **kwargs): | |
action = self.predict_action(state) | ||
else: | ||
action = get_model_actions(self.model, mode = 'random', actor_outputs = [{}] * len(self.action_size_list)) | ||
action = self.ou_noise.get_action(action, self.sample_count) | ||
return action | ||
|
||
@torch.no_grad() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
#!/usr/bin/env python | ||
# coding=utf-8 | ||
''' | ||
Author: JiangJi | ||
Email: [email protected] | ||
Date: 2024-07-30 13:40:26 | ||
LastEditor: JiangJi | ||
LastEditTime: 2024-07-30 13:40:27 | ||
Discription: | ||
''' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
#!/usr/bin/env python | ||
# coding=utf-8 | ||
''' | ||
Author: JiangJi | ||
Email: [email protected] | ||
Date: 2023-12-20 23:39:18 | ||
LastEditor: JiangJi | ||
LastEditTime: 2024-07-30 13:40:49 | ||
Discription: | ||
''' | ||
class AlgoConfig(): | ||
''' algorithm parameters | ||
''' | ||
def __init__(self) -> None: | ||
# set epsilon_start=epsilon_end to get fixed epsilon, i.e. epsilon=epsilon_end | ||
self.epsilon_start = 0.95 # epsilon start value | ||
self.epsilon_end = 0.01 # epsilon end value | ||
self.epsilon_decay = 500 # epsilon decay | ||
self.gamma = 0.95 # reward discount factor | ||
self.alpha = 0.4 # temperature parameter of softmax | ||
self.lr = 0.0001 # learning rate | ||
self.buffer_type = 'REPLAY_QUE' # replay buffer type | ||
self.max_buffer_size = 100000 # replay buffer size | ||
self.batch_size = 64 # batch size | ||
self.target_update = 4 # target network update frequency | ||
# value network layers config | ||
# [{'name': 'feature_1', 'layers': [{'layer_type': 'linear', 'layer_size': [256], 'activation': 'relu'}, {'layer_type': 'linear', 'layer_size': [256], 'activation': 'relu'}]}] | ||
self.branch_layers = [ | ||
# { | ||
# 'name': 'feature_1', | ||
# 'layers': | ||
# [ | ||
# {'layer_type': 'linear', 'layer_size': [64], 'activation': 'ReLU'}, | ||
# {'layer_type': 'linear', 'layer_size': [64], 'activation': 'ReLU'}, | ||
# ] | ||
# }, | ||
# { | ||
# 'name': 'feature_2', | ||
# 'layers': | ||
# [ | ||
# {'layer_type': 'linear', 'layer_size': [64], 'activation': 'ReLU'}, | ||
# {'layer_type': 'linear', 'layer_size': [64], 'activation': 'ReLU'}, | ||
# ] | ||
# } | ||
] | ||
self.merge_layers = [ | ||
# {'layer_type': 'linear', 'layer_size': [256], 'activation': 'ReLU'}, | ||
# {'layer_type': 'linear', 'layer_size': [256], 'activation': 'ReLU'}, | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
#!/usr/bin/env python | ||
# coding=utf-8 | ||
''' | ||
Author: JiangJi | ||
Email: [email protected] | ||
Date: 2024-07-30 13:40:11 | ||
LastEditor: JiangJi | ||
LastEditTime: 2024-07-30 13:40:12 | ||
Discription: | ||
''' | ||
from joyrl.algos.base.data_handler import BaseDataHandler | ||
|
||
class DataHandler(BaseDataHandler): | ||
def __init__(self, cfg): | ||
super().__init__(cfg) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
#!/usr/bin/env python | ||
# coding=utf-8 | ||
''' | ||
Author: JiangJi | ||
Email: [email protected] | ||
Date: 2024-01-25 09:58:33 | ||
LastEditor: JiangJi | ||
LastEditTime: 2024-07-31 10:19:50 | ||
Discription: | ||
''' | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import math,random | ||
from joyrl.algos.base.policy import BasePolicy | ||
from joyrl.algos.base.noise import OUNoise | ||
from joyrl.algos.base.network import * | ||
|
||
class Policy(BasePolicy): | ||
def __init__(self, cfg, **kwargs): | ||
super(Policy, self).__init__(cfg, **kwargs) | ||
self.gamma = cfg.gamma | ||
self.alpha = cfg.alpha | ||
# e-greedy parameters | ||
self.epsilon_start = cfg.epsilon_start | ||
self.epsilon_end = cfg.epsilon_end | ||
self.epsilon_decay = cfg.epsilon_decay | ||
self.target_update = cfg.target_update | ||
self.sample_count = 0 | ||
self.update_step = 0 | ||
|
||
def create_model(self): | ||
self.model = QNetwork(self.cfg, self.state_size_list).to(self.device) | ||
self.target_model = QNetwork(self.cfg, self.state_size_list).to(self.device) | ||
self.target_model.load_state_dict(self.model.state_dict()) # or use this to copy parameters | ||
|
||
def load_model_meta(self, model_meta): | ||
super().load_model_meta(model_meta) | ||
if model_meta.get('sample_count') is not None: | ||
self.sample_count = model_meta['sample_count'] | ||
|
||
def sample_action(self, state, **kwargs): | ||
''' sample action | ||
''' | ||
# epsilon must decay(linear,exponential and etc.) for balancing exploration and exploitation | ||
self.sample_count += 1 | ||
self.epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \ | ||
math.exp(-1. * self.sample_count / self.epsilon_decay) | ||
self.update_model_meta({'sample_count': self.sample_count}) | ||
if random.random() > self.epsilon: | ||
# before update, the network inference time may be longer | ||
action = self.predict_action(state) | ||
else: | ||
action = get_model_actions(self.model, mode = 'random', actor_outputs = [{}] * len(self.action_size_list)) | ||
return action | ||
|
||
@torch.no_grad() | ||
def predict_action(self,state, **kwargs): | ||
''' predict action | ||
''' | ||
state = self.process_sample_state(state) | ||
model_outputs = self.model(state) | ||
actor_outputs = model_outputs['actor_outputs'] | ||
q_values = [actor_outputs[i]['q_value'] for i in range(len(self.action_size_list))] | ||
values_soft = [ self.alpha * torch.logsumexp(q_value / self.alpha, dim=1, keepdim=True) for q_value in q_values] | ||
probs = [F.softmax((q_value - value_soft), dim=1) for q_value, value_soft in zip(q_values, values_soft)] | ||
dists = [torch.distributions.Categorical(probs = prob) for prob in probs] | ||
actions = [dist.sample().cpu().numpy().item() for dist in dists] | ||
return actions | ||
|
||
def learn(self, **kwargs): | ||
''' learn policy | ||
''' | ||
super().learn(**kwargs) | ||
# compute current Q values | ||
self.summary_loss = [] | ||
tot_loss = 0 | ||
actor_outputs = self.model(self.states)['actor_outputs'] | ||
target_actor_outputs = self.target_model(self.next_states)['actor_outputs'] | ||
for i in range(len(self.action_size_list)): | ||
actual_q_value = actor_outputs[i]['q_value'].gather(1, self.actions[i].long()) | ||
# compute next max q value | ||
next_q_value = target_actor_outputs[i]['q_value'] | ||
next_v_soft = self.alpha * torch.logsumexp(next_q_value / self.alpha, dim=1, keepdim=True) | ||
# compute target Q values | ||
target_q_value = self.rewards + (1 - self.dones) * self.gamma * next_v_soft | ||
# compute loss | ||
loss_i = nn.MSELoss()(actual_q_value, target_q_value) | ||
tot_loss += loss_i | ||
self.summary_loss.append(loss_i.item()) | ||
self.optimizer.zero_grad() | ||
tot_loss.backward() | ||
# clip to avoid gradient explosion | ||
for param in self.model.parameters(): | ||
param.grad.data.clamp_(-1, 1) | ||
self.optimizer.step() | ||
# update target net every C steps | ||
if self.update_step % self.target_update == 0: | ||
self.target_model.load_state_dict(self.model.state_dict()) | ||
self.update_step += 1 | ||
self.update_summary() # update summary | ||
|
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.