Analyzing A/B Test Data using Gaussian Processes

JAX
Numpyro
Bayesian Statistics
A/B Test
Gaussian Process
This post illustrates how to analyze longitudinal A/B test data using Gaussian Process.
Published

March 29, 2025

Show supplementary code
%load_ext watermark

NUM_CHAINS = 4

from typing import Dict, Any, Callable, Tuple

from functools import partial

import pandas as pd
import numpy as np

import scipy
from scipy.stats import lognorm, median_abs_deviation

import jax
from jax.typing import ArrayLike
import jax.random as random
from jax import jit
from jax import numpy as jnp
from jax import vmap

import numpyro
numpyro.set_host_device_count(NUM_CHAINS)
import numpyro.distributions as dist
from numpyro.infer import (
    MCMC,
    NUTS,
    Predictive
)

import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from matplotlib import gridspec

import seaborn as sns

def plot_univariate_series(
        series_data: Dict[Any, Any], 
        ax: plt.Axes, 
        **plot_kwargs: Any
    ) -> plt.Axes:
    ax.scatter(
        series_data["x"]["data"],
        series_data["y"]["data"],
        **plot_kwargs
    )
    ax.plot(
        series_data["x"]["data"],
        series_data["y"]["data"],
        **plot_kwargs
    )
    ax.set_xlabel(series_data["x"]["label"])
    ax.set_ylabel(series_data["y"]["label"])

    ax.tick_params(
        direction="in",
        top=True, axis="x",
        rotation=45,
    )
    ax.grid(
        visible=True,
        which="major",
        axis="x",
        color="k",
        alpha=0.25,
        linestyle="--",
    )
    return ax

1 Premise

What will be illustrated in this post is strongly in spired by the content of the books “Statistical Rethinking” (McElreath 2018) and “Bayesian Data Analysis 3rd Edition” (Gelman et al. 1995). In particular, the idea to separately code and illustrate the behavior of different covariance kernel functions comes from the amazing “Kernel Cookbook” and PhD thesis of David Duvenaud (Duvenaud 2014).

This post assumes some level of knowledge in bayesian statistics and probabilistic programming.

1.1 What we will cover

  1. Very brief illustration of longitudinal A/B test within and observational paradigm.
  2. Very brief illustration of gaussian processes and their application to analyzing A/B test data.
  3. Overview of how to implement a gaussian process model using Numpyro and JAX.
  4. Simulating A/B test data.
  5. Analyzing A/B test data within an modelling setting.

1.2 What we will not cover

  1. Detailed coverage of A/B test (e.g., sampling, randomization etc…).
  2. Hypothesis testing (we will focus on modelling).
  3. Fundamentals of bayesian statistics.
  4. Detailed overview of Gaussian processes.
  5. Probabilistic programming and sampling algorithms.

2 Introduction

2.1 Longitudinal A/B tests in observational settings

When we talk about A/B test we usually refer to a research method used for evaluating if a given intervention is having an impact on a pre-defined outcome variable measured inside a sample. For doing so, we can draw two distinct samples (the A and B group) from a population of interest, subject one of the two to the intervention and then measure observed differences in the outcome variable. Subject to several assumptions and pre-conditions (we suggest reading part V of “Regression and Other Stories” (Gelman, Hill, and Vehtari 2021)), if we observe a difference between the two groups when can conclude that our intervention might have had an impact on our outcome variable.

2.2 Gaussian Process

The Gaussian Process \(GP\) can be thought as the continuous generalization of basis function regression (see Chapter 20 of (Gelman et al. 1995)). We can think of it as a stochastic process where any point drawn from it, \(x_1, \dots, x_n\), comes from a multi-dimensional gaussian. In other words, it is as a prior distribution over an un-known function \(\mu(x)\) defined as

\[ \mu(x_1), \dots, \mu(x_n) \sim \mathcal{N}((m(x_1), \dots, m(x_n)), k(x, \dots, x_n)) \]

or more compactly

\[ \mu(X) \sim GP(m, k) \]

where \(m\) is a mean function and \(k\) is a covariance function. We can already have an intuition of how defining the \(GP\) in terms mean and covariance functions gives us quite some flexibility as it allows us to produce a model than can interpolate for all the value of \(x\).

The \(m\) function provides the most likely guess for the \(GP\) like the mean vector of a multi-dimensional Gaussian, deviation from this expected model are then handled by the covariance function \(k\).

The \(k\) function (often called Kernel) allows to structurally define the \(GP\) behavior at any two points by producing an \(n \times n\) covariance function given by evaluating \(k(x, x')\) for every \(x_1, \dots, x_n\).

One convenient property of \(GP\) is that the sum and multiplication of two or more \(GP\) is itself a \(GP\), this allows to combine different types of Kernels for imposing specific structural constrains.

2.3 Gaussian Process for A/B test

Although we were not able to find many papers illustrating how \(GP\) can be used for analyzing A/B test data we found this interesting work by (Benavoli and Mangili 2015) from IDSIA that we decided to adapt to our use-case.

3 Implementing a Gaussian Process Model in Numpyro

In this section we will illustrate how we can implement a \(GP\) model using Numpyro. Numpyro offers a Numpy-like Backend for Pyro a Probabilistic Programming Language (PPL). Other than offering the flexibility of specifying models using the familiar Numpy interface, Numpyro is perfectly integrated with JAX allowing us to tap into its JIT compilation capabilities.

BLUE = (0, 83, 159)
BLUE = tuple(value / 255. for value in BLUE)

RED = (238, 28, 46)
RED = tuple(value / 255. for value in RED)

TIME_IDX = np.arange(7*4)
DATES = pd.date_range(
    start="01-01-2023",
    periods=len(TIME_IDX)
).values
TIME = "date"
MODELS = [
    "Branch A",
    "Branch B",
    "Branch C",
    "Branch D",
    "Branch E",
]

3.1 Kernel Functions

implementing Kernel functions in Numpyro is as easy as simply writing the functional form of the actual covariance function in pure python. It is sufficient that the python function has the following signature:

def my_kernel_function(X: ArrayLike, X_prime:ArrayLike, **kwargs)-> ArrayLike:
    ...

Here X and X_prime is the same input covariate used for creating the covariance matrix. While the keyword arguments are the parameters used for computing the pairwise relationship between each pair of of values in X (hence why we duplicate it using X_prime). This is the base for giving structure and creating our covariance matrix. We will now present a series of conventional covariance functions usually found in standard Gaussian Process applications.

Show supplementary code
def simulate_kernel(
    X: ArrayLike, 
    X_prime:ArrayLike, 
    kernel_function: Callable, 
    samples: int,
    mean: float = 0.
) -> Tuple[ArrayLike, ArrayLike, ArrayLike]:
    generated_covariance = kernel_function(
        X=X,
        X_prime=X_prime,
    )
    distance = generated_covariance[:, len(X) // 2]
    sampled_functions = np.random.multivariate_normal(
        mean=np.zeros(
            shape=generated_covariance.shape[0]
        ) + mean,
        cov=generated_covariance,
        size=samples
    )
    return generated_covariance, distance, sampled_functions

def visualize_kernel(
    X: ArrayLike, 
    generated_covariance: ArrayLike, 
    distance: ArrayLike, 
    sampled_functions: ArrayLike, 
    kernel_name: str
) -> Figure:

    fig = plt.figure(
        figsize=(8, 8),
        tight_layout=True
    )
    grid = gridspec.GridSpec(
        nrows=2,
        ncols=2
    )
    ax_functions = fig.add_subplot(grid[0, :])
    ax_distance = fig.add_subplot(grid[1, 0])
    ax_covariance = fig.add_subplot(grid[1, 1])

    for index in range(sampled_functions.shape[0]):

        ax_functions = plot_univariate_series(
            series_data={
                "x": {
                    "data": X,
                    "label": "x"
                },
                "y": {
                    "data": sampled_functions[index, :],
                    "label": "y"
                },

            },
            ax=ax_functions,
            alpha=0.5
        )

    ax_functions.set_title("Sampled Functions \n from MvNormal")
    ax_functions.axhline(0, linestyle="--", c="k")

    ax_distance.plot(
        X - (len(X) // 2),
        distance
    )
    ax_distance.grid(alpha=0.5)
    ax_distance.set_title(f"Distance Function \n Determined by {kernel_name} Kernel")
    ax_covariance.set_ylabel("Similarity")
    ax_covariance.set_xlabel("Distance")

    ax_covariance.imshow(
        generated_covariance
    )
    ax_covariance.set_ylabel("x'")
    ax_covariance.set_xlabel("x")
    ax_covariance.set_title(f"Covariance \n Determined by {kernel_name} Kernel")


    plt.suptitle(f"{kernel_name} Kernel")
    return fig

3.1.1 Radial Basis Function Kernel

Quoting from Duvenaud (2014)

@jit
def RBF_kernel(
    X: ArrayLike, 
    X_prime: ArrayLike, 
    variance: float, 
    length: float, 
    noise: float,
    jitter: float =1.0e-6, 
    include_noise: bool =True
) -> ArrayLike:
    squared_differences = jnp.power(
        (X[:, None] - X_prime),
        2.0
    )
    squared_length_scale = 2 * jnp.power(
        length,
        2.0
    )
    covariance_matrix = jnp.exp(
        - (squared_differences / squared_length_scale)
    )
    scaled_covariance_matrix = variance * covariance_matrix

    scaled_covariance_matrix = jnp.where(
        include_noise, 
        scaled_covariance_matrix + (noise + jitter) * jnp.eye(X.shape[0]), 
        scaled_covariance_matrix,
    )

    return scaled_covariance_matrix
kernel_function = partial(
    RBF_kernel,
    variance=1,
    length=10,
    noise=0.001
)
generated_covariance, distance, sampled_functions = simulate_kernel(
    X=TIME_IDX,
    X_prime=TIME_IDX,
    kernel_function=kernel_function,
    samples=3,
)
fig = visualize_kernel(
    TIME_IDX,
    generated_covariance=generated_covariance,
    distance=distance,
    sampled_functions=sampled_functions,
    kernel_name="RBF"
)
plt.show()

3.1.2 Matern Kernel

@jit
def rational_quadratic_kernel(
    X: ArrayLike, 
    X_prime: ArrayLike, 
    variance: float, 
    length: float, 
    noise: float, 
    alpha: float, 
    jitter: float =1.0e-6, 
    include_noise: bool =True
) -> ArrayLike:

    squared_differences = jnp.power(
        (X[:, None] - X_prime),
        2.0
    )
    squared_length_scale = 2 * alpha * jnp.power(
        length,
        2.0
    )
    covariance_matrix = jnp.power(
        1 + (squared_differences / squared_length_scale),
        - alpha
    )
    scaled_covariance_matrix = variance * covariance_matrix

    scaled_covariance_matrix = jnp.where(
        include_noise, 
        scaled_covariance_matrix + (noise + jitter) * jnp.eye(X.shape[0]), 
        scaled_covariance_matrix,
    )

    return scaled_covariance_matrix
kernel_function = partial(
    rational_quadratic_kernel,
    variance=1,
    length=10,
    alpha=3,
    noise=0.001
)
generated_covariance, distance, sampled_functions = simulate_kernel(
    X=TIME_IDX,
    X_prime=TIME_IDX,
    kernel_function=kernel_function,
    samples=3

)
fig = visualize_kernel(
    TIME_IDX,
    generated_covariance=generated_covariance,
    distance=distance,
    sampled_functions=sampled_functions,
    kernel_name="Rational Quadratic"
)
plt.show()

3.1.3 Periodic Kernel

@jit
def periodic_kernel(
    X: ArrayLike, 
    X_prime: ArrayLike, 
    variance: float, 
    length: float, 
    noise: float, 
    period: float, 
    jitter: float = 1.0e-6, 
    include_noise: bool =True
) -> ArrayLike:

    periodic_difference = jnp.pi * jnp.abs(X[:, None] - X_prime) / period
    sine_squared_difference = 2 * jnp.power(
        jnp.sin(periodic_difference),
        2.0
    )
    squared_length_scale = jnp.power(
        length,
        2.0
    )
    covariance_matrix = jnp.exp(
        - (sine_squared_difference / squared_length_scale)
    )
    scaled_covariance_matrix = variance * covariance_matrix

    scaled_covariance_matrix = jnp.where(
        include_noise, 
        scaled_covariance_matrix + (noise + jitter) * jnp.eye(X.shape[0]), 
        scaled_covariance_matrix,
    )
    return scaled_covariance_matrix
kernel_function = partial(
    periodic_kernel,
    variance=1,
    length=5,
    period=10,
    noise=0.001
)

generated_covariance, distance, sampled_functions = simulate_kernel(
    X=TIME_IDX,
    X_prime=TIME_IDX,
    kernel_function=kernel_function,
    samples=3,
)
fig = visualize_kernel(
    TIME_IDX,
    generated_covariance=generated_covariance,
    distance=distance,
    sampled_functions=sampled_functions,
    kernel_name="Periodic"
)
plt.show()

3.1.4 Linear Kernel

@jit
def linear_kernel(
    X: ArrayLike, 
    X_prime: ArrayLike, 
    variance: float, 
    noise: float, 
    jitter: float = 1.0e-6, 
    include_noise: bool = True
) -> ArrayLike:
    scaled_covariance_matrix = variance + (X[:, None] * X_prime)
    scaled_covariance_matrix = jnp.where(
        include_noise, 
        scaled_covariance_matrix + (noise + jitter) * jnp.eye(X.shape[0]), 
        scaled_covariance_matrix,
    )

    return scaled_covariance_matrix
kernel_function = partial(
    linear_kernel,
    variance=0.01,
    noise=0.001
)
generated_covariance, distance, sampled_functions = simulate_kernel(
    X=TIME_IDX,
    X_prime=TIME_IDX,
    kernel_function=kernel_function,
    samples=3,
)
fig = visualize_kernel(
    TIME_IDX,
    generated_covariance=generated_covariance,
    distance=distance,
    sampled_functions=sampled_functions,
    kernel_name="Linear"
)
plt.show()

3.1.5 White Noise Kernel

@jit
def white_noise_kernel(
    X: ArrayLike, 
    X_prime: ArrayLike, 
    variance: float
) -> ArrayLike:
    covariance_matrix = variance * jnp.eye(X.shape[0])
    return covariance_matrix
kernel_function = partial(
    white_noise_kernel,
    variance=0.01,
)
generated_covariance, distance, sampled_functions = simulate_kernel(
    X=TIME_IDX,
    X_prime=TIME_IDX,
    kernel_function=kernel_function,
    samples=3,
)
fig = visualize_kernel(
    TIME_IDX,
    generated_covariance=generated_covariance,
    distance=distance,
    sampled_functions=sampled_functions,
    kernel_name="White Noise"
)
plt.show()

3.1.6 Combining Kernels

@partial(jit, static_argnums=(2,3))
def additive_combined_kernel(
    X: ArrayLike, 
    X_prime: ArrayLike, 
    first_kernel: Callable, 
    second_kernel: Callable
) -> ArrayLike:
    return first_kernel(X=X, X_prime=X_prime) + second_kernel(X=X, X_prime=X_prime)

@partial(jit, static_argnums=(2,3))
def multiplicative_combined_kernel(
    X: ArrayLike, 
    X_prime: ArrayLike, 
    first_kernel: Callable, 
    second_kernel: Callable
) -> ArrayLike:
    return first_kernel(X=X, X_prime=X_prime) * second_kernel(X=X, X_prime=X_prime)
kernel_function_linear = partial(
    linear_kernel,
    variance=0.0001,
    noise=0.000
)
kernel_function_rbf = partial(
    RBF_kernel,
    variance=30,
    length=2,
    noise=10
)

kernel_function = partial(
    additive_combined_kernel,
    first_kernel=kernel_function_linear,
    second_kernel=kernel_function_rbf
)

generated_covariance, distance, sampled_functions = simulate_kernel(
    X=TIME_IDX,
    X_prime=TIME_IDX,
    kernel_function=kernel_function,
    samples=3,
)
fig = visualize_kernel(
    TIME_IDX,
    generated_covariance=generated_covariance,
    distance=distance,
    sampled_functions=sampled_functions,
    kernel_name="Additive Linear+RBF"
)
plt.show()

kernel_function_periodic = partial(
    periodic_kernel,
    variance=5,
    length=5,
    period=10,
    noise=0.0001
)
kernel_function_rbf = partial(
    RBF_kernel,
    variance=.1,
    length=50,
    noise=0.0001
)

kernel_function = partial(
    multiplicative_combined_kernel,
    first_kernel=kernel_function_periodic,
    second_kernel=kernel_function_rbf
)

generated_covariance, distance, sampled_functions = simulate_kernel(
    X=TIME_IDX,
    X_prime=TIME_IDX,
    kernel_function=kernel_function,
    samples=3,
)
fig = visualize_kernel(
    TIME_IDX,
    generated_covariance=generated_covariance,
    distance=distance,
    sampled_functions=sampled_functions,
    kernel_name="Multiplicative Periodic RBF"
)
plt.show()

4 Simulating the data

4.1 Generating the underlying process

def generate_underlying(
    time_index, 
    periodic_variance, 
    periodic_length, 
    period, 
    rbf_variance, 
    rbf_length
):
    kernel_function_seasonality = partial(
        periodic_kernel,
        variance=periodic_variance,
        length=periodic_length,
        period=period, 
        noise=0.0001
    )
    kernel_function_slow_variation = partial(
        RBF_kernel,
        variance=rbf_variance,
        length=rbf_length,
        noise=0.0001
    )

    kernel_function_underlying = partial(
        additive_combined_kernel,
        first_kernel=kernel_function_seasonality,
        second_kernel=kernel_function_slow_variation
    )

    generated_covariance, distance, sampled_function_underlying = simulate_kernel(
        X=time_index,
        X_prime=time_index,
        kernel_function=kernel_function_underlying,
        samples=1,
    )
    return generated_covariance, distance, sampled_function_underlying


generated_covariance, distance, sampled_based_function = generate_underlying(
    time_index=TIME_IDX, 
    periodic_variance=0.25, 
    periodic_length=7, 
    period=7, # every seven days we have weekly variations
    rbf_variance=0.1, 
    rbf_length=14,  # slow variation with 21 days decay
)

sampled_based_function = (
    (sampled_based_function - sampled_based_function.mean()) / sampled_based_function.std()
)
fig = visualize_kernel(
    TIME_IDX,
    generated_covariance=generated_covariance,
    distance=distance,
    sampled_functions=sampled_based_function,
    kernel_name="Simulated Underlying Process"
)
plt.show()

4.2 Adding the experimental effects

# We assume constant effect as a fraction of standard deviation
# This is somehow unrealistic

@partial(jit, static_argnums=2)
def exponential_decay_function(initial_value, rate_decay, max_t):
    decayed_function = jnp.repeat(initial_value, max_t)
    return decayed_function * jnp.power(1 - rate_decay, jnp.arange(max_t))

models_parameters = {
    "Branch A": {
        "experimental_delta": 0, 
    }, 
    "Branch B": {
        "experimental_delta": sampled_based_function.std() * 1., 
    }, 
    "Branch C": {
        "experimental_delta": sampled_based_function.std() * 1.25, 
    }, 
    "Branch D": {
        "experimental_delta": sampled_based_function.std() * 1.5,
    }, 
    "Branch E": {
        "experimental_delta": sampled_based_function.std() * 1.75, 
    }, 
}

sampled_functions_underlying = []
sampled_functions_effect = []
for arm in ["Branch A", "Branch B", "Branch C", "Branch D", "Branch E"]:

    kernel_function = partial(
        white_noise_kernel,
        variance=.1,
    )
    exponential_decay_mean = exponential_decay_function(
        initial_value=models_parameters[arm]["experimental_delta"],
        rate_decay=0.05,
        max_t=sampled_based_function.shape[1],
    )
    
    _, _, sampled_experimental_effect_function = simulate_kernel(
        X=TIME_IDX,
        X_prime=TIME_IDX,
        kernel_function=kernel_function,
        mean=exponential_decay_mean,
        samples=1
    )
    
    sampled_functions_effect.append(
         sampled_experimental_effect_function.flatten()
    )
    sampled_functions_underlying.append(
        (
            sampled_based_function + sampled_experimental_effect_function
        ).flatten()
    )

sampled_functions_effect = np.array(sampled_functions_effect)
sampled_functions_underlying = np.array(sampled_functions_underlying)
fig, ax = plt.subplots(1, 1, figsize=(8, 4))
for idx, series in enumerate(range(sampled_functions_underlying.shape[0])):

    ax = plot_univariate_series(
        series_data={
            "x": {
                "data": TIME_IDX,
                "label": "x"
            },
            "y": {
                "data":  sampled_functions_underlying[series, :],
                "label": "y"
            },

        },
        ax=ax,
        alpha=0.25 if idx != 0 else 1.
    )
plt.show()

fig, ax = plt.subplots(1, 1, figsize=(8, 4))
for idx, series in enumerate(range(sampled_functions_effect.shape[0])):

    ax = plot_univariate_series(
        series_data={
            "x": {
                "data": TIME_IDX,
                "label": "x"
            },
            "y": {
                "data":  sampled_functions_effect[series, :],
                "label": "y"
            },

        },
        ax=ax,
        alpha=0.25 if idx != 0 else 1.
    )
plt.axhline(0, linestyle=":", c="k")
plt.show()

4.3 Adding WeekDay Effects

weekday_effect = np.array(
    [
        0.,
        0.,
        0.,
        0.,
        0.15,
        -0.1,
        -0.05
    ]
)
weekday_effect = np.hstack([weekday_effect for _ in range(TIME_IDX.shape[0] // 6)])
weekday_effect = weekday_effect[:TIME_IDX.shape[0]]

fig, ax = plt.subplots(1, 1, figsize=(8, 4))
ax = plot_univariate_series(
    series_data={
        "x": {
            "data": TIME_IDX,
            "label": "x"
        },
        "y": {
            "data":  np.ones(shape=(TIME_IDX.shape[0])) * weekday_effect,
            "label": "y"
        },

    },
    ax=ax,
    alpha=0.5
)
plt.show()

weekday_deltas = (sampled_functions_underlying) * weekday_effect
simulated_series = sampled_functions_underlying + weekday_deltas

fig, ax = plt.subplots(1, 1, figsize=(8, 4))
for idx, series in enumerate(range(simulated_series.shape[0])):

    ax = plot_univariate_series(
        series_data={
            "x": {
                "data": TIME_IDX,
                "label": "x"
            },
            "y": {
                "data":  simulated_series[series, :],
                "label": "y"
            },

        },
        ax=ax,
        alpha=0.25 if idx != 0 else 1.
    )
plt.show()

5 Fitting Gaussian Process Models to A/B test data

df = pd.DataFrame(
    simulated_series.T,
    # the names here indicates the 5 branches of an AB test.
    columns=MODELS
)
df["date"] = DATES
df["day_week"] = df["date"].dt.day_name()
df.head()
Branch A Branch B Branch C Branch D Branch E date day_week
0 2.100637 2.364098 3.293391 3.427072 3.219653 2023-01-01 Sunday
1 1.974662 2.498444 2.764741 2.890823 3.090374 2023-01-02 Monday
2 0.698670 2.277956 2.125152 2.282583 2.842012 2023-01-03 Tuesday
3 0.641749 1.941783 1.901703 2.220936 2.211287 2023-01-04 Wednesday
4 1.051130 1.669099 2.232579 2.115660 2.571655 2023-01-05 Thursday
Show supplementary code
def _set_x_axis_grid(ax):
    ax.tick_params(
        direction="in",
        top=True,
        axis="x"
    )
    ax.grid(
        visible=True,
        which="major",
        axis="x",
        color="k",
        alpha=0.25,
        linestyle="--"
    )
    return ax

def plot_arms_time_series(time, arm_a, arm_b, df, ax, y_label, title):
    ax.scatter(
        df[time].values,
        df[arm_a].values,
        marker="x",
        c="blue",
        label=arm_a
    )
    ax.scatter(
        df[time].values,
        df[arm_b].values,
        marker="*",
        c="red",
        label=arm_b
    )
    ax = _set_x_axis_grid(ax=ax)
    ax.set_title(title)
    ax.set_xlabel("Time")
    ax.set_ylabel(y_label)
    return ax
plt.figure(figsize=(15, 5))
sns.violinplot(
    data=df.melt(
        id_vars=("date"),
        value_vars=MODELS,
        var_name="Arm",
        value_name="Dependent Variable"
    ),
    x="Arm",
    y="Dependent Variable",
    hue="Arm",
)
plt.xlabel("AB Test Arms")
plt.grid(alpha=0.5)
plt.show()

plt.figure(figsize=(8, 4))
sns.violinplot(
    data=df.melt(
        id_vars=("date", "day_week"),
        value_vars=MODELS,
        var_name="Arm",
        value_name="Dependent Variable"
    ),
    x="day_week",
    y="Dependent Variable",
    order=[
        "Monday",
        "Tuesday",
        "Wednesday",
        "Thursday",
        "Friday",
        "Saturday",
        "Sunday"
    ],
    hue="day_week"
)
plt.xlabel("Day of the Week")
plt.grid(alpha=0.5)
plt.show()

fig = plt.figure(figsize=(8, 4))
ax = sns.violinplot(
    data=df.melt(
        id_vars=("date", "day_week"),
        value_vars=MODELS,
        var_name="Arm",
        value_name="Dependent Variable"
    ),
    x="day_week",
    y="Dependent Variable",
    hue="Arm",
    order=[
        "Monday",
        "Tuesday",
        "Wednesday",
        "Thursday",
        "Friday",
        "Saturday",
        "Sunday"
    ]
)
sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))
plt.xlabel("Day of the Week")
plt.grid(alpha=0.5)
plt.show()

fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(12, 6), sharex=True, sharey=True)
arms = MODELS.copy()
arms.remove("Branch A")

for arm_b, ax in zip(arms, axs.flatten()):

    ax = plot_arms_time_series(
        time="date",
        arm_a="Branch A",
        arm_b=arm_b,
        df=df,
        ax=ax,
        y_label="Dependent Variable",
        title="AB Test"
    )
    ax.legend()

plt.tight_layout()
plt.show()

arms = MODELS.copy()
arms.remove("Branch A")

fig, axs = plt.subplots(
    nrows=2,
    ncols=2,
    figsize=(12, 6),
    sharex=True,
    sharey=True
)
for arm_b, ax in zip(arms, axs.flatten()):

    delta = df[arm_b] - df["Branch A"]
    time = df[TIME]

    semed = median_abs_deviation(delta) / np.sqrt(len(time))

    ax.scatter(
        time,
        delta,
        marker="o",
        facecolors='none',
        edgecolors='r',
        linestyle=":"
    )
    ax.set_title(arm_b)
    ax.plot(
        time,
        [np.median(delta)] * len(time),
        c='r',
    )
    ax.fill_between(
        time,
        [np.median(delta) - (1.96 * semed)] * len(time),
        [np.median(delta) + (1.96 * semed)] * len(time),
        color='r',
        alpha=0.1
    )
    ax.axhline(
        0,
        linestyle=":",
        c="k"
    )
    ax.set_xlabel("Date")
    ax.set_ylabel("$ \Delta $\nDependent Variable")

plt.tight_layout()
plt.show()
<>:45: SyntaxWarning: invalid escape sequence '\D'
<>:45: SyntaxWarning: invalid escape sequence '\D'
/var/folders/h_/lq2hvf816xs9ffng6570sblh0000gn/T/ipykernel_55901/2513872932.py:45: SyntaxWarning: invalid escape sequence '\D'
  ax.set_ylabel("$ \Delta $\nDependent Variable")

5.1 The Gaussian Process Model

def gaussian_process_model(
        trend_kernel_variance_dist,
        trend_kernel_noise_dist,
        trend_kernel_length_dist,

        seasonal_kernel_variance_dist,
        seasonal_kernel_noise_dist,
        seasonal_kernel_length_dist,
        seasonal_kernel_period_dist,

        effect_initial_dist,
        effect_decay_rate_dist,

        model_noise_dist,
        coefficients_days_dist,
    ):

    def build_trend_covariance(
            X, 
            trend_kernel_variance_dist, 
            trend_kernel_noise_dist, 
            trend_kernel_length_dist,
        ):
        trend_kernel_variance = numpyro.sample(
            "trend_kernel_variance",
            trend_kernel_variance_dist,
        )
        trend_kernel_noise = numpyro.sample(
            "trend_kernel_noise",
            trend_kernel_noise_dist,
        )
        trend_kernel_length = numpyro.sample(
            "trend_kernel_length",
            trend_kernel_length_dist,
        )
        covariance_matrix_trend = RBF_kernel(
            X=X,
            X_prime=X,
            variance=trend_kernel_variance,
            length=trend_kernel_length,
            noise=trend_kernel_noise,
        )
        return covariance_matrix_trend
    
    def build_seasonal_covariance(
            X, 
            seasonal_kernel_variance_dist, 
            seasonal_kernel_noise_dist,
            seasonal_kernel_length_dist,
            seasonal_kernel_period_dist,
        ):
        seasonal_kernel_variance = numpyro.sample(
            "seasonal_kernel_variance",
            seasonal_kernel_variance_dist,
        )
        seasonal_kernel_noise = numpyro.sample(
            "seasonal_kernel_noise",
            seasonal_kernel_noise_dist,
        )
        seasonal_kernel_length = numpyro.sample(
            "seasonal_kernel_length",
            seasonal_kernel_length_dist,
        )
        seasonal_kernel_period = numpyro.sample(
            "seasonal_kernel_length_periodic",
            seasonal_kernel_period_dist,
        )
        covariance_matrix_seasonal = periodic_kernel(
            X=X,
            X_prime=X,
            variance=seasonal_kernel_variance,
            length=seasonal_kernel_length,
            period=seasonal_kernel_period,
            noise=seasonal_kernel_noise,
        )
        return covariance_matrix_seasonal
    
    def build_effect_component(
            X, 
            effect_initial_dist, 
            effect_decay_rate_dist,
        ):
        effect_initial = numpyro.sample(
            "effect_initial",
            effect_initial_dist,
        )
        effect_decay_rate = numpyro.sample(
            "effect_decay_rate",
            effect_decay_rate_dist,
        )
        effect_component = numpyro.deterministic(
            "effect_component",
            exponential_decay_function(
                initial_value=effect_initial,
                rate_decay=effect_decay_rate,
                max_t=X.shape[0],
            )
        )
        return effect_component


    def get_model(X, y, days):
        # The parameters of the kernel function are sampled from their
        # associated prior distributions

        trend_covariance = build_trend_covariance(
            X=X, 
            trend_kernel_variance_dist=trend_kernel_variance_dist, 
            trend_kernel_noise_dist=trend_kernel_noise_dist, 
            trend_kernel_length_dist=trend_kernel_length_dist,
        )
        seasonal_covariance = build_seasonal_covariance(
            X=X, 
            seasonal_kernel_variance_dist=seasonal_kernel_variance_dist, 
            seasonal_kernel_noise_dist=seasonal_kernel_noise_dist,
            seasonal_kernel_length_dist=seasonal_kernel_length_dist,
            seasonal_kernel_period_dist=seasonal_kernel_period_dist,
        )
        effect_component = build_effect_component(
            X=X, 
            effect_initial_dist=effect_initial_dist, 
            effect_decay_rate_dist=effect_decay_rate_dist,
        )
        intercept =  numpyro.sample(
            "intercept",
            dist.MultivariateNormal(
                loc=jnp.zeros(X.shape[0]),
                covariance_matrix=trend_covariance +  seasonal_covariance,
            )
        )
        coefficients_days = numpyro.sample(
            "coefficients_days",
            coefficients_days_dist,
        )
        days_effect = numpyro.deterministic(
            "days_effect", 
            jnp.dot(days, coefficients_days),
        )
        model_noise = numpyro.sample(
            "model_noise",
            model_noise_dist,
        )

        numpyro.sample(
            "y",
            dist.Normal(
                intercept + days_effect + effect_component,
                model_noise
            ),
            obs=y,
        )

    return build_trend_covariance, build_seasonal_covariance, get_model
Show supplementary code
def standardize_targets_to_common_statistics(reference_X, targets_X):
    standardizing_mean, standardizing_std = _derive_mean_std(
        X=reference_X
    )
    standardized_reference_X = (reference_X - standardizing_mean) / standardizing_std
    standardized_tragets_X = {}
    for target_name, target_X in targets_X.items():

        standardized_tragets_X[target_name] = (target_X - standardizing_mean) / standardizing_std

    return standardizing_mean, standardizing_std, standardized_reference_X, standardized_tragets_X

def inverse_standardize(X, mean, std):
    return X * std + mean

def _derive_mean_std(X):
    return np.mean(X), np.std(X)
from sklearn.preprocessing import OneHotEncoder

y_control = df["Branch A"].values.copy()
y_models = {
    model_name: df[model_name].values.copy() for model_name in arms
}

(
    standardizing_mean,
    standardizing_std,
    y_control_standardized,
    y_models_standardized
) = standardize_targets_to_common_statistics(
    reference_X=y_control,
    targets_X=y_models
)

# This can be read as the N of days over which the AB test was conducted
X = np.linspace(0, 1, len(y_control))

days = OneHotEncoder(sparse_output=False).fit_transform(df["day_week"].values.reshape(-1, 1))
Show supplementary code
def plot_prior_marginal(ax, prior_frozen, prior_name, samples=1000):
    support = np.sort(prior_frozen.rvs(samples))
    density = prior_frozen.pdf(support)
    ax.plot(
        support,
        density,
        linewidth=2.5,
        c=BLUE
    )
    ax.set_title(f"Prior \n {prior_name}")
    ax.set_xlabel("Support")
    ax.set_ylabel("Density")
    return ax

def plot_sampled_functions(ax, X, X_prime, kernel, priors, n_samples=10):
    alphas = np.linspace(0, 1, n_samples * 2)
    for sample in range(n_samples):

        priors_sample = {}

        for prior_name, prior_frozen_dist in priors.items():

            priors_sample[prior_name] = prior_frozen_dist.rvs(1)

        k = kernel(X, X_prime, **priors_sample)
        sampled_f = np.random.multivariate_normal(np.zeros(len(X)), k)
        ax.plot(
            X,
            sampled_f,
            alpha=alphas[sample],
            c=RED
        )

    ax.set_title("Prior over $f$")
    ax.set_xlabel("$X$")
    ax.set_ylabel("$f(X)$")
    return ax
def sample_posterior(model, sampler_kernel_algo, rng_key, X, y, days,
                     **mcmc_kwargs):
    # We obtain the kernel used by the MCMC sampler
    sampler_kernel = sampler_kernel_algo(model)
    mcmc = MCMC(
        sampler_kernel,
        **mcmc_kwargs
    )
    mcmc.run(
        rng_key=rng_key,
        X=X,
        y=y,
        days=days,
    )
    mcmc.print_summary()
    posterior_samples = mcmc.get_samples()
    return posterior_samples


def compute_posterior_predictive(rng_key, X, y, X_test, variance,
                                 length, noise):
    # In this case p means that the computations are executing using
    # the test data,
    covariance_xp_xp = RBF_kernel(
        X=X_test,
        X_prime=X_test,
        variance=variance,
        length=length,
        noise=noise,
        include_noise=True
    )
    covariance_xp_x = RBF_kernel(
        X=X_test,
        X_prime=X,
        variance=variance,
        length=length,
        noise=noise,
        include_noise=False
    )
    covariance_x_x = RBF_kernel(
        X=X,
        X_prime=X,
        variance=variance,
        length=length,
        noise=noise,
        include_noise=True
    )

    covariance_x_x_inverse = jnp.linalg.inv(covariance_x_x)
    full_covariance = covariance_xp_xp - jnp.matmul(
        covariance_xp_x,
        jnp.matmul(
            covariance_x_x_inverse,
            jnp.transpose(covariance_xp_x)
        )
    )

    predictive_noise = jnp.sqrt(
        jnp.clip(jnp.diag(full_covariance), a_min=0.0)
    ) * jax.random.normal(rng_key, X_test.shape[:1])

    predictive_mean = jnp.matmul(
        covariance_xp_x,
        jnp.matmul(covariance_x_x_inverse, y)
    )

    return predictive_mean, predictive_mean + predictive_noise

def sample_posterior_predictive(rng_key, X, X_test, y, samples,
                                compute_posterior_predictive):
    vmap_args = (
        random.split(rng_key, samples["trend_kernel_variance"].shape[0]),
        samples["trend_kernel_variance"],
        samples["trend_kernel_length"],
        samples["trend_kernel_noise"],
    )
    predictive_mean, predictive_distribution = vmap(
    lambda rng_key, variance, length, noise: compute_posterior_predictive(
        rng_key, X, y, X_test, variance, length, noise
    )
    )(*vmap_args)
    return predictive_mean, predictive_distribution
rng_key, rng_key_predictive = random.split(random.PRNGKey(0))

posteriors_models = {}
build_trend_component, build_seasonal_component, get_model = gaussian_process_model(
    trend_kernel_variance_dist=dist.LogNormal(0., 1.),
    trend_kernel_noise_dist=dist.LogNormal(0., 1.),
    trend_kernel_length_dist=dist.LogNormal(0., 1.),

    seasonal_kernel_variance_dist=dist.LogNormal(0., 1.),
    seasonal_kernel_noise_dist=dist.LogNormal(0., 1.),
    seasonal_kernel_length_dist=dist.LogNormal(0., .1),
    seasonal_kernel_period_dist=dist.LogNormal(0., 1.),

    effect_initial_dist=dist.Normal(0., 1.),
    effect_decay_rate_dist=dist.Beta(1., 4.),

    coefficients_days_dist=dist.Normal(0., jnp.ones(days.shape[1]) * 1.),

    model_noise_dist=dist.LogNormal(0., .05),
)

posterior_samples = sample_posterior(
    model=get_model,
    sampler_kernel_algo=NUTS,
    rng_key=rng_key,
    X=X,
    y=y_control_standardized,
    days=days,
    num_warmup=2000,
    num_samples=2000,
    num_chains=4,
    chain_method='parallel',
)

for model_name in arms:

    print(f"Fitting for model {model_name}")

    posterior_samples_model = sample_posterior(
        model=get_model,
        sampler_kernel_algo=NUTS,
        rng_key=rng_key,
        X=X,
        days=days,
        y=y_models_standardized[model_name],
        num_warmup=2000,
        num_samples=2000,
        num_chains=NUM_CHAINS
    )
    posteriors_models[model_name] = posterior_samples_model

                                       mean       std    median      5.0%     95.0%     n_eff     r_hat
             coefficients_days[0]      0.02      0.62      0.01     -0.98      1.01    600.27      1.00
             coefficients_days[1]      0.12      0.60      0.12     -0.83      1.12   2311.81      1.00
             coefficients_days[2]      0.18      0.62      0.18     -0.90      1.15    842.41      1.00
             coefficients_days[3]      0.26      0.60      0.26     -0.73      1.23   2126.69      1.00
             coefficients_days[4]      0.02      0.60      0.03     -1.01      0.97    875.29      1.00
             coefficients_days[5]     -0.15      0.60     -0.14     -1.09      0.90   1477.32      1.00
             coefficients_days[6]     -0.41      0.60     -0.41     -1.36      0.59   1522.91      1.00
                effect_decay_rate      0.20      0.16      0.15      0.00      0.43   3346.10      1.00
                   effect_initial      0.49      0.88      0.49     -0.81      2.02   1033.70      1.00
                     intercept[0]      0.87      1.01      0.86     -0.69      2.65    818.33      1.00
                     intercept[1]      0.90      0.97      0.89     -0.82      2.33    857.05      1.00
                     intercept[2]      0.61      0.92      0.63     -0.87      2.12    969.14      1.00
                     intercept[3]      0.68      0.89      0.69     -0.76      2.18    620.09      1.00
                     intercept[4]      0.64      0.87      0.64     -0.84      2.06    582.23      1.00
                     intercept[5]      0.65      0.88      0.68     -0.70      2.10    416.12      1.01
                     intercept[6]      0.66      0.85      0.67     -0.73      2.11    594.94      1.00
                     intercept[7]      0.57      0.83      0.59     -0.80      1.87    755.22      1.00
                     intercept[8]      0.56      0.81      0.58     -0.93      1.71    785.59      1.00
                     intercept[9]      0.46      0.81      0.45     -0.80      1.85   1159.61      1.00
                    intercept[10]      0.29      0.80      0.31     -1.02      1.66    653.89      1.00
                    intercept[11]      0.37      0.82      0.38     -0.93      1.73    424.54      1.01
                    intercept[12]      0.18      0.81      0.19     -1.15      1.48    451.41      1.00
                    intercept[13]     -0.10      0.78     -0.10     -1.37      1.17   1079.34      1.00
                    intercept[14]     -0.15      0.78     -0.14     -1.52      1.04    838.16      1.00
                    intercept[15]     -0.41      0.77     -0.40     -1.68      0.81   1155.83      1.00
                    intercept[16]     -0.47      0.76     -0.45     -1.66      0.83   1493.72      1.00
                    intercept[17]     -0.73      0.75     -0.73     -1.96      0.49   1347.54      1.00
                    intercept[18]     -0.80      0.77     -0.77     -2.06      0.47    630.61      1.00
                    intercept[19]     -1.02      0.76     -1.02     -2.19      0.30    887.48      1.00
                    intercept[20]     -0.74      0.76     -0.73     -1.99      0.47    758.70      1.00
                    intercept[21]     -0.99      0.75     -0.98     -2.20      0.21   1151.92      1.00
                    intercept[22]     -1.07      0.76     -1.06     -2.26      0.21   1142.17      1.00
                    intercept[23]     -0.99      0.73     -0.97     -2.15      0.25   1540.78      1.00
                    intercept[24]     -1.05      0.78     -1.03     -2.31      0.21    500.25      1.00
                    intercept[25]     -1.08      0.77     -1.06     -2.35      0.18    823.04      1.00
                    intercept[26]     -0.89      0.79     -0.88     -2.38      0.25    806.12      1.00
                    intercept[27]     -0.96      0.80     -0.96     -2.20      0.39   1516.32      1.00
                      model_noise      0.97      0.05      0.97      0.89      1.05   3133.61      1.00
           seasonal_kernel_length      1.01      0.10      1.01      0.85      1.17   4394.14      1.00
  seasonal_kernel_length_periodic      2.54      2.37      1.99      0.10      4.75   3978.98      1.00
            seasonal_kernel_noise      0.26      0.18      0.21      0.03      0.50   2939.42      1.00
         seasonal_kernel_variance      1.07      1.13      0.73      0.06      2.23   5245.47      1.00
              trend_kernel_length      1.46      2.04      0.86      0.09      3.09   4435.30      1.00
               trend_kernel_noise      0.26      0.18      0.21      0.04      0.50   3413.61      1.00
            trend_kernel_variance      1.36      1.49      0.92      0.04      2.88   3969.77      1.00

Number of divergences: 106
Fitting for model Branch B

                                       mean       std    median      5.0%     95.0%     n_eff     r_hat
             coefficients_days[0]     -0.03      0.60     -0.03     -0.98      0.99   2401.86      1.00
             coefficients_days[1]      0.18      0.60      0.18     -0.80      1.18   2216.15      1.00
             coefficients_days[2]      0.19      0.61      0.20     -0.79      1.19   1931.56      1.00
             coefficients_days[3]      0.25      0.60      0.25     -0.75      1.20   2296.56      1.00
             coefficients_days[4]     -0.01      0.60     -0.00     -1.06      0.91   2045.11      1.00
             coefficients_days[5]     -0.05      0.60     -0.04     -1.02      0.93   2270.06      1.00
             coefficients_days[6]     -0.33      0.59     -0.34     -1.34      0.61   2286.42      1.00
                effect_decay_rate      0.19      0.15      0.14      0.00      0.41   2764.25      1.00
                   effect_initial      0.53      0.90      0.55     -0.91      2.05   1300.25      1.00
                     intercept[0]      1.30      1.02      1.30     -0.35      2.97   1149.93      1.00
                     intercept[1]      1.40      0.99      1.37     -0.15      3.09   1107.94      1.00
                     intercept[2]      1.44      0.95      1.43     -0.07      3.06   1164.34      1.00
                     intercept[3]      1.42      0.91      1.41     -0.07      2.93   1095.59      1.00
                     intercept[4]      1.20      0.89      1.20     -0.21      2.71   1067.58      1.00
                     intercept[5]      1.02      0.88      1.03     -0.43      2.46   1179.55      1.00
                     intercept[6]      1.13      0.87      1.12     -0.33      2.52   1110.25      1.00
                     intercept[7]      1.28      0.84      1.29     -0.04      2.68   1146.04      1.00
                     intercept[8]      0.98      0.84      0.99     -0.44      2.28   1169.86      1.00
                     intercept[9]      0.81      0.83      0.82     -0.50      2.20   1275.46      1.00
                    intercept[10]      0.67      0.81      0.69     -0.65      2.02   1085.99      1.00
                    intercept[11]      0.64      0.81      0.65     -0.64      1.99   1138.20      1.00
                    intercept[12]      0.51      0.79      0.52     -0.78      1.78   1233.90      1.00
                    intercept[13]      0.31      0.80      0.31     -1.02      1.63   1300.77      1.00
                    intercept[14]      0.19      0.78      0.20     -1.07      1.46   1287.46      1.00
                    intercept[15]      0.05      0.77      0.05     -1.28      1.23   1342.79      1.00
                    intercept[16]     -0.18      0.77     -0.17     -1.45      1.08   1429.08      1.00
                    intercept[17]     -0.34      0.77     -0.33     -1.58      0.94   1401.26      1.00
                    intercept[18]     -0.49      0.75     -0.48     -1.67      0.79   1358.44      1.00
                    intercept[19]     -0.48      0.76     -0.48     -1.71      0.76   1449.94      1.00
                    intercept[20]     -0.57      0.77     -0.57     -1.82      0.70   1318.36      1.00
                    intercept[21]     -0.85      0.76     -0.83     -2.08      0.36   1477.93      1.00
                    intercept[22]     -0.86      0.77     -0.85     -2.14      0.39   1466.38      1.00
                    intercept[23]     -0.87      0.75     -0.86     -2.07      0.40   1555.59      1.00
                    intercept[24]     -1.04      0.77     -1.03     -2.33      0.17   1400.06      1.00
                    intercept[25]     -0.80      0.77     -0.80     -2.08      0.45   1558.39      1.00
                    intercept[26]     -0.77      0.79     -0.77     -2.04      0.53   1614.80      1.00
                    intercept[27]     -0.74      0.81     -0.74     -2.02      0.64   1702.85      1.00
                      model_noise      0.97      0.05      0.97      0.89      1.05   5852.26      1.00
           seasonal_kernel_length      1.01      0.10      1.00      0.85      1.18   7488.12      1.00
  seasonal_kernel_length_periodic      2.57      2.12      2.07      0.12      4.75   3853.72      1.00
            seasonal_kernel_noise      0.26      0.18      0.21      0.03      0.49   3222.86      1.00
         seasonal_kernel_variance      1.20      1.24      0.84      0.03      2.51   5037.33      1.00
              trend_kernel_length      1.41      2.03      0.82      0.08      2.93   4307.25      1.00
               trend_kernel_noise      0.26      0.18      0.21      0.03      0.49   3465.29      1.00
            trend_kernel_variance      1.48      1.75      0.98      0.04      3.15   4715.38      1.00

Number of divergences: 94
Fitting for model Branch C

                                       mean       std    median      5.0%     95.0%     n_eff     r_hat
             coefficients_days[0]      0.14      0.60      0.14     -0.84      1.11   2460.65      1.00
             coefficients_days[1]      0.08      0.61      0.08     -0.91      1.11   1941.57      1.00
             coefficients_days[2]      0.30      0.60      0.31     -0.69      1.27   2254.09      1.00
             coefficients_days[3]      0.33      0.60      0.33     -0.67      1.31   2650.03      1.00
             coefficients_days[4]     -0.16      0.59     -0.16     -1.18      0.75   2358.62      1.00
             coefficients_days[5]     -0.21      0.60     -0.20     -1.19      0.79   2283.02      1.00
             coefficients_days[6]     -0.26      0.60     -0.27     -1.19      0.78   2449.23      1.00
                effect_decay_rate      0.19      0.15      0.15      0.00      0.41   3032.63      1.00
                   effect_initial      0.68      0.89      0.68     -0.77      2.12   1526.09      1.00
                     intercept[0]      1.66      1.04      1.66      0.06      3.47   1286.46      1.00
                     intercept[1]      1.60      0.99      1.59     -0.04      3.16   1189.00      1.00
                     intercept[2]      1.54      0.95      1.54      0.05      3.17   1167.40      1.00
                     intercept[3]      1.48      0.93      1.48      0.02      3.02   1115.09      1.00
                     intercept[4]      1.54      0.91      1.54      0.06      3.06   1104.23      1.00
                     intercept[5]      1.29      0.88      1.32     -0.16      2.74   1160.62      1.00
                     intercept[6]      1.32      0.87      1.31     -0.03      2.80   1119.05      1.00
                     intercept[7]      1.23      0.85      1.24     -0.13      2.66   1143.87      1.00
                     intercept[8]      0.93      0.84      0.95     -0.38      2.36   1199.28      1.00
                     intercept[9]      0.98      0.83      0.99     -0.41      2.30   1289.15      1.00
                    intercept[10]      0.80      0.81      0.82     -0.55      2.11   1200.20      1.00
                    intercept[11]      0.64      0.81      0.64     -0.64      2.01   1210.30      1.00
                    intercept[12]      0.70      0.81      0.71     -0.64      1.98   1267.18      1.00
                    intercept[13]      0.46      0.80      0.47     -0.83      1.82   1306.94      1.00
                    intercept[14]      0.41      0.79      0.42     -0.81      1.79   1366.38      1.00
                    intercept[15]      0.25      0.78      0.27     -1.07      1.48   1428.01      1.00
                    intercept[16]     -0.05      0.78     -0.05     -1.28      1.26   1492.88      1.00
                    intercept[17]     -0.31      0.77     -0.31     -1.58      0.94   1459.19      1.00
                    intercept[18]     -0.28      0.76     -0.28     -1.52      0.97   1462.45      1.00
                    intercept[19]     -0.32      0.78     -0.33     -1.58      0.97   1313.20      1.00
                    intercept[20]     -0.37      0.76     -0.36     -1.61      0.89   1359.80      1.00
                    intercept[21]     -0.72      0.76     -0.70     -1.98      0.49   1487.05      1.00
                    intercept[22]     -0.66      0.77     -0.66     -1.93      0.60   1432.00      1.00
                    intercept[23]     -0.85      0.76     -0.83     -2.05      0.45   1647.37      1.00
                    intercept[24]     -0.67      0.75     -0.67     -1.86      0.60   1474.23      1.00
                    intercept[25]     -0.88      0.78     -0.88     -2.11      0.43   1734.27      1.00
                    intercept[26]     -0.85      0.79     -0.83     -2.15      0.41   1757.67      1.00
                    intercept[27]     -0.81      0.80     -0.81     -2.12      0.49   1919.05      1.00
                      model_noise      0.97      0.05      0.97      0.89      1.05   6601.94      1.00
           seasonal_kernel_length      1.01      0.10      1.01      0.86      1.17   7467.93      1.00
  seasonal_kernel_length_periodic      2.72      2.45      2.18      0.12      5.12   4147.26      1.00
            seasonal_kernel_noise      0.25      0.18      0.21      0.03      0.48   3321.23      1.00
         seasonal_kernel_variance      1.25      1.34      0.85      0.04      2.69   4941.00      1.00
              trend_kernel_length      1.35      1.89      0.80      0.11      2.81   4949.16      1.00
               trend_kernel_noise      0.26      0.18      0.21      0.03      0.49   3509.86      1.00
            trend_kernel_variance      1.59      1.86      1.03      0.06      3.46   4166.34      1.00

Number of divergences: 114
Fitting for model Branch D

                                       mean       std    median      5.0%     95.0%     n_eff     r_hat
             coefficients_days[0]      0.26      0.59      0.27     -0.79      1.16   2239.16      1.00
             coefficients_days[1]      0.08      0.60      0.08     -0.93      1.05   2312.88      1.00
             coefficients_days[2]      0.33      0.60      0.34     -0.68      1.29   2318.50      1.00
             coefficients_days[3]      0.41      0.61      0.41     -0.61      1.37   1983.59      1.00
             coefficients_days[4]     -0.06      0.59     -0.05     -1.03      0.92   2187.46      1.00
             coefficients_days[5]     -0.24      0.60     -0.24     -1.22      0.73   2365.62      1.00
             coefficients_days[6]     -0.44      0.59     -0.45     -1.37      0.58   2410.09      1.00
                effect_decay_rate      0.19      0.15      0.14      0.00      0.40   2734.62      1.00
                   effect_initial      0.74      0.90      0.73     -0.79      2.20   1294.58      1.00
                     intercept[0]      1.68      1.03      1.65      0.05      3.37   1119.85      1.00
                     intercept[1]      1.67      0.99      1.64      0.08      3.32   1115.84      1.00
                     intercept[2]      1.61      0.97      1.61      0.15      3.28   1063.15      1.00
                     intercept[3]      1.64      0.92      1.64      0.08      3.10   1026.81      1.00
                     intercept[4]      1.46      0.90      1.45     -0.07      2.86   1052.70      1.00
                     intercept[5]      1.39      0.87      1.40     -0.06      2.79   1099.93      1.00
                     intercept[6]      1.21      0.87      1.20     -0.17      2.64   1060.65      1.00
                     intercept[7]      1.05      0.86      1.07     -0.33      2.43   1024.30      1.00
                     intercept[8]      1.00      0.84      1.00     -0.31      2.42   1157.97      1.00
                     intercept[9]      0.93      0.82      0.93     -0.41      2.29   1201.12      1.00
                    intercept[10]      0.74      0.80      0.76     -0.59      2.01   1147.61      1.00
                    intercept[11]      0.76      0.80      0.75     -0.50      2.10   1202.42      1.00
                    intercept[12]      0.63      0.79      0.62     -0.67      1.90   1155.70      1.00
                    intercept[13]      0.54      0.80      0.54     -0.76      1.86   1228.85      1.00
                    intercept[14]      0.33      0.79      0.34     -0.94      1.64   1205.89      1.00
                    intercept[15]      0.11      0.78      0.12     -1.18      1.38   1243.86      1.00
                    intercept[16]     -0.08      0.77     -0.07     -1.27      1.25   1343.16      1.00
                    intercept[17]     -0.32      0.77     -0.31     -1.56      0.97   1314.20      1.00
                    intercept[18]     -0.40      0.77     -0.39     -1.69      0.84   1340.50      1.00
                    intercept[19]     -0.45      0.76     -0.45     -1.65      0.84   1458.84      1.00
                    intercept[20]     -0.35      0.77     -0.35     -1.65      0.86   1306.94      1.00
                    intercept[21]     -0.43      0.75     -0.43     -1.66      0.79   1500.43      1.00
                    intercept[22]     -0.59      0.76     -0.59     -1.75      0.75   1595.46      1.00
                    intercept[23]     -0.80      0.75     -0.79     -2.06      0.38   1409.25      1.00
                    intercept[24]     -0.80      0.74     -0.80     -2.01      0.38   1555.13      1.00
                    intercept[25]     -0.73      0.77     -0.72     -2.07      0.46   1398.41      1.00
                    intercept[26]     -0.59      0.78     -0.58     -1.84      0.72   1674.77      1.00
                    intercept[27]     -0.71      0.80     -0.73     -2.04      0.57   1815.45      1.00
                      model_noise      0.97      0.05      0.97      0.89      1.05   6332.76      1.00
           seasonal_kernel_length      1.01      0.10      1.01      0.86      1.19   4351.16      1.00
  seasonal_kernel_length_periodic      2.67      2.28      2.15      0.14      5.02   3540.77      1.00
            seasonal_kernel_noise      0.25      0.18      0.21      0.03      0.48   3304.63      1.00
         seasonal_kernel_variance      1.27      1.41      0.86      0.04      2.70   4757.60      1.00
              trend_kernel_length      1.39      2.00      0.84      0.09      2.88   4508.90      1.00
               trend_kernel_noise      0.25      0.18      0.21      0.04      0.48   3258.94      1.00
            trend_kernel_variance      1.65      2.03      1.03      0.03      3.55    840.58      1.01

Number of divergences: 135
Fitting for model Branch E

                                       mean       std    median      5.0%     95.0%     n_eff     r_hat
             coefficients_days[0]      0.10      0.60      0.10     -0.89      1.07   2276.13      1.00
             coefficients_days[1]      0.22      0.61      0.22     -0.80      1.19   2620.93      1.00
             coefficients_days[2]      0.11      0.61      0.11     -0.87      1.14   2169.29      1.00
             coefficients_days[3]      0.30      0.60      0.31     -0.68      1.30   2471.07      1.00
             coefficients_days[4]     -0.04      0.59     -0.03     -1.03      0.90   2463.33      1.00
             coefficients_days[5]      0.04      0.59      0.05     -0.92      1.01   2275.83      1.00
             coefficients_days[6]     -0.28      0.59     -0.29     -1.27      0.67   2212.04      1.00
                effect_decay_rate      0.18      0.15      0.14      0.00      0.41   3199.26      1.00
                   effect_initial      0.65      0.90      0.65     -0.81      2.14   1489.64      1.00
                     intercept[0]      1.79      1.02      1.78      0.15      3.47   1221.63      1.00
                     intercept[1]      1.81      0.99      1.80      0.24      3.46   1161.49      1.00
                     intercept[2]      1.81      0.96      1.80      0.21      3.35   1066.32      1.00
                     intercept[3]      1.72      0.91      1.71      0.15      3.13   1030.42      1.00
                     intercept[4]      1.72      0.89      1.72      0.30      3.24   1100.25      1.00
                     intercept[5]      1.61      0.88      1.62      0.18      3.08   1038.94      1.00
                     intercept[6]      1.32      0.86      1.32     -0.01      2.83   1055.96      1.00
                     intercept[7]      1.50      0.84      1.51      0.15      2.90   1062.59      1.00
                     intercept[8]      1.27      0.83      1.28     -0.13      2.57   1144.41      1.00
                     intercept[9]      1.07      0.83      1.08     -0.26      2.45   1130.59      1.00
                    intercept[10]      0.99      0.82      0.99     -0.26      2.42   1083.41      1.00
                    intercept[11]      0.84      0.80      0.84     -0.50      2.13   1137.21      1.00
                    intercept[12]      0.66      0.80      0.66     -0.55      2.03   1156.39      1.00
                    intercept[13]      0.61      0.79      0.61     -0.70      1.87   1235.88      1.00
                    intercept[14]      0.36      0.79      0.36     -0.95      1.65   1351.45      1.00
                    intercept[15]      0.33      0.78      0.34     -0.96      1.59   1261.37      1.00
                    intercept[16]     -0.01      0.77      0.00     -1.27      1.24   1382.18      1.00
                    intercept[17]     -0.14      0.77     -0.13     -1.45      1.08   1359.10      1.00
                    intercept[18]     -0.26      0.76     -0.25     -1.50      0.99   1386.48      1.00
                    intercept[19]     -0.20      0.75     -0.20     -1.34      1.12   1410.90      1.00
                    intercept[20]     -0.29      0.77     -0.29     -1.55      0.95   1311.00      1.00
                    intercept[21]     -0.55      0.76     -0.54     -1.75      0.70   1519.70      1.00
                    intercept[22]     -0.68      0.75     -0.67     -1.95      0.52   1515.70      1.00
                    intercept[23]     -0.59      0.76     -0.58     -1.86      0.61   1552.69      1.00
                    intercept[24]     -0.75      0.75     -0.74     -2.07      0.40   1471.96      1.00
                    intercept[25]     -0.73      0.77     -0.72     -2.00      0.48   1554.32      1.00
                    intercept[26]     -0.77      0.77     -0.76     -2.00      0.52   1500.53      1.00
                    intercept[27]     -0.63      0.80     -0.63     -1.95      0.68   1716.52      1.00
                      model_noise      0.97      0.05      0.97      0.89      1.05   5649.22      1.00
           seasonal_kernel_length      1.01      0.10      1.01      0.85      1.18   6571.70      1.00
  seasonal_kernel_length_periodic      2.65      2.11      2.18      0.12      4.89   4428.06      1.00
            seasonal_kernel_noise      0.25      0.18      0.21      0.03      0.48   3557.00      1.00
         seasonal_kernel_variance      1.33      1.44      0.90      0.04      2.85   4853.71      1.00
              trend_kernel_length      1.38      1.94      0.80      0.14      2.88   4402.17      1.00
               trend_kernel_noise      0.25      0.17      0.21      0.03      0.47   3548.47      1.00
            trend_kernel_variance      1.61      1.91      1.07      0.06      3.49   5695.54      1.00

Number of divergences: 106
samples_mean_control, _ = sample_posterior_predictive(
    rng_key_predictive,
    X,
    X,
    y_control,
    posterior_samples,
    compute_posterior_predictive
)
samples_mean_control = inverse_standardize(
    X=samples_mean_control,
    mean=standardizing_mean,
    std=standardizing_std
)

samples_mean_models = {}
samples_mean_differences = {}
for model_name in arms:

    samples_mean_model, _ = sample_posterior_predictive(
        rng_key_predictive,
        X,
        X,
        y_models_standardized[model_name],
        posteriors_models[model_name],
        compute_posterior_predictive
    )
    samples_mean_model = inverse_standardize(
        X=samples_mean_model,
        mean=standardizing_mean,
        std=standardizing_std
    )

    samples_mean_models[model_name] = samples_mean_model
    samples_mean_differences[model_name] = samples_mean_model - samples_mean_control

6 Visualize the results

Show supplementary code
def compute_mean_percentiles(time_series_samples):
    expected_time_series = np.mean(time_series_samples, axis=0)
    percentiles_time_series = np.percentile(
        time_series_samples,
        [2.5, 97.5],
        axis=0
    )
    return expected_time_series, percentiles_time_series

def plot_estimated_time_series(ax, color, time, expected_values,
                               percentiles_values, **plot_kwargs):
    ax.plot(
        time,
        expected_values,
        color=color,
        **plot_kwargs
    )
    ax.fill_between(
        time,
        percentiles_values[0, :],
        percentiles_values[1, :],
        color=color,
        alpha=0.1
    )
    return ax
expected_mean_control, percentiles_mean_control = compute_mean_percentiles(
    time_series_samples=samples_mean_control
)
expected_residuals_control, percentiles_residuals_control = compute_mean_percentiles(
    time_series_samples=samples_mean_control - y_control
)

means_models = {}
residuals_models = {}
deltas_models = {}
statistics_differences_models = {}
for model_name in arms:

    expected_mean_model, percentiles_mean_model = compute_mean_percentiles(
        time_series_samples=samples_mean_models[model_name]
    )
    expected_residuals_model, percentiles_residuals_model = compute_mean_percentiles(
        time_series_samples=samples_mean_models[model_name] - y_models[model_name]
    )
    expected_mean_difference, percentiles_mean_difference = compute_mean_percentiles(
        time_series_samples=samples_mean_differences[model_name]
    )

    samples_expected_difference = np.mean(samples_mean_differences[model_name], axis=1)
    percentiles_difference = np.percentile(samples_expected_difference, [2.5, 97.5], axis=0)

    means_models[model_name] = {
        "mean": expected_mean_model,
        "percentiles": percentiles_mean_model
    }
    residuals_models[model_name] = {
        "mean": expected_residuals_model,
        "percentiles": percentiles_residuals_model
    }
    deltas_models[model_name] = {
        "mean": expected_mean_difference,
        "percentiles": percentiles_mean_difference
    }
    statistics_differences_models[model_name] = {
        "expected_difference": samples_expected_difference,
        "percentiles_difference": percentiles_difference
    }
fig, axs = plt.subplots(
    2,
    2,
    tight_layout=True,
    figsize=(12, 8),
    sharey=True
)

for model_name, ax in zip(arms, axs.flatten()):

    expeceted_residuals = round(np.mean(residuals_models[model_name]["mean"]), 2)
    ax_residuals= plot_estimated_time_series(
        ax=ax,
        time=df["date"].values,
        expected_values=residuals_models[model_name]["mean"],
        percentiles_values=residuals_models[model_name]["percentiles"],
        color="k",
        label=f"Mean Residuals {expeceted_residuals}"
    )
    ax_residuals.set_title(model_name)
    ax_residuals.set_ylabel("Residuals")
    ax.legend()

for ax in fig.axes:

    ax.axhline(0, linestyle=":", c="k")
    ax = _set_x_axis_grid(ax=ax)

plt.tight_layout()
plt.show()

for model in arms:

    fig = plt.figure(tight_layout=True, figsize=(10, 7))
    gs = gridspec.GridSpec(2, 3, figure=fig)

    ax_time_series = fig.add_subplot(gs[0, :])
    ax_difference = fig.add_subplot(gs[1, :2])
    ax_distribution = fig.add_subplot(gs[1, -1])

    ax_time_series = plot_arms_time_series(
        time="date",
        arm_a="Branch A",
        arm_b=model,
        df=df,
        y_label="Dependent Variable",
        title="Estimated Dependent Variable",
        ax=ax_time_series
    )
    _set_x_axis_grid(ax=ax_time_series)

    ax_time_series = plot_estimated_time_series(
        ax=ax_time_series,
        color=BLUE,
        time=df["date"].values,
        expected_values=expected_mean_control,
        percentiles_values=percentiles_mean_control,
    )
    ax_time_series = plot_estimated_time_series(
        ax=ax_time_series,
        color=RED,
        time=df["date"].values,
        expected_values=means_models[model]["mean"],
        percentiles_values=means_models[model]["percentiles"]
    )

    ax_difference = plot_estimated_time_series(
        ax=ax_difference,
        color="k",
        time=df["date"].values,
        expected_values=deltas_models[model]["mean"],
        percentiles_values=deltas_models[model]["percentiles"],
        linestyle="--"
    )

    ax_difference.set_title(
        "Series of Differences in Dependent Variable"
    )
    ax_difference.axhline(0, linestyle=":", c="r")
    ax_difference.set_xlabel("Date")
    ax_difference.set_ylabel("Difference Dependent Variable")

    ax_time_series.axvline(
        df["date"].max(),
        c="k",
        linestyle=":",
        alpha=0.5
    )

    ax_distribution.hist(
        statistics_differences_models[model]["expected_difference"],
        bins=400,
        color="k",
        alpha=0.25,
        density=True
    )
    ax_distribution.set_title("Distribution of\nExpected Difference")
    ax_distribution.set_xlabel("Expected Difference\nin Dependent Variable")
    ax_distribution.axvline(
        np.median(statistics_differences_models[model]["expected_difference"]),
        c="red",
        linestyle="--"
    )
    ax_distribution.axvline(
        statistics_differences_models[model]["percentiles_difference"][0],
        c="red",
        linestyle=":"
    )
    ax_distribution.axvline(
        statistics_differences_models[model]["percentiles_difference"][1],
        c="red",
        linestyle=":"
    )

    ax_time_series.legend()

    ax_time_series = _set_x_axis_grid(ax=ax_time_series)
    ax_difference = _set_x_axis_grid(ax=ax_difference)

    plt.suptitle("AB Test")
    plt.tight_layout()
    plt.show()

7 Conclusion

8 Hardware and Requirements

Here you can find the hardware and python requirements used for building this post.

%watermark
Last updated: 2025-10-07T19:16:34.769342+02:00

Python implementation: CPython
Python version       : 3.13.2
IPython version      : 9.0.2

Compiler    : Clang 18.1.8 
OS          : Darwin
Release     : 24.5.0
Machine     : arm64
Processor   : arm
CPU cores   : 14
Architecture: 64bit
%watermark --iversions
jax       : 0.5.2
numpy     : 2.2.4
pandas    : 2.2.3
scipy     : 1.15.2
sklearn   : 1.6.1
matplotlib: 3.10.1
numpyro   : 0.18.0
seaborn   : 0.13.2

References

Benavoli, Alessio, and Francesca Mangili. 2015. “Gaussian Processes for Bayesian Hypothesis Tests on Regression Functions.” In Artificial Intelligence and Statistics, 74–82. PMLR.
Duvenaud, David. 2014. “Automatic Model Construction with Gaussian Processes.” PhD thesis.
Gelman, Andrew, John B Carlin, Hal S Stern, and Donald B Rubin. 1995. Bayesian Data Analysis. Chapman; Hall/CRC.
Gelman, Andrew, Jennifer Hill, and Aki Vehtari. 2021. Regression and Other Stories. Cambridge University Press.
McElreath, Richard. 2018. Statistical Rethinking: A Bayesian Course with Examples in r and Stan. Chapman; Hall/CRC.