Centroiding with ballet

Centroiding with ballet#

Important

You need to install the JAX dependencies to use this model

pip install "eloy[jax]"

To use the Ballet centroiding model in your pipelines, first instantiate the model

from eloy.ballet import Ballet

cnn = Ballet()

And then use it in eloy.centroid.ballet_centroid. Here is an example on one of the calibrated image dataset

from astropy.io import fits
from eloy import detection, centroid
import numpy as np

image = fits.getdata("calibrated_images/ESPC WASP-1245.000secs00000009.fits").astype(
    float
)
regions = detection.stars_detection(image)
coords = np.array([region.centroid_weighted[::-1] for region in regions])

centroids = centroid.ballet_centroid(image, coords, cnn)

The model#

Ballet is a convolutional neural network that computes the centroid of star cutouts. It is as accurate as fitting a 2D Gaussian to the profile but ~20x faster. Let’s show that on our image.

Speed#

We compute the evaluation time from photutils centroid_2dg function against eloy.centroid.ballet_centroid

from photutils.centroids import centroid_2dg

%timeit centroid.photutils_centroid(image, coords, centroid_fun=centroid_2dg)
%timeit centroid.ballet_centroid(image, coords, cnn)
256 ms ± 26.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
29.2 ms ± 720 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Precision#

As for the precision, let’s simulate a set of Moffat profiles and compare the centroid_2dg function to Ballet.

from eloy.ballet.training import Moffat2D
from photutils.centroids import centroid_2dg, centroid_1dg, centroid_quadratic

cutouts, truth = Moffat2D(15).random_model_label(10000)

# these will take a while
centroid_c2dg = np.array([centroid_2dg(cutout)[::-1] for cutout in cutouts[..., 0]])
centroid_c1dg = np.array([centroid_1dg(cutout)[::-1] for cutout in cutouts[..., 0]])
centroid_qd = np.array([centroid_quadratic(cutout)[::-1] for cutout in cutouts[..., 0]])

centroid_ballet = cnn.centroid(cutouts[..., 0])[:, ::-1]
WARNING: The fit may be unsuccessful; check: 
    The maximum number of function evaluations is exceeded. [astropy.modeling.fitting]
WARNING: quadratic fit does not have a maximum [photutils.centroids.core]
WARNING: quadratic polynomial maximum value falls outside of the image [photutils.centroids.core]
Hide code cell source
import matplotlib.pyplot as plt


fig, axes = plt.subplots(1, 4, sharey=True, figsize=(7, 3))
bins = np.linspace(-0.2, 0.2, 200)
offset = 0.01

methods = [centroid_qd, centroid_c1dg, centroid_c2dg, centroid_ballet]
names = [
    "Quadratic\n(photutils)",
    "1D Gaussian\n(photutils)",
    "2D Gaussian\n(photutils)",
    "ballet",
]
for i, method in enumerate(methods):
    axes[i].step(
        bins[1:],
        np.histogram((method - truth), bins)[0],
    )
    axes[i].set_title(names[i])
    axes[i].set_xlabel("Centroid error (pixels)")


axes[0].set_ylabel("Counts")

plt.tight_layout()
../_images/bd630fdc09607208e85380168a242c63596a5cbd4d23ea00007766d3cb1aaed6.png

As precise as photutils 2D Gaussian fit, but much faster!