From 339fa052ebba2e15f7c6c7f0117bcee50013f3d7 Mon Sep 17 00:00:00 2001 From: Vinzent Steinberg Date: Thu, 14 Dec 2023 21:56:48 +0100 Subject: [PATCH] Fix covariance and add tests --- src/covariance.rs | 26 ++++++++++++++++---- src/lib.rs | 1 + tests/integration/covariance.rs | 43 +++++++++++++++++++++++++++++++++ tests/integration/main.rs | 1 + 4 files changed, 66 insertions(+), 5 deletions(-) create mode 100644 tests/integration/covariance.rs diff --git a/src/covariance.rs b/src/covariance.rs index 597c2c6..2eaa2fb 100644 --- a/src/covariance.rs +++ b/src/covariance.rs @@ -4,8 +4,24 @@ use serde_derive::{Deserialize, Serialize}; use crate::Merge; -/// Estimate the arithmetic mean and the covariance of a sequence of number pairs +/// Estimate the arithmetic means and the covariance of a sequence of number pairs /// ("population"). +/// +/// Because the variances are calculated as well, this can be used to calculate the Pearson +/// correlation coefficient. +/// +/// +/// ## Example +/// +/// ``` +/// use average::Covariance; +/// +/// let a: Covariance = [(1., 5.), (2., 4.), (3., 3.), (4., 2.), (5., 1.)].iter().cloned().collect(); +/// assert_eq!(a.mean_x(), 3.); +/// assert_eq!(a.mean_y(), 3.); +/// assert_eq!(a.population_covariance(), -2.5); +/// assert_eq!(a.sample_covariance(), -2.0); +/// ``` #[derive(Debug, Clone)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] pub struct Covariance { @@ -40,13 +56,13 @@ impl Covariance { let delta_x = x - self.avg_x; let delta_y = y - self.avg_y; - self.avg_x += delta_x / self.n.to_f64().unwrap(); + self.avg_x += delta_x / n; self.sum_x_2 += delta_x * delta_x * n * (n - 1.); - self.avg_y += delta_y / self.n.to_f64().unwrap(); + self.avg_y += delta_y / n; self.sum_y_2 += delta_y * delta_y * n * (n - 1.); - self.sum_prod += delta_x * delta_y; + self.sum_prod += delta_x * (y - self.avg_y); } /// Calculate the population covariance of the sample. @@ -56,7 +72,7 @@ impl Covariance { /// Returns NaN for an empty sample. #[inline] pub fn population_covariance(&self) -> f64 { - if self.n < 2 { + if self.n < 1 { return f64::NAN; } self.sum_prod / self.n.to_f64().unwrap() diff --git a/src/lib.rs b/src/lib.rs index a613fa2..786cb1e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -93,6 +93,7 @@ #![forbid(missing_docs)] #![forbid(missing_debug_implementations)] #![cfg_attr(feature = "nightly", feature(generic_const_exprs))] +#[cfg(feature = "std")] extern crate std; #[macro_use] mod macros; diff --git a/tests/integration/covariance.rs b/tests/integration/covariance.rs new file mode 100644 index 0000000..fd5564f --- /dev/null +++ b/tests/integration/covariance.rs @@ -0,0 +1,43 @@ +use average::Covariance; + +#[test] +fn simple() { + let mut cov = Covariance::new(); + assert!(cov.mean_x().is_nan()); + assert!(cov.mean_y().is_nan()); + assert!(cov.population_covariance().is_nan()); + assert!(cov.sample_covariance().is_nan()); + assert!(cov.population_pearson().is_nan()); + assert!(cov.sample_pearson().is_nan()); + + cov.add(1., 5.); + assert_eq!(cov.mean_x(), 1.); + assert_eq!(cov.mean_y(), 5.); + assert_eq!(cov.population_covariance(), 0.); + assert!(cov.sample_covariance().is_nan()); + // TODO: pearson + + cov.add(2., 4.); + assert_eq!(cov.mean_x(), 1.5); + assert_eq!(cov.mean_y(), 4.5); + assert_eq!(cov.population_covariance(), -0.25); + assert_eq!(cov.sample_covariance(), -0.5); + + cov.add(3., 3.); + assert_eq!(cov.mean_x(), 2.); + assert_eq!(cov.mean_y(), 4.); + assert_eq!(cov.population_covariance(), -2./3.); + assert_eq!(cov.sample_covariance(), -1.); + + cov.add(4., 2.); + assert_eq!(cov.mean_x(), 2.5); + assert_eq!(cov.mean_y(), 3.5); + assert_eq!(cov.population_covariance(), -1.25); + assert_eq!(cov.sample_covariance(), -5./3.); + + cov.add(5., 1.); + assert_eq!(cov.mean_x(), 3.); + assert_eq!(cov.mean_y(), 3.); + assert_eq!(cov.population_covariance(), -2.0); + assert_eq!(cov.sample_covariance(), -2.5); +} \ No newline at end of file diff --git a/tests/integration/main.rs b/tests/integration/main.rs index b48617d..3be96e1 100644 --- a/tests/integration/main.rs +++ b/tests/integration/main.rs @@ -22,3 +22,4 @@ mod skewness; #[cfg(feature = "std")] mod streaming_stats; mod weighted_mean; +mod covariance;