Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: fairinternal/fairseq-py#822

Differential Revision: D16800078

Pulled By: myleott

fbshipit-source-id: b86e08e01f2fe13c64b77f1d23a5f6800f252bf7
  • Loading branch information
myleott authored and facebook-github-bot committed Aug 14, 2019
1 parent baa8ce1 commit 7c89e13
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 33 deletions.
5 changes: 4 additions & 1 deletion tests/speech_recognition/asr_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,10 @@ def check_encoder_output(encoder_output, batch_size=None):
"encoder_padding_mask must be a torch.Tensor" + _current_postion_info()
)
return False, msg
if mask.dtype != torch.uint8:
if (
mask.dtype != torch.uint8
and (not hasattr(torch, 'bool') or mask.dtype != torch.bool)
):
msg = (
"encoder_padding_mask must have dtype of uint8"
+ _current_postion_info()
Expand Down
68 changes: 36 additions & 32 deletions tests/test_binaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def test_transformer(self):
'--decoder-layers', '2',
'--encoder-embed-dim', '8',
'--decoder-embed-dim', '8',
])
], run_validation=True)
generate_main(data_dir)

def test_lightconv(self):
Expand Down Expand Up @@ -257,7 +257,9 @@ def test_transformer_lm(self):
with tempfile.TemporaryDirectory('test_transformer_lm') as data_dir:
create_dummy_data(data_dir)
preprocess_lm_data(data_dir)
train_language_model(data_dir, 'transformer_lm', ['--add-bos-token'])
train_language_model(
data_dir, 'transformer_lm', ['--add-bos-token'], run_validation=True,
)
eval_lm_main(data_dir)


Expand Down Expand Up @@ -457,7 +459,7 @@ def preprocess_translation_data(data_dir, extra_flags=None):
preprocess.main(preprocess_args)


def train_translation_model(data_dir, arch, extra_flags=None, task='translation'):
def train_translation_model(data_dir, arch, extra_flags=None, task='translation', run_validation=False):
train_parser = options.get_training_parser()
train_args = options.parse_args_and_arch(
train_parser,
Expand All @@ -477,20 +479,21 @@ def train_translation_model(data_dir, arch, extra_flags=None, task='translation'
)
train.main(train_args)

# test validation
validate_parser = options.get_validation_parser()
validate_args = options.parse_args_and_arch(
validate_parser,
[
'--task', task,
data_dir,
'--path', os.path.join(data_dir, 'checkpoint_last.pt'),
'--valid-subset', 'valid',
'--max-tokens', '500',
'--no-progress-bar',
]
)
validate.main(validate_args)
if run_validation:
# test validation
validate_parser = options.get_validation_parser()
validate_args = options.parse_args_and_arch(
validate_parser,
[
'--task', task,
data_dir,
'--path', os.path.join(data_dir, 'checkpoint_last.pt'),
'--valid-subset', 'valid',
'--max-tokens', '500',
'--no-progress-bar',
]
)
validate.main(validate_args)


def generate_main(data_dir, extra_flags=None):
Expand Down Expand Up @@ -534,7 +537,7 @@ def preprocess_lm_data(data_dir):
preprocess.main(preprocess_args)


def train_language_model(data_dir, arch, extra_flags=None):
def train_language_model(data_dir, arch, extra_flags=None, run_validation=False):
train_parser = options.get_training_parser()
train_args = options.parse_args_and_arch(
train_parser,
Expand All @@ -557,20 +560,21 @@ def train_language_model(data_dir, arch, extra_flags=None):
)
train.main(train_args)

# test validation
validate_parser = options.get_validation_parser()
validate_args = options.parse_args_and_arch(
validate_parser,
[
'--task', 'language_modeling',
data_dir,
'--path', os.path.join(data_dir, 'checkpoint_last.pt'),
'--valid-subset', 'valid',
'--max-tokens', '500',
'--no-progress-bar',
]
)
validate.main(validate_args)
if run_validation:
# test validation
validate_parser = options.get_validation_parser()
validate_args = options.parse_args_and_arch(
validate_parser,
[
'--task', 'language_modeling',
data_dir,
'--path', os.path.join(data_dir, 'checkpoint_last.pt'),
'--valid-subset', 'valid',
'--max-tokens', '500',
'--no-progress-bar',
]
)
validate.main(validate_args)


def eval_lm_main(data_dir):
Expand Down

0 comments on commit 7c89e13

Please sign in to comment.