-
Notifications
You must be signed in to change notification settings - Fork 0
/
objective_funcs.py
105 lines (70 loc) · 3.38 KB
/
objective_funcs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
from utils import softmax
from torch import nn
from torch.autograd import Variable
from torch.distributions import Categorical
import torch.nn.functional as F
mse = nn.MSELoss(reduction='none')
def dqn(batch, model, current_model, discount, tau):
states, actions, terminals, rewards, next_states = batch
# print(next_states)
q_next = current_model(Variable(next_states)).squeeze()
probs = softmax(q_next, tau)
# calculate the maximum action value of next states
# expected_q_next = (1-torch.stack(terminals)) * (torch.sum(probs * q_next , axis = 1))
max_q_next = (1-terminals) * (torch.max(q_next , axis = 1)[0])
# calculate the targets
rewards = rewards.float()
# targets = Variable(rewards + (discount * expected_q_next)).float()
targets = Variable(rewards + (discount * max_q_next)).float()
# calculate the outputs from the previous states (batch_size, num_actions)
outputs = model(Variable(states.float())).squeeze()
actions = actions.view(-1,1)
outputs = torch.gather(outputs, 1, actions).squeeze()
# the loss
loss = mse(outputs, targets)
return loss
def actor_critic(batch, model, current_model, discount, tau):
states, actions, terminals, rewards, next_states = batch
# data from priveus episods
q_next = current_model(Variable(next_states)).squeeze()
# probs = softmax(q_next, tau)
# m = Categorical(probs)
# log_probs = m.log_prob(actions)
# calculate the maximum action value of next states
# expected_q_next = (1-torch.stack(terminals)) * (torch.sum(probs * q_next , axis = 1))
max_q_next = (1-terminals) * (torch.max(q_next , axis = 1)[0])
# calculate the targets
rewards = rewards.float()
# targets = Variable(rewards + (discount * expected_q_next)).float()
targets = Variable(rewards + (discount * max_q_next)).float()
# calculate the outputs from the previous states (batch_size, num_actions)
outputs = model(Variable(states.float())).squeeze()
probs = softmax(outputs, tau)
dist = Categorical(probs)
log_probs = dist.log_prob(actions)
actions = actions.view(-1,1)
outputs = torch.gather(outputs, 1, actions).squeeze()
advantage = targets - outputs.detach()
policy_loss = torch.sum(-log_probs * advantage)
value_loss = torch.sum(F.smooth_l1_loss(outputs, targets))
return policy_loss + value_loss
def dqn_priority(batch, model, current_model, discount, tau, buffer):
states, actions, terminals, rewards, next_states = batch
# print(next_states)
q_next = current_model(Variable(next_states)).squeeze()
probs = softmax(q_next, tau)
# calculate the maximum action value of next states
# expected_q_next = (1-torch.stack(terminals)) * (torch.sum(probs * q_next , axis = 1))
max_q_next = (1-terminals) * (torch.max(q_next , axis = 1)[0])
# calculate the targets
rewards = rewards.float()
# targets = Variable(rewards + (discount * expected_q_next)).float()
targets = Variable(rewards + (discount * max_q_next)).float()
# calculate the outputs from the previous states (batch_size, num_actions)
outputs = model(Variable(states.float())).squeeze()
actions = actions.view(-1,1)
outputs = torch.gather(outputs, 1, actions).squeeze()
# the loss
loss = mse(outputs, targets)
return loss