Skip to content

Commit

Permalink
Eval adaptations (#2)
Browse files Browse the repository at this point in the history
* readme: Slightly adapt

* interface: Allow empty save_path

* interface: Add methods to get results

* blearndataset: Add warning for depth norm file not existing

* interface: Turn off random flip if not training
  • Loading branch information
marcojob authored Oct 1, 2024
1 parent 892b059 commit d4d168e
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 30 deletions.
2 changes: 1 addition & 1 deletion .devcontainer/desktop/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
},
"remoteUser": "asl",
"initializeCommand": ".devcontainer/devcontainer-optional-mounts.sh",
"postStartCommand": "pip3 install -e .",
"postStartCommand": "/bin/bash",
"mounts": [
{
"source": "${localEnv:HOME}/.bash-git-prompt",
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
## Overview

### metric_depth_network
Contains the metric network based on [Depth Anything V2](https://github.com/DepthAnything/Depth-Anything-V2). All important code diffs compared to this upstream codebase are outlined generally.
Contains the metric network based on [Depth Anything V2](https://github.com/DepthAnything/Depth-Anything-V2). All important code diffs compared to this upstream codebase are outlined in general.

## Unittests
## Unit tests
Unittests can be run either using `python3 -m unittest` or the CI script `ci/pr_unittest.bash`.
36 changes: 23 additions & 13 deletions radarmeetsvision/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(self):
self.optimizer = None
self.output_channels = None
self.previous_best = self.reset_previous_best()
self.results = None
self.results_path = None
self.use_depth_prior = None

Expand Down Expand Up @@ -71,6 +72,9 @@ def set_results_path(self, results_path):
else:
logger.error(f'{self.results_path} does not exist')

def get_results(self):
return self.results, self.results_per_sample

def load_model(self, pretrained_from=None):
if self.encoder is not None and self.max_depth is not None and self.output_channels is not None and self.use_depth_prior is not None:
logger.info(f'Using pretrained from: {pretrained_from}')
Expand Down Expand Up @@ -120,7 +124,7 @@ def get_dataset_loader(self, task, datasets_dir, dataset_list):

# Convert to MultiDataset (also ok for one)
dataset = MultiDatasetLoader(datasets, self.depth_min_max)
if task == 'train':
if 'train' in task:
loader = DataLoader(dataset, batch_size=self.batch_size, pin_memory=True, drop_last=True, shuffle=True)

else:
Expand Down Expand Up @@ -193,9 +197,9 @@ def train_epoch(self, epoch, train_loader):
def validate_epoch(self, epoch, val_loader):
self.model.eval()

results, nsamples = get_empty_results(self.device)
self.results, self.results_per_sample, nsamples = get_empty_results(self.device)
for i, sample in enumerate(val_loader):
image, _, depth_target, mask = self.prepare_sample(sample, random_flip=True)
image, _, depth_target, mask = self.prepare_sample(sample, random_flip=False)

# TODO: Maybe not hardcode 10 here?
if mask.sum() > 10:
Expand All @@ -206,20 +210,26 @@ def validate_epoch(self, epoch, val_loader):

current_results = eval_depth(depth_prediction[mask], depth_target[mask])
if current_results is not None:
for k in results.keys():
results[k] += current_results[k]
for k in self.results.keys():
self.results[k] += current_results[k]
self.results_per_sample[k].append(current_results[k])
nsamples += 1

self.update_best_result(results, nsamples)
if i % 10 == 0:
abs_rel = (self.results["abs_rel"]/nsamples).item()
logger.info(f'Iter: {i}/{len(val_loader)}, Absrel: {abs_rel:.3f}')

self.update_best_result(self.results, nsamples)
self.save_checkpoint(epoch)


def save_checkpoint(self, epoch):
checkpoint = {
'model': self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'epoch': epoch,
'previous_best': self.previous_best,
}
if self.results_path is not None:
if self.results_path is not None and len(str(self.results_path)) > 1:
checkpoint = {
'model': self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'epoch': epoch,
'previous_best': self.previous_best,
}
# TODO: How to check properly if current path is not .
torch.save(checkpoint, self.results_path / f'latest_{epoch}.pth')
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,11 @@ def get_depth(self, index):
depth = np.load(depth_path_alt)

elif depth_normalized_path.is_file():
logger.warning("found normalized depth")
if self.depth_range is not None and self.depth_min is not None:
depth_normalized = np.load(depth_normalized_path)
depth_valid_mask = (depth_normalized > 0.0) & (depth_normalized <= 1.0)
depth = np.zeros(depth_normalized.shape, dtype='float32')
depth[depth_valid_mask] = depth_normalized[depth_valid_mask] * self.depth_range + self.depth_min
logger.warning("unormalized image")

else:
logger.error("Only found normalized depth, but did not find depth normalization file")
Expand Down Expand Up @@ -218,6 +216,9 @@ def get_depth_range(self):
norm_range = norm_max - norm_min
logger.info(f'Found norm range {norm_range:.3f} m')

else:
logger.error(f"Could not find: {depth_norm_file}")

return norm_min, norm_max, norm_range

def get_depth_prior(self, index, img_copy, depth):
Expand Down
31 changes: 21 additions & 10 deletions radarmeetsvision/metric_depth_network/util/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,29 @@ def print_epoch_summary(epoch, epochs, result_dict):

def get_empty_results(device):
results = {
'd1': torch.tensor([0.0]).to(device),
'd2': torch.tensor([0.0]).to(device),
'd3': torch.tensor([0.0]).to(device),
'abs_rel': torch.tensor([0.0]).to(device),
'sq_rel': torch.tensor([0.0]).to(device),
'rmse': torch.tensor([0.0]).to(device),
'rmse_log': torch.tensor([0.0]).to(device),
'log10': torch.tensor([0.0]).to(device),
'silog': torch.tensor([0.0]).to(device)
'd1': 0.0,
'd2': 0.0,
'd3': 0.0,
'abs_rel': 0.0,
'sq_rel': 0.0,
'rmse': 0.0,
'rmse_log': 0.0,
'log10': 0.0,
'silog': 0.0
}
results_per_sample = {
'd1': [],
'd2': [],
'd3': [],
'abs_rel': [],
'sq_rel': [],
'rmse': [],
'rmse_log': [],
'log10': [],
'silog': []
}
nsamples = torch.tensor([0.0]).to(device)
return results, nsamples
return results, results_per_sample, nsamples

def randomly_flip(img, target, valid_mask):
if random.random() < 0.5:
Expand Down
2 changes: 1 addition & 1 deletion radarmeetsvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def setup_global_logger(output_dir=None):
console_handler.setFormatter(color_formatter)
logger.addHandler(console_handler)

if output_dir is not None:
if output_dir is not None and len(output_dir) > 0:
log_file = datetime.now().strftime(f"{output_dir}/log_%Y-%m-%d_%H-%M-%S") + '.txt'
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.INFO)
Expand Down
5 changes: 4 additions & 1 deletion scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@ def main(config, checkpoints, datasets, results):
interface.set_output_channels(config['output_channels'])
interface.set_use_depth_prior(config['use_depth_prior'])

pretrained_from = Path(checkpoints) / config['pretrained_from']
pretrained_from = None
if config['pretrained_from'] is not None:
pretrained_from = Path(checkpoints) / config['pretrained_from']
interface.load_model(pretrained_from=pretrained_from)

interface.set_results_path(results)
interface.set_optimizer()

Expand Down

0 comments on commit d4d168e

Please sign in to comment.