eloy.ballet.model
=================

.. py:module:: eloy.ballet.model


Classes
-------

.. autoapisummary::

   eloy.ballet.model.nn
   eloy.ballet.model.CNN
   eloy.ballet.model.Ballet


Functions
---------

.. autoapisummary::

   eloy.ballet.model.load_weights_file
   eloy.ballet.model.download_weights


Module Contents
---------------

.. py:class:: nn

   .. py:class:: Module

      .. py:method:: __call__(*args, **kwargs)
         :abstractmethod:




   .. py:method:: compact(func)
      :staticmethod:



.. py:class:: CNN

   Bases: :py:obj:`flax.linen.Module`


   Convolutional Neural Network for centroid regression.

   .. attribute:: params

      Placeholder for model parameters.

      :type: None


   .. py:attribute:: params
      :type:  None
      :value: None



   .. py:method:: __call__(x)

      Forward pass of the CNN.

      :param x: Input image batch of shape (batch, height, width, channels).
      :type x: jax.numpy.ndarray

      :returns: Output predictions of shape (batch, 2).
      :rtype: jax.numpy.ndarray



.. py:function:: load_weights_file(file)

   Load model weights from a .npz file.

   :param file: Path to the .npz weights file.
   :type file: str or Path

   :returns: Dictionary mapping layer names to their kernel and bias arrays.
   :rtype: dict


.. py:function:: download_weights()

   Download pretrained weights from HuggingFace Hub.

   :returns: Path to the downloaded weights file.
   :rtype: str


.. py:class:: Ballet(model_file=None)

   Ballet interface for centroid prediction using a pretrained CNN.

   .. attribute:: cnn

      The CNN model instance.

      :type: CNN

   .. attribute:: params

      Model parameters loaded from file.

      :type: dict


   .. py:attribute:: cnn
      :type:  None
      :value: None



   .. py:attribute:: params
      :type:  None
      :value: None



   .. py:method:: centroid(x)

      Predict centroids for input images.

      :param x: Input images of shape (batch, height, width).
      :type x: numpy.ndarray

      :returns: Predicted centroids of shape (batch, 2), with coordinates (y, x).
      :rtype: numpy.ndarray



