Skip to content

Commit

Permalink
Merge pull request #695 from ComputationalCryoEM/batched_cls_avg
Browse files Browse the repository at this point in the history
Batched cls avg
  • Loading branch information
garrettwrong authored Oct 26, 2022
2 parents 3effbbe + d4fa8ea commit 9571203
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 121 deletions.
207 changes: 107 additions & 100 deletions src/aspire/basis/fspca.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import copy
import logging
from collections import OrderedDict

import numpy as np

from aspire.basis import FFBBasis2D, SteerableBasis2D
from aspire.covariance import RotCov2D
from aspire.covariance import BatchedRotCov2D
from aspire.operators import BlkDiagMatrix
from aspire.utils import complex_type, fix_signs, real_type

Expand Down Expand Up @@ -45,6 +44,7 @@ def __init__(
Default value of `None` will estimate noise with WhiteNoiseEstimator.
Use 0 when using clean images so cov2d skips applying noisy covar coeffs..
:param batch_size: Batch size for computing basis coefficients.
`batch_size` is also passed to BatchedRotCov2D.
"""

self.src = src
Expand Down Expand Up @@ -135,21 +135,16 @@ def build(self):
This may take some time for large image stacks.
"""

coef = np.empty((self.src.n, self.basis.count), dtype=self.dtype)
num_batches = (self.src.n + self.batch_size - 1) // self.batch_size
for i in range(num_batches):
start = i * self.batch_size
finish = min((i + 1) * self.batch_size, self.src.n)
coef[start:finish] = self.basis.evaluate_t(self.src.images[start:finish])

if self.noise_var is None:
from aspire.noise import WhiteNoiseEstimator

logger.info("Estimating the noise of images.")
self.noise_var = WhiteNoiseEstimator(self.src).estimate()
logger.info(f"Setting noise_var={self.noise_var}")

cov2d = RotCov2D(self.basis)
cov2d = BatchedRotCov2D(
src=self.src, basis=self.basis, batch_size=self.batch_size
)
covar_opt = {
"shrinker": "frobenius_norm",
"verbose": 0,
Expand All @@ -160,9 +155,8 @@ def build(self):
"precision": "float64",
"preconditioner": "identity",
}
self.mean_coef_est = cov2d.get_mean(coef)
self.mean_coef_est = cov2d.get_mean()
self.covar_coef_est = cov2d.get_covar(
coef,
mean_coeff=self.mean_coef_est,
noise_var=self.noise_var,
covar_est_opt=covar_opt,
Expand All @@ -173,68 +167,26 @@ def build(self):

self.eigvecs = BlkDiagMatrix.empty(2 * self.basis.ell_max + 1, dtype=self.dtype)

self.spca_coef = np.zeros((self.src.n, self.basis.count), dtype=self.dtype)
# Perform the PCA over batches, storing the compressed coefficients.
self._compute_spca()

self._compute_spca(coef)
# Complete compression by mutating class
self._compress()

self._compress(self.components)

def _compute_spca(self, coef):
def _compute_spca(self):
"""
Algorithm 2 from paper.
It has been adopted to use ASPIRE-Python's
cov2d (real) covariance estimation.
"""

# Compute coefficient vector of mean image at zeroth component
self.mean_coef_zero = self.mean_coef_est[self.angular_indices == 0]

# Make the Data matrix (A_k)
# # Construct A_k, matrix of expansion coefficients a^i_k_q
# # for image i, angular index k, radial index q,
# # (around eq 31-33)
# # Rows radial indices, columns image i.
# #
# # We can extract this directly (up to transpose) from
# # fb coef matrix where ells == angular_index
# # then use the transpose so image stack becomes columns.

# Initialize a totally empty BlkDiagMatrix, then build incrementally.
A = BlkDiagMatrix.empty(0, dtype=coef.dtype)

# Zero angular index is special case of indexing.
mask = self.basis._indices["ells"] == 0
A_0 = coef[:, mask] - self.mean_coef_zero
A.append(A_0)

# Remaining angular indices have postive and negative entries in real representation.
for ell in range(
1, self.basis.ell_max + 1
): # `ell` in this code is `k` from paper
mask = self.basis._indices["ells"] == ell
mask_pos = [
mask[i] and (self.basis._indices["sgns"][i] == +1)
for i in range(len(mask))
]
mask_neg = [
mask[i] and (self.basis._indices["sgns"][i] == -1)
for i in range(len(mask))
]

A.append(coef[:, mask_pos])
A.append(coef[:, mask_neg])

if len(A) != len(self.covar_coef_est):
raise RuntimeError(
"Data matrix A should have same number of blocks as Covar matrix.",
f" {len(A)} != {len(self.covar_coef_est)}",
)

# -- Compute the spectrum blockwise. --
# For each angular frequency (`ells` in FB code, `k` from paper)
# we use the properties of Block Diagonal Matrices to work
# on the correspong block.
eigval_index = 0
basis_inds = []
for angular_index, C_k in enumerate(self.covar_coef_est):

# # Eigen/SVD, covariance block C_k should be symmetric.
Expand All @@ -251,20 +203,15 @@ def _compute_spca(self, coef):
)

# These are the dense basis indices for this block.
basis_inds = np.arange(eigval_index, eigval_index + len(eigvals_k))
_basis_inds = np.arange(eigval_index, eigval_index + len(eigvals_k))
basis_inds.append(_basis_inds)

# Store the eigvals for this block, note this is a flat array.
self.eigvals[basis_inds] = eigvals_k
self.eigvals[_basis_inds] = eigvals_k

# Store the eigvecs, note this is a BlkDiagMatrix and is assigned incrementally.
self.eigvecs[angular_index] = eigvecs_k

# To compute new expansion coefficients using spca basis
# we combine the basis coefs using the eigen decomposition.
# Note image stack slow moving axis, otherwise this is just a
# block by block matrix multiply.
self.spca_coef[:, basis_inds] = A[angular_index] @ eigvecs_k

eigval_index += len(eigvals_k)

# Sanity check we have same dimension of eigvals and basis coefs.
Expand All @@ -282,6 +229,80 @@ def _compute_spca(self, coef):
# the coefs. This is used later for compression and index re-generation.
self.sorted_indices = np.argsort(-np.abs(self.eigvals))

compressed_indices = self._get_compressed_indices()

self.spca_coef = np.zeros(
(self.src.n, len(compressed_indices)), dtype=self.dtype
)

# Compute coefficient vector of mean image at zeroth component
self.mean_coef_zero = self.mean_coef_est[self.angular_indices == 0]

# Define mask for zero angular mode, used in loop below
zero_ell_mask = self.basis._indices["ells"] == 0

# Apply Data matrix batchwise
num_batches = (self.src.n + self.batch_size - 1) // self.batch_size
for i in range(num_batches):

# Compute the coefficients for this batch
start = i * self.batch_size
finish = min((i + 1) * self.batch_size, self.src.n)
batch_coef = self.basis.evaluate_t(self.src.images[start:finish])

# Make the Data matrix (A_k)
# # Construct A_k, matrix of expansion coefficients a^i_k_q
# # for image i, angular index k, radial index q,
# # (around eq 31-33)
# # Rows radial indices, columns image i.
# #
# # We can extract this directly (up to transpose) from
# # fb coef matrix where ells == angular_index
# # then use the transpose so image stack becomes columns.

# Initialize a totally empty BlkDiagMatrix, then build incrementally.
A = BlkDiagMatrix.empty(0, dtype=batch_coef.dtype)

# Zero angular index is special case of indexing.
A_0 = batch_coef[:, zero_ell_mask] - self.mean_coef_zero
A.append(A_0)

# Remaining angular indices have postive and negative entries in real representation.
for ell in range(
1, self.basis.ell_max + 1
): # `ell` in this code is `k` from paper
mask_ell = self.basis._indices["ells"] == ell
mask_pos = mask_ell & (self.basis._indices["sgns"] == +1)
mask_neg = mask_ell & (self.basis._indices["sgns"] == -1)

A.append(batch_coef[:, mask_pos])
A.append(batch_coef[:, mask_neg])

if len(A) != len(self.covar_coef_est):
raise RuntimeError(
"Data matrix A should have same number of blocks as Covar matrix.",
f" {len(A)} != {len(self.covar_coef_est)}",
)

# -- Compute new FSPCA coefficients. --
# For each batch
# For each angular frequency (`ells` in FB code, `k` from paper)
# Use the properties of Block Diagonal Matrices to work
# on the correspong block.
blk_spca_coef = np.empty_like(batch_coef)
for angular_index, a_blk in enumerate(A):

# To compute new expansion coefficients using spca basis
# we combine the basis coefs using the eigen decomposition.
# Note image stack slow moving axis, otherwise this is just a
# block by block matrix multiply.
blk_spca_coef[:, basis_inds[angular_index]] = (
a_blk @ self.eigvecs[angular_index]
)

# Assign truncated block to global spca_coef
self.spca_coef[start:finish, :] = blk_spca_coef[:, compressed_indices]

def expand_from_image_basis(self, x):
"""
Take an image in the standard coordinate basis and express as FSPCA coefs.
Expand Down Expand Up @@ -347,12 +368,13 @@ def evaluate(self, c):

return c @ eigvecs.T

def _get_compressed_indices(self, n):
# TODO: Python>=3.8 @cached_property
def _get_compressed_indices(self):
"""
Return the sorted compressed (truncated) indices into the full FSPCA basis.
Note that we return some number of indices in the real representation (in +- pairs)
required to cover the `n` components in the complex representation.
required to cover the `self.components` in the complex representation.
"""

unsigned_components = zip(
Expand All @@ -369,7 +391,7 @@ def _get_compressed_indices(self, n):
ordered_components.setdefault((k, q)) # inserts when not exists yet

# Select the top n (k,q) pairs
top_components = list(ordered_components)[:n]
top_components = list(ordered_components)[: self.components]

# Now we need to find the locations of both the + and - sgns.
pos_mask = self.basis._indices["sgns"] == 1
Expand All @@ -387,46 +409,38 @@ def _get_compressed_indices(self, n):
compressed_indices.append(neg_index)
return compressed_indices

# # Noting this is awful, but I'm still trying to work out how we can push the complex arithmetic out and away...
def _compress(self, n):
def _compress(self):
"""
Use the eigendecomposition to select the most powerful
coefficients.
Using those coefficients new indice mappings are constructed.
Mutates `self`.
:param n: Number of components (coef)
:return: New FSPCABasis instance
"""

if n >= self.count:
if self.components >= self.count:
logger.warning(
f"Requested compression to {n} components,"
f"Requested compression to {self.components} components,"
f" but already {self.count}."
" Skipping compression."
)
return self

# Create a deepcopy.
old = copy.deepcopy(self)

# Create compressed mapping
compressed_indices = old._get_compressed_indices(n)
logger.debug(f"compressed_indices {compressed_indices}")
compressed_indices = self._get_compressed_indices()
self.count = len(compressed_indices)
logger.debug(f"n {n} compressed count {self.count}")

# NOTE, no longer blk_diag! ugh
# Note can copy from old or self, should be same...
self.eigvals = old.eigvals[compressed_indices]
if isinstance(old.eigvecs, BlkDiagMatrix):
old.eigvecs = old.eigvecs.dense()
self.eigvecs = old.eigvecs[:, compressed_indices]
self.spca_coef = old.spca_coef[:, compressed_indices]
self.eigvals = self.eigvals[compressed_indices]
if isinstance(self.eigvecs, BlkDiagMatrix):
self.eigvecs = self.eigvecs.dense()
self.eigvecs = self.eigvecs[:, compressed_indices]

self.angular_indices = old.angular_indices[compressed_indices]
self.radial_indices = old.radial_indices[compressed_indices]
self.signs_indices = old.signs_indices[compressed_indices]
self.angular_indices = self.angular_indices[compressed_indices]
self.radial_indices = self.radial_indices[compressed_indices]
self.signs_indices = self.signs_indices[compressed_indices]

self.complex_indices_map = self._get_complex_indices_map()
self.complex_count = len(self.complex_indices_map)
Expand All @@ -437,13 +451,6 @@ def _compress(self, n):
self.complex_angular_indices[i] = ang
self.complex_radial_indices[i] = rad

logger.debug(
f"complex_radial_indices: {self.complex_radial_indices} {len(self.complex_radial_indices)}"
)
logger.debug(
f"complex_angular_indices: {self.complex_angular_indices} {len(self.complex_angular_indices)}"
)

def to_complex(self, coef):
"""
Return complex valued representation of coefficients.
Expand Down
Loading

0 comments on commit 9571203

Please sign in to comment.