Skip to content

Commit

Permalink
Fix sac
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Jun 20, 2024
1 parent 7c98b1c commit 46df033
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion torchbenchmark/models/soft_actor_critic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .envs import load_gym
from .sac import SACAgent
from .replay import PrioritizedReplayBuffer, ReplayBuffer
from .utils import hard_update, soft_update
from .sac_utils import hard_update, soft_update


def learn_standard(
Expand Down
12 changes: 6 additions & 6 deletions torchbenchmark/models/soft_actor_critic/nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import distributions as pyd
from torch import nn

from . import utils
from . import sac_utils
from torchbenchmark.util.distribution import SquashedNormal

def weight_init(m):
Expand All @@ -30,11 +30,11 @@ def __init__(self, obs_shape, out_dim=50):
self.conv3 = nn.Conv2d(32, 32, kernel_size=3, stride=1)
self.conv4 = nn.Conv2d(32, 32, kernel_size=3, stride=1)

output_height, output_width = utils.compute_conv_output(
output_height, output_width = sac_utils.compute_conv_output(
obs_shape[1:], kernel_size=(3, 3), stride=(2, 2)
)
for _ in range(3):
output_height, output_width = utils.compute_conv_output(
output_height, output_width = sac_utils.compute_conv_output(
(output_height, output_width), kernel_size=(3, 3), stride=(1, 1)
)

Expand Down Expand Up @@ -63,15 +63,15 @@ def __init__(self, obs_shape, out_dim=50):
self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)

output_height, output_width = utils.compute_conv_output(
output_height, output_width = sac_utils.compute_conv_output(
obs_shape[1:], kernel_size=(8, 8), stride=(4, 4)
)

output_height, output_width = utils.compute_conv_output(
output_height, output_width = sac_utils.compute_conv_output(
(output_height, output_width), kernel_size=(4, 4), stride=(2, 2)
)

output_height, output_width = utils.compute_conv_output(
output_height, output_width = sac_utils.compute_conv_output(
(output_height, output_width), kernel_size=(3, 3), stride=(1, 1)
)

Expand Down
2 changes: 1 addition & 1 deletion torchbenchmark/models/soft_actor_critic/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import torch

from . import envs, nets, replay, utils
from . import envs, nets, replay, sac_utils


class SACAgent:
Expand Down

0 comments on commit 46df033

Please sign in to comment.