"""
This module provides spectral moment calculation utilities for IRIS spectrogram cubes.
"""
import numpy as np
import astropy.units as u
from astropy import constants
from irispy.spectrograph import RasterCollection
from irispy.utils._spectral import make_map_cube, make_spatial_template
__all__ = ["calculate_moments"]
def _parse_wings(wings):
if isinstance(wings, u.Quantity):
if wings.isscalar:
return wings, wings
if len(wings) != 2:
msg = "wings must be a scalar Quantity or a two-element Quantity"
raise ValueError(msg)
return wings[0], wings[1]
if isinstance(wings, (tuple, list)) and len(wings) == 2:
if not all(isinstance(wing, u.Quantity) for wing in wings):
msg = "wings tuple elements must be astropy.units.Quantity"
raise TypeError(msg)
return wings[0], wings[1]
msg = "wings must be an astropy.units.Quantity or a tuple of two Quantities"
raise TypeError(msg)
# NOTE: We do not use @u.quantity_input because it cannot validate tuple
# Quantities like ``(0.05, 0.15) * u.Angstrom`` for the ``wings`` argument.
[docs]
def calculate_moments(
cube, *, rest_wavelength=None, wings=None, integrated=False, min_intensity=None, saturation_limit=None
):
r"""
Calculate the 0th, 1st, and 2nd spectral moments of a data cube.
The moments are computed along the spectral (wavelength) axis for each
spatial pixel:
* 0th moment: total intensity, :math:`\sum I(\lambda_i)` (or :math:`\int I(\lambda) \, d\lambda` when ``integrated=True``)
* 1st moment: centroid wavelength, :math:`\sum \lambda_i I(\lambda_i) / \sum I(\lambda_i)`
* 2nd moment: standard deviation, :math:`\sqrt{\sum (\lambda_i - \lambda_0)^2 I(\lambda_i) / \sum I(\lambda_i)}`
Parameters
----------
cube : `irispy.spectrograph.SpectrogramCube`
The input data cube. Must have a spectral (wavelength) axis.
rest_wavelength : `astropy.units.Quantity`, optional
The rest wavelength of the spectral line. Required if ``wings`` is given.
wings : `astropy.units.Quantity`, optional
The spectral range around ``rest_wavelength`` to include in the calculation.
Must be an `~astropy.units.Quantity` with appropriate units (e.g., nm or Angstrom).
If a scalar Quantity, it is applied symmetrically. If a tuple of two Quantities,
they are the lower and upper offsets respectively.
integrated : `bool`, optional
If `True`, the 0th moment is computed as :math:`\int I(\lambda) \, d\lambda`
with units of ``DN·nm``. If `False` (default), it is computed as :math:`\sum I(\lambda)`
with units of ``DN`` (i.e., per-pixel sum, matching the convention used in
Gaussian fitting).
min_intensity : `float`, optional
Minimum integrated (or per-pixel) intensity required for a pixel to be
considered valid. Pixels with intensity **below** this value have all
their moments (including intensity) set to NaN. Useful for excluding
noisy low-signal pixels.
saturation_limit : `float`, optional
Maximum allowed peak intensity in any spectral bin. Pixels where any
bin in the (cropped) profile exceeds this value are treated as
saturated and have all their moments set to NaN.
Returns
-------
`irispy.spectrograph.RasterCollection`
A collection containing 2D `~irispy.spectrograph.SpectrogramCube`
objects with the spatial WCS preserved from the input cube.
Always present:
* ``"intensity"`` — 0th moment (total intensity)
* ``"centroid"`` — 1st moment (centroid wavelength)
* ``"width"`` — 2nd moment (standard deviation)
Additionally, if ``rest_wavelength`` is provided:
* ``"velocity"`` — Doppler shift from the centroid in km/s
* ``"velocity_width"`` — line width converted to velocity units in km/s
Notes
-----
* Negative and non-finite (NaN/inf) data values are set to zero before computing moments.
* Wavelength coordinates are converted to **nm** internally, so ``centroid`` and ``width``
are always returned in nm.
* For a uniform spectral grid, the 1st and 2nd moments are identical regardless of the
``integrated`` setting because the pixel spacing cancels out in the ratio.
References
----------
* `Spectral-Cube moment maps <https://spectral-cube.readthedocs.io/en/latest/moments.html#moment-map-equations>`__
* `arXiv:2005.02029, Section 3.1 <https://arxiv.org/abs/2005.02029>`__
* `Færder et al. (2024), ApJ, Appendix C <https://iopscience.iop.org/article/10.3847/1538-4357/ac4223>`__
"""
try:
wavelength_axis = next(
axis
for axis, physical_types in enumerate(cube.array_axis_physical_types)
if physical_types and "em.wl" in physical_types
)
except StopIteration as exc:
msg = "Could not identify a spectral wavelength axis on the input cube"
raise ValueError(msg) from exc
wavelengths = cube.axis_world_coords(wavelength_axis)[0]
if not isinstance(wavelengths, u.Quantity):
wavelengths = wavelengths * u.one
wavelengths = wavelengths.to(u.nm)
data = np.asarray(cube.data)
mask = None if cube.mask is None else np.asarray(cube.mask, dtype=bool)
if wings is not None:
if rest_wavelength is None:
msg = "rest_wavelength must be provided when wings is given"
raise ValueError(msg)
rest_wavelength = u.Quantity(rest_wavelength)
wing_low, wing_high = _parse_wings(wings)
wavelengths_in_rest_unit = wavelengths.to(rest_wavelength.unit)
wvl_min = rest_wavelength - wing_low.to(rest_wavelength.unit)
wvl_max = rest_wavelength + wing_high.to(rest_wavelength.unit)
crop_mask = (wavelengths_in_rest_unit >= wvl_min) & (wavelengths_in_rest_unit <= wvl_max)
crop_indices = np.where(crop_mask)[0]
if len(crop_indices) == 0:
msg = "No wavelength points found within the specified wings"
raise ValueError(msg)
slicer = [slice(None)] * data.ndim
slicer[wavelength_axis] = crop_indices
data = data[tuple(slicer)]
if mask is not None:
mask = mask[tuple(slicer)]
wavelengths = wavelengths[crop_mask]
data = np.array(data, dtype=float, copy=True)
if mask is not None:
data[mask] = 0
data[(data < 0) | ~np.isfinite(data)] = 0
# Calculate wavelength step (assumed uniform)
dwvl = np.mean(np.diff(wavelengths))
# Move wavelength axis to the end for vectorised computation
data_moved = np.moveaxis(data, wavelength_axis, -1)
wvls = wavelengths.value
# Broadcast wavelengths to match moved data shape
broadcast_shape = [1] * data_moved.ndim
broadcast_shape[-1] = -1
wvls_broadcast = wvls.reshape(broadcast_shape)
dwvl_value = dwvl.value
# Weights for moment calculation: integrated (x dλ) or per-pixel
if integrated:
weights = data_moved * dwvl_value
intensity_unit = cube.unit * dwvl.unit
else:
weights = data_moved
intensity_unit = cube.unit
# 0th moment
intensity_value = np.nansum(weights, axis=-1)
# 1st and 2nd moments are ratios where the weighting cancels out,
# so the result is the same for uniform dλ regardless of ``integrated``.
intensity_nonzero = intensity_value != 0
centroid_numerator = np.nansum(weights * wvls_broadcast, axis=-1)
with np.errstate(invalid="ignore"):
centroid_value = np.where(intensity_nonzero, centroid_numerator / intensity_value, np.nan)
# 2nd moment (variance)
variance_numerator = np.nansum(((wvls_broadcast - centroid_value[..., np.newaxis]) ** 2) * weights, axis=-1)
with np.errstate(invalid="ignore"):
variance_value = np.where(intensity_nonzero, variance_numerator / intensity_value, np.nan)
variance_value = np.where(variance_value < 0, np.nan, variance_value)
stddev_value = np.sqrt(variance_value)
if min_intensity is not None:
if isinstance(min_intensity, u.Quantity):
min_intensity_value = min_intensity.to_value(intensity_unit)
else:
min_intensity_value = min_intensity
low_intensity = intensity_value < min_intensity_value
intensity_value = np.where(low_intensity, np.nan, intensity_value)
centroid_value = np.where(low_intensity, np.nan, centroid_value)
stddev_value = np.where(low_intensity, np.nan, stddev_value)
if saturation_limit is not None:
peak_value = np.max(data_moved, axis=-1)
if isinstance(saturation_limit, u.Quantity):
saturation_limit_value = saturation_limit.to_value(cube.unit)
else:
saturation_limit_value = saturation_limit
saturated = peak_value > saturation_limit_value
intensity_value = np.where(saturated, np.nan, intensity_value)
centroid_value = np.where(saturated, np.nan, centroid_value)
stddev_value = np.where(saturated, np.nan, stddev_value)
template = make_spatial_template(cube, wavelength_axis)
intensity_cube = make_map_cube(template, intensity_value, intensity_unit, mask_invalid=True)
centroid_cube = make_map_cube(template, centroid_value, wavelengths.unit, mask_invalid=True)
width_cube = make_map_cube(template, stddev_value, wavelengths.unit, mask_invalid=True)
cubes = [
("intensity", intensity_cube),
("centroid", centroid_cube),
("width", width_cube),
]
# Compute velocity equivalents when rest_wavelength is known
if rest_wavelength is not None:
rest_wavelength = u.Quantity(rest_wavelength)
with np.errstate(invalid="ignore"):
velocity_value = (
((centroid_value * wavelengths.unit).to(rest_wavelength.unit) - rest_wavelength)
/ rest_wavelength
* constants.c.to(u.km / u.s)
)
velocity_width_value = (
(stddev_value * wavelengths.unit).to(rest_wavelength.unit)
/ rest_wavelength
* constants.c.to(u.km / u.s)
)
velocity_cube = make_map_cube(template, velocity_value.value, velocity_value.unit, mask_invalid=True)
velocity_width_cube = make_map_cube(
template, velocity_width_value.value, velocity_width_value.unit, mask_invalid=True
)
cubes.extend([("velocity", velocity_cube), ("velocity_width", velocity_width_cube)])
return RasterCollection(cubes, aligned_axes=tuple(range(len(template.shape))))