2 - Model specification and fitting

JAX
Tutorial
model fitting
model building
This post introduces the general set-up that we will use in this tutorial for specifying models and fitting them to the data.
Author

Valerio Bonometti

Published

August 24, 2023

Show supplementary code
%load_ext watermark

import numpy as np
import matplotlib.pyplot as plt

from jax.debug import print as jprint

In order to specify models in JAX we first need to figure out what are the core functionalities that we need to implement. We will focus on specific set of models that given an input \(X\), a target \(y\) and parameters \(\theta\) aim to approximate functions of the form \(f(X; \theta) \mapsto y\).

What we need to specify are:

  1. Parameters-related functionalities:
    • Storage, how to best keep records of our parameters.
    • Initialisation, how to set our parameters to good starting points.
    • Sharing, how to make the parameters available to the model.
  2. Model-related functionalities:
    • Forward computations, how to move from an input to an estimate of the target.
    • Objective computations, how to define suitable loss function along with any regularizing penalties.
    • Backward computations, how to derive the gradient of the parameters with respect to the model’s objective.
  3. Optimization-related functionalities:
    • Optimization routines, how to find the optimal values for the parameters using suitable algorithms.
    • Parameters update, how to use the information derived from the backward computations for updating the parameters.
    • Fitting routines, how to connect the input, model and the optimization routines.

We also need to make sure that while developing these functionalities we leverage the optimisations provided by JAX while avoiding its sharp edges.

3 Hardware and Requirements

Here you can find the hardware and python requirements used for building this post.

%watermark
Last updated: 2025-03-28T09:22:13.055626+00:00

Python implementation: CPython
Python version       : 3.13.2
IPython version      : 9.0.2

Compiler    : Clang 18.1.8 
OS          : Darwin
Release     : 24.3.0
Machine     : arm64
Processor   : arm
CPU cores   : 14
Architecture: 64bit
%watermark --iversions
jax       : 0.5.2
seaborn   : 0.13.2
matplotlib: 3.10.1
numpy     : 2.2.4