Skip to content

Commit

Permalink
Fix several long running tests in test_auto_unit (#468)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #468

These tests were accidentally made to run longer in #463

Reviewed By: ananthsub

Differential Revision: D47656034

fbshipit-source-id: f900a9138a23317f0d885ee0c65715e3f22ca984
  • Loading branch information
daniellepintz authored and facebook-github-bot committed Jul 21, 2023
1 parent 5c9fe83 commit a6552e1
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions tests/framework/test_auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def forward(self, x):
return x

my_module = Net()
my_swa_params = SWAParams(epoch_start=1, anneal_epochs=5)
my_swa_params = SWAParams(epoch_start=1, anneal_epochs=3)

auto_unit = DummyAutoUnit(
module=my_module,
Expand All @@ -223,8 +223,9 @@ def forward(self, x):
input_dim = 2
dataset_len = 10
batch_size = 2

dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
train(auto_unit, dataloader)
train(auto_unit, dataloader, max_epochs=5, max_steps_per_epoch=1)

orig_module = auto_unit.module
swa_module = auto_unit.swa_model
Expand Down Expand Up @@ -623,7 +624,7 @@ def forward(self, x):
return x

my_module = Net()
my_swa_params = SWAParams(epoch_start=1, anneal_epochs=5)
my_swa_params = SWAParams(epoch_start=1, anneal_epochs=3)

auto_unit = DummyAutoUnit(
module=my_module,
Expand All @@ -634,8 +635,9 @@ def forward(self, x):
input_dim = 2
dataset_len = 10
batch_size = 2

dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
train(auto_unit, dataloader)
train(auto_unit, dataloader, max_epochs=5, max_steps_per_epoch=1)

orig_module = auto_unit.module.module
swa_module = auto_unit.swa_model.module.module
Expand Down Expand Up @@ -695,9 +697,9 @@ def forward(self, x):
input_dim = 2
dataset_len = 10
batch_size = 2

dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
state = get_dummy_train_state(dataloader)
train(state, auto_unit)
train(auto_unit, dataloader, max_epochs=1, max_steps_per_epoch=1)

tc = unittest.TestCase()
tc.assertTrue(custom_noop_hook_called)
Expand Down

0 comments on commit a6552e1

Please sign in to comment.