-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[0.6.7.2] rewrite action noise; update DDPG and TD3
- Loading branch information
1 parent
c86dc5b
commit 425b3a9
Showing
11 changed files
with
161 additions
and
140 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,13 +5,13 @@ | |
Email: [email protected] | ||
Date: 2023-01-01 16:20:49 | ||
LastEditor: JiangJi | ||
LastEditTime: 2024-07-21 16:02:00 | ||
LastEditTime: 2024-07-28 11:25:20 | ||
Discription: | ||
''' | ||
from joyrl import algos, framework, envs | ||
from joyrl.run import run | ||
|
||
__version__ = "0.6.7.1" | ||
__version__ = "0.6.7.2" | ||
|
||
__all__ = [ | ||
"algos", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,13 +5,11 @@ | |
Email: [email protected] | ||
Date: 2024-07-20 14:15:24 | ||
LastEditor: JiangJi | ||
LastEditTime: 2024-07-21 14:52:51 | ||
LastEditTime: 2024-07-21 16:38:14 | ||
Discription: | ||
''' | ||
import copy | ||
import torch.nn as nn | ||
from joyrl.algos.base.network import CriticNetwork, ActorNetwork | ||
|
||
from joyrl.algos.base.network import * | ||
|
||
class Model(nn.Module): | ||
def __init__(self, cfg ): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,21 +5,21 @@ | |
Email: [email protected] | ||
Date: 2024-02-25 15:46:04 | ||
LastEditor: JiangJi | ||
LastEditTime: 2024-07-21 15:11:52 | ||
LastEditTime: 2024-07-28 11:08:50 | ||
Discription: | ||
''' | ||
import torch | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
from joyrl.algos.base.policy import BasePolicy | ||
from joyrl.algos.base.noise import OUNoise | ||
from joyrl.algos.base.noise import MultiHeadActionNoise | ||
from joyrl.algos.base.network import * | ||
from .model import Model | ||
|
||
class Policy(BasePolicy): | ||
def __init__(self,cfg, **kwargs) -> None: | ||
super(Policy, self).__init__(cfg, **kwargs) | ||
self.ou_noise = OUNoise(self.action_size_list) | ||
self.action_noise = MultiHeadActionNoise('ou',self.action_size_list) | ||
self.gamma = cfg.gamma | ||
self.tau = cfg.tau | ||
self.sample_count = 0 # sample count | ||
|
@@ -40,9 +40,6 @@ def create_optimizer(self): | |
self.critic_optimizer = optim.Adam(self.model.critic.parameters(), lr=self.cfg.critic_lr) | ||
|
||
def create_summary(self): | ||
''' | ||
创建 tensorboard 数据 | ||
''' | ||
self.summary = { | ||
'scalar': { | ||
'tot_loss': 0.0, | ||
|
@@ -52,8 +49,6 @@ def create_summary(self): | |
} | ||
|
||
def update_summary(self): | ||
''' 更新 tensorboard 数据 | ||
''' | ||
if hasattr(self, 'tot_loss'): | ||
self.summary['scalar']['tot_loss'] = self.tot_loss.item() | ||
self.summary['scalar']['policy_loss'] = self.policy_loss.item() | ||
|
@@ -67,7 +62,7 @@ def sample_action(self, state, **kwargs): | |
actor_outputs = self.model.actor(state) | ||
self.mu = torch.cat([actor_outputs[i]['mu'] for i in range(len(self.action_size_list))], dim=1) | ||
actions = get_model_actions(self.model, mode = 'sample', actor_outputs = actor_outputs) | ||
actions = self.ou_noise.get_action(actions, self.sample_count) # add noise to action | ||
actions = self.action_noise.get_action(actions, t = self.sample_count) # add noise to action | ||
return actions | ||
|
||
def update_policy_transition(self): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
#!/usr/bin/env python | ||
# coding=utf-8 | ||
''' | ||
Author: JiangJi | ||
Email: [email protected] | ||
Date: 2024-07-21 16:37:59 | ||
LastEditor: JiangJi | ||
LastEditTime: 2024-07-21 16:38:00 | ||
Discription: | ||
''' | ||
import torch.nn as nn | ||
from joyrl.algos.base.network import * | ||
|
||
class Model(nn.Module): | ||
def __init__(self, cfg ): | ||
super(Model, self).__init__() | ||
state_size_list = cfg.obs_space_info.size | ||
action_size_list = cfg.action_space_info.size | ||
critic_input_size_list = state_size_list+ [[None, len(action_size_list)]] | ||
self.actor = ActorNetwork(cfg, input_size_list = state_size_list) | ||
self.critic_1 = CriticNetwork(cfg, input_size_list = critic_input_size_list) | ||
self.critic_2 = CriticNetwork(cfg, input_size_list = critic_input_size_list) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,14 +5,15 @@ | |
Email: [email protected] | ||
Date: 2023-12-25 09:28:26 | ||
LastEditor: JiangJi | ||
LastEditTime: 2024-07-21 15:59:56 | ||
LastEditTime: 2024-07-21 16:46:03 | ||
Discription: | ||
''' | ||
from enum import Enum | ||
import random | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.distributions import Categorical,Normal | ||
from enum import Enum | ||
from joyrl.algos.base.base_layer import LayerConfig | ||
from joyrl.algos.base.base_layer import create_layer | ||
|
||
|
@@ -179,6 +180,9 @@ def predict_action(self, **kwargs): | |
mean = kwargs.get("mean", None) | ||
return {"action": mean.detach().cpu().numpy().item(), "log_prob": None} | ||
|
||
def random_action(self, **kwargs): | ||
return {"action": random.uniform(self.action_low, self.action_high)} | ||
|
||
def get_log_prob_action(self, actor_output, action): | ||
''' get log_probs | ||
''' | ||
|
@@ -229,3 +233,6 @@ def predict_action(self, **kwargs): | |
action = mu * self.action_scale + self.action_bias | ||
return {"action": action.detach().cpu().numpy().item()} | ||
|
||
def random_action(self, **kwargs): | ||
return {"action": random.uniform(self.action_low, self.action_high)} | ||
|
Oops, something went wrong.