Source code for naimatic.main

"""A wrapper around naima to facilitate multiple model building and fitting."""

import argparse
import logging

import yaml
import naima


from pathlib import Path
from functools import partial

from astropy.io import ascii  # type: ignore

from .config import Config, Param, MagneticField
from .factory import (
    build_model,
    build_priors,
    extract_p0_labels,
    compute_metadata_blobs
)

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
[docs] logger = logging.getLogger(__name__)
[docs] parser = argparse.ArgumentParser(description="Run Naima models from configuration.")
parser.add_argument( "config", type=str, help="Path to the YAML configuration file for Naima models.", ) parser.add_argument( "--output-path", type=str, default="./naimatic_results", help="Path to the output directory for results.", )
[docs] args = parser.parse_args()
[docs] def load_data(data_dict): """Load data from the provided dictionary of file paths.""" datasets = {} for key, value in data_dict.items(): data_path = Path(value).resolve() if not data_path.exists(): raise FileNotFoundError(f"Data file {value} does not exist.") dataset = ascii.read(data_path) datasets[key] = dataset return datasets
[docs] def model_func(data, model_cfg, radiative_models): """Compute the total flux for the given data and model configuration.""" total_flux = sum( rad_model.flux(data, distance=model_cfg.distance) for rad_model in radiative_models ) return total_flux
[docs] def lnprior(pars, labels, priors_dict): """Compute the log-prior probability for the given parameters.""" logp = 0 for val, name in zip(pars, labels): prior_func = priors_dict.get(name) if prior_func is None: continue logp += prior_func(val) return logp
[docs] def wrapped_model_func(pars, data, model_cfg, pdist, rmodels): """Wrapper for the model function to handle parameter extraction and units.""" i = 0 for pname, param_cfg in model_cfg.particle_distribution.__dict__.items(): if not isinstance(param_cfg, Param): continue if getattr(param_cfg, "freeze", False): val = param_cfg.init_value.value else: val = pars[i] if getattr(param_cfg, "log10", False): val = 10**val i += 1 unit = getattr(param_cfg.init_value, "unit", None) if unit is not None: val = val * unit setattr(pdist, pname, val) for proc_cfg, rmodel in zip(model_cfg.radiative_processes, rmodels): for pname, param_cfg in proc_cfg.__dict__.items(): if isinstance(param_cfg, Param) and not param_cfg.freeze: val = pars[i] i += 1 if getattr(param_cfg, "log10", False): val = 10**val unit = getattr(param_cfg.init_value, "unit", None) if unit is not None: val = val * unit setattr(rmodel, pname, val) total_flux = model_func(data, model_cfg, rmodels) blobs = () if model_cfg.metadata: blobs = compute_metadata_blobs(model_cfg.metadata, pdist, rmodels) return total_flux, *blobs
[docs] def wrapped_lnprior_func(pars, labels, priors): """Wrapper for the log-prior function to handle parameter extraction.""" return lnprior(pars, labels, priors)
[docs] def main(): yaml_path = Path(args.config).resolve() if not yaml_path.exists(): raise FileNotFoundError(f"Configuration file {yaml_path} does not exist.") cfg = Config.model_validate(yaml.safe_load(open(yaml_path))) output_path = cfg.output_path if cfg.output_path else Path(args.output_path) data_tables_dict = load_data(cfg.data) for model_cfg in cfg.models: try: # Initialize magnetic field value if present for proc_cfg in model_cfg.radiative_processes: if proc_cfg.name == "Synchrotron" and isinstance( proc_cfg.B, MagneticField ): if proc_cfg.B.estimate_from: proc_cfg.B = proc_cfg.B.resolve(data_tables_dict) overwrite_flag = model_cfg.overwrite logger.info(f"Running model: {model_cfg.name}") pdist, rmodels = build_model(model_cfg) p0, labels = extract_p0_labels(model_cfg) logger.debug(f"Initial parameters: {p0}") logger.debug(f"Parameter labels: {labels}") priors = build_priors(model_cfg.particle_distribution) model_callable = partial( wrapped_model_func, model_cfg=model_cfg, pdist=pdist, rmodels=rmodels, ) prior_callable = partial(wrapped_lnprior_func, labels=labels, priors=priors) sampler, _ = naima.run_sampler( data_table=list(data_tables_dict.values()), p0=p0, labels=labels, model=model_callable, prior=prior_callable, **cfg.mcmc.model_dump(), ) outdir = output_path / model_cfg.name outdir.mkdir(exist_ok=True, parents=True) results_file = outdir / f"{model_cfg.name}_results.hdf5" naima.save_run(results_file, sampler, clobber=overwrite_flag) logger.info(f"Results saved to: {results_file}") # Build blob labels dynamically from the metadata config metadata_cfg = model_cfg.metadata blob_labels = ["Spectrum"] # the first return value is always the flux if hasattr(model_cfg, "metadata") and model_cfg.metadata: for key, cfg_entry in metadata_cfg.model_dump().items(): if cfg_entry and (cfg_entry.get("save", False)): logger.debug("Processing metadata key: %s", key) if key == "particle_distribution": blob_labels.append("Electron energy distribution") elif key == "total_particle_energy": e_min = cfg_entry.get("e_min") # format nicely if Quantity if hasattr(e_min, "value") and hasattr(e_min, "unit"): e_min_str = f"{e_min.value:g} {e_min.unit}" else: e_min_str = str(e_min) blob_labels.append(f"$W_e (E_e>{e_min_str})$") else: # fallback label blob_labels.append(key) try: plot_prefix = model_cfg.name naima.save_diagnostic_plots( outdir / plot_prefix, sampler, sed=model_cfg.sed, blob_labels=blob_labels, ) logger.info("Diagnostic plots saved.") naima.save_results_table( outdir / plot_prefix, sampler, overwrite=overwrite_flag ) logger.info("Results table saved") except Exception as exception: logger.exception("Could not generate plots") raise exception except Exception as e: logger.exception(f"Error running sampler for model {model_cfg.name}: {e}") continue
if __name__ == "__main__": main()