Source code for naimatic.factory
"""Library of factory functions."""
import logging
import astropy.units as u # type: ignore
import naima
import naima.models as nm
import numpy as np
from functools import partial
from .config import (
ModelConfig,
Param,
ParticleDistributionConfig,
RadiativeProcessConfig,
)
__all__ = [
"build_priors",
"build_particle_distribution",
"build_radiative_process",
"build_model",
"extract_p0_labels",
"compute_metadata_blobs",
]
logger = logging.getLogger(__name__)
def _prior2_caller(x, fn, a, b):
"""Call a Naima prior function with two hyperparameters (e.g., min/max or mu/sigma)."""
return fn(x, a, b)
[docs]
def build_priors(particle_dist_cfg: ParticleDistributionConfig):
"""Build a dictionary of priors from the particle distribution configuration."""
priors = {}
for param_name, param_cfg in particle_dist_cfg.__dict__.items():
if isinstance(param_cfg, Param) and param_cfg.prior is not None:
prior_cfg = param_cfg.prior
prior_func = getattr(naima, f"{prior_cfg.name}_prior", None)
if prior_func is None:
raise ValueError(f"Unknown prior type: {prior_cfg.name}")
# Pull numeric bounds/hyperparameters
if prior_cfg.name in {"uniform", "loguniform"}:
a = float(getattr(prior_cfg.min, "value", prior_cfg.min))
b = float(getattr(prior_cfg.max, "value", prior_cfg.max))
if getattr(param_cfg, "log10", False):
a = np.log10(a)
b = np.log10(b)
priors[param_name] = partial(_prior2_caller, fn=prior_func, a=a, b=b)
elif prior_cfg.name == "normal":
mu = float(prior_cfg.mu)
sigma = float(prior_cfg.sigma)
if getattr(param_cfg, "log10", False):
mu = np.log10(mu)
priors[param_name] = partial(
_prior2_caller, fn=prior_func, a=mu, b=sigma
)
return priors
[docs]
def build_particle_distribution(cfg: ParticleDistributionConfig):
"""Build a particle distribution from configuration."""
model_cls = getattr(nm, cfg.name, None)
if model_cls is None:
raise ValueError(f"Unknown distribution: {cfg.name}")
kwargs = {
k: v.init_value
for k, v in cfg.__dict__.items()
if isinstance(v, Param) and v.init_value is not None
}
return model_cls(**kwargs)
[docs]
def build_radiative_process(cfg: RadiativeProcessConfig, particle_distribution):
"""Build a single radiative process from configuration."""
model_cls = getattr(nm, cfg.name, None)
if model_cls is None:
raise ValueError(f"Unknown radiative model: {cfg.name}")
kwargs = {}
for key, value in cfg.__dict__.items():
if isinstance(value, Param):
if value.init_value is not None:
kwargs[key] = value.init_value
elif key not in ("name", "particle_distribution") and value is not None:
kwargs[key] = value
return model_cls(particle_distribution, **kwargs)
[docs]
def build_model(model_cfg: ModelConfig):
"""Build a full Naima model from configuration."""
shared_pd = build_particle_distribution(model_cfg.particle_distribution)
processes = []
for proc_cfg in model_cfg.radiative_processes:
# If the process has its own particle distribution, build and pass it
pd_cfg = getattr(proc_cfg, "particle_distribution", None)
if pd_cfg:
pd = build_particle_distribution(pd_cfg)
else:
pd = shared_pd
process = build_radiative_process(proc_cfg, pd)
processes.append(process)
return shared_pd, processes
[docs]
def extract_p0_labels(model_cfg):
"""Ëxtract initial parameter values and their labels from the model configuration."""
p0 = []
labels = []
for param_name, param_cfg in model_cfg.particle_distribution.__dict__.items():
if isinstance(param_cfg, Param) and not param_cfg.freeze:
# Extract raw float value from Quantity or use as-is
raw_value = getattr(param_cfg.init_value, "value", param_cfg.init_value)
if getattr(param_cfg, "log10", True):
try:
if raw_value is None or raw_value <= 0:
logger.exception(
"Cannot take log10 of non-positive value: %s", raw_value
)
raise
logval = np.log10(raw_value)
except Exception:
logger.exception(
"np.log10 failed for %s with value %s", param_name, raw_value
)
logval = np.nan
p0.append(logval)
labels.append(f"log10({param_name})")
else:
p0.append(raw_value)
labels.append(param_name)
for proc_cfg in model_cfg.radiative_processes:
for param_name, param_cfg in proc_cfg.__dict__.items():
if isinstance(param_cfg, Param) and not param_cfg.freeze:
p0.append(param_cfg.init_value.value)
labels.append(f"{proc_cfg.name}.{param_name}")
return np.array(p0), labels
[docs]
def compute_metadata_blobs(metadata_cfg, pdist, rmodels):
blobs = []
for key, cfg_entry in metadata_cfg.model_dump().items():
if (cfg_entry is None) or (not cfg_entry.get("save", False)):
continue
if key == "particle_distribution":
energy_range = cfg_entry.get("energy_range")
blobs.append((energy_range, pdist(energy_range)))
elif key == "total_particle_energy":
e_min = cfg_entry.get("e_min", 1 * u.TeV)
# Check if all rmodels share the same particle distribution instance/parameters
pdists = [getattr(r, "particle_distribution", None) for r in rmodels]
# Compare by id (same object) or by parameters
all_same = all(p is not None and p is pdists[0] for p in pdists)
if all_same:
# Only sum once
try:
total_energy = rmodels[0].compute_We(Eemin=e_min)
except AttributeError:
total_energy = rmodels[0].compute_Wp(Epmin=e_min)
except Exception:
logger.exception(
"Failed to compute total particle energy with e_min=%s", e_min
)
raise
else:
try:
total_energy = sum(r.compute_We(Eemin=e_min) for r in rmodels)
except AttributeError:
total_energy = sum(r.compute_Wp(Epmin=e_min) for r in rmodels)
except Exception:
logger.exception(
"Failed to compute total particle energy with e_min=%s", e_min
)
raise
blobs.append(total_energy)
return blobs