eloy.ballet.training
====================

.. py:module:: eloy.ballet.training

.. autoapi-nested-parse::

   Training utilities and synthetic data generation for centroid regression.

   This module provides:
       - Moffat2D: a class for generating synthetic 2D Moffat profiles and labels.
       - Training utilities for JAX/Flax models, including loss computation, batching, and training steps.
       - Functions for saving and loading model parameters.

   Intended for use in training convolutional neural networks to predict centroids from image cutouts.



Attributes
----------

.. autoapisummary::

   eloy.ballet.training.size


Classes
-------

.. autoapisummary::

   eloy.ballet.training.Moffat2D
   eloy.ballet.training.TrainState


Functions
---------

.. autoapisummary::

   eloy.ballet.training.compute_loss
   eloy.ballet.training.train_step
   eloy.ballet.training.eval_step
   eloy.ballet.training.params_to_flat_dict
   eloy.ballet.training.get_batches


Module Contents
---------------

.. py:class:: Moffat2D(cutout_size=21, **kwargs)

   Moffat 2D generator.

   Generates synthetic 2D Moffat profiles for training and testing.

   :param cutout_size: Size of the generated image cutouts (default is 21).
   :type cutout_size: int, optional
   :param \*\*kwargs: Additional keyword arguments.


   .. py:method:: moffat2D_model(a, x0, y0, sx, sy, theta, b, beta)

      Generate a 2D Moffat profile.

      :param a: Amplitude.
      :type a: float
      :param x0: Center coordinates.
      :type x0: float
      :param y0: Center coordinates.
      :type y0: float
      :param sx: Scale parameters (widths) along x and y.
      :type sx: float
      :param sy: Scale parameters (widths) along x and y.
      :type sy: float
      :param theta: Rotation angle in radians.
      :type theta: float
      :param b: Background level.
      :type b: float
      :param beta: Moffat beta parameter.
      :type beta: float

      :returns: 2D Moffat profile of shape (cutout_size, cutout_size).
      :rtype: numpy.ndarray



   .. py:method:: sigma_to_fwhm(beta)

      Convert Moffat beta parameter to FWHM.

      :param beta: Moffat beta parameter.
      :type beta: float

      :returns: Full width at half maximum (FWHM).
      :rtype: float



   .. py:method:: random_model_label(N=10000, flatten=False, return_all=False, sigma=1.0)

      Generate random Moffat images and labels.

      :param N: Number of samples to generate (default is 10000).
      :type N: int, optional
      :param flatten: If True and N==1, returns single image and label (default is False).
      :type flatten: bool, optional
      :param return_all: If True, returns all model parameters as labels (default is False).
      :type return_all: bool, optional
      :param sigma: Standard deviation for center coordinates (default is 1.0).
      :type sigma: float, optional

      :returns: (images, labels) where images is (N, cutout_size, cutout_size, 1)
                and labels is (N, 2) or (N, 9) depending on return_all.
      :rtype: tuple



.. py:class:: TrainState

   Bases: :py:obj:`flax.training.train_state.TrainState`


   Custom TrainState for model training.

   Inherits from flax.training.train_state.TrainState.


.. py:function:: compute_loss(params, batch)

   Compute the mean squared error loss for a batch.

   :param params: Model parameters.
   :type params: dict
   :param batch: Tuple (x, y) of input images and target labels.
   :type batch: tuple

   :returns: Mean squared error loss.
   :rtype: jax.numpy.DeviceArray


.. py:function:: train_step(state, batch)

   Perform a single training step.

   :param state: Current training state.
   :type state: TrainState
   :param batch: Tuple (x, y) of input images and target labels.
   :type batch: tuple

   :returns: (new_state, loss)
   :rtype: tuple


.. py:function:: eval_step(params, batch)

   Evaluate the model on a batch and compute RMSE.

   :param params: Model parameters.
   :type params: dict
   :param batch: Tuple (x, y) of input images and target labels.
   :type batch: tuple

   :returns: Root mean squared error (RMSE).
   :rtype: jax.numpy.DeviceArray


.. py:function:: params_to_flat_dict(params)

   Flatten model parameters to a dictionary suitable for saving.

   :param params: Model parameters.
   :type params: dict

   :returns: Flattened dictionary with keys for each layer's kernel and bias.
   :rtype: dict


.. py:function:: get_batches(X, y, batch_size)

   Yield batches of data for training.

   :param X: Input images.
   :type X: numpy.ndarray
   :param y: Target labels.
   :type y: numpy.ndarray
   :param batch_size: Batch size.
   :type batch_size: int

   :Yields: *tuple* -- (x_batch, y_batch) as jax.numpy arrays.


.. py:data:: size
   :value: 15


