Skip to content

Commit

Permalink
Fix several long running tests in test_auto_unit
Browse files Browse the repository at this point in the history
Summary: These tests were accidentally made to run longer in #463

Differential Revision: D47656034

fbshipit-source-id: fb11998c0bb500d3e3fada71a09d69d06224963b
  • Loading branch information
daniellepintz authored and facebook-github-bot committed Jul 21, 2023
1 parent 5c9fe83 commit ec19540
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions tests/framework/test_auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,10 @@ def forward(self, x):
input_dim = 2
dataset_len = 10
batch_size = 2
max_epochs = 2

dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
train(auto_unit, dataloader)
train(auto_unit, dataloader, max_epochs=max_epochs)

orig_module = auto_unit.module
swa_module = auto_unit.swa_model
Expand Down Expand Up @@ -634,8 +636,10 @@ def forward(self, x):
input_dim = 2
dataset_len = 10
batch_size = 2
max_epochs = 2

dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
train(auto_unit, dataloader)
train(auto_unit, dataloader, max_epochs=max_epochs)

orig_module = auto_unit.module.module
swa_module = auto_unit.swa_model.module.module
Expand Down Expand Up @@ -695,9 +699,10 @@ def forward(self, x):
input_dim = 2
dataset_len = 10
batch_size = 2
max_epochs = 1

dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
state = get_dummy_train_state(dataloader)
train(state, auto_unit)
train(auto_unit, dataloader, max_epochs=max_epochs)

tc = unittest.TestCase()
tc.assertTrue(custom_noop_hook_called)
Expand Down

0 comments on commit ec19540

Please sign in to comment.