Skip to content

Commit

Permalink
[0.6.7.2] rewrite action noise; update DDPG and TD3
Browse files Browse the repository at this point in the history
  • Loading branch information
johnjim0816 committed Jul 28, 2024
1 parent c86dc5b commit 425b3a9
Show file tree
Hide file tree
Showing 11 changed files with 161 additions and 140 deletions.
4 changes: 2 additions & 2 deletions joyrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
Email: [email protected]
Date: 2023-01-01 16:20:49
LastEditor: JiangJi
LastEditTime: 2024-07-21 16:02:00
LastEditTime: 2024-07-28 11:25:20
Discription:
'''
from joyrl import algos, framework, envs
from joyrl.run import run

__version__ = "0.6.7.1"
__version__ = "0.6.7.2"

__all__ = [
"algos",
Expand Down
6 changes: 2 additions & 4 deletions joyrl/algos/DDPG/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@
Email: [email protected]
Date: 2024-07-20 14:15:24
LastEditor: JiangJi
LastEditTime: 2024-07-21 14:52:51
LastEditTime: 2024-07-21 16:38:14
Discription:
'''
import copy
import torch.nn as nn
from joyrl.algos.base.network import CriticNetwork, ActorNetwork

from joyrl.algos.base.network import *

class Model(nn.Module):
def __init__(self, cfg ):
Expand Down
13 changes: 4 additions & 9 deletions joyrl/algos/DDPG/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,21 @@
Email: [email protected]
Date: 2024-02-25 15:46:04
LastEditor: JiangJi
LastEditTime: 2024-07-21 15:11:52
LastEditTime: 2024-07-28 11:08:50
Discription:
'''
import torch
import torch.nn.functional as F
import torch.optim as optim
from joyrl.algos.base.policy import BasePolicy
from joyrl.algos.base.noise import OUNoise
from joyrl.algos.base.noise import MultiHeadActionNoise
from joyrl.algos.base.network import *
from .model import Model

class Policy(BasePolicy):
def __init__(self,cfg, **kwargs) -> None:
super(Policy, self).__init__(cfg, **kwargs)
self.ou_noise = OUNoise(self.action_size_list)
self.action_noise = MultiHeadActionNoise('ou',self.action_size_list)
self.gamma = cfg.gamma
self.tau = cfg.tau
self.sample_count = 0 # sample count
Expand All @@ -40,9 +40,6 @@ def create_optimizer(self):
self.critic_optimizer = optim.Adam(self.model.critic.parameters(), lr=self.cfg.critic_lr)

def create_summary(self):
'''
创建 tensorboard 数据
'''
self.summary = {
'scalar': {
'tot_loss': 0.0,
Expand All @@ -52,8 +49,6 @@ def create_summary(self):
}

def update_summary(self):
''' 更新 tensorboard 数据
'''
if hasattr(self, 'tot_loss'):
self.summary['scalar']['tot_loss'] = self.tot_loss.item()
self.summary['scalar']['policy_loss'] = self.policy_loss.item()
Expand All @@ -67,7 +62,7 @@ def sample_action(self, state, **kwargs):
actor_outputs = self.model.actor(state)
self.mu = torch.cat([actor_outputs[i]['mu'] for i in range(len(self.action_size_list))], dim=1)
actions = get_model_actions(self.model, mode = 'sample', actor_outputs = actor_outputs)
actions = self.ou_noise.get_action(actions, self.sample_count) # add noise to action
actions = self.action_noise.get_action(actions, t = self.sample_count) # add noise to action
return actions

def update_policy_transition(self):
Expand Down
22 changes: 22 additions & 0 deletions joyrl/algos/TD3/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/usr/bin/env python
# coding=utf-8
'''
Author: JiangJi
Email: [email protected]
Date: 2024-07-21 16:37:59
LastEditor: JiangJi
LastEditTime: 2024-07-21 16:38:00
Discription:
'''
import torch.nn as nn
from joyrl.algos.base.network import *

class Model(nn.Module):
def __init__(self, cfg ):
super(Model, self).__init__()
state_size_list = cfg.obs_space_info.size
action_size_list = cfg.action_space_info.size
critic_input_size_list = state_size_list+ [[None, len(action_size_list)]]
self.actor = ActorNetwork(cfg, input_size_list = state_size_list)
self.critic_1 = CriticNetwork(cfg, input_size_list = critic_input_size_list)
self.critic_2 = CriticNetwork(cfg, input_size_list = critic_input_size_list)
114 changes: 44 additions & 70 deletions joyrl/algos/TD3/policy.py
Original file line number Diff line number Diff line change
@@ -1,69 +1,44 @@
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F
from joyrl.algos.base.policy import BasePolicy
from joyrl.algos.base.network import CriticNetwork, ActorNetwork

from joyrl.algos.base.noise import MultiHeadActionNoise
from joyrl.algos.base.network import *
from .model import Model

class Policy(BasePolicy):
def __init__(self, cfg):
super(Policy, self).__init__(cfg)
self.cfg = cfg
self.gamma = cfg.gamma
self.actor_lr = cfg.actor_lr
self.critic_lr = cfg.critic_lr
self.policy_noise = cfg.policy_noise # noise added to target policy during critic update
self.noise_clip = cfg.noise_clip # range to clip target policy noise
self.expl_noise = cfg.expl_noise # std of Gaussian exploration noise
self.action_noise = MultiHeadActionNoise('random',self.action_size_list, theta = cfg.expl_noise)
self.action_lows = [self.cfg.action_space_info.size[i][0] for i in range(len(self.action_size_list))]
self.action_highs = [self.cfg.action_space_info.size[i][1] for i in range(len(self.action_size_list))]
self.action_scales = [self.action_highs[i] - self.action_lows[i] for i in range(len(self.action_size_list))]
self.action_biases = [self.action_highs[i] + self.action_lows[i] for i in range(len(self.action_size_list))]
self.policy_freq = cfg.policy_freq # policy update frequency
self.tau = cfg.tau
self.sample_count = 0
self.update_step = 0
self.explore_steps = cfg.explore_steps # exploration steps before training
self.device = torch.device(cfg.device)
self.action_high = torch.FloatTensor(self.action_space.high).to(self.device)
self.action_low = torch.FloatTensor(self.action_space.low).to(self.device)
self.action_scale = torch.tensor((self.action_space.high - self.action_space.low)/2, device=self.device, dtype=torch.float32)
self.action_bias = torch.tensor((self.action_space.high + self.action_space.low)/2, device=self.device, dtype=torch.float32)
self.create_graph() # create graph and optimizer
self.create_summary() # create summary
self.to(self.device)

def get_action_size(self):
''' get action size
'''
# action_size must be [action_dim_1, action_dim_2, ...]
self.action_size_list = [self.action_space.shape[0]]
self.action_type_list = ['dpg']
self.action_high_list = [self.action_space.high[0]]
self.action_low_list = [self.action_space.low[0]]
setattr(self.cfg, 'action_size_list', self.action_size_list)
setattr(self.cfg, 'action_type_list', self.action_type_list)
setattr(self.cfg, 'action_high_list', self.action_high_list)
setattr(self.cfg, 'action_low_list', self.action_low_list)

def create_graph(self):
critic_input_size_list = self.state_size_list + [[None, self.action_size_list[0]]]
self.actor = ActorNetwork(self.cfg, input_size_list = self.state_size_list)
self.critic_1 = CriticNetwork(self.cfg, input_size_list = critic_input_size_list)
self.critic_2 = CriticNetwork(self.cfg, input_size_list = critic_input_size_list)
self.target_actor = ActorNetwork(self.cfg, input_size_list = self.state_size_list)
self.target_critic_1 = CriticNetwork(self.cfg, input_size_list = critic_input_size_list)
self.target_critic_2 = CriticNetwork(self.cfg, input_size_list = critic_input_size_list)
self.target_actor.load_state_dict(self.actor.state_dict())
self.target_critic_1.load_state_dict(self.critic_1.state_dict())
self.target_critic_2.load_state_dict(self.critic_2.state_dict())
self.create_optimizer()
def create_model(self):
''' create graph and optimizer
'''
self.model = Model(self.cfg)
self.target_model = Model(self.cfg)
self.target_model.load_state_dict(self.model.state_dict()) # or use this to copy parameters

def create_optimizer(self):
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr = self.actor_lr)
self.critic_1_optimizer = torch.optim.Adam(self.critic_1.parameters(), lr = self.critic_lr)
self.critic_2_optimizer = torch.optim.Adam(self.critic_2.parameters(), lr = self.critic_lr)
self.actor_optimizer = optim.Adam(self.model.actor.parameters(), lr=self.cfg.actor_lr)
self.critic_1_optimizer = optim.Adam(self.model.critic_1.parameters(), lr=self.cfg.critic_lr)
self.critic_2_optimizer = optim.Adam(self.model.critic_2.parameters(), lr=self.cfg.critic_lr)

def create_summary(self):
'''
创建 tensorboard 数据
'''
self.summary = {
'scalar': {
'tot_loss': 0.0,
Expand All @@ -74,8 +49,6 @@ def create_summary(self):
}

def update_summary(self):
''' 更新 tensorboard 数据
'''
if hasattr(self, 'tot_loss'):
self.summary['scalar']['tot_loss'] = self.tot_loss.item()
try:
Expand All @@ -88,33 +61,33 @@ def update_summary(self):
def sample_action(self, state, **kwargs):
self.sample_count += 1
if self.sample_count < self.explore_steps:
return self.action_space.sample()
return get_model_actions(self.model, mode = 'random', actor_outputs = [{}] * len(self.action_size_list))
else:
action = self.predict_action(state, **kwargs)
action_noise = self.expl_noise * np.random.normal(0, self.action_scale.cpu().detach().numpy(), size=self.action_size_list[0])
action = (action + action_noise).clip(self.action_space.low, self.action_space.high)
return action
actions = self.predict_action(state, **kwargs)
actions = self.action_noise.get_action(actions, t = self.sample_count) # add noise to action
return actions

@torch.no_grad()
def predict_action(self, state, **kwargs):
state = [torch.tensor(np.array(state), device=self.device, dtype=torch.float32).unsqueeze(dim=0)]
_ = self.actor(state)
action = self.actor.action_layers.get_actions()
return action[0]
state = self.process_sample_state(state)
actor_outputs = self.model.actor(state)
actions = get_model_actions(self.model, mode = 'predict', actor_outputs = actor_outputs)
return actions

def learn(self, **kwargs):
# state, action, reward, next_state, done = self.memory.sample(self.batch_size)
states, actions, next_states, rewards, dones = kwargs.get('states'), kwargs.get('actions'), kwargs.get('next_states'), kwargs.get('rewards'), kwargs.get('dones')

super().learn(**kwargs)
# update critic
noise = (torch.randn_like(actions) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)
next_actions = self.target_actor(next_states)[0]['mu']
# next_actions = ((next_actions + noise) * self.action_scale + self.action_bias).clamp(-self.action_scale+self.action_bias, self.action_scale+ self.action_bias)
next_actions = (next_actions * self.action_scale + self.action_bias + noise).clamp(self.action_low, self.action_high)
target_q1, target_q2 = self.target_critic_1([next_states, next_actions]).detach(), self.target_critic_2([next_states, next_actions]).detach()
next_actor_outputs = self.target_model.actor(self.next_states)
# next_actions = get_model_actions(self.target_model, mode = 'predict', actor_outputs = actor_outputs)
next_mus = torch.cat([next_actor_outputs[i]['mu'] for i in range(len(self.action_size_list))], dim=1)
# noise = (torch.randn_like(next_mus) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)
# next_mus_noised = (next_mus + noise).clamp(self.action_low, self.action_high)
target_q1, target_q2 = self.target_model.critic_1(self.next_states+ [next_mus]).detach(), self.target_model.critic_2(self.next_states+ [next_mus]).detach()
target_q = torch.min(target_q1, target_q2) # shape:[train_batch_size,n_actions]
target_q = rewards + self.gamma * target_q * (1 - dones)
current_q1, current_q2 = self.critic_1([states, actions]), self.critic_2([states, actions])
target_q = self.rewards + self.gamma * target_q * (1 - self.dones)
actions = [ (self.actions[i] - self.action_biases[i])/ self.action_scales[i] for i in range(len(self.actions)) ]
actions = torch.cat(actions, dim=1)
current_q1, current_q2 = self.model.critic_1(self.states + [actions]), self.model.critic_2(self.states + [actions])
# compute critic loss
critic_1_loss = F.mse_loss(current_q1, target_q)
critic_2_loss = F.mse_loss(current_q2, target_q)
Expand All @@ -125,19 +98,20 @@ def learn(self, **kwargs):
self.critic_2_optimizer.zero_grad()
critic_2_loss.backward()
self.critic_2_optimizer.step()
self.soft_update(self.critic_1, self.target_critic_1, self.tau)
self.soft_update(self.critic_2, self.target_critic_2, self.tau)
self.soft_update(self.model.critic_1, self.target_model.critic_1, self.tau)
self.soft_update(self.model.critic_2, self.target_model.critic_2, self.tau)
# Delayed policy updates
if self.sample_count % self.policy_freq == 0:
if self.update_step % self.policy_freq == 0:
# compute actor loss
act_ = self.actor(states)[0]['mu'] * self.action_scale + self.action_bias
actor_loss = -self.critic_1([states, act_]).mean()
actor_outputs = self.model.actor(self.states)
mus = torch.cat([actor_outputs[i]['mu'] for i in range(len(self.action_size_list))], dim=1)
actor_loss = -self.model.critic_1(self.states + [mus]).mean()
self.policy_loss = actor_loss
self.tot_loss = self.policy_loss + self.value_loss1 + self.value_loss2
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
self.soft_update(self.actor, self.target_actor, self.tau)
self.soft_update(self.model.actor, self.target_model.actor, self.tau)
self.update_step += 1
self.update_summary()

Expand Down
11 changes: 9 additions & 2 deletions joyrl/algos/base/action_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
Email: [email protected]
Date: 2023-12-25 09:28:26
LastEditor: JiangJi
LastEditTime: 2024-07-21 15:59:56
LastEditTime: 2024-07-21 16:46:03
Discription:
'''
from enum import Enum
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical,Normal
from enum import Enum
from joyrl.algos.base.base_layer import LayerConfig
from joyrl.algos.base.base_layer import create_layer

Expand Down Expand Up @@ -179,6 +180,9 @@ def predict_action(self, **kwargs):
mean = kwargs.get("mean", None)
return {"action": mean.detach().cpu().numpy().item(), "log_prob": None}

def random_action(self, **kwargs):
return {"action": random.uniform(self.action_low, self.action_high)}

def get_log_prob_action(self, actor_output, action):
''' get log_probs
'''
Expand Down Expand Up @@ -229,3 +233,6 @@ def predict_action(self, **kwargs):
action = mu * self.action_scale + self.action_bias
return {"action": action.detach().cpu().numpy().item()}

def random_action(self, **kwargs):
return {"action": random.uniform(self.action_low, self.action_high)}

Loading

0 comments on commit 425b3a9

Please sign in to comment.