eloy.ballet.training#
Training utilities and synthetic data generation for centroid regression.
- This module provides:
Moffat2D: a class for generating synthetic 2D Moffat profiles and labels.
Training utilities for JAX/Flax models, including loss computation, batching, and training steps.
Functions for saving and loading model parameters.
Intended for use in training convolutional neural networks to predict centroids from image cutouts.
Attributes#
Classes#
Moffat 2D generator. |
|
Custom TrainState for model training. |
Functions#
|
Compute the mean squared error loss for a batch. |
|
Perform a single training step. |
|
Evaluate the model on a batch and compute RMSE. |
|
Flatten model parameters to a dictionary suitable for saving. |
|
Yield batches of data for training. |
Module Contents#
- class eloy.ballet.training.Moffat2D(cutout_size=21, **kwargs)[source]#
Moffat 2D generator.
Generates synthetic 2D Moffat profiles for training and testing.
- Parameters:
cutout_size (int, optional) – Size of the generated image cutouts (default is 21).
**kwargs – Additional keyword arguments.
- moffat2D_model(a, x0, y0, sx, sy, theta, b, beta)[source]#
Generate a 2D Moffat profile.
- Parameters:
a (float) – Amplitude.
x0 (float) – Center coordinates.
y0 (float) – Center coordinates.
sx (float) – Scale parameters (widths) along x and y.
sy (float) – Scale parameters (widths) along x and y.
theta (float) – Rotation angle in radians.
b (float) – Background level.
beta (float) – Moffat beta parameter.
- Returns:
2D Moffat profile of shape (cutout_size, cutout_size).
- Return type:
numpy.ndarray
- sigma_to_fwhm(beta)[source]#
Convert Moffat beta parameter to FWHM.
- Parameters:
beta (float) – Moffat beta parameter.
- Returns:
Full width at half maximum (FWHM).
- Return type:
float
- random_model_label(N=10000, flatten=False, return_all=False, sigma=1.0)[source]#
Generate random Moffat images and labels.
- Parameters:
N (int, optional) – Number of samples to generate (default is 10000).
flatten (bool, optional) – If True and N==1, returns single image and label (default is False).
return_all (bool, optional) – If True, returns all model parameters as labels (default is False).
sigma (float, optional) – Standard deviation for center coordinates (default is 1.0).
- Returns:
(images, labels) where images is (N, cutout_size, cutout_size, 1) and labels is (N, 2) or (N, 9) depending on return_all.
- Return type:
tuple
- class eloy.ballet.training.TrainState[source]#
Bases:
flax.training.train_state.TrainStateCustom TrainState for model training.
Inherits from flax.training.train_state.TrainState.
- eloy.ballet.training.compute_loss(params, batch)[source]#
Compute the mean squared error loss for a batch.
- Parameters:
params (dict) – Model parameters.
batch (tuple) – Tuple (x, y) of input images and target labels.
- Returns:
Mean squared error loss.
- Return type:
jax.numpy.DeviceArray
- eloy.ballet.training.train_step(state, batch)[source]#
Perform a single training step.
- Parameters:
state (TrainState) – Current training state.
batch (tuple) – Tuple (x, y) of input images and target labels.
- Returns:
(new_state, loss)
- Return type:
tuple
- eloy.ballet.training.eval_step(params, batch)[source]#
Evaluate the model on a batch and compute RMSE.
- Parameters:
params (dict) – Model parameters.
batch (tuple) – Tuple (x, y) of input images and target labels.
- Returns:
Root mean squared error (RMSE).
- Return type:
jax.numpy.DeviceArray
- eloy.ballet.training.params_to_flat_dict(params)[source]#
Flatten model parameters to a dictionary suitable for saving.
- Parameters:
params (dict) – Model parameters.
- Returns:
Flattened dictionary with keys for each layer’s kernel and bias.
- Return type:
dict