Skip to content

Commit

Permalink
Reverted posterior sampling to online implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelklee committed Jul 1, 2024
1 parent 84de6b7 commit e32cb02
Show file tree
Hide file tree
Showing 7 changed files with 420 additions and 368 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from . import io_consts
from . import io_intervals_and_counts
from .. import config
from ..models import commons
from ..models.model_denoising_calling import CopyNumberCallingConfig, DenoisingModelConfig
from ..models.model_denoising_calling import DenoisingCallingWorkspace, DenoisingModel
from ..utils import math

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -152,21 +154,12 @@ def __call__(self):

# compute approximate denoised copy ratios
_logger.info("Sampling and approximating posteriors for denoised copy ratios...")
denoising_copy_ratios_st_symbolic_samples = self.denoising_model_approx.sample_node(
node=self.denoising_model['denoised_copy_ratio_st'],
size=self.denoising_config.num_samples_copy_ratio_approx)
# must use compile_pymc to pass random_seed for reproducible sampling
denoising_copy_ratios_st_mean_samples_func = pm.pytensorf.compile_pymc(
inputs=[],
outputs=denoising_copy_ratios_st_symbolic_samples.mean(axis=0),
random_seed=self.denoising_model_approx.rng.randint(2**30, dtype=np.int64))
denoising_copy_ratios_st_std_samples_func = pm.pytensorf.compile_pymc(
inputs=[],
outputs=denoising_copy_ratios_st_symbolic_samples.std(axis=0),
random_seed=self.denoising_model_approx.rng.randint(2**30, dtype=np.int64))

mu_denoised_copy_ratio_st = denoising_copy_ratios_st_mean_samples_func()
std_denoised_copy_ratio_st = denoising_copy_ratios_st_std_samples_func()
denoising_copy_ratios_st_approx_generator = commons.get_sampling_generator_for_model_approximation(
approx=self.denoising_model_approx, node=self.denoising_model['denoised_copy_ratio_st'],
num_samples=self.denoising_config.num_samples_copy_ratio_approx)
mu_denoised_copy_ratio_st, var_denoised_copy_ratio_st = \
math.calculate_mean_and_variance_online(denoising_copy_ratios_st_approx_generator)
std_denoised_copy_ratio_st = np.sqrt(var_denoised_copy_ratio_st)

for si, sample_name in enumerate(self.denoising_calling_workspace.sample_names):
sample_name_comment_line = [io_consts.sample_name_sam_header_prefix + sample_name]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import numpy as np
import logging
import pytensor
import pytensor.tensor as pt
from pytensor.graph.replace import graph_replace
import pymc as pm
import pymc.distributions.dist_math as pm_dist_math
from pymc.util import makeiter
from pymc.variational.opvi import _known_scan_ignored_inputs
from typing import Tuple, Generator

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -183,3 +187,61 @@ def logsumexp(x, axis=None):
# Adapted from https://github.com/Theano/Theano/issues/1563
x_max = pt.max(x, axis=axis, keepdims=True)
return pt.log(pt.sum(pt.exp(x - x_max), axis=axis, keepdims=True)) + x_max

def stochastic_node_mean_symbolic(approx: pm.MeanField, node, size=100):
"""Symbolic mean of a given PyMC3 stochastic node with respect to a given variational
posterior approximation.
Args:
approx: an instance of PyMC3 approximation
node: stochastic node
size: the number of samples to use for calculating the mean
Returns:
Symbolic approximate mean of the stochastic node
"""

assert size > 0

cum_sum = pt.zeros(node.shape, node.dtype)
# see Approximation.sample_node()
node = approx.model.replace_rvs_by_values([node])
node = node[0]
node = approx.to_flat_input(node)
posterior_samples = approx.symbolic_random.astype(approx.symbolic_initial.dtype)
posterior_samples = pt.specify_shape(posterior_samples, approx.symbolic_initial.type.shape)
posterior_samples = approx.set_size_and_deterministic(posterior_samples, s=size, d=False)

def add_sample_to_cum_sum(posterior_sample, _cum_sum):
new_sample = graph_replace(node, {approx.input: posterior_sample}, strict=False)
return _cum_sum + new_sample

outputs, _ = pytensor.scan(add_sample_to_cum_sum,
sequences=posterior_samples,
non_sequences=_known_scan_ignored_inputs(makeiter(posterior_samples)),
outputs_info=cum_sum,
n_steps=size)

return outputs[-1] / size

def get_sampling_generator_for_model_approximation(approx: pm.MeanField, node,
num_samples: int = 20) -> Generator:
"""Get a generator that returns samples of a precomputed model approximation for a specific variable in that model
Args:
approx: an instance of PyMC3 mean-field approximation
node: a stochastic node in the model
num_samples: number of samples to draw
Returns:
A generator that will yield `num_samples` samples from an approximation to a posterior
"""

assert num_samples > 0

# see Approximation.sample_node()
node = approx.model.replace_rvs_by_values([node])
node = node[0]
node_sample = approx.symbolic_sample_over_posterior(node)
node_sample = approx.set_size_and_deterministic(node_sample, s=1, d=False)
# must use compile_pymc to pass random_seed for reproducible sampling
node_sample_func = pm.pytensorf.compile_pymc(inputs=[], outputs=node_sample,
random_seed=approx.rng.randint(2**30, dtype=np.int64))

return (node_sample_func()[0] for _ in range(num_samples))
Original file line number Diff line number Diff line change
Expand Up @@ -910,12 +910,11 @@ def draw(self) -> np.ndarray:
def _get_compiled_simultaneous_log_copy_number_emission_sampler(self, approx: pm.approximations.MeanField):
"""For a given variational approximation, returns a compiled pytensor function that draws posterior samples
from log copy number emission probabilities."""
log_copy_number_emission_stc = approx.sample_node(
node=self.denoising_model['log_copy_number_emission_stc'],
size=self.inference_params.log_emission_samples_per_round).mean(axis=0)
log_copy_number_emission_stc = commons.stochastic_node_mean_symbolic(
approx, self.denoising_model['log_copy_number_emission_stc'],
size=self.inference_params.log_emission_samples_per_round)
# must use compile_pymc to pass random_seed for reproducible sampling
return pm.pytensorf.compile_pymc(inputs=[],
outputs=log_copy_number_emission_stc,
return pm.pytensorf.compile_pymc(inputs=[], outputs=log_copy_number_emission_stc,
random_seed=approx.rng.randint(2**30, dtype=np.int64))


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,12 +310,9 @@ def draw(self) -> np.ndarray:
def _get_compiled_simultaneous_log_ploidy_emission_sampler(self, approx: pm.approximations.MeanField):
"""For a given variational approximation, returns a compiled pytensor function that draws posterior samples
from the log ploidy emission."""
log_ploidy_emission_sjk = approx.sample_node(
node=self.ploidy_model['logp_sjk'],
size=self.samples_per_round).mean(axis=0)
# must use compile_pymc to pass random_seed for reproducible sampling
return pm.pytensorf.compile_pymc(inputs=[],
outputs=log_ploidy_emission_sjk,
log_ploidy_emission_sjk = commons.stochastic_node_mean_symbolic(
approx, self.ploidy_model['logp_sjk'], size=self.samples_per_round)
return pm.pytensorf.compile_pymc(inputs=[], outputs=log_ploidy_emission_sjk,
random_seed=approx.rng.randint(2**30, dtype=np.int64))


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ public void testNumericalAccuracy() {
.add(GermlineCNVCaller.CONTIG_PLOIDY_CALLS_DIRECTORY_LONG_NAME,
CONTIG_PLOIDY_CALLS_OUTPUT_DIR.getAbsolutePath())
.add(StandardArgumentDefinitions.OUTPUT_LONG_NAME, outputDir)
.add(CopyNumberStandardArgument.OUTPUT_PREFIX_LONG_NAME, outputPrefix);
.add(CopyNumberStandardArgument.OUTPUT_PREFIX_LONG_NAME, outputPrefix)
.add(StandardArgumentDefinitions.VERBOSITY_NAME, "DEBUG");
runCommandLine(argsBuilder);

// Test that values of outputs are approximately numerically equivalent
Expand All @@ -182,7 +183,7 @@ public void testNumericalAccuracy() {
} catch (final IOException ex) {
throw new GATKException("Could not remove GermlineCNVCaller tracking files.");
}
IntStream.range(1, 20).forEach(
IntStream.range(1, TEST_COUNT_FILES.length).forEach(
s -> {
try {
FileUtils.deleteDirectory(new File(Paths.get(GCNV_TEST_OUTPUT_DIR, outputPrefix + "-calls", "SAMPLE_" + s).toString()));
Expand Down
Loading

0 comments on commit e32cb02

Please sign in to comment.