diff --git a/yogo/model.py b/yogo/model.py index 929314f6..bf9efbee 100644 --- a/yogo/model.py +++ b/yogo/model.py @@ -143,6 +143,7 @@ def from_pth( return model, { "step": global_step, "class_names": class_names, + "normalize_images": params["normalize_images"], } def to(self, device, *args, **kwargs): diff --git a/yogo/utils/test_model.py b/yogo/utils/test_model.py index fb0e9247..ecb11b8a 100644 --- a/yogo/utils/test_model.py +++ b/yogo/utils/test_model.py @@ -40,7 +40,7 @@ def test_model(args: argparse.Namespace) -> None: "slurm-job-id": os.getenv("SLURM_JOB_ID", default=None), } - log_to_wandb = args.wandb or len(args.wandb_resume_id) > 0 + log_to_wandb = args.wandb or (args.wandb_resume_id is not None) if log_to_wandb: print("logging to wandb")