Skip to content

Commit

Permalink
Fix warning in aggregation.mean
Browse files Browse the repository at this point in the history
Summary:
This diff fixes the incorrect warning when running `mean.compute()` when the mean is exactly 0.

Instead of checking for the weighted sum of elements to be 0, we instead check for the total sum of weights to be zero (meaning that the average can be 0 without error, but we throw a warning when dividing by zero)

We also update the error message to reflect that the issue is no weight has been accumulated, since it is possible to call this function with only 0 weights.

Addresses: pytorch#185

Reviewed By: JKSenthil

Differential Revision: D50806243
  • Loading branch information
bobakfb authored and facebook-github-bot committed Dec 20, 2023
1 parent 235aa26 commit 233d935
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
4 changes: 4 additions & 0 deletions tests/metrics/aggregation/test_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def test_mean_class_compute_without_update(self) -> None:
metric = Mean()
self.assertEqual(metric.compute(), torch.tensor(0.0, dtype=torch.float64))

metric = Mean()
metric.update(torch.tensor([0.0, 0.0]), weight=0)
self.assertEqual(metric.compute(), torch.tensor(0.0, dtype=torch.float64))

def test_mean_class_update_input_valid_weight(self) -> None:
update_value = [
torch.rand(BATCH_SIZE),
Expand Down
12 changes: 8 additions & 4 deletions torcheval/metrics/aggregation/mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,11 @@ def __init__(
device: Optional[torch.device] = None,
) -> None:
super().__init__(device=device)
# weighted sum of values over the entire state
self._add_state(
"weighted_sum", torch.tensor(0.0, device=self.device, dtype=torch.float64)
)
# sum total of weights over the entire state
self._add_state(
"weights", torch.tensor(0.0, device=self.device, dtype=torch.float64)
)
Expand All @@ -82,9 +84,9 @@ def update(
ValueError: If value of weight is neither a ``float`` nor a ``int'' nor a ``torch.Tensor`` that matches the input tensor size.
"""

weighted_sum, weights = _mean_update(input, weight)
weighted_sum, net_weight = _mean_update(input, weight)
self.weighted_sum += weighted_sum
self.weights += weights
self.weights += net_weight
return self

@torch.inference_mode()
Expand All @@ -93,8 +95,10 @@ def compute(self: TMean) -> torch.Tensor:
If no calls to ``update()`` are made before ``compute()`` is called,
the function throws a warning and returns 0.0.
"""
if not self.weighted_sum:
logging.warning("No calls to update() have been made - returning 0.0")
if not torch.is_nonzero(self.weights):
logging.warning(
"There is no weight for the average, no samples with weight have been added (did you ever run update()?)- returning 0.0"
)
return torch.tensor(0.0, dtype=torch.float64)
return self.weighted_sum / self.weights

Expand Down

0 comments on commit 233d935

Please sign in to comment.