Source code for amplfi.train.cli.base

from lightning.pytorch.cli import LightningCLI

from ..callbacks import SaveConfigCallback
import torch


[docs] class AmplfiBaseCLI(LightningCLI): def __init__(self, *args, **kwargs): # hack into init to hardcode # parser_mode to omegaconf for all subclasses kwargs["parser_kwargs"] = {"parser_mode": "omegaconf"} kwargs["save_config_callback"] = SaveConfigCallback kwargs["save_config_kwargs"] = {"overwrite": True} super().__init__(*args, **kwargs)
[docs] def after_instantiate_classes(self) -> None: super().after_instantiate_classes() torch.set_float32_matmul_precision( self._get(self.config, "matmul_precision") )
[docs] def add_arguments_to_parser(self, parser): parser.add_argument("--matmul_precision", type=str, default="highest") parser.link_arguments( "data.init_args.sample_rate", "data.init_args.waveform_sampler.init_args.sample_rate", apply_on="parse", ) parser.link_arguments( "data.init_args.kernel_length", "data.init_args.waveform_sampler.init_args.kernel_length", apply_on="parse", ) parser.link_arguments( "data.init_args.fduration", "data.init_args.waveform_sampler.init_args.fduration", apply_on="parse", ) parser.link_arguments( "data.init_args.inference_params", "data.init_args.waveform_sampler.init_args.inference_params", apply_on="parse", ) parser.link_arguments( "data.init_args.inference_params", "model.init_args.inference_params", apply_on="parse", ) parser.link_arguments( "seed_everything", "data.init_args.seed", apply_on="parse", )