from typing import Callable, Optional
import torch
[docs]
class AmplfiPrior:
def __init__(
self,
priors: dict[str, torch.distributions.Distribution],
conversion_function: Optional[Callable] = None,
):
"""
A class for sampling parameters from a prior distribution
Args:
priors:
A dictionary of parameter samplers that take an integer N
and return a tensor of shape (N, ...) representing
samples from the prior distribution
conversion_function:
A callable that takes a dictionary of sampled parameters
and returns a dictionary of waveform generation parameters
"""
super().__init__()
self.priors = priors
self.conversion_function = conversion_function or (lambda x: x)
def __call__(
self,
N: int,
device: str = "cpu",
) -> dict[str, torch.Tensor]:
"""
Generates random samples from the prior
Args:
N: Number of samples to generate
device: Device to place the samples
"""
# sample parameters from prior
parameters = {
k: v.sample((N,)).to(device) for k, v in self.priors.items()
}
# perform any necessary conversions
# to from sampled parameters to
# waveform generation parameters
parameters = self.conversion_function(parameters)
return parameters
[docs]
def log_prob(self, samples: dict[str, torch.Tensor]) -> torch.Tensor:
"""
Calculate the log probability of samples under the prior
Args:
samples:
Dictionary where key is parameter and
value is tensor of samples
"""
first = samples[list(samples.keys())[0]]
log_probs = torch.ones(len(first), device=first.device)
for param, tensor in samples.items():
log_probs += self.priors[param].log_prob(tensor).to(first.device)
return log_probs