eloy.ballet.model#

Classes#

nn

CNN

Convolutional Neural Network for centroid regression.

Ballet

Ballet interface for centroid prediction using a pretrained CNN.

Functions#

load_weights_file(file)

Load model weights from a .npz file.

download_weights()

Download pretrained weights from HuggingFace Hub.

Module Contents#

class eloy.ballet.model.nn[source]#
class Module[source]#
abstract __call__(*args, **kwargs)[source]#
static compact(func)[source]#
class eloy.ballet.model.CNN[source]#

Bases: flax.linen.Module

Convolutional Neural Network for centroid regression.

params[source]#

Placeholder for model parameters.

Type:

None

params: None = None[source]#
__call__(x)[source]#

Forward pass of the CNN.

Parameters:

x (jax.numpy.ndarray) – Input image batch of shape (batch, height, width, channels).

Returns:

Output predictions of shape (batch, 2).

Return type:

jax.numpy.ndarray

eloy.ballet.model.load_weights_file(file)[source]#

Load model weights from a .npz file.

Parameters:

file (str or Path) – Path to the .npz weights file.

Returns:

Dictionary mapping layer names to their kernel and bias arrays.

Return type:

dict

eloy.ballet.model.download_weights()[source]#

Download pretrained weights from HuggingFace Hub.

Returns:

Path to the downloaded weights file.

Return type:

str

class eloy.ballet.model.Ballet(model_file=None)[source]#

Ballet interface for centroid prediction using a pretrained CNN.

cnn[source]#

The CNN model instance.

Type:

CNN

params[source]#

Model parameters loaded from file.

Type:

dict

cnn: None = None[source]#
params: None = None[source]#
centroid(x)[source]#

Predict centroids for input images.

Parameters:

x (numpy.ndarray) – Input images of shape (batch, height, width).

Returns:

Predicted centroids of shape (batch, 2), with coordinates (y, x).

Return type:

numpy.ndarray