Source code for eloy.ballet.training

"""
Training utilities and synthetic data generation for centroid regression.

This module provides:
    - Moffat2D: a class for generating synthetic 2D Moffat profiles and labels.
    - Training utilities for JAX/Flax models, including loss computation, batching, and training steps.
    - Functions for saving and loading model parameters.

Intended for use in training convolutional neural networks to predict centroids from image cutouts.
"""

import numpy as np
from tqdm import tqdm

try:
    from eloy.ballet.model import CNN
    import jax
    import jax.numpy as jnp
    from flax.training import train_state
    import optax
except ImportError:
    pass
    # raise ImportError(
    #     'jax-related packages are not installed. Use pip insall "eloy[jax]"'
    # )


[docs] class Moffat2D: """ Moffat 2D generator. Generates synthetic 2D Moffat profiles for training and testing. Parameters ---------- cutout_size : int, optional Size of the generated image cutouts (default is 21). **kwargs Additional keyword arguments. """ def __init__(self, cutout_size=21, **kwargs): super().__init__(**kwargs) self.model = None self.train_history = None self.cutout_size = cutout_size self.x, self.y = np.indices((cutout_size, cutout_size))
[docs] def moffat2D_model(self, a, x0, y0, sx, sy, theta, b, beta): """ Generate a 2D Moffat profile. Parameters ---------- a : float Amplitude. x0, y0 : float Center coordinates. sx, sy : float Scale parameters (widths) along x and y. theta : float Rotation angle in radians. b : float Background level. beta : float Moffat beta parameter. Returns ------- numpy.ndarray 2D Moffat profile of shape (cutout_size, cutout_size). """ # https://pixinsight.com/doc/tools/DynamicPSF/DynamicPSF.html dx_ = self.x - x0 dy_ = self.y - y0 dx = dx_ * np.cos(theta) + dy_ * np.sin(theta) dy = -dx_ * np.sin(theta) + dy_ * np.cos(theta) return b + a / np.power(1 + (dx / sx) ** 2 + (dy / sy) ** 2, beta)
[docs] def sigma_to_fwhm(self, beta): """ Convert Moffat beta parameter to FWHM. Parameters ---------- beta : float Moffat beta parameter. Returns ------- float Full width at half maximum (FWHM). """ return 2 * np.sqrt(np.power(2, 1 / beta) - 1)
[docs] def random_model_label(self, N=10000, flatten=False, return_all=False, sigma=1.0): """ Generate random Moffat images and labels. Parameters ---------- N : int, optional Number of samples to generate (default is 10000). flatten : bool, optional If True and N==1, returns single image and label (default is False). return_all : bool, optional If True, returns all model parameters as labels (default is False). sigma : float, optional Standard deviation for center coordinates (default is 1.0). Returns ------- tuple (images, labels) where images is (N, cutout_size, cutout_size, 1) and labels is (N, 2) or (N, 9) depending on return_all. """ images = [] labels = [] a = np.ones(N) b = np.zeros(N) x0, y0 = np.random.normal(self.cutout_size / 2, sigma, (2, N)) theta = np.random.uniform(0, np.pi / 8, size=N) beta = np.random.uniform(1, 8, size=N) sx = np.array( [np.random.uniform(1.5, 20.5) / self.sigma_to_fwhm(_beta) for _beta in beta] ) sy = np.random.uniform(0.5, 1.5, size=N) * sx noise = np.random.uniform(0, 0.1, size=N) for i in range(N): _noise = np.random.rand(self.cutout_size, self.cutout_size) * noise[i] data = ( self.moffat2D_model( a[i], x0[i], y0[i], sx[i], sy[i], theta[i], b[i], beta[i] ) + _noise ) images.append(data) images = np.array(images)[:, :, :, None] if not return_all: labels = np.array([x0, y0]).T else: labels = np.array([a, x0, y0, sx, sy, theta, b, beta, noise]).T if N == 1 and flatten: return (np.array(images[0]), np.array(labels[0])) else: return (np.array(images), np.array(labels))
# --- Training utilities ---
[docs] class TrainState(train_state.TrainState): """ Custom TrainState for model training. Inherits from flax.training.train_state.TrainState. """ pass
[docs] def compute_loss(params, batch): """ Compute the mean squared error loss for a batch. Parameters ---------- params : dict Model parameters. batch : tuple Tuple (x, y) of input images and target labels. Returns ------- jax.numpy.DeviceArray Mean squared error loss. """ x, y = batch preds = model.apply({"params": params}, x) return jnp.mean(optax.l2_loss(preds, y))
# def compute_loss(params, batch, delta=1.0): # """ # Compute the Huber loss for a batch. # # Parameters # ---------- # params : dict # Model parameters. # batch : tuple # Tuple (x, y) of input images and target labels. # delta : float, optional # Huber loss delta parameter (default is 1.0). # # Returns # ------- # jax.numpy.DeviceArray # Huber loss. # """ # x, y = batch # preds = model.apply({"params": params}, x) # diff = jnp.abs(preds - y) # loss = jnp.where(diff < delta, 0.5 * diff**2, delta * diff - 0.5 * delta**2) # return jnp.mean(loss) @jax.jit
[docs] def train_step(state, batch): """ Perform a single training step. Parameters ---------- state : TrainState Current training state. batch : tuple Tuple (x, y) of input images and target labels. Returns ------- tuple (new_state, loss) """ grad_fn = jax.value_and_grad(compute_loss) loss, grads = grad_fn(state.params, batch) new_state = state.apply_gradients(grads=grads) return new_state, loss
@jax.jit
[docs] def eval_step(params, batch): """ Evaluate the model on a batch and compute RMSE. Parameters ---------- params : dict Model parameters. batch : tuple Tuple (x, y) of input images and target labels. Returns ------- jax.numpy.DeviceArray Root mean squared error (RMSE). """ x, y = batch preds = model.apply({"params": params}, x) return jnp.sqrt(jnp.mean((preds - y) ** 2)) # RMSE
[docs] def params_to_flat_dict(params): """ Flatten model parameters to a dictionary suitable for saving. Parameters ---------- params : dict Model parameters. Returns ------- dict Flattened dictionary with keys for each layer's kernel and bias. """ weights = {} for k, v in params.items(): weights.update( {f"{k}_bias": np.array(v["bias"]), f"{k}_kernel": np.array(v["kernel"])} ) return weights
# --- Data batching ---
[docs] def get_batches(X, y, batch_size): """ Yield batches of data for training. Parameters ---------- X : numpy.ndarray Input images. y : numpy.ndarray Target labels. batch_size : int Batch size. Yields ------ tuple (x_batch, y_batch) as jax.numpy arrays. """ indices = np.random.permutation(len(X)) for start in range(0, len(X), batch_size): end = start + batch_size batch_idx = indices[start:end] yield jnp.array(X[batch_idx]), jnp.array(y[batch_idx])
if __name__ == "__main__": from datetime import datetime
[docs] size = 15
rng = jax.random.PRNGKey(0) model = CNN() params = model.init(rng, jnp.ones([1, size, size, 1]))["params"] state = TrainState.create(apply_fn=model.apply, params=params, tx=optax.adamw(5e-5)) moffat_gen = Moffat2D(size) train_size = 5000 test_size = 10000 batch_size = 100 print("Generate test/train samples") X_train, y_train = moffat_gen.random_model_label(train_size) X_test, y_test = moffat_gen.random_model_label(test_size) learning_rate = 1e-3 state = TrainState.create( apply_fn=model.apply, params=state.params, tx=optax.adamw(learning_rate) ) print("Start model training") # --- Training loop --- for epoch in range(300): for batch in get_batches(X_train, y_train, batch_size): state, loss = train_step(state, batch) if epoch == 100: learning_rate = 1e-4 elif epoch == 150: learning_rate = 1e-5 if epoch % 10 == 0: test_rmse = eval_step(state.params, (jnp.array(X_test), jnp.array(y_test))) print( f"Epoch {epoch}: Loss = {loss:.4f}, Test RMSE = {test_rmse:.4f}, LR = {learning_rate:.1e}" ) X_train, y_train = moffat_gen.random_model_label(train_size) state = TrainState.create( apply_fn=model.apply, params=state.params, tx=optax.adamw(learning_rate) ) if test_rmse < 0.01: break print("Adjusting model") X_train, y_train = moffat_gen.random_model_label(20000) adjust_params = [] adjust_mean = [] for i in range(10): for batch in get_batches(X_train, y_train, batch_size): state, loss = train_step(state, batch) predictions = model.apply({"params": state.params}, X_train) adjust_params.append(np.mean(predictions - y_train, 0)) adjust_mean.append(state.params) print(f"{i} - (x,y) = {adjust_params[i]}") j = np.argmin([np.max(np.abs(d)) for d in adjust_params]) final_model = adjust_mean[j] print(f"Best model: {j} - (x,y) = {adjust_params[j]}") print("Saving model file") now = datetime.now().isoformat() file_name = f"{size}x{size}_{now}.npz" np.savez(file_name, **params_to_flat_dict(final_model)) print("Model saved")