Source code for eloy.ballet.model

try:
    from flax import linen as nn
    import jax.numpy as jnp
    from huggingface_hub import hf_hub_download
except ImportError:
    pass

    # raise ImportError(
    #     'jax-related packages are not installed. Use pip insall "eloy[jax]"'
    # )
[docs] class nn:
[docs] class Module:
[docs] def __call__(self, *args, **kwargs): raise NotImplementedError( "Flax is not installed. Use pip install 'eloy[jax]'" )
@staticmethod
[docs] def compact(func): return func
import numpy as np
[docs] class CNN(nn.Module): """ Convolutional Neural Network for centroid regression. Attributes ---------- params : None Placeholder for model parameters. """
[docs] params: None = None
@nn.compact
[docs] def __call__(self, x): """ Forward pass of the CNN. Parameters ---------- x : jax.numpy.ndarray Input image batch of shape (batch, height, width, channels). Returns ------- jax.numpy.ndarray Output predictions of shape (batch, 2). """ x = x - jnp.min(x, axis=(1, 2, 3), keepdims=True) # Center input x = x / jnp.max(x, axis=(1, 2, 3), keepdims=True) # Normalize input x = nn.Conv(64, (3, 3), padding="SAME")(x) x = nn.relu(x) x = nn.max_pool(x, (2, 2), strides=(2, 2), padding="SAME") x = nn.Conv(128, (3, 3), padding="SAME")(x) x = nn.relu(x) x = nn.max_pool(x, (2, 2), strides=(2, 2), padding="SAME") x = nn.Conv(256, (3, 3), padding="SAME")(x) x = nn.relu(x) x = x.reshape((x.shape[0], -1)) x = nn.Dense(2048)(x) x = nn.sigmoid(x) x = nn.Dense(512)(x) x = nn.sigmoid(x) return nn.Dense(2)(x)
[docs] def load_weights_file(file): """ Load model weights from a .npz file. Parameters ---------- file : str or Path Path to the .npz weights file. Returns ------- dict Dictionary mapping layer names to their kernel and bias arrays. """ weights = np.load(file) layers = np.unique( [key.replace("_bias", "").replace("_kernel", "") for key in weights.keys()] ) return { layer: { "kernel": weights[f"{layer}_kernel"], "bias": weights[f"{layer}_bias"], } for layer in layers }
[docs] def download_weights(): """ Download pretrained weights from HuggingFace Hub. Returns ------- str Path to the downloaded weights file. """ return hf_hub_download(repo_id="lgrcia/ballet", filename="centroid_15x15.npz")
[docs] class Ballet: """ Ballet interface for centroid prediction using a pretrained CNN. Attributes ---------- cnn : CNN The CNN model instance. params : dict Model parameters loaded from file. """
[docs] cnn: None = None
[docs] params: None = None
def __init__(self, model_file=None): """ Initialize the Ballet model. Parameters ---------- model_file : str or Path, optional Path to the model weights file. If None, downloads default weights. """ if model_file is None: model_file = download_weights() self.cnn = CNN() self.params = load_weights_file(model_file)
[docs] def centroid(self, x): """ Predict centroids for input images. Parameters ---------- x : numpy.ndarray Input images of shape (batch, height, width). Returns ------- numpy.ndarray Predicted centroids of shape (batch, 2), with coordinates (y, x). """ return self.cnn.apply({"params": self.params}, x[..., None])[:, ::-1]