amplfi.train.architectures module
- class amplfi.train.architectures.flows.FlowArchitecture(num_params, embedding_net, embedding_weights=None, freeze_embedding=False)[source]
Bases:
ModuleBase class for normalizing flow architectures that provides interface for interacting with embedding networks
- log_prob(x, context)[source]
Wrapper around
log_prob()from zuko.lazy.Flow object.
- class amplfi.train.architectures.flows.NSF(*args, transforms, hidden_features=(64, 64), passes=None, bins=8, randperm=False, residual=False, **kwargs)[source]
Bases:
FlowArchitectureLight wrapper around the NSF flow from zuko library for compatibility with the FlowArchitecture interface.
See https://zuko.readthedocs.io/stable/api/zuko.flows.spline.html#zuko.flows.spline.NSF
- Parameters:
transforms (
int) – Number of transformations in the flowhidden_features (
Optional[Sequence[int]]) – Sequence of integers representing hidden units in the hyper networkpasses (
Optional[int]) – Default of None corresponds to fully autoregressive flow. A value of 2 corresponds to coupling flow.bins (
Optional[int]) – Number of bins in the splinerandperm (
Optional[bool]) – Whether to randomly permute features in between transformation layersresidual (
Optional[bool]) – Whether to use residual connections in the hyper network.