Source code for ctis.inverters._iterative._iterative
from typing import ClassVar
import dataclasses
import named_arrays as na
import ctis
from .. import AbstractInverter, AbstractInversionResult
__all__ = [
"AbstractIterativeInverter",
"IterativeInversionResult",
]
[docs]
@dataclasses.dataclass
class AbstractIterativeInverter(
AbstractInverter,
):
"""
An abstract inversion algorithm which reconstructs an observed scene
using iterative methods.
These methods will apply some operation repeatedly until a specified
convergence criteria is met.
"""
axis_iteration: ClassVar[str] = "iteration"
"""The logical axis associated with changing iteration index."""
num_iteration: int = dataclasses.field(default=100, kw_only=True)
"""
The maximum number of iterations to perform.
If convergence is not reached before this number is exceeded,
a warning is raised and an unsuccessful result is returned.
"""
intermediate: bool = dataclasses.field(default=False, kw_only=True)
"""
Whether to save intermediate solutions.
This is set to :obj:`False` during normal operation, but can be useful for
debugging or demonstration purposes.
"""
[docs]
def mean_chi_squared(
self,
images_observed: na.ScalarArray,
images_predicted: na.ScalarArray,
) -> na.ScalarArray:
r"""
Evaluate :math:`\langle \chi^2 \rangle` for each observed/predicted
image pair.
Parameters
----------
images_observed
The actual measured images.
images_predicted
The images predicted by the inversion.
"""
uncertainty = self.instrument.uncertainty(images_predicted)
return ctis.inverters.merit.mean_chi_squared(
observed=images_observed,
expected=images_predicted,
uncertainty=uncertainty,
axis=self.instrument.axis_sensor_xy,
)
[docs]
def correlation_residual(
self,
images_observed: na.ScalarArray,
images_predicted: na.ScalarArray,
) -> na.ScalarArray:
"""
Evaluate the correlation between the predicted images and the residual.
Parameters
----------
images_observed
The actual measured images.
images_predicted
The images predicted by the inversion.
"""
return ctis.inverters.merit.correlation_residual(
observed=images_observed,
expected=images_predicted,
axis=self.instrument.axis_sensor_xy,
)
[docs]
@dataclasses.dataclass
class IterativeInversionResult(
AbstractInversionResult,
):
"""The results of an iterative inversion attempt."""
solutions: na.FunctionArray[na.SpectralPositionalVectorArray, na.ScalarArray]
"""
Intermediate solutions from each iteration.
If :attr:`AbstractIterativeInverter.intermediate` is set to :obj:`True`,
this has up to :attr:`~AbstractIterativeInverter.num_iteration` elements
along the :attr:`~AbstractIterativeInverter.axis_iteration` logical axis.
Otherwise this has only one element along the
:attr:`~AbstractIterativeInverter.axis_iteration` axis.
"""
success: bool = dataclasses.MISSING
"""A boolean flag indicating whether the inversion was successful."""
images: na.FunctionArray[
na.SpectralPositionalVectorArray,
na.ScalarArray,
] = dataclasses.MISSING
"""The observed images on which the inversion was performed."""
inverter: "ctis.inverters.AbstractInverter" = dataclasses.MISSING
"""The inversion algorithm instance that produced these results."""
message: str = dataclasses.MISSING
"""Any message from the inversion routine concerning the results."""
num_iteration: int = dataclasses.MISSING
"""The number of iterations performed by the inverter."""
mean_chi_squared: na.ScalarArray = dataclasses.MISSING
"""The mean chi squared statistic for each iteration."""
correlation_residual: na.ScalarArray = dataclasses.MISSING
"""
The correlation between the predicted images and the residuals
for each iteration.
"""
@property
def solution(
self,
) -> na.FunctionArray[na.SpectralPositionalVectorArray, na.ScalarArray]:
axis_iteration = self.inverter.axis_iteration
return self.solutions[{axis_iteration: ~0}]
@property
def iteration(self) -> na.ScalarArray:
"""The iteration value for each iteration."""
return na.arange(
start=0,
stop=self.num_iteration,
axis=self.inverter.axis_iteration,
)