eloy.ballet.training#

Training utilities and synthetic data generation for centroid regression.

This module provides:
  • Moffat2D: a class for generating synthetic 2D Moffat profiles and labels.

  • Training utilities for JAX/Flax models, including loss computation, batching, and training steps.

  • Functions for saving and loading model parameters.

Intended for use in training convolutional neural networks to predict centroids from image cutouts.

Attributes#

Classes#

Moffat2D

Moffat 2D generator.

TrainState

Custom TrainState for model training.

Functions#

compute_loss(params, batch)

Compute the mean squared error loss for a batch.

train_step(state, batch)

Perform a single training step.

eval_step(params, batch)

Evaluate the model on a batch and compute RMSE.

params_to_flat_dict(params)

Flatten model parameters to a dictionary suitable for saving.

get_batches(X, y, batch_size)

Yield batches of data for training.

Module Contents#

class eloy.ballet.training.Moffat2D(cutout_size=21, **kwargs)[source]#

Moffat 2D generator.

Generates synthetic 2D Moffat profiles for training and testing.

Parameters:
  • cutout_size (int, optional) – Size of the generated image cutouts (default is 21).

  • **kwargs – Additional keyword arguments.

moffat2D_model(a, x0, y0, sx, sy, theta, b, beta)[source]#

Generate a 2D Moffat profile.

Parameters:
  • a (float) – Amplitude.

  • x0 (float) – Center coordinates.

  • y0 (float) – Center coordinates.

  • sx (float) – Scale parameters (widths) along x and y.

  • sy (float) – Scale parameters (widths) along x and y.

  • theta (float) – Rotation angle in radians.

  • b (float) – Background level.

  • beta (float) – Moffat beta parameter.

Returns:

2D Moffat profile of shape (cutout_size, cutout_size).

Return type:

numpy.ndarray

sigma_to_fwhm(beta)[source]#

Convert Moffat beta parameter to FWHM.

Parameters:

beta (float) – Moffat beta parameter.

Returns:

Full width at half maximum (FWHM).

Return type:

float

random_model_label(N=10000, flatten=False, return_all=False, sigma=1.0)[source]#

Generate random Moffat images and labels.

Parameters:
  • N (int, optional) – Number of samples to generate (default is 10000).

  • flatten (bool, optional) – If True and N==1, returns single image and label (default is False).

  • return_all (bool, optional) – If True, returns all model parameters as labels (default is False).

  • sigma (float, optional) – Standard deviation for center coordinates (default is 1.0).

Returns:

(images, labels) where images is (N, cutout_size, cutout_size, 1) and labels is (N, 2) or (N, 9) depending on return_all.

Return type:

tuple

class eloy.ballet.training.TrainState[source]#

Bases: flax.training.train_state.TrainState

Custom TrainState for model training.

Inherits from flax.training.train_state.TrainState.

eloy.ballet.training.compute_loss(params, batch)[source]#

Compute the mean squared error loss for a batch.

Parameters:
  • params (dict) – Model parameters.

  • batch (tuple) – Tuple (x, y) of input images and target labels.

Returns:

Mean squared error loss.

Return type:

jax.numpy.DeviceArray

eloy.ballet.training.train_step(state, batch)[source]#

Perform a single training step.

Parameters:
  • state (TrainState) – Current training state.

  • batch (tuple) – Tuple (x, y) of input images and target labels.

Returns:

(new_state, loss)

Return type:

tuple

eloy.ballet.training.eval_step(params, batch)[source]#

Evaluate the model on a batch and compute RMSE.

Parameters:
  • params (dict) – Model parameters.

  • batch (tuple) – Tuple (x, y) of input images and target labels.

Returns:

Root mean squared error (RMSE).

Return type:

jax.numpy.DeviceArray

eloy.ballet.training.params_to_flat_dict(params)[source]#

Flatten model parameters to a dictionary suitable for saving.

Parameters:

params (dict) – Model parameters.

Returns:

Flattened dictionary with keys for each layer’s kernel and bias.

Return type:

dict

eloy.ballet.training.get_batches(X, y, batch_size)[source]#

Yield batches of data for training.

Parameters:
  • X (numpy.ndarray) – Input images.

  • y (numpy.ndarray) – Target labels.

  • batch_size (int) – Batch size.

Yields:

tuple – (x_batch, y_batch) as jax.numpy arrays.

eloy.ballet.training.size = 15[source]#