import abc
import dataclasses
import numpy as np
import matplotlib.pyplot as plt
import astropy.units as u
import astropy.visualization
import named_arrays as na
import ctis
__all__ = [
"InversionResult",
]
[docs]
@dataclasses.dataclass
class AbstractInversionResult(
abc.ABC,
):
"""An interface describing the results of an inversion attempt."""
@property
@abc.abstractmethod
def solution(
self,
) -> na.FunctionArray[
na.AbstractDopplerPositionalVectorArray,
na.ScalarArray,
]:
"""The reconstructed scene found by the inversion."""
@property
@abc.abstractmethod
def success(self) -> bool:
"""Whether the inversion was successful."""
@property
@abc.abstractmethod
def images(
self,
) -> na.FunctionArray[
na.AbstractDopplerPositionalVectorArray,
na.ScalarArray,
]:
"""The observed images used to calculate the inversion."""
@property
@abc.abstractmethod
def inverter(self) -> "ctis.inverters.AbstractInverter":
"""The inversion algorithm instance that produced these results."""
@property
@abc.abstractmethod
def message(self) -> str:
"""Any message from the inverter regarding these results."""
[docs]
def plot_moments(
self,
truth: na.FunctionArray[
na.AbstractDopplerPositionalVectorArray,
na.ScalarArray,
],
num_bins: int = 50,
range_radiance: None | tuple[u.Quantity, u.Quantity] = None,
range_median: None | tuple[u.Quantity, u.Quantity] = None,
range_iqr: None | tuple[u.Quantity, u.Quantity] = None,
) -> tuple[plt.Figure, np.ndarray]:
recon = self.solution
axis_wavelength = self.inverter.instrument.axis_wavelength
wavelength_truth = truth.inputs.wavelength
wavelength_recon = recon.inputs.wavelength
dw_truth = wavelength_truth.volume_cell(axis_wavelength)
dw_recon = wavelength_recon.volume_cell(axis_wavelength)
radiance_truth = (truth.outputs * dw_truth).sum(axis_wavelength)
radiance_recon = (recon.outputs * dw_recon).sum(axis_wavelength)
median_truth = na.pdf.median(
x=truth.inputs.velocity,
f=truth.outputs,
axis="wavelength",
)
median_recon = na.pdf.median(
x=recon.inputs.velocity,
f=recon.outputs,
axis="wavelength",
)
iqr_truth = na.pdf.iqr(
x=truth.inputs.velocity,
f=truth.outputs,
axis="wavelength",
)
iqr_recon = na.pdf.iqr(
x=recon.inputs.velocity,
f=recon.outputs,
axis="wavelength",
)
bins = dict(true=num_bins, reconstructed=num_bins)
if range_radiance is None:
range_radiance = (None, None)
if range_median is None:
range_median = (None, None)
if range_iqr is None:
range_iqr = (None, None)
min_radiance, max_radiance = range_radiance
min_median, max_median = range_median
min_iqr, max_iqr = range_iqr
if min_radiance is None:
min_radiance = 0 * radiance_truth.unit
if max_radiance is None:
max_radiance = radiance_truth.max()
if min_median is None:
min_median = np.nanmin(median_truth)
if max_median is None:
max_median = np.nanmax(median_truth)
if min_iqr is None:
min_iqr = 0 * iqr_truth.unit
if max_iqr is None:
max_iqr = iqr_truth.max()
hist_radiance = na.histogram2d(
radiance_truth,
radiance_recon,
bins=bins,
min=min_radiance,
max=max_radiance,
)
hist_median = na.histogram2d(
median_truth,
median_recon,
bins=bins,
min=min_median,
max=max_median,
)
hist_iqr = na.histogram2d(
iqr_truth,
iqr_recon,
bins=bins,
min=min_iqr,
max=max_iqr,
)
hist_radiance = hist_radiance / hist_radiance.sum("reconstructed")
hist_median = hist_median / hist_median.sum("reconstructed")
hist_iqr = hist_iqr / hist_iqr.sum("reconstructed")
hist_radiance.outputs = np.nan_to_num(
x=hist_radiance.outputs,
posinf=0,
neginf=0,
)
hist_median.outputs = np.nan_to_num(hist_median.outputs)
hist_iqr.outputs = np.nan_to_num(hist_iqr.outputs)
with astropy.visualization.quantity_support():
fig, axs = plt.subplots(
constrained_layout=True,
figsize=(10, 4),
ncols=3,
)
ax_radiance, ax_median, ax_iqr = axs
img_radiance = na.plt.pcolormesh(
C=hist_radiance,
ax=ax_radiance,
vmax=np.nanpercentile(hist_radiance.outputs, 99.5),
)
img_median = na.plt.pcolormesh(
C=hist_median,
ax=ax_median,
vmax=np.nanpercentile(hist_median.outputs, 99.5),
)
img_iqr = na.plt.pcolormesh(
C=hist_iqr,
ax=ax_iqr,
vmax=np.nanpercentile(hist_iqr.outputs, 99.5),
)
pt_radiance = np.nanmean(radiance_truth).ndarray.value
pt_median = np.nanmean(median_truth).ndarray.value
pt_iqr = np.nanmean(iqr_truth).ndarray.value
ax_radiance.axline(
(pt_radiance, pt_radiance),
slope=1,
color="tab:red",
linestyle="dashed",
)
ax_median.axline(
(pt_median, pt_median),
slope=1,
color="tab:red",
linestyle="dashed",
)
ax_iqr.axline(
(pt_iqr, pt_iqr),
slope=1,
color="tab:red",
linestyle="dashed",
)
plt.colorbar(
img_radiance.ndarray.item(),
ax=ax_radiance,
location="top",
label="probability",
)
plt.colorbar(
img_median.ndarray.item(),
ax=ax_median,
location="top",
label="probability",
)
plt.colorbar(
img_iqr.ndarray.item(),
ax=ax_iqr,
location="top",
label="probability",
)
ax_radiance.set_xlabel(
f"true radiance ({ax_radiance.get_xlabel()})",
)
ax_radiance.set_ylabel(
f"reconstructed radiance ({ax_radiance.get_ylabel()})",
)
ax_median.set_xlabel(
f"true median ({ax_median.get_xlabel()})",
)
ax_median.set_ylabel(
f"reconstructed median ({ax_median.get_ylabel()})",
)
ax_iqr.set_xlabel(
f"true IQR ({ax_iqr.get_xlabel()})",
)
ax_iqr.set_ylabel(
f"reconstructed IQR ({ax_iqr.get_ylabel()})",
)
ax_radiance.set_aspect("equal")
ax_median.set_aspect("equal")
ax_iqr.set_aspect("equal")
return fig, axs
[docs]
@dataclasses.dataclass
class InversionResult(
AbstractInversionResult,
):
"""
The results of an inversion attempt.
"""
solution: na.FunctionArray[
na.AbstractDopplerPositionalVectorArray, na.ScalarArray
] = dataclasses.MISSING
"""The reconstructed scene found by the inversion."""
success: bool = dataclasses.MISSING
"""A boolean flag indicating whether the inversion was successful."""
images: na.FunctionArray[na.SpectralPositionalVectorArray, na.ScalarArray] = (
dataclasses.MISSING
)
"""The observed images on which the inversion was performed."""
inverter: "ctis.inverters.AbstractInverter" = dataclasses.MISSING
"""The inversion algorithm instance that produced these results."""
message: str = dataclasses.MISSING
"""Any message from the inversion routine concerning the results."""