From 3c9c1ffbee445eb7d011ef8f94da721bea9295be Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 28 Apr 2023 09:25:28 -0400 Subject: [PATCH 001/116] Integer metadata fill value --- src/aspire/source/image.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index bdc1a33c38..2ae1c44ec8 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -514,6 +514,11 @@ def set_metadata(self, metadata_fields, values, indices=None): if np.issubdtype(values.dtype, np.str_): values = values.astype("object") fill_value = "" + elif np.issubdtype(values.dtype, np.integer): + # For integers, we'll use the minimal value. + # This will be a large negative value when signed, + # and zero for unsigned integers. + fill_value = np.iinfo(values.dtype).min else: fill_value = np.nan From 391732bf5180e5e0ace39ee2f2992a9d60965b48 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 3 May 2023 15:57:28 -0400 Subject: [PATCH 002/116] Author updates --- .zenodo.json | 2 +- CONTRIBUTORS.rst | 4 ++-- setup.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.zenodo.json b/.zenodo.json index 5a043ed44f..7c9b4e6aca 100644 --- a/.zenodo.json +++ b/.zenodo.json @@ -18,7 +18,7 @@ "name": "Junchao Xia" }, { - "affiliation": "Princeton University", + "affiliation": "Columbia University", "name": "Chris Langfield", "orcid": "0000-0003-4151-203X" }, diff --git a/CONTRIBUTORS.rst b/CONTRIBUTORS.rst index e6e74e36e0..332feb3671 100644 --- a/CONTRIBUTORS.rst +++ b/CONTRIBUTORS.rst @@ -22,7 +22,7 @@ Developers of the Python and Matlab collection of ASPIRE codes are listed below. +------------------+-----------------+---------------------------+-----------------------------------+ | Ayelet Heimowitz | ayeltg | ayeleth@ariel.ac.il | Ariel University | +------------------+-----------------+---------------------------+-----------------------------------+ - | Chris Langfield | chris-langfield | langfield@princeton.edu | Princeton University | + | Chris Langfield | chris-langfield | cal2254@columbia.edu | Columbia University | +------------------+-----------------+---------------------------+-----------------------------------+ | Amit Moscovich | mosco | mosco@tauex.tau.ac.il | Tel Aviv University | +------------------+-----------------+---------------------------+-----------------------------------+ @@ -36,7 +36,7 @@ Developers of the Python and Matlab collection of ASPIRE codes are listed below. +------------------+-----------------+---------------------------+-----------------------------------+ | Garrett Wright | garrettwrong | gbwright@princeton.edu | Princeton University | +------------------+-----------------+---------------------------+-----------------------------------+ - | Junchao Xia | junchaoxia | junchao.xia@princeton.edu | Princeton University | + | Junchao Xia | junchaoxia | junchaoxiacn@gmail.com | OpenEye | +------------------+-----------------+---------------------------+-----------------------------------+ | diff --git a/setup.py b/setup.py index 11898cdcb9..8d0e1f7a15 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ def read(fname): long_description_content_type="text/markdown", license="GPLv3", url="https://github.com/ComputationalCryoEM/ASPIRE-Python", - author="Joakim Anden, Ayelet Heimowitz, Vineet Bansal, Robbie Brook, Itay Sason, Yoel Shkolnisky, Garrett Wright, Junchao Xia", + author="Joakim Anden, Vineet Bansal, Josh Carmichael, Chris Langfield, Ayelet Heimowitz, Yoel Shkolnisky, Amit Singer, Garrett Wright, Junchao Xia", author_email="devs.aspire@gmail.com", install_requires=[ "click", From 929bfea6e28ba708492fdd7badfdd8e9ef1272d5 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 20 Apr 2023 14:38:38 -0400 Subject: [PATCH 003/116] Better logging messages. --- src/aspire/abinitio/commonline_c3_c4.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/aspire/abinitio/commonline_c3_c4.py b/src/aspire/abinitio/commonline_c3_c4.py index b880452e5a..d901ce431b 100644 --- a/src/aspire/abinitio/commonline_c3_c4.py +++ b/src/aspire/abinitio/commonline_c3_c4.py @@ -373,7 +373,9 @@ def _estimate_inplane_rotations(self, vis): # Q is a rank-1 Hermitian matrix. eig_vals, eig_vecs = eigh(Q) leading_eig_vec = eig_vecs[:, -1] - logger.info(f"Top 5 eigenvalues of Q are {str(eig_vals[-5:][::-1])}.") + logger.info( + f"Top 3 eigenvalues of Q (rank-1) are {str(eig_vals[-3:][::-1])}." + ) # Calculate R_thetas. R_thetas = Rotation.about_axis( @@ -493,10 +495,10 @@ def _estimate_all_Riis_c3_c4(self, sclmatrix): # cos_diff should be <= 0.5, but due to discretization that might be violated. if np.max(cos_diff) > 0.5: bad_diffs = np.count_nonzero(cos_diff > 0.5) - logger.warning( + logger.debug( "cos(angular_diff) should be < 0.5." f"Found {bad_diffs} estimates exceeding 0.5, with maximum {np.max(cos_diff)}." - "Setting all bad estimates to 0.5." + " Setting all bad estimates to 0.5." ) cos_diff[cos_diff > 0.5] = 0.5 euler_y2 = np.arccos(cos_diff / (1 - cos_diff)) @@ -504,10 +506,10 @@ def _estimate_all_Riis_c3_c4(self, sclmatrix): # cos_diff should be <= 0, but due to discretization that might be violated. if np.max(cos_diff) > 0: bad_diffs = np.count_nonzero(cos_diff > 0) - logger.warning( + logger.debug( "cos(angular_diff) should be < 0." - f"Found {bad_diffs} estimates exceeding 0, with maximum {np.max(cos_diff)}" - "Setting all bad estimates to 0." + f"Found {bad_diffs} estimates exceeding 0, with maximum {np.max(cos_diff)}." + " Setting all bad estimates to 0." ) cos_diff[cos_diff > 0] = 0 euler_y2 = np.arccos((1 + cos_diff) / (1 - cos_diff)) @@ -680,16 +682,16 @@ def _J_sync_power_method(self, vijs): itr = 0 # Power method iterations + logger.info("Initiating Power Method.") while itr < max_iters and residual > epsilon: itr += 1 vec_new = self._signs_times_v(vijs, vec) vec_new = vec_new / norm(vec_new) residual = norm(vec_new - vec) vec = vec_new - - logger.info( - f"Power method used {itr} iterations. Maximum iterations set to {max_iters}." - ) + logger.info( + f"Iteration {itr}, residual {round(residual, 5)} (target {epsilon})" + ) # We need only the signs of the eigenvector J_sync = np.sign(vec) From 967c531515c1613ba02fbff68a00612b0aafc357 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 21 Apr 2023 10:49:51 -0400 Subject: [PATCH 004/116] use np.maximum in place of boolean slice --- src/aspire/abinitio/commonline_c3_c4.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/aspire/abinitio/commonline_c3_c4.py b/src/aspire/abinitio/commonline_c3_c4.py index d901ce431b..64e2cf59fc 100644 --- a/src/aspire/abinitio/commonline_c3_c4.py +++ b/src/aspire/abinitio/commonline_c3_c4.py @@ -500,7 +500,7 @@ def _estimate_all_Riis_c3_c4(self, sclmatrix): f"Found {bad_diffs} estimates exceeding 0.5, with maximum {np.max(cos_diff)}." " Setting all bad estimates to 0.5." ) - cos_diff[cos_diff > 0.5] = 0.5 + cos_diff = np.maximum(cos_diff, 0.5) euler_y2 = np.arccos(cos_diff / (1 - cos_diff)) else: # cos_diff should be <= 0, but due to discretization that might be violated. @@ -511,7 +511,7 @@ def _estimate_all_Riis_c3_c4(self, sclmatrix): f"Found {bad_diffs} estimates exceeding 0, with maximum {np.max(cos_diff)}." " Setting all bad estimates to 0." ) - cos_diff[cos_diff > 0] = 0 + cos_diff = np.maximum(cos_diff, 0) euler_y2 = np.arccos((1 + cos_diff) / (1 - cos_diff)) # Calculate remaining Euler angles in ZYZ convention. From 0f9c8516c3948a08883fd79fecbf52dcc9ccee4a Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 21 Apr 2023 11:58:12 -0400 Subject: [PATCH 005/116] maximum ~~> minimum. Ooops. --- src/aspire/abinitio/commonline_c3_c4.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/aspire/abinitio/commonline_c3_c4.py b/src/aspire/abinitio/commonline_c3_c4.py index 64e2cf59fc..c5243ce656 100644 --- a/src/aspire/abinitio/commonline_c3_c4.py +++ b/src/aspire/abinitio/commonline_c3_c4.py @@ -500,7 +500,7 @@ def _estimate_all_Riis_c3_c4(self, sclmatrix): f"Found {bad_diffs} estimates exceeding 0.5, with maximum {np.max(cos_diff)}." " Setting all bad estimates to 0.5." ) - cos_diff = np.maximum(cos_diff, 0.5) + cos_diff = np.minimum(cos_diff, 0.5) euler_y2 = np.arccos(cos_diff / (1 - cos_diff)) else: # cos_diff should be <= 0, but due to discretization that might be violated. @@ -511,7 +511,7 @@ def _estimate_all_Riis_c3_c4(self, sclmatrix): f"Found {bad_diffs} estimates exceeding 0, with maximum {np.max(cos_diff)}." " Setting all bad estimates to 0." ) - cos_diff = np.maximum(cos_diff, 0) + cos_diff = np.minimum(cos_diff, 0) euler_y2 = np.arccos((1 + cos_diff) / (1 - cos_diff)) # Calculate remaining Euler angles in ZYZ convention. From 50f807627b7a8e7a2e9c85f40e12d982ee70ded1 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 4 May 2023 11:16:28 -0400 Subject: [PATCH 006/116] edit logger message --- src/aspire/abinitio/commonline_c3_c4.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/aspire/abinitio/commonline_c3_c4.py b/src/aspire/abinitio/commonline_c3_c4.py index c5243ce656..6958f2e5c9 100644 --- a/src/aspire/abinitio/commonline_c3_c4.py +++ b/src/aspire/abinitio/commonline_c3_c4.py @@ -682,7 +682,9 @@ def _J_sync_power_method(self, vijs): itr = 0 # Power method iterations - logger.info("Initiating Power Method.") + logger.info( + "Initiating power method to estimate J-synchronization matrix eigenvector." + ) while itr < max_iters and residual > epsilon: itr += 1 vec_new = self._signs_times_v(vijs, vec) From 6ef2a23ab2c14674d1c9f35bd1d483d171515be8 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 20 Apr 2023 15:16:50 -0400 Subject: [PATCH 007/116] in-place division ~~> long form division --- src/aspire/source/image.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index 2ae1c44ec8..a325a83cd0 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -674,7 +674,8 @@ def downsample(self, L): ds_factor = self.L / L self.unique_filters = [f.scale(ds_factor) for f in self.unique_filters] - self.offsets /= ds_factor + # Using long form division to prevent casting float to int + self.offsets = self.offsets / ds_factor self.L = L From 53afb1a0ee9bf8608f6786fc10f91e3484441c99 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 21 Apr 2023 09:04:18 -0400 Subject: [PATCH 008/116] integer offsets smoke test --- tests/test_downsample.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_downsample.py b/tests/test_downsample.py index 1a44a538d2..cfbcab1ec6 100644 --- a/tests/test_downsample.py +++ b/tests/test_downsample.py @@ -42,6 +42,10 @@ def testDownsample3DCase(self, L, L_ds): # check signal energy is conserved self.assertTrue(self.checkSignalEnergy(vols_org, vols_ds)) + def testIntegerOffsets(self): + sim = Simulation(offsets=0) + _ = sim.downsample(3) + def checkCenterPoint(self, data_org, data_ds): # Check that center point is the same after ds L = data_org.shape[-1] From b4c065cc52fa97ce9d0445d97186850fdddf2330 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 25 Apr 2023 09:02:39 -0400 Subject: [PATCH 009/116] cast offsets to floats (self.dtype) in setter. revert long hand division to in-place --- src/aspire/source/image.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index a325a83cd0..2f01eabd0f 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -390,7 +390,7 @@ def offsets(self): @offsets.setter def offsets(self, values): - return self.set_metadata(["_rlnOriginX", "_rlnOriginY"], values) + return self.set_metadata(["_rlnOriginX", "_rlnOriginY"], np.array(values, dtype=self.dtype)) @property def amplitudes(self): @@ -674,8 +674,7 @@ def downsample(self, L): ds_factor = self.L / L self.unique_filters = [f.scale(ds_factor) for f in self.unique_filters] - # Using long form division to prevent casting float to int - self.offsets = self.offsets / ds_factor + self.offsets /= ds_factor self.L = L From f394b990291433ea2a39c159030c7b40b4f854f0 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 25 Apr 2023 09:43:04 -0400 Subject: [PATCH 010/116] cast amplitudes as self.dtype --- src/aspire/source/image.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index 2f01eabd0f..98e017fb11 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -390,7 +390,9 @@ def offsets(self): @offsets.setter def offsets(self, values): - return self.set_metadata(["_rlnOriginX", "_rlnOriginY"], np.array(values, dtype=self.dtype)) + return self.set_metadata( + ["_rlnOriginX", "_rlnOriginY"], np.array(values, dtype=self.dtype) + ) @property def amplitudes(self): @@ -398,7 +400,7 @@ def amplitudes(self): @amplitudes.setter def amplitudes(self, values): - return self.set_metadata("_rlnAmplitude", values) + return self.set_metadata("_rlnAmplitude", np.array(values, dtype=self.dtype)) @property def angles(self): From ce2a19c24e69da2770cfb15db472a9cc9cf5642f Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 4 Apr 2023 11:50:44 -0400 Subject: [PATCH 011/116] initial FRC and FSC addition --- src/aspire/image/image.py | 99 +++++++++++++++++++++++++ src/aspire/volume/volume.py | 100 +++++++++++++++++++++++++ tests/test_fourier_corrs.py | 143 ++++++++++++++++++++++++++++++++++++ 3 files changed, 342 insertions(+) create mode 100644 tests/test_fourier_corrs.py diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index 2d56bb2847..3a59eadea1 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -478,6 +478,105 @@ def show(self, columns=5, figsize=(20, 10), colorbar=True): plt.show() + def frc(self, other, resolution=None, cutoff=0.143, eps=1e-4, dtype=None): + r""" + Compute the Fourier ring correlation between two images. + + Images are assumed to be well aligned. + + Stack of both images must be `1` and shape of both images must match. + + When `resolution` (pixel-size in Angstrom) is provided, returns + tuple(`estimated_resolution`, FRC as a Numpy array). `estimated_resolution` is 1/Angstrom. + + The FRC is defined as: + + .. math:: + + c(i) = \frac{ \Re( \sum_i{ \mathcal{F}_1(i) * {\mathcal{F}^{*}_2(i) } } ) }{\ + \sqrt{ \sum_i { | \mathcal{F}_1(i) |^2 } * \sum_i{| \mathcal{F}^{*}_2}(i) |^2 } } + + :param other: `Image` instance to compare. + :param resolution: Optional, pixel-size in Angstrom. #TODO check pixel-size vs 1/A? + :param cutoff: Cutoff value, traditionally `1.43`. + :param eps: Epsilon past boundary values, defaults 1e-4. + :param dtype: Optional, dtype. Defaults to `self.dtype`. + + :return: FRC as Numpy array or tuple(estimated_resolution, FRC as a Numpy array). + """ + + dtype = np.dtype(dtype or self.dtype) + + # When passed resolution, sanity check type. + if resolution is not None: + resolution = float(resolution) + + if not isinstance(other, Image): + raise TypeError( + f"`other` image must be an `Image` instance, received {type(other)}" + ) + + if self.shape != other.shape: + raise RuntimeError(f"Shapes do not match, {self.shape} != {other.shape}.") + + if self.stack_ndim != 1 or self.n_images != 1: + raise RuntimeError( + f"FRC is computed between two singletons, received {self}." + ) + + # Compute shells from 2D grid. + L = self.resolution + radii = grid_2d(L, shifted=True, normalized=False, dtype=dtype)["r"] + + # Compute centered Fourier transforms, + # upcasting when nessecary. + f1 = fft.centered_fft2(self.asnumpy().astype(dtype, copy=False)) + f2 = fft.centered_fft2(other.asnumpy().astype(dtype, copy=False)) + + correlations = np.zeros(L // 2, dtype=dtype) + inner_diameter = 0.5 + eps + for i in range(0, L // 2): + # Compute ring mask + outer_diameter = 0.5 + (i + 1) + eps + ring_mask = (radii > inner_diameter) & (radii < outer_diameter) + logger.debug(f"Ring, Elements: {i}, {np.sum(ring_mask)}") + + # Mask off values in Fourier space + r1 = ring_mask * f1 + r2 = ring_mask * f2 + + # Compute FRC + num = np.real(np.sum(r1 * np.conj(r2))) + den = np.sqrt(np.sum(np.abs(r1) ** 2) * np.sum(np.abs(r2) ** 2)) + # Assign + correlations[i] = num / den + # Update ring + inner_diameter = outer_diameter + + logger.debug(f"FRC: {correlations}") + result = correlations + + if resolution is not None: + if np.min(correlations) > cutoff: + # All correlations are above cutoff + c_ind = L // 2 # Index of highest sampled frequency. + elif np.max(correlations) < cutoff: + # All correlations are below cutoff. + c_ind = 0 + else: + # Correlations cross the cutoff. + # Find the first index of a correlation at `cutoff`. + c_ind = np.argmax(correlations <= cutoff) + + # Convert to frequency + c = c_ind * (1 / (L * resolution)) + + logger.debug(f"FRC Resolution: {c}") + # Construct the result tuple. + result = (c, result) + + return result + class CartesianImage(Image): def expand(self, basis): diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index 9c9e9f95df..23a990f02c 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -95,6 +95,7 @@ def __init__(self, data, dtype=None): self.stack_shape = self._data.shape[:-3] self.n_vols = np.prod(self.stack_shape) self.resolution = self._data.shape[-1] + self.size = self._data.size # Numpy interop # https://numpy.org/devdocs/user/basics.interoperability.html#the-array-interface-protocol @@ -485,6 +486,105 @@ def load(cls, filename, permissive=True, dtype=np.float32): logger.info(f"{filename} with dtype {loaded_data.dtype} loaded as {dtype}") return cls(loaded_data.astype(dtype)) + def fsc(self, other, resolution=None, cutoff=0.143, eps=1e-4, dtype=None): + r""" + Compute the Fourier shell correlation between two volumes. + + Volumes are assumed to be well aligned. + + Stack of both volumes must be `1` and shape of both volumes must match. + + When `resolution` (pixel-size in Angstrom) is provided, returns + tuple(`estimated_resolution`, FSC as a Numpy array). `estimated_resolution` is 1/Angstrom. + + The FSC is defined as: + + .. math:: + + c(i) = \frac{ \Re( \sum_i{ \mathcal{F}_1(i) * {\mathcal{F}^{*}_2(i) } } ) }{\ + \sqrt{ \sum_i { | \mathcal{F}_1(i) |^2 } * \sum_i{| \mathcal{F}^{*}_2}(i) |^2 } } + + :param other: `Volume` instance to compare. + :param resolution: Optional, pixel-size in Angstrom. #TODO check pixel-size vs 1/A? + :param cutoff: Cutoff value, traditionally `1.43`. + :param eps: Epsilon past boundary values, defaults 1e-4. + :param dtype: Optional, dtype. Defaults to `self.dtype`. + + :return: FSC as Numpy array or tuple(estimated_resolution, FSC as a Numpy array). + """ + + dtype = np.dtype(dtype or self.dtype) + + # When passed resolution, sanity check type. + if resolution is not None: + resolution = float(resolution) + + if not isinstance(other, Volume): + raise TypeError( + f"`other` image must be an `Volume` instance, received {type(other)}" + ) + + if self.shape != other.shape: + raise RuntimeError(f"Shapes do not match, {self.shape} != {other.shape}.") + + if self.stack_ndim != 1 or self.n_vols != 1: + raise RuntimeError( + f"FSC is computed between two singletons, received {self}." + ) + + # Compute shells from 3D grid. + L = self.resolution + radii = grid_3d(L, shifted=True, normalized=False, dtype=dtype)["r"] + + # Compute centered Fourier transforms, + # upcasting when nessecary. + f1 = fft.centered_fftn(self.asnumpy().astype(dtype, copy=False)) + f2 = fft.centered_fftn(other.asnumpy().astype(dtype, copy=False)) + + correlations = np.zeros(L // 2, dtype=dtype) + inner_diameter = 0.5 + eps + for i in range(0, L // 2): + # Compute shell mask + outer_diameter = 0.5 + (i + 1) + eps + shell_mask = (radii > inner_diameter) & (radii < outer_diameter) + logger.debug(f"Shell, Elements: {i}, {np.sum(shell_mask)}") + + # Mask off values in Fourier space + s1 = shell_mask * f1 + s2 = shell_mask * f2 + + # Compute FSC + num = np.real(np.sum(s1 * np.conj(s2))) + den = np.sqrt(np.sum(np.abs(s1) ** 2) * np.sum(np.abs(s2) ** 2)) + # Assign + correlations[i] = num / den + # Update shell + inner_diameter = outer_diameter + + logger.debug(f"FSC: {correlations}") + result = correlations + + if resolution is not None: + if np.min(correlations) > cutoff: + # All correlations are above cutoff + c_ind = L // 2 # Index of highest sampled frequency. + elif np.max(correlations) < cutoff: + # All correlations are below cutoff. + c_ind = 0 + else: + # Correlations cross the cutoff. + # Find the first index of a correlation at `cutoff`. + c_ind = np.argmax(correlations <= cutoff) + + # Convert to frequency + c = c_ind * (1 / (L * resolution)) + + logger.debug(f"FSC Resolution: {c}") + # Construct the result tuple. + result = (c, result) + + return result + class CartesianVolume(Volume): def expand(self, basis): diff --git a/tests/test_fourier_corrs.py b/tests/test_fourier_corrs.py new file mode 100644 index 0000000000..4bc4127d14 --- /dev/null +++ b/tests/test_fourier_corrs.py @@ -0,0 +1,143 @@ +import logging +import os + +import numpy as np +import pytest + +from aspire.source import Simulation +from aspire.utils import Rotation +from aspire.volume import Volume + +DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data") + +logger = logging.getLogger(__name__) + +IMG_SIZES = [ + 64, + 65, +] +DTYPES = [ + np.float64, + np.float32, +] + + +@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}") +def dtype(request): + return request.param + + +@pytest.fixture(params=IMG_SIZES, ids=lambda x: f"img_size={x}") +def img_size(request): + return request.param + + +@pytest.fixture +def image_fixture(img_size, dtype): + """ + Serve up images with prescribed parameters. + """ + # Load sample molecule volume + v = Volume( + np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol.npy")), dtype=dtype + ).downsample(img_size) + + # Instantiate ASPIRE's Rotation class with a set of angles. + thetas = [0, 1.23] + rots = Rotation.about_axis("z", thetas, dtype=dtype) + + # Contruct the Simulation source. + src = Simulation( + L=img_size, n=2, vols=v, offsets=0, amplitudes=1, C=1, angles=rots.angles + ) + + img, img_rot = src.images[:] + + return img, img_rot + + +@pytest.fixture +def volume_fixture(img_size, dtype): + """ + Serve up volumes with prescribed parameters. + """ + # Load sample molecule volume + vol = Volume( + np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol.npy")), dtype=dtype + ).downsample(img_size) + + # Instantiate ASPIRE's Rotation class with a set of angles. + thetas = [1.23] + rots = Rotation.about_axis("z", thetas, dtype=dtype) + + vol_rot = vol.rotate(rots) + + return vol, vol_rot + + +# FRC + + +def test_frc_id(image_fixture): + img, _ = image_fixture + + frc = img.frc(img) + assert np.allclose(frc, 1) + + frc_resolution, frc = img.frc(img, resolution=1) + assert np.isclose(frc_resolution, 0.5, rtol=0.02) + assert np.allclose(frc, 1) + + +def test_frc_rot(image_fixture): + img_a, img_b = image_fixture + + frc_resolution, frc = img_a.frc(img_b, resolution=1) + assert np.isclose(frc_resolution, 0.031, rtol=0.01) + + +# @pytest.mark.skip(reason="Need to check for valid FRC curve....") +def test_frc_noise(image_fixture): + img_a, _ = image_fixture + + noise = np.random.normal( + loc=np.mean(img_a), scale=0.5 * np.std(img_a), size=img_a.size + ).reshape(img_a.shape) + img_n = img_a + noise + + frc_resolution, frc = img_a.frc(img_n, resolution=1) + assert np.isclose(frc_resolution, 0.3, rtol=0.3) + + +# FSC + + +def test_fsc_id(volume_fixture): + vol, _ = volume_fixture + + fsc = vol.fsc(vol) + assert np.allclose(fsc, 1) + + fsc_resolution, fsc = vol.fsc(vol, resolution=1) + assert np.isclose(fsc_resolution, 0.5, rtol=0.02) + assert np.allclose(fsc, 1) + + +def test_fsc_rot(volume_fixture): + vol_a, vol_b = volume_fixture + + fsc_resolution, fsc = vol_a.fsc(vol_b, resolution=1) + assert np.isclose(fsc_resolution, 0.0930, rtol=0.01) + + +# @pytest.mark.skip(reason="Need to check for valid FSC curve....") +def test_fsc_noise(volume_fixture): + vol_a, _ = volume_fixture + + noise = np.random.normal( + loc=np.mean(vol_a), scale=np.std(vol_a), size=vol_a.size + ).reshape(vol_a.shape) + vol_n = vol_a + noise + + fsc_resolution, fsc = vol_a.fsc(vol_n, resolution=1) + assert np.isclose(fsc_resolution, 0.38, rtol=0.1) From e761be468ddc50716f21c5a76db475ab9eaa29f3 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 5 Apr 2023 07:07:17 -0400 Subject: [PATCH 012/116] clearer Real operator in frc/fsc docstring --- src/aspire/image/image.py | 2 +- src/aspire/volume/volume.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index 3a59eadea1..ae1a612435 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -493,7 +493,7 @@ def frc(self, other, resolution=None, cutoff=0.143, eps=1e-4, dtype=None): .. math:: - c(i) = \frac{ \Re( \sum_i{ \mathcal{F}_1(i) * {\mathcal{F}^{*}_2(i) } } ) }{\ + c(i) = \frac{ \operatorname{Re}( \sum_i{ \mathcal{F}_1(i) * {\mathcal{F}^{*}_2(i) } } ) }{\ \sqrt{ \sum_i { | \mathcal{F}_1(i) |^2 } * \sum_i{| \mathcal{F}^{*}_2}(i) |^2 } } :param other: `Image` instance to compare. diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index 23a990f02c..3296959f0f 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -501,7 +501,7 @@ def fsc(self, other, resolution=None, cutoff=0.143, eps=1e-4, dtype=None): .. math:: - c(i) = \frac{ \Re( \sum_i{ \mathcal{F}_1(i) * {\mathcal{F}^{*}_2(i) } } ) }{\ + c(i) = \frac{ \operatorname{Re}( \sum_i{ \mathcal{F}_1(i) * {\mathcal{F}^{*}_2(i) } } ) }{\ \sqrt{ \sum_i { | \mathcal{F}_1(i) |^2 } * \sum_i{| \mathcal{F}^{*}_2}(i) |^2 } } :param other: `Volume` instance to compare. From 74a5060e7e9edafa29527103d18037b022cb8384 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 5 Apr 2023 07:57:08 -0400 Subject: [PATCH 013/116] add blue and pink noise adders, update unit test --- src/aspire/noise/__init__.py | 2 ++ src/aspire/noise/noise.py | 47 +++++++++++++++++++++++++++++++- tests/test_noise.py | 52 ++++++++++++++++++++++++------------ 3 files changed, 83 insertions(+), 18 deletions(-) diff --git a/src/aspire/noise/__init__.py b/src/aspire/noise/__init__.py index ef95e1675b..1c5c45e08d 100644 --- a/src/aspire/noise/__init__.py +++ b/src/aspire/noise/__init__.py @@ -1,8 +1,10 @@ from .noise import ( AnisotropicNoiseEstimator, + BlueNoiseAdder, CustomNoiseAdder, NoiseAdder, NoiseEstimator, + PinkNoiseAdder, WhiteNoiseAdder, WhiteNoiseEstimator, ) diff --git a/src/aspire/noise/noise.py b/src/aspire/noise/noise.py index d36427fdef..9d7a05f7df 100644 --- a/src/aspire/noise/noise.py +++ b/src/aspire/noise/noise.py @@ -6,7 +6,7 @@ from aspire.image import Image from aspire.image.xform import Xform from aspire.numeric import fft, xp -from aspire.operators import ArrayFilter, PowerFilter, ScalarFilter +from aspire.operators import ArrayFilter, FunctionFilter, PowerFilter, ScalarFilter from aspire.utils import grid_2d, randn, trange logger = logging.getLogger(__name__) @@ -183,6 +183,51 @@ def signal_power(self, p): self._build() +class ColoredNoiseAdder(WhiteNoiseAdder): + @abc.abstractmethod + def _spectrum(self, x, y): + """ + Colored noise spectrum (2d). + """ + + def _build(self): + """ + Builds underlying Filter for this NoiseAdder. + """ + custom_filter = FunctionFilter(f=self._spectrum) * ScalarFilter( + value=self.noise_var + ) + + # Call the __init__ from parent of WhiteNoiseAdder. + super(WhiteNoiseAdder, self).__init__( + noise_filter=custom_filter, seed=self.seed + ) + + +class BlueNoiseAdder(ColoredNoiseAdder): + """ + NoiseAdder where noise power increases with frequency. + """ + + def _spectrum(self, x, y): + s = x[-1] - x[-2] + f = s * np.hypot(x, y) + m = np.mean(f) + return f / m + + +class PinkNoiseAdder(ColoredNoiseAdder): + """ + NoiseAdder where noise power decreases with frequency. + """ + + def _spectrum(self, x, y): + s = x[-1] - x[-2] + f = 2 * s / (np.hypot(x, y) + s) + m = np.mean(f) + return f / m + + class NoiseEstimator: """ Noise Estimator base class. diff --git a/tests/test_noise.py b/tests/test_noise.py index 5a3fa4b3d7..803deb1b57 100644 --- a/tests/test_noise.py +++ b/tests/test_noise.py @@ -8,7 +8,9 @@ from aspire.image import Image from aspire.noise import ( AnisotropicNoiseEstimator, + BlueNoiseAdder, CustomNoiseAdder, + PinkNoiseAdder, WhiteNoiseAdder, WhiteNoiseEstimator, ) @@ -28,6 +30,12 @@ ] +def _noise_function(x, y): + f = np.exp(-(x * x + y * y) / (2 * 0.3**2)) + m = np.mean(f) + return f / m + + def sim_fixture_id(params): res = params[0] dtype = params[1] @@ -60,17 +68,6 @@ def adder(request): return request.param -# Create the custom noise function and associated Filter -def _pinkish_spectrum(x, y): - """ - Util method for generating a pink like noise spectrum. - """ - s = x[-1] - x[-2] - f = 2 * s / (np.hypot(x, y) + s) - m = np.mean(f) - return f / m - - def test_white_noise_estimator_clean_corners(sim_fixture): """ Tests that a clean image yields a noise estimate that is virtually zero. @@ -130,7 +127,7 @@ def test_custom_noise_adder(sim_fixture, target_noise_variance): by generating a sample of the noise. """ - custom_filter = FunctionFilter(f=_pinkish_spectrum) * ScalarFilter( + custom_filter = FunctionFilter(f=_noise_function) * ScalarFilter( value=target_noise_variance ) @@ -209,18 +206,39 @@ def test_from_snr_white(sim_fixture, target_noise_variance): @pytest.mark.parametrize( "target_noise_variance", VARS, ids=lambda param: f"var={param}" ) -def test_pink_iso_noise_estimation(sim_fixture, target_noise_variance): +def test_blue_iso_noise_estimation(sim_fixture, target_noise_variance): """ - Test that prescribing isotropic pink-ish noise + Test that prescribing isotropic blue-ish noise is close to target for a variety of paramaters. """ - custom_filter = FunctionFilter(f=_pinkish_spectrum) * ScalarFilter( - value=target_noise_variance + # Create the CustomNoiseAdder + sim_fixture.noise_adder = BlueNoiseAdder(var=target_noise_variance) + + # TODO, potentially remove or change to Isotropic after #842 + # Compare with AnisotropicNoiseEstimator consuming sim_from_snr + noise_estimator = AnisotropicNoiseEstimator(sim_fixture, batchSize=512) + est_noise_variance = noise_estimator.estimate() + logger.info( + "est_noise_variance, target_noise_variance =" + f" {est_noise_variance}, {target_noise_variance}" ) + # Check we're within 5% + assert np.isclose(est_noise_variance, target_noise_variance, rtol=0.05) + + +@pytest.mark.parametrize( + "target_noise_variance", VARS, ids=lambda param: f"var={param}" +) +def test_pink_iso_noise_estimation(sim_fixture, target_noise_variance): + """ + Test that prescribing isotropic pink-ish noise + is close to target for a variety of paramaters. + """ + # Create the CustomNoiseAdder - sim_fixture.noise_adder = CustomNoiseAdder(noise_filter=custom_filter) + sim_fixture.noise_adder = PinkNoiseAdder(var=target_noise_variance) # TODO, potentially remove or change to Isotropic after #842 # Compare with AnisotropicNoiseEstimator consuming sim_from_snr From 9e5cb4b8e5bb8f2432a0929e92d591857171ee71 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 5 Apr 2023 08:27:39 -0400 Subject: [PATCH 014/116] update FRC to use blue noise --- tests/test_fourier_corrs.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/tests/test_fourier_corrs.py b/tests/test_fourier_corrs.py index 4bc4127d14..51c2218a3b 100644 --- a/tests/test_fourier_corrs.py +++ b/tests/test_fourier_corrs.py @@ -4,6 +4,7 @@ import numpy as np import pytest +from aspire.noise import BlueNoiseAdder from aspire.source import Simulation from aspire.utils import Rotation from aspire.volume import Volume @@ -53,7 +54,13 @@ def image_fixture(img_size, dtype): img, img_rot = src.images[:] - return img, img_rot + noisy_src = Simulation( + L=img_size, n=2, vols=v, offsets=0, amplitudes=1, C=1, angles=rots.angles, + noise_adder=BlueNoiseAdder(var=np.var(img.asnumpy()*0.5)), + ) + img_noisy = noisy_src.images[0] + + return img, img_rot, img_noisy @pytest.fixture @@ -79,7 +86,7 @@ def volume_fixture(img_size, dtype): def test_frc_id(image_fixture): - img, _ = image_fixture + img, _, _ = image_fixture frc = img.frc(img) assert np.allclose(frc, 1) @@ -90,20 +97,14 @@ def test_frc_id(image_fixture): def test_frc_rot(image_fixture): - img_a, img_b = image_fixture + img_a, img_b, _ = image_fixture frc_resolution, frc = img_a.frc(img_b, resolution=1) assert np.isclose(frc_resolution, 0.031, rtol=0.01) -# @pytest.mark.skip(reason="Need to check for valid FRC curve....") def test_frc_noise(image_fixture): - img_a, _ = image_fixture - - noise = np.random.normal( - loc=np.mean(img_a), scale=0.5 * np.std(img_a), size=img_a.size - ).reshape(img_a.shape) - img_n = img_a + noise + img_a, _, img_n = image_fixture frc_resolution, frc = img_a.frc(img_n, resolution=1) assert np.isclose(frc_resolution, 0.3, rtol=0.3) From ef9c8b5973b6bd65df3f6998663e209c400b3dd3 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 5 Apr 2023 08:28:12 -0400 Subject: [PATCH 015/116] linting --- tests/test_fourier_corrs.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/test_fourier_corrs.py b/tests/test_fourier_corrs.py index 51c2218a3b..e5aa916b79 100644 --- a/tests/test_fourier_corrs.py +++ b/tests/test_fourier_corrs.py @@ -55,8 +55,14 @@ def image_fixture(img_size, dtype): img, img_rot = src.images[:] noisy_src = Simulation( - L=img_size, n=2, vols=v, offsets=0, amplitudes=1, C=1, angles=rots.angles, - noise_adder=BlueNoiseAdder(var=np.var(img.asnumpy()*0.5)), + L=img_size, + n=2, + vols=v, + offsets=0, + amplitudes=1, + C=1, + angles=rots.angles, + noise_adder=BlueNoiseAdder(var=np.var(img.asnumpy() * 0.5)), ) img_noisy = noisy_src.images[0] From fa364fd11e9f41707a8f92d6a506d7f1fe73d3e5 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 5 Apr 2023 10:25:42 -0400 Subject: [PATCH 016/116] Use blueish noise for FSC test --- src/aspire/volume/volume.py | 4 ++-- tests/test_fourier_corrs.py | 26 ++++++++++++++------------ 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index 3296959f0f..7a98be9672 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -538,8 +538,8 @@ def fsc(self, other, resolution=None, cutoff=0.143, eps=1e-4, dtype=None): # Compute centered Fourier transforms, # upcasting when nessecary. - f1 = fft.centered_fftn(self.asnumpy().astype(dtype, copy=False)) - f2 = fft.centered_fftn(other.asnumpy().astype(dtype, copy=False)) + f1 = fft.centered_fftn(self.asnumpy()[0].astype(dtype, copy=False)) + f2 = fft.centered_fftn(other.asnumpy()[0].astype(dtype, copy=False)) correlations = np.zeros(L // 2, dtype=dtype) inner_diameter = 0.5 + eps diff --git a/tests/test_fourier_corrs.py b/tests/test_fourier_corrs.py index e5aa916b79..006a31bffd 100644 --- a/tests/test_fourier_corrs.py +++ b/tests/test_fourier_corrs.py @@ -5,8 +5,9 @@ import pytest from aspire.noise import BlueNoiseAdder +from aspire.numeric import fft from aspire.source import Simulation -from aspire.utils import Rotation +from aspire.utils import Rotation, grid_3d from aspire.volume import Volume DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data") @@ -85,7 +86,14 @@ def volume_fixture(img_size, dtype): vol_rot = vol.rotate(rots) - return vol, vol_rot + # Scale gaussian noise radially + noise = np.random.normal(loc=0, scale=1, size=vol.shape) + noise = noise * (1.0 + grid_3d(img_size, normalized=False)["r"]) * 0.33 + vol_noise = Volume( + np.real(fft.centered_ifftn(fft.centered_fftn(vol.asnumpy()[0]) * (1 + noise))) + ) + + return vol, vol_rot, vol_noise # FRC @@ -120,7 +128,7 @@ def test_frc_noise(image_fixture): def test_fsc_id(volume_fixture): - vol, _ = volume_fixture + vol, _, _ = volume_fixture fsc = vol.fsc(vol) assert np.allclose(fsc, 1) @@ -131,20 +139,14 @@ def test_fsc_id(volume_fixture): def test_fsc_rot(volume_fixture): - vol_a, vol_b = volume_fixture + vol_a, vol_b, _ = volume_fixture fsc_resolution, fsc = vol_a.fsc(vol_b, resolution=1) assert np.isclose(fsc_resolution, 0.0930, rtol=0.01) -# @pytest.mark.skip(reason="Need to check for valid FSC curve....") def test_fsc_noise(volume_fixture): - vol_a, _ = volume_fixture - - noise = np.random.normal( - loc=np.mean(vol_a), scale=np.std(vol_a), size=vol_a.size - ).reshape(vol_a.shape) - vol_n = vol_a + noise + vol_a, _, vol_n = volume_fixture fsc_resolution, fsc = vol_a.fsc(vol_n, resolution=1) - assert np.isclose(fsc_resolution, 0.38, rtol=0.1) + assert np.isclose(fsc_resolution, 0.38, rtol=0.3) From 9b1ba844f92a61e8671682ab9ec43304f7431479 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 6 Apr 2023 14:46:17 -0400 Subject: [PATCH 017/116] convert towards a shared class with broadcasting --- src/aspire/image/image.py | 89 ++------ src/aspire/reconstruction/__init__.py | 1 + .../reconstruction/resolution_estimation.py | 192 ++++++++++++++++++ src/aspire/volume/volume.py | 92 ++------- tests/test_fourier_corrs.py | 36 ++-- 5 files changed, 239 insertions(+), 171 deletions(-) create mode 100644 src/aspire/reconstruction/resolution_estimation.py diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index ae1a612435..8382de5dea 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -9,6 +9,7 @@ import aspire.volume from aspire.nufft import anufft from aspire.numeric import fft, xp +from aspire.reconstruction import FourierRingCorrelation from aspire.utils import crop_pad_2d, grid_2d from aspire.utils.matrix import anorm @@ -478,17 +479,12 @@ def show(self, columns=5, figsize=(20, 10), colorbar=True): plt.show() - def frc(self, other, resolution=None, cutoff=0.143, eps=1e-4, dtype=None): + def frc(self, other, pixel_size, cutoff=0.143, eps=1e-4): r""" Compute the Fourier ring correlation between two images. Images are assumed to be well aligned. - Stack of both images must be `1` and shape of both images must match. - - When `resolution` (pixel-size in Angstrom) is provided, returns - tuple(`estimated_resolution`, FRC as a Numpy array). `estimated_resolution` is 1/Angstrom. - The FRC is defined as: .. math:: @@ -497,85 +493,30 @@ def frc(self, other, resolution=None, cutoff=0.143, eps=1e-4, dtype=None): \sqrt{ \sum_i { | \mathcal{F}_1(i) |^2 } * \sum_i{| \mathcal{F}^{*}_2}(i) |^2 } } :param other: `Image` instance to compare. - :param resolution: Optional, pixel-size in Angstrom. #TODO check pixel-size vs 1/A? + :param pixel_size: Pixel size in Angstrom. + For synthetic data, 1 is a reasonable value. :param cutoff: Cutoff value, traditionally `1.43`. :param eps: Epsilon past boundary values, defaults 1e-4. - :param dtype: Optional, dtype. Defaults to `self.dtype`. - :return: FRC as Numpy array or tuple(estimated_resolution, FRC as a Numpy array). + :return: tuple(estimated_resolution, FRC), + where `estimated_resolution` is in Angstrom + and FRC is a Numpy array of correlations. """ - dtype = np.dtype(dtype or self.dtype) - - # When passed resolution, sanity check type. - if resolution is not None: - resolution = float(resolution) - if not isinstance(other, Image): raise TypeError( f"`other` image must be an `Image` instance, received {type(other)}" ) - if self.shape != other.shape: - raise RuntimeError(f"Shapes do not match, {self.shape} != {other.shape}.") - - if self.stack_ndim != 1 or self.n_images != 1: - raise RuntimeError( - f"FRC is computed between two singletons, received {self}." - ) + frc = FourierRingCorrelation( + a=self.asnumpy(), + b=other.asnumpy(), + pixel_size=pixel_size, + cutoff=cutoff, + eps=eps, + ) - # Compute shells from 2D grid. - L = self.resolution - radii = grid_2d(L, shifted=True, normalized=False, dtype=dtype)["r"] - - # Compute centered Fourier transforms, - # upcasting when nessecary. - f1 = fft.centered_fft2(self.asnumpy().astype(dtype, copy=False)) - f2 = fft.centered_fft2(other.asnumpy().astype(dtype, copy=False)) - - correlations = np.zeros(L // 2, dtype=dtype) - inner_diameter = 0.5 + eps - for i in range(0, L // 2): - # Compute ring mask - outer_diameter = 0.5 + (i + 1) + eps - ring_mask = (radii > inner_diameter) & (radii < outer_diameter) - logger.debug(f"Ring, Elements: {i}, {np.sum(ring_mask)}") - - # Mask off values in Fourier space - r1 = ring_mask * f1 - r2 = ring_mask * f2 - - # Compute FRC - num = np.real(np.sum(r1 * np.conj(r2))) - den = np.sqrt(np.sum(np.abs(r1) ** 2) * np.sum(np.abs(r2) ** 2)) - # Assign - correlations[i] = num / den - # Update ring - inner_diameter = outer_diameter - - logger.debug(f"FRC: {correlations}") - result = correlations - - if resolution is not None: - if np.min(correlations) > cutoff: - # All correlations are above cutoff - c_ind = L // 2 # Index of highest sampled frequency. - elif np.max(correlations) < cutoff: - # All correlations are below cutoff. - c_ind = 0 - else: - # Correlations cross the cutoff. - # Find the first index of a correlation at `cutoff`. - c_ind = np.argmax(correlations <= cutoff) - - # Convert to frequency - c = c_ind * (1 / (L * resolution)) - - logger.debug(f"FRC Resolution: {c}") - # Construct the result tuple. - result = (c, result) - - return result + return frc.estimated_resolution, frc.correlations class CartesianImage(Image): diff --git a/src/aspire/reconstruction/__init__.py b/src/aspire/reconstruction/__init__.py index 103e116290..245ebdbdaa 100644 --- a/src/aspire/reconstruction/__init__.py +++ b/src/aspire/reconstruction/__init__.py @@ -1,3 +1,4 @@ from .estimator import Estimator from .kernel import FourierKernel, Kernel from .mean import MeanEstimator +from .resolution_estimation import FourierRingCorrelation, FourierShellCorrelation diff --git a/src/aspire/reconstruction/resolution_estimation.py b/src/aspire/reconstruction/resolution_estimation.py new file mode 100644 index 0000000000..518db9bc5a --- /dev/null +++ b/src/aspire/reconstruction/resolution_estimation.py @@ -0,0 +1,192 @@ +""" +This module contains code for estimating resolution achieved by reconstructions. +""" +import logging + +import numpy as np + +from aspire.numeric import fft +from aspire.utils import grid_2d + +logger = logging.getLogger(__name__) + + +class _FourierCorrelation: + r""" + Compute the Fourier correlations between two arrays. + + Underlying data (images/volumes) are assumed to be well aligned. + + The Fourier correlation is defined as: + + .. math:: + + c(i) = \frac{ \operatorname{Re}( \sum_i{ \mathcal{F}_1(i) * {\mathcal{F}^{*}_2(i) } } ) }{\ + \sqrt{ \sum_i { | \mathcal{F}_1(i) |^2 } * \sum_i{| \mathcal{F}^{*}_2}(i) |^2 } } +. + """ + + def __init__(self, a, b, pixel_size, cutoff=0.143, eps=1e-4): + """ + :param a: Input array a, shape(..., *dim). + :param b: Input array b, shape(..., *dim). + :param pixel_size: Pixel size in Angstrom. + :param cutoff: Cutoff value, traditionally `.143`. + :param eps: Epsilon past boundary values, defaults 1e-4. + """ + + # Sanity checks + if not hasattr(self, "dim"): + raise RuntimeError("Subclass must assign `dim`") + for x in (a, b): + if not isinstance(x, np.ndarray): + raise TypeError(f"`{x.__name__}` is not a Numpy array.") + + if not a.dtype == b.dtype: + raise TypeError( + f"Mismatched input types {a.dtype} != {b.dtype}. Cast `a` or `b`." + ) + # TODO, check-math/avoid complex inputs. + + # Shape checks + if not a.shape[-1] == b.shape[-1]: + raise RuntimeError( + f"`a` and `b` appear to have different data axis shapes, {a.shape[-1]} {b.shape[-1]}" + ) + + # To support arbitrary broadcasting simply, + # we'll force all shapes to be (-1, *(L,)*dim) + self._a, self._a_stack_shape = self._reshape(a) + self._b, self._b_stack_shape = self._reshape(b) + + self._analyzed = False + self.cutoff = cutoff + self.pixel_size = float(pixel_size) + self.eps = float(eps) + self._correlations = None + self.L = self._a.shape[-1] + self.dtype = self._a.dtype + + @property + def _fourier_axes(self): + return tuple(range(-self.dim, 0)) + + def _reshape(self, x): + """ + Returns `x` with flattened stack axis and `x`'s original stack shape, as determined by `dim`. + + :param x: Numpy ndarray + """ + # TODO, check 2d in put for dim=2 (singleton case) + original_stack_shape = x.shape[: -self.dim] + x = x.reshape(-1, *x.shape[-self.dim :]) + return x, original_stack_shape + + @property + def cutoff(self): + return self._cutoff + + @cutoff.setter + def cutoff(self, cutoff_correlation): + self._cutoff = float(cutoff_correlation) + self._analyzed = False # reset analysis + + @property + def correlations(self): + # There is no need to run this twice if we assume inputs are immutable + if self._correlations is not None: + return self._correlations + + # Compute shells from 2D grid. + radii = grid_2d(self.L, shifted=True, normalized=False, dtype=self.dtype)["r"] + + # Compute centered Fourier transforms, + # upcasting when nessecary. + f1 = fft.centered_fftn(self._a, axes=self._fourier_axes) + f2 = fft.centered_fftn(self._b, axes=self._fourier_axes) + + # Construct an output table of correlations + correlations = np.zeros( + (self.L // 2, self._a.shape[0], self._b.shape[0]), dtype=self.dtype + ) + + inner_diameter = 0.5 + self.eps + for i in range(0, self.L // 2): + # Compute ring mask + outer_diameter = 0.5 + (i + 1) + self.eps + ring_mask = (radii > inner_diameter) & (radii < outer_diameter) + logger.debug(f"Shell, Elements: {i}, {np.sum(ring_mask)}") + + # Mask off values in Fourier space + r1 = ring_mask * f1 + r2 = ring_mask * f2 + + # Compute FRC + num = np.real(np.sum(r1 * np.conj(r2), axis=self._fourier_axes)) + den = np.sqrt( + np.sum(np.abs(r1) ** 2, axis=self._fourier_axes) + * np.sum(np.abs(r2) ** 2, axis=self._fourier_axes) + ) + # Assign + correlations[i] = num / den + # Update ring + inner_diameter = outer_diameter + + # Repack the table as (_a, _b, L//2) + correlations = np.swapaxes(correlations, 0, 2) + # Then unpack the a and b shapes. + self._correlations = correlations.reshape( + *self._a_stack_shape, *self._b_stack_shape, self.L // 2 + ) + return self._correlations + + @property + def estimated_resolution(self): + """ """ + self.analyze_correlations() + return self._resolutions + + def analyze_correlations(self): + """ """ + if self._analyzed: + return + + c_inds = np.zeros(self.correlations.shape[:-1], dtype=int) + + # All correlations are above cutoff, + # set index of highest sampled frequency. + c_inds[np.min(self.correlations, axis=-1) > self.cutoff] = self.L // 2 + + # # All correlations are below cutoff, + # # set index to 0 + # elif np.max(correlations) < cutoff: + # c_ind = 0 + # else: + + # Correlations cross the cutoff. + # Find the first index of a correlation at `cutoff`. + c_ind = np.maximum(c_inds, np.argmax(self.correlations <= self.cutoff, axis=-1)) + + # Convert indices to frequency (as length 1/A) and assign + freqs = c_ind * (1 / (self.L * self.pixel_size)) + + self._resolutions = 1 / freqs + + def plot(self, to_file=False): + """ """ + + +class FourierRingCorrelation(_FourierCorrelation): + """ + See `_FourierCorrelation`. + """ + + dim = 2 + + +class FourierShellCorrelation(_FourierCorrelation): + """ + See `_FourierCorrelation`. + """ + + dim = 3 diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index 7a98be9672..f17cef0430 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -486,17 +486,12 @@ def load(cls, filename, permissive=True, dtype=np.float32): logger.info(f"{filename} with dtype {loaded_data.dtype} loaded as {dtype}") return cls(loaded_data.astype(dtype)) - def fsc(self, other, resolution=None, cutoff=0.143, eps=1e-4, dtype=None): + def fsc(self, other, pixel_size=None, cutoff=0.143, eps=1e-4, dtype=None): r""" Compute the Fourier shell correlation between two volumes. Volumes are assumed to be well aligned. - Stack of both volumes must be `1` and shape of both volumes must match. - - When `resolution` (pixel-size in Angstrom) is provided, returns - tuple(`estimated_resolution`, FSC as a Numpy array). `estimated_resolution` is 1/Angstrom. - The FSC is defined as: .. math:: @@ -505,85 +500,30 @@ def fsc(self, other, resolution=None, cutoff=0.143, eps=1e-4, dtype=None): \sqrt{ \sum_i { | \mathcal{F}_1(i) |^2 } * \sum_i{| \mathcal{F}^{*}_2}(i) |^2 } } :param other: `Volume` instance to compare. - :param resolution: Optional, pixel-size in Angstrom. #TODO check pixel-size vs 1/A? - :param cutoff: Cutoff value, traditionally `1.43`. + :param pixel_size: Pixel size in Angstrom. + For synthetic data, 1 is a reasonable value. + :param cutoff: Cutoff value, traditionally `.143`. :param eps: Epsilon past boundary values, defaults 1e-4. - :param dtype: Optional, dtype. Defaults to `self.dtype`. - - :return: FSC as Numpy array or tuple(estimated_resolution, FSC as a Numpy array). + :return: tuple(estimated_resolution, FRC), + where `estimated_resolution` is in Angstrom + and FRC is a Numpy array of correlations. """ - - dtype = np.dtype(dtype or self.dtype) - - # When passed resolution, sanity check type. - if resolution is not None: - resolution = float(resolution) + from aspire.reconstruction import FourierShellCorrelation if not isinstance(other, Volume): raise TypeError( f"`other` image must be an `Volume` instance, received {type(other)}" ) - if self.shape != other.shape: - raise RuntimeError(f"Shapes do not match, {self.shape} != {other.shape}.") - - if self.stack_ndim != 1 or self.n_vols != 1: - raise RuntimeError( - f"FSC is computed between two singletons, received {self}." - ) + fsc = FourierShellCorrelation( + a=self.asnumpy(), + b=other.asnumpy(), + pixel_size=pixel_size, + cutoff=cutoff, + eps=eps, + ) - # Compute shells from 3D grid. - L = self.resolution - radii = grid_3d(L, shifted=True, normalized=False, dtype=dtype)["r"] - - # Compute centered Fourier transforms, - # upcasting when nessecary. - f1 = fft.centered_fftn(self.asnumpy()[0].astype(dtype, copy=False)) - f2 = fft.centered_fftn(other.asnumpy()[0].astype(dtype, copy=False)) - - correlations = np.zeros(L // 2, dtype=dtype) - inner_diameter = 0.5 + eps - for i in range(0, L // 2): - # Compute shell mask - outer_diameter = 0.5 + (i + 1) + eps - shell_mask = (radii > inner_diameter) & (radii < outer_diameter) - logger.debug(f"Shell, Elements: {i}, {np.sum(shell_mask)}") - - # Mask off values in Fourier space - s1 = shell_mask * f1 - s2 = shell_mask * f2 - - # Compute FSC - num = np.real(np.sum(s1 * np.conj(s2))) - den = np.sqrt(np.sum(np.abs(s1) ** 2) * np.sum(np.abs(s2) ** 2)) - # Assign - correlations[i] = num / den - # Update shell - inner_diameter = outer_diameter - - logger.debug(f"FSC: {correlations}") - result = correlations - - if resolution is not None: - if np.min(correlations) > cutoff: - # All correlations are above cutoff - c_ind = L // 2 # Index of highest sampled frequency. - elif np.max(correlations) < cutoff: - # All correlations are below cutoff. - c_ind = 0 - else: - # Correlations cross the cutoff. - # Find the first index of a correlation at `cutoff`. - c_ind = np.argmax(correlations <= cutoff) - - # Convert to frequency - c = c_ind * (1 / (L * resolution)) - - logger.debug(f"FSC Resolution: {c}") - # Construct the result tuple. - result = (c, result) - - return result + return fsc.estimated_resolution, fsc.correlations class CartesianVolume(Volume): diff --git a/tests/test_fourier_corrs.py b/tests/test_fourier_corrs.py index 006a31bffd..ba7738a884 100644 --- a/tests/test_fourier_corrs.py +++ b/tests/test_fourier_corrs.py @@ -16,11 +16,11 @@ IMG_SIZES = [ 64, - 65, + # 65, ] DTYPES = [ np.float64, - np.float32, + # np.float32, ] @@ -81,7 +81,7 @@ def volume_fixture(img_size, dtype): ).downsample(img_size) # Instantiate ASPIRE's Rotation class with a set of angles. - thetas = [1.23] + thetas = [0.12] rots = Rotation.about_axis("z", thetas, dtype=dtype) vol_rot = vol.rotate(rots) @@ -102,26 +102,23 @@ def volume_fixture(img_size, dtype): def test_frc_id(image_fixture): img, _, _ = image_fixture - frc = img.frc(img) - assert np.allclose(frc, 1) - - frc_resolution, frc = img.frc(img, resolution=1) - assert np.isclose(frc_resolution, 0.5, rtol=0.02) + frc_resolution, frc = img.frc(img, pixel_size=1) + assert np.isclose(frc_resolution[0][0], 2, rtol=0.02) assert np.allclose(frc, 1) def test_frc_rot(image_fixture): img_a, img_b, _ = image_fixture - frc_resolution, frc = img_a.frc(img_b, resolution=1) - assert np.isclose(frc_resolution, 0.031, rtol=0.01) + frc_resolution, frc = img_a.frc(img_b, pixel_size=1) + assert np.isclose(frc_resolution[0][0], 1 / 0.031, rtol=0.01) def test_frc_noise(image_fixture): img_a, _, img_n = image_fixture - frc_resolution, frc = img_a.frc(img_n, resolution=1) - assert np.isclose(frc_resolution, 0.3, rtol=0.3) + frc_resolution, frc = img_a.frc(img_n, pixel_size=1) + assert np.isclose(frc_resolution[0][0], 1 / 0.3, rtol=0.3) # FSC @@ -130,23 +127,20 @@ def test_frc_noise(image_fixture): def test_fsc_id(volume_fixture): vol, _, _ = volume_fixture - fsc = vol.fsc(vol) - assert np.allclose(fsc, 1) - - fsc_resolution, fsc = vol.fsc(vol, resolution=1) - assert np.isclose(fsc_resolution, 0.5, rtol=0.02) + fsc_resolution, fsc = vol.fsc(vol, pixel_size=1) + assert np.isclose(fsc_resolution[0][0], 2.0, rtol=0.02) assert np.allclose(fsc, 1) def test_fsc_rot(volume_fixture): vol_a, vol_b, _ = volume_fixture - fsc_resolution, fsc = vol_a.fsc(vol_b, resolution=1) - assert np.isclose(fsc_resolution, 0.0930, rtol=0.01) + fsc_resolution, fsc = vol_a.fsc(vol_b, pixel_size=1) + assert np.isclose(fsc_resolution[0][0], 3.2, rtol=0.01) def test_fsc_noise(volume_fixture): vol_a, _, vol_n = volume_fixture - fsc_resolution, fsc = vol_a.fsc(vol_n, resolution=1) - assert np.isclose(fsc_resolution, 0.38, rtol=0.3) + fsc_resolution, fsc = vol_a.fsc(vol_n, pixel_size=1) + assert np.isclose(fsc_resolution[0][0], 1 / 0.38, rtol=0.3) From 5ca99dc692b7658019559df1ab73b8f560c553fb Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 6 Apr 2023 16:22:04 -0400 Subject: [PATCH 018/116] update and restore test cases for Fourier corrs --- tests/test_fourier_corrs.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_fourier_corrs.py b/tests/test_fourier_corrs.py index ba7738a884..5690042573 100644 --- a/tests/test_fourier_corrs.py +++ b/tests/test_fourier_corrs.py @@ -16,11 +16,11 @@ IMG_SIZES = [ 64, - # 65, + 65, ] DTYPES = [ np.float64, - # np.float32, + np.float32, ] @@ -45,7 +45,7 @@ def image_fixture(img_size, dtype): ).downsample(img_size) # Instantiate ASPIRE's Rotation class with a set of angles. - thetas = [0, 1.23] + thetas = [0, 0.123] rots = Rotation.about_axis("z", thetas, dtype=dtype) # Contruct the Simulation source. @@ -111,7 +111,7 @@ def test_frc_rot(image_fixture): img_a, img_b, _ = image_fixture frc_resolution, frc = img_a.frc(img_b, pixel_size=1) - assert np.isclose(frc_resolution[0][0], 1 / 0.031, rtol=0.01) + assert np.isclose(frc_resolution[0][0], 3.76, rtol=0.01) def test_frc_noise(image_fixture): From 9111a25cf1579314506dc2632cd28d53ae66190e Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 7 Apr 2023 07:10:53 -0400 Subject: [PATCH 019/116] refactor imports --- src/aspire/image/image.py | 4 +--- src/aspire/reconstruction/__init__.py | 1 - src/aspire/utils/__init__.py | 2 ++ .../resolution_estimation.py | 0 src/aspire/volume/volume.py | 14 ++++++++--- tests/test_fourier_corrs.py | 24 ++++++++++++------- 6 files changed, 30 insertions(+), 15 deletions(-) rename src/aspire/{reconstruction => utils}/resolution_estimation.py (100%) diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index 8382de5dea..b99d58cf24 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -9,9 +9,7 @@ import aspire.volume from aspire.nufft import anufft from aspire.numeric import fft, xp -from aspire.reconstruction import FourierRingCorrelation -from aspire.utils import crop_pad_2d, grid_2d -from aspire.utils.matrix import anorm +from aspire.utils import FourierRingCorrelation, anorm, crop_pad_2d, grid_2d logger = logging.getLogger(__name__) diff --git a/src/aspire/reconstruction/__init__.py b/src/aspire/reconstruction/__init__.py index 245ebdbdaa..103e116290 100644 --- a/src/aspire/reconstruction/__init__.py +++ b/src/aspire/reconstruction/__init__.py @@ -1,4 +1,3 @@ from .estimator import Estimator from .kernel import FourierKernel, Kernel from .mean import MeanEstimator -from .resolution_estimation import FourierRingCorrelation, FourierShellCorrelation diff --git a/src/aspire/utils/__init__.py b/src/aspire/utils/__init__.py index 7488854d25..dee6ed6068 100644 --- a/src/aspire/utils/__init__.py +++ b/src/aspire/utils/__init__.py @@ -31,6 +31,7 @@ ) from .logging import LogFilterByCount, get_full_version, tqdm, trange + from .matrix import ( acorr, ainner, @@ -64,6 +65,7 @@ ) from .random import Random, choice, rand, randi, randn, random from .relion_interop import RelionStarFile, relion_metadata_fields +from .resolution_estimation import FourierRingCorrelation, FourierShellCorrelation from .rotation import Rotation from .types import complex_type, real_type, utest_tolerance from .units import ratio_to_decibel, voltage_to_wavelength, wavelength_to_voltage diff --git a/src/aspire/reconstruction/resolution_estimation.py b/src/aspire/utils/resolution_estimation.py similarity index 100% rename from src/aspire/reconstruction/resolution_estimation.py rename to src/aspire/utils/resolution_estimation.py diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index f17cef0430..81daede16f 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -7,8 +7,16 @@ import aspire.image from aspire.nufft import nufft from aspire.numeric import fft, xp -from aspire.utils import Rotation, crop_pad_3d, grid_2d, grid_3d, mat_to_vec, vec_to_mat -from aspire.utils.types import complex_type +from aspire.utils import ( + FourierShellCorrelation, + Rotation, + complex_type, + crop_pad_3d, + grid_2d, + grid_3d, + mat_to_vec, + vec_to_mat, +) logger = logging.getLogger(__name__) @@ -508,7 +516,7 @@ def fsc(self, other, pixel_size=None, cutoff=0.143, eps=1e-4, dtype=None): where `estimated_resolution` is in Angstrom and FRC is a Numpy array of correlations. """ - from aspire.reconstruction import FourierShellCorrelation + # from aspire.reconstruction import FourierShellCorrelation if not isinstance(other, Volume): raise TypeError( diff --git a/tests/test_fourier_corrs.py b/tests/test_fourier_corrs.py index 5690042573..7b7d6d9b50 100644 --- a/tests/test_fourier_corrs.py +++ b/tests/test_fourier_corrs.py @@ -50,7 +50,14 @@ def image_fixture(img_size, dtype): # Contruct the Simulation source. src = Simulation( - L=img_size, n=2, vols=v, offsets=0, amplitudes=1, C=1, angles=rots.angles + L=img_size, + n=2, + vols=v, + offsets=0, + amplitudes=1, + C=1, + angles=rots.angles, + dtype=dtype, ) img, img_rot = src.images[:] @@ -64,6 +71,7 @@ def image_fixture(img_size, dtype): C=1, angles=rots.angles, noise_adder=BlueNoiseAdder(var=np.var(img.asnumpy() * 0.5)), + dtype=dtype, ) img_noisy = noisy_src.images[0] @@ -87,8 +95,8 @@ def volume_fixture(img_size, dtype): vol_rot = vol.rotate(rots) # Scale gaussian noise radially - noise = np.random.normal(loc=0, scale=1, size=vol.shape) - noise = noise * (1.0 + grid_3d(img_size, normalized=False)["r"]) * 0.33 + noise = np.random.normal(loc=0, scale=1, size=vol.shape).astype(dtype, copy=False) + noise = noise * (1.0 + grid_3d(img_size, normalized=False)["r"]) * 0.3 vol_noise = Volume( np.real(fft.centered_ifftn(fft.centered_fftn(vol.asnumpy()[0]) * (1 + noise))) ) @@ -109,16 +117,16 @@ def test_frc_id(image_fixture): def test_frc_rot(image_fixture): img_a, img_b, _ = image_fixture - + assert img_a.dtype == img_b.dtype frc_resolution, frc = img_a.frc(img_b, pixel_size=1) - assert np.isclose(frc_resolution[0][0], 3.76, rtol=0.01) + assert np.isclose(frc_resolution[0][0], 3.78, rtol=0.1) def test_frc_noise(image_fixture): img_a, _, img_n = image_fixture frc_resolution, frc = img_a.frc(img_n, pixel_size=1) - assert np.isclose(frc_resolution[0][0], 1 / 0.3, rtol=0.3) + assert np.isclose(frc_resolution[0][0], 3.5, rtol=0.2) # FSC @@ -136,11 +144,11 @@ def test_fsc_rot(volume_fixture): vol_a, vol_b, _ = volume_fixture fsc_resolution, fsc = vol_a.fsc(vol_b, pixel_size=1) - assert np.isclose(fsc_resolution[0][0], 3.2, rtol=0.01) + assert np.isclose(fsc_resolution[0][0], 3.225, rtol=0.01) def test_fsc_noise(volume_fixture): vol_a, _, vol_n = volume_fixture fsc_resolution, fsc = vol_a.fsc(vol_n, pixel_size=1) - assert np.isclose(fsc_resolution[0][0], 1 / 0.38, rtol=0.3) + assert np.isclose(fsc_resolution[0][0], 2.6, rtol=0.3) From d7e87e08ae534bae5ef66720b4a2d399258a36c0 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 7 Apr 2023 08:08:12 -0400 Subject: [PATCH 020/116] fixup the freq conversion --- src/aspire/utils/resolution_estimation.py | 34 +++++++++++++++++++---- tests/test_fourier_corrs.py | 12 ++++---- 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/src/aspire/utils/resolution_estimation.py b/src/aspire/utils/resolution_estimation.py index 518db9bc5a..6e9734c735 100644 --- a/src/aspire/utils/resolution_estimation.py +++ b/src/aspire/utils/resolution_estimation.py @@ -147,7 +147,9 @@ def estimated_resolution(self): return self._resolutions def analyze_correlations(self): - """ """ + """ + Convert from the Fourier Correlations to frequencies and resolution. + """ if self._analyzed: return @@ -167,14 +169,34 @@ def analyze_correlations(self): # Find the first index of a correlation at `cutoff`. c_ind = np.maximum(c_inds, np.argmax(self.correlations <= self.cutoff, axis=-1)) - # Convert indices to frequency (as length 1/A) and assign - freqs = c_ind * (1 / (self.L * self.pixel_size)) + # Convert indices to frequency (as 1/Angstrom) + frequencies = self._freq(c_ind) - self._resolutions = 1 / freqs + # Convert to resolution in Angstrom, smaller is higher frequency. + self._resolutions = 1 / frequencies - def plot(self, to_file=False): - """ """ + def _freq(self, k): + """ + Converts `k` from index of Fourier transform to frequency (as length 1/A). + From Shannon-Nyquist, for a given pixel-size, sampling theorem limits us to the sampled frequency 1/pixel_size. + Thus the Bandwidth ranges from `[-1/pixel_size, 1/pixel_size]`, so the total bandwidth is `2*(1/pixel_size)`. + + Given a real space signal observed with `L` bins (pixels/voxels), each with a `pixel_size` in Angstrom, + we can compute the width of a Fourier space bin to be the `Bandwidth / L = (2*(1/pixel_size)) / L`. + Thus the frequency at an index `k` is `freq_k = k * 2 * (1 / pixel_size) / L = 2*k / (pixel_size * L) + """ + + # _freq(k) Units: 1 / (pixels * (Angstrom / pixel) = 1 / Angstrom + # Similar idea to wavenumbers (cm-1). Larger is higher frequency. + return k * 2 / (self.L * self.pixel_size) + + + def plot(self, to_file=False): + """ + Generates a Fourier correlation plot. + """ + class FourierRingCorrelation(_FourierCorrelation): """ diff --git a/tests/test_fourier_corrs.py b/tests/test_fourier_corrs.py index 7b7d6d9b50..d5f7640da8 100644 --- a/tests/test_fourier_corrs.py +++ b/tests/test_fourier_corrs.py @@ -111,7 +111,7 @@ def test_frc_id(image_fixture): img, _, _ = image_fixture frc_resolution, frc = img.frc(img, pixel_size=1) - assert np.isclose(frc_resolution[0][0], 2, rtol=0.02) + assert np.isclose(frc_resolution[0][0], 1, rtol=0.02) assert np.allclose(frc, 1) @@ -119,14 +119,14 @@ def test_frc_rot(image_fixture): img_a, img_b, _ = image_fixture assert img_a.dtype == img_b.dtype frc_resolution, frc = img_a.frc(img_b, pixel_size=1) - assert np.isclose(frc_resolution[0][0], 3.78, rtol=0.1) + assert np.isclose(frc_resolution[0][0], 3.78/2, rtol=0.1) def test_frc_noise(image_fixture): img_a, _, img_n = image_fixture frc_resolution, frc = img_a.frc(img_n, pixel_size=1) - assert np.isclose(frc_resolution[0][0], 3.5, rtol=0.2) + assert np.isclose(frc_resolution[0][0], 3.5/2, rtol=0.2) # FSC @@ -136,7 +136,7 @@ def test_fsc_id(volume_fixture): vol, _, _ = volume_fixture fsc_resolution, fsc = vol.fsc(vol, pixel_size=1) - assert np.isclose(fsc_resolution[0][0], 2.0, rtol=0.02) + assert np.isclose(fsc_resolution[0][0], 1, rtol=0.02) assert np.allclose(fsc, 1) @@ -144,11 +144,11 @@ def test_fsc_rot(volume_fixture): vol_a, vol_b, _ = volume_fixture fsc_resolution, fsc = vol_a.fsc(vol_b, pixel_size=1) - assert np.isclose(fsc_resolution[0][0], 3.225, rtol=0.01) + assert np.isclose(fsc_resolution[0][0], 3.225/2, rtol=0.01) def test_fsc_noise(volume_fixture): vol_a, _, vol_n = volume_fixture fsc_resolution, fsc = vol_a.fsc(vol_n, pixel_size=1) - assert np.isclose(fsc_resolution[0][0], 2.6, rtol=0.3) + assert np.isclose(fsc_resolution[0][0], 2.6/2, rtol=0.3) From 13a1e2c27899121b824a1cbeb82653d898fcba66 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 7 Apr 2023 09:25:40 -0400 Subject: [PATCH 021/116] add plotting util --- src/aspire/utils/resolution_estimation.py | 50 +++++++++++++++++++++-- tests/test_fourier_corrs.py | 18 +++++--- 2 files changed, 59 insertions(+), 9 deletions(-) diff --git a/src/aspire/utils/resolution_estimation.py b/src/aspire/utils/resolution_estimation.py index 6e9734c735..ebd6ee0ee4 100644 --- a/src/aspire/utils/resolution_estimation.py +++ b/src/aspire/utils/resolution_estimation.py @@ -3,6 +3,7 @@ """ import logging +import matplotlib.pyplot as plt import numpy as np from aspire.numeric import fft @@ -173,6 +174,7 @@ def analyze_correlations(self): frequencies = self._freq(c_ind) # Convert to resolution in Angstrom, smaller is higher frequency. + # TODO: handle 0 freq self._resolutions = 1 / frequencies def _freq(self, k): @@ -184,19 +186,57 @@ def _freq(self, k): Given a real space signal observed with `L` bins (pixels/voxels), each with a `pixel_size` in Angstrom, we can compute the width of a Fourier space bin to be the `Bandwidth / L = (2*(1/pixel_size)) / L`. - Thus the frequency at an index `k` is `freq_k = k * 2 * (1 / pixel_size) / L = 2*k / (pixel_size * L) + Thus the frequency at an index `k` is `freq_k = k * 2 * (1 / pixel_size) / L = 2*k / (pixel_size * L) """ - + # _freq(k) Units: 1 / (pixels * (Angstrom / pixel) = 1 / Angstrom # Similar idea to wavenumbers (cm-1). Larger is higher frequency. return k * 2 / (self.L * self.pixel_size) - def plot(self, to_file=False): """ Generates a Fourier correlation plot. """ - + + # Construct x-axis labels + x_inds = np.arange(self.L // 2) + freqs = self._freq(x_inds) + # TODO: handle zero freq + freqs_angstrom = 1 / freqs + + # Check we're asking for a reasonable plot. + stack = self.correlations.shape[: -self.dim] + if len(stack) > 2: + raise RuntimeError( + f"Unable to plot figure tables with more than 2 dim, stack shape {stack}. Try reducing to a simpler request." + ) + if np.prod(stack) > 1: + raise RuntimeError( + f"Unable to plot figure tables with more than 1 figures, stack shape {stack}. Try reducing to a simpler request." + ) + + plt.figure(figsize=(8, 6)) + plt.title(self._plot_title) + plt.xlabel("Resolution (Angstrom)") + plt.ylabel("Correlation") + plt.ylim([0, 1]) + plt.plot(freqs_angstrom, self.correlations[0][0]) + # Display cutoff + plt.axhline( + y=self.cutoff, color="r", linestyle="--", label=f"cutoff={self.cutoff}" + ) + # Display resolution + plt.axvline( + x=self.estimated_resolution[0][0], + color="b", + linestyle=":", + label=f"Resolution={self.estimated_resolution[0][0]:.3f}", + ) + # x-axis in decreasing + plt.gca().invert_xaxis() + plt.legend() + plt.show() + class FourierRingCorrelation(_FourierCorrelation): """ @@ -204,6 +244,7 @@ class FourierRingCorrelation(_FourierCorrelation): """ dim = 2 + _plot_title = "Fourier Ring Correlation" class FourierShellCorrelation(_FourierCorrelation): @@ -212,3 +253,4 @@ class FourierShellCorrelation(_FourierCorrelation): """ dim = 3 + _plot_title = "Fourier Shell Correlation" diff --git a/tests/test_fourier_corrs.py b/tests/test_fourier_corrs.py index d5f7640da8..a117e9236e 100644 --- a/tests/test_fourier_corrs.py +++ b/tests/test_fourier_corrs.py @@ -7,7 +7,7 @@ from aspire.noise import BlueNoiseAdder from aspire.numeric import fft from aspire.source import Simulation -from aspire.utils import Rotation, grid_3d +from aspire.utils import FourierShellCorrelation, Rotation, grid_3d from aspire.volume import Volume DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data") @@ -119,14 +119,14 @@ def test_frc_rot(image_fixture): img_a, img_b, _ = image_fixture assert img_a.dtype == img_b.dtype frc_resolution, frc = img_a.frc(img_b, pixel_size=1) - assert np.isclose(frc_resolution[0][0], 3.78/2, rtol=0.1) + assert np.isclose(frc_resolution[0][0], 3.78 / 2, rtol=0.1) def test_frc_noise(image_fixture): img_a, _, img_n = image_fixture frc_resolution, frc = img_a.frc(img_n, pixel_size=1) - assert np.isclose(frc_resolution[0][0], 3.5/2, rtol=0.2) + assert np.isclose(frc_resolution[0][0], 3.5 / 2, rtol=0.2) # FSC @@ -144,11 +144,19 @@ def test_fsc_rot(volume_fixture): vol_a, vol_b, _ = volume_fixture fsc_resolution, fsc = vol_a.fsc(vol_b, pixel_size=1) - assert np.isclose(fsc_resolution[0][0], 3.225/2, rtol=0.01) + assert np.isclose(fsc_resolution[0][0], 3.225 / 2, rtol=0.01) def test_fsc_noise(volume_fixture): vol_a, _, vol_n = volume_fixture fsc_resolution, fsc = vol_a.fsc(vol_n, pixel_size=1) - assert np.isclose(fsc_resolution[0][0], 2.6/2, rtol=0.3) + assert np.isclose(fsc_resolution[0][0], 2.6 / 2, rtol=0.3) + + +def test_fsc_plot(volume_fixture): + vol_a, _, vol_n = volume_fixture + + # fsc_resolution, fsc = vol_a.fsc(vol_n, pixel_size=1) + fsc = FourierShellCorrelation(vol_a.asnumpy(), vol_n.asnumpy(), pixel_size=1) + fsc.plot() From 42a4907b942f78b7b4cbdf4ca3884adf43aa74db Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 7 Apr 2023 09:44:51 -0400 Subject: [PATCH 022/116] add figure save to plot and cleanup strings --- src/aspire/utils/resolution_estimation.py | 33 ++++++++++++++++------- tests/test_fourier_corrs.py | 19 ++++++++++--- 2 files changed, 39 insertions(+), 13 deletions(-) diff --git a/src/aspire/utils/resolution_estimation.py b/src/aspire/utils/resolution_estimation.py index ebd6ee0ee4..b23f424614 100644 --- a/src/aspire/utils/resolution_estimation.py +++ b/src/aspire/utils/resolution_estimation.py @@ -179,23 +179,34 @@ def analyze_correlations(self): def _freq(self, k): """ - Converts `k` from index of Fourier transform to frequency (as length 1/A). - - From Shannon-Nyquist, for a given pixel-size, sampling theorem limits us to the sampled frequency 1/pixel_size. - Thus the Bandwidth ranges from `[-1/pixel_size, 1/pixel_size]`, so the total bandwidth is `2*(1/pixel_size)`. - - Given a real space signal observed with `L` bins (pixels/voxels), each with a `pixel_size` in Angstrom, - we can compute the width of a Fourier space bin to be the `Bandwidth / L = (2*(1/pixel_size)) / L`. - Thus the frequency at an index `k` is `freq_k = k * 2 * (1 / pixel_size) / L = 2*k / (pixel_size * L) + Converts `k` from index of Fourier transform to frequency (as + length 1/A). + + From Shannon-Nyquist, for a given pixel-size, sampling theorem + limits us to the sampled frequency 1/pixel_size. Thus the + Bandwidth ranges from `[-1/pixel_size, 1/pixel_size]`, so the + total bandwidth is `2*(1/pixel_size)`. + + Given a real space signal observed with `L` bins + (pixels/voxels), each with a `pixel_size` in Angstrom, we can + compute the width of a Fourier space bin to be the `Bandwidth + / L = (2*(1/pixel_size)) / L`. Thus the frequency at an index + `k` is `freq_k = k * 2 * (1 / pixel_size) / L = 2*k / + (pixel_size * L) """ # _freq(k) Units: 1 / (pixels * (Angstrom / pixel) = 1 / Angstrom # Similar idea to wavenumbers (cm-1). Larger is higher frequency. return k * 2 / (self.L * self.pixel_size) - def plot(self, to_file=False): + def plot(self, save_to_file=False): """ Generates a Fourier correlation plot. + + :param save_to_file: Optionally, save plot to file. + Defaults False, enabled by providing a string filename. + User is responsible for providing reasonable filename. + See `https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.savefig.html`. """ # Construct x-axis labels @@ -235,6 +246,10 @@ def plot(self, to_file=False): # x-axis in decreasing plt.gca().invert_xaxis() plt.legend() + + if save_to_file: + plt.savefig(save_to_file) + plt.show() diff --git a/tests/test_fourier_corrs.py b/tests/test_fourier_corrs.py index a117e9236e..cd87cb69cf 100644 --- a/tests/test_fourier_corrs.py +++ b/tests/test_fourier_corrs.py @@ -1,5 +1,6 @@ import logging import os +import tempfile import numpy as np import pytest @@ -10,6 +11,8 @@ from aspire.utils import FourierShellCorrelation, Rotation, grid_3d from aspire.volume import Volume +from .test_utils import matplotlib_no_gui + DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data") logger = logging.getLogger(__name__) @@ -155,8 +158,16 @@ def test_fsc_noise(volume_fixture): def test_fsc_plot(volume_fixture): - vol_a, _, vol_n = volume_fixture + """ + Smoke test the plots. + """ + vol_a, vol_rot, _ = volume_fixture + + fsc = FourierShellCorrelation(vol_a.asnumpy(), vol_rot.asnumpy(), pixel_size=1) + + with matplotlib_no_gui(): + fsc.plot() - # fsc_resolution, fsc = vol_a.fsc(vol_n, pixel_size=1) - fsc = FourierShellCorrelation(vol_a.asnumpy(), vol_n.asnumpy(), pixel_size=1) - fsc.plot() + with tempfile.TemporaryDirectory() as tmp_input_dir: + file_path = os.path.join(tmp_input_dir, "fsc_curve.png") + fsc.plot(save_to_file=file_path) From b9f5db6d94ebf09e82d0e9df023ef470c4467060 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 7 Apr 2023 10:06:10 -0400 Subject: [PATCH 023/116] fixup vol noise fsc test and plotting test --- src/aspire/utils/resolution_estimation.py | 2 +- tests/test_fourier_corrs.py | 24 ++++++++++++++--------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/src/aspire/utils/resolution_estimation.py b/src/aspire/utils/resolution_estimation.py index b23f424614..ee9b61dec8 100644 --- a/src/aspire/utils/resolution_estimation.py +++ b/src/aspire/utils/resolution_estimation.py @@ -230,7 +230,7 @@ def plot(self, save_to_file=False): plt.title(self._plot_title) plt.xlabel("Resolution (Angstrom)") plt.ylabel("Correlation") - plt.ylim([0, 1]) + plt.ylim([0, 1.1]) plt.plot(freqs_angstrom, self.correlations[0][0]) # Display cutoff plt.axhline( diff --git a/tests/test_fourier_corrs.py b/tests/test_fourier_corrs.py index cd87cb69cf..3fa162bb08 100644 --- a/tests/test_fourier_corrs.py +++ b/tests/test_fourier_corrs.py @@ -97,12 +97,11 @@ def volume_fixture(img_size, dtype): vol_rot = vol.rotate(rots) - # Scale gaussian noise radially - noise = np.random.normal(loc=0, scale=1, size=vol.shape).astype(dtype, copy=False) - noise = noise * (1.0 + grid_3d(img_size, normalized=False)["r"]) * 0.3 - vol_noise = Volume( - np.real(fft.centered_ifftn(fft.centered_fftn(vol.asnumpy()[0]) * (1 + noise))) - ) + # Add some noise to the volume + noise = np.random.normal( + loc=0, scale=np.cbrt(vol.asnumpy().var()), size=vol.shape + ).astype(dtype, copy=False) + vol_noise = vol + noise return vol, vol_rot, vol_noise @@ -154,20 +153,27 @@ def test_fsc_noise(volume_fixture): vol_a, _, vol_n = volume_fixture fsc_resolution, fsc = vol_a.fsc(vol_n, pixel_size=1) - assert np.isclose(fsc_resolution[0][0], 2.6 / 2, rtol=0.3) + assert fsc_resolution[0][0] > 4 def test_fsc_plot(volume_fixture): """ Smoke test the plots. + + Also tests resetting the cutoff. """ - vol_a, vol_rot, _ = volume_fixture + vol_a, vol_b, _ = volume_fixture - fsc = FourierShellCorrelation(vol_a.asnumpy(), vol_rot.asnumpy(), pixel_size=1) + fsc = FourierShellCorrelation( + vol_a.asnumpy(), vol_b.asnumpy(), pixel_size=1, cutoff=0.5 + ) with matplotlib_no_gui(): fsc.plot() + # Reset cutoff + fsc.cutoff = 0.143 + with tempfile.TemporaryDirectory() as tmp_input_dir: file_path = os.path.join(tmp_input_dir, "fsc_curve.png") fsc.plot(save_to_file=file_path) From 1753332ea2f7280560fae5c4624dd3197f259d5f Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 7 Apr 2023 11:24:48 -0400 Subject: [PATCH 024/116] migrate to Blue and Pink Filters --- src/aspire/noise/noise.py | 54 ++++++++++++++------------------ src/aspire/operators/__init__.py | 2 ++ src/aspire/operators/filters.py | 42 +++++++++++++++++++++++++ 3 files changed, 67 insertions(+), 31 deletions(-) diff --git a/src/aspire/noise/noise.py b/src/aspire/noise/noise.py index 9d7a05f7df..71935c16cc 100644 --- a/src/aspire/noise/noise.py +++ b/src/aspire/noise/noise.py @@ -6,7 +6,14 @@ from aspire.image import Image from aspire.image.xform import Xform from aspire.numeric import fft, xp -from aspire.operators import ArrayFilter, FunctionFilter, PowerFilter, ScalarFilter +from aspire.operators import ( + ArrayFilter, + BlueFilter, + FunctionFilter, + PinkFilter, + PowerFilter, + ScalarFilter, +) from aspire.utils import grid_2d, randn, trange logger = logging.getLogger(__name__) @@ -183,49 +190,34 @@ def signal_power(self, p): self._build() -class ColoredNoiseAdder(WhiteNoiseAdder): - @abc.abstractmethod - def _spectrum(self, x, y): - """ - Colored noise spectrum (2d). - """ +class BlueNoiseAdder(WhiteNoiseAdder): + """ + NoiseAdder where noise power increases with frequency. + """ def _build(self): """ Builds underlying Filter for this NoiseAdder. """ - custom_filter = FunctionFilter(f=self._spectrum) * ScalarFilter( - value=self.noise_var - ) + blue_filter = BlueFilter(value=self.noise_var) # Call the __init__ from parent of WhiteNoiseAdder. - super(WhiteNoiseAdder, self).__init__( - noise_filter=custom_filter, seed=self.seed - ) + super(WhiteNoiseAdder, self).__init__(noise_filter=blue_filter, seed=self.seed) -class BlueNoiseAdder(ColoredNoiseAdder): - """ - NoiseAdder where noise power increases with frequency. - """ - - def _spectrum(self, x, y): - s = x[-1] - x[-2] - f = s * np.hypot(x, y) - m = np.mean(f) - return f / m - - -class PinkNoiseAdder(ColoredNoiseAdder): +class PinkNoiseAdder(WhiteNoiseAdder): """ NoiseAdder where noise power decreases with frequency. """ - def _spectrum(self, x, y): - s = x[-1] - x[-2] - f = 2 * s / (np.hypot(x, y) + s) - m = np.mean(f) - return f / m + def _build(self): + """ + Builds underlying Filter for this NoiseAdder. + """ + pink_filter = PinkFilter(value=self.noise_var) + + # Call the __init__ from parent of WhiteNoiseAdder. + super(WhiteNoiseAdder, self).__init__(noise_filter=pink_filter, seed=self.seed) class NoiseEstimator: diff --git a/src/aspire/operators/__init__.py b/src/aspire/operators/__init__.py index 27789eaae8..442edc79fb 100644 --- a/src/aspire/operators/__init__.py +++ b/src/aspire/operators/__init__.py @@ -1,6 +1,7 @@ from .blk_diag_matrix import BlkDiagMatrix from .filters import ( ArrayFilter, + BlueFilter, CTFFilter, DualFilter, Filter, @@ -8,6 +9,7 @@ IdentityFilter, LambdaFilter, MultiplicativeFilter, + PinkFilter, PowerFilter, RadialCTFFilter, ScalarFilter, diff --git a/src/aspire/operators/filters.py b/src/aspire/operators/filters.py index b89598b003..b07e21f2f0 100644 --- a/src/aspire/operators/filters.py +++ b/src/aspire/operators/filters.py @@ -458,3 +458,45 @@ def __init__( alpha=alpha, B=B, ) + + +class BlueFilter(Filter): + """ + Filter where power increases with frequency. + """ + + def __init__(self, dim=None, value=1): + super().__init__(dim=dim, radial=True) + self.value = value + + def __repr__(self): + return f"BlueFilter (dim={self.dim}, value={self.value})" + + def _evaluate(self, omega): + f = np.sqrt(omega[0]) + m = np.mean(f) + f = f / m + + return self.value * f + + +class PinkFilter(Filter): + """ + Filter where power decreases with frequency. + """ + + def __init__(self, dim=None, value=1): + super().__init__(dim=dim, radial=True) + self.value = value + + def __repr__(self): + return f"PinkFilter (dim={self.dim}, value={self.value})" + + def _evaluate(self, omega): + step = np.abs(np.subtract(*omega[0][:2])) + # Avoid zero division + f = np.sqrt(2 * step / (omega[0] + step)) + m = np.mean(f) + f = f / m + + return self.value * f From 3b99ad3a183b5ed475f1a26a87d5a03043708377 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 7 Apr 2023 11:26:45 -0400 Subject: [PATCH 025/116] capture frequency div0 warning in context --- src/aspire/utils/resolution_estimation.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/aspire/utils/resolution_estimation.py b/src/aspire/utils/resolution_estimation.py index ee9b61dec8..1133d5605f 100644 --- a/src/aspire/utils/resolution_estimation.py +++ b/src/aspire/utils/resolution_estimation.py @@ -212,8 +212,9 @@ def plot(self, save_to_file=False): # Construct x-axis labels x_inds = np.arange(self.L // 2) freqs = self._freq(x_inds) - # TODO: handle zero freq - freqs_angstrom = 1 / freqs + # TODO: handle zero freq better + with np.errstate(divide="ignore"): + freqs_angstrom = 1 / freqs # Check we're asking for a reasonable plot. stack = self.correlations.shape[: -self.dim] From 3f6345779d818a783e152c79a2dd6a92338bdfa9 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 7 Apr 2023 11:42:22 -0400 Subject: [PATCH 026/116] Docsting and import cleanup --- src/aspire/noise/noise.py | 1 - src/aspire/utils/resolution_estimation.py | 68 +++++++++++++++++------ tests/test_fourier_corrs.py | 3 +- 3 files changed, 52 insertions(+), 20 deletions(-) diff --git a/src/aspire/noise/noise.py b/src/aspire/noise/noise.py index 71935c16cc..0333584f38 100644 --- a/src/aspire/noise/noise.py +++ b/src/aspire/noise/noise.py @@ -9,7 +9,6 @@ from aspire.operators import ( ArrayFilter, BlueFilter, - FunctionFilter, PinkFilter, PowerFilter, ScalarFilter, diff --git a/src/aspire/utils/resolution_estimation.py b/src/aspire/utils/resolution_estimation.py index 1133d5605f..b5e0049e2e 100644 --- a/src/aspire/utils/resolution_estimation.py +++ b/src/aspire/utils/resolution_estimation.py @@ -27,11 +27,12 @@ class _FourierCorrelation: . """ - def __init__(self, a, b, pixel_size, cutoff=0.143, eps=1e-4): + def __init__(self, a, b, pixel_size=1, cutoff=0.143, eps=1e-4): """ :param a: Input array a, shape(..., *dim). :param b: Input array b, shape(..., *dim). :param pixel_size: Pixel size in Angstrom. + Defaults to 1. :param cutoff: Cutoff value, traditionally `.143`. :param eps: Epsilon past boundary values, defaults 1e-4. """ @@ -70,30 +71,56 @@ def __init__(self, a, b, pixel_size, cutoff=0.143, eps=1e-4): @property def _fourier_axes(self): + """ + Returns tuple representing the axis containing signal data + based on dimension `dim`. + """ return tuple(range(-self.dim, 0)) def _reshape(self, x): """ - Returns `x` with flattened stack axis and `x`'s original stack shape, as determined by `dim`. + Returns `x` with flattened stack axis and `x`'s original stack + shape, as determined by `dim`. :param x: Numpy ndarray + + :return: (stack flattened x, x_stack_shape) """ - # TODO, check 2d in put for dim=2 (singleton case) + # TODO, check 2d input for dim=2 (singleton case) original_stack_shape = x.shape[: -self.dim] x = x.reshape(-1, *x.shape[-self.dim :]) return x, original_stack_shape @property def cutoff(self): + """ + Returns `cutoff` value. + """ return self._cutoff @cutoff.setter - def cutoff(self, cutoff_correlation): - self._cutoff = float(cutoff_correlation) + def cutoff(self, cutoff): + """ + Sets `cutoff` value, and resets analysis, which is dependent + on `cutoff` values. + + :param cutoff: Float + """ + self._cutoff = float(cutoff) + if not (0 <= self._cutoff <= 1): + raise ValueError( + "Supplied correlation `cutoff` not in [0,1], {self._cutoff}" + ) self._analyzed = False # reset analysis @property def correlations(self): + """ + Compute and return the Fourier correlations of signal stacks a + cross b. + + :return: Numpy array + """ # There is no need to run this twice if we assume inputs are immutable if self._correlations is not None: return self._correlations @@ -143,7 +170,11 @@ def correlations(self): @property def estimated_resolution(self): - """ """ + """ + Return estimated resolution of stacks `a` cross `b`. + + :return: Numpy array. + """ self.analyze_correlations() return self._resolutions @@ -182,19 +213,22 @@ def _freq(self, k): Converts `k` from index of Fourier transform to frequency (as length 1/A). - From Shannon-Nyquist, for a given pixel-size, sampling theorem - limits us to the sampled frequency 1/pixel_size. Thus the - Bandwidth ranges from `[-1/pixel_size, 1/pixel_size]`, so the - total bandwidth is `2*(1/pixel_size)`. - - Given a real space signal observed with `L` bins - (pixels/voxels), each with a `pixel_size` in Angstrom, we can - compute the width of a Fourier space bin to be the `Bandwidth - / L = (2*(1/pixel_size)) / L`. Thus the frequency at an index - `k` is `freq_k = k * 2 * (1 / pixel_size) / L = 2*k / - (pixel_size * L) + :param k: Frequency index, integer or Numpy array of ints. + :return: Frequency in 1/Angstrom. """ + # From Shannon-Nyquist, for a given pixel-size, sampling theorem + # limits us to the sampled frequency 1/pixel_size. Thus the + # Bandwidth ranges from `[-1/pixel_size, 1/pixel_size]`, so the + # total bandwidth is `2*(1/pixel_size)`. + + # Given a real space signal observed with `L` bins + # (pixels/voxels), each with a `pixel_size` in Angstrom, we can + # compute the width of a Fourier space bin to be the `Bandwidth + # / L = (2*(1/pixel_size)) / L`. Thus the frequency at an index + # `k` is `freq_k = k * 2 * (1 / pixel_size) / L = 2*k / + # (pixel_size * L) + # _freq(k) Units: 1 / (pixels * (Angstrom / pixel) = 1 / Angstrom # Similar idea to wavenumbers (cm-1). Larger is higher frequency. return k * 2 / (self.L * self.pixel_size) diff --git a/tests/test_fourier_corrs.py b/tests/test_fourier_corrs.py index 3fa162bb08..d2f54547a0 100644 --- a/tests/test_fourier_corrs.py +++ b/tests/test_fourier_corrs.py @@ -6,9 +6,8 @@ import pytest from aspire.noise import BlueNoiseAdder -from aspire.numeric import fft from aspire.source import Simulation -from aspire.utils import FourierShellCorrelation, Rotation, grid_3d +from aspire.utils import FourierShellCorrelation, Rotation from aspire.volume import Volume from .test_utils import matplotlib_no_gui From 261eb2fbd136283ceda7edbe27d1dbd87df23be9 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 7 Apr 2023 11:54:50 -0400 Subject: [PATCH 027/116] minor rebase import conflict (whitespace) --- src/aspire/utils/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/aspire/utils/__init__.py b/src/aspire/utils/__init__.py index dee6ed6068..ec5fb68230 100644 --- a/src/aspire/utils/__init__.py +++ b/src/aspire/utils/__init__.py @@ -31,7 +31,6 @@ ) from .logging import LogFilterByCount, get_full_version, tqdm, trange - from .matrix import ( acorr, ainner, From d962822807247bbd4ca257c5c04b11cd4488d20e Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 12 Apr 2023 14:37:55 -0400 Subject: [PATCH 028/116] Add nufft for frc still need to fixup shapes for 3d --- src/aspire/image/image.py | 5 +- src/aspire/utils/__init__.py | 2 +- src/aspire/utils/resolution_estimation.py | 61 +++++++++++++++++++++-- src/aspire/volume/volume.py | 6 ++- tests/test_fourier_corrs.py | 36 +++++++------ 5 files changed, 88 insertions(+), 22 deletions(-) diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index b99d58cf24..f5c7d3838f 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -477,7 +477,7 @@ def show(self, columns=5, figsize=(20, 10), colorbar=True): plt.show() - def frc(self, other, pixel_size, cutoff=0.143, eps=1e-4): + def frc(self, other, pixel_size, cutoff=0.143, eps=1e-4, method="fft"): r""" Compute the Fourier ring correlation between two images. @@ -495,6 +495,8 @@ def frc(self, other, pixel_size, cutoff=0.143, eps=1e-4): For synthetic data, 1 is a reasonable value. :param cutoff: Cutoff value, traditionally `1.43`. :param eps: Epsilon past boundary values, defaults 1e-4. + :param method: Selects either 'fft' (on cartesian grid), + or 'nufft' (on polar grid). Defaults to 'fft'. :return: tuple(estimated_resolution, FRC), where `estimated_resolution` is in Angstrom @@ -512,6 +514,7 @@ def frc(self, other, pixel_size, cutoff=0.143, eps=1e-4): pixel_size=pixel_size, cutoff=cutoff, eps=eps, + method=method, ) return frc.estimated_resolution, frc.correlations diff --git a/src/aspire/utils/__init__.py b/src/aspire/utils/__init__.py index ec5fb68230..7a612d0311 100644 --- a/src/aspire/utils/__init__.py +++ b/src/aspire/utils/__init__.py @@ -1,3 +1,4 @@ +from .types import complex_type, real_type, utest_tolerance # isort:skip from .coor_trans import ( # isort:skip common_line_from_rots, crop_pad_2d, @@ -66,5 +67,4 @@ from .relion_interop import RelionStarFile, relion_metadata_fields from .resolution_estimation import FourierRingCorrelation, FourierShellCorrelation from .rotation import Rotation -from .types import complex_type, real_type, utest_tolerance from .units import ratio_to_decibel, voltage_to_wavelength, wavelength_to_voltage diff --git a/src/aspire/utils/resolution_estimation.py b/src/aspire/utils/resolution_estimation.py index b5e0049e2e..f57ca6b116 100644 --- a/src/aspire/utils/resolution_estimation.py +++ b/src/aspire/utils/resolution_estimation.py @@ -6,6 +6,7 @@ import matplotlib.pyplot as plt import numpy as np +from aspire.nufft import nufft from aspire.numeric import fft from aspire.utils import grid_2d @@ -27,7 +28,7 @@ class _FourierCorrelation: . """ - def __init__(self, a, b, pixel_size=1, cutoff=0.143, eps=1e-4): + def __init__(self, a, b, pixel_size=1, cutoff=0.143, eps=1e-4, method="fft"): """ :param a: Input array a, shape(..., *dim). :param b: Input array b, shape(..., *dim). @@ -35,6 +36,8 @@ def __init__(self, a, b, pixel_size=1, cutoff=0.143, eps=1e-4): Defaults to 1. :param cutoff: Cutoff value, traditionally `.143`. :param eps: Epsilon past boundary values, defaults 1e-4. + :param method: Selects either 'fft' (on cartesian grid), + or 'nufft' (on polar grid). Defaults to 'fft'. """ # Sanity checks @@ -56,6 +59,15 @@ def __init__(self, a, b, pixel_size=1, cutoff=0.143, eps=1e-4): f"`a` and `b` appear to have different data axis shapes, {a.shape[-1]} {b.shape[-1]}" ) + # Method selection + methods = {"fft": self._fft_correlations, "nufft": self._nufft_correlations} + if method not in methods: + raise RuntimeError( + f"Requested method {method} not in available methods {methods}." + ) + self.method = method + self._correlation_method = methods[self.method] + # To support arbitrary broadcasting simply, # we'll force all shapes to be (-1, *(L,)*dim) self._a, self._a_stack_shape = self._reshape(a) @@ -121,10 +133,20 @@ def correlations(self): :return: Numpy array """ - # There is no need to run this twice if we assume inputs are immutable + # There is no need to run this twice if we assume inputs are immutable. if self._correlations is not None: return self._correlations + # Compute the correlations + self._correlations = self._correlation_method() + + return self._correlations + + def _fft_correlations(self): + """ + Computes Fourier Correlations using the FFT on a cartesian grid. + """ + # Compute shells from 2D grid. radii = grid_2d(self.L, shifted=True, normalized=False, dtype=self.dtype)["r"] @@ -163,10 +185,39 @@ def correlations(self): # Repack the table as (_a, _b, L//2) correlations = np.swapaxes(correlations, 0, 2) # Then unpack the a and b shapes. - self._correlations = correlations.reshape( + return correlations.reshape( *self._a_stack_shape, *self._b_stack_shape, self.L // 2 ) - return self._correlations + + def _nufft_correlations(self): + """ + Computes Fourier Correlations using the NUFFT on polar grid. + """ + + # TODO, should we use an internal tool (Polar2D?) for this + r = np.linspace(0, np.pi, self.L, endpoint=False, dtype=self.dtype) + phi = np.linspace(0, 2 * np.pi, 2 * self.L, endpoint=False, dtype=self.dtype) + x = r[:, np.newaxis] * np.cos(phi[np.newaxis, :]) + y = r[:, np.newaxis] * np.sin(phi[np.newaxis, :]) + fourier_pts = np.vstack((x.flatten(), y.flatten())) + + # Stack signal data. Note, we want a complex result. + signal = np.vstack((self._a, self._b)) + # Compute NUFFT and unpack as two 1D stacks of the polar grid + # points, one for each image. + f1, f2 = nufft(signal, fourier_pts, real=False).reshape( + 2, self._a.shape[0], len(r), len(phi) + ) + + # Compute the Fourier correlations + cov = np.sum(f1 * np.conj(f2), -1).real + norm1 = np.sqrt(np.sum(np.abs(f1) ** 2, -1)) + norm2 = np.sqrt(np.sum(np.abs(f2) ** 2, -1)) + + correlations = np.mean(cov / (norm1 * norm2), 0) + + # Then unpack the a and b shapes. + return correlations.reshape(*self._a_stack_shape, *self._b_stack_shape, self.L) @property def estimated_resolution(self): @@ -280,7 +331,7 @@ def plot(self, save_to_file=False): ) # x-axis in decreasing plt.gca().invert_xaxis() - plt.legend() + plt.legend(title=f"Method: {self.method}") if save_to_file: plt.savefig(save_to_file) diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index 81daede16f..1b05fb0ad1 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -494,7 +494,7 @@ def load(cls, filename, permissive=True, dtype=np.float32): logger.info(f"{filename} with dtype {loaded_data.dtype} loaded as {dtype}") return cls(loaded_data.astype(dtype)) - def fsc(self, other, pixel_size=None, cutoff=0.143, eps=1e-4, dtype=None): + def fsc(self, other, pixel_size=None, cutoff=0.143, eps=1e-4, method="fft"): r""" Compute the Fourier shell correlation between two volumes. @@ -512,6 +512,9 @@ def fsc(self, other, pixel_size=None, cutoff=0.143, eps=1e-4, dtype=None): For synthetic data, 1 is a reasonable value. :param cutoff: Cutoff value, traditionally `.143`. :param eps: Epsilon past boundary values, defaults 1e-4. + :param method: Selects either 'fft' (on cartesian grid), + or 'nufft' (on polar grid). Defaults to 'fft'. + :return: tuple(estimated_resolution, FRC), where `estimated_resolution` is in Angstrom and FRC is a Numpy array of correlations. @@ -529,6 +532,7 @@ def fsc(self, other, pixel_size=None, cutoff=0.143, eps=1e-4, dtype=None): pixel_size=pixel_size, cutoff=cutoff, eps=eps, + method=method, ) return fsc.estimated_resolution, fsc.correlations diff --git a/tests/test_fourier_corrs.py b/tests/test_fourier_corrs.py index d2f54547a0..2c91468026 100644 --- a/tests/test_fourier_corrs.py +++ b/tests/test_fourier_corrs.py @@ -24,6 +24,7 @@ np.float64, np.float32, ] +METHOD = ["fft", "nufft"] @pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}") @@ -31,6 +32,11 @@ def dtype(request): return request.param +@pytest.fixture(params=METHOD, ids=lambda x: f"method={x}") +def method(request): + return request.param + + @pytest.fixture(params=IMG_SIZES, ids=lambda x: f"img_size={x}") def img_size(request): return request.param @@ -108,54 +114,54 @@ def volume_fixture(img_size, dtype): # FRC -def test_frc_id(image_fixture): +def test_frc_id(image_fixture, method): img, _, _ = image_fixture - frc_resolution, frc = img.frc(img, pixel_size=1) + frc_resolution, frc = img.frc(img, pixel_size=1, method=method) assert np.isclose(frc_resolution[0][0], 1, rtol=0.02) assert np.allclose(frc, 1) -def test_frc_rot(image_fixture): +def test_frc_rot(image_fixture, method): img_a, img_b, _ = image_fixture assert img_a.dtype == img_b.dtype - frc_resolution, frc = img_a.frc(img_b, pixel_size=1) + frc_resolution, frc = img_a.frc(img_b, pixel_size=1, method=method) assert np.isclose(frc_resolution[0][0], 3.78 / 2, rtol=0.1) -def test_frc_noise(image_fixture): +def test_frc_noise(image_fixture, method): img_a, _, img_n = image_fixture - frc_resolution, frc = img_a.frc(img_n, pixel_size=1) + frc_resolution, frc = img_a.frc(img_n, pixel_size=1, method=method) assert np.isclose(frc_resolution[0][0], 3.5 / 2, rtol=0.2) # FSC -def test_fsc_id(volume_fixture): +def test_fsc_id(volume_fixture, method): vol, _, _ = volume_fixture - fsc_resolution, fsc = vol.fsc(vol, pixel_size=1) + fsc_resolution, fsc = vol.fsc(vol, pixel_size=1, method=method) assert np.isclose(fsc_resolution[0][0], 1, rtol=0.02) assert np.allclose(fsc, 1) -def test_fsc_rot(volume_fixture): +def test_fsc_rot(volume_fixture, method): vol_a, vol_b, _ = volume_fixture - fsc_resolution, fsc = vol_a.fsc(vol_b, pixel_size=1) + fsc_resolution, fsc = vol_a.fsc(vol_b, pixel_size=1, method=method) assert np.isclose(fsc_resolution[0][0], 3.225 / 2, rtol=0.01) -def test_fsc_noise(volume_fixture): +def test_fsc_noise(volume_fixture, method): vol_a, _, vol_n = volume_fixture - fsc_resolution, fsc = vol_a.fsc(vol_n, pixel_size=1) + fsc_resolution, fsc = vol_a.fsc(vol_n, pixel_size=1, method=method) assert fsc_resolution[0][0] > 4 -def test_fsc_plot(volume_fixture): +def test_fsc_plot(volume_fixture, method): """ Smoke test the plots. @@ -164,8 +170,10 @@ def test_fsc_plot(volume_fixture): vol_a, vol_b, _ = volume_fixture fsc = FourierShellCorrelation( - vol_a.asnumpy(), vol_b.asnumpy(), pixel_size=1, cutoff=0.5 + vol_a.asnumpy(), vol_b.asnumpy(), pixel_size=1, method=method, cutoff=0.5 ) + # test + fsc.plot() with matplotlib_no_gui(): fsc.plot() From 93cefa410dcfdd118c4791b4f47f580fcdd600f6 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 18 Apr 2023 10:27:43 -0400 Subject: [PATCH 029/116] stashing 3d fsc extension --- src/aspire/utils/resolution_estimation.py | 36 +++++++++++++++++++---- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/src/aspire/utils/resolution_estimation.py b/src/aspire/utils/resolution_estimation.py index f57ca6b116..272223801f 100644 --- a/src/aspire/utils/resolution_estimation.py +++ b/src/aspire/utils/resolution_estimation.py @@ -197,16 +197,41 @@ def _nufft_correlations(self): # TODO, should we use an internal tool (Polar2D?) for this r = np.linspace(0, np.pi, self.L, endpoint=False, dtype=self.dtype) phi = np.linspace(0, 2 * np.pi, 2 * self.L, endpoint=False, dtype=self.dtype) - x = r[:, np.newaxis] * np.cos(phi[np.newaxis, :]) - y = r[:, np.newaxis] * np.sin(phi[np.newaxis, :]) - fourier_pts = np.vstack((x.flatten(), y.flatten())) + if self.dim == 2: + x = r[:, np.newaxis] * np.cos(phi[np.newaxis, :]) + y = r[:, np.newaxis] * np.sin(phi[np.newaxis, :]) + fourier_pts = np.vstack((x.flatten(), y.flatten())) + # result_frame_shape = (len(r), len(phi)) + elif self.dim == 3: + theta = np.linspace(0, np.pi, self.L, endpoint=False, dtype=self.dtype) + x = ( + r[:, np.newaxis, np.newaxis] + * np.sin(theta[np.newaxis, :, np.newaxis]) + * np.cos(phi[np.newaxis, np.newaxis, :]) + ) + y = ( + r[:, np.newaxis, np.newaxis] + * np.sin(theta[np.newaxis, :, np.newaxis]) + * np.cos(phi[np.newaxis, np.newaxis, :]) + ) + z = ( + r[:, np.newaxis, np.newaxis] + * np.cos(theta[np.newaxis, :, np.newaxis]) + * np.ones((1, 1, 2 * self.L), dtype=self.dtype) + ) + fourier_pts = np.vstack((x.flatten(), y.flatten(), z.flatten())) + # result_frame_shape = (len(r), len(theta)*len(phi)) + else: + raise NotImplementedError( + "`nufft` based correlations only implemented for dimensions 2 and 3." + ) # Stack signal data. Note, we want a complex result. signal = np.vstack((self._a, self._b)) # Compute NUFFT and unpack as two 1D stacks of the polar grid # points, one for each image. f1, f2 = nufft(signal, fourier_pts, real=False).reshape( - 2, self._a.shape[0], len(r), len(phi) + 2, self._a.shape[0], len(r), -1 ) # Compute the Fourier correlations @@ -295,7 +320,8 @@ def plot(self, save_to_file=False): """ # Construct x-axis labels - x_inds = np.arange(self.L // 2) + # x_inds = np.arange(self.L // 2) + x_inds = np.arange(self.correlations.shape[-1]) freqs = self._freq(x_inds) # TODO: handle zero freq better with np.errstate(divide="ignore"): From 54c761d126739297b46bff26d83550ab02272c38 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 18 Apr 2023 12:09:56 -0400 Subject: [PATCH 030/116] Use same radial grid between polar nufft and cart fft for frc/fsc --- src/aspire/utils/resolution_estimation.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/aspire/utils/resolution_estimation.py b/src/aspire/utils/resolution_estimation.py index 272223801f..7be4707710 100644 --- a/src/aspire/utils/resolution_estimation.py +++ b/src/aspire/utils/resolution_estimation.py @@ -195,7 +195,8 @@ def _nufft_correlations(self): """ # TODO, should we use an internal tool (Polar2D?) for this - r = np.linspace(0, np.pi, self.L, endpoint=False, dtype=self.dtype) + # For now use L//2 for compatibility with cartesian. + r = np.linspace(0, np.pi, self.L//2, endpoint=False, dtype=self.dtype) phi = np.linspace(0, 2 * np.pi, 2 * self.L, endpoint=False, dtype=self.dtype) if self.dim == 2: x = r[:, np.newaxis] * np.cos(phi[np.newaxis, :]) @@ -242,7 +243,7 @@ def _nufft_correlations(self): correlations = np.mean(cov / (norm1 * norm2), 0) # Then unpack the a and b shapes. - return correlations.reshape(*self._a_stack_shape, *self._b_stack_shape, self.L) + return correlations.reshape(*self._a_stack_shape, *self._b_stack_shape, r.shape[-1]) @property def estimated_resolution(self): From d3a35c4c0bc144f3f2cfe6ccea97069c811c5d4c Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 18 Apr 2023 12:29:01 -0400 Subject: [PATCH 031/116] Remove superfluous mean() in frc/fsc nufft method. [skip ci] --- src/aspire/utils/resolution_estimation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aspire/utils/resolution_estimation.py b/src/aspire/utils/resolution_estimation.py index 7be4707710..1b874c3589 100644 --- a/src/aspire/utils/resolution_estimation.py +++ b/src/aspire/utils/resolution_estimation.py @@ -240,7 +240,7 @@ def _nufft_correlations(self): norm1 = np.sqrt(np.sum(np.abs(f1) ** 2, -1)) norm2 = np.sqrt(np.sum(np.abs(f2) ** 2, -1)) - correlations = np.mean(cov / (norm1 * norm2), 0) + correlations = cov / (norm1 * norm2) # Then unpack the a and b shapes. return correlations.reshape(*self._a_stack_shape, *self._b_stack_shape, r.shape[-1]) From ad4e11d4884be128bad65e2508fe56b4caf6a706 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 28 Apr 2023 09:12:32 -0400 Subject: [PATCH 032/116] Fix FSC nufft grid bug (sin sin), improved fsc volume test --- src/aspire/utils/resolution_estimation.py | 11 ++-- tests/test_fourier_corrs.py | 68 ++++++++++++++--------- 2 files changed, 48 insertions(+), 31 deletions(-) diff --git a/src/aspire/utils/resolution_estimation.py b/src/aspire/utils/resolution_estimation.py index 1b874c3589..d9b4c064a1 100644 --- a/src/aspire/utils/resolution_estimation.py +++ b/src/aspire/utils/resolution_estimation.py @@ -171,7 +171,7 @@ def _fft_correlations(self): r1 = ring_mask * f1 r2 = ring_mask * f2 - # Compute FRC + # Compute FC num = np.real(np.sum(r1 * np.conj(r2), axis=self._fourier_axes)) den = np.sqrt( np.sum(np.abs(r1) ** 2, axis=self._fourier_axes) @@ -196,7 +196,7 @@ def _nufft_correlations(self): # TODO, should we use an internal tool (Polar2D?) for this # For now use L//2 for compatibility with cartesian. - r = np.linspace(0, np.pi, self.L//2, endpoint=False, dtype=self.dtype) + r = np.linspace(0, np.pi, self.L // 2, endpoint=False, dtype=self.dtype) phi = np.linspace(0, 2 * np.pi, 2 * self.L, endpoint=False, dtype=self.dtype) if self.dim == 2: x = r[:, np.newaxis] * np.cos(phi[np.newaxis, :]) @@ -213,7 +213,7 @@ def _nufft_correlations(self): y = ( r[:, np.newaxis, np.newaxis] * np.sin(theta[np.newaxis, :, np.newaxis]) - * np.cos(phi[np.newaxis, np.newaxis, :]) + * np.sin(phi[np.newaxis, np.newaxis, :]) ) z = ( r[:, np.newaxis, np.newaxis] @@ -221,7 +221,6 @@ def _nufft_correlations(self): * np.ones((1, 1, 2 * self.L), dtype=self.dtype) ) fourier_pts = np.vstack((x.flatten(), y.flatten(), z.flatten())) - # result_frame_shape = (len(r), len(theta)*len(phi)) else: raise NotImplementedError( "`nufft` based correlations only implemented for dimensions 2 and 3." @@ -243,7 +242,9 @@ def _nufft_correlations(self): correlations = cov / (norm1 * norm2) # Then unpack the a and b shapes. - return correlations.reshape(*self._a_stack_shape, *self._b_stack_shape, r.shape[-1]) + return correlations.reshape( + *self._a_stack_shape, *self._b_stack_shape, r.shape[-1] + ) @property def estimated_resolution(self): diff --git a/tests/test_fourier_corrs.py b/tests/test_fourier_corrs.py index 2c91468026..ff707c8240 100644 --- a/tests/test_fourier_corrs.py +++ b/tests/test_fourier_corrs.py @@ -6,8 +6,9 @@ import pytest from aspire.noise import BlueNoiseAdder +from aspire.numeric import fft from aspire.source import Simulation -from aspire.utils import FourierShellCorrelation, Rotation +from aspire.utils import FourierRingCorrelation, FourierShellCorrelation, Rotation from aspire.volume import Volume from .test_utils import matplotlib_no_gui @@ -96,19 +97,19 @@ def volume_fixture(img_size, dtype): np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol.npy")), dtype=dtype ).downsample(img_size) - # Instantiate ASPIRE's Rotation class with a set of angles. - thetas = [0.12] - rots = Rotation.about_axis("z", thetas, dtype=dtype) - - vol_rot = vol.rotate(rots) + # Invert correlation for some high frequency content + # Convert volume to Fourier space. + vol_trunc_f = fft.centered_fftn(vol.asnumpy()[0]) + # Get a frequency index. + trunc_frq = img_size // 3 + # Negate the power for some frequencies higher than `trunc_frq`. + vol_trunc_f[-trunc_frq:, :, :] *= -1.0 + vol_trunc_f[:, -trunc_frq:, :] *= -1.0 + vol_trunc_f[:, :, -trunc_frq:] *= -1.0 + # Convert volume from Fourier space to real space Volume. + vol_trunc = Volume(fft.centered_ifftn(vol_trunc_f).real) - # Add some noise to the volume - noise = np.random.normal( - loc=0, scale=np.cbrt(vol.asnumpy().var()), size=vol.shape - ).astype(dtype, copy=False) - vol_noise = vol + noise - - return vol, vol_rot, vol_noise + return vol, vol_trunc # FRC @@ -136,29 +137,46 @@ def test_frc_noise(image_fixture, method): assert np.isclose(frc_resolution[0][0], 3.5 / 2, rtol=0.2) +def test_frc_plot(image_fixture, method): + """ + Smoke test the plots. + + Also tests resetting the cutoff. + """ + img_a, img_b, _ = image_fixture + + frc = FourierRingCorrelation( + img_a.asnumpy(), img_b.asnumpy(), pixel_size=1, method=method, cutoff=0.5 + ) + + with matplotlib_no_gui(): + frc.plot() + + with tempfile.TemporaryDirectory() as tmp_input_dir: + file_path = os.path.join(tmp_input_dir, "frc_curve.png") + frc.plot(save_to_file=file_path) + + # FSC def test_fsc_id(volume_fixture, method): - vol, _, _ = volume_fixture + vol, _ = volume_fixture fsc_resolution, fsc = vol.fsc(vol, pixel_size=1, method=method) assert np.isclose(fsc_resolution[0][0], 1, rtol=0.02) assert np.allclose(fsc, 1) -def test_fsc_rot(volume_fixture, method): - vol_a, vol_b, _ = volume_fixture +def test_fsc_trunc(volume_fixture, method): + vol_a, vol_b = volume_fixture fsc_resolution, fsc = vol_a.fsc(vol_b, pixel_size=1, method=method) - assert np.isclose(fsc_resolution[0][0], 3.225 / 2, rtol=0.01) - - -def test_fsc_noise(volume_fixture, method): - vol_a, _, vol_n = volume_fixture + assert fsc_resolution[0][0] > 1.5 - fsc_resolution, fsc = vol_a.fsc(vol_n, pixel_size=1, method=method) - assert fsc_resolution[0][0] > 4 + # The follow should correspond to the test_fsc_plot below. + fsc_resolution, fsc = vol_a.fsc(vol_b, pixel_size=1, method=method, cutoff=0.5) + assert fsc_resolution[0][0] > 2.0 def test_fsc_plot(volume_fixture, method): @@ -167,13 +185,11 @@ def test_fsc_plot(volume_fixture, method): Also tests resetting the cutoff. """ - vol_a, vol_b, _ = volume_fixture + vol_a, vol_b = volume_fixture fsc = FourierShellCorrelation( vol_a.asnumpy(), vol_b.asnumpy(), pixel_size=1, method=method, cutoff=0.5 ) - # test - fsc.plot() with matplotlib_no_gui(): fsc.plot() From 4468f4cd99ab87109da3847e9808c9846f95dc13 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 28 Apr 2023 10:05:18 -0400 Subject: [PATCH 033/116] Remove excess lines in blue/pink noise adder code. --- src/aspire/noise/noise.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/aspire/noise/noise.py b/src/aspire/noise/noise.py index 0333584f38..b52a4f4a48 100644 --- a/src/aspire/noise/noise.py +++ b/src/aspire/noise/noise.py @@ -198,10 +198,11 @@ def _build(self): """ Builds underlying Filter for this NoiseAdder. """ - blue_filter = BlueFilter(value=self.noise_var) # Call the __init__ from parent of WhiteNoiseAdder. - super(WhiteNoiseAdder, self).__init__(noise_filter=blue_filter, seed=self.seed) + super(WhiteNoiseAdder, self).__init__( + noise_filter=BlueFilter(value=self.noise_var), seed=self.seed + ) class PinkNoiseAdder(WhiteNoiseAdder): @@ -213,10 +214,11 @@ def _build(self): """ Builds underlying Filter for this NoiseAdder. """ - pink_filter = PinkFilter(value=self.noise_var) # Call the __init__ from parent of WhiteNoiseAdder. - super(WhiteNoiseAdder, self).__init__(noise_filter=pink_filter, seed=self.seed) + super(WhiteNoiseAdder, self).__init__( + noise_filter=PinkFilter(value=self.noise_var), seed=self.seed + ) class NoiseEstimator: From 0e98e59b27754c29e480da4f3820cbd71f4a1fc6 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 28 Apr 2023 11:09:02 -0400 Subject: [PATCH 034/116] polishing tests and cleanup before review --- src/aspire/operators/filters.py | 4 +- src/aspire/utils/resolution_estimation.py | 67 +++++++++-------- src/aspire/volume/volume.py | 9 +-- tests/test_fourier_corrs.py | 90 +++++++++++++++++++++-- 4 files changed, 127 insertions(+), 43 deletions(-) diff --git a/src/aspire/operators/filters.py b/src/aspire/operators/filters.py index b07e21f2f0..d8db65d7f0 100644 --- a/src/aspire/operators/filters.py +++ b/src/aspire/operators/filters.py @@ -470,7 +470,7 @@ def __init__(self, dim=None, value=1): self.value = value def __repr__(self): - return f"BlueFilter (dim={self.dim}, value={self.value})" + return f"BlueFilter(dim={self.dim}, value={self.value})" def _evaluate(self, omega): f = np.sqrt(omega[0]) @@ -490,7 +490,7 @@ def __init__(self, dim=None, value=1): self.value = value def __repr__(self): - return f"PinkFilter (dim={self.dim}, value={self.value})" + return f"PinkFilter(dim={self.dim}, value={self.value})" def _evaluate(self, omega): step = np.abs(np.subtract(*omega[0][:2])) diff --git a/src/aspire/utils/resolution_estimation.py b/src/aspire/utils/resolution_estimation.py index d9b4c064a1..41a50498ec 100644 --- a/src/aspire/utils/resolution_estimation.py +++ b/src/aspire/utils/resolution_estimation.py @@ -12,6 +12,9 @@ logger = logging.getLogger(__name__) +# _FourierCorrelation holds a single implementation for both FSC and +# FRC based on dimension `dim`. + class _FourierCorrelation: r""" @@ -45,7 +48,7 @@ def __init__(self, a, b, pixel_size=1, cutoff=0.143, eps=1e-4, method="fft"): raise RuntimeError("Subclass must assign `dim`") for x in (a, b): if not isinstance(x, np.ndarray): - raise TypeError(f"`{x.__name__}` is not a Numpy array.") + raise TypeError(f"`{x}` is not a Numpy array.") if not a.dtype == b.dtype: raise TypeError( @@ -63,7 +66,7 @@ def __init__(self, a, b, pixel_size=1, cutoff=0.143, eps=1e-4, method="fft"): methods = {"fft": self._fft_correlations, "nufft": self._nufft_correlations} if method not in methods: raise RuntimeError( - f"Requested method {method} not in available methods {methods}." + f"Requested method {method} not in available methods {list(methods.keys())}." ) self.method = method self._correlation_method = methods[self.method] @@ -133,12 +136,11 @@ def correlations(self): :return: Numpy array """ - # There is no need to run this twice if we assume inputs are immutable. - if self._correlations is not None: - return self._correlations - - # Compute the correlations - self._correlations = self._correlation_method() + # Cache _correlations. + # There is no need to run this twice assuming inputs are immutable. + if self._correlations is None: + # Compute the correlations + self._correlations = self._correlation_method() return self._correlations @@ -150,8 +152,7 @@ def _fft_correlations(self): # Compute shells from 2D grid. radii = grid_2d(self.L, shifted=True, normalized=False, dtype=self.dtype)["r"] - # Compute centered Fourier transforms, - # upcasting when nessecary. + # Compute centered Fourier transforms. f1 = fft.centered_fftn(self._a, axes=self._fourier_axes) f2 = fft.centered_fftn(self._b, axes=self._fourier_axes) @@ -171,7 +172,7 @@ def _fft_correlations(self): r1 = ring_mask * f1 r2 = ring_mask * f2 - # Compute FC + # Compute Fourier Correlations num = np.real(np.sum(r1 * np.conj(r2), axis=self._fourier_axes)) den = np.sqrt( np.sum(np.abs(r1) ** 2, axis=self._fourier_axes) @@ -184,26 +185,30 @@ def _fft_correlations(self): # Repack the table as (_a, _b, L//2) correlations = np.swapaxes(correlations, 0, 2) - # Then unpack the a and b shapes. + # Then unpack the original a and b shapes. return correlations.reshape( *self._a_stack_shape, *self._b_stack_shape, self.L // 2 ) def _nufft_correlations(self): """ - Computes Fourier Correlations using the NUFFT on polar grid. + Computes Fourier Correlations using the NUFFT on a polar grid. """ - # TODO, should we use an internal tool (Polar2D?) for this - # For now use L//2 for compatibility with cartesian. + # TODO, we could use an internal tool (Polar2D?) for this. + # L//2 is intentionally used for compatibility with cartesian grid. + # This avoids having to have multiple methods for computing resolutions later. r = np.linspace(0, np.pi, self.L // 2, endpoint=False, dtype=self.dtype) phi = np.linspace(0, 2 * np.pi, 2 * self.L, endpoint=False, dtype=self.dtype) if self.dim == 2: + # 2D Polar points x = r[:, np.newaxis] * np.cos(phi[np.newaxis, :]) y = r[:, np.newaxis] * np.sin(phi[np.newaxis, :]) + # Because the values will be summed later, ordering does not matter. fourier_pts = np.vstack((x.flatten(), y.flatten())) - # result_frame_shape = (len(r), len(phi)) + elif self.dim == 3: + # 3D Spherical points theta = np.linspace(0, np.pi, self.L, endpoint=False, dtype=self.dtype) x = ( r[:, np.newaxis, np.newaxis] @@ -220,28 +225,30 @@ def _nufft_correlations(self): * np.cos(theta[np.newaxis, :, np.newaxis]) * np.ones((1, 1, 2 * self.L), dtype=self.dtype) ) + # Because the values will be summed later, ordering does not matter. fourier_pts = np.vstack((x.flatten(), y.flatten(), z.flatten())) else: raise NotImplementedError( "`nufft` based correlations only implemented for dimensions 2 and 3." ) - # Stack signal data. Note, we want a complex result. + # Stack signal data to create a larger NUFFT problem (better performance). + # Note, we want a complex result. signal = np.vstack((self._a, self._b)) - # Compute NUFFT and unpack as two 1D stacks of the polar grid + # Compute NUFFT, then unpack as two 1D stacks of the polar grid # points, one for each image. f1, f2 = nufft(signal, fourier_pts, real=False).reshape( 2, self._a.shape[0], len(r), -1 ) - # Compute the Fourier correlations + # Compute the Fourier correlations. cov = np.sum(f1 * np.conj(f2), -1).real norm1 = np.sqrt(np.sum(np.abs(f1) ** 2, -1)) norm2 = np.sqrt(np.sum(np.abs(f2) ** 2, -1)) correlations = cov / (norm1 * norm2) - # Then unpack the a and b shapes. + # Then unpack the original a and b shapes. return correlations.reshape( *self._a_stack_shape, *self._b_stack_shape, r.shape[-1] ) @@ -260,6 +267,9 @@ def analyze_correlations(self): """ Convert from the Fourier Correlations to frequencies and resolution. """ + # `_analyzed` attribute in conjunction with `cutoff` allow a + # user to try different cutoffs without recomputing the + # correlations (FFT/NUFFT calls). if self._analyzed: return @@ -269,21 +279,16 @@ def analyze_correlations(self): # set index of highest sampled frequency. c_inds[np.min(self.correlations, axis=-1) > self.cutoff] = self.L // 2 - # # All correlations are below cutoff, - # # set index to 0 - # elif np.max(correlations) < cutoff: - # c_ind = 0 - # else: - # Correlations cross the cutoff. # Find the first index of a correlation at `cutoff`. + # Should return 0 if not found, which corresponded to the case + # where all correlations are below cutoff. c_ind = np.maximum(c_inds, np.argmax(self.correlations <= self.cutoff, axis=-1)) # Convert indices to frequency (as 1/Angstrom) frequencies = self._freq(c_ind) # Convert to resolution in Angstrom, smaller is higher frequency. - # TODO: handle 0 freq self._resolutions = 1 / frequencies def _freq(self, k): @@ -304,7 +309,7 @@ def _freq(self, k): # (pixels/voxels), each with a `pixel_size` in Angstrom, we can # compute the width of a Fourier space bin to be the `Bandwidth # / L = (2*(1/pixel_size)) / L`. Thus the frequency at an index - # `k` is `freq_k = k * 2 * (1 / pixel_size) / L = 2*k / + # `k` is `freq_k = k * 2 * (1 / pixel_size) / L = k * 2 / # (pixel_size * L) # _freq(k) Units: 1 / (pixels * (Angstrom / pixel) = 1 / Angstrom @@ -322,7 +327,6 @@ def plot(self, save_to_file=False): """ # Construct x-axis labels - # x_inds = np.arange(self.L // 2) x_inds = np.arange(self.correlations.shape[-1]) freqs = self._freq(x_inds) # TODO: handle zero freq better @@ -367,6 +371,11 @@ def plot(self, save_to_file=False): plt.show() +# The following are user facing classes, and simply wrap +# `_FourierCorrelation` after assigning dimension `dim` and any +# dimension specific variables. + + class FourierRingCorrelation(_FourierCorrelation): """ See `_FourierCorrelation`. diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index 1b05fb0ad1..4203b6a268 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -494,7 +494,7 @@ def load(cls, filename, permissive=True, dtype=np.float32): logger.info(f"{filename} with dtype {loaded_data.dtype} loaded as {dtype}") return cls(loaded_data.astype(dtype)) - def fsc(self, other, pixel_size=None, cutoff=0.143, eps=1e-4, method="fft"): + def fsc(self, other, pixel_size, cutoff=0.143, eps=1e-4, method="fft"): r""" Compute the Fourier shell correlation between two volumes. @@ -515,15 +515,14 @@ def fsc(self, other, pixel_size=None, cutoff=0.143, eps=1e-4, method="fft"): :param method: Selects either 'fft' (on cartesian grid), or 'nufft' (on polar grid). Defaults to 'fft'. - :return: tuple(estimated_resolution, FRC), + :return: tuple(estimated_resolution, FSC), where `estimated_resolution` is in Angstrom - and FRC is a Numpy array of correlations. + and FSC is a Numpy array of correlations. """ - # from aspire.reconstruction import FourierShellCorrelation if not isinstance(other, Volume): raise TypeError( - f"`other` image must be an `Volume` instance, received {type(other)}" + f"`other` volume must be an `Volume` instance, received {type(other)}" ) fsc = FourierShellCorrelation( diff --git a/tests/test_fourier_corrs.py b/tests/test_fourier_corrs.py index ff707c8240..ed64461f66 100644 --- a/tests/test_fourier_corrs.py +++ b/tests/test_fourier_corrs.py @@ -5,6 +5,7 @@ import numpy as np import pytest +from aspire.image import Image from aspire.noise import BlueNoiseAdder from aspire.numeric import fft from aspire.source import Simulation @@ -139,7 +140,7 @@ def test_frc_noise(image_fixture, method): def test_frc_plot(image_fixture, method): """ - Smoke test the plots. + Smoke test the plot. Also tests resetting the cutoff. """ @@ -152,6 +153,9 @@ def test_frc_plot(image_fixture, method): with matplotlib_no_gui(): frc.plot() + # Reset cutoff, then plot again + frc.cutoff = 0.143 + with tempfile.TemporaryDirectory() as tmp_input_dir: file_path = os.path.join(tmp_input_dir, "frc_curve.png") frc.plot(save_to_file=file_path) @@ -181,9 +185,7 @@ def test_fsc_trunc(volume_fixture, method): def test_fsc_plot(volume_fixture, method): """ - Smoke test the plots. - - Also tests resetting the cutoff. + Smoke test the plot. """ vol_a, vol_b = volume_fixture @@ -194,9 +196,83 @@ def test_fsc_plot(volume_fixture, method): with matplotlib_no_gui(): fsc.plot() - # Reset cutoff - fsc.cutoff = 0.143 - with tempfile.TemporaryDirectory() as tmp_input_dir: file_path = os.path.join(tmp_input_dir, "fsc_curve.png") fsc.plot(save_to_file=file_path) + + +# Check the error checks. + + +def test_dtype_mismatch(): + a = np.empty((8, 8), dtype=np.float32) + b = a.astype(np.float64) + + with pytest.raises(TypeError, match=r"Mismatched input types"): + _ = FourierRingCorrelation(a, b) + + +def test_type_mismatch(): + a = np.empty((8, 8), dtype=np.float32) + b = a.tolist() + + with pytest.raises(TypeError, match=r".*is not a Numpy array"): + _ = FourierRingCorrelation(a, b) + + +def test_data_shape_mismatch(): + a = np.empty((8, 8), dtype=np.float32) + b = np.empty((8, 9), dtype=np.float32) + + with pytest.raises(RuntimeError, match=r".*different data axis shapes"): + _ = FourierRingCorrelation(a, b) + + +def test_method_na(): + a = np.empty((8, 8), dtype=np.float32) + + with pytest.raises( + RuntimeError, match=r"Requested method.*not in available methods" + ): + _ = FourierRingCorrelation(a, a, method="man") + + +def test_cutoff_range(): + a = np.empty((8, 8), dtype=np.float32) + + with pytest.raises(ValueError, match=r"Supplied correlation `cutoff` not in"): + _ = FourierRingCorrelation(a, a, cutoff=2) + + +def test_2d_stack_plot_raise(): + a = np.random.random((2, 3, 8, 8)).astype(np.float32) + + with pytest.raises( + RuntimeError, match=r"Unable to plot figure tables with more than 2 dim" + ): + FourierRingCorrelation(a, a).plot() + + +def test_multiple_stack_plot_raise(): + a = np.random.random((3, 8, 8)).astype(np.float32) + + with pytest.raises( + RuntimeError, match=r"Unable to plot figure tables with more than 1 figure" + ): + FourierRingCorrelation(a, a).plot() + + +def test_img_type_mismatch(): + a = Image(np.empty((8, 8), dtype=np.float32)) + b = a.asnumpy() + + with pytest.raises(TypeError, match=r"`other` image must be an `Image` instance"): + _ = a.frc(b, pixel_size=1) + + +def test_vol_type_mismatch(): + a = Volume(np.empty((8, 8, 8), dtype=np.float32)) + b = a.asnumpy() + + with pytest.raises(TypeError, match=r"`other` volume must be an `Volume` instance"): + _ = a.fsc(b, pixel_size=1) From ca506e332ddf2aa50ddc116cc8bd8f65fb011979 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 28 Apr 2023 11:14:19 -0400 Subject: [PATCH 035/116] Rename fsc/frc tests to test_fourier_correlation --- tests/{test_fourier_corrs.py => test_fourier_correlation.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{test_fourier_corrs.py => test_fourier_correlation.py} (100%) diff --git a/tests/test_fourier_corrs.py b/tests/test_fourier_correlation.py similarity index 100% rename from tests/test_fourier_corrs.py rename to tests/test_fourier_correlation.py From 269c7bf5c44dea5b4afd9ec4d5cc24565a1da861 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 3 May 2023 13:04:56 -0400 Subject: [PATCH 036/116] spelling typo --- src/aspire/utils/resolution_estimation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aspire/utils/resolution_estimation.py b/src/aspire/utils/resolution_estimation.py index 41a50498ec..581624f364 100644 --- a/src/aspire/utils/resolution_estimation.py +++ b/src/aspire/utils/resolution_estimation.py @@ -87,7 +87,7 @@ def __init__(self, a, b, pixel_size=1, cutoff=0.143, eps=1e-4, method="fft"): @property def _fourier_axes(self): """ - Returns tuple representing the axis containing signal data + Returns tuple representing the axes containing signal data based on dimension `dim`. """ return tuple(range(-self.dim, 0)) From 4f88c5e9d58f2102032391b45cb7ac4500c882db Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 3 May 2023 14:06:08 -0400 Subject: [PATCH 037/116] rm superfluous sim source from FC tests, rms extra div --- tests/test_fourier_correlation.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/tests/test_fourier_correlation.py b/tests/test_fourier_correlation.py index ed64461f66..eceda4347d 100644 --- a/tests/test_fourier_correlation.py +++ b/tests/test_fourier_correlation.py @@ -59,19 +59,6 @@ def image_fixture(img_size, dtype): rots = Rotation.about_axis("z", thetas, dtype=dtype) # Contruct the Simulation source. - src = Simulation( - L=img_size, - n=2, - vols=v, - offsets=0, - amplitudes=1, - C=1, - angles=rots.angles, - dtype=dtype, - ) - - img, img_rot = src.images[:] - noisy_src = Simulation( L=img_size, n=2, @@ -80,9 +67,10 @@ def image_fixture(img_size, dtype): amplitudes=1, C=1, angles=rots.angles, - noise_adder=BlueNoiseAdder(var=np.var(img.asnumpy() * 0.5)), + noise_adder=BlueNoiseAdder.from_snr(2), dtype=dtype, ) + img, img_rot = noisy_src.clean_images[:] img_noisy = noisy_src.images[0] return img, img_rot, img_noisy @@ -128,14 +116,14 @@ def test_frc_rot(image_fixture, method): img_a, img_b, _ = image_fixture assert img_a.dtype == img_b.dtype frc_resolution, frc = img_a.frc(img_b, pixel_size=1, method=method) - assert np.isclose(frc_resolution[0][0], 3.78 / 2, rtol=0.1) + assert np.isclose(frc_resolution[0][0], 1.89, rtol=0.1) def test_frc_noise(image_fixture, method): img_a, _, img_n = image_fixture frc_resolution, frc = img_a.frc(img_n, pixel_size=1, method=method) - assert np.isclose(frc_resolution[0][0], 3.5 / 2, rtol=0.2) + assert np.isclose(frc_resolution[0][0], 1.75, rtol=0.2) def test_frc_plot(image_fixture, method): From 16712784f79d961fe7d68cb65b8f3e85b6c50425 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 3 May 2023 14:52:53 -0400 Subject: [PATCH 038/116] add plot passthrough to Image/Volume fc wrappers --- src/aspire/image/image.py | 10 ++++- src/aspire/utils/resolution_estimation.py | 4 +- src/aspire/volume/volume.py | 10 ++++- tests/test_fourier_correlation.py | 50 +++++++++++++++++++---- 4 files changed, 62 insertions(+), 12 deletions(-) diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index f5c7d3838f..f01ea635f7 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -477,7 +477,7 @@ def show(self, columns=5, figsize=(20, 10), colorbar=True): plt.show() - def frc(self, other, pixel_size, cutoff=0.143, eps=1e-4, method="fft"): + def frc(self, other, pixel_size, cutoff=0.143, eps=1e-4, method="fft", plot=False): r""" Compute the Fourier ring correlation between two images. @@ -497,6 +497,9 @@ def frc(self, other, pixel_size, cutoff=0.143, eps=1e-4, method="fft"): :param eps: Epsilon past boundary values, defaults 1e-4. :param method: Selects either 'fft' (on cartesian grid), or 'nufft' (on polar grid). Defaults to 'fft'. + :param plot: Optionally plot to screen or file. + Defaults to `False`. `True` plots to screen. + Passing a filepath as a string will attempt to save to file. :return: tuple(estimated_resolution, FRC), where `estimated_resolution` is in Angstrom @@ -517,6 +520,11 @@ def frc(self, other, pixel_size, cutoff=0.143, eps=1e-4, method="fft"): method=method, ) + if plot is True: + frc.plot() + elif plot: + frc.plot(save_to_file=plot) + return frc.estimated_resolution, frc.correlations diff --git a/src/aspire/utils/resolution_estimation.py b/src/aspire/utils/resolution_estimation.py index 581624f364..49aabd8c8b 100644 --- a/src/aspire/utils/resolution_estimation.py +++ b/src/aspire/utils/resolution_estimation.py @@ -367,8 +367,8 @@ def plot(self, save_to_file=False): if save_to_file: plt.savefig(save_to_file) - - plt.show() + else: + plt.show() # The following are user facing classes, and simply wrap diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index 4203b6a268..367431685b 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -494,7 +494,7 @@ def load(cls, filename, permissive=True, dtype=np.float32): logger.info(f"{filename} with dtype {loaded_data.dtype} loaded as {dtype}") return cls(loaded_data.astype(dtype)) - def fsc(self, other, pixel_size, cutoff=0.143, eps=1e-4, method="fft"): + def fsc(self, other, pixel_size, cutoff=0.143, eps=1e-4, method="fft", plot=False): r""" Compute the Fourier shell correlation between two volumes. @@ -514,6 +514,9 @@ def fsc(self, other, pixel_size, cutoff=0.143, eps=1e-4, method="fft"): :param eps: Epsilon past boundary values, defaults 1e-4. :param method: Selects either 'fft' (on cartesian grid), or 'nufft' (on polar grid). Defaults to 'fft'. + :param plot: Optionally plot to screen or file. + Defaults to `False`. `True` plots to screen. + Passing a filepath as a string will attempt to save to file. :return: tuple(estimated_resolution, FSC), where `estimated_resolution` is in Angstrom @@ -534,6 +537,11 @@ def fsc(self, other, pixel_size, cutoff=0.143, eps=1e-4, method="fft"): method=method, ) + if plot is True: + fsc.plot() + elif plot: + fsc.plot(save_to_file=plot) + return fsc.estimated_resolution, fsc.correlations diff --git a/tests/test_fourier_correlation.py b/tests/test_fourier_correlation.py index eceda4347d..610ef3e2c2 100644 --- a/tests/test_fourier_correlation.py +++ b/tests/test_fourier_correlation.py @@ -126,6 +126,23 @@ def test_frc_noise(image_fixture, method): assert np.isclose(frc_resolution[0][0], 1.75, rtol=0.2) +def test_frc_img_plot(image_fixture): + """ + Smoke test Image.frc(plot=) passthrough. + """ + img_a, _, img_n = image_fixture + + # Plot to screen + with matplotlib_no_gui(): + _ = img_a.frc(img_n, pixel_size=1, plot=True) + + # Plot to file + with tempfile.TemporaryDirectory() as tmp_input_dir: + file_path = os.path.join(tmp_input_dir, "img_frc_curve.png") + img_a.frc(img_n, pixel_size=1, plot=file_path) + assert os.path.exists(file_path) + + def test_frc_plot(image_fixture, method): """ Smoke test the plot. @@ -141,12 +158,12 @@ def test_frc_plot(image_fixture, method): with matplotlib_no_gui(): frc.plot() - # Reset cutoff, then plot again - frc.cutoff = 0.143 + # Reset cutoff, then plot again + frc.cutoff = 0.143 - with tempfile.TemporaryDirectory() as tmp_input_dir: - file_path = os.path.join(tmp_input_dir, "frc_curve.png") - frc.plot(save_to_file=file_path) + with tempfile.TemporaryDirectory() as tmp_input_dir: + file_path = os.path.join(tmp_input_dir, "frc_curve.png") + frc.plot(save_to_file=file_path) # FSC @@ -171,6 +188,23 @@ def test_fsc_trunc(volume_fixture, method): assert fsc_resolution[0][0] > 2.0 +def test_fsc_vol_plot(volume_fixture): + """ + Smoke test Image.frc(plot=) passthrough. + """ + vol_a, vol_b = volume_fixture + + # Plot to screen + with matplotlib_no_gui(): + _ = vol_a.fsc(vol_b, pixel_size=1, plot=True) + + # Plot to file + with tempfile.TemporaryDirectory() as tmp_input_dir: + file_path = os.path.join(tmp_input_dir, "img_fsc_curve.png") + vol_a.fsc(vol_b, pixel_size=1, plot=file_path) + assert os.path.exists(file_path) + + def test_fsc_plot(volume_fixture, method): """ Smoke test the plot. @@ -184,9 +218,9 @@ def test_fsc_plot(volume_fixture, method): with matplotlib_no_gui(): fsc.plot() - with tempfile.TemporaryDirectory() as tmp_input_dir: - file_path = os.path.join(tmp_input_dir, "fsc_curve.png") - fsc.plot(save_to_file=file_path) + with tempfile.TemporaryDirectory() as tmp_input_dir: + file_path = os.path.join(tmp_input_dir, "fsc_curve.png") + fsc.plot(save_to_file=file_path) # Check the error checks. From 74baf7df61a0e6de8301c3970612dbba50c99327 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 3 May 2023 15:26:23 -0400 Subject: [PATCH 039/116] _FourierCorrelation ~~> FourierCorrelation to popluate docs --- src/aspire/utils/resolution_estimation.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/aspire/utils/resolution_estimation.py b/src/aspire/utils/resolution_estimation.py index 49aabd8c8b..95d9e8f3bb 100644 --- a/src/aspire/utils/resolution_estimation.py +++ b/src/aspire/utils/resolution_estimation.py @@ -12,11 +12,11 @@ logger = logging.getLogger(__name__) -# _FourierCorrelation holds a single implementation for both FSC and +# FourierCorrelation holds a single implementation for both FSC and # FRC based on dimension `dim`. -class _FourierCorrelation: +class FourierCorrelation: r""" Compute the Fourier correlations between two arrays. @@ -372,22 +372,22 @@ def plot(self, save_to_file=False): # The following are user facing classes, and simply wrap -# `_FourierCorrelation` after assigning dimension `dim` and any +# `FourierCorrelation` after assigning dimension `dim` and any # dimension specific variables. -class FourierRingCorrelation(_FourierCorrelation): +class FourierRingCorrelation(FourierCorrelation): """ - See `_FourierCorrelation`. + See `FourierCorrelation`. """ dim = 2 _plot_title = "Fourier Ring Correlation" -class FourierShellCorrelation(_FourierCorrelation): +class FourierShellCorrelation(FourierCorrelation): """ - See `_FourierCorrelation`. + See `FourierCorrelation`. """ dim = 3 From 942f8f377a9461218033a4f41231a93446c26afc Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 4 May 2023 11:07:30 -0400 Subject: [PATCH 040/116] rename filter value arg to power --- src/aspire/noise/noise.py | 4 ++-- src/aspire/operators/filters.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/aspire/noise/noise.py b/src/aspire/noise/noise.py index b52a4f4a48..6741945e53 100644 --- a/src/aspire/noise/noise.py +++ b/src/aspire/noise/noise.py @@ -201,7 +201,7 @@ def _build(self): # Call the __init__ from parent of WhiteNoiseAdder. super(WhiteNoiseAdder, self).__init__( - noise_filter=BlueFilter(value=self.noise_var), seed=self.seed + noise_filter=BlueFilter(power=self.noise_var), seed=self.seed ) @@ -217,7 +217,7 @@ def _build(self): # Call the __init__ from parent of WhiteNoiseAdder. super(WhiteNoiseAdder, self).__init__( - noise_filter=PinkFilter(value=self.noise_var), seed=self.seed + noise_filter=PinkFilter(power=self.noise_var), seed=self.seed ) diff --git a/src/aspire/operators/filters.py b/src/aspire/operators/filters.py index d8db65d7f0..91a45e287d 100644 --- a/src/aspire/operators/filters.py +++ b/src/aspire/operators/filters.py @@ -465,19 +465,19 @@ class BlueFilter(Filter): Filter where power increases with frequency. """ - def __init__(self, dim=None, value=1): + def __init__(self, dim=None, power=1): super().__init__(dim=dim, radial=True) - self.value = value + self.power = power def __repr__(self): - return f"BlueFilter(dim={self.dim}, value={self.value})" + return f"BlueFilter(dim={self.dim}, power={self.power})" def _evaluate(self, omega): f = np.sqrt(omega[0]) m = np.mean(f) f = f / m - return self.value * f + return self.power * f class PinkFilter(Filter): @@ -485,12 +485,12 @@ class PinkFilter(Filter): Filter where power decreases with frequency. """ - def __init__(self, dim=None, value=1): + def __init__(self, dim=None, power=1): super().__init__(dim=dim, radial=True) - self.value = value + self.power = power def __repr__(self): - return f"PinkFilter(dim={self.dim}, value={self.value})" + return f"PinkFilter(dim={self.dim}, power={self.power})" def _evaluate(self, omega): step = np.abs(np.subtract(*omega[0][:2])) @@ -499,4 +499,4 @@ def _evaluate(self, omega): m = np.mean(f) f = f / m - return self.value * f + return self.power * f From f6eb880a308df8aba1d97ca5723b3927ab344a35 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 4 May 2023 11:21:21 -0400 Subject: [PATCH 041/116] angstrom and shape matching --- src/aspire/utils/resolution_estimation.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/aspire/utils/resolution_estimation.py b/src/aspire/utils/resolution_estimation.py index 95d9e8f3bb..26f6644a72 100644 --- a/src/aspire/utils/resolution_estimation.py +++ b/src/aspire/utils/resolution_estimation.py @@ -35,7 +35,7 @@ def __init__(self, a, b, pixel_size=1, cutoff=0.143, eps=1e-4, method="fft"): """ :param a: Input array a, shape(..., *dim). :param b: Input array b, shape(..., *dim). - :param pixel_size: Pixel size in Angstrom. + :param pixel_size: Pixel size in angstrom. Defaults to 1. :param cutoff: Cutoff value, traditionally `.143`. :param eps: Epsilon past boundary values, defaults 1e-4. @@ -57,9 +57,9 @@ def __init__(self, a, b, pixel_size=1, cutoff=0.143, eps=1e-4, method="fft"): # TODO, check-math/avoid complex inputs. # Shape checks - if not a.shape[-1] == b.shape[-1]: + if not a.shape[-self.dim:] == b.shape[-self.dim:]: raise RuntimeError( - f"`a` and `b` appear to have different data axis shapes, {a.shape[-1]} {b.shape[-1]}" + f"`a` and `b` appear to have different data axis shapes, {a.shape} {b.shape}" ) # Method selection @@ -285,10 +285,10 @@ def analyze_correlations(self): # where all correlations are below cutoff. c_ind = np.maximum(c_inds, np.argmax(self.correlations <= self.cutoff, axis=-1)) - # Convert indices to frequency (as 1/Angstrom) + # Convert indices to frequency (as 1/angstrom) frequencies = self._freq(c_ind) - # Convert to resolution in Angstrom, smaller is higher frequency. + # Convert to resolution in angstrom, smaller is higher frequency. self._resolutions = 1 / frequencies def _freq(self, k): @@ -297,7 +297,7 @@ def _freq(self, k): length 1/A). :param k: Frequency index, integer or Numpy array of ints. - :return: Frequency in 1/Angstrom. + :return: Frequency in 1/angstrom. """ # From Shannon-Nyquist, for a given pixel-size, sampling theorem @@ -306,13 +306,13 @@ def _freq(self, k): # total bandwidth is `2*(1/pixel_size)`. # Given a real space signal observed with `L` bins - # (pixels/voxels), each with a `pixel_size` in Angstrom, we can + # (pixels/voxels), each with a `pixel_size` in angstrom, we can # compute the width of a Fourier space bin to be the `Bandwidth # / L = (2*(1/pixel_size)) / L`. Thus the frequency at an index # `k` is `freq_k = k * 2 * (1 / pixel_size) / L = k * 2 / # (pixel_size * L) - # _freq(k) Units: 1 / (pixels * (Angstrom / pixel) = 1 / Angstrom + # _freq(k) Units: 1 / (pixels * (angstrom / pixel) = 1 / angstrom # Similar idea to wavenumbers (cm-1). Larger is higher frequency. return k * 2 / (self.L * self.pixel_size) @@ -346,7 +346,7 @@ def plot(self, save_to_file=False): plt.figure(figsize=(8, 6)) plt.title(self._plot_title) - plt.xlabel("Resolution (Angstrom)") + plt.xlabel("Resolution (angstrom)") plt.ylabel("Correlation") plt.ylim([0, 1.1]) plt.plot(freqs_angstrom, self.correlations[0][0]) From bb2dbc368898e7c41133881a408c55ff978765bd Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 4 May 2023 11:21:46 -0400 Subject: [PATCH 042/116] 3d grid function --- src/aspire/utils/resolution_estimation.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/aspire/utils/resolution_estimation.py b/src/aspire/utils/resolution_estimation.py index 26f6644a72..564aadb0eb 100644 --- a/src/aspire/utils/resolution_estimation.py +++ b/src/aspire/utils/resolution_estimation.py @@ -8,7 +8,7 @@ from aspire.nufft import nufft from aspire.numeric import fft -from aspire.utils import grid_2d +from aspire.utils import grid_2d, grid_3d logger = logging.getLogger(__name__) @@ -150,7 +150,12 @@ def _fft_correlations(self): """ # Compute shells from 2D grid. - radii = grid_2d(self.L, shifted=True, normalized=False, dtype=self.dtype)["r"] + if self.dim == 2: + grid_function = grid_2d + elif self.dim == 3: + grid_function = grid_3d + + radii = grid_function(self.L, shifted=True, normalized=False, dtype=self.dtype)["r"] # Compute centered Fourier transforms. f1 = fft.centered_fftn(self._a, axes=self._fourier_axes) From 2136d4deddc817ccb66071ae37008bda06d95c6f Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 4 May 2023 12:21:23 -0400 Subject: [PATCH 043/116] rename filter power arg to var --- src/aspire/noise/noise.py | 4 ++-- src/aspire/operators/filters.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/aspire/noise/noise.py b/src/aspire/noise/noise.py index 6741945e53..154e77f624 100644 --- a/src/aspire/noise/noise.py +++ b/src/aspire/noise/noise.py @@ -201,7 +201,7 @@ def _build(self): # Call the __init__ from parent of WhiteNoiseAdder. super(WhiteNoiseAdder, self).__init__( - noise_filter=BlueFilter(power=self.noise_var), seed=self.seed + noise_filter=BlueFilter(var=self.noise_var), seed=self.seed ) @@ -217,7 +217,7 @@ def _build(self): # Call the __init__ from parent of WhiteNoiseAdder. super(WhiteNoiseAdder, self).__init__( - noise_filter=PinkFilter(power=self.noise_var), seed=self.seed + noise_filter=PinkFilter(var=self.noise_var), seed=self.seed ) diff --git a/src/aspire/operators/filters.py b/src/aspire/operators/filters.py index 91a45e287d..004a7a12bd 100644 --- a/src/aspire/operators/filters.py +++ b/src/aspire/operators/filters.py @@ -465,19 +465,19 @@ class BlueFilter(Filter): Filter where power increases with frequency. """ - def __init__(self, dim=None, power=1): + def __init__(self, dim=None, var=1): super().__init__(dim=dim, radial=True) - self.power = power + self.var = var def __repr__(self): - return f"BlueFilter(dim={self.dim}, power={self.power})" + return f"BlueFilter(dim={self.dim}, var={self.var})" def _evaluate(self, omega): f = np.sqrt(omega[0]) m = np.mean(f) f = f / m - return self.power * f + return self.var * f class PinkFilter(Filter): @@ -485,12 +485,12 @@ class PinkFilter(Filter): Filter where power decreases with frequency. """ - def __init__(self, dim=None, power=1): + def __init__(self, dim=None, var=1): super().__init__(dim=dim, radial=True) - self.power = power + self.var = var def __repr__(self): - return f"PinkFilter(dim={self.dim}, power={self.power})" + return f"PinkFilter(dim={self.dim}, var={self.var})" def _evaluate(self, omega): step = np.abs(np.subtract(*omega[0][:2])) @@ -499,4 +499,4 @@ def _evaluate(self, omega): m = np.mean(f) f = f / m - return self.power * f + return self.var * f From be0a53576b9991674392227a1fc007946fb36262 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 4 May 2023 13:00:00 -0400 Subject: [PATCH 044/116] linter --- src/aspire/utils/resolution_estimation.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/aspire/utils/resolution_estimation.py b/src/aspire/utils/resolution_estimation.py index 564aadb0eb..3f7b27f22b 100644 --- a/src/aspire/utils/resolution_estimation.py +++ b/src/aspire/utils/resolution_estimation.py @@ -57,7 +57,7 @@ def __init__(self, a, b, pixel_size=1, cutoff=0.143, eps=1e-4, method="fft"): # TODO, check-math/avoid complex inputs. # Shape checks - if not a.shape[-self.dim:] == b.shape[-self.dim:]: + if not a.shape[-self.dim :] == b.shape[-self.dim :]: raise RuntimeError( f"`a` and `b` appear to have different data axis shapes, {a.shape} {b.shape}" ) @@ -153,9 +153,11 @@ def _fft_correlations(self): if self.dim == 2: grid_function = grid_2d elif self.dim == 3: - grid_function = grid_3d + grid_function = grid_3d - radii = grid_function(self.L, shifted=True, normalized=False, dtype=self.dtype)["r"] + radii = grid_function(self.L, shifted=True, normalized=False, dtype=self.dtype)[ + "r" + ] # Compute centered Fourier transforms. f1 = fft.centered_fftn(self._a, axes=self._fourier_axes) From 5baaa26bceac470f210d1f4bc43cec9b9be92852 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 4 May 2023 13:15:25 -0400 Subject: [PATCH 045/116] removed epsilon --- src/aspire/image/image.py | 4 +--- src/aspire/utils/resolution_estimation.py | 8 +++----- src/aspire/volume/volume.py | 4 +--- 3 files changed, 5 insertions(+), 11 deletions(-) diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index f01ea635f7..f921cc36c4 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -477,7 +477,7 @@ def show(self, columns=5, figsize=(20, 10), colorbar=True): plt.show() - def frc(self, other, pixel_size, cutoff=0.143, eps=1e-4, method="fft", plot=False): + def frc(self, other, pixel_size, cutoff=0.143, method="fft", plot=False): r""" Compute the Fourier ring correlation between two images. @@ -494,7 +494,6 @@ def frc(self, other, pixel_size, cutoff=0.143, eps=1e-4, method="fft", plot=Fals :param pixel_size: Pixel size in Angstrom. For synthetic data, 1 is a reasonable value. :param cutoff: Cutoff value, traditionally `1.43`. - :param eps: Epsilon past boundary values, defaults 1e-4. :param method: Selects either 'fft' (on cartesian grid), or 'nufft' (on polar grid). Defaults to 'fft'. :param plot: Optionally plot to screen or file. @@ -516,7 +515,6 @@ def frc(self, other, pixel_size, cutoff=0.143, eps=1e-4, method="fft", plot=Fals b=other.asnumpy(), pixel_size=pixel_size, cutoff=cutoff, - eps=eps, method=method, ) diff --git a/src/aspire/utils/resolution_estimation.py b/src/aspire/utils/resolution_estimation.py index 3f7b27f22b..921c37a6d2 100644 --- a/src/aspire/utils/resolution_estimation.py +++ b/src/aspire/utils/resolution_estimation.py @@ -31,14 +31,13 @@ class FourierCorrelation: . """ - def __init__(self, a, b, pixel_size=1, cutoff=0.143, eps=1e-4, method="fft"): + def __init__(self, a, b, pixel_size=1, cutoff=0.143, method="fft"): """ :param a: Input array a, shape(..., *dim). :param b: Input array b, shape(..., *dim). :param pixel_size: Pixel size in angstrom. Defaults to 1. :param cutoff: Cutoff value, traditionally `.143`. - :param eps: Epsilon past boundary values, defaults 1e-4. :param method: Selects either 'fft' (on cartesian grid), or 'nufft' (on polar grid). Defaults to 'fft'. """ @@ -79,7 +78,6 @@ def __init__(self, a, b, pixel_size=1, cutoff=0.143, eps=1e-4, method="fft"): self._analyzed = False self.cutoff = cutoff self.pixel_size = float(pixel_size) - self.eps = float(eps) self._correlations = None self.L = self._a.shape[-1] self.dtype = self._a.dtype @@ -168,10 +166,10 @@ def _fft_correlations(self): (self.L // 2, self._a.shape[0], self._b.shape[0]), dtype=self.dtype ) - inner_diameter = 0.5 + self.eps + inner_diameter = 0.5 for i in range(0, self.L // 2): # Compute ring mask - outer_diameter = 0.5 + (i + 1) + self.eps + outer_diameter = 0.5 + (i + 1) ring_mask = (radii > inner_diameter) & (radii < outer_diameter) logger.debug(f"Shell, Elements: {i}, {np.sum(ring_mask)}") diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index 367431685b..a72413c97d 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -494,7 +494,7 @@ def load(cls, filename, permissive=True, dtype=np.float32): logger.info(f"{filename} with dtype {loaded_data.dtype} loaded as {dtype}") return cls(loaded_data.astype(dtype)) - def fsc(self, other, pixel_size, cutoff=0.143, eps=1e-4, method="fft", plot=False): + def fsc(self, other, pixel_size, cutoff=0.143, method="fft", plot=False): r""" Compute the Fourier shell correlation between two volumes. @@ -511,7 +511,6 @@ def fsc(self, other, pixel_size, cutoff=0.143, eps=1e-4, method="fft", plot=Fals :param pixel_size: Pixel size in Angstrom. For synthetic data, 1 is a reasonable value. :param cutoff: Cutoff value, traditionally `.143`. - :param eps: Epsilon past boundary values, defaults 1e-4. :param method: Selects either 'fft' (on cartesian grid), or 'nufft' (on polar grid). Defaults to 'fft'. :param plot: Optionally plot to screen or file. @@ -533,7 +532,6 @@ def fsc(self, other, pixel_size, cutoff=0.143, eps=1e-4, method="fft", plot=Fals b=other.asnumpy(), pixel_size=pixel_size, cutoff=cutoff, - eps=eps, method=method, ) From e8a467e9c2f978c6852ac80f02895eeed97a3506 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 4 May 2023 13:28:14 -0400 Subject: [PATCH 046/116] Cartesian, correlation, caps --- src/aspire/utils/resolution_estimation.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/aspire/utils/resolution_estimation.py b/src/aspire/utils/resolution_estimation.py index 921c37a6d2..d676ad3ee7 100644 --- a/src/aspire/utils/resolution_estimation.py +++ b/src/aspire/utils/resolution_estimation.py @@ -38,7 +38,7 @@ def __init__(self, a, b, pixel_size=1, cutoff=0.143, method="fft"): :param pixel_size: Pixel size in angstrom. Defaults to 1. :param cutoff: Cutoff value, traditionally `.143`. - :param method: Selects either 'fft' (on cartesian grid), + :param method: Selects either 'fft' (on Cartesian grid), or 'nufft' (on polar grid). Defaults to 'fft'. """ @@ -144,7 +144,7 @@ def correlations(self): def _fft_correlations(self): """ - Computes Fourier Correlations using the FFT on a cartesian grid. + Computes Fourier correlations using the FFT on a Cartesian grid. """ # Compute shells from 2D grid. @@ -177,7 +177,7 @@ def _fft_correlations(self): r1 = ring_mask * f1 r2 = ring_mask * f2 - # Compute Fourier Correlations + # Compute Fourier correlations num = np.real(np.sum(r1 * np.conj(r2), axis=self._fourier_axes)) den = np.sqrt( np.sum(np.abs(r1) ** 2, axis=self._fourier_axes) @@ -197,11 +197,11 @@ def _fft_correlations(self): def _nufft_correlations(self): """ - Computes Fourier Correlations using the NUFFT on a polar grid. + Computes Fourier correlations using the NUFFT on a polar grid. """ # TODO, we could use an internal tool (Polar2D?) for this. - # L//2 is intentionally used for compatibility with cartesian grid. + # L//2 is intentionally used for compatibility with Cartesian grid. # This avoids having to have multiple methods for computing resolutions later. r = np.linspace(0, np.pi, self.L // 2, endpoint=False, dtype=self.dtype) phi = np.linspace(0, 2 * np.pi, 2 * self.L, endpoint=False, dtype=self.dtype) @@ -270,7 +270,7 @@ def estimated_resolution(self): def analyze_correlations(self): """ - Convert from the Fourier Correlations to frequencies and resolution. + Convert from the Fourier correlations to frequencies and resolution. """ # `_analyzed` attribute in conjunction with `cutoff` allow a # user to try different cutoffs without recomputing the From 5dbbf8cf2e1411695575dee1d1aa7b42b2e53ef8 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 4 May 2023 15:44:50 -0400 Subject: [PATCH 047/116] Migrate cutoff to method arguments --- src/aspire/image/image.py | 11 ++-- src/aspire/utils/resolution_estimation.py | 71 +++++++---------------- src/aspire/volume/volume.py | 9 ++- tests/test_fourier_correlation.py | 45 +++++++------- 4 files changed, 51 insertions(+), 85 deletions(-) diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index f921cc36c4..96cc7ad9df 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -477,7 +477,7 @@ def show(self, columns=5, figsize=(20, 10), colorbar=True): plt.show() - def frc(self, other, pixel_size, cutoff=0.143, method="fft", plot=False): + def frc(self, other, pixel_size, cutoff, method="fft", plot=False): r""" Compute the Fourier ring correlation between two images. @@ -493,7 +493,7 @@ def frc(self, other, pixel_size, cutoff=0.143, method="fft", plot=False): :param other: `Image` instance to compare. :param pixel_size: Pixel size in Angstrom. For synthetic data, 1 is a reasonable value. - :param cutoff: Cutoff value, traditionally `1.43`. + :param cutoff: Cutoff value, traditionally `.143`. :param method: Selects either 'fft' (on cartesian grid), or 'nufft' (on polar grid). Defaults to 'fft'. :param plot: Optionally plot to screen or file. @@ -514,16 +514,15 @@ def frc(self, other, pixel_size, cutoff=0.143, method="fft", plot=False): a=self.asnumpy(), b=other.asnumpy(), pixel_size=pixel_size, - cutoff=cutoff, method=method, ) if plot is True: - frc.plot() + frc.plot(cutoff=cutoff) elif plot: - frc.plot(save_to_file=plot) + frc.plot(cutoff=cutoff, save_to_file=plot) - return frc.estimated_resolution, frc.correlations + return frc.analyze_correlations(cutoff), frc.correlations class CartesianImage(Image): diff --git a/src/aspire/utils/resolution_estimation.py b/src/aspire/utils/resolution_estimation.py index d676ad3ee7..bcb589d7b3 100644 --- a/src/aspire/utils/resolution_estimation.py +++ b/src/aspire/utils/resolution_estimation.py @@ -31,13 +31,12 @@ class FourierCorrelation: . """ - def __init__(self, a, b, pixel_size=1, cutoff=0.143, method="fft"): + def __init__(self, a, b, pixel_size=1, method="fft"): """ :param a: Input array a, shape(..., *dim). :param b: Input array b, shape(..., *dim). :param pixel_size: Pixel size in angstrom. Defaults to 1. - :param cutoff: Cutoff value, traditionally `.143`. :param method: Selects either 'fft' (on Cartesian grid), or 'nufft' (on polar grid). Defaults to 'fft'. """ @@ -75,8 +74,6 @@ def __init__(self, a, b, pixel_size=1, cutoff=0.143, method="fft"): self._a, self._a_stack_shape = self._reshape(a) self._b, self._b_stack_shape = self._reshape(b) - self._analyzed = False - self.cutoff = cutoff self.pixel_size = float(pixel_size) self._correlations = None self.L = self._a.shape[-1] @@ -104,28 +101,6 @@ def _reshape(self, x): x = x.reshape(-1, *x.shape[-self.dim :]) return x, original_stack_shape - @property - def cutoff(self): - """ - Returns `cutoff` value. - """ - return self._cutoff - - @cutoff.setter - def cutoff(self, cutoff): - """ - Sets `cutoff` value, and resets analysis, which is dependent - on `cutoff` values. - - :param cutoff: Float - """ - self._cutoff = float(cutoff) - if not (0 <= self._cutoff <= 1): - raise ValueError( - "Supplied correlation `cutoff` not in [0,1], {self._cutoff}" - ) - self._analyzed = False # reset analysis - @property def correlations(self): """ @@ -258,37 +233,27 @@ def _nufft_correlations(self): *self._a_stack_shape, *self._b_stack_shape, r.shape[-1] ) - @property - def estimated_resolution(self): - """ - Return estimated resolution of stacks `a` cross `b`. - - :return: Numpy array. - """ - self.analyze_correlations() - return self._resolutions - - def analyze_correlations(self): + def analyze_correlations(self, cutoff): """ Convert from the Fourier correlations to frequencies and resolution. + :param cutoff: Cutoff value, traditionally `.143`. """ - # `_analyzed` attribute in conjunction with `cutoff` allow a - # user to try different cutoffs without recomputing the - # correlations (FFT/NUFFT calls). - if self._analyzed: - return + + cutoff = float(cutoff) + if not (0 <= cutoff <= 1): + raise ValueError("Supplied correlation `cutoff` not in [0,1], {cutoff}") c_inds = np.zeros(self.correlations.shape[:-1], dtype=int) # All correlations are above cutoff, # set index of highest sampled frequency. - c_inds[np.min(self.correlations, axis=-1) > self.cutoff] = self.L // 2 + c_inds[np.min(self.correlations, axis=-1) > cutoff] = self.L // 2 # Correlations cross the cutoff. # Find the first index of a correlation at `cutoff`. # Should return 0 if not found, which corresponded to the case # where all correlations are below cutoff. - c_ind = np.maximum(c_inds, np.argmax(self.correlations <= self.cutoff, axis=-1)) + c_ind = np.maximum(c_inds, np.argmax(self.correlations <= cutoff, axis=-1)) # Convert indices to frequency (as 1/angstrom) frequencies = self._freq(c_ind) @@ -296,6 +261,8 @@ def analyze_correlations(self): # Convert to resolution in angstrom, smaller is higher frequency. self._resolutions = 1 / frequencies + return self._resolutions + def _freq(self, k): """ Converts `k` from index of Fourier transform to frequency (as @@ -321,15 +288,19 @@ def _freq(self, k): # Similar idea to wavenumbers (cm-1). Larger is higher frequency. return k * 2 / (self.L * self.pixel_size) - def plot(self, save_to_file=False): + def plot(self, cutoff, save_to_file=False): """ Generates a Fourier correlation plot. + :param cutoff: Cutoff value, traditionally `.143`. :param save_to_file: Optionally, save plot to file. Defaults False, enabled by providing a string filename. User is responsible for providing reasonable filename. See `https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.savefig.html`. """ + cutoff = float(cutoff) + if not (0 <= cutoff <= 1): + raise ValueError("Supplied correlation `cutoff` not in [0,1], {cutoff}") # Construct x-axis labels x_inds = np.arange(self.correlations.shape[-1]) @@ -356,15 +327,15 @@ def plot(self, save_to_file=False): plt.ylim([0, 1.1]) plt.plot(freqs_angstrom, self.correlations[0][0]) # Display cutoff - plt.axhline( - y=self.cutoff, color="r", linestyle="--", label=f"cutoff={self.cutoff}" - ) + plt.axhline(y=cutoff, color="r", linestyle="--", label=f"cutoff={cutoff}") + estimated_resolution = self.analyze_correlations(cutoff)[0][0] + # Display resolution plt.axvline( - x=self.estimated_resolution[0][0], + x=estimated_resolution, color="b", linestyle=":", - label=f"Resolution={self.estimated_resolution[0][0]:.3f}", + label=f"Resolution={estimated_resolution:.3f}", ) # x-axis in decreasing plt.gca().invert_xaxis() diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index a72413c97d..ae1a8564b5 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -494,7 +494,7 @@ def load(cls, filename, permissive=True, dtype=np.float32): logger.info(f"{filename} with dtype {loaded_data.dtype} loaded as {dtype}") return cls(loaded_data.astype(dtype)) - def fsc(self, other, pixel_size, cutoff=0.143, method="fft", plot=False): + def fsc(self, other, pixel_size, cutoff, method="fft", plot=False): r""" Compute the Fourier shell correlation between two volumes. @@ -531,16 +531,15 @@ def fsc(self, other, pixel_size, cutoff=0.143, method="fft", plot=False): a=self.asnumpy(), b=other.asnumpy(), pixel_size=pixel_size, - cutoff=cutoff, method=method, ) if plot is True: - fsc.plot() + fsc.plot(cutoff=cutoff) elif plot: - fsc.plot(save_to_file=plot) + fsc.plot(cutoff=cutoff, save_to_file=plot) - return fsc.estimated_resolution, fsc.correlations + return fsc.analyze_correlations(cutoff), fsc.correlations class CartesianVolume(Volume): diff --git a/tests/test_fourier_correlation.py b/tests/test_fourier_correlation.py index 610ef3e2c2..06da9de77f 100644 --- a/tests/test_fourier_correlation.py +++ b/tests/test_fourier_correlation.py @@ -107,7 +107,7 @@ def volume_fixture(img_size, dtype): def test_frc_id(image_fixture, method): img, _, _ = image_fixture - frc_resolution, frc = img.frc(img, pixel_size=1, method=method) + frc_resolution, frc = img.frc(img, pixel_size=1, cutoff=0.143, method=method) assert np.isclose(frc_resolution[0][0], 1, rtol=0.02) assert np.allclose(frc, 1) @@ -115,14 +115,14 @@ def test_frc_id(image_fixture, method): def test_frc_rot(image_fixture, method): img_a, img_b, _ = image_fixture assert img_a.dtype == img_b.dtype - frc_resolution, frc = img_a.frc(img_b, pixel_size=1, method=method) + frc_resolution, frc = img_a.frc(img_b, pixel_size=1, cutoff=0.143, method=method) assert np.isclose(frc_resolution[0][0], 1.89, rtol=0.1) def test_frc_noise(image_fixture, method): img_a, _, img_n = image_fixture - frc_resolution, frc = img_a.frc(img_n, pixel_size=1, method=method) + frc_resolution, frc = img_a.frc(img_n, pixel_size=1, cutoff=0.143, method=method) assert np.isclose(frc_resolution[0][0], 1.75, rtol=0.2) @@ -134,12 +134,12 @@ def test_frc_img_plot(image_fixture): # Plot to screen with matplotlib_no_gui(): - _ = img_a.frc(img_n, pixel_size=1, plot=True) + _ = img_a.frc(img_n, pixel_size=1, cutoff=0.143, plot=True) # Plot to file with tempfile.TemporaryDirectory() as tmp_input_dir: file_path = os.path.join(tmp_input_dir, "img_frc_curve.png") - img_a.frc(img_n, pixel_size=1, plot=file_path) + img_a.frc(img_n, pixel_size=1, cutoff=0.143, plot=file_path) assert os.path.exists(file_path) @@ -152,18 +152,15 @@ def test_frc_plot(image_fixture, method): img_a, img_b, _ = image_fixture frc = FourierRingCorrelation( - img_a.asnumpy(), img_b.asnumpy(), pixel_size=1, method=method, cutoff=0.5 + img_a.asnumpy(), img_b.asnumpy(), pixel_size=1, method=method ) with matplotlib_no_gui(): - frc.plot() - - # Reset cutoff, then plot again - frc.cutoff = 0.143 + frc.plot(cutoff=0.5) with tempfile.TemporaryDirectory() as tmp_input_dir: file_path = os.path.join(tmp_input_dir, "frc_curve.png") - frc.plot(save_to_file=file_path) + frc.plot(cutoff=0.143, save_to_file=file_path) # FSC @@ -172,7 +169,7 @@ def test_frc_plot(image_fixture, method): def test_fsc_id(volume_fixture, method): vol, _ = volume_fixture - fsc_resolution, fsc = vol.fsc(vol, pixel_size=1, method=method) + fsc_resolution, fsc = vol.fsc(vol, pixel_size=1, cutoff=0.143, method=method) assert np.isclose(fsc_resolution[0][0], 1, rtol=0.02) assert np.allclose(fsc, 1) @@ -180,11 +177,11 @@ def test_fsc_id(volume_fixture, method): def test_fsc_trunc(volume_fixture, method): vol_a, vol_b = volume_fixture - fsc_resolution, fsc = vol_a.fsc(vol_b, pixel_size=1, method=method) + fsc_resolution, fsc = vol_a.fsc(vol_b, pixel_size=1, cutoff=0.143, method=method) assert fsc_resolution[0][0] > 1.5 # The follow should correspond to the test_fsc_plot below. - fsc_resolution, fsc = vol_a.fsc(vol_b, pixel_size=1, method=method, cutoff=0.5) + fsc_resolution, fsc = vol_a.fsc(vol_b, pixel_size=1, cutoff=0.5, method=method) assert fsc_resolution[0][0] > 2.0 @@ -196,12 +193,12 @@ def test_fsc_vol_plot(volume_fixture): # Plot to screen with matplotlib_no_gui(): - _ = vol_a.fsc(vol_b, pixel_size=1, plot=True) + _ = vol_a.fsc(vol_b, pixel_size=1, cutoff=0.5, plot=True) # Plot to file with tempfile.TemporaryDirectory() as tmp_input_dir: file_path = os.path.join(tmp_input_dir, "img_fsc_curve.png") - vol_a.fsc(vol_b, pixel_size=1, plot=file_path) + vol_a.fsc(vol_b, pixel_size=1, cutoff=0.143, plot=file_path) assert os.path.exists(file_path) @@ -212,15 +209,15 @@ def test_fsc_plot(volume_fixture, method): vol_a, vol_b = volume_fixture fsc = FourierShellCorrelation( - vol_a.asnumpy(), vol_b.asnumpy(), pixel_size=1, method=method, cutoff=0.5 + vol_a.asnumpy(), vol_b.asnumpy(), pixel_size=1, method=method ) with matplotlib_no_gui(): - fsc.plot() + fsc.plot(cutoff=0.5) with tempfile.TemporaryDirectory() as tmp_input_dir: file_path = os.path.join(tmp_input_dir, "fsc_curve.png") - fsc.plot(save_to_file=file_path) + fsc.plot(cutoff=0.143, save_to_file=file_path) # Check the error checks. @@ -263,7 +260,7 @@ def test_cutoff_range(): a = np.empty((8, 8), dtype=np.float32) with pytest.raises(ValueError, match=r"Supplied correlation `cutoff` not in"): - _ = FourierRingCorrelation(a, a, cutoff=2) + _ = FourierRingCorrelation(a, a).analyze_correlations(cutoff=2) def test_2d_stack_plot_raise(): @@ -272,7 +269,7 @@ def test_2d_stack_plot_raise(): with pytest.raises( RuntimeError, match=r"Unable to plot figure tables with more than 2 dim" ): - FourierRingCorrelation(a, a).plot() + FourierRingCorrelation(a, a).plot(cutoff=0.143) def test_multiple_stack_plot_raise(): @@ -281,7 +278,7 @@ def test_multiple_stack_plot_raise(): with pytest.raises( RuntimeError, match=r"Unable to plot figure tables with more than 1 figure" ): - FourierRingCorrelation(a, a).plot() + FourierRingCorrelation(a, a).plot(cutoff=0.143) def test_img_type_mismatch(): @@ -289,7 +286,7 @@ def test_img_type_mismatch(): b = a.asnumpy() with pytest.raises(TypeError, match=r"`other` image must be an `Image` instance"): - _ = a.frc(b, pixel_size=1) + _ = a.frc(b, pixel_size=1, cutoff=0.143) def test_vol_type_mismatch(): @@ -297,4 +294,4 @@ def test_vol_type_mismatch(): b = a.asnumpy() with pytest.raises(TypeError, match=r"`other` volume must be an `Volume` instance"): - _ = a.fsc(b, pixel_size=1) + _ = a.fsc(b, pixel_size=1, cutoff=0.143) From 9d8de04a4b089202044cb1c7d4910a8eca54d45b Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 5 May 2023 13:18:52 -0400 Subject: [PATCH 048/116] Begin adding broadcasting support --- src/aspire/utils/resolution_estimation.py | 36 +++++++++++++---------- tests/test_fourier_correlation.py | 19 ++++++------ 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/src/aspire/utils/resolution_estimation.py b/src/aspire/utils/resolution_estimation.py index bcb589d7b3..e96e6b02bc 100644 --- a/src/aspire/utils/resolution_estimation.py +++ b/src/aspire/utils/resolution_estimation.py @@ -28,7 +28,12 @@ class FourierCorrelation: c(i) = \frac{ \operatorname{Re}( \sum_i{ \mathcal{F}_1(i) * {\mathcal{F}^{*}_2(i) } } ) }{\ \sqrt{ \sum_i { | \mathcal{F}_1(i) |^2 } * \sum_i{| \mathcal{F}^{*}_2}(i) |^2 } } -. + + This implementation supports Numpy style broadcasting resulting in + up to two stack dimensions. For example, to compute all pairs + supply signal with stack shapes (m,1) and (1,n) to yield an (m,n) + table of results. Note that plotting is limited to a single + reference signal. """ def __init__(self, a, b, pixel_size=1, method="fft"): @@ -71,8 +76,12 @@ def __init__(self, a, b, pixel_size=1, method="fft"): # To support arbitrary broadcasting simply, # we'll force all shapes to be (-1, *(L,)*dim) + # and keep track of the stack shapes. self._a, self._a_stack_shape = self._reshape(a) self._b, self._b_stack_shape = self._reshape(b) + self._result_stack_shape = np.broadcast_shapes( + self._a_stack_shape, self._b_stack_shape + ) self.pixel_size = float(pixel_size) self._correlations = None @@ -138,7 +147,7 @@ def _fft_correlations(self): # Construct an output table of correlations correlations = np.zeros( - (self.L // 2, self._a.shape[0], self._b.shape[0]), dtype=self.dtype + (self.L // 2, *self._result_stack_shape), dtype=self.dtype ) inner_diameter = 0.5 @@ -163,12 +172,9 @@ def _fft_correlations(self): # Update ring inner_diameter = outer_diameter - # Repack the table as (_a, _b, L//2) - correlations = np.swapaxes(correlations, 0, 2) - # Then unpack the original a and b shapes. - return correlations.reshape( - *self._a_stack_shape, *self._b_stack_shape, self.L // 2 - ) + # Repack the table as (..., L//2) + correlations = np.swapaxes(correlations, 0, -1) + return correlations def _nufft_correlations(self): """ @@ -215,11 +221,11 @@ def _nufft_correlations(self): # Stack signal data to create a larger NUFFT problem (better performance). # Note, we want a complex result. signal = np.vstack((self._a, self._b)) - # Compute NUFFT, then unpack as two 1D stacks of the polar grid - # points, one for each image. - f1, f2 = nufft(signal, fourier_pts, real=False).reshape( - 2, self._a.shape[0], len(r), -1 - ) + # Compute one large NUFFT for all the signal frames, + f = nufft(signal, fourier_pts, real=False) + # then unpack as two 1D stacks of the polar grid points, one for _a and _b. + f = f.reshape(self._a.shape[0] + self._b.shape[0], len(r), -1) + f1, f2 = np.split(f, self._a.shape[0]) # Compute the Fourier correlations. cov = np.sum(f1 * np.conj(f2), -1).real @@ -325,10 +331,10 @@ def plot(self, cutoff, save_to_file=False): plt.xlabel("Resolution (angstrom)") plt.ylabel("Correlation") plt.ylim([0, 1.1]) - plt.plot(freqs_angstrom, self.correlations[0][0]) + plt.plot(freqs_angstrom, self.correlations[0]) # Display cutoff plt.axhline(y=cutoff, color="r", linestyle="--", label=f"cutoff={cutoff}") - estimated_resolution = self.analyze_correlations(cutoff)[0][0] + estimated_resolution = self.analyze_correlations(cutoff)[0] # Display resolution plt.axvline( diff --git a/tests/test_fourier_correlation.py b/tests/test_fourier_correlation.py index 06da9de77f..af0c9e2f0a 100644 --- a/tests/test_fourier_correlation.py +++ b/tests/test_fourier_correlation.py @@ -20,13 +20,14 @@ IMG_SIZES = [ 64, - 65, + # 65, ] DTYPES = [ np.float64, - np.float32, + # np.float32, ] -METHOD = ["fft", "nufft"] +# METHOD = ["fft", "nufft"] +METHOD = ["fft"] @pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}") @@ -108,7 +109,7 @@ def test_frc_id(image_fixture, method): img, _, _ = image_fixture frc_resolution, frc = img.frc(img, pixel_size=1, cutoff=0.143, method=method) - assert np.isclose(frc_resolution[0][0], 1, rtol=0.02) + assert np.isclose(frc_resolution[0], 1, rtol=0.02) assert np.allclose(frc, 1) @@ -116,14 +117,14 @@ def test_frc_rot(image_fixture, method): img_a, img_b, _ = image_fixture assert img_a.dtype == img_b.dtype frc_resolution, frc = img_a.frc(img_b, pixel_size=1, cutoff=0.143, method=method) - assert np.isclose(frc_resolution[0][0], 1.89, rtol=0.1) + assert np.isclose(frc_resolution[0], 1.89, rtol=0.1) def test_frc_noise(image_fixture, method): img_a, _, img_n = image_fixture frc_resolution, frc = img_a.frc(img_n, pixel_size=1, cutoff=0.143, method=method) - assert np.isclose(frc_resolution[0][0], 1.75, rtol=0.2) + assert np.isclose(frc_resolution[0], 1.75, rtol=0.2) def test_frc_img_plot(image_fixture): @@ -170,7 +171,7 @@ def test_fsc_id(volume_fixture, method): vol, _ = volume_fixture fsc_resolution, fsc = vol.fsc(vol, pixel_size=1, cutoff=0.143, method=method) - assert np.isclose(fsc_resolution[0][0], 1, rtol=0.02) + assert np.isclose(fsc_resolution[0], 1, rtol=0.02) assert np.allclose(fsc, 1) @@ -178,11 +179,11 @@ def test_fsc_trunc(volume_fixture, method): vol_a, vol_b = volume_fixture fsc_resolution, fsc = vol_a.fsc(vol_b, pixel_size=1, cutoff=0.143, method=method) - assert fsc_resolution[0][0] > 1.5 + assert fsc_resolution[0] > 1.5 # The follow should correspond to the test_fsc_plot below. fsc_resolution, fsc = vol_a.fsc(vol_b, pixel_size=1, cutoff=0.5, method=method) - assert fsc_resolution[0][0] > 2.0 + assert fsc_resolution[0] > 2.0 def test_fsc_vol_plot(volume_fixture): From 1a1324bcb9f341326b39b123c4a9d5b5c1c18000 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 5 May 2023 15:06:48 -0400 Subject: [PATCH 049/116] Add few broadcasting tests and extend plotting. --- src/aspire/utils/resolution_estimation.py | 7 ++- tests/test_fourier_correlation.py | 58 +++++++++++++++++++++++ 2 files changed, 64 insertions(+), 1 deletion(-) diff --git a/src/aspire/utils/resolution_estimation.py b/src/aspire/utils/resolution_estimation.py index e96e6b02bc..6db50b2ca2 100644 --- a/src/aspire/utils/resolution_estimation.py +++ b/src/aspire/utils/resolution_estimation.py @@ -331,7 +331,12 @@ def plot(self, cutoff, save_to_file=False): plt.xlabel("Resolution (angstrom)") plt.ylabel("Correlation") plt.ylim([0, 1.1]) - plt.plot(freqs_angstrom, self.correlations[0]) + for i, line in enumerate(self.correlations): + _label = None + if len(self.correlations) > 1: + _label = f"{i}" + plt.plot(freqs_angstrom, line, label=_label) + # Display cutoff plt.axhline(y=cutoff, color="r", linestyle="--", label=f"cutoff={cutoff}") estimated_resolution = self.analyze_correlations(cutoff)[0] diff --git a/tests/test_fourier_correlation.py b/tests/test_fourier_correlation.py index af0c9e2f0a..65be0799b9 100644 --- a/tests/test_fourier_correlation.py +++ b/tests/test_fourier_correlation.py @@ -296,3 +296,61 @@ def test_vol_type_mismatch(): with pytest.raises(TypeError, match=r"`other` volume must be an `Volume` instance"): _ = a.fsc(b, pixel_size=1, cutoff=0.143) + + +# Broadcasting + + +def test_frc_id_bcast(image_fixture, method): + img, _, _ = image_fixture + + k = 3 + img_b = Image(np.tile(img, (3, 1, 1))) + + frc_resolution, frc = img.frc(img_b, pixel_size=1, cutoff=0.143, method=method) + assert np.allclose( + frc_resolution, + [ + 1.0, + ] + * k, + rtol=0.02, + ) + assert np.allclose(frc, 1.0) + + +def test_fsc_id_bcast(volume_fixture, method): + vol, _ = volume_fixture + + k = 3 + vol_b = Volume(np.tile(vol.asnumpy(), (3, 1, 1, 1))) + + fsc_resolution, fsc = vol.fsc(vol_b, pixel_size=1, cutoff=0.143, method=method) + assert np.allclose( + fsc_resolution, + [ + 1.0, + ] + * k, + rtol=0.02, + ) + assert np.allclose(fsc, 1.0) + + +def test_frc_img_plot_bcast(image_fixture): + """ + Smoke test Image.frc(plot=) passthrough. + """ + img_a, img_b, img_n = image_fixture + + img_b = Image(np.vstack((img_a, img_b, img_n))) + + # Plot to screen + with matplotlib_no_gui(): + _ = img_a.frc(img_b, pixel_size=1, cutoff=0.143, plot=True) + + # Plot to file + with tempfile.TemporaryDirectory() as tmp_input_dir: + file_path = os.path.join(tmp_input_dir, "img_frc_curve.png") + img_a.frc(img_b, pixel_size=1, cutoff=0.143, plot=file_path) + assert os.path.exists(file_path) From 049bebbce399e59b7cb7471af75909fa328d05a5 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 5 May 2023 15:40:05 -0400 Subject: [PATCH 050/116] Raise for the unhandled all pairs plotting --- src/aspire/utils/resolution_estimation.py | 13 ++++++++--- tests/test_fourier_correlation.py | 28 ++++++++++++++++++++--- 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/src/aspire/utils/resolution_estimation.py b/src/aspire/utils/resolution_estimation.py index 6db50b2ca2..c3701804e7 100644 --- a/src/aspire/utils/resolution_estimation.py +++ b/src/aspire/utils/resolution_estimation.py @@ -77,6 +77,8 @@ def __init__(self, a, b, pixel_size=1, method="fft"): # To support arbitrary broadcasting simply, # we'll force all shapes to be (-1, *(L,)*dim) # and keep track of the stack shapes. + self.a = a + self.b = b self._a, self._a_stack_shape = self._reshape(a) self._b, self._b_stack_shape = self._reshape(b) self._result_stack_shape = np.broadcast_shapes( @@ -142,8 +144,8 @@ def _fft_correlations(self): ] # Compute centered Fourier transforms. - f1 = fft.centered_fftn(self._a, axes=self._fourier_axes) - f2 = fft.centered_fftn(self._b, axes=self._fourier_axes) + f1 = fft.centered_fftn(self.a, axes=self._fourier_axes) + f2 = fft.centered_fftn(self.b, axes=self._fourier_axes) # Construct an output table of correlations correlations = np.zeros( @@ -323,7 +325,12 @@ def plot(self, cutoff, save_to_file=False): ) if np.prod(stack) > 1: raise RuntimeError( - f"Unable to plot figure tables with more than 1 figures, stack shape {stack}. Try reducing to a simpler request." + f"Unable to plot figure tables with more than 1 reference figures, stack shape {stack}. Try reducing to a simpler request." + ) + breakpoint() + if self._a_stack_shape[0] > 1 and self._a_stack_shape != self._b_stack_shape: + raise RuntimeError( + f"Unable to plot figure tables with more than 1 reference figures, stack shape {stack}. Try reducing to a simpler request." ) plt.figure(figsize=(8, 6)) diff --git a/tests/test_fourier_correlation.py b/tests/test_fourier_correlation.py index 65be0799b9..57942cb027 100644 --- a/tests/test_fourier_correlation.py +++ b/tests/test_fourier_correlation.py @@ -345,12 +345,34 @@ def test_frc_img_plot_bcast(image_fixture): img_b = Image(np.vstack((img_a, img_b, img_n))) - # Plot to screen + # Plot to screen, one:many with matplotlib_no_gui(): _ = img_a.frc(img_b, pixel_size=1, cutoff=0.143, plot=True) - # Plot to file + # Plot to file, many elementwise with tempfile.TemporaryDirectory() as tmp_input_dir: file_path = os.path.join(tmp_input_dir, "img_frc_curve.png") - img_a.frc(img_b, pixel_size=1, cutoff=0.143, plot=file_path) + img_b.frc(img_b, pixel_size=1, cutoff=0.143, plot=file_path) assert os.path.exists(file_path) + + +def test_plot_bad_bcast(image_fixture): + """ + When reference is a stack, we should raise when attempting to plot + anything other than 1:1 elementwise. + """ + img_a, img_b, img_n = image_fixture + img_b = np.vstack((img_a, img_b, img_n)) + + # many:many, all pairs for (3,) x (2,1) + with pytest.raises(RuntimeError, match="Unable to plot figure tables"): + FourierRingCorrelation(img_b, img_b[:2].reshape(2, 1, *img_b.shape[-2:])).plot( + cutoff=0.143 + ) + + # many:one + with pytest.raises(RuntimeError, match="Unable to plot figure tables"): + FourierRingCorrelation( + img_b, + img_a.asnumpy(), + ).plot(cutoff=0.143) From 033d7a27d17451ef1441ecb7df81b61acb7e949b Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Mon, 8 May 2023 08:20:24 -0400 Subject: [PATCH 051/116] rm breakpoint --- src/aspire/utils/resolution_estimation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aspire/utils/resolution_estimation.py b/src/aspire/utils/resolution_estimation.py index c3701804e7..0ec06b3738 100644 --- a/src/aspire/utils/resolution_estimation.py +++ b/src/aspire/utils/resolution_estimation.py @@ -327,7 +327,7 @@ def plot(self, cutoff, save_to_file=False): raise RuntimeError( f"Unable to plot figure tables with more than 1 reference figures, stack shape {stack}. Try reducing to a simpler request." ) - breakpoint() + if self._a_stack_shape[0] > 1 and self._a_stack_shape != self._b_stack_shape: raise RuntimeError( f"Unable to plot figure tables with more than 1 reference figures, stack shape {stack}. Try reducing to a simpler request." From 8cb6d43586cc65eaf5e0f6973a35dc293df6fcef Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Mon, 8 May 2023 11:51:21 -0400 Subject: [PATCH 052/116] More broadcasting tests and plot labels --- src/aspire/utils/resolution_estimation.py | 45 ++++++++------- tests/test_fourier_correlation.py | 67 ++++++++++++++++++++++- 2 files changed, 90 insertions(+), 22 deletions(-) diff --git a/src/aspire/utils/resolution_estimation.py b/src/aspire/utils/resolution_estimation.py index 0ec06b3738..b0efaa0ff1 100644 --- a/src/aspire/utils/resolution_estimation.py +++ b/src/aspire/utils/resolution_estimation.py @@ -38,13 +38,13 @@ class FourierCorrelation: def __init__(self, a, b, pixel_size=1, method="fft"): """ - :param a: Input array a, shape(..., *dim). - :param b: Input array b, shape(..., *dim). - :param pixel_size: Pixel size in angstrom. - Defaults to 1. - :param method: Selects either 'fft' (on Cartesian grid), - or 'nufft' (on polar grid). Defaults to 'fft'. - """ + :param a: Input array a, shape(..., *dim). + :param b: Input array b, shape(..., *dim). + :param pixel_size: Pixel size in angstrom. + Defaults to 1. + :param method: Selects either 'fft' (on Cartesian grid), + or 'nufft' (on polar grid). Defaults to 'fft'. + 7""" # Sanity checks if not hasattr(self, "dim"): @@ -176,7 +176,7 @@ def _fft_correlations(self): # Repack the table as (..., L//2) correlations = np.swapaxes(correlations, 0, -1) - return correlations + return correlations.reshape(*self._result_stack_shape, self.L // 2) def _nufft_correlations(self): """ @@ -227,7 +227,7 @@ def _nufft_correlations(self): f = nufft(signal, fourier_pts, real=False) # then unpack as two 1D stacks of the polar grid points, one for _a and _b. f = f.reshape(self._a.shape[0] + self._b.shape[0], len(r), -1) - f1, f2 = np.split(f, self._a.shape[0]) + f1, f2 = np.vsplit(f, [self._a.shape[0]]) # Compute the Fourier correlations. cov = np.sum(f1 * np.conj(f2), -1).real @@ -236,10 +236,8 @@ def _nufft_correlations(self): correlations = cov / (norm1 * norm2) - # Then unpack the original a and b shapes. - return correlations.reshape( - *self._a_stack_shape, *self._b_stack_shape, r.shape[-1] - ) + # Then unpack as original a and b broadcasted shapes. + return correlations.reshape(*self._result_stack_shape, r.shape[-1]) def analyze_correlations(self, cutoff): """ @@ -296,7 +294,7 @@ def _freq(self, k): # Similar idea to wavenumbers (cm-1). Larger is higher frequency. return k * 2 / (self.L * self.pixel_size) - def plot(self, cutoff, save_to_file=False): + def plot(self, cutoff, save_to_file=False, labels=None): """ Generates a Fourier correlation plot. @@ -318,21 +316,26 @@ def plot(self, cutoff, save_to_file=False): freqs_angstrom = 1 / freqs # Check we're asking for a reasonable plot. - stack = self.correlations.shape[: -self.dim] + stack = self.correlations.shape[:-1] if len(stack) > 2: raise RuntimeError( f"Unable to plot figure tables with more than 2 dim, stack shape {stack}. Try reducing to a simpler request." ) - if np.prod(stack) > 1: - raise RuntimeError( - f"Unable to plot figure tables with more than 1 reference figures, stack shape {stack}. Try reducing to a simpler request." - ) - if self._a_stack_shape[0] > 1 and self._a_stack_shape != self._b_stack_shape: + if ( + self._a_stack_shape[0] > 1 and self._a_stack_shape != self._b_stack_shape + ) or (len(stack) == 2 and 1 not in stack): raise RuntimeError( f"Unable to plot figure tables with more than 1 reference figures, stack shape {stack}. Try reducing to a simpler request." ) + # Check `labels` length when provided. + if labels is not None: + if len(labels) != len(self.correlations): + raise ValueError( + f"Check `labels`. Provided len(labels) != len(self.correlations): {len(labels)} != {len(self.correlations)}." + ) + plt.figure(figsize=(8, 6)) plt.title(self._plot_title) plt.xlabel("Resolution (angstrom)") @@ -342,6 +345,8 @@ def plot(self, cutoff, save_to_file=False): _label = None if len(self.correlations) > 1: _label = f"{i}" + if labels is not None: + _label = labels[i] plt.plot(freqs_angstrom, line, label=_label) # Display cutoff diff --git a/tests/test_fourier_correlation.py b/tests/test_fourier_correlation.py index 57942cb027..6195cc0259 100644 --- a/tests/test_fourier_correlation.py +++ b/tests/test_fourier_correlation.py @@ -267,6 +267,15 @@ def test_cutoff_range(): def test_2d_stack_plot_raise(): a = np.random.random((2, 3, 8, 8)).astype(np.float32) + with pytest.raises( + RuntimeError, match=r"Unable to plot figure tables with more than 1 reference" + ): + FourierRingCorrelation(a, a).plot(cutoff=0.143) + + +def test_3d_stack_plot_raise(): + a = np.random.random((2, 3, 4, 8, 8)).astype(np.float32) + with pytest.raises( RuntimeError, match=r"Unable to plot figure tables with more than 2 dim" ): @@ -275,11 +284,12 @@ def test_2d_stack_plot_raise(): def test_multiple_stack_plot_raise(): a = np.random.random((3, 8, 8)).astype(np.float32) + b = np.reshape(a, (3, 1, 8, 8)) with pytest.raises( - RuntimeError, match=r"Unable to plot figure tables with more than 1 figure" + RuntimeError, match=r"Unable to plot figure tables with more than 1 reference" ): - FourierRingCorrelation(a, a).plot(cutoff=0.143) + FourierRingCorrelation(a, b).plot(cutoff=0.143) def test_img_type_mismatch(): @@ -302,6 +312,9 @@ def test_vol_type_mismatch(): def test_frc_id_bcast(image_fixture, method): + """ + Test FRC for (1) x (3), (1) x (1,3) , (1) x (3,1). + """ img, _, _ = image_fixture k = 3 @@ -317,6 +330,37 @@ def test_frc_id_bcast(image_fixture, method): rtol=0.02, ) assert np.allclose(frc, 1.0) + assert frc_resolution.shape == (3,) + + # (1) x (1,3) + img_b = img_b.stack_reshape(1, 3) + + frc_resolution, frc = img.frc(img_b, pixel_size=1, cutoff=0.143, method=method) + assert np.allclose( + frc_resolution, + [ + 1.0, + ] + * k, + rtol=0.02, + ) + assert np.allclose(frc, 1.0) + assert frc_resolution.shape == (1, 3) + + # (1) x (3,1) + img_b = img_b.stack_reshape(3, 1) + + frc_resolution, frc = img.frc(img_b, pixel_size=1, cutoff=0.143, method=method) + assert np.allclose( + frc_resolution, + [ + 1.0, + ] + * k, + rtol=0.02, + ) + assert np.allclose(frc, 1.0) + assert frc_resolution.shape == (3, 1) def test_fsc_id_bcast(volume_fixture, method): @@ -376,3 +420,22 @@ def test_plot_bad_bcast(image_fixture): img_b, img_a.asnumpy(), ).plot(cutoff=0.143) + + +def test_plot_labels(image_fixture): + """ + When reference is a stack, we should raise when attempting to plot + anything other than 1:1 elementwise. + """ + img_a, img_b, img_n = image_fixture + img_b = np.vstack((img_a, img_b, img_n)) + + frc = FourierRingCorrelation(img_a.asnumpy(), img_b) + with matplotlib_no_gui(): + frc.plot(cutoff=0.143, labels=["abc", "easyas", "123"]) + + with pytest.raises(ValueError, match="Check `labels`"): + frc.plot(cutoff=0.143, labels=["abc", "easyas", "123", "toomany"]) + + with pytest.raises(ValueError, match="Check `labels`"): + frc.plot(cutoff=0.143, labels=["toofew"]) From 7239fb2c7f134f7c5c11041eba9a04744dfed32c Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Mon, 8 May 2023 13:55:54 -0400 Subject: [PATCH 053/116] Adjust factor of two for using only positive freqs --- src/aspire/utils/resolution_estimation.py | 16 ++-------------- tests/test_fourier_correlation.py | 20 ++++++++++---------- 2 files changed, 12 insertions(+), 24 deletions(-) diff --git a/src/aspire/utils/resolution_estimation.py b/src/aspire/utils/resolution_estimation.py index b0efaa0ff1..7ee3d78ba1 100644 --- a/src/aspire/utils/resolution_estimation.py +++ b/src/aspire/utils/resolution_estimation.py @@ -278,21 +278,9 @@ def _freq(self, k): :return: Frequency in 1/angstrom. """ - # From Shannon-Nyquist, for a given pixel-size, sampling theorem - # limits us to the sampled frequency 1/pixel_size. Thus the - # Bandwidth ranges from `[-1/pixel_size, 1/pixel_size]`, so the - # total bandwidth is `2*(1/pixel_size)`. - - # Given a real space signal observed with `L` bins - # (pixels/voxels), each with a `pixel_size` in angstrom, we can - # compute the width of a Fourier space bin to be the `Bandwidth - # / L = (2*(1/pixel_size)) / L`. Thus the frequency at an index - # `k` is `freq_k = k * 2 * (1 / pixel_size) / L = k * 2 / - # (pixel_size * L) - # _freq(k) Units: 1 / (pixels * (angstrom / pixel) = 1 / angstrom - # Similar idea to wavenumbers (cm-1). Larger is higher frequency. - return k * 2 / (self.L * self.pixel_size) + # Similar to wavenumbers. Larger is higher frequency. + return k / (self.L * self.pixel_size) def plot(self, cutoff, save_to_file=False, labels=None): """ diff --git a/tests/test_fourier_correlation.py b/tests/test_fourier_correlation.py index 6195cc0259..8a02b23bf0 100644 --- a/tests/test_fourier_correlation.py +++ b/tests/test_fourier_correlation.py @@ -109,7 +109,7 @@ def test_frc_id(image_fixture, method): img, _, _ = image_fixture frc_resolution, frc = img.frc(img, pixel_size=1, cutoff=0.143, method=method) - assert np.isclose(frc_resolution[0], 1, rtol=0.02) + assert np.isclose(frc_resolution[0], 2, rtol=0.02) assert np.allclose(frc, 1) @@ -117,14 +117,14 @@ def test_frc_rot(image_fixture, method): img_a, img_b, _ = image_fixture assert img_a.dtype == img_b.dtype frc_resolution, frc = img_a.frc(img_b, pixel_size=1, cutoff=0.143, method=method) - assert np.isclose(frc_resolution[0], 1.89, rtol=0.1) + assert np.isclose(frc_resolution[0], 3.78, rtol=0.1) def test_frc_noise(image_fixture, method): img_a, _, img_n = image_fixture frc_resolution, frc = img_a.frc(img_n, pixel_size=1, cutoff=0.143, method=method) - assert np.isclose(frc_resolution[0], 1.75, rtol=0.2) + assert np.isclose(frc_resolution[0], 3.5, rtol=0.2) def test_frc_img_plot(image_fixture): @@ -171,7 +171,7 @@ def test_fsc_id(volume_fixture, method): vol, _ = volume_fixture fsc_resolution, fsc = vol.fsc(vol, pixel_size=1, cutoff=0.143, method=method) - assert np.isclose(fsc_resolution[0], 1, rtol=0.02) + assert np.isclose(fsc_resolution[0], 2, rtol=0.02) assert np.allclose(fsc, 1) @@ -179,11 +179,11 @@ def test_fsc_trunc(volume_fixture, method): vol_a, vol_b = volume_fixture fsc_resolution, fsc = vol_a.fsc(vol_b, pixel_size=1, cutoff=0.143, method=method) - assert fsc_resolution[0] > 1.5 + assert fsc_resolution[0] > 3. # The follow should correspond to the test_fsc_plot below. fsc_resolution, fsc = vol_a.fsc(vol_b, pixel_size=1, cutoff=0.5, method=method) - assert fsc_resolution[0] > 2.0 + assert fsc_resolution[0] > 4.0 def test_fsc_vol_plot(volume_fixture): @@ -324,7 +324,7 @@ def test_frc_id_bcast(image_fixture, method): assert np.allclose( frc_resolution, [ - 1.0, + 2.0, ] * k, rtol=0.02, @@ -339,7 +339,7 @@ def test_frc_id_bcast(image_fixture, method): assert np.allclose( frc_resolution, [ - 1.0, + 2.0, ] * k, rtol=0.02, @@ -354,7 +354,7 @@ def test_frc_id_bcast(image_fixture, method): assert np.allclose( frc_resolution, [ - 1.0, + 2.0, ] * k, rtol=0.02, @@ -373,7 +373,7 @@ def test_fsc_id_bcast(volume_fixture, method): assert np.allclose( fsc_resolution, [ - 1.0, + 2.0, ] * k, rtol=0.02, From 18e91cce00ced06aacb93742fa9f763a0facda0a Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 9 May 2023 08:54:42 -0400 Subject: [PATCH 054/116] Add pixel mode default of None to FC methods --- src/aspire/image/image.py | 6 +++--- src/aspire/utils/resolution_estimation.py | 18 ++++++++++++------ src/aspire/volume/volume.py | 6 +++--- tests/test_fourier_correlation.py | 2 +- 4 files changed, 19 insertions(+), 13 deletions(-) diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index 96cc7ad9df..0ccc8807fc 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -477,7 +477,7 @@ def show(self, columns=5, figsize=(20, 10), colorbar=True): plt.show() - def frc(self, other, pixel_size, cutoff, method="fft", plot=False): + def frc(self, other, cutoff, pixel_size=None, method="fft", plot=False): r""" Compute the Fourier ring correlation between two images. @@ -491,9 +491,9 @@ def frc(self, other, pixel_size, cutoff, method="fft", plot=False): \sqrt{ \sum_i { | \mathcal{F}_1(i) |^2 } * \sum_i{| \mathcal{F}^{*}_2}(i) |^2 } } :param other: `Image` instance to compare. - :param pixel_size: Pixel size in Angstrom. - For synthetic data, 1 is a reasonable value. :param cutoff: Cutoff value, traditionally `.143`. + :param pixel_size: Pixel size in angstrom. Default `None` + implies unit in pixels, equivalent to pixel_size=1. :param method: Selects either 'fft' (on cartesian grid), or 'nufft' (on polar grid). Defaults to 'fft'. :param plot: Optionally plot to screen or file. diff --git a/src/aspire/utils/resolution_estimation.py b/src/aspire/utils/resolution_estimation.py index 7ee3d78ba1..255e6e5cbb 100644 --- a/src/aspire/utils/resolution_estimation.py +++ b/src/aspire/utils/resolution_estimation.py @@ -36,12 +36,12 @@ class FourierCorrelation: reference signal. """ - def __init__(self, a, b, pixel_size=1, method="fft"): + def __init__(self, a, b, pixel_size=None, method="fft"): """ :param a: Input array a, shape(..., *dim). :param b: Input array b, shape(..., *dim). :param pixel_size: Pixel size in angstrom. - Defaults to 1. + Default `None` implies "pixel" units. :param method: Selects either 'fft' (on Cartesian grid), or 'nufft' (on polar grid). Defaults to 'fft'. 7""" @@ -85,7 +85,13 @@ def __init__(self, a, b, pixel_size=1, method="fft"): self._a_stack_shape, self._b_stack_shape ) + # Handle `pixel_size` and `pixel_mode` + self._pixel_units = "angstrom" + if pixel_size is None: + pixel_size = 1.0 + self._pixel_units = "pixels" self.pixel_size = float(pixel_size) + self._correlations = None self.L = self._a.shape[-1] self.dtype = self._a.dtype @@ -301,7 +307,7 @@ def plot(self, cutoff, save_to_file=False, labels=None): freqs = self._freq(x_inds) # TODO: handle zero freq better with np.errstate(divide="ignore"): - freqs_angstrom = 1 / freqs + freqs_units = 1 / freqs # Check we're asking for a reasonable plot. stack = self.correlations.shape[:-1] @@ -326,7 +332,7 @@ def plot(self, cutoff, save_to_file=False, labels=None): plt.figure(figsize=(8, 6)) plt.title(self._plot_title) - plt.xlabel("Resolution (angstrom)") + plt.xlabel(f"Resolution ({self._pixel_units})") plt.ylabel("Correlation") plt.ylim([0, 1.1]) for i, line in enumerate(self.correlations): @@ -335,7 +341,7 @@ def plot(self, cutoff, save_to_file=False, labels=None): _label = f"{i}" if labels is not None: _label = labels[i] - plt.plot(freqs_angstrom, line, label=_label) + plt.plot(freqs_units, line, label=_label) # Display cutoff plt.axhline(y=cutoff, color="r", linestyle="--", label=f"cutoff={cutoff}") @@ -348,7 +354,7 @@ def plot(self, cutoff, save_to_file=False, labels=None): linestyle=":", label=f"Resolution={estimated_resolution:.3f}", ) - # x-axis in decreasing + # x-axis decreasing plt.gca().invert_xaxis() plt.legend(title=f"Method: {self.method}") diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index ae1a8564b5..b04d559efa 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -494,7 +494,7 @@ def load(cls, filename, permissive=True, dtype=np.float32): logger.info(f"{filename} with dtype {loaded_data.dtype} loaded as {dtype}") return cls(loaded_data.astype(dtype)) - def fsc(self, other, pixel_size, cutoff, method="fft", plot=False): + def fsc(self, other, cutoff, pixel_size=None, method="fft", plot=False): r""" Compute the Fourier shell correlation between two volumes. @@ -508,9 +508,9 @@ def fsc(self, other, pixel_size, cutoff, method="fft", plot=False): \sqrt{ \sum_i { | \mathcal{F}_1(i) |^2 } * \sum_i{| \mathcal{F}^{*}_2}(i) |^2 } } :param other: `Volume` instance to compare. - :param pixel_size: Pixel size in Angstrom. - For synthetic data, 1 is a reasonable value. :param cutoff: Cutoff value, traditionally `.143`. + :param pixel_size: Pixel size in angstrom. Default `None` + implies unit in pixels, equivalent to pixel_size=1. :param method: Selects either 'fft' (on cartesian grid), or 'nufft' (on polar grid). Defaults to 'fft'. :param plot: Optionally plot to screen or file. diff --git a/tests/test_fourier_correlation.py b/tests/test_fourier_correlation.py index 8a02b23bf0..63bdc2fff6 100644 --- a/tests/test_fourier_correlation.py +++ b/tests/test_fourier_correlation.py @@ -179,7 +179,7 @@ def test_fsc_trunc(volume_fixture, method): vol_a, vol_b = volume_fixture fsc_resolution, fsc = vol_a.fsc(vol_b, pixel_size=1, cutoff=0.143, method=method) - assert fsc_resolution[0] > 3. + assert fsc_resolution[0] > 3.0 # The follow should correspond to the test_fsc_plot below. fsc_resolution, fsc = vol_a.fsc(vol_b, pixel_size=1, cutoff=0.5, method=method) From 4ab888c4017f2b1b20bb22b5480b7374d526bf86 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 9 May 2023 08:55:59 -0400 Subject: [PATCH 055/116] restore all parameterized FC tests --- tests/test_fourier_correlation.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/test_fourier_correlation.py b/tests/test_fourier_correlation.py index 63bdc2fff6..0d8d8ac87e 100644 --- a/tests/test_fourier_correlation.py +++ b/tests/test_fourier_correlation.py @@ -20,14 +20,13 @@ IMG_SIZES = [ 64, - # 65, + 65, ] DTYPES = [ np.float64, - # np.float32, + np.float32, ] -# METHOD = ["fft", "nufft"] -METHOD = ["fft"] +METHOD = ["fft", "nufft"] @pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}") From 35a928e3b6816036bda83bd51d4ef696f82d7cf5 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 9 May 2023 09:49:18 -0400 Subject: [PATCH 056/116] convert volume fixture to use grid_3d --- tests/test_fourier_correlation.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/tests/test_fourier_correlation.py b/tests/test_fourier_correlation.py index 0d8d8ac87e..67326a9f14 100644 --- a/tests/test_fourier_correlation.py +++ b/tests/test_fourier_correlation.py @@ -9,7 +9,12 @@ from aspire.noise import BlueNoiseAdder from aspire.numeric import fft from aspire.source import Simulation -from aspire.utils import FourierRingCorrelation, FourierShellCorrelation, Rotation +from aspire.utils import ( + FourierRingCorrelation, + FourierShellCorrelation, + Rotation, + grid_3d, +) from aspire.volume import Volume from .test_utils import matplotlib_no_gui @@ -89,12 +94,10 @@ def volume_fixture(img_size, dtype): # Invert correlation for some high frequency content # Convert volume to Fourier space. vol_trunc_f = fft.centered_fftn(vol.asnumpy()[0]) - # Get a frequency index. - trunc_frq = img_size // 3 - # Negate the power for some frequencies higher than `trunc_frq`. - vol_trunc_f[-trunc_frq:, :, :] *= -1.0 - vol_trunc_f[:, -trunc_frq:, :] *= -1.0 - vol_trunc_f[:, :, -trunc_frq:] *= -1.0 + # Get high frequency indices + trunc_frq = grid_3d(img_size, normalized=True)["r"] > 1 / 2 + # Negate the power for high freq content + vol_trunc_f[trunc_frq] *= -1.0 # Convert volume from Fourier space to real space Volume. vol_trunc = Volume(fft.centered_ifftn(vol_trunc_f).real) @@ -182,7 +185,7 @@ def test_fsc_trunc(volume_fixture, method): # The follow should correspond to the test_fsc_plot below. fsc_resolution, fsc = vol_a.fsc(vol_b, pixel_size=1, cutoff=0.5, method=method) - assert fsc_resolution[0] > 4.0 + assert fsc_resolution[0] > 3.9 def test_fsc_vol_plot(volume_fixture): From 04a0c9dca96c52c76133c81b0dcd31ab8fd87c44 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 9 May 2023 11:05:12 -0400 Subject: [PATCH 057/116] convert image fixture to use spectral manipulation, grid_2d --- tests/test_fourier_correlation.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/tests/test_fourier_correlation.py b/tests/test_fourier_correlation.py index 67326a9f14..f0d491a6a2 100644 --- a/tests/test_fourier_correlation.py +++ b/tests/test_fourier_correlation.py @@ -12,7 +12,7 @@ from aspire.utils import ( FourierRingCorrelation, FourierShellCorrelation, - Rotation, + grid_2d, grid_3d, ) from aspire.volume import Volume @@ -60,25 +60,31 @@ def image_fixture(img_size, dtype): ).downsample(img_size) # Instantiate ASPIRE's Rotation class with a set of angles. - thetas = [0, 0.123] - rots = Rotation.about_axis("z", thetas, dtype=dtype) - # Contruct the Simulation source. noisy_src = Simulation( L=img_size, - n=2, + n=1, vols=v, offsets=0, amplitudes=1, C=1, - angles=rots.angles, noise_adder=BlueNoiseAdder.from_snr(2), dtype=dtype, ) - img, img_rot = noisy_src.clean_images[:] + img = noisy_src.clean_images[0] img_noisy = noisy_src.images[0] - return img, img_rot, img_noisy + # Invert correlation for some high frequency content + # Convert image to Fourier space. + img_trunc_f = fft.centered_fftn(img.asnumpy()[0]) + # Get high frequency indices + trunc_frq = grid_2d(img_size, normalized=True)["r"] > 1 / 2 + # Negate the power for high freq content + img_trunc_f[trunc_frq] *= -1.0 + # Convert imgume from Fourier space to real space Imgume. + img_trunc = Image(fft.centered_ifftn(img_trunc_f).real) + + return img, img_trunc, img_noisy @pytest.fixture @@ -115,18 +121,18 @@ def test_frc_id(image_fixture, method): assert np.allclose(frc, 1) -def test_frc_rot(image_fixture, method): +def test_frc_trunc(image_fixture, method): img_a, img_b, _ = image_fixture assert img_a.dtype == img_b.dtype frc_resolution, frc = img_a.frc(img_b, pixel_size=1, cutoff=0.143, method=method) - assert np.isclose(frc_resolution[0], 3.78, rtol=0.1) + assert frc_resolution[0] > 3.0 def test_frc_noise(image_fixture, method): img_a, _, img_n = image_fixture frc_resolution, frc = img_a.frc(img_n, pixel_size=1, cutoff=0.143, method=method) - assert np.isclose(frc_resolution[0], 3.5, rtol=0.2) + assert frc_resolution[0] > 3.5 def test_frc_img_plot(image_fixture): From 726e0ab01f5609f7e882bb494fb0692823750264 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 9 May 2023 13:21:46 -0400 Subject: [PATCH 058/116] Remove deprecated comment line. --- tests/test_fourier_correlation.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_fourier_correlation.py b/tests/test_fourier_correlation.py index f0d491a6a2..24667e861e 100644 --- a/tests/test_fourier_correlation.py +++ b/tests/test_fourier_correlation.py @@ -59,8 +59,7 @@ def image_fixture(img_size, dtype): np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol.npy")), dtype=dtype ).downsample(img_size) - # Instantiate ASPIRE's Rotation class with a set of angles. - # Contruct the Simulation source. + # Contruct the Simulation source to generate a noisy image. noisy_src = Simulation( L=img_size, n=1, From 68bf94aca580eee7b562dd3a945a25dbd5efa802 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 11 Apr 2023 14:20:33 -0400 Subject: [PATCH 059/116] add expensive tests for (F)FBbases. --- tests/test_FBbasis2D.py | 30 ++++++++++++++++++++++++++++++ tests/test_FBbasis3D.py | 33 ++++++++++++++++++++++++++++++--- tests/test_FFBbasis2D.py | 29 +++++++++++++++++++++++++++++ tests/test_FFBbasis3D.py | 24 +++++++++++++++++++++++- 4 files changed, 112 insertions(+), 4 deletions(-) diff --git a/tests/test_FBbasis2D.py b/tests/test_FBbasis2D.py index 9b60c0e788..1012ca7fdd 100644 --- a/tests/test_FBbasis2D.py +++ b/tests/test_FBbasis2D.py @@ -7,6 +7,7 @@ from aspire.basis import FBBasis2D from aspire.image import Image +from aspire.source import Simulation from aspire.utils import complex_type, real_type from aspire.utils.coor_trans import grid_2d from aspire.utils.random import randn @@ -134,3 +135,32 @@ def testComplexCoversionErrorsToReal(self, basis): # Try a 0d vector, should not crash. _ = basis.to_real(cv1.reshape(-1)) + + +params = [pytest.param(256, np.float32, marks=pytest.mark.expensive)] + + +@pytest.mark.parametrize( + "L, dtype", + params, +) +def testHighResFBBasis2D(L, dtype): + seed = 42 + basis = FBBasis2D(L, dtype=dtype) + sim = Simulation( + n=1, + L=L, + dtype=dtype, + amplitudes=1, + offsets=0, + seed=seed, + ) + im = sim.images[0] + + # Round trip + coeff = basis.evaluate_t(im) + FB_im = basis.evaluate(coeff) + + # Mask to compare inside disk of radius 1. + mask = grid_2d(L, normalized=True)["r"] < 1 + assert np.allclose(FB_im.asnumpy()[0][mask], im.asnumpy()[0][mask], atol=3e-3) diff --git a/tests/test_FBbasis3D.py b/tests/test_FBbasis3D.py index 251adbc80d..3aff4ff5ab 100644 --- a/tests/test_FBbasis3D.py +++ b/tests/test_FBbasis3D.py @@ -4,8 +4,8 @@ import pytest from aspire.basis import FBBasis3D -from aspire.utils import utest_tolerance -from aspire.volume import Volume +from aspire.utils import grid_3d, utest_tolerance +from aspire.volume import AsymmetricVolume, Volume from ._basis_util import UniversalBasisMixin, basis_params_3d, show_basis_params @@ -18,7 +18,6 @@ class TestFBBasis3D(UniversalBasisMixin): def testFBBasis3DIndices(self, basis): indices = basis.indices() - assert np.allclose( indices["ells"], [ @@ -686,3 +685,31 @@ def testFBBasis3DExpand(self, basis): ], atol=utest_tolerance(basis.dtype), ) + + +# NOTE: This test is failing for L=64. `coeff_0` has a few NANs which propogate into `vol_1`. See GH issue #923 +params = [pytest.param(64, np.float32, marks=pytest.mark.expensive)] + + +@pytest.mark.parametrize( + "L, dtype", + params, +) +@pytest.mark.skip(reason="Failing for L=64 due to NaN values.") +def testHighResFBbasis3D(L, dtype): + seed = 42 + basis = FBBasis3D(L, dtype=dtype) + vol_0 = AsymmetricVolume(L=L, C=1, K=64, dtype=dtype, seed=seed).generate() + + # First round trip + coeff_0 = basis.evaluate_t(vol_0) + vol_1 = basis.evaluate(coeff_0) + + # Second round trip + coeff_1 = basis.evaluate_t(vol_1) + vol_2 = basis.evaluate(coeff_1) + + # Mask to compare inside sphere of radius 1. + mask = grid_3d(L, normalized=True)["r"] < 1 + + assert np.allclose(vol_2.asnumpy()[0][mask], vol_1.asnumpy()[0][mask], atol=0.007) diff --git a/tests/test_FFBbasis2D.py b/tests/test_FFBbasis2D.py index 05830176f2..80479b65a4 100644 --- a/tests/test_FFBbasis2D.py +++ b/tests/test_FFBbasis2D.py @@ -213,3 +213,32 @@ def testShift(self, basis): rmse = np.sqrt(np.mean(np.square(diff), axis=(1, 2))) logger.info(f"RMSE shifted image diffs {rmse}") assert np.allclose(rmse, 0, atol=1e-5) + + +params = [pytest.param(512, np.float32, marks=pytest.mark.expensive)] + + +@pytest.mark.parametrize( + "L, dtype", + params, +) +def testHighResFFBBasis2D(L, dtype): + seed = 42 + basis = FFBBasis2D(L, dtype=dtype) + sim = Simulation( + n=1, + L=L, + dtype=dtype, + amplitudes=1, + offsets=0, + seed=seed, + ) + im = sim.images[0] + + # Round trip + coeff = basis.evaluate_t(im) + FB_im = basis.evaluate(coeff) + + # Mask to compare inside disk of radius 1. + mask = grid_2d(L, normalized=True)["r"] < 1 + assert np.allclose(FB_im.asnumpy()[0][mask], im.asnumpy()[0][mask], atol=1e-4) diff --git a/tests/test_FFBbasis3D.py b/tests/test_FFBbasis3D.py index 5b3ce0d602..9b78758831 100644 --- a/tests/test_FFBbasis3D.py +++ b/tests/test_FFBbasis3D.py @@ -4,7 +4,8 @@ import pytest from aspire.basis import FFBBasis3D -from aspire.volume import Volume +from aspire.utils import grid_3d +from aspire.volume import AsymmetricVolume, Volume from ._basis_util import UniversalBasisMixin, basis_params_3d, show_basis_params @@ -486,3 +487,24 @@ def testFFBBasis3DExpand(self, basis): ] assert np.allclose(result, ref, atol=1e-2) + + +params = [pytest.param(256, np.float32, marks=pytest.mark.expensive)] + + +@pytest.mark.parametrize( + "L, dtype", + params, +) +def testHighResFFBbasis3D(L, dtype): + seed = 42 + basis = FFBBasis3D(L, dtype=dtype) + vol_0 = AsymmetricVolume(L=L, C=1, K=64, dtype=dtype, seed=seed).generate() + + # Round trip + coeff_0 = basis.evaluate_t(vol_0) + vol_1 = basis.evaluate(coeff_0) + + # Mask to compare inside sphere of radius 1. + mask = grid_3d(L, normalized=True)["r"] < 1 + assert np.allclose(vol_1.asnumpy()[0][mask], vol_0.asnumpy()[0][mask], atol=1e-3) From 4cc3503a27f990322705d83cd99d51d941d6f71b Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 4 May 2023 09:30:48 -0400 Subject: [PATCH 060/116] Use FBBasis2D.expand. Fix variable names. --- tests/test_FBbasis2D.py | 6 +++--- tests/test_FFBbasis2D.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_FBbasis2D.py b/tests/test_FBbasis2D.py index 1012ca7fdd..c5e9c555ce 100644 --- a/tests/test_FBbasis2D.py +++ b/tests/test_FBbasis2D.py @@ -158,9 +158,9 @@ def testHighResFBBasis2D(L, dtype): im = sim.images[0] # Round trip - coeff = basis.evaluate_t(im) - FB_im = basis.evaluate(coeff) + coeff = basis.expand(im) + im_fb = basis.evaluate(coeff) # Mask to compare inside disk of radius 1. mask = grid_2d(L, normalized=True)["r"] < 1 - assert np.allclose(FB_im.asnumpy()[0][mask], im.asnumpy()[0][mask], atol=3e-3) + assert np.allclose(im_fb.asnumpy()[0][mask], im.asnumpy()[0][mask], atol=2e-5) diff --git a/tests/test_FFBbasis2D.py b/tests/test_FFBbasis2D.py index 80479b65a4..1796bbf0e7 100644 --- a/tests/test_FFBbasis2D.py +++ b/tests/test_FFBbasis2D.py @@ -237,8 +237,8 @@ def testHighResFFBBasis2D(L, dtype): # Round trip coeff = basis.evaluate_t(im) - FB_im = basis.evaluate(coeff) + im_ffb = basis.evaluate(coeff) # Mask to compare inside disk of radius 1. mask = grid_2d(L, normalized=True)["r"] < 1 - assert np.allclose(FB_im.asnumpy()[0][mask], im.asnumpy()[0][mask], atol=1e-4) + assert np.allclose(im_ffb.asnumpy()[0][mask], im.asnumpy()[0][mask], atol=1e-4) From f214933262c123e236b0f9045e65b0bd37069b22 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 4 May 2023 15:34:35 -0400 Subject: [PATCH 061/116] use expand in FB3D test. Only one round trip. --- tests/test_FBbasis3D.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/tests/test_FBbasis3D.py b/tests/test_FBbasis3D.py index 3aff4ff5ab..02ba88109a 100644 --- a/tests/test_FBbasis3D.py +++ b/tests/test_FBbasis3D.py @@ -699,17 +699,13 @@ def testFBBasis3DExpand(self, basis): def testHighResFBbasis3D(L, dtype): seed = 42 basis = FBBasis3D(L, dtype=dtype) - vol_0 = AsymmetricVolume(L=L, C=1, K=64, dtype=dtype, seed=seed).generate() + vol = AsymmetricVolume(L=L, C=1, K=64, dtype=dtype, seed=seed).generate() - # First round trip - coeff_0 = basis.evaluate_t(vol_0) - vol_1 = basis.evaluate(coeff_0) - - # Second round trip - coeff_1 = basis.evaluate_t(vol_1) - vol_2 = basis.evaluate(coeff_1) + # Round trip + coeff = basis.expand(vol) + vol_fb = basis.evaluate(coeff) # Mask to compare inside sphere of radius 1. mask = grid_3d(L, normalized=True)["r"] < 1 - assert np.allclose(vol_2.asnumpy()[0][mask], vol_1.asnumpy()[0][mask], atol=0.007) + assert np.allclose(vol_fb.asnumpy()[0][mask], vol.asnumpy()[0][mask], atol=4e-2) From 4b7b8c47d647310a38c7f1136d4cbc67f5340b78 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 5 May 2023 08:34:23 -0400 Subject: [PATCH 062/116] more consistent variable names. --- tests/test_FFBbasis3D.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_FFBbasis3D.py b/tests/test_FFBbasis3D.py index 9b78758831..c8879ab2a8 100644 --- a/tests/test_FFBbasis3D.py +++ b/tests/test_FFBbasis3D.py @@ -499,12 +499,12 @@ def testFFBBasis3DExpand(self, basis): def testHighResFFBbasis3D(L, dtype): seed = 42 basis = FFBBasis3D(L, dtype=dtype) - vol_0 = AsymmetricVolume(L=L, C=1, K=64, dtype=dtype, seed=seed).generate() + vol = AsymmetricVolume(L=L, C=1, K=64, dtype=dtype, seed=seed).generate() # Round trip - coeff_0 = basis.evaluate_t(vol_0) - vol_1 = basis.evaluate(coeff_0) + coeff = basis.evaluate_t(vol) + vol_ffb = basis.evaluate(coeff) # Mask to compare inside sphere of radius 1. mask = grid_3d(L, normalized=True)["r"] < 1 - assert np.allclose(vol_1.asnumpy()[0][mask], vol_0.asnumpy()[0][mask], atol=1e-3) + assert np.allclose(vol_ffb.asnumpy()[0][mask], vol.asnumpy()[0][mask], atol=1e-3) From 82a83d6457acdf1e8624b89b8109f806e9a4d42b Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 12 May 2023 13:17:23 -0400 Subject: [PATCH 063/116] Skip some problematic FLE fpu tests for now --- tests/test_FLEbasis2D.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/test_FLEbasis2D.py b/tests/test_FLEbasis2D.py index 96138ebff4..030690b337 100644 --- a/tests/test_FLEbasis2D.py +++ b/tests/test_FLEbasis2D.py @@ -22,7 +22,7 @@ def show_fle_params(basis): def gpu_ci_skip(): - pytest.skip("1e-7 precision for FLEBasis2D.evaluate()") + pytest.skip("1e-7 precision for FLEBasis2D") fle_params = [ @@ -71,6 +71,9 @@ def relerr(base, approx): class TestFLEBasis2D(UniversalBasisMixin): # check closeness guarantees for fast vs dense matrix method def testFastVDense_T(self, basis): + if backend_available("cufinufft") and basis.epsilon == 1e-7: + gpu_ci_skip() + dense_b = basis._create_dense_matrix() # create sample particle @@ -157,6 +160,9 @@ def testMatchFBDenseEvaluate(basis): @pytest.mark.parametrize("basis", test_bases_match_fb, ids=show_fle_params) def testMatchFBEvaluate_t(basis): + if backend_available("cufinufft") and basis.epsilon == 1e-7: + gpu_ci_skip() + # ensure that coefficients are the same when evaluating images fb_basis = FBBasis2D(basis.nres, dtype=np.float64) From 69b4952fc4a38bbd11baac5f53ce7a2be8747fac Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 12 May 2023 13:34:11 -0400 Subject: [PATCH 064/116] loosen fc id case tolerance for cufinufft --- tests/test_fourier_correlation.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_fourier_correlation.py b/tests/test_fourier_correlation.py index 24667e861e..13ef32e5f9 100644 --- a/tests/test_fourier_correlation.py +++ b/tests/test_fourier_correlation.py @@ -117,7 +117,7 @@ def test_frc_id(image_fixture, method): frc_resolution, frc = img.frc(img, pixel_size=1, cutoff=0.143, method=method) assert np.isclose(frc_resolution[0], 2, rtol=0.02) - assert np.allclose(frc, 1) + assert np.allclose(frc, 1, rtol=0.0001) def test_frc_trunc(image_fixture, method): @@ -179,7 +179,7 @@ def test_fsc_id(volume_fixture, method): fsc_resolution, fsc = vol.fsc(vol, pixel_size=1, cutoff=0.143, method=method) assert np.isclose(fsc_resolution[0], 2, rtol=0.02) - assert np.allclose(fsc, 1) + assert np.allclose(fsc, 1, rtol=0.0001) def test_fsc_trunc(volume_fixture, method): @@ -336,7 +336,7 @@ def test_frc_id_bcast(image_fixture, method): * k, rtol=0.02, ) - assert np.allclose(frc, 1.0) + assert np.allclose(frc, 1.0, rtol=0.0001) assert frc_resolution.shape == (3,) # (1) x (1,3) @@ -351,7 +351,7 @@ def test_frc_id_bcast(image_fixture, method): * k, rtol=0.02, ) - assert np.allclose(frc, 1.0) + assert np.allclose(frc, 1.0, rtol=0.0001) assert frc_resolution.shape == (1, 3) # (1) x (3,1) @@ -366,7 +366,7 @@ def test_frc_id_bcast(image_fixture, method): * k, rtol=0.02, ) - assert np.allclose(frc, 1.0) + assert np.allclose(frc, 1.0, rtol=0.0001) assert frc_resolution.shape == (3, 1) @@ -385,7 +385,7 @@ def test_fsc_id_bcast(volume_fixture, method): * k, rtol=0.02, ) - assert np.allclose(fsc, 1.0) + assert np.allclose(fsc, 1.0, rtol=0.0001) def test_frc_img_plot_bcast(image_fixture): From c3c2593f2c91c12fcbaf72d30e61e534dedd3caf Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Mon, 15 May 2023 11:47:20 -0400 Subject: [PATCH 065/116] loosen fc id case tolerance for cufinufft --- tests/test_fourier_correlation.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_fourier_correlation.py b/tests/test_fourier_correlation.py index 13ef32e5f9..476e2f5548 100644 --- a/tests/test_fourier_correlation.py +++ b/tests/test_fourier_correlation.py @@ -117,7 +117,7 @@ def test_frc_id(image_fixture, method): frc_resolution, frc = img.frc(img, pixel_size=1, cutoff=0.143, method=method) assert np.isclose(frc_resolution[0], 2, rtol=0.02) - assert np.allclose(frc, 1, rtol=0.0001) + assert np.allclose(frc, 1, rtol=0.01) def test_frc_trunc(image_fixture, method): @@ -179,7 +179,7 @@ def test_fsc_id(volume_fixture, method): fsc_resolution, fsc = vol.fsc(vol, pixel_size=1, cutoff=0.143, method=method) assert np.isclose(fsc_resolution[0], 2, rtol=0.02) - assert np.allclose(fsc, 1, rtol=0.0001) + assert np.allclose(fsc, 1, rtol=0.01) def test_fsc_trunc(volume_fixture, method): @@ -336,7 +336,7 @@ def test_frc_id_bcast(image_fixture, method): * k, rtol=0.02, ) - assert np.allclose(frc, 1.0, rtol=0.0001) + assert np.allclose(frc, 1.0, rtol=0.01) assert frc_resolution.shape == (3,) # (1) x (1,3) @@ -351,7 +351,7 @@ def test_frc_id_bcast(image_fixture, method): * k, rtol=0.02, ) - assert np.allclose(frc, 1.0, rtol=0.0001) + assert np.allclose(frc, 1.0, rtol=0.01) assert frc_resolution.shape == (1, 3) # (1) x (3,1) @@ -366,7 +366,7 @@ def test_frc_id_bcast(image_fixture, method): * k, rtol=0.02, ) - assert np.allclose(frc, 1.0, rtol=0.0001) + assert np.allclose(frc, 1.0, rtol=0.01) assert frc_resolution.shape == (3, 1) @@ -385,7 +385,7 @@ def test_fsc_id_bcast(volume_fixture, method): * k, rtol=0.02, ) - assert np.allclose(fsc, 1.0, rtol=0.0001) + assert np.allclose(fsc, 1.0, rtol=0.01) def test_frc_img_plot_bcast(image_fixture): From bf1d85ab15247dd3b5d2c408a6adc6ea7cf3b71f Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 10 May 2023 08:17:52 -0400 Subject: [PATCH 066/116] fix whitespace breaking intro table render --- gallery/tutorials/aspire_introduction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gallery/tutorials/aspire_introduction.py b/gallery/tutorials/aspire_introduction.py index 3ca8d8dc6e..8022a72402 100644 --- a/gallery/tutorials/aspire_introduction.py +++ b/gallery/tutorials/aspire_introduction.py @@ -649,7 +649,7 @@ def noise_function(x, y): # %% # +----------------+--------------------+-----------------+----------------+---------------------+ -# | Image Processing | Ab initio | +# | Image Processing | Ab initio | # +----------------+--------------------+-----------------+----------------+---------------------+ # | Data | Preprocessing | Denoising | Orientation | 3D Reconstruction | # +================+====================+=================+================+=====================+ From e4f9d5dd51c77a4289e8e1904467295731cdb435 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 10 May 2023 08:18:46 -0400 Subject: [PATCH 067/116] change pipeline vol download path until we have data downloader --- gallery/tutorials/pipeline_demo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gallery/tutorials/pipeline_demo.py b/gallery/tutorials/pipeline_demo.py index 03c4d2d64e..f286c1eac9 100644 --- a/gallery/tutorials/pipeline_demo.py +++ b/gallery/tutorials/pipeline_demo.py @@ -31,7 +31,7 @@ def download(url, save_path, chunk_size=1024 * 1024): fd.write(chunk) -file_path = os.path.join(os.getcwd(), "data", "emd_2660.map") +file_path = os.path.join(os.getcwd(), "emd_2660.map") if not os.path.exists(file_path): url = "https://ftp.ebi.ac.uk/pub/databases/emdb/structures/EMD-2660/map/emd_2660.map.gz" download(url, file_path) From 352bba06a7f7fa00a49e5de467b0127eac959e57 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 10 May 2023 08:23:01 -0400 Subject: [PATCH 068/116] remove unused params from pipeline demo. --- gallery/tutorials/pipeline_demo.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/gallery/tutorials/pipeline_demo.py b/gallery/tutorials/pipeline_demo.py index f286c1eac9..2380541f5a 100644 --- a/gallery/tutorials/pipeline_demo.py +++ b/gallery/tutorials/pipeline_demo.py @@ -107,7 +107,6 @@ def download(url, save_path, chunk_size=1024 * 1024): from aspire.source import Simulation # set parameters -res = 41 n_imgs = 2500 # SNR target for white gaussian noise. @@ -170,7 +169,7 @@ def download(url, save_path, chunk_size=1024 * 1024): # practice, the selection is done by sorting class averages based on # some configurable notion of quality. -from aspire.classification import RIRClass2D, TopClassSelector +from aspire.classification import RIRClass2D # set parameters n_classes = 200 From 1e722f9816648d860bc453eacc8e8c58d451111e Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 23 May 2023 09:15:07 -0400 Subject: [PATCH 069/116] add sphinxcontrib.jquery extension. reformat conf.py. --- docs/source/conf.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index b4cd64d907..ca036350b3 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -39,6 +39,7 @@ "sphinx.ext.autodoc", "sphinx.ext.mathjax", "sphinxcontrib.bibtex", + "sphinxcontrib.jquery", "sphinxcontrib.mermaid", "sphinx.ext.napoleon", "sphinx_gallery.gen_gallery", @@ -47,14 +48,20 @@ # Sphinx-Gallery Configuration sphinx_gallery_conf = { - 'examples_dirs': ['../../gallery/tutorials', '../../gallery/experiments'], # path to your example scripts - 'gallery_dirs': ['auto_tutorials', 'auto_experiments'], # path to where to save gallery generated output - 'download_all_examples': False, - 'plot_gallery': 'True', - 'remove_config_comments': True, - 'notebook_images': True, - 'within_subsection_order': ExampleTitleSortKey, - 'filename_pattern': r'/tutorials/.*\.py', # Parse all gallery python files, but only execute tutorials. + "examples_dirs": [ + "../../gallery/tutorials", + "../../gallery/experiments", + ], # path to your example scripts + "gallery_dirs": [ + "auto_tutorials", + "auto_experiments", + ], # path to where to save gallery generated output + "download_all_examples": False, + "plot_gallery": "True", + "remove_config_comments": True, + "notebook_images": True, + "within_subsection_order": ExampleTitleSortKey, + "filename_pattern": r"/tutorials/.*\.py", # Parse all gallery python files, but only execute tutorials. } # Add any paths that contain templates here, relative to this directory. @@ -86,7 +93,7 @@ # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. -language = 'en' +language = "en" # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. From b1bdc20e475bef16313a993f3b609980b7e9adba Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Mon, 1 May 2023 14:06:15 -0400 Subject: [PATCH 070/116] Always use FB ordering for FLE coef --- src/aspire/basis/fle_2d.py | 43 ++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/src/aspire/basis/fle_2d.py b/src/aspire/basis/fle_2d.py index 77f8d8d286..7f121747ff 100644 --- a/src/aspire/basis/fle_2d.py +++ b/src/aspire/basis/fle_2d.py @@ -47,11 +47,6 @@ def __init__( be set to that of `FBBasis2D`, and the FLE frequency thresholding procedure to reduce the number of functions will not be carried out. This means the number of basis functions for a given image size will be identical across the two bases. - - The signs of basis functions and coefficients with `sgn == 1` will be flipped relative to - the original FLE implementation, to match FB. - - The basis functions returned will be reordered according to the FB ordering, that is, first - by `ell`s, then by `sgn`s, then by `k`s. - """ if isinstance(size, int): size = (size, size) @@ -102,6 +97,8 @@ def _build_indices(self): self.angular_indices = np.abs(self.ells) self.radial_indices = self.ks - 1 self.signs_indices = np.sign(self.ells) + # Use the FB2D ells sign convention of `1` for `ell=0` + self.signs_indices[self.ells == 0] = 1 def indices(self): """ @@ -115,14 +112,13 @@ def indices(self): def _generate_fb_compat_indices(self): """ - Generate indices to shuffle basis function ordering and flip signs in order - to match `FBBasis2D`. + Generate indices to shuffle basis function ordering. """ ind = self.indices() - # basis function ordering - self.fb_compat_indices = np.lexsort((ind["ks"], ind["sgns"], ind["ells"])) - # flip signs - self.flip_sign_indices = np.where(self.signs_indices == 1) + # basis function ordering (used during evaluate_t output) + self.fle_to_fb_indices = np.lexsort((ind["ks"], ind["sgns"], ind["ells"])) + # store the reverse mapping (used during evaluate input) + self.fb_to_fle_indices = np.argsort(self.fle_to_fb_indices) def _precomp(self): """ @@ -460,11 +456,10 @@ def _evaluate(self, coeffs): be evaluated. The last dimension must be equal to `self.count` :return: An Image object containing the corresponding images. """ - if self.match_fb: - # sign of basis functions with positive indices flipped relative to FB2d - coeffs[self.flip_sign_indices] *= -1.0 - # reorder coefficients by FB2d ordering - coeffs = coeffs[self.fb_compat_indices] + # convert from FB order + coeffs = coeffs[..., self.fb_to_fle_indices] + inds = (self.signs_indices == 1) & (self.ells != 0) + coeffs[..., inds] = coeffs[..., inds] * -1 # See Remark 3.3 and Section 3.4 betas = self._step3(coeffs) @@ -486,10 +481,12 @@ def _evaluate_t(self, imgs): z = self._step1_t(imgs) b = self._step2_t(z) coeffs = self._step3_t(b) - if self.match_fb: - coeffs[:, self.flip_sign_indices] *= -1.0 - coeffs = coeffs[:, self.fb_compat_indices] - return coeffs.astype(self.coefficient_dtype) + + # return in FB order + inds = (self.signs_indices == 1) & (self.ells != 0) + coeffs[..., inds] = coeffs[..., inds] * -1 + coeffs = coeffs[..., self.fle_to_fb_indices] + return coeffs.astype(self.coefficient_dtype, copy=False) def _step1_t(self, im): """ @@ -636,9 +633,9 @@ def _create_dense_matrix(self): B = B.reshape(self.nres**2, self.count) B = transform_complex_to_real(np.conj(B), self.ells) B = B.reshape(self.nres**2, self.count) - if self.match_fb: - B[:, self.flip_sign_indices] *= -1.0 - B = B[:, self.fb_compat_indices] + inds = (self.signs_indices == 1) & (self.ells != 0) + B[..., inds] = B[..., inds] * -1 + B = B[..., self.fle_to_fb_indices] return B def lowpass(self, coeffs, bandlimit): From 24caea5755e52fbdb6fb19e6d3f760de74f7adde Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 3 May 2023 08:11:02 -0400 Subject: [PATCH 071/116] Add 45 deg rotation test for FLE --- tests/test_FLEbasis2D.py | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/tests/test_FLEbasis2D.py b/tests/test_FLEbasis2D.py index 030690b337..c80443caa0 100644 --- a/tests/test_FLEbasis2D.py +++ b/tests/test_FLEbasis2D.py @@ -4,7 +4,7 @@ import numpy as np import pytest -from aspire.basis import FBBasis2D, FLEBasis2D +from aspire.basis import FBBasis2D, FFBBasis2D, FLEBasis2D from aspire.image import Image from aspire.nufft import backend_available from aspire.numeric import fft @@ -292,6 +292,34 @@ def testRotate(): _ = basis.lowpass(np.zeros((3, 3, 3)), np.pi) +def testRotate45(): + # test ability to accurately rotate images via + # FLE coefficients + dtype = np.float64 + + L = 128 + fb_basis = FBBasis2D(L, dtype=dtype) + basis = FLEBasis2D(L, match_fb=True, dtype=dtype) + + # sample image + ims = create_images(L, 1) + + # get FLE coefficients + fb_coeffs = fb_basis.evaluate_t(ims) + coeffs = basis.evaluate_t(ims) + + # rotate original image in FLE space using Steerable rotate method + fb_coeffs_rot = fb_basis.rotate(fb_coeffs, np.pi / 4) + coeffs_rot = basis.rotate(coeffs, np.pi / 4) + + # back to cartesian + fb_ims_rot = fb_basis.evaluate(fb_coeffs_rot) + ims_rot = basis.evaluate(coeffs_rot) + + # test close + assert np.allclose(ims_rot[0], fb_ims_rot[0], atol=1e-4) + + def testRadialConvolution(): # test ability to accurately convolve with a radial # (e.g. CTF) function via FLE coefficients From 13790ffc4b6e055eca673cd4affd10c665ff0f84 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 3 May 2023 10:07:30 -0400 Subject: [PATCH 072/116] Always use FB order for FLE coef and indices. Begin digging into rotation/conjugation issue --- src/aspire/basis/fle_2d.py | 113 +++++++++++++++++-------------- src/aspire/basis/fle_2d_utils.py | 2 +- tests/test_FLEbasis2D.py | 1 + 3 files changed, 63 insertions(+), 53 deletions(-) diff --git a/src/aspire/basis/fle_2d.py b/src/aspire/basis/fle_2d.py index 7f121747ff..d80d4f91c4 100644 --- a/src/aspire/basis/fle_2d.py +++ b/src/aspire/basis/fle_2d.py @@ -90,15 +90,29 @@ def _build(self): # Steerable basis indices self._build_indices() - # FB compatability indices - self._generate_fb_compat_indices() - def _build_indices(self): - self.angular_indices = np.abs(self.ells) - self.radial_indices = self.ks - 1 - self.signs_indices = np.sign(self.ells) + # FLE internal indices + self._fle_angular_indices = np.abs(self._ells) + self._fle_radial_indices = self._ks - 1 + self._fle_signs_indices = np.sign(self._ells) # Use the FB2D ells sign convention of `1` for `ell=0` - self.signs_indices[self.ells == 0] = 1 + self._fle_signs_indices[self._ells == 0] = 1 + + # basis function ordering (used during evaluate_t output) + self._fle_to_fb_indices = np.lexsort( + ( + self._fle_radial_indices, + self._fle_signs_indices, + self._fle_angular_indices, + ) + ) + # store the reverse mapping (used during evaluate input) + self._fb_to_fle_indices = np.argsort(self._fle_to_fb_indices) + + # + self.angular_indices = self._fle_angular_indices[self._fle_to_fb_indices] + self.radial_indices = self._fle_radial_indices[self._fle_to_fb_indices] + self.signs_indices = self._fle_signs_indices[self._fle_to_fb_indices] def indices(self): """ @@ -110,16 +124,6 @@ def indices(self): "sgns": self.signs_indices, } - def _generate_fb_compat_indices(self): - """ - Generate indices to shuffle basis function ordering. - """ - ind = self.indices() - # basis function ordering (used during evaluate_t output) - self.fle_to_fb_indices = np.lexsort((ind["ks"], ind["sgns"], ind["ells"])) - # store the reverse mapping (used during evaluate input) - self.fb_to_fle_indices = np.argsort(self.fle_to_fb_indices) - def _precomp(self): """ Precompute the basis functions and other objects used in the evaluation of @@ -133,22 +137,22 @@ def _precomp(self): # Some important constants self.smallest_lambda = np.min(self.bessel_zeros) self.greatest_lambda = np.max(self.bessel_zeros) - self.max_ell = np.max(np.abs(self.ells)) + self.max_ell = np.max(np.abs(self._ells)) self.h = 1 / (self.nres / 2) # give each ell a positive index increasing first in |ell| # then in sign, e.g. 0->1, -1->2, 1->3, -2->4, 2->5, etc. - self.ells_p = 2 * np.abs(self.ells) - (self.ells < 0) - self.ell_p_max = np.max(self.ells_p) + self._ells_p = 2 * np.abs(self._ells) - (self._ells < 0) + self.ell_p_max = np.max(self._ells_p) # idx_list[k] contains the indices j of ells_p where ells_p[j] = k idx_list = [[] for i in range(self.ell_p_max + 1)] for i in range(self.count): - ellp = self.ells_p[i] + ellp = self._ells_p[i] idx_list[ellp].append(i) self.idx_list = idx_list # real <-> complex - self.c2r = precomp_transform_complex_to_real(self.ells) + self.c2r = precomp_transform_complex_to_real(self._ells) self.r2c = sparse.csr_matrix(self.c2r.transpose().conj()) # create an ordered list of the original ell values @@ -319,27 +323,27 @@ def _lap_eig_disk(self): # 0 frequency plus pos and negative frequencies for each bessel function # num functions per frequency num_ells = 1 + 2 * max_ell - self.ells = np.zeros((num_ells, max_k), dtype=int) - self.ks = np.zeros((num_ells, max_k), dtype=int) + self._ells = np.zeros((num_ells, max_k), dtype=int) + self._ks = np.zeros((num_ells, max_k), dtype=int) self.bessel_zeros = np.ones((num_ells, max_k), dtype=np.float64) * np.Inf # keep track of which order Bessel function we're on - self.ells[0, :] = 0 + self._ells[0, :] = 0 # bessel_roots[0, m] is the m'th zero of J_0 self.bessel_zeros[0, :] = besselj_zeros(0, max_k) # table of values of which zero of J_0 we are finding - self.ks[0, :] = np.arange(max_k) + 1 + self._ks[0, :] = np.arange(max_k) + 1 # add roots of J_ell for ell>0 twice with +k and -k (frequencies) # iterate over Bessel function order for ell in range(1, max_ell + 1): - self.ells[2 * ell - 1, :] = -ell - self.ks[2 * ell - 1, :] = np.arange(max_k) + 1 + self._ells[2 * ell - 1, :] = -ell + self._ks[2 * ell - 1, :] = np.arange(max_k) + 1 self.bessel_zeros[2 * ell - 1, :max_k] = besselj_zeros(ell, max_k) - self.ells[2 * ell, :] = ell - self.ks[2 * ell, :] = self.ks[2 * ell - 1, :] + self._ells[2 * ell, :] = ell + self._ks[2 * ell, :] = self._ks[2 * ell - 1, :] self.bessel_zeros[2 * ell, :] = self.bessel_zeros[2 * ell - 1, :] # Reshape the arrays and order by the size of the Bessel function zeros @@ -353,30 +357,30 @@ def _lap_eig_disk(self): def _flatten_and_sort_bessel_zeros(self): """ - Reshapes arrays self.ells, self.ks, and self.bessel_zeros + Reshapes arrays self._ells, self._ks, and self.bessel_zeros """ # flatten list of zeros, ells and ks: - self.ells = self.ells.flatten() - self.ks = self.ks.flatten() + self._ells = self._ells.flatten() + self._ks = self._ks.flatten() self.bessel_zeros = self.bessel_zeros.flatten() idx = np.argsort(self.bessel_zeros) - self.ells = self.ells[idx] - self.ks = self.ks[idx] + self._ells = self._ells[idx] + self._ks = self._ks[idx] self.bessel_zeros = self.bessel_zeros[idx] # sort complex conjugate pairs: -ell first, +ell second idx = np.arange(self.max_basis_functions + 1) for i in range(self.max_basis_functions + 1): - if self.ells[i] >= 0: + if self._ells[i] >= 0: continue if np.abs(self.bessel_zeros[i] - self.bessel_zeros[i + 1]) < 1e-14: continue idx[i - 1] = i idx[i] = i - 1 - self.ells = self.ells[idx] - self.ks = self.ks[idx] + self._ells = self._ells[idx] + self._ks = self._ks[idx] self.bessel_zeros = self.bessel_zeros[idx] def _threshold_basis_functions(self): @@ -402,12 +406,12 @@ def _threshold_basis_functions(self): _final_num_basis_functions -= 1 # potentially subtract one to keep complex conjugate pairs - if self.ells[_final_num_basis_functions - 1] < 0: + if self._ells[_final_num_basis_functions - 1] < 0: _final_num_basis_functions -= 1 # discard zeros above the threshold - self.ells = self.ells[:_final_num_basis_functions] - self.ks = self.ks[:_final_num_basis_functions] + self._ells = self._ells[:_final_num_basis_functions] + self._ks = self._ks[:_final_num_basis_functions] self.bessel_zeros = self.bessel_zeros[:_final_num_basis_functions] return _final_num_basis_functions @@ -420,7 +424,7 @@ def _create_basis_functions(self): basis_functions = [None] * self.count for i in range(self.count): # parameters defining the basis function: bessel order and which bessel root - ell = self.ells[i] + ell = self._ells[i] bessel_zero = self.bessel_zeros[i] # compute normalization constant @@ -457,9 +461,7 @@ def _evaluate(self, coeffs): :return: An Image object containing the corresponding images. """ # convert from FB order - coeffs = coeffs[..., self.fb_to_fle_indices] - inds = (self.signs_indices == 1) & (self.ells != 0) - coeffs[..., inds] = coeffs[..., inds] * -1 + coeffs = coeffs[..., self._fb_to_fle_indices] # See Remark 3.3 and Section 3.4 betas = self._step3(coeffs) @@ -483,9 +485,7 @@ def _evaluate_t(self, imgs): coeffs = self._step3_t(b) # return in FB order - inds = (self.signs_indices == 1) & (self.ells != 0) - coeffs[..., inds] = coeffs[..., inds] * -1 - coeffs = coeffs[..., self.fle_to_fb_indices] + coeffs = coeffs[..., self._fle_to_fb_indices] return coeffs.astype(self.coefficient_dtype, copy=False) def _step1_t(self, im): @@ -631,11 +631,10 @@ def _create_dense_matrix(self): for i in range(self.count): B[:, :, i] = self.basis_functions[i](self.rs, ts) * self.h B = B.reshape(self.nres**2, self.count) - B = transform_complex_to_real(np.conj(B), self.ells) + B = transform_complex_to_real(B, self._ells) B = B.reshape(self.nres**2, self.count) - inds = (self.signs_indices == 1) & (self.ells != 0) - B[..., inds] = B[..., inds] * -1 - B = B[..., self.fle_to_fb_indices] + B = B[..., self._fle_to_fb_indices] + return B def lowpass(self, coeffs, bandlimit): @@ -662,6 +661,9 @@ def lowpass(self, coeffs, bandlimit): return coeffs + def rotate(self, coef, radians, refl=None): + return super().rotate(coef, -1 * radians, refl) + def radial_convolve(self, coeffs, radial_img): """ Convolve a stack of FLE coefficients with a 2D radial function. @@ -671,6 +673,10 @@ def radial_convolve(self, coeffs, radial_img): """ num_img = coeffs.shape[0] coeffs_conv = np.zeros(coeffs.shape) + + # Convert to internal FLE indices ordering + coeffs = coeffs[..., self._fb_to_fle_indices] + for k in range(num_img): _coeffs = coeffs[k, :] z = self._step1_t(radial_img) @@ -680,6 +686,9 @@ def radial_convolve(self, coeffs, radial_img): b = b.reshape(self.count) coeffs_conv[k, :] = np.real(self.c2r @ (b * (self.r2c @ _coeffs).flatten())) + # Convert from internal FLE ordering to FB convention + coeffs_conv = coeffs_conv[..., self._fle_to_fb_indices] + return coeffs_conv def _radial_convolve_weights(self, b): diff --git a/src/aspire/basis/fle_2d_utils.py b/src/aspire/basis/fle_2d_utils.py index e97d309524..67246ab497 100644 --- a/src/aspire/basis/fle_2d_utils.py +++ b/src/aspire/basis/fle_2d_utils.py @@ -87,7 +87,7 @@ def precomp_transform_complex_to_real(ells): A = sparse.csr_matrix((vals, (idx, jdx)), shape=(count, count), dtype=np.complex128) - return A + return A.conjugate() def barycentric_interp_sparse(target_points, known_points, numsparse): diff --git a/tests/test_FLEbasis2D.py b/tests/test_FLEbasis2D.py index c80443caa0..4d286a6a59 100644 --- a/tests/test_FLEbasis2D.py +++ b/tests/test_FLEbasis2D.py @@ -156,6 +156,7 @@ def testMatchFBDenseEvaluate(basis): # Matrix column reording in match_fb mode flips signs of some of the basis functions assert np.allclose(np.abs(fb_images), np.abs(fle_images), atol=1e-3) + assert np.allclose(fb_images, fle_images, atol=1e-3) @pytest.mark.parametrize("basis", test_bases_match_fb, ids=show_fle_params) From 6ccd2017615e5008f040353d14c7d73023ca8afe Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 3 May 2023 10:31:57 -0400 Subject: [PATCH 073/116] Use FFB for FLE rotate 45 comparison (faster) --- tests/test_FLEbasis2D.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_FLEbasis2D.py b/tests/test_FLEbasis2D.py index 4d286a6a59..a8b70b2bb9 100644 --- a/tests/test_FLEbasis2D.py +++ b/tests/test_FLEbasis2D.py @@ -299,7 +299,7 @@ def testRotate45(): dtype = np.float64 L = 128 - fb_basis = FBBasis2D(L, dtype=dtype) + fb_basis = FFBBasis2D(L, dtype=dtype) basis = FLEBasis2D(L, match_fb=True, dtype=dtype) # sample image From cc7ca2222efadd80c31600c5c1eedabca965d191 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 4 May 2023 08:27:17 -0400 Subject: [PATCH 074/116] Add missing indices comment --- src/aspire/basis/fle_2d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aspire/basis/fle_2d.py b/src/aspire/basis/fle_2d.py index d80d4f91c4..c92d0db240 100644 --- a/src/aspire/basis/fle_2d.py +++ b/src/aspire/basis/fle_2d.py @@ -109,7 +109,7 @@ def _build_indices(self): # store the reverse mapping (used during evaluate input) self._fb_to_fle_indices = np.argsort(self._fle_to_fb_indices) - # + # User facing indices, should follow FB ordering. self.angular_indices = self._fle_angular_indices[self._fle_to_fb_indices] self.radial_indices = self._fle_radial_indices[self._fle_to_fb_indices] self.signs_indices = self._fle_signs_indices[self._fle_to_fb_indices] From 23748c3a54e91ca4fc30be2c94d29325170386f4 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 4 May 2023 09:31:13 -0400 Subject: [PATCH 075/116] Sign indices re-ordering. --- src/aspire/basis/fle_2d.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/aspire/basis/fle_2d.py b/src/aspire/basis/fle_2d.py index c92d0db240..8709abdffb 100644 --- a/src/aspire/basis/fle_2d.py +++ b/src/aspire/basis/fle_2d.py @@ -94,7 +94,8 @@ def _build_indices(self): # FLE internal indices self._fle_angular_indices = np.abs(self._ells) self._fle_radial_indices = self._ks - 1 - self._fle_signs_indices = np.sign(self._ells) + # Negate all signs from FLE implementation + self._fle_signs_indices = -np.sign(self._ells) # Use the FB2D ells sign convention of `1` for `ell=0` self._fle_signs_indices[self._ells == 0] = 1 @@ -102,7 +103,9 @@ def _build_indices(self): self._fle_to_fb_indices = np.lexsort( ( self._fle_radial_indices, - self._fle_signs_indices, + # Reverse sign sorting order so +1 first, + # match `sgns = (1,) if ell == 0 else (1, -1)` from fb_2d.py + -self._fle_signs_indices, self._fle_angular_indices, ) ) @@ -112,6 +115,7 @@ def _build_indices(self): # User facing indices, should follow FB ordering. self.angular_indices = self._fle_angular_indices[self._fle_to_fb_indices] self.radial_indices = self._fle_radial_indices[self._fle_to_fb_indices] + # Note we negate the FLE signs? self.signs_indices = self._fle_signs_indices[self._fle_to_fb_indices] def indices(self): @@ -661,9 +665,6 @@ def lowpass(self, coeffs, bandlimit): return coeffs - def rotate(self, coef, radians, refl=None): - return super().rotate(coef, -1 * radians, refl) - def radial_convolve(self, coeffs, radial_img): """ Convolve a stack of FLE coefficients with a 2D radial function. From e49d9b984b7e45cde1204936ff14c8eb1b4ac41a Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 16 May 2023 11:39:00 -0400 Subject: [PATCH 076/116] string replace varname B --- src/aspire/basis/fle_2d_utils.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/aspire/basis/fle_2d_utils.py b/src/aspire/basis/fle_2d_utils.py index 67246ab497..cde0cd11bf 100644 --- a/src/aspire/basis/fle_2d_utils.py +++ b/src/aspire/basis/fle_2d_utils.py @@ -2,29 +2,29 @@ import scipy.sparse as sparse -def transform_complex_to_real(B_conj, ells): +def transform_complex_to_real(B, ells): """ Transforms coefficients of the matrix B (see Eq. 3) from complex to real. B is the linear transformation that takes FLE coefficients to images. - :param B_conj: Complex conjugate of the matrix B. + :param B: Complex matrix B. :param ells: List of ells (Bessel function orders) in this basis. :return: Transformed matrix. """ - num_basis_functions = B_conj.shape[1] - X = np.zeros(B_conj.shape, dtype=np.float64) + num_basis_functions = B.shape[1] + X = np.zeros(B.shape, dtype=np.float64) for i in range(num_basis_functions): ell = ells[i] if ell == 0: - X[:, i] = np.real(B_conj[:, i]) + X[:, i] = np.real(B[:, i]) # for each ell != 0, we can populate two entries of the matrix # by taking the complex conjugate of the ell with the opposite sign if ell < 0: s = (-1) ** np.abs(ell) - x0 = (B_conj[:, i] + s * B_conj[:, i + 1]) / np.sqrt(2) - x1 = (-B_conj[:, i] + s * B_conj[:, i + 1]) / (1j * np.sqrt(2)) + x0 = (B[:, i] + s * B[:, i + 1]) / np.sqrt(2) + x1 = (-B[:, i] + s * B[:, i + 1]) / (1j * np.sqrt(2)) X[:, i] = np.real(x0) X[:, i + 1] = np.real(x1) From 1212e0c1968927706c455bfe49e8d92732047648 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 25 May 2023 07:21:30 -0400 Subject: [PATCH 077/116] Refactor FLE docstring --- src/aspire/basis/fle_2d.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/aspire/basis/fle_2d.py b/src/aspire/basis/fle_2d.py index 8709abdffb..e9a595d727 100644 --- a/src/aspire/basis/fle_2d.py +++ b/src/aspire/basis/fle_2d.py @@ -40,13 +40,14 @@ def __init__( resolution of the basis. :param epsilon: Relative precision between FLE fast method and dense matrix multiplication. :param dtype: Datatype of images and coefficients represented. - :param match_fb: - With this flag set the following will ensure that the basis functions are - identical to `FBBasis2D`: - - The initial heuristic for the number of basis functions, based on the resolution, will - be set to that of `FBBasis2D`, and the FLE frequency thresholding procedure to reduce the - number of functions will not be carried out. This means the number of basis functions for - a given image size will be identical across the two bases. + :param match_fb: This flag constructs basis functions + identical to `FBBasis2D`. The initial heuristic for the + number of basis functions, based on the image size, will + be set to that of `FBBasis2D`, and the FLE frequency + thresholding procedure to reduce the number of functions + will not be carried out. This means the number of basis + functions for a given image size will be identical across + the two bases. """ if isinstance(size, int): size = (size, size) From 5737f0cc02e988354fda0f3d38a7fc60259459b4 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 25 Apr 2023 16:09:37 -0400 Subject: [PATCH 078/116] Remove vol_idx. Admit volume stacks for project. --- src/aspire/volume/volume.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index b04d559efa..e024c281f5 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -230,7 +230,7 @@ def __rtruediv__(self, otherL): """ return otherL * Volume(1.0 / self._data) - def project(self, vol_idx, rot_matrices): + def project(self, rot_matrices): """ Using the stack of rot_matrices, project images of Volume[vol_idx]. @@ -239,11 +239,6 @@ def project(self, vol_idx, rot_matrices): :param rot_matrices: Stack of rotations. Rotation or ndarray instance. :return: `Image` instance. """ - # See Issue #727 - if self.stack_ndim > 1: - raise NotImplementedError( - "`project` is currently limited to 1D Volume stacks." - ) # If we are an ASPIRE Rotation, get the numpy representation. if isinstance(rot_matrices, Rotation): @@ -257,9 +252,10 @@ def project(self, vol_idx, rot_matrices): " In the future this will raise an error." ) - data = self[vol_idx].asnumpy() - n = rot_matrices.shape[0] + return_stack_shape = self.stack_shape + (n,) + + data = self.stack_reshape(-1)._data pts_rot = rotated_grids(self.resolution, rot_matrices) @@ -275,8 +271,9 @@ def project(self, vol_idx, rot_matrices): im_f[:, :, 0] = 0 im_f = xp.asnumpy(fft.centered_ifft2(xp.asarray(im_f))) + im = aspire.image.Image(np.real(im_f)).stack_reshape(return_stack_shape) - return aspire.image.Image(np.real(im_f)) + return im def to_vec(self): """Returns an N x resolution ** 3 array.""" From 2794adff6935223afcc26a079d7a10fd8cb8abf2 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 28 Apr 2023 10:13:24 -0400 Subject: [PATCH 079/116] project with broadcast for n_vols n_rots. --- gallery/tutorials/aspire_introduction.py | 4 +-- src/aspire/source/image.py | 2 +- src/aspire/source/simulation.py | 2 +- src/aspire/volume/volume.py | 37 ++++++++++++++++-------- tests/test_volume.py | 4 +-- 5 files changed, 31 insertions(+), 18 deletions(-) diff --git a/gallery/tutorials/aspire_introduction.py b/gallery/tutorials/aspire_introduction.py index 8022a72402..c77ba94ad9 100644 --- a/gallery/tutorials/aspire_introduction.py +++ b/gallery/tutorials/aspire_introduction.py @@ -212,9 +212,9 @@ print(rots.matrices) # %% -# Using the zero-th (and in this case, only) volume, compute +# Using the ``Volume.project()`` method we compute # projections using the stack of rotations: -projections = vol.project(0, rots) +projections = vol.project(rots) print(projections) # %% diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index 98e017fb11..16c1d4fbd3 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -856,7 +856,7 @@ def vol_forward(self, vol, start, num): if vol.dtype != self.dtype: logger.warning(f"Volume.dtype {vol.dtype} inconsistent with {self.dtype}") - im = vol.project(0, self.rotations[all_idx, :, :]) + im = vol.project(self.rotations[all_idx, :, :]) im = self._apply_source_filters(im, all_idx) im = im.shift(self.offsets[all_idx, :]) im *= self.amplitudes[all_idx, np.newaxis, np.newaxis] diff --git a/src/aspire/source/simulation.py b/src/aspire/source/simulation.py index 9e3dea758b..4161bfd92a 100644 --- a/src/aspire/source/simulation.py +++ b/src/aspire/source/simulation.py @@ -240,7 +240,7 @@ def _projections(self, indices): idx_k = np.where(states == k)[0] rot = self.rotations[indices[idx_k], :, :] - im_k = self.vols.project(vol_idx=k - 1, rot_matrices=rot) + im_k = self.vols[k - 1].project(rot_matrices=rot) im[idx_k, :, :] = im_k.asnumpy() return Image(im) diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index e024c281f5..945098c5ee 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -232,13 +232,18 @@ def __rtruediv__(self, otherL): def project(self, rot_matrices): """ - Using the stack of rot_matrices, - project images of Volume[vol_idx]. + Using the stack of rot_matrices, project images of Volume. When projecting + over a stack of volumes, a singleton Rotation or a Rotation with stack size + self.n_vols must be used. - :param vol_idx: Volume index :param rot_matrices: Stack of rotations. Rotation or ndarray instance. :return: `Image` instance. """ + # See Issue #727 + if self.stack_ndim > 1: + raise NotImplementedError( + "`project` is currently limited to 1D Volume stacks." + ) # If we are an ASPIRE Rotation, get the numpy representation. if isinstance(rot_matrices, Rotation): @@ -252,17 +257,26 @@ def project(self, rot_matrices): " In the future this will raise an error." ) - n = rot_matrices.shape[0] - return_stack_shape = self.stack_shape + (n,) + data = self._data + n_rots = rot_matrices.shape[0] - data = self.stack_reshape(-1)._data + if not ((n_rots == self.n_vols) or (n_rots == 1) or (self.n_vols == 1)): + raise NotImplementedError( + f"Cannot broadcast with {n_rots} Rotations and {self.n_vols} Volumes." + ) pts_rot = rotated_grids(self.resolution, rot_matrices) - # TODO: rotated_grids might as well give us correctly shaped array in the first place - pts_rot = pts_rot.reshape((3, n * self.resolution**2)) - - im_f = nufft(data, pts_rot) / self.resolution + if n_rots == self.n_vols: + im_f = np.empty( + (self.n_vols, self.resolution**2), dtype=complex_type(self.dtype) + ) + pts_rot = pts_rot.reshape((3, n_rots, self.resolution**2)) + for i in range(self.n_vols): + im_f[i] = nufft(data[i], pts_rot[:, i]) / self.resolution + else: + pts_rot = pts_rot.reshape((3, n_rots * self.resolution**2)) + im_f = nufft(data, pts_rot) / self.resolution im_f = im_f.reshape(-1, self.resolution, self.resolution) @@ -271,9 +285,8 @@ def project(self, rot_matrices): im_f[:, :, 0] = 0 im_f = xp.asnumpy(fft.centered_ifft2(xp.asarray(im_f))) - im = aspire.image.Image(np.real(im_f)).stack_reshape(return_stack_shape) - return im + return aspire.image.Image(np.real(im_f)) def to_vec(self): """Returns an N x resolution ** 3 array.""" diff --git a/tests/test_volume.py b/tests/test_volume.py index 2ebac31e38..9a23748399 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -201,7 +201,7 @@ def testProject(self): # Project a Volume with all the test rotations vol_id = 1 # select a volume from Volume stack - img_stack = self.vols_1.project(vol_id, r_stack) + img_stack = self.vols_1[vol_id].project(r_stack) for r in range(len(r_stack)): # Get result of test projection at center of Image. @@ -221,7 +221,7 @@ def testProject(self): vols = Volume(np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol_down8.npy"))) rots = np.load(os.path.join(DATA_DIR, "rand_rot_matrices32.npy")) rots = np.moveaxis(rots, 2, 0) - imgs_clean = vols.project(0, rots).asnumpy() + imgs_clean = vols.project(rots).asnumpy() self.assertTrue(np.allclose(results, imgs_clean, atol=1e-7)) # Parameterize over even and odd resolutions From a64bba0c07d591bed383e7de90f19b05286e3ee6 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 28 Apr 2023 11:00:53 -0400 Subject: [PATCH 080/116] fix project in pipeline_demo --- gallery/tutorials/pipeline_demo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gallery/tutorials/pipeline_demo.py b/gallery/tutorials/pipeline_demo.py index 2380541f5a..f48741c4c2 100644 --- a/gallery/tutorials/pipeline_demo.py +++ b/gallery/tutorials/pipeline_demo.py @@ -293,7 +293,7 @@ def download(url, save_path, chunk_size=1024 * 1024): # Get projections from the estimated volume using the estimated # orientations. We instantiate the projections as an # ``ArrayImageSource`` to access the ``Image.show()`` method. -projections_est = ArrayImageSource(estimated_volume.project(0, rots_est)) +projections_est = ArrayImageSource(estimated_volume.project(rots_est)) # We view the first 10 projections of the estimated volume. projections_est.images[0:10].show() From 56a9f29553d62337e1eb78b933b0a605f6382cfb Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 28 Apr 2023 11:40:20 -0400 Subject: [PATCH 081/116] testProjectBroadcast --- tests/test_volume.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/tests/test_volume.py b/tests/test_volume.py index 9a23748399..d1038602ab 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -9,7 +9,7 @@ from parameterized import parameterized from pytest import raises, skip -from aspire.utils import Rotation, powerset +from aspire.utils import Rotation, gaussian_3d, powerset from aspire.utils.matrix import anorm from aspire.utils.types import utest_tolerance from aspire.volume import Volume @@ -224,6 +224,35 @@ def testProject(self): imgs_clean = vols.project(rots).asnumpy() self.assertTrue(np.allclose(results, imgs_clean, atol=1e-7)) + def testProjectBroadcast(self): + L = 32 + + # Create stack of Volume with Gaussians stretched along varying axes. + blob_x = gaussian_3d(L, sigma=(3, 2, 1), dtype=self.dtype) + blob_y = gaussian_3d(L, sigma=(1, 3, 2), dtype=self.dtype) + blob_z = gaussian_3d(L, sigma=(1, 2, 3), dtype=self.dtype) + vols = Volume(np.vstack((blob_x, blob_y, blob_z)).reshape(3, L, L, L)) + + # Create singleton and stacks of identity Rotations. + I = np.eye(3, dtype=self.dtype) + eye = Rotation(I) + eyes_2 = Rotation(np.vstack((I,) * 2).reshape(2, 3, 3)) + eyes_3 = Rotation(np.vstack((I,) * 3).reshape(3, 3, 3)) + + # Broadcast Volume stack with singleton Rotation. + ims_3_1 = vols.project(eye) + + # Broadcast Volume stack with Rotation stack of same size. + ims_3_3 = vols.project(eyes_3) + + # These image stacks should be identical. + self.assertTrue(np.allclose(ims_3_1, ims_3_3)) + + # Check we raise an error for incompatible stacks. + msg = "Cannot broadcast with 2 Rotations and 3 Volumes." + with raises(NotImplementedError, match=msg): + _ = vols.project(eyes_2) + # Parameterize over even and odd resolutions @parameterized.expand([(res,), (res - 1,)]) def testRotate(self, L): From 917dd97f68a7005f7e797cb03cfa80605d3134a3 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 28 Apr 2023 11:44:44 -0400 Subject: [PATCH 082/116] I ~~> identity --- tests/test_volume.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_volume.py b/tests/test_volume.py index d1038602ab..0b43860837 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -234,10 +234,10 @@ def testProjectBroadcast(self): vols = Volume(np.vstack((blob_x, blob_y, blob_z)).reshape(3, L, L, L)) # Create singleton and stacks of identity Rotations. - I = np.eye(3, dtype=self.dtype) - eye = Rotation(I) - eyes_2 = Rotation(np.vstack((I,) * 2).reshape(2, 3, 3)) - eyes_3 = Rotation(np.vstack((I,) * 3).reshape(3, 3, 3)) + identity = np.eye(3, dtype=self.dtype) + eye = Rotation(identity) + eyes_2 = Rotation(np.vstack((identity,) * 2).reshape(2, 3, 3)) + eyes_3 = Rotation(np.vstack((identity,) * 3).reshape(3, 3, 3)) # Broadcast Volume stack with singleton Rotation. ims_3_1 = vols.project(eye) From ec3b3a861c6793708c5ddaef8e24da5d0097e3ad Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 1 May 2023 12:00:14 -0400 Subject: [PATCH 083/116] improve broadcast test --- tests/test_volume.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/tests/test_volume.py b/tests/test_volume.py index 0b43860837..ae4fcdd1e8 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -233,25 +233,30 @@ def testProjectBroadcast(self): blob_z = gaussian_3d(L, sigma=(1, 2, 3), dtype=self.dtype) vols = Volume(np.vstack((blob_x, blob_y, blob_z)).reshape(3, L, L, L)) - # Create singleton and stacks of identity Rotations. - identity = np.eye(3, dtype=self.dtype) - eye = Rotation(identity) - eyes_2 = Rotation(np.vstack((identity,) * 2).reshape(2, 3, 3)) - eyes_3 = Rotation(np.vstack((identity,) * 3).reshape(3, 3, 3)) - - # Broadcast Volume stack with singleton Rotation. - ims_3_1 = vols.project(eye) + # Create a singleton and stack of Rotations. + rot = Rotation.about_axis("z", np.pi / 6, dtype=self.dtype) + rots = Rotation.about_axis( + "z", [np.pi / 4, np.pi / 3, np.pi / 2], dtype=self.dtype + ) - # Broadcast Volume stack with Rotation stack of same size. - ims_3_3 = vols.project(eyes_3) + # Broadcast Volume stack with singleton Rotation and compare with manual projection. + projs_3_1 = vols.project(rot) + vols_rot_3_1 = vols.rotate(rot) + manual_projs_3_1 = np.sum(vols_rot_3_1, axis=-1) / L + self.assertTrue(projs_3_1.shape[0] == 3) + self.assertTrue(np.allclose(projs_3_1, manual_projs_3_1)) - # These image stacks should be identical. - self.assertTrue(np.allclose(ims_3_1, ims_3_3)) + # Broadcast Volume stack with Rotation stack of same size and compare with manual projections. + projs_3_3 = vols.project(rots) + vols_rot_3_3 = vols.rotate(rots) + manual_projs_3_3 = np.sum(vols_rot_3_3, axis=-1) / L + self.assertTrue(projs_3_3.shape[0] == 3) + self.assertTrue(np.allclose(projs_3_3, manual_projs_3_3)) # Check we raise an error for incompatible stacks. msg = "Cannot broadcast with 2 Rotations and 3 Volumes." with raises(NotImplementedError, match=msg): - _ = vols.project(eyes_2) + _ = vols.project(rots[:2]) # Parameterize over even and odd resolutions @parameterized.expand([(res,), (res - 1,)]) From cef27eac21520deabe32e785e445125c5327a6ef Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 1 May 2023 14:18:11 -0400 Subject: [PATCH 084/116] test singles and doubles --- tests/test_volume.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/test_volume.py b/tests/test_volume.py index ae4fcdd1e8..623c656a5b 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -224,19 +224,20 @@ def testProject(self): imgs_clean = vols.project(rots).asnumpy() self.assertTrue(np.allclose(results, imgs_clean, atol=1e-7)) - def testProjectBroadcast(self): + @parameterized.expand([(np.float32,), (np.float64)]) + def testProjectBroadcast(self, dtype): L = 32 # Create stack of Volume with Gaussians stretched along varying axes. - blob_x = gaussian_3d(L, sigma=(3, 2, 1), dtype=self.dtype) - blob_y = gaussian_3d(L, sigma=(1, 3, 2), dtype=self.dtype) - blob_z = gaussian_3d(L, sigma=(1, 2, 3), dtype=self.dtype) + blob_x = gaussian_3d(L, sigma=(3, 2, 1), dtype=dtype) + blob_y = gaussian_3d(L, sigma=(1, 3, 2), dtype=dtype) + blob_z = gaussian_3d(L, sigma=(1, 2, 3), dtype=dtype) vols = Volume(np.vstack((blob_x, blob_y, blob_z)).reshape(3, L, L, L)) # Create a singleton and stack of Rotations. - rot = Rotation.about_axis("z", np.pi / 6, dtype=self.dtype) + rot = Rotation.about_axis("z", np.pi / 6, dtype=dtype) rots = Rotation.about_axis( - "z", [np.pi / 4, np.pi / 3, np.pi / 2], dtype=self.dtype + "z", [np.pi / 4, np.pi / 3, np.pi / 2], dtype=dtype ) # Broadcast Volume stack with singleton Rotation and compare with manual projection. From 2c12c771c284bb5e06ecf4bfed8cf04f42cb7145 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 1 May 2023 14:55:01 -0400 Subject: [PATCH 085/116] tox --- tests/test_volume.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_volume.py b/tests/test_volume.py index 623c656a5b..673a48101e 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -236,9 +236,7 @@ def testProjectBroadcast(self, dtype): # Create a singleton and stack of Rotations. rot = Rotation.about_axis("z", np.pi / 6, dtype=dtype) - rots = Rotation.about_axis( - "z", [np.pi / 4, np.pi / 3, np.pi / 2], dtype=dtype - ) + rots = Rotation.about_axis("z", [np.pi / 4, np.pi / 3, np.pi / 2], dtype=dtype) # Broadcast Volume stack with singleton Rotation and compare with manual projection. projs_3_1 = vols.project(rot) From e804051b23bc27b18ff67de41ddaf6642614be95 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 2 May 2023 08:23:39 -0400 Subject: [PATCH 086/116] missing comma --- tests/test_volume.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_volume.py b/tests/test_volume.py index 673a48101e..68efda5184 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -224,7 +224,7 @@ def testProject(self): imgs_clean = vols.project(rots).asnumpy() self.assertTrue(np.allclose(results, imgs_clean, atol=1e-7)) - @parameterized.expand([(np.float32,), (np.float64)]) + @parameterized.expand([(np.float32,), (np.float64,)]) def testProjectBroadcast(self, dtype): L = 32 From bddb6516aec2a828552cf428128827392817505a Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 2 May 2023 08:58:25 -0400 Subject: [PATCH 087/116] utest_tolerance --- tests/test_volume.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_volume.py b/tests/test_volume.py index 68efda5184..5f54334733 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -243,14 +243,18 @@ def testProjectBroadcast(self, dtype): vols_rot_3_1 = vols.rotate(rot) manual_projs_3_1 = np.sum(vols_rot_3_1, axis=-1) / L self.assertTrue(projs_3_1.shape[0] == 3) - self.assertTrue(np.allclose(projs_3_1, manual_projs_3_1)) + self.assertTrue( + np.allclose(projs_3_1, manual_projs_3_1, atol=utest_tolerance(self.dtype)) + ) # Broadcast Volume stack with Rotation stack of same size and compare with manual projections. projs_3_3 = vols.project(rots) vols_rot_3_3 = vols.rotate(rots) manual_projs_3_3 = np.sum(vols_rot_3_3, axis=-1) / L self.assertTrue(projs_3_3.shape[0] == 3) - self.assertTrue(np.allclose(projs_3_3, manual_projs_3_3)) + self.assertTrue( + np.allclose(projs_3_3, manual_projs_3_3, atol=utest_tolerance(self.dtype)) + ) # Check we raise an error for incompatible stacks. msg = "Cannot broadcast with 2 Rotations and 3 Volumes." From 499aea03a6f48500c37a43aed0033ac40c17a512 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 8 May 2023 10:13:32 -0400 Subject: [PATCH 088/116] Remove duplicate logic. Expand_dim if singleton numpy array. --- src/aspire/volume/volume.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index 945098c5ee..c312871976 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -248,6 +248,8 @@ def project(self, rot_matrices): # If we are an ASPIRE Rotation, get the numpy representation. if isinstance(rot_matrices, Rotation): rot_matrices = rot_matrices.matrices + elif rot_matrices.ndim == 2: + rot_matrices = np.expand_dims(rot_matrices, axis=0) if rot_matrices.dtype != self.dtype: logger.warning( @@ -259,24 +261,25 @@ def project(self, rot_matrices): data = self._data n_rots = rot_matrices.shape[0] - - if not ((n_rots == self.n_vols) or (n_rots == 1) or (self.n_vols == 1)): - raise NotImplementedError( - f"Cannot broadcast with {n_rots} Rotations and {self.n_vols} Volumes." - ) - pts_rot = rotated_grids(self.resolution, rot_matrices) if n_rots == self.n_vols: + # Apply rotations to Volumes element-wise. im_f = np.empty( (self.n_vols, self.resolution**2), dtype=complex_type(self.dtype) ) pts_rot = pts_rot.reshape((3, n_rots, self.resolution**2)) for i in range(self.n_vols): im_f[i] = nufft(data[i], pts_rot[:, i]) / self.resolution - else: + elif (n_rots == 1) or (self.n_vols == 1): + # Broadcast stack with singleton. pts_rot = pts_rot.reshape((3, n_rots * self.resolution**2)) im_f = nufft(data, pts_rot) / self.resolution + else: + # Currently not supporting broadcasting n Volumes with m rotations. + raise NotImplementedError( + f"Cannot broadcast with {n_rots} Rotations and {self.n_vols} Volumes." + ) im_f = im_f.reshape(-1, self.resolution, self.resolution) From abec28a95a92333a7256447e024161da3147ef67 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Mon, 8 May 2023 10:30:26 -0400 Subject: [PATCH 089/116] Check rot_matrices size --- src/aspire/volume/volume.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index c312871976..e0bfd0e894 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -248,7 +248,13 @@ def project(self, rot_matrices): # If we are an ASPIRE Rotation, get the numpy representation. if isinstance(rot_matrices, Rotation): rot_matrices = rot_matrices.matrices - elif rot_matrices.ndim == 2: + elif rot_matrices.shape[-2:] != (3, 3): + raise NotImplementedError( + f"`rot_matrices` must be a stack of 3x3 rotation matrices, found shape {rot_matrices.shape}." + ) + + # If singleton rotation array, expand to have stack format. + if rot_matrices.ndim == 2: rot_matrices = np.expand_dims(rot_matrices, axis=0) if rot_matrices.dtype != self.dtype: From cc92083952a2335d2fd77b910ad18caac109c9d8 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 9 May 2023 12:00:30 -0400 Subject: [PATCH 090/116] Use AsymmetricVolume in test. --- tests/test_volume.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/tests/test_volume.py b/tests/test_volume.py index 5f54334733..15ce23baa1 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -9,10 +9,10 @@ from parameterized import parameterized from pytest import raises, skip -from aspire.utils import Rotation, gaussian_3d, powerset +from aspire.utils import Rotation, powerset from aspire.utils.matrix import anorm from aspire.utils.types import utest_tolerance -from aspire.volume import Volume +from aspire.volume import AsymmetricVolume, Volume DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data") @@ -228,17 +228,14 @@ def testProject(self): def testProjectBroadcast(self, dtype): L = 32 - # Create stack of Volume with Gaussians stretched along varying axes. - blob_x = gaussian_3d(L, sigma=(3, 2, 1), dtype=dtype) - blob_y = gaussian_3d(L, sigma=(1, 3, 2), dtype=dtype) - blob_z = gaussian_3d(L, sigma=(1, 2, 3), dtype=dtype) - vols = Volume(np.vstack((blob_x, blob_y, blob_z)).reshape(3, L, L, L)) + # Create stack of 3 Volumes. + vols = AsymmetricVolume(L=L, C=3, dtype=dtype).generate() # Create a singleton and stack of Rotations. rot = Rotation.about_axis("z", np.pi / 6, dtype=dtype) rots = Rotation.about_axis("z", [np.pi / 4, np.pi / 3, np.pi / 2], dtype=dtype) - # Broadcast Volume stack with singleton Rotation and compare with manual projection. + # Broadcast Volume stack with singleton Rotation and compare with manually generated projection. projs_3_1 = vols.project(rot) vols_rot_3_1 = vols.rotate(rot) manual_projs_3_1 = np.sum(vols_rot_3_1, axis=-1) / L @@ -247,7 +244,7 @@ def testProjectBroadcast(self, dtype): np.allclose(projs_3_1, manual_projs_3_1, atol=utest_tolerance(self.dtype)) ) - # Broadcast Volume stack with Rotation stack of same size and compare with manual projections. + # Broadcast Volume stack with Rotation stack of same size and compare with manually generated projections. projs_3_3 = vols.project(rots) vols_rot_3_3 = vols.rotate(rots) manual_projs_3_3 = np.sum(vols_rot_3_3, axis=-1) / L @@ -256,7 +253,7 @@ def testProjectBroadcast(self, dtype): np.allclose(projs_3_3, manual_projs_3_3, atol=utest_tolerance(self.dtype)) ) - # Check we raise an error for incompatible stacks. + # Check we raise an error for incompatible stack sizes. msg = "Cannot broadcast with 2 Rotations and 3 Volumes." with raises(NotImplementedError, match=msg): _ = vols.project(rots[:2]) From 9b09af38a6a01ee7bc7139bc82ab07d04c7ed36a Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 9 May 2023 12:12:10 -0400 Subject: [PATCH 091/116] Remove unnecessary checks and designate ndarray shape in docstring. --- src/aspire/volume/volume.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index e0bfd0e894..33b69c8da4 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -236,7 +236,7 @@ def project(self, rot_matrices): over a stack of volumes, a singleton Rotation or a Rotation with stack size self.n_vols must be used. - :param rot_matrices: Stack of rotations. Rotation or ndarray instance. + :param rot_matrices: Stack of rotations. Rotation or kx3x3 ndarray instance. :return: `Image` instance. """ # See Issue #727 @@ -248,14 +248,6 @@ def project(self, rot_matrices): # If we are an ASPIRE Rotation, get the numpy representation. if isinstance(rot_matrices, Rotation): rot_matrices = rot_matrices.matrices - elif rot_matrices.shape[-2:] != (3, 3): - raise NotImplementedError( - f"`rot_matrices` must be a stack of 3x3 rotation matrices, found shape {rot_matrices.shape}." - ) - - # If singleton rotation array, expand to have stack format. - if rot_matrices.ndim == 2: - rot_matrices = np.expand_dims(rot_matrices, axis=0) if rot_matrices.dtype != self.dtype: logger.warning( From ba2d74e37ef519fa3200adb0e34e7a356e5a2bb3 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 10 May 2023 09:24:21 -0400 Subject: [PATCH 092/116] use pytest parametrized --- tests/test_volume.py | 66 +++++++++++++++++++++----------------------- 1 file changed, 32 insertions(+), 34 deletions(-) diff --git a/tests/test_volume.py b/tests/test_volume.py index 15ce23baa1..437c29ce0d 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -5,6 +5,7 @@ from unittest import TestCase import numpy as np +import pytest from numpy import pi from parameterized import parameterized from pytest import raises, skip @@ -224,40 +225,6 @@ def testProject(self): imgs_clean = vols.project(rots).asnumpy() self.assertTrue(np.allclose(results, imgs_clean, atol=1e-7)) - @parameterized.expand([(np.float32,), (np.float64,)]) - def testProjectBroadcast(self, dtype): - L = 32 - - # Create stack of 3 Volumes. - vols = AsymmetricVolume(L=L, C=3, dtype=dtype).generate() - - # Create a singleton and stack of Rotations. - rot = Rotation.about_axis("z", np.pi / 6, dtype=dtype) - rots = Rotation.about_axis("z", [np.pi / 4, np.pi / 3, np.pi / 2], dtype=dtype) - - # Broadcast Volume stack with singleton Rotation and compare with manually generated projection. - projs_3_1 = vols.project(rot) - vols_rot_3_1 = vols.rotate(rot) - manual_projs_3_1 = np.sum(vols_rot_3_1, axis=-1) / L - self.assertTrue(projs_3_1.shape[0] == 3) - self.assertTrue( - np.allclose(projs_3_1, manual_projs_3_1, atol=utest_tolerance(self.dtype)) - ) - - # Broadcast Volume stack with Rotation stack of same size and compare with manually generated projections. - projs_3_3 = vols.project(rots) - vols_rot_3_3 = vols.rotate(rots) - manual_projs_3_3 = np.sum(vols_rot_3_3, axis=-1) / L - self.assertTrue(projs_3_3.shape[0] == 3) - self.assertTrue( - np.allclose(projs_3_3, manual_projs_3_3, atol=utest_tolerance(self.dtype)) - ) - - # Check we raise an error for incompatible stack sizes. - msg = "Cannot broadcast with 2 Rotations and 3 Volumes." - with raises(NotImplementedError, match=msg): - _ = vols.project(rots[:2]) - # Parameterize over even and odd resolutions @parameterized.expand([(res,), (res - 1,)]) def testRotate(self, L): @@ -510,3 +477,34 @@ def test_asnumpy_readonly(): # Attempt assignment with raises(ValueError, match=r".*destination is read-only.*"): vw[0, 0, 0, 0] = 123 + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def testProjectBroadcast(dtype): + L = 32 + + # Create stack of 3 Volumes. + vols = AsymmetricVolume(L=L, C=3, dtype=dtype).generate() + + # Create a singleton and stack of Rotations. + rot = Rotation.about_axis("z", np.pi / 6, dtype=dtype) + rots = Rotation.about_axis("z", [np.pi / 4, np.pi / 3, np.pi / 2], dtype=dtype) + + # Broadcast Volume stack with singleton Rotation and compare with manually generated projection. + projs_3_1 = vols.project(rot) + vols_rot_3_1 = vols.rotate(rot) + manual_projs_3_1 = np.sum(vols_rot_3_1, axis=-1) / L + assert projs_3_1.shape[0] == 3 + assert np.allclose(projs_3_1, manual_projs_3_1, atol=utest_tolerance(dtype)) + + # Broadcast Volume stack with Rotation stack of same size and compare with manually generated projections. + projs_3_3 = vols.project(rots) + vols_rot_3_3 = vols.rotate(rots) + manual_projs_3_3 = np.sum(vols_rot_3_3, axis=-1) / L + assert projs_3_3.shape[0] == 3 + assert np.allclose(projs_3_3, manual_projs_3_3, atol=utest_tolerance(dtype)) + + # Check we raise an error for incompatible stack sizes. + msg = "Cannot broadcast with 2 Rotations and 3 Volumes." + with raises(NotImplementedError, match=msg): + _ = vols.project(rots[:2]) From e9c6962ad25eddd8a1415d1e195aca253c9558e8 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 10 May 2023 09:54:25 -0400 Subject: [PATCH 093/116] asnumpy on images in np.allclose --- tests/test_volume.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_volume.py b/tests/test_volume.py index 437c29ce0d..7af8ff8175 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -495,14 +495,14 @@ def testProjectBroadcast(dtype): vols_rot_3_1 = vols.rotate(rot) manual_projs_3_1 = np.sum(vols_rot_3_1, axis=-1) / L assert projs_3_1.shape[0] == 3 - assert np.allclose(projs_3_1, manual_projs_3_1, atol=utest_tolerance(dtype)) + assert np.allclose(projs_3_1.asnumpy(), manual_projs_3_1, atol=utest_tolerance(dtype)) # Broadcast Volume stack with Rotation stack of same size and compare with manually generated projections. projs_3_3 = vols.project(rots) vols_rot_3_3 = vols.rotate(rots) manual_projs_3_3 = np.sum(vols_rot_3_3, axis=-1) / L assert projs_3_3.shape[0] == 3 - assert np.allclose(projs_3_3, manual_projs_3_3, atol=utest_tolerance(dtype)) + assert np.allclose(projs_3_3.asnumpy(), manual_projs_3_3, atol=utest_tolerance(dtype)) # Check we raise an error for incompatible stack sizes. msg = "Cannot broadcast with 2 Rotations and 3 Volumes." From 63c2c28aa93622a63680de2b02f4352c7e382cc1 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 10 May 2023 09:55:02 -0400 Subject: [PATCH 094/116] black --- tests/test_volume.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_volume.py b/tests/test_volume.py index 7af8ff8175..8b2e5e3272 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -495,14 +495,18 @@ def testProjectBroadcast(dtype): vols_rot_3_1 = vols.rotate(rot) manual_projs_3_1 = np.sum(vols_rot_3_1, axis=-1) / L assert projs_3_1.shape[0] == 3 - assert np.allclose(projs_3_1.asnumpy(), manual_projs_3_1, atol=utest_tolerance(dtype)) + assert np.allclose( + projs_3_1.asnumpy(), manual_projs_3_1, atol=utest_tolerance(dtype) + ) # Broadcast Volume stack with Rotation stack of same size and compare with manually generated projections. projs_3_3 = vols.project(rots) vols_rot_3_3 = vols.rotate(rots) manual_projs_3_3 = np.sum(vols_rot_3_3, axis=-1) / L assert projs_3_3.shape[0] == 3 - assert np.allclose(projs_3_3.asnumpy(), manual_projs_3_3, atol=utest_tolerance(dtype)) + assert np.allclose( + projs_3_3.asnumpy(), manual_projs_3_3, atol=utest_tolerance(dtype) + ) # Check we raise an error for incompatible stack sizes. msg = "Cannot broadcast with 2 Rotations and 3 Volumes." From 98ba24bdfb7b56a3241c635b14511225b25a8cce Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 11 May 2023 07:37:47 -0400 Subject: [PATCH 095/116] Try lower frequency volumes to pass on ampere. --- tests/test_volume.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_volume.py b/tests/test_volume.py index 8b2e5e3272..cdc0e62381 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -484,7 +484,7 @@ def testProjectBroadcast(dtype): L = 32 # Create stack of 3 Volumes. - vols = AsymmetricVolume(L=L, C=3, dtype=dtype).generate() + vols = AsymmetricVolume(L=L, C=3, K=16, dtype=dtype).generate() # Create a singleton and stack of Rotations. rot = Rotation.about_axis("z", np.pi / 6, dtype=dtype) From baada97521702a4528640ad3f32b7917ec57a53a Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 11 May 2023 10:52:50 -0400 Subject: [PATCH 096/116] Refactor test. Expand dims on singleton rots. --- src/aspire/volume/volume.py | 6 +++++- tests/test_volume.py | 31 +++++++++++++++---------------- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index 33b69c8da4..ed9ab3acac 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -236,7 +236,7 @@ def project(self, rot_matrices): over a stack of volumes, a singleton Rotation or a Rotation with stack size self.n_vols must be used. - :param rot_matrices: Stack of rotations. Rotation or kx3x3 ndarray instance. + :param rot_matrices: Stack of rotations. Rotation or ndarray instance. :return: `Image` instance. """ # See Issue #727 @@ -257,6 +257,10 @@ def project(self, rot_matrices): " In the future this will raise an error." ) + # Handle singletons. `rotated_grids` expect shape kx3x3. + if rot_matrices.ndim == 2: + rot_matrices = np.expand_dims(rot_matrices, axis=0) + data = self._data n_rots = rot_matrices.shape[0] pts_rot = rotated_grids(self.resolution, rot_matrices) diff --git a/tests/test_volume.py b/tests/test_volume.py index cdc0e62381..6b8e549465 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -484,29 +484,28 @@ def testProjectBroadcast(dtype): L = 32 # Create stack of 3 Volumes. - vols = AsymmetricVolume(L=L, C=3, K=16, dtype=dtype).generate() + n_vols = 3 + vols = AsymmetricVolume(L=L, C=n_vols, dtype=dtype).generate() # Create a singleton and stack of Rotations. rot = Rotation.about_axis("z", np.pi / 6, dtype=dtype) rots = Rotation.about_axis("z", [np.pi / 4, np.pi / 3, np.pi / 2], dtype=dtype) - # Broadcast Volume stack with singleton Rotation and compare with manually generated projection. + # Broadcast Volume stack with singleton Rotation and compare with individually generated projections. projs_3_1 = vols.project(rot) - vols_rot_3_1 = vols.rotate(rot) - manual_projs_3_1 = np.sum(vols_rot_3_1, axis=-1) / L - assert projs_3_1.shape[0] == 3 - assert np.allclose( - projs_3_1.asnumpy(), manual_projs_3_1, atol=utest_tolerance(dtype) - ) - - # Broadcast Volume stack with Rotation stack of same size and compare with manually generated projections. + for i in range(n_vols): + proj_i = vols[i].project(rot) + assert np.allclose(projs_3_1[i], proj_i, atol=utest_tolerance(dtype)) + + assert projs_3_1.shape[0] == n_vols + + # Broadcast Volume stack with Rotation stack of same size and compare with individually generated projections. projs_3_3 = vols.project(rots) - vols_rot_3_3 = vols.rotate(rots) - manual_projs_3_3 = np.sum(vols_rot_3_3, axis=-1) / L - assert projs_3_3.shape[0] == 3 - assert np.allclose( - projs_3_3.asnumpy(), manual_projs_3_3, atol=utest_tolerance(dtype) - ) + for i in range(n_vols): + proj_i = vols[i].project(rots[i]) + assert np.allclose(projs_3_3[i], proj_i, atol=utest_tolerance(dtype)) + + assert projs_3_3.shape[0] == n_vols # Check we raise an error for incompatible stack sizes. msg = "Cannot broadcast with 2 Rotations and 3 Volumes." From 8390997b4932f31f62d98983a3ea72caf882afd9 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 11 May 2023 11:17:49 -0400 Subject: [PATCH 097/116] Use asnumpy in allclose --- tests/test_volume.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_volume.py b/tests/test_volume.py index 6b8e549465..c298863b3c 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -494,16 +494,16 @@ def testProjectBroadcast(dtype): # Broadcast Volume stack with singleton Rotation and compare with individually generated projections. projs_3_1 = vols.project(rot) for i in range(n_vols): - proj_i = vols[i].project(rot) - assert np.allclose(projs_3_1[i], proj_i, atol=utest_tolerance(dtype)) + proj_i = vols[i].project(rot).asnumpy() + assert np.allclose(projs_3_1[i].asnumpy(), proj_i, atol=utest_tolerance(dtype)) assert projs_3_1.shape[0] == n_vols # Broadcast Volume stack with Rotation stack of same size and compare with individually generated projections. projs_3_3 = vols.project(rots) for i in range(n_vols): - proj_i = vols[i].project(rots[i]) - assert np.allclose(projs_3_3[i], proj_i, atol=utest_tolerance(dtype)) + proj_i = vols[i].project(rots[i]).asnumpy() + assert np.allclose(projs_3_3[i].asnumpy(), proj_i, atol=utest_tolerance(dtype)) assert projs_3_3.shape[0] == n_vols From 8263708262ee701dae320ce83d6c83e7cb5789a0 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Thu, 11 May 2023 11:58:46 -0400 Subject: [PATCH 098/116] Add mask --- tests/test_volume.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/tests/test_volume.py b/tests/test_volume.py index c298863b3c..bd8a753c77 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -10,7 +10,7 @@ from parameterized import parameterized from pytest import raises, skip -from aspire.utils import Rotation, powerset +from aspire.utils import Rotation, grid_2d, powerset from aspire.utils.matrix import anorm from aspire.utils.types import utest_tolerance from aspire.volume import AsymmetricVolume, Volume @@ -491,11 +491,18 @@ def testProjectBroadcast(dtype): rot = Rotation.about_axis("z", np.pi / 6, dtype=dtype) rots = Rotation.about_axis("z", [np.pi / 4, np.pi / 3, np.pi / 2], dtype=dtype) + # Test mask. + mask = grid_2d(L)["r"] < 1 + # Broadcast Volume stack with singleton Rotation and compare with individually generated projections. projs_3_1 = vols.project(rot) for i in range(n_vols): proj_i = vols[i].project(rot).asnumpy() - assert np.allclose(projs_3_1[i].asnumpy(), proj_i, atol=utest_tolerance(dtype)) + assert np.allclose( + projs_3_1[i].asnumpy()[0][mask], + proj_i[0][mask], + atol=utest_tolerance(dtype), + ) assert projs_3_1.shape[0] == n_vols @@ -503,7 +510,11 @@ def testProjectBroadcast(dtype): projs_3_3 = vols.project(rots) for i in range(n_vols): proj_i = vols[i].project(rots[i]).asnumpy() - assert np.allclose(projs_3_3[i].asnumpy(), proj_i, atol=utest_tolerance(dtype)) + assert np.allclose( + projs_3_3[i].asnumpy()[0][mask], + proj_i[0][mask], + atol=utest_tolerance(dtype), + ) assert projs_3_3.shape[0] == n_vols From d51bc57eb3ff84b9ebfd24ac7b1a274a042b2a61 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Fri, 12 May 2023 14:42:30 -0400 Subject: [PATCH 099/116] test cleanup --- tests/test_volume.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_volume.py b/tests/test_volume.py index bd8a753c77..531cfd7dcc 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -495,11 +495,11 @@ def testProjectBroadcast(dtype): mask = grid_2d(L)["r"] < 1 # Broadcast Volume stack with singleton Rotation and compare with individually generated projections. - projs_3_1 = vols.project(rot) + projs_3_1 = vols.project(rot).asnumpy() for i in range(n_vols): proj_i = vols[i].project(rot).asnumpy() assert np.allclose( - projs_3_1[i].asnumpy()[0][mask], + projs_3_1[i][mask], proj_i[0][mask], atol=utest_tolerance(dtype), ) @@ -507,11 +507,11 @@ def testProjectBroadcast(dtype): assert projs_3_1.shape[0] == n_vols # Broadcast Volume stack with Rotation stack of same size and compare with individually generated projections. - projs_3_3 = vols.project(rots) + projs_3_3 = vols.project(rots).asnumpy() for i in range(n_vols): proj_i = vols[i].project(rots[i]).asnumpy() assert np.allclose( - projs_3_3[i].asnumpy()[0][mask], + projs_3_3[i][mask], proj_i[0][mask], atol=utest_tolerance(dtype), ) From 0ed04b5843befc4e8b28c65833018119db71a840 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 16 May 2023 11:25:01 -0400 Subject: [PATCH 100/116] docstring clarification. --- src/aspire/volume/volume.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index ed9ab3acac..89490751c3 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -234,7 +234,9 @@ def project(self, rot_matrices): """ Using the stack of rot_matrices, project images of Volume. When projecting over a stack of volumes, a singleton Rotation or a Rotation with stack size - self.n_vols must be used. + self.n_vols must be used. In the case of a singleton Rotation, each Volume in + the stack will be projected using the single Rotation. In the case of a Volume stack + and a Rotation stack, the i'th Volume will be projected using the i'th Rotation. :param rot_matrices: Stack of rotations. Rotation or ndarray instance. :return: `Image` instance. From 2ad6fcb639de711dfacd52f7591d430441845330 Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Tue, 16 May 2023 11:29:20 -0400 Subject: [PATCH 101/116] comment about test case covered elsewhere. --- tests/test_volume.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_volume.py b/tests/test_volume.py index 531cfd7dcc..feee928200 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -518,6 +518,8 @@ def testProjectBroadcast(dtype): assert projs_3_3.shape[0] == n_vols + # Note: The test case for a single Volume and a stack of Rotations is covered above in testProject. + # Check we raise an error for incompatible stack sizes. msg = "Cannot broadcast with 2 Rotations and 3 Volumes." with raises(NotImplementedError, match=msg): From 4792eb2e0e402be57ab80d0260fa381aa44d4fa1 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 26 May 2023 09:32:20 -0400 Subject: [PATCH 102/116] Patch Starfile for change in gemmi 0.6.2 defaults --- src/aspire/storage/starfile.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/aspire/storage/starfile.py b/src/aspire/storage/starfile.py index 7f2c1fae28..19598c1e06 100644 --- a/src/aspire/storage/starfile.py +++ b/src/aspire/storage/starfile.py @@ -67,10 +67,12 @@ def _initialize_blocks(self): # populated if this block as a loop loop_tags = [] loop_data = [] - # correct for GEMMI default behavior - # if a block is called 'data_' in the .star file, GEMMI names it '#' - # but we want to name it '' for consistency - if gemmi_block.name == "#": + # Correct for GEMMI default behavior. + # If a block is called 'data_' in the .star file: + # gemmi>=0.6.2 names it ' ' + # gemmi<0.6.2 names it '#' + # Rename it '' consistently. + if gemmi_block.name in (" ", "#"): gemmi_block.name = "" for gemmi_item in gemmi_block: if gemmi_item.pair is not None: From 2cee0087b695a1f6abf5e11cb34577b96c22cc57 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 23 May 2023 09:20:13 -0400 Subject: [PATCH 103/116] Add wrapper to remap attributes inside IndexedSource --- src/aspire/source/image.py | 22 ++++++++++++++++++++++ tests/test_class_src.py | 10 ++++++++++ 2 files changed, 32 insertions(+) diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index 16c1d4fbd3..06650ac491 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -1309,6 +1309,13 @@ class IndexedSource(ImageSource): Map into another into ImageSource. """ + _indexed_attrs = [ + "selection_indices", + "class_indices", + "class_refl", + "class_distances", + ] + def __init__(self, src, indices, memory=None): """ Instantiates a new source along given `indices`. @@ -1350,6 +1357,21 @@ def __init__(self, src, indices, memory=None): # Any further operations should not mutate this instance. self._mutable = False + def __getattribute__(self, name): + """ + Overrides attribute getter to remap attributes listed in `_indexed_attrs`. + + :param name: Attribute name + """ + + # Avoid recursion + if name in super().__getattribute__("_indexed_attrs"): + # The attribute should be remapped from prior src + return getattr(self.src, name)[self.index_map] + + # Otherwise passthrough. + return super().__getattribute__(name) + def _images(self, indices): """ Returns images from `self.src` corresponding to `indices` diff --git a/tests/test_class_src.py b/tests/test_class_src.py index faa3460416..93d7eecdd0 100644 --- a/tests/test_class_src.py +++ b/tests/test_class_src.py @@ -142,6 +142,16 @@ class averages. np.linalg.norm((orig_imgs - test_imgs).asnumpy(), axis=(1, 2)), 0, atol=0.001 ) + # Check we can slice the source and retrieve remapped attributes + src2 = test_src[::3] + # Check we match selection between hidden and manual slice. + np.testing.assert_equal(src2.selection_indices, test_src.selection_indices[::3]) + # Check we match class indices between hidden and manual slice. + # Note that the class selection counts can be different under repulsion, + # so we will compare the subset that exists in both sources. + k = len(src2.class_indices) + np.testing.assert_equal(src2.class_indices, test_src.class_indices[::3][:k]) + # Test the _HeapItem helper class def test_heap_helper(): From 95d4ac3ae365e8499e7a1ed9d293a0ab594f439c Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 23 May 2023 09:51:18 -0400 Subject: [PATCH 104/116] close figure when plotting to file --- src/aspire/utils/resolution_estimation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/aspire/utils/resolution_estimation.py b/src/aspire/utils/resolution_estimation.py index 255e6e5cbb..73780e36be 100644 --- a/src/aspire/utils/resolution_estimation.py +++ b/src/aspire/utils/resolution_estimation.py @@ -360,6 +360,7 @@ def plot(self, cutoff, save_to_file=False, labels=None): if save_to_file: plt.savefig(save_to_file) + plt.close() else: plt.show() From 3e6f998910913e693da60d9cafafe9d7147a8190 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 23 May 2023 11:32:02 -0400 Subject: [PATCH 105/116] simplify install documentation --- README.md | 55 +++++----- docs/source/installation.rst | 195 +++++++++++++++-------------------- 2 files changed, 110 insertions(+), 140 deletions(-) diff --git a/README.md b/README.md index 6cd3797fb9..b030a9f730 100644 --- a/README.md +++ b/README.md @@ -26,49 +26,50 @@ ComputationalCryoEM/ASPIRE-Python: v0.11.0 https://doi.org/10.5281/zenodo.565728 ## Installation Instructions -For end-users -------------- +Getting Started +--------------- -ASPIRE is a pip-installable package that works on Linux/Mac/Windows, and requires Python 3.7-3.10. The recommended method of installation is to use Anaconda 64-bit for your platform to install Python 3.8 and `pip`, and then use `pip` to install `aspire` in that environment. +ASPIRE is a pip-installable package for Linux/Mac/Windows, and +requires Python 3.7-3.10. The recommended method of installation for +getting started is to use Anaconda (64-bit) for your platform to +install Python. Python's package manager `pip` can then be used to +install `aspire` safely in that environment. -``` -conda create -n aspire_env python=3.8 pip -conda activate aspire_env -pip install aspire -``` - -The final step above should install any dependent packages from `pip` automatically. To see what packages are required, browse `setup.py`. +If you are unfamiliar with `conda`, the +[Miniconda](https://docs.conda.io/en/latest/miniconda.html) +distribution is recommended. -Note that this step installs the base `aspire` package for you to work with, but not the unit tests/documentation. If you need to install ASPIRE for development purposes, read on. +Assuming you have `conda` and a compatible system, the following steps +will checkout current code release, create an environment, and install +ASPIRE. -For developers --------------- - -After cloning this repo, the simplest option is to use Anaconda 64-bit for your platform, and use the provided `environment.yml` file to build a Conda environment to run ASPIRE. This is very similar to above except you will be based off of your local checkout, and you are free to rename `aspire_dev` used in the commands below. The `pip` line will install aspire in a locally editable mode, and is equivalent to `python setup.py develop`: +Python 3.8 is used as an example, but the same procedure should work +for any of our supported Python versions. ``` -cd /path/to/git/clone/folder +# Clone the code +git clone https://github.com/ComputationalCryoEM/ASPIRE-Python.git +cd ASPIRE-Python -# Creates the conda environment and installs base dependencies. -conda env create -f environment-default.yml --name aspire_dev +# Create a fresh environment +conda create --name aspire python=3.8 pip # Enable the environment -conda activate aspire_dev +conda activate aspire -# Install the aspire package in a locally editable way, -# and additionally installs the developer tools extras: +# Install the `aspire` package from the checked out code, +# and additionally installs extra developer tools: pip install -e ".[dev]" - ``` -If you prefer not to use Anaconda, or want to manage environments yourself, you should be able to use `pip` with Python >= 3.7. -Please see the full documentation for details. +If you prefer not to use Anaconda, or have other preferences for managing environments, you should be able to directly use `pip` with Python >= 3.7 from the local checkout or via PyPI. +Please see the full documentation for details and advanced instructions. -### Make sure everything works +### Installation Testing -Once ASPIRE is installed, make sure the unit tests run correctly on your platform by doing: +To check the installation, a unit test suite is provided, +taking approximate 15 minutes on an average machine. ``` -cd /path/to/git/clone/folder pytest ``` diff --git a/docs/source/installation.rst b/docs/source/installation.rst index 61d22be3ca..4d3cb0185f 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -1,9 +1,17 @@ Installation ============ -ASPIRE comes with an ``environment-default.yml`` for reproducing a working Conda environment based on Python 3.8 to run the package. -The package is tested on Linux/Windows/Mac OS X. Pre-built binaries are available for all platform-specific components. No manual -compilation should be needed. +The package is tested on Linux/Windows/Mac OS X. Pre-built binaries should be available for platform-specific dependencies. No manual compilation should be needed. + +For end users who simply want to use or run scripts depending on ASPIRE, simply installing the ``aspire`` package from PyPI is sufficient to use ASPIRE. + +.. note:: + Installing the package installs ASPIRE to the ``site-packages`` folder of your active environment. + This is only desirable if you are not going to be doing any development on ASPIRE, + but simply want to run scripts that depend on the ASPIRE package. + +For those who wish to develop, we recommend starting with the instructions on our README. Additionally some more advanced instructions are provided for installing with software and hardware optimizations. For developers and users not confident in software management, we strongly encourage the use of ``conda``. + Install Conda ************* @@ -16,95 +24,88 @@ distribution to view Conda's installation instructions. .. note:: If you're not sure which distribution is right for you, go with `Miniconda `__ -Install and Activate the environment +Getting Started ************************************ -For most end users, simply installing the package is sufficient to use ASPIRE. -The commands in this section should install ASPIRE directly from the ``Python Package Index`` into your activated environment. -This does not require checking out source code. -If you are interested in checking out and working with the source code, running tests, or a different flavor of install, -then skip to the next section now instead. - -Once ``conda`` is installed and available on the path, we can create a fresh ``conda`` environment. -Here we have chosen to name it ``aspire_env``, but you may choose any name you like so long as that name is used consistently in the following step. -After creating the environment, we activate it. -Finally, we install the ``aspire`` package inside the activated environment. This should install all supporting Python software required. - :: - conda create --name aspire_env python=3.8 pip - conda activate aspire_env - pip install aspire - -.. note:: - Installing the package installs ASPIRE to the ``site-packages`` folder of your active environment. - This is only desirable if you are not going to be doing any development on ASPIRE, - but simply want to run scripts that depend on the ASPIRE package. - + # Clone the code + git clone https://github.com/ComputationalCryoEM/ASPIRE-Python.git + cd ASPIRE-Python -Alternative Developer Installations -*********************************** + # Create a fresh environment + conda create --name aspire python=3.8 pip -Developers are expected to be able to manage their own code and environments. -However, for consistency and newcomers, we recommend the following procedure using `conda`. -Note that here the name ``aspire_dev`` was chosen, but you may choose any name for the ``conda`` environment. -Some people use different environment names for different features, -but this is personal preference and will largely depend on what type of changes you are making. -For example, if you are making changes to dependent package versions for testing, -you would probably want to keep that in a seperate environment. + # Enable the environment + conda activate aspire -:: + # Install the ``aspire`` package from the checked out code, + # and additionally installs extra developer tools: + pip install -e ".[dev]" - # Acquire the code. - git clone -b develop https://github.com/ComputationalCryoEM/ASPIRE-Python - cd ASPIRE-Python - # Create's the conda environment and installs base dependencies. - conda env create -f environment-default.yml --name aspire_dev +Test the package +**************** - # Activate the environment - conda activate aspire_dev +Make sure all unit tests run correctly by doing: - # Command to install the aspire package, along with developer extensions, in a locally editable way: - pip install -e ".[dev]" +:: -We recommend using ``conda`` or a ``virtualenv`` environment managing solutions because ASPIRE may have conflicts or change installed versions of Python packages on your system. + pytest -Again, we recommend the above for consistency. -However, ASPIRE is a ``pip`` package, -so you can attempt to install it using standard ``pip`` or ``setup.py`` commands. There are methods such as ``pip --no-deps`` that can leave your other packages undisturbed, but this is left to the developer. -ASPIRE should generally be compatible with newer version of Python, and newer dependent packages. We are currently testing 3.7, 3.8, 3.9, and 3.10 base Python as configured by ASPIRE, and with upgrading packages to the latest for each of those bases. -If you encounter an issue with a custom pip install, we will try to help, but you may be on your own for support of this method of installation. +Tests currently take around 15 minutes to run, but this depends on +your specific machine's resources and configuration. -:: +Optimized Numerical Backends +**************************** - # Standard pip site-packages installation command - cd path/to/aspire-repo - pip install . +For advanced users, ``conda`` provides optimized numerical backends +that offer significant performance improvements on appropriate +machines. The backends accelerate the performance of ``numpy``, +``scipy``, and ``scikit`` packages. ASPIRE ships several +``environment*.yml`` files which define tested package versions along +with these optimized numerical installations. - # Standard pip developer installation - cd path/to/aspire-repo - pip install -e . +The default ``environment-default.yml`` does not force a specific +backend, instead relying on ``conda`` to select something reasonable. +In the case of an Intel machine, the default ``conda`` install will +automatically install some optimizations for you. However, these +files can be used to specify a specific setup or as the basis for your +own customized ``conda`` environment. +.. list-table:: Suggested Conda Environments + :widths: 25 25 + :header-rows: 1 -Test the package -**************** + * - Architecture + - Recommended Environment File + * - Default + - environment-default.yml + * - Intel x86_64 + - environment-intel.yml + * - AMD x86_64 + - environment-openblas.yml + * - Apple M1 + - environment-accelerate.yml -Make sure all unit tests run correctly by doing: +Using any of these environments follows the same pattern outlined +below. As an example to specify using the ``accelerate`` backend on +an M1 laptop: :: - cd /path/to/git/clone/folder - pytest - -Tests currently take around 5 minutes to run, but this depends on your specific machine's resources. + cd ASPIRE-Python + conda env create -f environment-accelerate.yml --name aspire_acc + conda activate aspire_acc + pip install -e ".[dev]" Installing GPU Extensions ************************* -GPU extensions can be installed using pip. -Extensions are grouped based on CUDA versions. -To find the CUDA driver version, run ``nvidia-smi``. +ASPIRE does support GPUs, depending on several external packages. The +collection of GPU extensions can be installed using ``pip``. +Extensions are grouped based on CUDA versions. To find the CUDA +driver version, run ``nvidia-smi`` on the intended system. .. list-table:: CUDA GPU Extension Versions :widths: 25 25 @@ -132,62 +133,30 @@ the command below would install GPU packages required for ASPIRE. # From a local git repo pip install -e ".[gpu_11x]" -By default if GPU extensions are correctly installed, +By default if the required GPU extensions are correctly installed, ASPIRE should automatically begin using the GPU for select components (such as those using ``nufft``). +Because GPU extensions depend on several third party packages and +libraries, we can only offer limited support if one of the packages +has a problem on your system. + Generating Documentation ************************ -Sphinx Documentation of the source (a local copy of what you're looking at right now) can be generated using: +Sphinx Documentation of the source (a local copy of what you're +looking at right now) can be generated by using the following commands +from the root of the code repository. + +The ``make html`` command runs and renders the ``gallery/tutorials`` +examples, which takes several minutes. :: - cd /path/to/git/clone/folder/docs + cd docs sphinx-apidoc -f -o ./source ../src -H Modules make clean - make html - -The built html files can be found at ``/path/to/git/clone/folder/docs/build/html`` - - -Optimized Numerical Backends -**************************** - -Conda provides optimized numerical backends that can provide significant -performance improvements on appropriate machines. The backends accelerate -the performance of ``numpy``, ``scipy``, and ``scikit`` packages. -ASPIRE ships several ``environment*.yml`` files which define tested package -versions along with these optimized numerical installations. - -The default ``environment-default.yml`` does not force a specific backend, -instead relying on ``conda`` to select something reasonable. -In the case of an Intel machine, the default ``conda`` install -will automatically install some optimizations for you. -However, these files can be used to specify a specific setup -or as the basis for your own customized ``conda`` environment. + make html-noplot # Generate only documentation + make html # Generate documentation and gallery examples -.. list-table:: Suggested Conda Environments - :widths: 25 25 - :header-rows: 1 - - * - Architecture - - Recommended Environment File - * - Default - - environment-default.yml - * - Intel x86_64 - - environment-intel.yml - * - AMD x86_64 - - environment-openblas.yml - * - Apple M1 - - environment-accelerate.yml - -Using any of these environments follows the same pattern outlined above in the developer's section. -As an example to specify using the ``accelerate`` backend on an M1 laptop: - -:: - - cd ASPIRE-Python - conda env create -f environment-accelerate.yml --name aspire_acc - conda activate aspire_acc - pip install -e ".[dev]" +The resulting html files can be found at ``docs/build/html``. From 0cb1babc1971126ec6a728dacbecf9fd3cd0cffe Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 23 May 2023 13:59:22 -0400 Subject: [PATCH 106/116] second pass doc editing --- README.md | 7 ++---- docs/source/installation.rst | 49 +++++++++++++++++++++++------------- 2 files changed, 34 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index b030a9f730..86e5aa340b 100644 --- a/README.md +++ b/README.md @@ -26,8 +26,8 @@ ComputationalCryoEM/ASPIRE-Python: v0.11.0 https://doi.org/10.5281/zenodo.565728 ## Installation Instructions -Getting Started ---------------- +Getting Started - Installation +------------------------------ ASPIRE is a pip-installable package for Linux/Mac/Windows, and requires Python 3.7-3.10. The recommended method of installation for @@ -43,9 +43,6 @@ Assuming you have `conda` and a compatible system, the following steps will checkout current code release, create an environment, and install ASPIRE. -Python 3.8 is used as an example, but the same procedure should work -for any of our supported Python versions. - ``` # Clone the code git clone https://github.com/ComputationalCryoEM/ASPIRE-Python.git diff --git a/docs/source/installation.rst b/docs/source/installation.rst index 4d3cb0185f..d38817c7ac 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -1,32 +1,46 @@ Installation ============ -The package is tested on Linux/Windows/Mac OS X. Pre-built binaries should be available for platform-specific dependencies. No manual compilation should be needed. +This package is tested on Linux/Windows/Mac OS X. Pre-built binaries +should be available for platform-specific dependencies. No manual +compilation should be needed. -For end users who simply want to use or run scripts depending on ASPIRE, simply installing the ``aspire`` package from PyPI is sufficient to use ASPIRE. +For end users who simply want to use or run scripts depending on +ASPIRE, installing the ``aspire`` package from PyPI is sufficient. -.. note:: - Installing the package installs ASPIRE to the ``site-packages`` folder of your active environment. - This is only desirable if you are not going to be doing any development on ASPIRE, - but simply want to run scripts that depend on the ASPIRE package. +.. note:: Installing the package installs ASPIRE to the + ``site-packages`` folder of your active environment. This is only + desirable if you are not going to be doing any development on + ASPIRE, and only intend to run scripts that depend on the ASPIRE + package. -For those who wish to develop, we recommend starting with the instructions on our README. Additionally some more advanced instructions are provided for installing with software and hardware optimizations. For developers and users not confident in software management, we strongly encourage the use of ``conda``. +For those who wish to develop, we recommend starting with the +instructions on our README (copied below). Additionally some more +advanced instructions are provided here for installing with software +and hardware optimizations. Although not explicitly required, For +developers and users not confident in software management the use of +``conda`` is strongly encouraged. Install Conda ************* -To follow the suggested installation, you will need to install Conda for **Python3**, either -`Anaconda `__ or -`Miniconda `__, click on the right -distribution to view Conda's installation instructions. +To follow the suggested installation, you will need to install Conda +for **Python3**, either `Anaconda +`__ or `Miniconda +`__, click on the right distribution +to view Conda's installation instructions. -.. note:: - If you're not sure which distribution is right for you, go with `Miniconda `__ +.. note:: If you're not sure which distribution is right for you, go + with `Miniconda `__ -Getting Started +Getting Started - Installation ************************************ +Python 3.8 is used as an example, but the same procedure should work +for any of our supported Python versions 3.7-3.10. + + :: # Clone the code @@ -127,11 +141,12 @@ the command below would install GPU packages required for ASPIRE. :: - # From PyPI - pip install -e "aspire[gpu_11x]" - # From a local git repo pip install -e ".[gpu_11x]" + + # From PyPI + pip install "aspire[gpu_11x]" + By default if the required GPU extensions are correctly installed, ASPIRE should automatically begin using the GPU for select components From f7621b99046c0314453fb347cbe9e224dca705be Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 24 May 2023 09:47:38 -0400 Subject: [PATCH 107/116] remove the the third-person singular simple present indicative form of install. --- README.md | 2 +- docs/source/installation.rst | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 86e5aa340b..f4e4810df2 100644 --- a/README.md +++ b/README.md @@ -55,7 +55,7 @@ conda create --name aspire python=3.8 pip conda activate aspire # Install the `aspire` package from the checked out code, -# and additionally installs extra developer tools: +# and additional developer tools: pip install -e ".[dev]" ``` diff --git a/docs/source/installation.rst b/docs/source/installation.rst index d38817c7ac..82104b7aa2 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -54,7 +54,7 @@ for any of our supported Python versions 3.7-3.10. conda activate aspire # Install the ``aspire`` package from the checked out code, - # and additionally installs extra developer tools: + # and additional developer tools: pip install -e ".[dev]" From 72f476c69d759a374c65824c227c7de7e508924e Mon Sep 17 00:00:00 2001 From: Josh Carmichael Date: Wed, 24 May 2023 10:54:36 -0400 Subject: [PATCH 108/116] Update installation.rst and README.md Co-authored-by: Garrett Wright <47759732+garrettwrong@users.noreply.github.com> --- README.md | 4 ++-- docs/source/installation.rst | 11 ++++++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index f4e4810df2..62ea4a0d8c 100644 --- a/README.md +++ b/README.md @@ -54,8 +54,8 @@ conda create --name aspire python=3.8 pip # Enable the environment conda activate aspire -# Install the `aspire` package from the checked out code, -# and additional developer tools: +# Install the `aspire` package from the checked out code +# with the additional `dev` extension. pip install -e ".[dev]" ``` diff --git a/docs/source/installation.rst b/docs/source/installation.rst index 82104b7aa2..70401b5db7 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -38,7 +38,12 @@ Getting Started - Installation ************************************ Python 3.8 is used as an example, but the same procedure should work -for any of our supported Python versions 3.7-3.10. +for any of our supported Python versions 3.7-3.10. Below we pip install +the ``aspire`` package using the ``-e`` flag to install the project in +editable mode. The ``".[dev]"`` command installs ``aspire`` from the local +path with additional development tools such as pytest and Jupyter Notebook. +See the `pip documentation `__ +for more details on using pip install. :: @@ -53,8 +58,8 @@ for any of our supported Python versions 3.7-3.10. # Enable the environment conda activate aspire - # Install the ``aspire`` package from the checked out code, - # and additional developer tools: + # Install the ``aspire`` package from the checked out code + # with the additional ``dev`` extension. pip install -e ".[dev]" From ffc9d25d73ab752c79a07b6701fa49144faf0fcd Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 25 May 2023 07:35:08 -0400 Subject: [PATCH 109/116] Clarify the doc building installation doc --- docs/source/installation.rst | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/docs/source/installation.rst b/docs/source/installation.rst index 70401b5db7..b4c3ed9c93 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -168,15 +168,24 @@ Sphinx Documentation of the source (a local copy of what you're looking at right now) can be generated by using the following commands from the root of the code repository. -The ``make html`` command runs and renders the ``gallery/tutorials`` -examples, which takes several minutes. +ASPIRE has both traditional documentation and a gallery of tutorial +scripts. To make only the documentation run ``make html-noplot``. +The ``make html`` command makes the traditonal documentation then runs +and renders the ``gallery/tutorials`` examples, which takes several +minutes. :: cd docs + + # Parse the code in ``src`` sphinx-apidoc -f -o ./source ../src -H Modules - make clean + make html-noplot # Generate only documentation + # or make html # Generate documentation and gallery examples + # To remove any documentation build artifacts + make distclean + The resulting html files can be found at ``docs/build/html``. From 58a0a495b4a832ea9a3283962fa2d473c94d8080 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 25 May 2023 08:38:11 -0400 Subject: [PATCH 110/116] Simplify code comment --- tests/test_class_src.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/test_class_src.py b/tests/test_class_src.py index 93d7eecdd0..7ab58d50f2 100644 --- a/tests/test_class_src.py +++ b/tests/test_class_src.py @@ -144,11 +144,9 @@ class averages. # Check we can slice the source and retrieve remapped attributes src2 = test_src[::3] - # Check we match selection between hidden and manual slice. + # Check we match selection between automatic and manual slice. np.testing.assert_equal(src2.selection_indices, test_src.selection_indices[::3]) - # Check we match class indices between hidden and manual slice. - # Note that the class selection counts can be different under repulsion, - # so we will compare the subset that exists in both sources. + # Check we match class indices between automatic and manual slice. k = len(src2.class_indices) np.testing.assert_equal(src2.class_indices, test_src.class_indices[::3][:k]) From 29181c7f0f9dfc6f06d22a08101a4e78d36fc11b Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 25 May 2023 16:02:45 -0400 Subject: [PATCH 111/116] Try to pack 2d class indices data into metadata table --- src/aspire/denoising/class_avg.py | 56 ++++++++++++++-------- src/aspire/source/image.py | 79 ++++++++++++++++++++++--------- 2 files changed, 92 insertions(+), 43 deletions(-) diff --git a/src/aspire/denoising/class_avg.py b/src/aspire/denoising/class_avg.py index 6f19685b55..716d9bfe77 100644 --- a/src/aspire/denoising/class_avg.py +++ b/src/aspire/denoising/class_avg.py @@ -64,12 +64,7 @@ def __init__( f"`averager` should be instance of `Averager2D`, found {self.averager}." ) - self._nn_classes = None - self._nn_reflections = None - self._nn_distances = None - # Flag for lazy eval, we'll classify once, on first touch. - # We could use self._nn_* vars, but might lose some flexibility later. self._classified = False self._selected = False @@ -114,16 +109,20 @@ def _classify(self): return ( - self._nn_classes, - self._nn_reflections, - self._nn_distances, + self.class_indices, + self.class_refl, + self.class_distances, ) = self.classifier.classify() self._classified = True @property def selection_indices(self): self._class_select() - return self._selection_indices + return super().selection_indices + + @selection_indices.setter + def selection_indices(self, value): + self.set_metadata(["selection_indices"], value) @property def class_indices(self): @@ -141,7 +140,11 @@ def class_indices(self): :return: Numpy array, integers. """ self._classify() - return self._nn_classes + return super().class_indices + + @class_indices.setter + def class_indices(self, table): + self.set_metadata(["class_indices"], [",".join(map(str, row)) for row in table]) @property def class_refl(self): @@ -158,7 +161,11 @@ def class_refl(self): :return: Numpy array, boolean. """ self._classify() - return self._nn_reflections + return super().class_refl + + @class_refl.setter + def class_refl(self, table): + self.set_metadata(["class_refl"], [",".join(map(str, row)) for row in table]) @property def class_distances(self): @@ -176,7 +183,13 @@ def class_distances(self): :return: Numpy array, self.dtype. """ self._classify() - return self._nn_distances + return super().class_distances + + @class_distances.setter + def class_distances(self, table): + self.set_metadata( + ["class_distances"], [",".join(map(str, row)) for row in table] + ) def _class_select(self): """ @@ -197,22 +210,23 @@ def _class_select(self): self._classify() # Perform class selection - self._selection_indices = self.class_selector.select( - self._nn_classes, - self._nn_reflections, - self._nn_distances, + _selection_indices = self.class_selector.select( + self.class_indices, + self.class_refl, + self.class_distances, ) # Override the initial self.n # Some selectors will (dramatically) reduce the space of classes. - if len(self._selection_indices) != self.n: + if len(_selection_indices) != self.n: logger.info( - f"After selection process, updating maximum {len(self._selection_indices)} classes from {self.n}." + f"After selection process, updating maximum {len(_selection_indices)} classes from {self.n}." ) # Note, setter should be inherited from base ImageSource. - self.n = len(self._selection_indices) + self.n = len(_selection_indices) self._selected = True + self.selection_indices = _selection_indices def _images(self, indices): """ @@ -296,9 +310,9 @@ def _images(self, indices): else: # Perform image averaging for the requested images (classes) - logger.debug(f"Averaging {len(indices)} images from source") + logger.debug(f"Averaging {len(_indices)} images from source") im = self.averager.average( - self._nn_classes[indices], self._nn_reflections[indices] + self.class_indices[_indices], self.class_refl[_indices] ) # Finally, apply transforms to resulting Images diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index 06650ac491..bc9e2e6e18 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -467,6 +467,63 @@ def rotations(self, values): np.rad2deg(self._rotations.angles), ) + @property + def class_indices(self): + """ + Returns table of class image indices as `(src.n, n_nbors)` + Numpy array. + + Each row reprsents a class, with the columns ordered by + smallest `class_distances` from the reference image (zeroth + columm). + + Note `n_nbors` is managed by `self.classifier` and used here + for documentation. + + :return: Numpy array, integers. + """ + res = self.get_metadata(["class_indices"]) + return np.vstack([np.array(row.split(","), dtype=int) for row in res]) + + @property + def selection_indices(self): + return self.get_metadata(["selection_indices"]) + + @property + def class_refl(self): + """ + Returns table of class image reflections as `(src.n, n_nbors)` + Numpy array. + + Follows same layout as `class_indices` but holds booleans that + are True when the image should be reflected before averaging. + + Note `n_nbors` is managed by `self.classifier` and used here + for documentation. + + :return: Numpy array, boolean. + """ + res = self.get_metadata(["class_refl"]) + return np.vstack([np.array(row.split(","), dtype=bool) for row in res]) + + @property + def class_distances(self): + """ + Returns table of class image distances as `(src.n, n_nbors)` + Numpy array. + + Follows same layout as `class_indices` but holds floats + representing the distance (returned by classifier) to the + zeroth image in each class. + + Note `n_nbors` is managed by `self.classifier` and used here + for documentation. + + :return: Numpy array, self.dtype. + """ + res = self.get_metadata(["class_distances"]) + return np.vstack([np.array(row.split(","), dtype=self.dtype) for row in res]) + def set_metadata(self, metadata_fields, values, indices=None): """ Modify metadata field information of this ImageSource for selected indices @@ -1309,13 +1366,6 @@ class IndexedSource(ImageSource): Map into another into ImageSource. """ - _indexed_attrs = [ - "selection_indices", - "class_indices", - "class_refl", - "class_distances", - ] - def __init__(self, src, indices, memory=None): """ Instantiates a new source along given `indices`. @@ -1357,21 +1407,6 @@ def __init__(self, src, indices, memory=None): # Any further operations should not mutate this instance. self._mutable = False - def __getattribute__(self, name): - """ - Overrides attribute getter to remap attributes listed in `_indexed_attrs`. - - :param name: Attribute name - """ - - # Avoid recursion - if name in super().__getattribute__("_indexed_attrs"): - # The attribute should be remapped from prior src - return getattr(self.src, name)[self.index_map] - - # Otherwise passthrough. - return super().__getattribute__(name) - def _images(self, indices): """ Returns images from `self.src` corresponding to `indices` From b7d7a95605bf7bd4987cab1e69f7a9cd38e4288f Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 30 May 2023 14:38:39 -0400 Subject: [PATCH 112/116] Don't reuse var name, and fix Bool() refl get bug --- src/aspire/denoising/class_avg.py | 6 +++--- src/aspire/source/image.py | 4 +++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/aspire/denoising/class_avg.py b/src/aspire/denoising/class_avg.py index 716d9bfe77..47a4f098d4 100644 --- a/src/aspire/denoising/class_avg.py +++ b/src/aspire/denoising/class_avg.py @@ -290,10 +290,10 @@ def _images(self, indices): # Recursively call `_images`. # `heap_inds` set should be empty in the recursive call, # and compute only remaining images (those not in heap). - _indices = list(indices_to_compute.keys()) + _compute_indices = list(indices_to_compute.keys()) # Skip when empty (everything requested in heap). - if len(_indices): - _imgs = self._images(_indices) + if len(_compute_indices): + _imgs = self._images(_compute_indices) # Pack images computed from `_images` recursive call. _inds = list(indices_to_compute.values()) diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index bc9e2e6e18..b5ea32a736 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -504,7 +504,9 @@ def class_refl(self): :return: Numpy array, boolean. """ res = self.get_metadata(["class_refl"]) - return np.vstack([np.array(row.split(","), dtype=bool) for row in res]) + return np.vstack( + [np.array(row.split(",")) == "True" for row in res], dtype=bool + ) @property def class_distances(self): From 598be5ff26db9ac9a21d81e7a3d8fb83be5e0316 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Tue, 30 May 2023 16:13:01 -0400 Subject: [PATCH 113/116] try more portable bool cast --- src/aspire/source/image.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index b5ea32a736..75164121d4 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -504,8 +504,8 @@ def class_refl(self): :return: Numpy array, boolean. """ res = self.get_metadata(["class_refl"]) - return np.vstack( - [np.array(row.split(",")) == "True" for row in res], dtype=bool + return np.vstack([np.array(row.split(",")) == "True" for row in res]).astype( + bool ) @property From 78da5835ad0c6e68593b42b429a4989c6ed60000 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Wed, 31 May 2023 08:55:06 -0400 Subject: [PATCH 114/116] prefer integers over True/False --- src/aspire/denoising/class_avg.py | 6 +++++- src/aspire/source/image.py | 3 ++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/aspire/denoising/class_avg.py b/src/aspire/denoising/class_avg.py index 47a4f098d4..68b362608b 100644 --- a/src/aspire/denoising/class_avg.py +++ b/src/aspire/denoising/class_avg.py @@ -165,7 +165,11 @@ def class_refl(self): @class_refl.setter def class_refl(self, table): - self.set_metadata(["class_refl"], [",".join(map(str, row)) for row in table]) + # Convert boolean to (O, 1) integers. + array_int = np.array(table, dtype=int) + self.set_metadata( + ["class_refl"], [",".join(map(str, row)) for row in array_int] + ) @property def class_distances(self): diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index 75164121d4..fe721bb437 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -504,7 +504,8 @@ def class_refl(self): :return: Numpy array, boolean. """ res = self.get_metadata(["class_refl"]) - return np.vstack([np.array(row.split(",")) == "True" for row in res]).astype( + # Read table of (0, 1) integers, cast to `bool`. + return np.vstack([np.array(row.split(","), dtype=int) for row in res]).astype( bool ) From 5bec1637a2fbffd28ddd83bc400aaa97d959288d Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 1 Jun 2023 12:16:41 -0400 Subject: [PATCH 115/116] Add starfile metadata save for new ClassAverageSource attributes --- src/aspire/denoising/class_avg.py | 36 +++++++++++++++++++++++++----- src/aspire/source/image.py | 10 ++++----- tests/test_class_src.py | 37 +++++++++++++++++++++++++++---- 3 files changed, 69 insertions(+), 14 deletions(-) diff --git a/src/aspire/denoising/class_avg.py b/src/aspire/denoising/class_avg.py index 68b362608b..d297841156 100644 --- a/src/aspire/denoising/class_avg.py +++ b/src/aspire/denoising/class_avg.py @@ -122,7 +122,7 @@ def selection_indices(self): @selection_indices.setter def selection_indices(self, value): - self.set_metadata(["selection_indices"], value) + self.set_metadata(["_selection_indices"], value) @property def class_indices(self): @@ -144,7 +144,9 @@ def class_indices(self): @class_indices.setter def class_indices(self, table): - self.set_metadata(["class_indices"], [",".join(map(str, row)) for row in table]) + self.set_metadata( + ["_class_indices"], [",".join(map(str, row)) for row in table] + ) @property def class_refl(self): @@ -168,7 +170,7 @@ def class_refl(self, table): # Convert boolean to (O, 1) integers. array_int = np.array(table, dtype=int) self.set_metadata( - ["class_refl"], [",".join(map(str, row)) for row in array_int] + ["_class_refl"], [",".join(map(str, row)) for row in array_int] ) @property @@ -192,7 +194,7 @@ def class_distances(self): @class_distances.setter def class_distances(self, table): self.set_metadata( - ["class_distances"], [",".join(map(str, row)) for row in table] + ["_class_distances"], [",".join(map(str, row)) for row in table] ) def _class_select(self): @@ -232,6 +234,18 @@ def _class_select(self): self._selected = True self.selection_indices = _selection_indices + def save(self, *args, **kwargs): + """ + Save metadata to STAR file. + + See `ImageSource.save` for documentation. + """ + # Evaluate any lazy actions. + # This should populate relevant metadata. + self._class_select() + # Call parent `save` method. + return super().save(*args, **kwargs) + def _images(self, indices): """ Output images @@ -241,8 +255,20 @@ def _images(self, indices): if not self._selected: self._class_select() + # Truncate the request if nessecary, + # ie, when selection reduces `self.n`. + indices = np.array(indices, dtype=int) + selected = indices[indices < self.n] + + if len(indices) != len(selected): + deselected = indices[indices >= self.n] + logger.debug( + f"Dropping requested indices {deselected} following to selection process." + ) + indices = selected + # Remap to the selected ordering - _indices = indices.copy() # Store original request + _indices = indices.copy() # Store original mapping indices = self.selection_indices[indices] # Check if there is a cache available from class selection component. diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index fe721bb437..87d5e26ed0 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -482,12 +482,12 @@ def class_indices(self): :return: Numpy array, integers. """ - res = self.get_metadata(["class_indices"]) + res = self.get_metadata(["_class_indices"]) return np.vstack([np.array(row.split(","), dtype=int) for row in res]) @property def selection_indices(self): - return self.get_metadata(["selection_indices"]) + return self.get_metadata(["_selection_indices"]) @property def class_refl(self): @@ -503,7 +503,7 @@ def class_refl(self): :return: Numpy array, boolean. """ - res = self.get_metadata(["class_refl"]) + res = self.get_metadata(["_class_refl"]) # Read table of (0, 1) integers, cast to `bool`. return np.vstack([np.array(row.split(","), dtype=int) for row in res]).astype( bool @@ -524,7 +524,7 @@ def class_distances(self): :return: Numpy array, self.dtype. """ - res = self.get_metadata(["class_distances"]) + res = self.get_metadata(["_class_distances"]) return np.vstack([np.array(row.split(","), dtype=self.dtype) for row in res]) def set_metadata(self, metadata_fields, values, indices=None): @@ -930,7 +930,7 @@ def save( overwrite=False, ): """ - Save the output metadata to STAR file and/or images to MRCS file + Save the output metadata to STAR file and/or images to MRCS file. :param starfile_filepath: Path to STAR file where we want to save metadata of image_source diff --git a/tests/test_class_src.py b/tests/test_class_src.py index 7ab58d50f2..6e1b9ca903 100644 --- a/tests/test_class_src.py +++ b/tests/test_class_src.py @@ -1,5 +1,6 @@ import logging import os +import tempfile from heapq import heappush, heappushpop from itertools import product, repeat @@ -25,7 +26,7 @@ from aspire.classification.class_selection import _HeapItem from aspire.denoising import DebugClassAvgSource, DefaultClassAvgSource from aspire.image import Image -from aspire.source import Simulation +from aspire.source import RelionSource, Simulation from aspire.utils import Rotation from aspire.volume import Volume @@ -116,9 +117,11 @@ def class_sim_fixture(dtype, img_size): return src -@pytest.mark.parametrize( - "test_src_cls", CLS_SRCS, ids=lambda param: f"ClassSource={param.__class__}" -) +@pytest.fixture(params=CLS_SRCS, ids=lambda param: f"ClassSource={param.__class__}") +def test_src_cls(request): + return request.param + + def test_basic_averaging(class_sim_fixture, test_src_cls): """ Test that the default `ClassAvgSource` implementations return @@ -272,3 +275,29 @@ def test_contrast_selector(dtype): # Compare indices and scores. assert np.all(selection == ref_class_ids) assert np.allclose(selector._quality_scores, ref_scores) + + +def test_avg_src_starfileio(class_sim_fixture, test_src_cls): + src = test_src_cls(src=class_sim_fixture, num_procs=NUM_PROCS) + + # Save and load the source as a STAR file. + # Saving should force classification and selection to occur, + # and the attributes will be checked below. + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "test.star") + + # Save + src.save(path, overwrite=True) + + # Load + saved_src = RelionSource(path) + + # Get entire metadata table + a = src.get_metadata(as_dict=True) + b = saved_src.get_metadata(as_dict=True) + + # Ensuring src has attributes following classification and selection. + for attr in ("_class_indices", "_class_refl", "_class_distances"): + assert attr in a.keys(), f"Attribute {attr} not in test Source." + assert attr in b.keys(), f"Attribute {attr} not in Source read from disk." + assert all(a[attr] == b[attr]), f"Attribute {attr} does not match." From 077b68ef2ef6d5d526d69cfafdf2dadea08e444b Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Fri, 2 Jun 2023 15:01:02 -0400 Subject: [PATCH 116/116] =?UTF-8?q?Bump=20version:=200.11.0=20=E2=86=92=20?= =?UTF-8?q?0.11.1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .bumpversion.cfg | 2 +- README.md | 4 ++-- docs/source/conf.py | 2 +- docs/source/index.rst | 2 +- setup.py | 2 +- src/aspire/__init__.py | 2 +- src/aspire/config_default.yaml | 2 +- 7 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index abcd3f4309..b3131a6a39 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.11.0 +current_version = 0.11.1 commit = True tag = True diff --git a/README.md b/README.md index 62ea4a0d8c..b81a46b200 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ [![codecov](https://codecov.io/gh/ComputationalCryoEM/ASPIRE-Python/branch/master/graph/badge.svg?token=3XFC4VONX0)](https://codecov.io/gh/ComputationalCryoEM/ASPIRE-Python) [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.5657281.svg)](https://doi.org/10.5281/zenodo.5657281) -# ASPIRE - Algorithms for Single Particle Reconstruction - v0.11.0 +# ASPIRE - Algorithms for Single Particle Reconstruction - v0.11.1 This is the Python version to supersede the [Matlab ASPIRE](https://github.com/PrincetonUniversity/aspire). @@ -20,7 +20,7 @@ For more information about the project, algorithms, and related publications ple Please cite using the following DOI. This DOI represents all versions, and will always resolve to the latest one. ``` -ComputationalCryoEM/ASPIRE-Python: v0.11.0 https://doi.org/10.5281/zenodo.5657281 +ComputationalCryoEM/ASPIRE-Python: v0.11.1 https://doi.org/10.5281/zenodo.5657281 ``` diff --git a/docs/source/conf.py b/docs/source/conf.py index ca036350b3..8094849932 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -86,7 +86,7 @@ # built documents. # # The full version, including alpha/beta/rc tags. -release = version = "0.11.0" +release = version = "0.11.1" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/docs/source/index.rst b/docs/source/index.rst index 6b40838066..74117120f3 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,4 +1,4 @@ -Aspire v0.11.0 +Aspire v0.11.1 ============== Algorithms for Single Particle Reconstruction diff --git a/setup.py b/setup.py index 8d0e1f7a15..cf7db0bf34 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ def read(fname): setup( name="aspire", - version="0.11.0", + version="0.11.1", data_files=[ ("", ["src/aspire/config_default.yaml"]), ("", ["src/aspire/logging.conf"]), diff --git a/src/aspire/__init__.py b/src/aspire/__init__.py index 5fd1b36ca0..2fd907cbfe 100644 --- a/src/aspire/__init__.py +++ b/src/aspire/__init__.py @@ -11,7 +11,7 @@ from aspire.exceptions import handle_exception # version in maj.min.bld format -__version__ = "0.11.0" +__version__ = "0.11.1" # Setup `confuse` config diff --git a/src/aspire/config_default.yaml b/src/aspire/config_default.yaml index 0faadb95e2..f622ba5162 100644 --- a/src/aspire/config_default.yaml +++ b/src/aspire/config_default.yaml @@ -1,4 +1,4 @@ -version: 0.11.0 +version: 0.11.1 common: # numeric module to use - one of numpy/cupy numeric: numpy