Skip to content

Commit

Permalink
feat: Align classification results with sklearn classification ones (#63
Browse files Browse the repository at this point in the history
)

* feat: Align classification results with sklearn classification ones

* build: Bump version 1.2.3 -> 1.2.4

* docs: Add changelog

* feat: Print also task config during experiment starting phase (#62)

* fix: Add->Added in changelog

Approved By: @lorenzomammana
  • Loading branch information
AlessandroPolidori authored Sep 20, 2023
1 parent ed54da9 commit 4a0cd8d
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 7 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@

# Changelog
All notable changes to this project will be documented in this file.
### [1.2.4]
#### Added

- Return also probabilities in Classification's module predict step and add them to `self.res`.


### [1.2.3]

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "quadra"
version = "1.2.3"
version = "1.2.4"
description = "Deep Learning experiment orchestration library"
authors = [
{ name = "Alessandro Polidori", email = "[email protected]" },
Expand Down Expand Up @@ -118,7 +118,7 @@ repository = "https://github.com/orobix/quadra"

# Adapted from https://realpython.com/pypi-publish-python-package/#version-your-package
[tool.bumpver]
current_version = "1.2.3"
current_version = "1.2.4"
version_pattern = "MAJOR.MINOR.PATCH"
commit_message = "build: Bump version {old_version} -> {new_version}"
commit = true
Expand Down
2 changes: 1 addition & 1 deletion quadra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "1.2.3"
__version__ = "1.2.4"


def get_version():
Expand Down
2 changes: 1 addition & 1 deletion quadra/modules/classification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A
grayscale_cam = ndimage.zoom(grayscale_cam_low_res, zoom_factors, order=1)
else:
grayscale_cam = None
return predicted_classes, grayscale_cam
return predicted_classes, grayscale_cam, torch.max(probs, dim=1)[0].tolist()


class MultilabelClassificationModule(BaseLightningModule):
Expand Down
9 changes: 6 additions & 3 deletions quadra/tasks/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,10 +311,12 @@ def generate_report(self) -> None:
log.warning("There is no prediction to generate the report. Skipping report generation.")
return
all_outputs = [x[0] for x in predictions_outputs]
if not all_outputs:
all_probs = [x[2] for x in predictions_outputs]
if not all_outputs or not all_probs:
log.warning("There is no prediction to generate the report. Skipping report generation.")
return
all_outputs = [item for sublist in all_outputs for item in sublist]
all_probs = [item for sublist in all_probs for item in sublist]
all_targets = [target.tolist() for im, target in self.datamodule.test_dataloader()]
all_targets = [item for sublist in all_targets for item in sublist]

Expand All @@ -335,16 +337,17 @@ def generate_report(self) -> None:
output_folder_test = "test"
test_dataloader = self.datamodule.test_dataloader()
test_dataset = cast(ImageClassificationListDataset, test_dataloader.dataset)
res = pd.DataFrame(
self.res = pd.DataFrame(
{
"sample": list(test_dataset.x),
"real_label": all_targets,
"pred_label": all_outputs,
"probability": all_probs,
}
)
os.makedirs(output_folder_test, exist_ok=True)
save_classification_result(
results=res,
results=self.res,
output_folder=output_folder_test,
confmat=self.report_confmat,
accuracy=accuracy,
Expand Down
1 change: 1 addition & 0 deletions quadra/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def extras(config: DictConfig) -> None:
def print_config(
config: DictConfig,
fields: Sequence[str] = (
"task",
"trainer",
"model",
"datamodule",
Expand Down

0 comments on commit 4a0cd8d

Please sign in to comment.