Invert a Synthetic Scene with MART#
In this tutorial, we’ll use the Multiplicative Algebraic Reconstruction Technique (MART) [Gordon et al., 1970] to invert the gaussians() sample scene developed by Amy R. Winebarger.
We will use a simple instrument with four projections to image this scene and then use the MART algorithm to reconstruct the original scene.
[1]:
import IPython.display
import matplotlib.pyplot as plt
import astropy.units as u
import astropy.visualization
import named_arrays as na
import ctis
Start by defining a grid of Doppler velocities on which to reconstruct the scene.
[2]:
velocity = na.linspace(-500, 500, axis="wavelength", num=21) * u.km / u.s
Define the rest wavelength for converting between velocity and wavelength.
[3]:
wavelength_rest = 171 * u.AA
Now define a grid of positions on which to reconstruct the scene,
[4]:
position_scene = na.Cartesian2dVectorLinearSpace(
start=-10 * u.arcsec,
stop=10 * u.arcsec,
axis=na.Cartesian2dVectorArray("scene_x", "scene_y"),
num=na.Cartesian2dVectorArray(64 + 1, 64 + 1),
)
and a grid of positions on the sensor representing the vertices of each pixel.
[5]:
position_sensor = na.Cartesian2dVectorArray(
x=na.arange(0, 128 + 1, axis="sensor_x") * u.pix,
y=na.arange(0, 64 + 1, axis="sensor_y") * u.pix,
)
Combine the 1D velocity grid and the 2D position grid into a single 3D grid for both the scene and sensor coordinates.
[6]:
coordinates_scene = na.DopplerPositionalVectorArray.from_velocity(
velocity=velocity,
wavelength_rest=wavelength_rest,
position=position_scene,
)
[7]:
coordinates_sensor = na.DopplerPositionalVectorArray.from_velocity(
velocity=velocity,
wavelength_rest=wavelength_rest,
position=position_sensor,
)
Create a synthetic scene composed of spatial/spectral 3D Gaussians with various Doppler shifts.
[8]:
scene = ctis.scenes.gaussians(coordinates_scene)
Add a small background equal to 1 percent of the maximum value of the scene.
[9]:
scene = scene + scene.max() / 100
Display the scene as a false-color image.
[10]:
with astropy.visualization.quantity_support():
fig, axs = plt.subplots(
ncols=2,
gridspec_kw=dict(width_ratios=[.9,.1]),
constrained_layout=True,
)
ax, cax = axs
colorbar = na.plt.rgbmesh(
C=scene,
axis_wavelength="wavelength",
ax=ax,
vmin=0,
vmax=scene.outputs.max(),
)
na.plt.pcolormesh(
C=colorbar,
axis_rgb="wavelength",
ax=cax,
)
ax.set_aspect("equal")
ax.set_xlabel(f"scene $x$ ({ax.get_xlabel()})")
ax.set_ylabel(f"scene $y$ ({ax.get_ylabel()})")
cax.xaxis.set_ticks_position("top")
cax.xaxis.set_label_position("top")
cax.yaxis.tick_right()
cax.yaxis.set_label_position("right")
Compute the average spectrum of the scene
[11]:
spectrum = scene.outputs.mean(("scene_x", "scene_y"))
Plot the average spectrum of the scene.
[12]:
with astropy.visualization.quantity_support():
fig, ax = plt.subplots(constrained_layout=True)
ax2 = ax.twiny()
na.plt.stairs(
velocity,
spectrum,
ax=ax,
)
na.plt.stairs(
scene.inputs.wavelength,
spectrum,
ax=ax2
)
ax.set_xlabel(f"Doppler velocity ({ax.get_xlabel()})")
ax2.set_xlabel(f"wavelength ({ax2.get_xlabel()})")
ax.set_ylabel(f"average radiance ({ax.get_ylabel()})")
Define the dispersion angles for our instrument. In this case we’ll define four channels, each each separated by \(90^\circ\) degrees.
[13]:
angle = na.linspace(0, 360, num=4, axis="channel", endpoint=False) * u.deg + 5.64 * u.deg
Define the magnitude of dispersion for our instrument in terms of Doppler velocity and then convert to wavelength units.
[14]:
dispersion = 10 * u.km / u.s
dispersion = dispersion.to(u.AA, equivalencies=u.doppler_optical(wavelength_rest))
dispersion = (dispersion - wavelength_rest) / u.pix
dispersion.to(u.mAA / u.pix)
[14]:
Create an ideal CTIS using the dispersion magnitude and angles.
[15]:
instrument = ctis.instruments.IdealInstrument(
area_effective=1 * u.cm ** 2,
timedelta_exposure=20 * u.s,
plate_scale=.4 * u.arcsec / u.pix,
dispersion=dispersion,
angle=angle,
wavelength_ref=wavelength_rest,
position_ref=na.Cartesian2dVectorArray(64, 32) * u.pix,
coordinates_scene=coordinates_scene,
coordinates_sensor=coordinates_sensor,
channel="dispersion angle = " + angle.to_string_array("%03d"),
axis_channel="channel",
axis_wavelength="wavelength",
axis_scene_xy=("scene_x", "scene_y"),
axis_sensor_xy=("sensor_x", "sensor_y"),
)
Apply the forward model of this instrument to the scene to calculate the observed images.
[16]:
images = instrument.image(scene)
Display the images as an animation, where each frame represents a different channel / dispersion direction.
[17]:
with astropy.visualization.quantity_support():
fig, ax = plt.subplots(
constrained_layout=True,
figsize=(9.2, 4),
)
norm = plt.Normalize(
vmin=0,
vmax=images.outputs.value.ndarray.max(),
)
colorizer = plt.Colorizer(
cmap="gray",
norm=norm,
)
ani = na.plt.pcolormovie(
instrument.channel,
images.inputs.position.x,
images.inputs.position.y,
C=images.outputs.value,
axis_time="channel",
ax=ax,
kwargs_pcolormesh=dict(
colorizer=colorizer,
),
)
plt.colorbar(
mappable=plt.cm.ScalarMappable(colorizer=colorizer),
ax=ax,
label=f"signal ({images.outputs.unit:latex_inline})",
)
ax.set_aspect("equal")
ax.set_xlabel(f"sensor $x$ ({images.inputs.position.x.unit})")
ax.set_ylabel(f"sensor $y$ ({images.inputs.position.y.unit})")
result = ani.to_jshtml(fps=2)
result = IPython.display.HTML(result)
plt.close(ani._fig)
result
[17]:
Initialize the MART inversion algorithm with the instrument model. We’ll also enable saving intermediate results so that we can visualize the behavior of the algorithm.
[18]:
mart = ctis.inverters.MartInverter(
instrument=instrument,
intermediate=True,
)
Invert the images using our instance of MART and the initial guess.
/home/docs/checkouts/readthedocs.org/user_builds/ctis/envs/stable/lib/python3.11/site-packages/astropy/units/quantity.py:648: RuntimeWarning: invalid value encountered in divide
result = super().__array_ufunc__(function, method, *arrays, **kwargs)
Display the results as a false-color movie, where each frame represents subsequent iterations of the MART algorithm.
[20]:
with astropy.visualization.quantity_support():
fig, axs = plt.subplots(
ncols=3,
gridspec_kw=dict(width_ratios=[.5, .5, .1]),
constrained_layout=True,
figsize=(10, 4.5),
)
ax1, ax2, cax = axs
ax2.set_yticklabels([])
na.plt.rgbmesh(
C=scene,
axis_wavelength="wavelength",
ax=ax1,
vmin=0,
vmax=scene.outputs.max(),
)
label = "iteration = " + inversion.iteration.to_string_array("%d") + "\n"
name = r"$\langle \chi^2 \rangle$"
label = label + f"{name} = " + inversion.mean_chi_squared.mean(instrument.axis_channel).to_string_array()
ani, colorbar = na.plt.rgbmovie(
label,
scene.inputs.wavelength,
scene.inputs.position.x,
scene.inputs.position.y,
C=inversion.solutions.outputs,
axis_time=inversion.inverter.axis_iteration,
axis_wavelength="wavelength",
ax=ax2,
vmin=0,
vmax=scene.outputs.max(),
)
na.plt.pcolormesh(
C=colorbar,
axis_rgb="wavelength",
ax=cax,
)
ax1.set_title("original")
ax2.set_title("reconstructed")
unit_x = scene.inputs.position.x.unit
unit_y = scene.inputs.position.y.unit
ax1.set_xlabel(f"scene $x$ ({unit_x:latex_inline})")
ax2.set_xlabel(f"scene $x$ ({unit_x:latex_inline})")
ax1.set_ylabel(f"scene $y$ ({unit_y:latex_inline})")
cax.xaxis.set_ticks_position("top")
cax.xaxis.set_label_position("top")
cax.yaxis.tick_right()
cax.yaxis.set_label_position("right")
result = ani.to_jshtml(fps=20)
result = IPython.display.HTML(result)
plt.close(ani._fig)
result
[20]:
Isolate the solution array from the inversion result object.
[21]:
solution = inversion.solution
Compute the average spectrum of the reconstructed scene.
[22]:
spectrum_inverted = solution.outputs.mean(("scene_x", "scene_y"))
Plot the average spectrum of the original scene vs. the average spectrum of the reconstructed scene.
[23]:
with astropy.visualization.quantity_support():
fig, ax = plt.subplots(constrained_layout=True)
na.plt.stairs(
scene.inputs.wavelength,
spectrum,
ax=ax,
label="original",
)
na.plt.stairs(
scene.inputs.wavelength,
spectrum_inverted,
ax=ax,
label="reconstructed",
)
ax.set_xlabel(f"wavelength ({ax.get_xlabel()})")
ax2.set_xlabel(f"wavelength ({ax2.get_xlabel()})")
ax.set_ylabel(f"average radiance ({ax.get_ylabel()})")
ax.legend()
Plot 2D histograms of the true vs. reconstructed value of the total radiance, median (Doppler shift), and interquartile range (Doppler width) for every pixel in the scene.
[24]:
inversion.plot_moments(scene);
/home/docs/checkouts/readthedocs.org/user_builds/ctis/envs/stable/lib/python3.11/site-packages/named_arrays/_scalars/scalars.py:596: RuntimeWarning: invalid value encountered in divide
result_ndarray = getattr(function, method)(*inputs_ndarray, **kwargs_ndarray)
Plot \(\langle \chi^2 \rangle\) and the signal-correlated residual as a function of iteration.
[25]:
fig, ax = plt.subplots(
nrows=2,
sharex=True,
constrained_layout=True,
)
na.plt.plot(
inversion.iteration,
inversion.mean_chi_squared,
ax=ax[0],
axis=inversion.inverter.axis_iteration,
label=instrument.channel,
)
na.plt.plot(
inversion.iteration,
inversion.correlation_residual,
ax=ax[1],
axis=inversion.inverter.axis_iteration,
label=instrument.channel,
)
ax[0].set_ylabel(r"$\langle \chi^2 \rangle$")
ax[1].set_xlabel("iteration")
ax[1].set_ylabel("signal-correlated residual")
ax[0].set_yscale("log")
ax[0].legend();