Source code for ctis.inverters._iterative._mart._mart
import warnings
import dataclasses
import numpy as np
import named_arrays as na
import ctis
from .. import AbstractIterativeInverter, IterativeInversionResult
__all__ = [
"MartInverter",
]
[docs]
@dataclasses.dataclass
class MartInverter(
AbstractIterativeInverter,
):
"""
An inversion routine based on the multiplicative algebraic reconstruction
technique (MART) :cite:t:`Gordon1970`.
For further information, see the discussion :doc:`../discussions/mart-discussion`.
"""
instrument: ctis.instruments.AbstractInstrument = dataclasses.MISSING
"""
A model of a CTIS instrument which transforms the radiance of an observed
scene to photons measured by the sensors.
"""
gamma: None | float = None
r"""
Learning rate, :math:`\gamma`.
At every iteration, the current correction, :math:`C`, is replaced by
:math:`C^\gamma`.
If :obj:`None`, :math:`\gamma = 2 / N`, where :math:`N` is the number of
channels.
"""
threshold_convergence: float = 1e-3
r"""
The convergence threshold, :math:`T`, which halts the iteration.
If :math:`\langle \chi_{i-1}^2 \rangle - \langle \chi_{i}^2 \rangle < T`,
then the algorithm is considered to be converged.
"""
def __post_init__(self):
if self.gamma is None:
self.gamma = 2 / self.instrument.num_channel
def __call__(
self,
images: na.FunctionArray[na.SpectralPositionalVectorArray, na.ScalarArray],
guess: None | na.ScalarArray = None,
verbose: bool = False,
) -> IterativeInversionResult:
"""
Reconstruct a scene using the observed images.
Parameters
----------
images
The observed images used to calculate the reconstruction.
Must be evaluated on the same position coordinates as
:attr:`~ctis.instruments.AbstractInstrument.coordinates_sensor`
attribute of :attr:`instrument`.
guess
The initial guess at the reconstructed scene.
Must be evaluated on the same coordinates as
:attr:`~ctis.instruments.AbstractInstrument.coordinates_scene`
attribute of :attr:`instrument`.
"""
instrument = self.instrument
axis_channel = instrument.axis_channel
position_images = images.inputs.position
position_sensor = instrument.coordinates_sensor.position
if not np.all(position_images == position_sensor):
raise ValueError(
"`images.inputs.position` and `self.coordinates_sensor.position` "
"are not equal."
)
images_inputs = images.inputs
images = images.outputs
if guess is None:
scene = instrument.backproject(images).outputs
scene = scene.mean(axis_channel)
scene.ndarray[:] = scene.ndarray.mean()
else:
scene = guess.copy()
num_channel = instrument.num_channel
gamma = self.gamma
backprojected = instrument.backproject(images).outputs
backprojected = np.maximum(backprojected, 0)
intermediate = []
merit_old = np.inf
chi2 = []
correlation_residual = []
for i in range(self.num_iteration):
if self.intermediate:
intermediate.append(scene)
if verbose: # pragma: nocover
print(f"{i=}")
images_new = instrument.image(scene, noise=False).outputs
chi2_ij = self.mean_chi_squared(images, images_new)
r_ij = self.correlation_residual(images, images_new)
chi2.append(chi2_ij)
correlation_residual.append(r_ij)
merit = chi2_ij.mean(axis_channel)
if verbose: # pragma: nocover
print(f"merit: {merit}")
if merit > merit_old: # pragma: nocover
message = "Failure: merit increasing."
success = False
num_iteration = i + 1
warnings.warn(message)
break
elif (merit_old - merit) < self.threshold_convergence:
message = f"Achieved merit less than {self.threshold_convergence}."
success = True
num_iteration = i + 1
break
backprojected_new = instrument.backproject(images_new).outputs
backprojected_new = np.maximum(backprojected_new, 0)
correction = backprojected / backprojected_new
correction = np.nan_to_num(
x=correction,
nan=1,
posinf=1,
neginf=1,
)
correction = correction**gamma
correction = np.prod(correction, axis=instrument.axis_channel)
correction = correction ** (1 / num_channel)
if self.intermediate:
scene = scene * correction
else:
scene *= correction
merit_old = merit
else:
message = f"Max number of iterations ({self.num_iteration}) exceeded."
warnings.warn(message)
success = False
num_iteration = self.num_iteration
if self.intermediate:
intermediate = na.stack(intermediate, axis=self.axis_iteration)
solutions = intermediate
else:
solutions = scene.add_axes(self.axis_iteration)
solutions = na.FunctionArray(
inputs=self.instrument.coordinates_scene,
outputs=solutions,
)
images = na.FunctionArray(
inputs=images_inputs,
outputs=images,
)
mean_chi_squared = na.stack(chi2, axis=self.axis_iteration)
correlation_residual = na.stack(correlation_residual, axis=self.axis_iteration)
return IterativeInversionResult(
solutions=solutions,
success=success,
images=images,
inverter=self,
message=message,
num_iteration=num_iteration,
mean_chi_squared=mean_chi_squared,
correlation_residual=correlation_residual,
)