<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" ><generator uri="https://jekyllrb.com/" version="3.10.0">Jekyll</generator><link href="https://wesley-demontigny.github.io/feed.xml" rel="self" type="application/atom+xml" /><link href="https://wesley-demontigny.github.io/" rel="alternate" type="text/html" /><updated>2026-03-29T20:57:23+00:00</updated><id>https://wesley-demontigny.github.io/feed.xml</id><title type="html">Wesley DeMontigny</title><subtitle>Probabilistic modeling, computational biology, and scientific software</subtitle><entry><title type="html">Sampling Markov Random Fields in NumPyro</title><link href="https://wesley-demontigny.github.io/probability/2026/03/29/MRF.html" rel="alternate" type="text/html" title="Sampling Markov Random Fields in NumPyro" /><published>2026-03-29T00:00:00+00:00</published><updated>2026-03-29T00:00:00+00:00</updated><id>https://wesley-demontigny.github.io/probability/2026/03/29/MRF</id><content type="html" xml:base="https://wesley-demontigny.github.io/probability/2026/03/29/MRF.html"><![CDATA[<script type="text/javascript" id="MathJax-script" async="" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js">
</script>

<p>Markov random fields (MRFs) are an interesting class of statistical models that originate in physics. They can be thought of as a collection of random variables where we allow arbitrary joint dependencies between variables. These dependencies are often not all-to-all and can frequently be represented as a graph, where random variables are nodes and edges indicate interactions. These interactions are typically written as potentials (log-likelihood contributions) of the state under the model.</p>

<p>MRFs have strong conceptual overlap with belief networks (or Bayesian networks). However, belief networks are generally more tractable for inference, as their directed structure enables algorithms like belief propagation and variable elimination to compute likelihoods efficiently. Random fields, on the other hand, allow more general dependency structures and therefore often require approximate inference methods such as loopy belief propagation or mean-field approximations.</p>

<p>I am considering working with MRFs in some future work, so I thought I would write a quick introduction to sampling from them using the probabilistic programming library NumPyro. Here, I do not consider inference of model parameters and instead focus on sampling from these models. The visualizations associated with sampling these models can be quite cool, as I show below.</p>

<p>Let’s begin with a basic random field. This is simply a collection of random variables</p>

\[\{X_s : s \in S\}\]

<p>Typically, we choose \(S\) (the index set) to be a grid (see the visualizations below for examples of these grids evolving over time). We obtain a Markov random field by imposing a Markov property with respect to a graph \(G = (V, E)\), such that</p>

\[X_i \perp X_{V \setminus \{i \cup N(i)\}} \mid X_{N(i)}\]

<p>where \(N(i)\) denotes the neighbors (the Markov blanket) of node \(i\). In other words, each node is conditionally independent of the rest of the graph given its neighbors. The joint distribution of the full system can then be written as a product over cliques of the graph:</p>

\[p(x) \propto \prod_{c \in C} \psi_c(x_c)\]

<p>It is worth noting that this definition is extremely general and the set of cliques could, in principle, be the entire graph.</p>

<p>Below, we consider two cases of MRFs: the Ising model (a classic model from statistical physics with local interactions) and a Hopfield network (a fully connected model from classical machine learning).</p>

<hr />
<h3 id="the-ising-model">The Ising Model</h3>
<p>The Ising model is the prototypical MRF. This model was originally developed to describe the spin of atoms in a lattice. In this model, each variable only interacts with its local neighbors, resulting in a small Markov blanket. We consider a grid \(S\) of size \(n \times n\), where each variable takes values</p>

\[X_{ij} \in \{-1, +1\}, \quad i,j \in \{1,2,\dots,n\}\]

<p>We introduce two parameters:</p>
<ul>
  <li>\(J\): The interaction strength between neighboring spins</li>
  <li>\(h\): An external field bias<br />
A large \(J\) encourages neighboring variables to align (take the same sign), while \(h\) introduces a global bias towards one sign or the other. The probability of a realization of the field \(x\) is given by</li>
</ul>

\[p(x) \propto \exp\left( \sum_{(i,j,k,l) \in E} J x_{ij} x_{kl} + \sum_{i,j \in S} h x_{ij} \right)\]

<p>where \(E\) indexes adjacent cells.</p>

<p>To sample from this model, we need to define its energy (potential) and implement a Metropolis–Hastings sampler. Below is a NumPyro implementation:</p>
<pre><code class="language-Python">import jax
import jax.numpy as jnp
import numpyro.distributions as dist
from numpyro.infer.mcmc import MCMC, MCMCKernel
from numpyro.util import identity
from collections import namedtuple

interaction_strength = 100.0
external_strength = 0.0
def local_ising_energy(state, i, j, beta = 1.0):
    energy = -1.0 * external_strength * state[i,j]

    energy -= (i&gt;0) * interaction_strength * state[i,j] * state[i-1,j]
    energy -= (i &lt; state.shape[0]-1) * interaction_strength * state[i,j] * state[i+1,j]

    energy -= (j &gt; 0) * interaction_strength * state[i,j] * state[i,j-1]
    energy -= (j &lt; state.shape[1]-1) * interaction_strength * state[i,j] * state[i,j+1]

    return beta * energy

MHState = namedtuple("MHState", ["spins", "rng_key"])
class IsingMH(MCMCKernel):
    sample_field = "spins"

    def __init__(self, potential_fn=None, **kwargs):
        super().__init__()
        self.potential_fn = potential_fn

    def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
        return MHState(init_params, rng_key)

    # Attempt a simple flip of a random state using the Metropolis-Hastings algorithm
    def sample(self, state, model_args, model_kwargs):
        spins, rng_key = state
        rng_key, key_i, key_j, key_accept = jax.random.split(rng_key, 4)

        random_i = jax.random.randint(key_i, (), 0, spins.shape[0], dtype=jnp.int32)
        random_j = jax.random.randint(key_j, (), 0, spins.shape[1], dtype=jnp.int32)

        new_spins = spins.at[random_i, random_j].set(spins[random_i, random_j] * -1)

        accept_prob = jnp.exp(self.potential_fn(spins, random_i, random_j) - self.potential_fn(new_spins, random_i, random_j))
        new_spins = jnp.where(dist.Uniform().sample(key_accept) &lt; accept_prob, new_spins, spins)

        return MHState(new_spins, rng_key)

    def postprocess_fn(self, model_args, model_kwargs):
        return identity

if __name__ == "__main__":
    key = jax.random.PRNGKey(1)
    random_state = dist.Bernoulli(0.5).sample(key, sample_shape=(64, 64)) * 2 - 1

    kernel = IsingMH(local_ising_energy)
    mcmc = MCMC(kernel, num_warmup=0, num_samples=150000)
    mcmc.run(key, init_params=random_state)
    posterior_samples = mcmc.get_samples()
</code></pre>
<p>Starting from <code class="language-plaintext highlighter-rouge">main</code>, we initialize a random grid using a transformed Bernoulli distribution and pass it into our MCMC sampler. The sampler uses a custom kernel implementing Metropolis–Hastings, where each iteration proposes flipping a single spin. Each iteration, this algorithm attempts to flip the state of a random cell from state \(x_{ij}\) to \(x'_{ij}\) and accepts the new flip with probability</p>

\[\min \bigg(1, \frac{\exp-\ell(x'_{ij})}{\exp-\ell(x_{ij})} \bigg)\]

<p>where \(\ell\) is the local energy contribution. This ensures the Markov chain spends time in each configuration proportional to its probability under the model.</p>

<p>We can animate our sampling process using the following code and we can get some interesting visual behavior:</p>
<pre><code class="language-Python">import numpy as np
import imageio

data_np = np.array(posterior_samples)[::100, :, :]
frames = ((data_np + 1) / 2 * 255).astype(np.uint8)
imageio.mimsave("ising.gif", frames, fps=50)
</code></pre>
<p>Notice how the interaction strength encourages neighbors of a common state among the cells.</p>

<p align="center">
<img src="/assets/ising.gif" width="200" />
</p>

<hr />
<h3 id="hopfield-networks">Hopfield Networks</h3>
<p>Hopfield networks are an interesting model from classical machine learning. The key idea is to construct an energy function whose local minima correspond to stored patterns (for example, images). The system can then recover these patterns by evolving toward energy minima. These were studied in the context of memory recall, where the model was given only part of an image and asked to recall the unobserved portion. Here, we do not focus on memory recall, but instead treat the Hopfield network as a distribution to sample from. The probability distribution of a Hopfield network takes the form</p>

\[p(x) \propto \exp(-\frac{1}{2} \sum_{i \neq j} W_{ij} x_i x_j)\]

<p>where \(W_{ij}\) is the symmetric weight matrix encoding the stored patterns.</p>

<p>We can apply many of the same ideas from the Ising model. The primary difference is that the graph is fully connected, meaning every node interacts with every other node.</p>

<p>Below, I generated a simple \(64\times 64\) pixel-art cat. This serves as an energy minimum of the Hopfield network:</p>
<pre><code class="language-Python">cat_64x64 = jnp.array([[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]], dtype=jnp.int8)

patterns = jnp.stack([cat_64x64], axis=0)
</code></pre>

<p>We can now implement our potential function and energy. The only difference here is that we can’t simply update based on the local energy, as the Markov blanket of the Hopfield network is the whole graph.</p>
<pre><code class="language-Python">import jax
import jax.numpy as jnp
import numpyro.distributions as dist
from numpyro.infer.mcmc import MCMC, MCMCKernel
from numpyro.util import identity
from collections import namedtuple

# Include the code from that cat image above. I did not include it in this block because it makes this code harder to read.
X = patterns.reshape(patterns.shape[0], -1)
W = (X.T @ X) / X.shape[1]
W = W.at[jnp.diag_indices(W.shape[0])].set(0)
def hopfield_energy(state, beta=100):
    s = state.reshape(-1)
    return -0.5 * s @ W @ s * beta

  
MHState = namedtuple("MHState", ["spins", "rng_key"])
class HopfieldMH(MCMCKernel):
    sample_field = "spins"

    def __init__(self, potential_fn=None, **kwargs):
        super().__init__()
        self.potential_fn = potential_fn

    def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
        return MHState(init_params, rng_key)
        
    # Attempt a simple flip of a random state using the Metropolis-Hastings algorithm.
    def sample(self, state, model_args, model_kwargs):
        spins, rng_key = state
        rng_key, key_i, key_j, key_accept = jax.random.split(rng_key, 4)

        random_i = jax.random.randint(key_i, (), 0, spins.shape[0], dtype=jnp.int32)
        random_j = jax.random.randint(key_j, (), 0, spins.shape[1], dtype=jnp.int32)
        
        new_spins = spins.at[random_i, random_j].set(spins[random_i, random_j] * -1)

        accept_prob = jnp.exp(self.potential_fn(spins) - self.potential_fn(new_spins))
        new_spins = jnp.where(dist.Uniform().sample(key_accept) &lt; accept_prob, new_spins, spins)

        return MHState(new_spins, rng_key)

    def postprocess_fn(self, model_args, model_kwargs):
        return identity
  

if __name__ == "__main__":
    key = jax.random.PRNGKey(1)
    random_state = dist.Bernoulli(0.5).sample(key, sample_shape=(64, 64)) * 2 - 1

    kernel = HopfieldMH(hopfield_energy)
    mcmc = MCMC(kernel, num_warmup=0, num_samples=20000)
    mcmc.run(key, init_params=random_state)
    posterior_samples = mcmc.get_samples()
</code></pre>
<p>As we can see, we can start the grid from complete noise, and obtain our cat at the end.</p>

<p align="center">
<img src="/assets/hopfield.gif" width="200" />
</p>

<hr />
<p>Overall, MRFs provide a flexible framework for defining structured distributions over high-dimensional spaces. While exact inference is often intractable, even simple sampling schemes can produce  visually interesting behavior.</p>]]></content><author><name></name></author><category term="Probability" /><category term="probability" /><summary type="html"><![CDATA[Markov random fields (MRFs) are an interesting class of statistical models that originate in physics. They can be thought of as a collection of random variables where we allow arbitrary joint dependencies between variables. These dependencies are often not all-to-all and can frequently be represented as a graph, where random variables are nodes and edges indicate interactions. These interactions are typically written as potentials (log-likelihood contributions) of the state under the model. MRFs have strong conceptual overlap with belief networks (or Bayesian networks). However, belief networks are generally more tractable for inference, as their directed structure enables algorithms like belief propagation and variable elimination to compute likelihoods efficiently. Random fields, on the other hand, allow more general dependency structures and therefore often require approximate inference methods such as loopy belief propagation or mean-field approximations. I am considering working with MRFs in some future work, so I thought I would write a quick introduction to sampling from them using the probabilistic programming library NumPyro. Here, I do not consider inference of model parameters and instead focus on sampling from these models. The visualizations associated with sampling these models can be quite cool, as I show below. Let’s begin with a basic random field. This is simply a collection of random variables \[\{X_s : s \in S\}\] Typically, we choose \(S\) (the index set) to be a grid (see the visualizations below for examples of these grids evolving over time). We obtain a Markov random field by imposing a Markov property with respect to a graph \(G = (V, E)\), such that \[X_i \perp X_{V \setminus \{i \cup N(i)\}} \mid X_{N(i)}\] where \(N(i)\) denotes the neighbors (the Markov blanket) of node \(i\). In other words, each node is conditionally independent of the rest of the graph given its neighbors. The joint distribution of the full system can then be written as a product over cliques of the graph: \[p(x) \propto \prod_{c \in C} \psi_c(x_c)\] It is worth noting that this definition is extremely general and the set of cliques could, in principle, be the entire graph. Below, we consider two cases of MRFs: the Ising model (a classic model from statistical physics with local interactions) and a Hopfield network (a fully connected model from classical machine learning). The Ising Model The Ising model is the prototypical MRF. This model was originally developed to describe the spin of atoms in a lattice. In this model, each variable only interacts with its local neighbors, resulting in a small Markov blanket. We consider a grid \(S\) of size \(n \times n\), where each variable takes values \[X_{ij} \in \{-1, +1\}, \quad i,j \in \{1,2,\dots,n\}\] We introduce two parameters: \(J\): The interaction strength between neighboring spins \(h\): An external field bias A large \(J\) encourages neighboring variables to align (take the same sign), while \(h\) introduces a global bias towards one sign or the other. The probability of a realization of the field \(x\) is given by \[p(x) \propto \exp\left( \sum_{(i,j,k,l) \in E} J x_{ij} x_{kl} + \sum_{i,j \in S} h x_{ij} \right)\] where \(E\) indexes adjacent cells. To sample from this model, we need to define its energy (potential) and implement a Metropolis–Hastings sampler. Below is a NumPyro implementation: import jax import jax.numpy as jnp import numpyro.distributions as dist from numpyro.infer.mcmc import MCMC, MCMCKernel from numpyro.util import identity from collections import namedtuple interaction_strength = 100.0 external_strength = 0.0 def local_ising_energy(state, i, j, beta = 1.0): energy = -1.0 * external_strength * state[i,j] energy -= (i&gt;0) * interaction_strength * state[i,j] * state[i-1,j] energy -= (i &lt; state.shape[0]-1) * interaction_strength * state[i,j] * state[i+1,j] energy -= (j &gt; 0) * interaction_strength * state[i,j] * state[i,j-1] energy -= (j &lt; state.shape[1]-1) * interaction_strength * state[i,j] * state[i,j+1] return beta * energy MHState = namedtuple("MHState", ["spins", "rng_key"]) class IsingMH(MCMCKernel): sample_field = "spins" def __init__(self, potential_fn=None, **kwargs): super().__init__() self.potential_fn = potential_fn def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): return MHState(init_params, rng_key) # Attempt a simple flip of a random state using the Metropolis-Hastings algorithm def sample(self, state, model_args, model_kwargs): spins, rng_key = state rng_key, key_i, key_j, key_accept = jax.random.split(rng_key, 4) random_i = jax.random.randint(key_i, (), 0, spins.shape[0], dtype=jnp.int32) random_j = jax.random.randint(key_j, (), 0, spins.shape[1], dtype=jnp.int32) new_spins = spins.at[random_i, random_j].set(spins[random_i, random_j] * -1) accept_prob = jnp.exp(self.potential_fn(spins, random_i, random_j) - self.potential_fn(new_spins, random_i, random_j)) new_spins = jnp.where(dist.Uniform().sample(key_accept) &lt; accept_prob, new_spins, spins) return MHState(new_spins, rng_key) def postprocess_fn(self, model_args, model_kwargs): return identity if __name__ == "__main__": key = jax.random.PRNGKey(1) random_state = dist.Bernoulli(0.5).sample(key, sample_shape=(64, 64)) * 2 - 1 kernel = IsingMH(local_ising_energy) mcmc = MCMC(kernel, num_warmup=0, num_samples=150000) mcmc.run(key, init_params=random_state) posterior_samples = mcmc.get_samples() Starting from main, we initialize a random grid using a transformed Bernoulli distribution and pass it into our MCMC sampler. The sampler uses a custom kernel implementing Metropolis–Hastings, where each iteration proposes flipping a single spin. Each iteration, this algorithm attempts to flip the state of a random cell from state \(x_{ij}\) to \(x'_{ij}\) and accepts the new flip with probability \[\min \bigg(1, \frac{\exp-\ell(x'_{ij})}{\exp-\ell(x_{ij})} \bigg)\] where \(\ell\) is the local energy contribution. This ensures the Markov chain spends time in each configuration proportional to its probability under the model. We can animate our sampling process using the following code and we can get some interesting visual behavior: import numpy as np import imageio data_np = np.array(posterior_samples)[::100, :, :] frames = ((data_np + 1) / 2 * 255).astype(np.uint8) imageio.mimsave("ising.gif", frames, fps=50) Notice how the interaction strength encourages neighbors of a common state among the cells. Hopfield Networks Hopfield networks are an interesting model from classical machine learning. The key idea is to construct an energy function whose local minima correspond to stored patterns (for example, images). The system can then recover these patterns by evolving toward energy minima. These were studied in the context of memory recall, where the model was given only part of an image and asked to recall the unobserved portion. Here, we do not focus on memory recall, but instead treat the Hopfield network as a distribution to sample from. The probability distribution of a Hopfield network takes the form \[p(x) \propto \exp(-\frac{1}{2} \sum_{i \neq j} W_{ij} x_i x_j)\] where \(W_{ij}\) is the symmetric weight matrix encoding the stored patterns. We can apply many of the same ideas from the Ising model. The primary difference is that the graph is fully connected, meaning every node interacts with every other node. Below, I generated a simple \(64\times 64\) pixel-art cat. This serves as an energy minimum of the Hopfield network: cat_64x64 = jnp.array([[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]], dtype=jnp.int8) patterns = jnp.stack([cat_64x64], axis=0) We can now implement our potential function and energy. The only difference here is that we can’t simply update based on the local energy, as the Markov blanket of the Hopfield network is the whole graph. import jax import jax.numpy as jnp import numpyro.distributions as dist from numpyro.infer.mcmc import MCMC, MCMCKernel from numpyro.util import identity from collections import namedtuple # Include the code from that cat image above. I did not include it in this block because it makes this code harder to read. X = patterns.reshape(patterns.shape[0], -1) W = (X.T @ X) / X.shape[1] W = W.at[jnp.diag_indices(W.shape[0])].set(0) def hopfield_energy(state, beta=100): s = state.reshape(-1) return -0.5 * s @ W @ s * beta MHState = namedtuple("MHState", ["spins", "rng_key"]) class HopfieldMH(MCMCKernel): sample_field = "spins" def __init__(self, potential_fn=None, **kwargs): super().__init__() self.potential_fn = potential_fn def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): return MHState(init_params, rng_key) # Attempt a simple flip of a random state using the Metropolis-Hastings algorithm. def sample(self, state, model_args, model_kwargs): spins, rng_key = state rng_key, key_i, key_j, key_accept = jax.random.split(rng_key, 4) random_i = jax.random.randint(key_i, (), 0, spins.shape[0], dtype=jnp.int32) random_j = jax.random.randint(key_j, (), 0, spins.shape[1], dtype=jnp.int32) new_spins = spins.at[random_i, random_j].set(spins[random_i, random_j] * -1) accept_prob = jnp.exp(self.potential_fn(spins) - self.potential_fn(new_spins)) new_spins = jnp.where(dist.Uniform().sample(key_accept) &lt; accept_prob, new_spins, spins) return MHState(new_spins, rng_key) def postprocess_fn(self, model_args, model_kwargs): return identity if __name__ == "__main__": key = jax.random.PRNGKey(1) random_state = dist.Bernoulli(0.5).sample(key, sample_shape=(64, 64)) * 2 - 1 kernel = HopfieldMH(hopfield_energy) mcmc = MCMC(kernel, num_warmup=0, num_samples=20000) mcmc.run(key, init_params=random_state) posterior_samples = mcmc.get_samples() As we can see, we can start the grid from complete noise, and obtain our cat at the end. Overall, MRFs provide a flexible framework for defining structured distributions over high-dimensional spaces. While exact inference is often intractable, even simple sampling schemes can produce visually interesting behavior.]]></summary></entry><entry><title type="html">When Are Claims Evidence?</title><link href="https://wesley-demontigny.github.io/bayesian/probability/2026/03/15/Claims_as_evidence.html" rel="alternate" type="text/html" title="When Are Claims Evidence?" /><published>2026-03-15T00:00:00+00:00</published><updated>2026-03-15T00:00:00+00:00</updated><id>https://wesley-demontigny.github.io/bayesian/probability/2026/03/15/Claims_as_evidence</id><content type="html" xml:base="https://wesley-demontigny.github.io/bayesian/probability/2026/03/15/Claims_as_evidence.html"><![CDATA[<script type="text/javascript" id="MathJax-script" async="" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js">
</script>

<p>I recently watched a back and forth between two YouTubers who generally engage in discussions on philosophy, epistemology, and debate. The topic was about whether or not claims are evidence. One of the YouTubers (<strong>A</strong>) adamantly claims that claims are not evidence, while the other (<strong>B</strong>) says that this is clearly not the case. In one moment, <strong>B</strong> brought up an example. He considered the case of a friend claiming that he has bought a soccer ball. To <strong>B</strong>, this is clearly evidence of his friend having bought a soccer ball. <strong>A</strong> points out that it is not that simple and that when he treats this as evidence he is using a LOT of background knowledge and circumstantial evidence to come to that conclusion, even if he doesn’t realize it.</p>

<p>I think <strong>A</strong> is on the right track here and that <strong>B</strong> is being a little simplistic, but the language of this discussion is not well suited for the topic. I think some of the language of probability will help us here, because although logical syllogisms of the form</p>

\[C \implies A; \quad C \therefore A\]

<p>are useful in philosophy, real life deals in degrees of uncertainty. Claims can be expressed as a kind of measurement in a latent variable problem.</p>

<p>Let’s say we are interested in the truth of some claim \(z\). Without loss of generality, let us consider \(z\) as a binary variable, either taking on values \(z=0\) or \(z=1\). We can observe some data that (hopefully) indirectly measures \(z\), which we will denote as the vector \(\mathbf{x}\), which contains potentially multiple kinds of measurements \(x_i\) for \(i \in 1,\dots,n\). We want to understand \(\mathbf{x}\) as arising from some kind of emission process from \(z\).</p>

<p>Let us consider the soccer ball example, with \(z=0\) denoting that the friend has not bought the soccer ball and \(z=1\) denoting that our friend has bought the ball. We will consider our measurement of this latent variable \(x\) to be one-dimensional, with \(x=0\) meaning that our friend claims he has not bought a soccer ball and \(x=1\) meaning he claims he has bought a soccer ball.</p>

<p>If our friend is perfectly truthful, then our result is clear and \(x=z\). However, there are a number of things that could complicate this. The easiest thing is to question our friend’s truthfulness. Suppose that our friend is embarrassed about his soccer interests: he will never falsely claim to buy a ball, but may lie about not buying a ball. Let us consider that he lies about his soccer interests with frequency \(1-p\). Suddenly, our observation \(x\) does not perfectly correspond to \(z\). Instead, \(x\) is emitted by \(z\) according to the process</p>

\[x \sim \text{Bernoulli}(p z).\]

<p>That is, when \(z=1\) we get \(x=1\) with probability \(p\), and when \(z=0\) we get \(x=0\) with probability \(1\).</p>

<p>If we are interested in the probability that \(z\) is true, we need to consider the conditional probability of \(z\) given \(x\), denoted \(P(z \mid x)\). In this case, if we are interested in the probability of \(z\) given some observation \(x\), we have</p>

\[\begin{aligned}
P(z=0 \mid x) &amp;\propto P(z=0)P(x \mid z=0)
= 
\begin{cases}
(1-\psi) &amp; \text{if } x=0 \\
0 &amp; \text{if } x=1
\end{cases} \\\\
P(z=1 \mid x) &amp;\propto P(z=1)P(x \mid z=1)
=
\begin{cases}
\psi (1-p) &amp; \text{if } x=0 \\
\psi p &amp; \text{if } x=1
\end{cases}
\end{aligned}\]

<p>Notice that an extra term has appeared here, \(\psi\). This is the prior probability that our friend (or maybe anyone, depending on how you’d like to set it up) would buy a soccer ball. If we consider the case where \(x=0\), that is, he has claimed that he has not bought a soccer ball, we get a probability that he has bought the soccer ball equal to</p>

\[P(z=1 \mid x=0) = 
\frac{\psi (1-p)}{(1-\psi)+\psi (1-p)}
=
\frac{\psi - \psi p}{1 - \psi p}\]

<p>By this very simple introduction of untruthfulness, we have injected some very serious assumptions. If we do not have perfect correspondence of \(x\) to \(z\), we now need to rely not only on the truthfulness of our friend, \(p\), but also on how reasonable it is that he would buy a soccer ball in the first place, \(\psi\).</p>

<p>Notice that if \(\psi\) is extremely low, then our friend saying he has bought a soccer ball is essentially worthless in convincing us. That is, even if our friend is the most truthful person we’ve ever met, if soccer balls are extremely rare then we would probably conclude he’s lying. If we change the soccer ball example to our friend saying he bought a purple alien on the black market, all of a sudden not only do we not change our belief very much about \(z\), but we may update our beliefs about \(p\), the truthfulness of our friend.</p>

<p>Things get even more complicated when our friend could be either untruthful or just gullible (or both). Let us consider the case where we are interested in whether \(z=1\), but there is some distractor \(\tilde z\) that, when \(\tilde z = 1\), can lead our friend into thinking \(z=1\). Suppose in this case our friend believes he has won some sweepstakes from a random email he received, but we know that plenty of these fake emails circulate. In this case we will say that our friend may lie about winning, resulting in \(x=1\) when \(z=0\) with probability \((1-p)\), or that he may have fallen for a false sweepstakes \(\tilde z = 1\) with probability \(q\). We will assume, for simplicity, that \(x=1\) when \(z=1\); that is, our friend would always tell us if he had actually won a sweepstakes (maybe he is down on his luck and would be too excited not to share). This relatively simple real-life problem introduces a lot of different variables into our system. We now have unnormalized probabilities</p>

\[\begin{aligned}
P(z=0 \mid x) &amp;\propto P(z=0)P(x \mid z=0) \\
&amp;= \alpha P(z=0)P(x \mid z=0,\tilde z=1)
+ (1-\alpha)P(z=0)P(x \mid z=0,\tilde z=0) \\
&amp;=
\begin{cases}
p(1-\psi)(\alpha (1-q) + (1-\alpha)) &amp; \text{if } x=0 \\
(1-\psi)(\alpha q + (1-\alpha)(1-p)) &amp; \text{if } x=1
\end{cases}
\\\\
P(z=1 \mid x) &amp;\propto P(z=1)P(x \mid z=1)
=
\begin{cases}
0 &amp; \text{if } x=0 \\
\psi &amp; \text{if } x=1
\end{cases}
\end{aligned}\]

<p>With our normalized probability of observing our friend telling the truth and making the claim being</p>

\[P(z=1 \mid x=1) =
\frac{\psi}{\psi + \alpha q(1-\psi) + (1-\alpha)(1-p)(1-\psi)}.\]

<p>Notice that we now have a mixture of processes happening here. We have a friend that is exposed to fraudulent sweepstakes at a frequency \(\alpha\). Think about how simple this example was. The situation is not particularly complex, and yet we suddenly have a lot of factors to consider. I have actually simplified this a good amount by allowing our friend to perfectly detect true sweepstakes, but what if he couldn’t do that?</p>

<p>If we also had to assess uncertainty in each of the parameters mentioned above (say we aren’t certain about \(\alpha\) or \(p\)), or if we had uncertainty in the generative model itself, then the situation becomes much closer to claims contributing essentially no evidence. Here, I also made the basic assumption that the lying was one-way! This is also generally not the case and would significantly change how we update our beliefs given a claim.</p>

<p>YouTuber <strong>A</strong> was right: the evidential value of claims is incredibly dependent on our prior knowledge of the subject. But I do not know if I would strictly phrase it as “claims not being evidence.” It is more like this: without strong background and domain knowledge on a given subject, claims themselves should not have much sway on our beliefs, as they are extremely weak measurements of the truth.</p>]]></content><author><name></name></author><category term="Bayesian" /><category term="Probability" /><category term="probability" /><summary type="html"><![CDATA[I recently watched a back and forth between two YouTubers who generally engage in discussions on philosophy, epistemology, and debate. The topic was about whether or not claims are evidence. One of the YouTubers (A) adamantly claims that claims are not evidence, while the other (B) says that this is clearly not the case. In one moment, B brought up an example. He considered the case of a friend claiming that he has bought a soccer ball. To B, this is clearly evidence of his friend having bought a soccer ball. A points out that it is not that simple and that when he treats this as evidence he is using a LOT of background knowledge and circumstantial evidence to come to that conclusion, even if he doesn’t realize it. I think A is on the right track here and that B is being a little simplistic, but the language of this discussion is not well suited for the topic. I think some of the language of probability will help us here, because although logical syllogisms of the form \[C \implies A; \quad C \therefore A\] are useful in philosophy, real life deals in degrees of uncertainty. Claims can be expressed as a kind of measurement in a latent variable problem. Let’s say we are interested in the truth of some claim \(z\). Without loss of generality, let us consider \(z\) as a binary variable, either taking on values \(z=0\) or \(z=1\). We can observe some data that (hopefully) indirectly measures \(z\), which we will denote as the vector \(\mathbf{x}\), which contains potentially multiple kinds of measurements \(x_i\) for \(i \in 1,\dots,n\). We want to understand \(\mathbf{x}\) as arising from some kind of emission process from \(z\). Let us consider the soccer ball example, with \(z=0\) denoting that the friend has not bought the soccer ball and \(z=1\) denoting that our friend has bought the ball. We will consider our measurement of this latent variable \(x\) to be one-dimensional, with \(x=0\) meaning that our friend claims he has not bought a soccer ball and \(x=1\) meaning he claims he has bought a soccer ball. If our friend is perfectly truthful, then our result is clear and \(x=z\). However, there are a number of things that could complicate this. The easiest thing is to question our friend’s truthfulness. Suppose that our friend is embarrassed about his soccer interests: he will never falsely claim to buy a ball, but may lie about not buying a ball. Let us consider that he lies about his soccer interests with frequency \(1-p\). Suddenly, our observation \(x\) does not perfectly correspond to \(z\). Instead, \(x\) is emitted by \(z\) according to the process \[x \sim \text{Bernoulli}(p z).\] That is, when \(z=1\) we get \(x=1\) with probability \(p\), and when \(z=0\) we get \(x=0\) with probability \(1\). If we are interested in the probability that \(z\) is true, we need to consider the conditional probability of \(z\) given \(x\), denoted \(P(z \mid x)\). In this case, if we are interested in the probability of \(z\) given some observation \(x\), we have \[\begin{aligned} P(z=0 \mid x) &amp;\propto P(z=0)P(x \mid z=0) = \begin{cases} (1-\psi) &amp; \text{if } x=0 \\ 0 &amp; \text{if } x=1 \end{cases} \\\\ P(z=1 \mid x) &amp;\propto P(z=1)P(x \mid z=1) = \begin{cases} \psi (1-p) &amp; \text{if } x=0 \\ \psi p &amp; \text{if } x=1 \end{cases} \end{aligned}\] Notice that an extra term has appeared here, \(\psi\). This is the prior probability that our friend (or maybe anyone, depending on how you’d like to set it up) would buy a soccer ball. If we consider the case where \(x=0\), that is, he has claimed that he has not bought a soccer ball, we get a probability that he has bought the soccer ball equal to \[P(z=1 \mid x=0) = \frac{\psi (1-p)}{(1-\psi)+\psi (1-p)} = \frac{\psi - \psi p}{1 - \psi p}\] By this very simple introduction of untruthfulness, we have injected some very serious assumptions. If we do not have perfect correspondence of \(x\) to \(z\), we now need to rely not only on the truthfulness of our friend, \(p\), but also on how reasonable it is that he would buy a soccer ball in the first place, \(\psi\). Notice that if \(\psi\) is extremely low, then our friend saying he has bought a soccer ball is essentially worthless in convincing us. That is, even if our friend is the most truthful person we’ve ever met, if soccer balls are extremely rare then we would probably conclude he’s lying. If we change the soccer ball example to our friend saying he bought a purple alien on the black market, all of a sudden not only do we not change our belief very much about \(z\), but we may update our beliefs about \(p\), the truthfulness of our friend. Things get even more complicated when our friend could be either untruthful or just gullible (or both). Let us consider the case where we are interested in whether \(z=1\), but there is some distractor \(\tilde z\) that, when \(\tilde z = 1\), can lead our friend into thinking \(z=1\). Suppose in this case our friend believes he has won some sweepstakes from a random email he received, but we know that plenty of these fake emails circulate. In this case we will say that our friend may lie about winning, resulting in \(x=1\) when \(z=0\) with probability \((1-p)\), or that he may have fallen for a false sweepstakes \(\tilde z = 1\) with probability \(q\). We will assume, for simplicity, that \(x=1\) when \(z=1\); that is, our friend would always tell us if he had actually won a sweepstakes (maybe he is down on his luck and would be too excited not to share). This relatively simple real-life problem introduces a lot of different variables into our system. We now have unnormalized probabilities \[\begin{aligned} P(z=0 \mid x) &amp;\propto P(z=0)P(x \mid z=0) \\ &amp;= \alpha P(z=0)P(x \mid z=0,\tilde z=1) + (1-\alpha)P(z=0)P(x \mid z=0,\tilde z=0) \\ &amp;= \begin{cases} p(1-\psi)(\alpha (1-q) + (1-\alpha)) &amp; \text{if } x=0 \\ (1-\psi)(\alpha q + (1-\alpha)(1-p)) &amp; \text{if } x=1 \end{cases} \\\\ P(z=1 \mid x) &amp;\propto P(z=1)P(x \mid z=1) = \begin{cases} 0 &amp; \text{if } x=0 \\ \psi &amp; \text{if } x=1 \end{cases} \end{aligned}\] With our normalized probability of observing our friend telling the truth and making the claim being \[P(z=1 \mid x=1) = \frac{\psi}{\psi + \alpha q(1-\psi) + (1-\alpha)(1-p)(1-\psi)}.\] Notice that we now have a mixture of processes happening here. We have a friend that is exposed to fraudulent sweepstakes at a frequency \(\alpha\). Think about how simple this example was. The situation is not particularly complex, and yet we suddenly have a lot of factors to consider. I have actually simplified this a good amount by allowing our friend to perfectly detect true sweepstakes, but what if he couldn’t do that? If we also had to assess uncertainty in each of the parameters mentioned above (say we aren’t certain about \(\alpha\) or \(p\)), or if we had uncertainty in the generative model itself, then the situation becomes much closer to claims contributing essentially no evidence. Here, I also made the basic assumption that the lying was one-way! This is also generally not the case and would significantly change how we update our beliefs given a claim. YouTuber A was right: the evidential value of claims is incredibly dependent on our prior knowledge of the subject. But I do not know if I would strictly phrase it as “claims not being evidence.” It is more like this: without strong background and domain knowledge on a given subject, claims themselves should not have much sway on our beliefs, as they are extremely weak measurements of the truth.]]></summary></entry><entry><title type="html">Sequential Neural Likelihood Estimation with C++</title><link href="https://wesley-demontigny.github.io/mcmc/neural%20networks/2025/12/22/SNLE_in_Cpp.html" rel="alternate" type="text/html" title="Sequential Neural Likelihood Estimation with C++" /><published>2025-12-22T00:00:00+00:00</published><updated>2025-12-22T00:00:00+00:00</updated><id>https://wesley-demontigny.github.io/mcmc/neural%20networks/2025/12/22/SNLE_in_Cpp</id><content type="html" xml:base="https://wesley-demontigny.github.io/mcmc/neural%20networks/2025/12/22/SNLE_in_Cpp.html"><![CDATA[<script type="text/javascript" id="MathJax-script" async="" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js">
</script>

<p>A recent pre-print by <a href="https://arxiv.org/abs/2510.12976">Blassel et al. (2025)</a> renewed my interest in simulation-based inference. In this post, I revisit my <a href="https://wesley-demontigny.github.io/mcmc/neural%20networks/2025/05/22/SNLE_Lotka_Volterra.html">sequential neural likelihood estimator</a> for the partially observed stochastic Lotka–Volterra model, reimplemented in C++ using LibTorch and Boost. I assume familiarity with that earlier post and with SNLE in general, and concentrate here on providing C++ implementation. The model itself is nearly identical, with the exception that predator birth is now directly tied to predation in the Gillespie simulator. I also decided that this time I would run Markov chain Monte Carlo with my neural likelihood and determine how good my posterior distributions for the parameters \(\beta_{birth}\), \(\beta_{predation}\), \(\beta_{death}\). In this case, we had \((\beta_{birth},\beta_{predation}, \beta_{death}) = (0.75,0.01,0.9)\). The posterior mass concentrates around the true parameters, indicating that our neural likelihood captures the relevant structure of the simulator for these parameter choices.
<img src="/assets/lotka_volterra_posterior.png" alt="Neural Lotka–Volterra Dynamics" /></p>

<h3 id="cmakeliststxt">CMakeLists.txt</h3>
<pre><code class="language-CMake">cmake_minimum_required(VERSION 3.20)
project(SNLE-LV LANGUAGES CXX)

set(CMAKE_CXX_STANDARD 23)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)

set(CMAKE_PREFIX_PATH "/libtorch/")
find_package(Torch REQUIRED)
find_package(Boost REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")

add_executable(snle_lv snle_lv.cpp)
target_link_libraries(snle_lv "${TORCH_LIBRARIES}")
target_link_libraries(snle_lv Boost::boost) 
</code></pre>
<h3 id="snle_lvcpp">snle_lv.cpp</h3>
<pre><code class="language-C++">#include &lt;iostream&gt;
#include &lt;fstream&gt;
#include &lt;torch/torch.h&gt;
#include &lt;boost/random/mersenne_twister.hpp&gt;
#include &lt;boost/random/exponential_distribution.hpp&gt;
#include &lt;boost/random/uniform_01.hpp&gt;
#include &lt;boost/random/exponential_distribution.hpp&gt;

/**
 * @brief This is the conditional autoregressive normalizing flow we will use to estimate the
 * likelihood. The general layout here is that there are some embedding layer that takes in the conditionals
 * of the probability distribution
 */
struct ConditionalAutoregressiveFlow : torch::nn::Module {
    ConditionalAutoregressiveFlow(void) = delete;
    ConditionalAutoregressiveFlow(int nF = 3) : numFlows(nF), 
                                                l1e(register_module("l1e", torch::nn::Linear(5, 16))) {

        register_module("flow1Lin", flow1Lin);
        register_module("flow2Lin", flow2Lin);
        register_module("flow1Param", flow1Param);
        register_module("flow2Param", flow2Param);
        
        for(int i = 0; i &lt; numFlows; i++){
            flow1Lin-&gt;push_back(torch::nn::Linear(16, 16));
            flow2Lin-&gt;push_back(torch::nn::Linear(17, 16));
            flow1Param-&gt;push_back(torch::nn::Linear(16, 2));
            flow2Param-&gt;push_back(torch::nn::Linear(16, 2));
        }
    }

    /**
     * @brief Sends data in the normalizing direction and returns the NLL from the Base distribution
     */
    torch::Tensor forward(torch::Tensor augmentedData){
        auto e = torch::log(augmentedData.index({torch::indexing::Slice(), torch::indexing::Slice(0, 5)}).clone());
        auto data = torch::log(augmentedData.index({torch::indexing::Slice(), torch::indexing::Slice(5, 7)}).clone());
        int conditionalDim = 1;
        torch::Tensor jacobianSum = torch::sum(data, -1);

        e = torch::relu(l1e(e));

        for(int i = 0; i &lt; numFlows; i++){
            // Handle Non-Conditional Affine
            auto localEmbed = torch::relu(flow1Lin[i]-&gt;as&lt;torch::nn::Linear&gt;()-&gt;forward(e));
            auto p1 = flow1Param[i]-&gt;as&lt;torch::nn::Linear&gt;()-&gt;forward(localEmbed);
            auto scale1 = torch::softplus(p1.index({torch::indexing::Slice(), 1}));
            auto shift1 = p1.index({torch::indexing::Slice(), 0});
            auto nonCondView = data.index({torch::indexing::Slice(), conditionalDim ^ 1});
            nonCondView.div_(scale1);
            nonCondView.sub_(shift1);
            jacobianSum -= torch::log(scale1); 

            // Handle Conditional Affine
            localEmbed = torch::relu(
                flow2Lin[i]-&gt;as&lt;torch::nn::Linear&gt;()-&gt;forward(
                    torch::concat({e, nonCondView.unsqueeze(-1)}, -1)
                )
            );
            auto p2 = flow2Param[i]-&gt;as&lt;torch::nn::Linear&gt;()-&gt;forward(localEmbed);
            auto scale2 = torch::softplus(p2.index({torch::indexing::Slice(), 1}));
            auto shift2 = p2.index({torch::indexing::Slice(), 0});
            auto condView = data.index({torch::indexing::Slice(), conditionalDim});
            condView.div_(scale2);
            condView.sub_(shift2);
            jacobianSum -= torch::log(scale2);

            conditionalDim ^= 1; // Flip Conditionality
        }

        torch::apply([](auto&amp; datum){ datum = -0.5 * (datum*datum); }, data);
        data = torch::sum(data, -1);
        if(torch::any(torch::isnan(data)).item&lt;bool&gt;()){
            std::cout &lt;&lt; augmentedData.index({torch::nonzero(torch::isnan(data)).squeeze()}).slice(-1, -4) &lt;&lt; std::endl;
            std::cout &lt;&lt; "Detected NaNs! Dumping all values ^" &lt;&lt; std::endl;
            std::exit(1);
        }
        return -1.0 * (data + jacobianSum); // Apply the Jacobian and return NLL
    }

    // @todo Add sampling function

    void train(int epochs, torch::Tensor data, double threshold = 1e-6){
        torch::optim::Adam optim(
            this-&gt;parameters(),
            torch::optim::AdamOptions(1e-3).betas(std::make_tuple(0.5, 0.5))
        );

        int patience = 0;
        double lastBestLoss = INFINITY;

        for(int epoch = 1; epoch &lt;= epochs; epoch++){

            this-&gt;zero_grad();
            auto loss = this-&gt;forward(data).mean(); // Mean NLL Loss
            double lossScalar = loss.item&lt;double&gt;();
            loss.backward();
            torch::nn::utils::clip_grad_norm_(this-&gt;parameters(), 1.0); // Clip Gradients For Stability
            optim.step();

            if(lossScalar + threshold &lt; lastBestLoss){
                patience = 0;
                lastBestLoss = lossScalar;
            }
            else{
                patience++;
            }

            if(patience &gt; 50){
                std::cout &lt;&lt; "Early Stopping at " &lt;&lt; epoch &lt;&lt; std::endl;
                break;
            }

            if(epoch % 10 == 0)
                std::cout &lt;&lt; "Average Loss of {" &lt;&lt; lossScalar &lt;&lt; "} at Epoch " &lt;&lt; epoch &lt;&lt; std::endl;
        }
    }

    private:
        int numFlows;

        // Embedding Block (Generate an internal representation from the context)
        torch::nn::Linear l1e;
        // Flow Blocks
        torch::nn::ModuleList flow1Lin;
        torch::nn::ModuleList flow2Lin;
        torch::nn::ModuleList flow1Param;
        torch::nn::ModuleList flow2Param;
};

/**
 * @brief Generate data under the stochastic Lotka-Voltera model.
 */
torch::Tensor generativeModel(boost::random::mt19937&amp; rng, double time, int numSamples, torch::Tensor initState, double birthRate, double predationRate, double deathRate){
    // Put a floor on these values
    birthRate += 1e-2;    
    predationRate += 1e-2;
    deathRate += 1e-2;

    torch::Tensor currentState = initState;
    double currentTime = 0.0;
    double sampleFrequency = time / (double)(numSamples);
    int currentSample = 1;
    torch::Tensor samples = torch::zeros({numSamples, 2});
    samples.index_put_({0}, initState);

    // Simulate with the Gillespie algorithm
    while(currentSample &lt; numSamples){

        double currentPrey = currentState[0].item&lt;double&gt;();
        double currentPredators = currentState[1].item&lt;double&gt;();
        if(currentPrey == 0.0 &amp;&amp; currentPredators == 0.0){ // Quit upon total extinction
            break;
        }

        double birthPoisson = birthRate * currentPrey;
        double predationPoisson = predationRate * currentPredators * currentPrey;
        double deathPoisson = deathRate * currentPredators;

        double exponentialRaceRate = birthPoisson + predationPoisson + deathPoisson;
        double waitingTime = boost::random::exponential_distribution&lt;double&gt;{exponentialRaceRate}(rng);

        currentTime += waitingTime;

        double randomUpdate = boost::random::uniform_01&lt;double&gt;{}(rng);
        if(randomUpdate &lt; birthPoisson / exponentialRaceRate){
            currentPrey++;
        }
        else if(randomUpdate &lt; (birthPoisson + predationPoisson) / exponentialRaceRate){
            currentPrey--;
            currentPredators++;
        }
        else{
            currentPredators--;
        }

        currentState.index_put_({0}, currentPrey);
        currentState.index_put_({1}, currentPredators);
            
        // Catch up on (potentiall multiple) sampling events
        for(; currentTime &gt;= sampleFrequency * (double)currentSample &amp;&amp; currentSample &lt; numSamples; currentSample++){
            samples.index_put_({currentSample}, currentState);
        }
        
    }

    auto shiftedSamples = torch::roll(samples, 1, 0);
    auto multiplicativeIncrease = torch::divide(samples, shiftedSamples);
    multiplicativeIncrease = torch::where(torch::isnan(multiplicativeIncrease), 1e-6, multiplicativeIncrease);
    multiplicativeIncrease = torch::where(multiplicativeIncrease == 0.0, 1e-6, multiplicativeIncrease);
    shiftedSamples = torch::where(shiftedSamples == 0.0, 1e-6, shiftedSamples);

    auto augmentedData = torch::concat({shiftedSamples, multiplicativeIncrease}, -1);
    augmentedData = augmentedData.slice(0, 1, augmentedData.size(0));

    return augmentedData;
}

int main(int argc, char** argv){
    auto rng = boost::random::mt19937{1124};

    double initPrey = 60.0;
    double initPredator = 10.0;
    double truePreyBirth = 0.75;
    double truePredation = 0.01;
    double truePredatorDeath = 0.9;
    double simulationTime = 40;
    int numSamples = 200;
    auto initState = torch::tensor({60.0, 10.0});

    std::cout &lt;&lt; "Simulating Data..." &lt;&lt; std::endl;
    auto simulatedData = generativeModel(rng, simulationTime, numSamples, initState.clone(), truePreyBirth, truePredation, truePredatorDeath);
    std::cout &lt;&lt; simulatedData &lt;&lt; std::endl;

    auto trainingData = torch::empty({0, 7});;

    int numSNLEIterations = 4;
    int mcmcIterations = 20000;
    int samplingFrequency = 1000;
    int priorSamples = 100;
    int samplingIterations = 10;
    std::array&lt;double, 3&gt; lambdas = {3.0, 20.0, 3.0};
    auto unif = boost::random::uniform_01();

    for(int i = 1; i &lt;= numSNLEIterations; i++){
        std::cout &lt;&lt; "Starting SNLE Iteration " &lt;&lt; i &lt;&lt; std::endl;

        if(i &gt; 1){
            ConditionalAutoregressiveFlow neuralLikelihood(4);
            neuralLikelihood.train(5000, trainingData);

            std::array&lt;double, 3&gt; currentParams = {
                boost::random::exponential_distribution{lambdas[0]}(rng), 
                boost::random::exponential_distribution{lambdas[1]}(rng), 
                boost::random::exponential_distribution{lambdas[2]}(rng)
            };
            double currentPrior = 0.0;
            for(int i = 0; i &lt; 3; i++){
                currentPrior += std::log(lambdas[i]) - lambdas[i] * currentParams[i];
            }
            
            auto simDim = simulatedData.size(0);
            double currentLL = -1.0 * neuralLikelihood.forward(
                torch::concat({
                    torch::full({simDim, 1}, currentParams[0]),
                    torch::full({simDim, 1}, currentParams[1]),
                    torch::full({simDim, 1}, currentParams[2]),
                    simulatedData
                }, 1)
            ).sum().item&lt;double&gt;();

            for(int j = 1; j &lt;= mcmcIterations; j++){

                for(int gibbsIter = 0; gibbsIter &lt; 3; gibbsIter++){
                    double scalingFactor = std::exp(1.5 * (unif(rng) - 0.5));
                    std::array&lt;double, 3&gt; newParams = currentParams;
                    newParams[gibbsIter] = currentParams[gibbsIter] * scalingFactor;
                    double newPrior = 0.0;
                    for(int i = 0; i &lt; 3; i++){
                        newPrior += std::log(lambdas[i]) - lambdas[i] * newParams[i];
                    }

                    double newLL = -1.0 * neuralLikelihood.forward(
                        torch::concat({
                            torch::full({simDim, 1}, newParams[0]),
                            torch::full({simDim, 1}, newParams[1]),
                            torch::full({simDim, 1}, newParams[2]),
                            simulatedData
                        }, 1)
                    ).sum().item&lt;double&gt;();

                    double logRatio = std::log(scalingFactor) + newLL - currentLL + newPrior - currentPrior;
                    if(std::log(unif(rng)) &lt;= logRatio){
                        currentPrior = newPrior;
                        currentLL = newLL;
                        currentParams = newParams;
                    }
                }

                if(j % samplingFrequency == 0){
                    for(int simIter = 0; simIter &lt; samplingIterations; simIter++){
                        auto newData = generativeModel(rng, simulationTime/4.0, numSamples/4, initState.clone(), currentParams[0], currentParams[1], currentParams[2]);
                        auto newDataDim = newData.size(0);
                        newData = torch::concat({
                            torch::full({newDataDim, 1}, currentParams[0]),
                            torch::full({newDataDim, 1}, currentParams[1]),
                            torch::full({newDataDim, 1}, currentParams[2]),
                            newData
                        }, 1);

                        trainingData = torch::concat({trainingData, newData}, 0);
                    }

                    std::cout &lt;&lt; j &lt;&lt; " ( " &lt;&lt; currentLL &lt;&lt; " ):\tBirth: " &lt;&lt; currentParams[0] &lt;&lt; "\t Predation: " &lt;&lt; currentParams[1] &lt;&lt; "\tDeath: " &lt;&lt; currentParams[2] &lt;&lt; std::endl;
                }
            }
        }
        else {
            for(int j = 1; j &lt;= priorSamples; j++){
                std::array&lt;double, 3&gt; currentParams = {
                    boost::random::exponential_distribution{lambdas[0]}(rng), 
                    boost::random::exponential_distribution{lambdas[1]}(rng), 
                    boost::random::exponential_distribution{lambdas[2]}(rng)
                };

                for(int simIter = 0; simIter &lt; samplingIterations; simIter++){
                    auto newData = generativeModel(rng, simulationTime/4.0, numSamples/4, initState.clone(), currentParams[0], currentParams[1], currentParams[2]);
                    auto newDataDim = newData.size(0);
                    newData = torch::concat({
                        torch::full({newDataDim, 1}, currentParams[0]),
                        torch::full({newDataDim, 1}, currentParams[1]),
                        torch::full({newDataDim, 1}, currentParams[2]),
                        newData
                    }, 1);

                    trainingData = torch::concat({trainingData, newData}, 0);
                }

                std::cout &lt;&lt; j &lt;&lt; "\tBirth: " &lt;&lt; currentParams[0] &lt;&lt; "\t Predation: " &lt;&lt; currentParams[1] &lt;&lt; "\tDeath: " &lt;&lt; currentParams[2] &lt;&lt; std::endl;
            }
        }
    }

    // Run MCMC one final time to get a trace file
    ConditionalAutoregressiveFlow neuralLikelihood(4);
    neuralLikelihood.train(5000, trainingData);

    int mcmcSamplingFreq = 100;
    int finalMcmcIter = 50000;
    std::ofstream outputFile("./output.trace");
    outputFile &lt;&lt; "Iteration\tPosterior\tBirthRate\tPredationRate\tDeathRate" &lt;&lt; std::endl;

    std::array&lt;double, 3&gt; currentParams = {
        boost::random::exponential_distribution{lambdas[0]}(rng), 
        boost::random::exponential_distribution{lambdas[1]}(rng), 
        boost::random::exponential_distribution{lambdas[2]}(rng)
    };
    double currentPrior = 0.0;
    for(int i = 0; i &lt; 3; i++){
        currentPrior += std::log(lambdas[i]) - lambdas[i] * currentParams[i];
    }
    
    auto simDim = simulatedData.size(0);
    double currentLL = -1.0 * neuralLikelihood.forward(
        torch::concat({
            torch::full({simDim, 1}, currentParams[0]),
            torch::full({simDim, 1}, currentParams[1]),
            torch::full({simDim, 1}, currentParams[2]),
            simulatedData
        }, 1)
    ).sum().item&lt;double&gt;();

    for(int j = 1; j &lt;= finalMcmcIter; j++){

        for(int gibbsIter = 0; gibbsIter &lt; 3; gibbsIter++){
            double scalingFactor = std::exp(1.5 * (unif(rng) - 0.5));
            std::array&lt;double, 3&gt; newParams = currentParams;
            newParams[gibbsIter] = currentParams[gibbsIter] * scalingFactor;
            double newPrior = 0.0;
            for(int i = 0; i &lt; 3; i++){
                newPrior += std::log(lambdas[i]) - lambdas[i] * newParams[i];
            }

            double newLL = -1.0 * neuralLikelihood.forward(
                torch::concat({
                    torch::full({simDim, 1}, newParams[0]),
                    torch::full({simDim, 1}, newParams[1]),
                    torch::full({simDim, 1}, newParams[2]),
                    simulatedData
                }, 1)
            ).sum().item&lt;double&gt;();

            double logRatio = std::log(scalingFactor) + newLL - currentLL + newPrior - currentPrior;
            if(std::log(unif(rng)) &lt;= logRatio){
                currentPrior = newPrior;
                currentLL = newLL;
                currentParams = newParams;
            }
        }

        if(j % mcmcSamplingFreq == 0){
            outputFile &lt;&lt; j &lt;&lt; "\t" &lt;&lt; currentLL + currentPrior &lt;&lt; "\t" &lt;&lt; currentParams[0] &lt;&lt; "\t" &lt;&lt; currentParams[1] &lt;&lt; "\t" &lt;&lt; currentParams[2] &lt;&lt; std::endl;
            std::cout &lt;&lt; j &lt;&lt; " ( " &lt;&lt; currentLL &lt;&lt; " ):\tBirth: " &lt;&lt; currentParams[0] &lt;&lt; "\t Predation: " &lt;&lt; currentParams[1] &lt;&lt; "\tDeath: " &lt;&lt; currentParams[2] &lt;&lt; std::endl;
        }
    }


    return 0;

}
</code></pre>]]></content><author><name></name></author><category term="MCMC" /><category term="Neural Networks" /><category term="mcmc" /><category term="neural networks" /><category term="c++" /><summary type="html"><![CDATA[A recent pre-print by Blassel et al. (2025) renewed my interest in simulation-based inference. In this post, I revisit my sequential neural likelihood estimator for the partially observed stochastic Lotka–Volterra model, reimplemented in C++ using LibTorch and Boost. I assume familiarity with that earlier post and with SNLE in general, and concentrate here on providing C++ implementation. The model itself is nearly identical, with the exception that predator birth is now directly tied to predation in the Gillespie simulator. I also decided that this time I would run Markov chain Monte Carlo with my neural likelihood and determine how good my posterior distributions for the parameters \(\beta_{birth}\), \(\beta_{predation}\), \(\beta_{death}\). In this case, we had \((\beta_{birth},\beta_{predation}, \beta_{death}) = (0.75,0.01,0.9)\). The posterior mass concentrates around the true parameters, indicating that our neural likelihood captures the relevant structure of the simulator for these parameter choices. CMakeLists.txt cmake_minimum_required(VERSION 3.20) project(SNLE-LV LANGUAGES CXX) set(CMAKE_CXX_STANDARD 23) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) set(CMAKE_PREFIX_PATH "/libtorch/") find_package(Torch REQUIRED) find_package(Boost REQUIRED) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") add_executable(snle_lv snle_lv.cpp) target_link_libraries(snle_lv "${TORCH_LIBRARIES}") target_link_libraries(snle_lv Boost::boost) snle_lv.cpp #include &lt;iostream&gt; #include &lt;fstream&gt; #include &lt;torch/torch.h&gt; #include &lt;boost/random/mersenne_twister.hpp&gt; #include &lt;boost/random/exponential_distribution.hpp&gt; #include &lt;boost/random/uniform_01.hpp&gt; #include &lt;boost/random/exponential_distribution.hpp&gt; /** * @brief This is the conditional autoregressive normalizing flow we will use to estimate the * likelihood. The general layout here is that there are some embedding layer that takes in the conditionals * of the probability distribution */ struct ConditionalAutoregressiveFlow : torch::nn::Module { ConditionalAutoregressiveFlow(void) = delete; ConditionalAutoregressiveFlow(int nF = 3) : numFlows(nF), l1e(register_module("l1e", torch::nn::Linear(5, 16))) { register_module("flow1Lin", flow1Lin); register_module("flow2Lin", flow2Lin); register_module("flow1Param", flow1Param); register_module("flow2Param", flow2Param); for(int i = 0; i &lt; numFlows; i++){ flow1Lin-&gt;push_back(torch::nn::Linear(16, 16)); flow2Lin-&gt;push_back(torch::nn::Linear(17, 16)); flow1Param-&gt;push_back(torch::nn::Linear(16, 2)); flow2Param-&gt;push_back(torch::nn::Linear(16, 2)); } } /** * @brief Sends data in the normalizing direction and returns the NLL from the Base distribution */ torch::Tensor forward(torch::Tensor augmentedData){ auto e = torch::log(augmentedData.index({torch::indexing::Slice(), torch::indexing::Slice(0, 5)}).clone()); auto data = torch::log(augmentedData.index({torch::indexing::Slice(), torch::indexing::Slice(5, 7)}).clone()); int conditionalDim = 1; torch::Tensor jacobianSum = torch::sum(data, -1); e = torch::relu(l1e(e)); for(int i = 0; i &lt; numFlows; i++){ // Handle Non-Conditional Affine auto localEmbed = torch::relu(flow1Lin[i]-&gt;as&lt;torch::nn::Linear&gt;()-&gt;forward(e)); auto p1 = flow1Param[i]-&gt;as&lt;torch::nn::Linear&gt;()-&gt;forward(localEmbed); auto scale1 = torch::softplus(p1.index({torch::indexing::Slice(), 1})); auto shift1 = p1.index({torch::indexing::Slice(), 0}); auto nonCondView = data.index({torch::indexing::Slice(), conditionalDim ^ 1}); nonCondView.div_(scale1); nonCondView.sub_(shift1); jacobianSum -= torch::log(scale1); // Handle Conditional Affine localEmbed = torch::relu( flow2Lin[i]-&gt;as&lt;torch::nn::Linear&gt;()-&gt;forward( torch::concat({e, nonCondView.unsqueeze(-1)}, -1) ) ); auto p2 = flow2Param[i]-&gt;as&lt;torch::nn::Linear&gt;()-&gt;forward(localEmbed); auto scale2 = torch::softplus(p2.index({torch::indexing::Slice(), 1})); auto shift2 = p2.index({torch::indexing::Slice(), 0}); auto condView = data.index({torch::indexing::Slice(), conditionalDim}); condView.div_(scale2); condView.sub_(shift2); jacobianSum -= torch::log(scale2); conditionalDim ^= 1; // Flip Conditionality } torch::apply([](auto&amp; datum){ datum = -0.5 * (datum*datum); }, data); data = torch::sum(data, -1); if(torch::any(torch::isnan(data)).item&lt;bool&gt;()){ std::cout &lt;&lt; augmentedData.index({torch::nonzero(torch::isnan(data)).squeeze()}).slice(-1, -4) &lt;&lt; std::endl; std::cout &lt;&lt; "Detected NaNs! Dumping all values ^" &lt;&lt; std::endl; std::exit(1); } return -1.0 * (data + jacobianSum); // Apply the Jacobian and return NLL } // @todo Add sampling function void train(int epochs, torch::Tensor data, double threshold = 1e-6){ torch::optim::Adam optim( this-&gt;parameters(), torch::optim::AdamOptions(1e-3).betas(std::make_tuple(0.5, 0.5)) ); int patience = 0; double lastBestLoss = INFINITY; for(int epoch = 1; epoch &lt;= epochs; epoch++){ this-&gt;zero_grad(); auto loss = this-&gt;forward(data).mean(); // Mean NLL Loss double lossScalar = loss.item&lt;double&gt;(); loss.backward(); torch::nn::utils::clip_grad_norm_(this-&gt;parameters(), 1.0); // Clip Gradients For Stability optim.step(); if(lossScalar + threshold &lt; lastBestLoss){ patience = 0; lastBestLoss = lossScalar; } else{ patience++; } if(patience &gt; 50){ std::cout &lt;&lt; "Early Stopping at " &lt;&lt; epoch &lt;&lt; std::endl; break; } if(epoch % 10 == 0) std::cout &lt;&lt; "Average Loss of {" &lt;&lt; lossScalar &lt;&lt; "} at Epoch " &lt;&lt; epoch &lt;&lt; std::endl; } } private: int numFlows; // Embedding Block (Generate an internal representation from the context) torch::nn::Linear l1e; // Flow Blocks torch::nn::ModuleList flow1Lin; torch::nn::ModuleList flow2Lin; torch::nn::ModuleList flow1Param; torch::nn::ModuleList flow2Param; }; /** * @brief Generate data under the stochastic Lotka-Voltera model. */ torch::Tensor generativeModel(boost::random::mt19937&amp; rng, double time, int numSamples, torch::Tensor initState, double birthRate, double predationRate, double deathRate){ // Put a floor on these values birthRate += 1e-2; predationRate += 1e-2; deathRate += 1e-2; torch::Tensor currentState = initState; double currentTime = 0.0; double sampleFrequency = time / (double)(numSamples); int currentSample = 1; torch::Tensor samples = torch::zeros({numSamples, 2}); samples.index_put_({0}, initState); // Simulate with the Gillespie algorithm while(currentSample &lt; numSamples){ double currentPrey = currentState[0].item&lt;double&gt;(); double currentPredators = currentState[1].item&lt;double&gt;(); if(currentPrey == 0.0 &amp;&amp; currentPredators == 0.0){ // Quit upon total extinction break; } double birthPoisson = birthRate * currentPrey; double predationPoisson = predationRate * currentPredators * currentPrey; double deathPoisson = deathRate * currentPredators; double exponentialRaceRate = birthPoisson + predationPoisson + deathPoisson; double waitingTime = boost::random::exponential_distribution&lt;double&gt;{exponentialRaceRate}(rng); currentTime += waitingTime; double randomUpdate = boost::random::uniform_01&lt;double&gt;{}(rng); if(randomUpdate &lt; birthPoisson / exponentialRaceRate){ currentPrey++; } else if(randomUpdate &lt; (birthPoisson + predationPoisson) / exponentialRaceRate){ currentPrey--; currentPredators++; } else{ currentPredators--; } currentState.index_put_({0}, currentPrey); currentState.index_put_({1}, currentPredators); // Catch up on (potentiall multiple) sampling events for(; currentTime &gt;= sampleFrequency * (double)currentSample &amp;&amp; currentSample &lt; numSamples; currentSample++){ samples.index_put_({currentSample}, currentState); } } auto shiftedSamples = torch::roll(samples, 1, 0); auto multiplicativeIncrease = torch::divide(samples, shiftedSamples); multiplicativeIncrease = torch::where(torch::isnan(multiplicativeIncrease), 1e-6, multiplicativeIncrease); multiplicativeIncrease = torch::where(multiplicativeIncrease == 0.0, 1e-6, multiplicativeIncrease); shiftedSamples = torch::where(shiftedSamples == 0.0, 1e-6, shiftedSamples); auto augmentedData = torch::concat({shiftedSamples, multiplicativeIncrease}, -1); augmentedData = augmentedData.slice(0, 1, augmentedData.size(0)); return augmentedData; } int main(int argc, char** argv){ auto rng = boost::random::mt19937{1124}; double initPrey = 60.0; double initPredator = 10.0; double truePreyBirth = 0.75; double truePredation = 0.01; double truePredatorDeath = 0.9; double simulationTime = 40; int numSamples = 200; auto initState = torch::tensor({60.0, 10.0}); std::cout &lt;&lt; "Simulating Data..." &lt;&lt; std::endl; auto simulatedData = generativeModel(rng, simulationTime, numSamples, initState.clone(), truePreyBirth, truePredation, truePredatorDeath); std::cout &lt;&lt; simulatedData &lt;&lt; std::endl; auto trainingData = torch::empty({0, 7});; int numSNLEIterations = 4; int mcmcIterations = 20000; int samplingFrequency = 1000; int priorSamples = 100; int samplingIterations = 10; std::array&lt;double, 3&gt; lambdas = {3.0, 20.0, 3.0}; auto unif = boost::random::uniform_01(); for(int i = 1; i &lt;= numSNLEIterations; i++){ std::cout &lt;&lt; "Starting SNLE Iteration " &lt;&lt; i &lt;&lt; std::endl; if(i &gt; 1){ ConditionalAutoregressiveFlow neuralLikelihood(4); neuralLikelihood.train(5000, trainingData); std::array&lt;double, 3&gt; currentParams = { boost::random::exponential_distribution{lambdas[0]}(rng), boost::random::exponential_distribution{lambdas[1]}(rng), boost::random::exponential_distribution{lambdas[2]}(rng) }; double currentPrior = 0.0; for(int i = 0; i &lt; 3; i++){ currentPrior += std::log(lambdas[i]) - lambdas[i] * currentParams[i]; } auto simDim = simulatedData.size(0); double currentLL = -1.0 * neuralLikelihood.forward( torch::concat({ torch::full({simDim, 1}, currentParams[0]), torch::full({simDim, 1}, currentParams[1]), torch::full({simDim, 1}, currentParams[2]), simulatedData }, 1) ).sum().item&lt;double&gt;(); for(int j = 1; j &lt;= mcmcIterations; j++){ for(int gibbsIter = 0; gibbsIter &lt; 3; gibbsIter++){ double scalingFactor = std::exp(1.5 * (unif(rng) - 0.5)); std::array&lt;double, 3&gt; newParams = currentParams; newParams[gibbsIter] = currentParams[gibbsIter] * scalingFactor; double newPrior = 0.0; for(int i = 0; i &lt; 3; i++){ newPrior += std::log(lambdas[i]) - lambdas[i] * newParams[i]; } double newLL = -1.0 * neuralLikelihood.forward( torch::concat({ torch::full({simDim, 1}, newParams[0]), torch::full({simDim, 1}, newParams[1]), torch::full({simDim, 1}, newParams[2]), simulatedData }, 1) ).sum().item&lt;double&gt;(); double logRatio = std::log(scalingFactor) + newLL - currentLL + newPrior - currentPrior; if(std::log(unif(rng)) &lt;= logRatio){ currentPrior = newPrior; currentLL = newLL; currentParams = newParams; } } if(j % samplingFrequency == 0){ for(int simIter = 0; simIter &lt; samplingIterations; simIter++){ auto newData = generativeModel(rng, simulationTime/4.0, numSamples/4, initState.clone(), currentParams[0], currentParams[1], currentParams[2]); auto newDataDim = newData.size(0); newData = torch::concat({ torch::full({newDataDim, 1}, currentParams[0]), torch::full({newDataDim, 1}, currentParams[1]), torch::full({newDataDim, 1}, currentParams[2]), newData }, 1); trainingData = torch::concat({trainingData, newData}, 0); } std::cout &lt;&lt; j &lt;&lt; " ( " &lt;&lt; currentLL &lt;&lt; " ):\tBirth: " &lt;&lt; currentParams[0] &lt;&lt; "\t Predation: " &lt;&lt; currentParams[1] &lt;&lt; "\tDeath: " &lt;&lt; currentParams[2] &lt;&lt; std::endl; } } } else { for(int j = 1; j &lt;= priorSamples; j++){ std::array&lt;double, 3&gt; currentParams = { boost::random::exponential_distribution{lambdas[0]}(rng), boost::random::exponential_distribution{lambdas[1]}(rng), boost::random::exponential_distribution{lambdas[2]}(rng) }; for(int simIter = 0; simIter &lt; samplingIterations; simIter++){ auto newData = generativeModel(rng, simulationTime/4.0, numSamples/4, initState.clone(), currentParams[0], currentParams[1], currentParams[2]); auto newDataDim = newData.size(0); newData = torch::concat({ torch::full({newDataDim, 1}, currentParams[0]), torch::full({newDataDim, 1}, currentParams[1]), torch::full({newDataDim, 1}, currentParams[2]), newData }, 1); trainingData = torch::concat({trainingData, newData}, 0); } std::cout &lt;&lt; j &lt;&lt; "\tBirth: " &lt;&lt; currentParams[0] &lt;&lt; "\t Predation: " &lt;&lt; currentParams[1] &lt;&lt; "\tDeath: " &lt;&lt; currentParams[2] &lt;&lt; std::endl; } } } // Run MCMC one final time to get a trace file ConditionalAutoregressiveFlow neuralLikelihood(4); neuralLikelihood.train(5000, trainingData); int mcmcSamplingFreq = 100; int finalMcmcIter = 50000; std::ofstream outputFile("./output.trace"); outputFile &lt;&lt; "Iteration\tPosterior\tBirthRate\tPredationRate\tDeathRate" &lt;&lt; std::endl; std::array&lt;double, 3&gt; currentParams = { boost::random::exponential_distribution{lambdas[0]}(rng), boost::random::exponential_distribution{lambdas[1]}(rng), boost::random::exponential_distribution{lambdas[2]}(rng) }; double currentPrior = 0.0; for(int i = 0; i &lt; 3; i++){ currentPrior += std::log(lambdas[i]) - lambdas[i] * currentParams[i]; } auto simDim = simulatedData.size(0); double currentLL = -1.0 * neuralLikelihood.forward( torch::concat({ torch::full({simDim, 1}, currentParams[0]), torch::full({simDim, 1}, currentParams[1]), torch::full({simDim, 1}, currentParams[2]), simulatedData }, 1) ).sum().item&lt;double&gt;(); for(int j = 1; j &lt;= finalMcmcIter; j++){ for(int gibbsIter = 0; gibbsIter &lt; 3; gibbsIter++){ double scalingFactor = std::exp(1.5 * (unif(rng) - 0.5)); std::array&lt;double, 3&gt; newParams = currentParams; newParams[gibbsIter] = currentParams[gibbsIter] * scalingFactor; double newPrior = 0.0; for(int i = 0; i &lt; 3; i++){ newPrior += std::log(lambdas[i]) - lambdas[i] * newParams[i]; } double newLL = -1.0 * neuralLikelihood.forward( torch::concat({ torch::full({simDim, 1}, newParams[0]), torch::full({simDim, 1}, newParams[1]), torch::full({simDim, 1}, newParams[2]), simulatedData }, 1) ).sum().item&lt;double&gt;(); double logRatio = std::log(scalingFactor) + newLL - currentLL + newPrior - currentPrior; if(std::log(unif(rng)) &lt;= logRatio){ currentPrior = newPrior; currentLL = newLL; currentParams = newParams; } } if(j % mcmcSamplingFreq == 0){ outputFile &lt;&lt; j &lt;&lt; "\t" &lt;&lt; currentLL + currentPrior &lt;&lt; "\t" &lt;&lt; currentParams[0] &lt;&lt; "\t" &lt;&lt; currentParams[1] &lt;&lt; "\t" &lt;&lt; currentParams[2] &lt;&lt; std::endl; std::cout &lt;&lt; j &lt;&lt; " ( " &lt;&lt; currentLL &lt;&lt; " ):\tBirth: " &lt;&lt; currentParams[0] &lt;&lt; "\t Predation: " &lt;&lt; currentParams[1] &lt;&lt; "\tDeath: " &lt;&lt; currentParams[2] &lt;&lt; std::endl; } } return 0; }]]></summary></entry><entry><title type="html">Modeling Stochastic Lotka-Volterra using Sequential Neural Likelihood Estimation</title><link href="https://wesley-demontigny.github.io/mcmc/neural%20networks/2025/05/22/SNLE_Lotka_Volterra.html" rel="alternate" type="text/html" title="Modeling Stochastic Lotka-Volterra using Sequential Neural Likelihood Estimation" /><published>2025-05-22T00:00:00+00:00</published><updated>2025-05-22T00:00:00+00:00</updated><id>https://wesley-demontigny.github.io/mcmc/neural%20networks/2025/05/22/SNLE_Lotka_Volterra</id><content type="html" xml:base="https://wesley-demontigny.github.io/mcmc/neural%20networks/2025/05/22/SNLE_Lotka_Volterra.html"><![CDATA[<script type="text/javascript" id="MathJax-script" async="" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js">
</script>

<p>Many interesting systems in science can’t be described by closed-form probability distributions. This makes them hard to analyze using classical statistical methods. One common example in biology is the stochastic Lotka–Volterra model, which describes the dynamics of of random predator-prey interactions. It’s a simple model to simulate using the Gillespie algorithm, even though it lacks a closed-form likelihood function.</p>

<p>Briefly, this is how the simulation works:</p>
<ol>
  <li>Initialize a population of predators (\(a\)) and prey (\(b\)), along with per-capita rates of birth (\(\beta_\text{prey}, \beta_\text{predator}\)), death (\(\gamma_\text{prey}, \gamma_\text{predator}\)), and predation (\(\epsilon\)).</li>
  <li>Compute the event rates:
    <ul>
      <li>Prey birth: \(b \cdot \beta_\text{prey}\)</li>
      <li>Prey death: \(b \cdot \gamma_\text{prey}\)</li>
      <li>Predator birth: \(a \cdot b \cdot \beta_\text{predator}\)</li>
      <li>Predation: \(a \cdot b \cdot \epsilon\)</li>
      <li>Predator death: \(a \cdot \gamma_\text{predator}\)</li>
    </ul>
  </li>
  <li>Sum all event rates to get the total event rate \(r\), and draw a waiting time \(t \sim \text{Exponential}(r)\).</li>
  <li>Choose an event with probability proportional to its rate (e.g., the probability of predation is \(\frac{a \cdot b \cdot \epsilon}{r}\)), then update the system accordingly.</li>
  <li>Repeat steps 2–4 until a desired simulation time is reached.</li>
</ol>

<p>Despite being easy to simulate, this model has no closed-form likelihood, which makes traditional parameter estimation difficult. However, if we simulate the model many times under a fixed set of parameters, we get an empirical distribution of likely outcomes, which is essentially a way to score how well different parameters explain the data.</p>

<p>Running millions of simulations for every parameter setting isn’t feasible. Instead, we can train a neural network to approximate this probability distribution based on a finite number of simulations. There are many kinds of these techniques, but here I’ll focus on one: Sequential Neural Likelihood Estimation (SNLE).</p>

<p>SNLE uses normalizing flows to transform a simple base distribution (like a multivariate Gaussian) into one that mimics the complex data-generating process of the simulator. Over successive rounds, the method refines its approximation by focusing simulations on more plausible regions of the parameter space, allowing efficient and accurate inference in models where there is no likelihood. In general, the algorithm looks like this:</p>
<ol>
  <li>Initialize a prior over parameters, \(P(\theta)\), and an autoregressive conditional normalizing flow (ACNF).
    <ul>
      <li>The ACNF transforms a base multivariate Gaussian into a more flexible distribution using a series of affine transformations. Each transformation is conditioned on the preceding dimensions and parameter values. For more details, see Papamakarios et al. (2018).</li>
    </ul>
  </li>
  <li>Generate training data by sampling parameter vectors \(\theta\) from the prior \(P(\theta)\) and simulating data from your generative model.</li>
  <li>Train the ACNF to approximate the likelihood \(L_x(\theta)\) based on these parameter–simulation pairs. This gives a crude estimate of the likelihood for the observed data \(x\).</li>
  <li>Run MCMC using the product \(P(\theta) L_x(\theta)\) as the unnormalized posterior. For each sampled \(\theta\), simulate new data under the model.</li>
  <li>Retrain the ACNF on all accumulated parameter–simulation pairs, including both prior samples and MCMC-based samples. This improves the likelihood approximation in regions of high posterior density.</li>
  <li>Repeat steps 4–5 until the ACNF has adequately learned the likelihood in the relevant region of the parameter space.</li>
</ol>

<p>I spent some time earlier this year trying to understand these methods and so I thought I would share a project of mine on the blog. Although I am unsure what I would use this method for in my own research, it is such a clever application of MCMC that I couldn’t help but implement it. The code below trains a conditional autoregressive flow neural network to learn an autoregressive multiplicative random walk that produces Lotka-Volterra dynamics. While Papamakarios et al. (2019) model the full trajectory as a static vector, I chose to model the incremental dynamics autoregressively. This better respects the sequential structure of the Lotka–Volterra system and allows the learned flow to generate new trajectories step-by-step. Below is an image of a four draws from the trained conditional autoregressive flow; the positions and magnitude of the oscillations shown below are a classic Lotka-Volterra pattern. Although I will not print the output here, SNLE also does a pretty good job at inferring the true parameter value of many of my test simulations.
<img src="/assets/lotka_volterra_trajectories.png" alt="Neural Lotka–Volterra Dynamics" /></p>
<h3 id="conditional_autoregressive_flowpy">conditional_autoregressive_flow.py</h3>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="n">tf</span>
<span class="kn">from</span> <span class="nn">tensorflow</span> <span class="kn">import</span> <span class="n">keras</span>
<span class="kn">import</span> <span class="nn">tensorflow_probability</span> <span class="k">as</span> <span class="n">tfp</span>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>

<span class="n">tfd</span> <span class="o">=</span> <span class="n">tfp</span><span class="p">.</span><span class="n">distributions</span>

<span class="k">class</span> <span class="nc">ConditionalAutoregressiveFlow</span><span class="p">(</span><span class="n">keras</span><span class="p">.</span><span class="n">Model</span><span class="p">):</span>
    <span class="s">"""
    A neural network that autoregressively applies affine transformations to an N dimensional
    normal distribution conditional on some provided parameter. Rather than doing masking like in
    masked autoregressive flows, we are just explicitly enforcing the autoregressive nature of the 
    network in the structure of the network itself.
    """</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_params</span><span class="p">,</span> <span class="n">num_dimensions</span><span class="p">,</span> <span class="n">num_flows</span><span class="p">,</span> <span class="n">internal_dim</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="mf">1e-2</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">ConditionalAutoregressiveFlow</span><span class="p">,</span> <span class="bp">self</span><span class="p">).</span><span class="n">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">optimizer</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">optimizers</span><span class="p">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="n">lr</span><span class="p">,</span> <span class="n">clipnorm</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span> 
        <span class="bp">self</span><span class="p">.</span><span class="n">num_params</span> <span class="o">=</span> <span class="n">num_params</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">num_dimensions</span> <span class="o">=</span> <span class="n">num_dimensions</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">internal_dim</span> <span class="o">=</span> <span class="n">internal_dim</span>
        <span class="bp">self</span><span class="p">.</span><span class="n">num_flows</span> <span class="o">=</span> <span class="n">num_flows</span>

        <span class="bp">self</span><span class="p">.</span><span class="n">autoregressive_networks</span> <span class="o">=</span> <span class="p">[]</span>
        <span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_flows</span><span class="p">):</span>
            <span class="n">permutation</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">num_dimensions</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="p">.</span><span class="n">int32</span><span class="p">))</span>
            <span class="k">if</span><span class="p">(</span><span class="n">n</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">):</span>
                <span class="k">while</span> <span class="n">tf</span><span class="p">.</span><span class="n">math</span><span class="p">.</span><span class="n">reduce_all</span><span class="p">(</span><span class="n">permutation</span> <span class="o">==</span> <span class="bp">self</span><span class="p">.</span><span class="n">autoregressive_networks</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">][</span><span class="s">"permutation"</span><span class="p">]):</span> <span class="c1"># We want to make sure each layer is different from the last
</span>                    <span class="n">permutation</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">num_dimensions</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="p">.</span><span class="n">int32</span><span class="p">))</span> 
            <span class="n">network</span> <span class="o">=</span> <span class="p">[]</span>
            <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_dimensions</span><span class="p">):</span>
                <span class="n">network</span><span class="p">.</span><span class="n">append</span><span class="p">({</span>
                    <span class="s">"dense"</span><span class="p">:</span> <span class="n">keras</span><span class="p">.</span><span class="n">layers</span><span class="p">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">internal_dim</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s">"relu"</span><span class="p">),</span>
                    <span class="s">"alpha"</span><span class="p">:</span> <span class="n">keras</span><span class="p">.</span><span class="n">layers</span><span class="p">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s">"softplus"</span><span class="p">,</span><span class="n">kernel_initializer</span><span class="o">=</span><span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">initializers</span><span class="p">.</span><span class="n">RandomNormal</span><span class="p">(</span><span class="n">mean</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">stddev</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="bp">None</span><span class="p">)),</span>
                    <span class="s">"beta"</span><span class="p">:</span> <span class="n">keras</span><span class="p">.</span><span class="n">layers</span><span class="p">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s">"linear"</span><span class="p">,</span> <span class="n">kernel_initializer</span><span class="o">=</span><span class="n">tf</span><span class="p">.</span><span class="n">keras</span><span class="p">.</span><span class="n">initializers</span><span class="p">.</span><span class="n">RandomNormal</span><span class="p">(</span><span class="n">mean</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">stddev</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="bp">None</span><span class="p">)),</span>
                <span class="p">})</span>
            <span class="bp">self</span><span class="p">.</span><span class="n">autoregressive_networks</span><span class="p">.</span><span class="n">append</span><span class="p">({</span><span class="s">"network"</span><span class="p">:</span> <span class="n">network</span><span class="p">,</span> <span class="s">"permutation"</span><span class="p">:</span> <span class="n">permutation</span><span class="p">})</span>
        
        <span class="bp">self</span><span class="p">.</span><span class="n">base_dist</span> <span class="o">=</span> <span class="n">tfd</span><span class="p">.</span><span class="n">MultivariateNormalDiag</span><span class="p">(</span><span class="n">loc</span><span class="o">=</span><span class="n">tf</span><span class="p">.</span><span class="n">ones</span><span class="p">([</span><span class="n">num_dimensions</span><span class="p">]))</span>

    <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">flow_input</span><span class="p">):</span>
        <span class="s">"""Normalizing Direction"""</span>
        <span class="n">conditionals</span> <span class="o">=</span> <span class="n">flow_input</span><span class="p">[:,</span> <span class="p">:</span><span class="bp">self</span><span class="p">.</span><span class="n">num_params</span><span class="p">]</span>
        <span class="n">data</span> <span class="o">=</span> <span class="n">flow_input</span><span class="p">[:,</span> <span class="bp">self</span><span class="p">.</span><span class="n">num_params</span><span class="p">:]</span>
        <span class="n">tf</span><span class="p">.</span><span class="n">debugging</span><span class="p">.</span><span class="n">assert_equal</span><span class="p">(</span><span class="n">data</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="bp">self</span><span class="p">.</span><span class="n">num_dimensions</span><span class="p">,</span> 
                                  <span class="sa">f</span><span class="s">"Error: </span><span class="si">{</span><span class="bp">self</span><span class="p">.</span><span class="n">num_dimensions</span><span class="si">}</span><span class="s"> dimensions of data were expected, but we only got </span><span class="si">{</span><span class="n">data</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>

        <span class="n">output_list</span> <span class="o">=</span> <span class="p">[</span><span class="n">data</span><span class="p">[:,</span> <span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">num_dimensions</span><span class="p">)]</span>

        <span class="n">jacobian_sum</span> <span class="o">=</span> <span class="mi">0</span>
        <span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">num_flows</span><span class="p">):</span>
            <span class="n">temp_conditionals</span> <span class="o">=</span> <span class="n">conditionals</span>

            <span class="n">network_obj</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">autoregressive_networks</span><span class="p">[</span><span class="n">n</span><span class="p">]</span>
            <span class="n">autoregressive_network</span> <span class="o">=</span> <span class="n">network_obj</span><span class="p">[</span><span class="s">"network"</span><span class="p">]</span>
            <span class="n">network_permutation</span> <span class="o">=</span> <span class="n">network_obj</span><span class="p">[</span><span class="s">"permutation"</span><span class="p">]</span>

            <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">num_dimensions</span><span class="p">):</span>
                <span class="n">index</span> <span class="o">=</span> <span class="n">network_permutation</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
                <span class="n">y</span> <span class="o">=</span> <span class="n">output_list</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>
                <span class="n">flow</span> <span class="o">=</span> <span class="n">autoregressive_network</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>

                <span class="n">x</span> <span class="o">=</span> <span class="n">flow</span><span class="p">[</span><span class="s">"dense"</span><span class="p">](</span><span class="n">temp_conditionals</span><span class="p">)</span>
                <span class="n">alpha</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">flow</span><span class="p">[</span><span class="s">"alpha"</span><span class="p">](</span><span class="n">x</span><span class="p">))</span>
                <span class="n">beta</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">flow</span><span class="p">[</span><span class="s">"beta"</span><span class="p">](</span><span class="n">x</span><span class="p">))</span>
                <span class="n">y</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">divide</span><span class="p">(</span><span class="n">y</span> <span class="o">-</span> <span class="n">beta</span><span class="p">,</span> <span class="n">alpha</span><span class="p">)</span>

                <span class="n">output_list</span><span class="p">[</span><span class="n">index</span><span class="p">]</span> <span class="o">=</span> <span class="n">y</span>
                <span class="n">temp_conditionals</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">concat</span><span class="p">([</span><span class="n">temp_conditionals</span><span class="p">,</span> <span class="n">tf</span><span class="p">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>

                <span class="n">jacobian_sum</span> <span class="o">+=</span> <span class="n">tf</span><span class="p">.</span><span class="n">math</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">alpha</span><span class="p">)</span>

        <span class="n">output</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">stack</span><span class="p">(</span><span class="n">output_list</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>

        <span class="n">log_prob</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">base_dist</span><span class="p">.</span><span class="n">log_prob</span><span class="p">(</span><span class="n">output</span><span class="p">)</span>
        <span class="n">log_prob</span> <span class="o">-=</span> <span class="n">jacobian_sum</span>

        <span class="k">return</span> <span class="n">log_prob</span>
    
    <span class="k">def</span> <span class="nf">transform</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">flow_input</span><span class="p">):</span>
        <span class="s">"""Generative Direction"""</span>
        <span class="n">conditionals</span> <span class="o">=</span> <span class="n">flow_input</span><span class="p">[:,</span> <span class="p">:</span><span class="bp">self</span><span class="p">.</span><span class="n">num_params</span><span class="p">]</span>
        <span class="n">data</span> <span class="o">=</span> <span class="n">flow_input</span><span class="p">[:,</span> <span class="bp">self</span><span class="p">.</span><span class="n">num_params</span><span class="p">:]</span>

        <span class="n">tf</span><span class="p">.</span><span class="n">debugging</span><span class="p">.</span><span class="n">assert_equal</span><span class="p">(</span><span class="n">data</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="bp">self</span><span class="p">.</span><span class="n">num_dimensions</span><span class="p">,</span> 
                                  <span class="sa">f</span><span class="s">"Error: </span><span class="si">{</span><span class="bp">self</span><span class="p">.</span><span class="n">num_dimensions</span><span class="si">}</span><span class="s"> dimensions of data were expected, but we only got </span><span class="si">{</span><span class="n">data</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>


        <span class="n">output_list</span> <span class="o">=</span> <span class="p">[</span><span class="n">data</span><span class="p">[:,</span> <span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">num_dimensions</span><span class="p">)]</span>


        <span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">num_flows</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">):</span>
            <span class="n">temp_conditionals</span> <span class="o">=</span> <span class="n">conditionals</span>

            <span class="n">network_obj</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">autoregressive_networks</span><span class="p">[</span><span class="n">n</span><span class="p">]</span>
            <span class="n">autoregressive_network</span> <span class="o">=</span> <span class="n">network_obj</span><span class="p">[</span><span class="s">"network"</span><span class="p">]</span>
            <span class="n">network_permutation</span> <span class="o">=</span> <span class="n">network_obj</span><span class="p">[</span><span class="s">"permutation"</span><span class="p">]</span>

            <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">num_dimensions</span><span class="p">):</span>
                <span class="n">index</span> <span class="o">=</span> <span class="n">network_permutation</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
                <span class="n">y</span> <span class="o">=</span> <span class="n">output_list</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>
                <span class="n">flow</span> <span class="o">=</span> <span class="n">autoregressive_network</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>

                <span class="n">x</span> <span class="o">=</span> <span class="n">flow</span><span class="p">[</span><span class="s">"dense"</span><span class="p">](</span><span class="n">temp_conditionals</span><span class="p">)</span>
                <span class="n">alpha</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">flow</span><span class="p">[</span><span class="s">"alpha"</span><span class="p">](</span><span class="n">x</span><span class="p">))</span>
                <span class="n">beta</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">flow</span><span class="p">[</span><span class="s">"beta"</span><span class="p">](</span><span class="n">x</span><span class="p">))</span>
                <span class="n">y</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">multiply</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">alpha</span><span class="p">)</span> <span class="o">+</span> <span class="n">beta</span>

                <span class="n">temp_conditionals</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">concat</span><span class="p">([</span><span class="n">temp_conditionals</span><span class="p">,</span> <span class="n">tf</span><span class="p">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">output_list</span><span class="p">[</span><span class="n">index</span><span class="p">],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
                <span class="n">output_list</span><span class="p">[</span><span class="n">index</span><span class="p">]</span> <span class="o">=</span> <span class="n">y</span>

        <span class="n">output</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">stack</span><span class="p">(</span><span class="n">output_list</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>

        <span class="k">return</span> <span class="n">output</span>

    <span class="k">def</span> <span class="nf">draw</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">parameters</span><span class="p">,</span> <span class="n">num_draws</span> <span class="o">=</span> <span class="mi">1</span><span class="p">):</span>
        <span class="s">"""Generative Direction"""</span>
        <span class="n">conditionals</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">tile</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">parameters</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">),</span> <span class="p">[</span><span class="n">num_draws</span><span class="p">,</span><span class="mi">1</span><span class="p">])</span>
        <span class="n">data</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">base_dist</span><span class="p">.</span><span class="n">sample</span><span class="p">(</span><span class="n">num_draws</span><span class="p">)</span>
        <span class="n">flow_input</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">concat</span><span class="p">([</span><span class="n">conditionals</span><span class="p">,</span> <span class="n">data</span><span class="p">],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>

        <span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">transform</span><span class="p">(</span><span class="n">flow_input</span><span class="p">)</span>
        

    <span class="k">def</span> <span class="nf">train_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span>
        <span class="s">"""Negative log-likelihood loss"""</span>
        
        <span class="k">with</span> <span class="n">tf</span><span class="p">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape</span><span class="p">:</span>
            <span class="n">log_likelihood</span> <span class="o">=</span> <span class="bp">self</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
            <span class="n">loss</span> <span class="o">=</span> <span class="o">-</span><span class="n">tf</span><span class="p">.</span><span class="n">reduce_mean</span><span class="p">(</span><span class="n">log_likelihood</span><span class="p">)</span>
            
        <span class="n">trainable_vars</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">trainable_variables</span>
        <span class="n">gradients</span> <span class="o">=</span> <span class="n">tape</span><span class="p">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="n">trainable_vars</span><span class="p">)</span>

        <span class="bp">self</span><span class="p">.</span><span class="n">optimizer</span><span class="p">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">gradients</span><span class="p">,</span> <span class="n">trainable_vars</span><span class="p">))</span>

        <span class="k">return</span> <span class="p">{</span><span class="s">"loss"</span><span class="p">:</span> <span class="n">loss</span><span class="p">}</span>

<span class="k">if</span> <span class="n">__name__</span><span class="o">==</span><span class="s">"__main__"</span><span class="p">:</span>
    <span class="n">scale</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">uniform</span><span class="p">([</span><span class="mi">500</span><span class="p">],</span> <span class="mf">4.0</span><span class="p">,</span> <span class="mf">8.0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="p">.</span><span class="n">float32</span><span class="p">)</span>
    <span class="n">scale</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">repeat</span><span class="p">(</span><span class="n">scale</span><span class="p">,</span> <span class="mi">1000</span><span class="p">),</span> <span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
    <span class="n">quandrant_flag</span> <span class="o">=</span> <span class="n">tfd</span><span class="p">.</span><span class="n">Bernoulli</span><span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="p">.</span><span class="nb">bool</span><span class="p">).</span><span class="n">sample</span><span class="p">(</span><span class="n">scale</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
    <span class="n">mean_vec</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">where</span><span class="p">(</span><span class="n">quandrant_flag</span><span class="p">,</span> <span class="n">scale</span> <span class="o">*</span> <span class="o">-</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">scale</span><span class="p">)</span>
    <span class="n">training_samples</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">tfd</span><span class="p">.</span><span class="n">MultivariateNormalDiag</span><span class="p">([</span><span class="n">mean_vec</span><span class="p">,</span> <span class="n">mean_vec</span><span class="o">*-</span><span class="mi">1</span><span class="p">]).</span><span class="n">sample</span><span class="p">(</span><span class="mi">1</span><span class="p">))</span>
    <span class="n">training_input</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">concat</span><span class="p">([</span><span class="n">tf</span><span class="p">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">scale</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">),</span> <span class="n">tf</span><span class="p">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">training_samples</span><span class="p">)],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>

    <span class="n">test_af</span> <span class="o">=</span> <span class="n">ConditionalAutoregressiveFlow</span><span class="p">(</span><span class="n">num_params</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span> <span class="n">num_dimensions</span> <span class="o">=</span> <span class="mi">2</span><span class="p">,</span> <span class="n">num_flows</span> <span class="o">=</span> <span class="mi">5</span><span class="p">,</span> <span class="n">internal_dim</span> <span class="o">=</span> <span class="mi">64</span><span class="p">)</span>

    <span class="n">losses</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">500</span><span class="p">):</span>
        <span class="n">loss</span> <span class="o">=</span> <span class="n">test_af</span><span class="p">.</span><span class="n">train_step</span><span class="p">(</span><span class="n">training_input</span><span class="p">)[</span><span class="s">'loss'</span><span class="p">]</span>
        <span class="n">losses</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span>
        <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Loss at Epoch </span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s">: </span><span class="si">{</span><span class="n">loss</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>

    <span class="n">transformed_data</span> <span class="o">=</span> <span class="n">test_af</span><span class="p">.</span><span class="n">draw</span><span class="p">([</span><span class="mf">5.5</span><span class="p">],</span> <span class="n">num_draws</span><span class="o">=</span><span class="mi">1000</span><span class="p">)</span>

    <span class="n">dummy_data</span> <span class="o">=</span> <span class="n">tfd</span><span class="p">.</span><span class="n">MultivariateNormalDiag</span><span class="p">(</span><span class="n">loc</span><span class="o">=</span><span class="n">tf</span><span class="p">.</span><span class="n">ones</span><span class="p">([</span><span class="mi">2</span><span class="p">])).</span><span class="n">sample</span><span class="p">(</span><span class="mi">1000</span><span class="p">)</span>
    <span class="n">dummy_data_input</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">concat</span><span class="p">([</span><span class="n">tf</span><span class="p">.</span><span class="n">fill</span><span class="p">([</span><span class="n">dummy_data</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="mi">1</span><span class="p">],</span> <span class="mf">7.5</span><span class="p">),</span> <span class="n">dummy_data</span><span class="p">],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
    <span class="n">transformed_data_2</span> <span class="o">=</span> <span class="n">test_af</span><span class="p">.</span><span class="n">transform</span><span class="p">(</span><span class="n">dummy_data_input</span><span class="p">)</span>

    <span class="n">fig1</span><span class="p">,</span> <span class="n">ax1</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">()</span>
    <span class="n">ax1</span><span class="p">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">x</span> <span class="o">=</span> <span class="n">transformed_data</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">y</span> <span class="o">=</span> <span class="n">transformed_data</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="s">"Scale = 5.5"</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span>
    <span class="n">ax1</span><span class="p">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">x</span> <span class="o">=</span> <span class="n">transformed_data_2</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">y</span> <span class="o">=</span> <span class="n">transformed_data_2</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="s">"Scale = 7.5"</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span>
    <span class="n">ax1</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
    <span class="n">ax1</span><span class="p">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s">"Y"</span><span class="p">)</span>
    <span class="n">ax1</span><span class="p">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s">"X"</span><span class="p">)</span>
    <span class="n">ax1</span><span class="p">.</span><span class="n">set_title</span><span class="p">(</span><span class="s">"Transformation"</span><span class="p">)</span>
    <span class="n">fig1</span><span class="p">.</span><span class="n">savefig</span><span class="p">(</span><span class="s">"./conditional_af_transformation.png"</span><span class="p">)</span>
    
    <span class="n">x_min</span> <span class="o">=</span> <span class="o">-</span><span class="mi">10</span>
    <span class="n">x_max</span> <span class="o">=</span> <span class="mi">10</span>
    <span class="n">y_min</span> <span class="o">=</span> <span class="o">-</span><span class="mi">10</span>
    <span class="n">y_max</span> <span class="o">=</span> <span class="mi">10</span>
    <span class="n">num_x</span> <span class="o">=</span> <span class="mi">100</span>
    <span class="n">num_y</span> <span class="o">=</span> <span class="mi">100</span>

    <span class="n">x</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="n">x_min</span><span class="p">,</span> <span class="n">x_max</span><span class="p">,</span> <span class="n">num_x</span><span class="p">)</span>
    <span class="n">y</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="n">y_min</span><span class="p">,</span> <span class="n">y_max</span><span class="p">,</span> <span class="n">num_y</span><span class="p">)</span>

    <span class="n">X</span><span class="p">,</span> <span class="n">Y</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">meshgrid</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
    <span class="n">coordinates</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">stack</span><span class="p">([</span><span class="n">tf</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]),</span> <span class="n">tf</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">Y</span><span class="p">,</span> <span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
    <span class="n">coordinates</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">cast</span><span class="p">(</span><span class="n">coordinates</span><span class="p">,</span> <span class="n">tf</span><span class="p">.</span><span class="n">float32</span><span class="p">)</span>

    <span class="n">coordinate_input</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">concat</span><span class="p">([</span><span class="n">tf</span><span class="p">.</span><span class="n">fill</span><span class="p">([</span><span class="n">coordinates</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="mi">1</span><span class="p">],</span> <span class="mf">7.5</span><span class="p">),</span> <span class="n">coordinates</span><span class="p">],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
    <span class="n">probs</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">math</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="n">test_af</span><span class="p">(</span><span class="n">coordinate_input</span><span class="p">))</span>
    <span class="n">heatmap</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">probs</span><span class="p">,</span> <span class="p">(</span><span class="n">num_y</span><span class="p">,</span> <span class="n">num_x</span><span class="p">))</span>

    <span class="n">fig2</span><span class="p">,</span> <span class="n">ax2</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">()</span>
    <span class="n">ax2</span><span class="p">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">heatmap</span><span class="p">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">extent</span><span class="o">=</span><span class="p">[</span><span class="n">x_min</span><span class="p">,</span> <span class="n">x_max</span><span class="p">,</span> <span class="n">y_min</span><span class="p">,</span> <span class="n">y_max</span><span class="p">],</span> <span class="n">origin</span><span class="o">=</span><span class="s">'lower'</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="s">'viridis'</span><span class="p">)</span>
    <span class="n">ax2</span><span class="p">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s">"Y"</span><span class="p">)</span>
    <span class="n">ax2</span><span class="p">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s">"X"</span><span class="p">)</span>
    <span class="n">ax2</span><span class="p">.</span><span class="n">set_title</span><span class="p">(</span><span class="s">"Transform Probability"</span><span class="p">)</span>
    <span class="n">fig2</span><span class="p">.</span><span class="n">savefig</span><span class="p">(</span><span class="s">"./conditional_af_pdf.png"</span><span class="p">)</span>
</code></pre></div></div>
<h3 id="sequential_neural_likelihood_lotka_volterrapy">sequential_neural_likelihood_lotka_volterra.py</h3>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="n">tf</span>
<span class="kn">from</span> <span class="nn">tensorflow</span> <span class="kn">import</span> <span class="n">keras</span>
<span class="kn">import</span> <span class="nn">tensorflow_probability</span> <span class="k">as</span> <span class="n">tfp</span>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="kn">from</span> <span class="nn">conditional_autoregressive_flow</span> <span class="kn">import</span> <span class="n">ConditionalAutoregressiveFlow</span>
<span class="kn">import</span> <span class="nn">math</span>

<span class="n">tfd</span> <span class="o">=</span> <span class="n">tfp</span><span class="p">.</span><span class="n">distributions</span>

<span class="k">class</span> <span class="nc">LogConditionalAutoregressiveFlow</span><span class="p">(</span><span class="n">ConditionalAutoregressiveFlow</span><span class="p">):</span>
    <span class="s">"""
    Our conditionals interact on a multiplicative scale, and we are 
    modeling a multiplicative random walk - so our neural network should
    reflect that. We log transform all of our conditionals to express that.
    """</span>
    <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">flow_input</span><span class="p">):</span>
        <span class="s">"""Normalizing Direction"""</span>
        <span class="n">conditionals</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">math</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">flow_input</span><span class="p">[:,</span> <span class="p">:</span><span class="bp">self</span><span class="p">.</span><span class="n">num_params</span><span class="p">])</span>
        <span class="n">data</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">math</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">flow_input</span><span class="p">[:,</span> <span class="bp">self</span><span class="p">.</span><span class="n">num_params</span><span class="p">:])</span>
        <span class="n">tf</span><span class="p">.</span><span class="n">debugging</span><span class="p">.</span><span class="n">assert_equal</span><span class="p">(</span><span class="n">data</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="bp">self</span><span class="p">.</span><span class="n">num_dimensions</span><span class="p">,</span> 
                                  <span class="sa">f</span><span class="s">"Error: </span><span class="si">{</span><span class="bp">self</span><span class="p">.</span><span class="n">num_dimensions</span><span class="si">}</span><span class="s"> dimensions of data were expected, but we only got </span><span class="si">{</span><span class="n">data</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>

        <span class="n">output_list</span> <span class="o">=</span> <span class="p">[</span><span class="n">data</span><span class="p">[:,</span> <span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">num_dimensions</span><span class="p">)]</span>

        <span class="c1"># Adjust for the log transform (-1 * log(original data))
</span>        <span class="n">jacobian_sum</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span> <span class="o">*</span> <span class="n">tf</span><span class="p">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
        <span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">num_flows</span><span class="p">):</span>
            <span class="n">temp_conditionals</span> <span class="o">=</span> <span class="n">conditionals</span>

            <span class="n">network_obj</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">autoregressive_networks</span><span class="p">[</span><span class="n">n</span><span class="p">]</span>
            <span class="n">autoregressive_network</span> <span class="o">=</span> <span class="n">network_obj</span><span class="p">[</span><span class="s">"network"</span><span class="p">]</span>
            <span class="n">network_permutation</span> <span class="o">=</span> <span class="n">network_obj</span><span class="p">[</span><span class="s">"permutation"</span><span class="p">]</span>

            <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">num_dimensions</span><span class="p">):</span>
                <span class="n">index</span> <span class="o">=</span> <span class="n">network_permutation</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
                <span class="n">y</span> <span class="o">=</span> <span class="n">output_list</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>
                <span class="n">flow</span> <span class="o">=</span> <span class="n">autoregressive_network</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>

                <span class="n">x</span> <span class="o">=</span> <span class="n">flow</span><span class="p">[</span><span class="s">"dense"</span><span class="p">](</span><span class="n">temp_conditionals</span><span class="p">)</span>
                <span class="n">alpha</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">flow</span><span class="p">[</span><span class="s">"alpha"</span><span class="p">](</span><span class="n">x</span><span class="p">))</span>
                <span class="n">beta</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">flow</span><span class="p">[</span><span class="s">"beta"</span><span class="p">](</span><span class="n">x</span><span class="p">))</span>
                <span class="n">y</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">divide</span><span class="p">(</span><span class="n">y</span> <span class="o">-</span> <span class="n">beta</span><span class="p">,</span> <span class="n">alpha</span><span class="p">)</span>

                <span class="n">output_list</span><span class="p">[</span><span class="n">index</span><span class="p">]</span> <span class="o">=</span> <span class="n">y</span>
                <span class="n">temp_conditionals</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">concat</span><span class="p">([</span><span class="n">temp_conditionals</span><span class="p">,</span> <span class="n">tf</span><span class="p">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>

                <span class="n">jacobian_sum</span> <span class="o">+=</span> <span class="n">tf</span><span class="p">.</span><span class="n">math</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">alpha</span><span class="p">)</span>

        <span class="n">output</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">stack</span><span class="p">(</span><span class="n">output_list</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>

        <span class="n">log_prob</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">base_dist</span><span class="p">.</span><span class="n">log_prob</span><span class="p">(</span><span class="n">output</span><span class="p">)</span>
        <span class="n">log_prob</span> <span class="o">-=</span> <span class="n">jacobian_sum</span>

        <span class="k">return</span> <span class="n">log_prob</span>
    
    <span class="k">def</span> <span class="nf">transform</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">flow_input</span><span class="p">):</span>
        <span class="s">"""Generative Direction"""</span>
        <span class="n">conditionals</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">math</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">flow_input</span><span class="p">[:,</span> <span class="p">:</span><span class="bp">self</span><span class="p">.</span><span class="n">num_params</span><span class="p">])</span>
        <span class="n">data</span> <span class="o">=</span> <span class="n">flow_input</span><span class="p">[:,</span> <span class="bp">self</span><span class="p">.</span><span class="n">num_params</span><span class="p">:]</span>

        <span class="n">tf</span><span class="p">.</span><span class="n">debugging</span><span class="p">.</span><span class="n">assert_equal</span><span class="p">(</span><span class="n">data</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="bp">self</span><span class="p">.</span><span class="n">num_dimensions</span><span class="p">,</span> 
                                  <span class="sa">f</span><span class="s">"Error: </span><span class="si">{</span><span class="bp">self</span><span class="p">.</span><span class="n">num_dimensions</span><span class="si">}</span><span class="s"> dimensions of data were expected, but we only got </span><span class="si">{</span><span class="n">data</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>


        <span class="n">output_list</span> <span class="o">=</span> <span class="p">[</span><span class="n">data</span><span class="p">[:,</span> <span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">num_dimensions</span><span class="p">)]</span>

        <span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">num_flows</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">):</span>
            <span class="n">temp_conditionals</span> <span class="o">=</span> <span class="n">conditionals</span>

            <span class="n">network_obj</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">autoregressive_networks</span><span class="p">[</span><span class="n">n</span><span class="p">]</span>
            <span class="n">autoregressive_network</span> <span class="o">=</span> <span class="n">network_obj</span><span class="p">[</span><span class="s">"network"</span><span class="p">]</span>
            <span class="n">network_permutation</span> <span class="o">=</span> <span class="n">network_obj</span><span class="p">[</span><span class="s">"permutation"</span><span class="p">]</span>

            <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">num_dimensions</span><span class="p">):</span>
                <span class="n">index</span> <span class="o">=</span> <span class="n">network_permutation</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
                <span class="n">y</span> <span class="o">=</span> <span class="n">output_list</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>
                <span class="n">flow</span> <span class="o">=</span> <span class="n">autoregressive_network</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>

                <span class="n">x</span> <span class="o">=</span> <span class="n">flow</span><span class="p">[</span><span class="s">"dense"</span><span class="p">](</span><span class="n">temp_conditionals</span><span class="p">)</span>
                <span class="n">alpha</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">flow</span><span class="p">[</span><span class="s">"alpha"</span><span class="p">](</span><span class="n">x</span><span class="p">))</span>
                <span class="n">beta</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">flow</span><span class="p">[</span><span class="s">"beta"</span><span class="p">](</span><span class="n">x</span><span class="p">))</span>
                <span class="n">y</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">multiply</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">alpha</span><span class="p">)</span> <span class="o">+</span> <span class="n">beta</span>

                <span class="n">temp_conditionals</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">concat</span><span class="p">([</span><span class="n">temp_conditionals</span><span class="p">,</span> <span class="n">tf</span><span class="p">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">output_list</span><span class="p">[</span><span class="n">index</span><span class="p">],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
                <span class="n">output_list</span><span class="p">[</span><span class="n">index</span><span class="p">]</span> <span class="o">=</span> <span class="n">y</span>

        <span class="n">output</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">stack</span><span class="p">(</span><span class="n">output_list</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>

        <span class="k">return</span> <span class="n">tf</span><span class="p">.</span><span class="n">math</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="n">output</span><span class="p">)</span>


<span class="k">def</span> <span class="nf">generative_model</span><span class="p">(</span><span class="n">prey_birth</span><span class="p">,</span> <span class="n">predation</span><span class="p">,</span> <span class="n">predator_birth</span><span class="p">,</span> <span class="n">predator_death</span><span class="p">,</span> <span class="n">init_predators</span><span class="o">=</span><span class="mi">30</span><span class="p">,</span> <span class="n">init_prey</span><span class="o">=</span><span class="mi">60</span><span class="p">,</span> <span class="n">total_time</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span> <span class="n">samples_per_unit_time</span><span class="o">=</span><span class="mi">5</span><span class="p">):</span>
    <span class="s">"""
    Simulate a time series of stochastic Lotka-Volterra dynamics.
    Here we will output the multiplicative increase from the previous step.
    """</span>

    <span class="c1"># Enforce minimum for simplicity
</span>    <span class="n">prey_birth</span> <span class="o">=</span> <span class="n">prey_birth</span> <span class="o">+</span> <span class="mf">1e-2</span>
    <span class="n">predation</span> <span class="o">=</span> <span class="n">predation</span> <span class="o">+</span> <span class="mf">1e-2</span>
    <span class="n">predator_birth</span> <span class="o">=</span> <span class="n">predator_birth</span> <span class="o">+</span> <span class="mf">1e-2</span>
    <span class="n">predator_death</span> <span class="o">=</span> <span class="n">predator_death</span> <span class="o">+</span> <span class="mf">1e-2</span>

    <span class="n">times</span> <span class="o">=</span> <span class="p">[</span><span class="mf">0.0</span><span class="p">]</span>
    <span class="n">predators</span> <span class="o">=</span> <span class="p">[</span><span class="n">init_predators</span><span class="p">]</span>
    <span class="n">prey</span> <span class="o">=</span> <span class="p">[</span><span class="n">init_prey</span><span class="p">]</span>

    <span class="k">while</span> <span class="bp">True</span><span class="p">:</span>
	    <span class="c1"># Get current population
</span>        <span class="n">prey_count</span> <span class="o">=</span> <span class="n">prey</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
        <span class="n">predator_count</span> <span class="o">=</span> <span class="n">predators</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>

		<span class="c1"># Determine event rates
</span>        <span class="n">prey_birth_rate</span> <span class="o">=</span> <span class="n">prey_birth</span><span class="o">*</span><span class="n">prey_count</span> 
        <span class="n">predation_rate</span> <span class="o">=</span>  <span class="n">predation</span><span class="o">*</span><span class="n">prey_count</span><span class="o">*</span><span class="n">predator_count</span>
        <span class="n">predator_death_rate</span> <span class="o">=</span> <span class="n">predator_death</span><span class="o">*</span><span class="n">predator_count</span>
        <span class="n">predator_birth_rate</span> <span class="o">=</span> <span class="n">predator_birth</span><span class="o">*</span><span class="n">predator_count</span><span class="o">*</span><span class="n">prey_count</span>

		<span class="c1"># Get total "race" rate
</span>        <span class="n">total_rate</span> <span class="o">=</span> <span class="n">prey_birth_rate</span> <span class="o">+</span> <span class="n">predation_rate</span> <span class="o">+</span> <span class="n">predator_death_rate</span> <span class="o">+</span> <span class="n">predator_birth_rate</span>
        <span class="k">if</span><span class="p">(</span><span class="n">total_rate</span> <span class="o">==</span> <span class="mi">0</span><span class="p">):</span> <span class="c1"># Total extinction
</span>            <span class="k">break</span>

        <span class="c1"># These are what we based the uniform draw on
</span>        <span class="n">prey_birth_cumulative</span> <span class="o">=</span> <span class="n">prey_birth_rate</span><span class="o">/</span><span class="n">total_rate</span>
        <span class="n">predation_cumulative</span> <span class="o">=</span> <span class="n">predation_rate</span><span class="o">/</span><span class="n">total_rate</span> <span class="o">+</span> <span class="n">prey_birth_cumulative</span>
        <span class="n">predator_death_cumulative</span> <span class="o">=</span> <span class="n">predator_death_rate</span><span class="o">/</span><span class="n">total_rate</span> <span class="o">+</span> <span class="n">predation_cumulative</span>

        <span class="n">waiting_time</span> <span class="o">=</span> <span class="n">tfd</span><span class="p">.</span><span class="n">Exponential</span><span class="p">(</span><span class="n">total_rate</span><span class="p">).</span><span class="n">sample</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>

        <span class="n">new_time</span> <span class="o">=</span> <span class="n">times</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">waiting_time</span> 
        <span class="k">if</span><span class="p">(</span><span class="n">new_time</span> <span class="o">&lt;</span> <span class="n">total_time</span><span class="p">):</span>
            <span class="n">times</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">new_time</span><span class="p">)</span>
            <span class="c1"># Draw a random event (birth, death, predation)
</span>            <span class="n">action_draw</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">uniform</span><span class="p">([</span><span class="mi">1</span><span class="p">],</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">)</span>
            <span class="k">if</span><span class="p">(</span><span class="n">action_draw</span> <span class="o">&lt;</span> <span class="n">prey_birth_cumulative</span><span class="p">):</span>
                <span class="n">predators</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">predator_count</span><span class="p">)</span>
                <span class="n">prey</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">prey_count</span> <span class="o">+</span> <span class="mf">1.0</span><span class="p">)</span>
            <span class="k">elif</span><span class="p">(</span><span class="n">action_draw</span> <span class="o">&lt;</span> <span class="n">predation_cumulative</span><span class="p">):</span>
                <span class="n">predators</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">predator_count</span><span class="p">)</span>
                <span class="n">prey</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">prey_count</span> <span class="o">-</span> <span class="mf">1.0</span><span class="p">)</span>
            <span class="k">elif</span><span class="p">(</span><span class="n">action_draw</span> <span class="o">&lt;</span> <span class="n">predator_death_cumulative</span><span class="p">):</span>
                <span class="n">predators</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">predator_count</span> <span class="o">-</span> <span class="mf">1.0</span><span class="p">)</span>
                <span class="n">prey</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">prey_count</span><span class="p">)</span>
            <span class="k">else</span><span class="p">:</span>
                <span class="n">predators</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">predator_count</span> <span class="o">+</span> <span class="mf">1.0</span><span class="p">)</span>
                <span class="n">prey</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">prey_count</span><span class="p">)</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="k">break</span>

    <span class="n">time_sampled_prey</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="n">time_sampled_predator</span> <span class="o">=</span> <span class="p">[]</span>

    <span class="n">increment</span> <span class="o">=</span> <span class="mi">1</span><span class="o">/</span><span class="n">samples_per_unit_time</span>
    <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">total_time</span> <span class="o">*</span> <span class="n">samples_per_unit_time</span><span class="p">):</span>
        <span class="n">last_event</span> <span class="o">=</span> <span class="mi">0</span>
        <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">times</span><span class="p">)):</span>
            <span class="k">if</span><span class="p">(</span><span class="n">times</span><span class="p">[</span><span class="n">t</span><span class="p">]</span> <span class="o">&lt;=</span> <span class="n">i</span> <span class="o">*</span> <span class="n">increment</span><span class="p">):</span>
                <span class="n">last_event</span> <span class="o">=</span> <span class="n">t</span>
            <span class="k">else</span><span class="p">:</span>
                <span class="k">break</span>
        
        <span class="n">time_sampled_predator</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">predators</span><span class="p">[</span><span class="n">t</span><span class="p">])</span>
        <span class="n">time_sampled_prey</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">prey</span><span class="p">[</span><span class="n">t</span><span class="p">])</span>

    <span class="n">current_dim</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">stack</span><span class="p">([</span><span class="n">time_sampled_predator</span><span class="p">,</span> <span class="n">time_sampled_prey</span><span class="p">],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
    <span class="n">prev_dim</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">roll</span><span class="p">(</span><span class="n">current_dim</span><span class="p">,</span> <span class="n">shift</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>

    <span class="c1"># Now we get the increase
</span>    <span class="n">multiplicative_increase</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">math</span><span class="p">.</span><span class="n">divide</span><span class="p">(</span><span class="n">current_dim</span><span class="p">,</span> <span class="n">prev_dim</span><span class="p">)</span>
    <span class="n">multiplicative_increase</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">where</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">math</span><span class="p">.</span><span class="n">is_nan</span><span class="p">(</span><span class="n">multiplicative_increase</span><span class="p">),</span> <span class="mf">1e-6</span><span class="p">,</span> <span class="n">multiplicative_increase</span><span class="p">)</span>
    <span class="n">multiplicative_increase</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">where</span><span class="p">(</span><span class="n">multiplicative_increase</span> <span class="o">&lt;=</span> <span class="mi">0</span><span class="p">,</span> <span class="mf">1e-6</span><span class="p">,</span> <span class="n">multiplicative_increase</span><span class="p">)</span>
    <span class="n">prev_dim</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">where</span><span class="p">(</span><span class="n">prev_dim</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="mf">1e-6</span><span class="p">,</span> <span class="n">prev_dim</span><span class="p">)</span>

    <span class="c1"># Roll moves the last element to the first, so we shave that off
</span>    <span class="n">output</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">concat</span><span class="p">([</span><span class="n">prev_dim</span><span class="p">,</span> <span class="n">multiplicative_increase</span><span class="p">],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)[</span><span class="mi">1</span><span class="p">:,</span> <span class="p">:]</span>

    <span class="k">return</span> <span class="n">output</span>

<span class="k">def</span> <span class="nf">parameter_update</span><span class="p">(</span><span class="n">parameter</span><span class="p">):</span>
    <span class="n">param_choice</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">uniform</span><span class="p">([</span><span class="mi">1</span><span class="p">],</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="p">.</span><span class="n">int32</span><span class="p">)</span> <span class="c1"># Pick random param
</span>    <span class="n">param_choice</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">param_choice</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span>

    <span class="n">scaler</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">math</span><span class="p">.</span><span class="n">exp</span><span class="p">(</span><span class="mf">2.0</span> <span class="o">*</span> <span class="n">tf</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">uniform</span><span class="p">([</span><span class="mi">1</span><span class="p">],</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">-</span> <span class="mf">0.5</span><span class="p">)</span>

    <span class="n">updated_value</span> <span class="o">=</span> <span class="n">parameter</span><span class="p">[</span><span class="n">param_choice</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">]]</span> <span class="o">*</span> <span class="n">scaler</span>

    <span class="n">parameter</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">tensor_scatter_nd_update</span><span class="p">(</span><span class="n">parameter</span><span class="p">,</span> <span class="n">param_choice</span><span class="p">,</span> <span class="n">updated_value</span><span class="p">)</span>

    <span class="k">return</span> <span class="n">parameter</span><span class="p">,</span> <span class="n">tf</span><span class="p">.</span><span class="n">math</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">scaler</span><span class="p">)</span>

<span class="k">def</span> <span class="nf">mcmc</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">prior</span><span class="p">,</span> <span class="n">likelihood</span> <span class="o">=</span> <span class="bp">None</span><span class="p">,</span> <span class="n">iterations</span> <span class="o">=</span> <span class="mi">1000</span><span class="p">,</span> <span class="n">debug</span> <span class="o">=</span> <span class="bp">True</span><span class="p">,</span> <span class="n">sample_iter</span> <span class="o">=</span> <span class="mi">1000</span><span class="p">):</span>
    <span class="s">"""
    If a likelihood is not passed we will do Metropolis Hastings
    on the prior alone. For now we are just going to be doing scale moves 
    on the parameter when it comes to working with the product of the prior 
    and model. For the prior alone we will just draw directly from the prior.
    This function the final state of the Markov chain, which can be plugged 
    into the generative model.
    """</span>
    <span class="n">return_params</span> <span class="o">=</span> <span class="p">[]</span>

    <span class="k">if</span><span class="p">(</span><span class="n">likelihood</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">):</span>
        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">iterations</span><span class="o">/</span><span class="n">sample_iter</span><span class="p">)):</span>
            <span class="n">param</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">prior</span><span class="p">.</span><span class="n">sample</span><span class="p">())</span>
            <span class="k">if</span><span class="p">(</span><span class="n">debug</span><span class="p">):</span>
                <span class="n">tf</span><span class="p">.</span><span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Drawing parameter from prior: </span><span class="si">{</span><span class="n">param</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
            <span class="n">return_params</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">param</span><span class="p">)</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="n">param</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">prior</span><span class="p">.</span><span class="n">sample</span><span class="p">())</span>
        <span class="n">old_posterior</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">prior</span><span class="p">.</span><span class="n">log_prob</span><span class="p">(</span><span class="n">param</span><span class="p">))</span> <span class="o">+</span>  <span class="n">tf</span><span class="p">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">likelihood</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">concat</span><span class="p">([</span><span class="n">tf</span><span class="p">.</span><span class="n">tile</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">param</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">),</span> <span class="p">[</span><span class="n">data</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="mi">1</span><span class="p">]),</span> <span class="n">data</span><span class="p">],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)))</span>
        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">iterations</span><span class="o">+</span><span class="mi">1</span><span class="p">):</span>
            <span class="n">new_param</span><span class="p">,</span> <span class="n">hastings</span> <span class="o">=</span> <span class="n">parameter_update</span><span class="p">(</span><span class="n">param</span><span class="p">)</span>
            <span class="n">new_posterior</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">prior</span><span class="p">.</span><span class="n">log_prob</span><span class="p">(</span><span class="n">new_param</span><span class="p">))</span> <span class="o">+</span> <span class="n">tf</span><span class="p">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">likelihood</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">concat</span><span class="p">([</span><span class="n">tf</span><span class="p">.</span><span class="n">tile</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">new_param</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">),</span> <span class="p">[</span><span class="n">data</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="mi">1</span><span class="p">]),</span> <span class="n">data</span><span class="p">],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)))</span>
            <span class="n">posterior_ratio</span> <span class="o">=</span> <span class="n">hastings</span> <span class="o">+</span> <span class="n">new_posterior</span> <span class="o">-</span> <span class="n">old_posterior</span>

            <span class="n">draw</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">uniform</span><span class="p">([</span><span class="mi">1</span><span class="p">],</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
            <span class="k">if</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">math</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">draw</span><span class="p">)</span> <span class="o">&lt;</span> <span class="n">posterior_ratio</span><span class="p">):</span>
                <span class="n">param</span> <span class="o">=</span> <span class="n">new_param</span>
                <span class="n">old_posterior</span> <span class="o">=</span> <span class="n">new_posterior</span>
                <span class="k">if</span><span class="p">(</span><span class="n">debug</span><span class="p">):</span>
                    <span class="n">tf</span><span class="p">.</span><span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Accepted new parameter value at iteration </span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s">: </span><span class="si">{</span><span class="n">param</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
            
            <span class="k">if</span><span class="p">(</span><span class="n">i</span> <span class="o">%</span> <span class="n">sample_iter</span> <span class="o">==</span> <span class="mi">0</span><span class="p">):</span>
                <span class="n">return_params</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">param</span><span class="p">)</span>
    
    <span class="k">return</span> <span class="n">return_params</span>


<span class="k">def</span> <span class="nf">train_flow_model</span><span class="p">(</span><span class="n">training_input</span><span class="p">,</span> <span class="n">settings</span><span class="p">,</span> <span class="n">debug</span> <span class="o">=</span> <span class="bp">True</span><span class="p">):</span>
    <span class="n">model</span> <span class="o">=</span> <span class="n">LogConditionalAutoregressiveFlow</span><span class="p">(</span><span class="n">settings</span><span class="p">[</span><span class="s">"num_params"</span><span class="p">],</span> 
    <span class="n">settings</span><span class="p">[</span><span class="s">"num_dimensions"</span><span class="p">],</span> <span class="n">settings</span><span class="p">[</span><span class="s">"num_flows"</span><span class="p">],</span> <span class="n">settings</span><span class="p">[</span><span class="s">"internal_dim"</span><span class="p">],</span> <span class="n">lr</span><span class="o">=</span><span class="mf">1e-3</span><span class="p">)</span>

    <span class="n">patience</span> <span class="o">=</span> <span class="mi">25</span>
    <span class="n">min_delta</span> <span class="o">=</span> <span class="mf">1e-2</span>
    <span class="n">best_loss</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span><span class="s">'inf'</span><span class="p">)</span>
    <span class="n">waited_epochs</span> <span class="o">=</span> <span class="mi">0</span>

    <span class="n">losses</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">settings</span><span class="p">[</span><span class="s">"training_iterations"</span><span class="p">]):</span>
        <span class="n">loss</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">train_step</span><span class="p">(</span><span class="n">training_input</span><span class="p">)[</span><span class="s">'loss'</span><span class="p">]</span>
        <span class="n">losses</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span>
        <span class="k">if</span><span class="p">(</span><span class="n">debug</span><span class="p">):</span>
            <span class="n">tf</span><span class="p">.</span><span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Loss at Epoch </span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s">: </span><span class="si">{</span><span class="n">loss</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>

        <span class="k">if</span><span class="p">(</span><span class="n">loss</span> <span class="o">&lt;</span> <span class="n">best_loss</span> <span class="o">-</span> <span class="n">min_delta</span><span class="p">):</span>
            <span class="n">best_loss</span> <span class="o">=</span> <span class="n">loss</span>
            <span class="n">waited_epochs</span> <span class="o">=</span> <span class="mi">0</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">waited_epochs</span> <span class="o">+=</span> <span class="mi">1</span>

        <span class="k">if</span><span class="p">(</span><span class="n">waited_epochs</span> <span class="o">&gt;=</span> <span class="n">patience</span><span class="p">):</span>
            <span class="k">if</span><span class="p">(</span><span class="n">debug</span><span class="p">):</span>
                <span class="n">tf</span><span class="p">.</span><span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Stopping early at epoch </span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
            <span class="k">break</span>
    
    <span class="k">return</span> <span class="n">model</span>

<span class="k">def</span> <span class="nf">snle_step</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">prior</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">settings</span><span class="p">):</span>
    <span class="n">training_data_list</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="n">training_param_list</span> <span class="o">=</span> <span class="p">[]</span>

    <span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">settings</span><span class="p">[</span><span class="s">"num_mcmc_chains"</span><span class="p">]):</span>
        <span class="n">param</span> <span class="o">=</span> <span class="n">mcmc</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">prior</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">settings</span><span class="p">[</span><span class="s">"mcmc_iterations"</span><span class="p">])</span>
        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">settings</span><span class="p">[</span><span class="s">"num_simulations"</span><span class="p">]):</span>
            <span class="k">for</span> <span class="n">y</span> <span class="ow">in</span> <span class="n">param</span><span class="p">:</span>
                <span class="n">t_data</span> <span class="o">=</span> <span class="n">generative_model</span><span class="p">(</span><span class="n">y</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">y</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">y</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">y</span><span class="p">[</span><span class="mi">3</span><span class="p">],</span> <span class="n">total_time</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span> <span class="n">samples_per_unit_time</span> <span class="o">=</span> <span class="mi">10</span><span class="p">)</span>
                <span class="n">training_param_list</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">y</span><span class="p">)</span>
                <span class="n">training_data_list</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">t_data</span><span class="p">)</span>

    <span class="k">return</span> <span class="n">training_param_list</span><span class="p">,</span> <span class="n">training_data_list</span>

<span class="k">if</span> <span class="n">__name__</span> <span class="o">==</span> <span class="s">"__main__"</span><span class="p">:</span>
    <span class="n">settings</span> <span class="o">=</span> <span class="p">{</span><span class="s">"num_params"</span><span class="p">:</span> <span class="mi">6</span><span class="p">,</span> <span class="s">"num_dimensions"</span><span class="p">:</span> <span class="mi">2</span><span class="p">,</span> 
                <span class="s">"num_flows"</span><span class="p">:</span> <span class="mi">3</span><span class="p">,</span> <span class="s">"internal_dim"</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span>
                <span class="s">"num_retrainings"</span><span class="p">:</span> <span class="mi">4</span><span class="p">,</span> <span class="s">"num_simulations"</span><span class="p">:</span> <span class="mi">5</span><span class="p">,</span>
                <span class="s">"num_mcmc_chains"</span><span class="p">:</span> <span class="mi">100</span><span class="p">,</span> <span class="s">"mcmc_iterations"</span><span class="p">:</span> <span class="mi">2000</span><span class="p">,</span>
                <span class="s">"training_iterations"</span><span class="p">:</span> <span class="mi">2500</span><span class="p">}</span>
    
    <span class="n">training_data_list</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="n">training_param_list</span> <span class="o">=</span> <span class="p">[]</span>

    <span class="n">model</span> <span class="o">=</span> <span class="bp">None</span>
    <span class="n">prior</span> <span class="o">=</span> <span class="n">tfd</span><span class="p">.</span><span class="n">Exponential</span><span class="p">(</span><span class="n">rate</span> <span class="o">=</span> <span class="p">[</span><span class="mi">4</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">20</span><span class="p">,</span> <span class="mi">4</span><span class="p">])</span>
    <span class="n">observed_data</span> <span class="o">=</span> <span class="n">generative_model</span><span class="p">(</span><span class="mf">0.75</span><span class="p">,</span> <span class="mf">0.01</span><span class="p">,</span> <span class="mf">0.01</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> 
                                    <span class="n">init_predators</span> <span class="o">=</span> <span class="mi">30</span><span class="p">,</span> <span class="n">init_prey</span> <span class="o">=</span> <span class="mi">60</span><span class="p">,</span> <span class="n">total_time</span> <span class="o">=</span> <span class="mi">25</span><span class="p">,</span> <span class="n">samples_per_unit_time</span> <span class="o">=</span> <span class="mi">10</span><span class="p">)</span>
    <span class="k">print</span><span class="p">(</span><span class="s">"We generated:"</span><span class="p">)</span>
    <span class="k">print</span><span class="p">(</span><span class="n">observed_data</span><span class="p">)</span>

    <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">settings</span><span class="p">[</span><span class="s">"num_retrainings"</span><span class="p">]):</span>
        <span class="n">param</span><span class="p">,</span> <span class="n">data</span> <span class="o">=</span> <span class="n">snle_step</span><span class="p">(</span><span class="n">observed_data</span><span class="p">,</span> <span class="n">prior</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">settings</span><span class="p">)</span>

        <span class="k">for</span> <span class="n">d</span> <span class="ow">in</span> <span class="n">data</span><span class="p">:</span>
            <span class="n">training_data_list</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">d</span><span class="p">)</span>
        <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">param</span><span class="p">:</span>
            <span class="n">training_param_list</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">tile</span><span class="p">(</span><span class="n">tf</span><span class="p">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">),</span> <span class="p">[</span><span class="mi">99</span><span class="p">,</span><span class="mi">1</span><span class="p">]))</span> <span class="c1"># 99 time samples per param because one is sliced off
</span>        
        <span class="n">training_data</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">concat</span><span class="p">(</span><span class="n">training_data_list</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
        <span class="n">training_param</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">concat</span><span class="p">(</span><span class="n">training_param_list</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>

        <span class="n">training_input</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">concat</span><span class="p">([</span><span class="n">training_param</span><span class="p">,</span> <span class="n">training_data</span><span class="p">],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>

        <span class="n">model</span> <span class="o">=</span> <span class="n">train_flow_model</span><span class="p">(</span><span class="n">training_input</span><span class="p">,</span> <span class="n">settings</span><span class="p">)</span>

    <span class="n">fig</span><span class="p">,</span> <span class="n">axes</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span><span class="mi">2</span><span class="p">)</span>

    <span class="k">for</span> <span class="n">n</span><span class="p">,</span> <span class="n">ax</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">axes</span><span class="p">.</span><span class="n">flat</span><span class="p">,</span> <span class="n">start</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
        <span class="n">param_input</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">constant</span><span class="p">([</span><span class="mf">0.75</span><span class="p">,</span> <span class="mf">0.01</span><span class="p">,</span> <span class="mf">0.01</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="mf">60.0</span><span class="p">,</span> <span class="mf">30.0</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="p">.</span><span class="n">float32</span><span class="p">)</span>
        <span class="n">prey_counts</span> <span class="o">=</span> <span class="p">[</span><span class="mf">60.0</span><span class="p">]</span>
        <span class="n">predator_counts</span> <span class="o">=</span> <span class="p">[</span><span class="mf">30.0</span><span class="p">]</span>
        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">250</span><span class="p">):</span>
            <span class="n">next_step</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">draw</span><span class="p">(</span><span class="n">param_input</span><span class="p">,</span> <span class="n">num_draws</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
            <span class="n">predator</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="mf">1e-6</span><span class="p">,</span> <span class="n">next_step</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">].</span><span class="n">numpy</span><span class="p">()</span> <span class="o">*</span> <span class="n">predator_counts</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
            <span class="n">prey</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="mf">1e-6</span><span class="p">,</span> <span class="n">next_step</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">].</span><span class="n">numpy</span><span class="p">()</span> <span class="o">*</span> <span class="n">prey_counts</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>

            <span class="n">predator_counts</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">predator</span><span class="p">)</span>
            <span class="n">prey_counts</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">prey</span><span class="p">)</span>

            <span class="n">param_input</span> <span class="o">=</span> <span class="n">tf</span><span class="p">.</span><span class="n">constant</span><span class="p">([</span><span class="mf">0.75</span><span class="p">,</span> <span class="mf">0.01</span><span class="p">,</span> <span class="mf">0.01</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="n">predator</span><span class="p">,</span> <span class="n">prey</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="p">.</span><span class="n">float32</span><span class="p">)</span>

        <span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">prey_counts</span><span class="p">[</span><span class="mi">1</span><span class="p">:],</span> <span class="n">label</span><span class="o">=</span><span class="s">"Prey"</span><span class="p">)</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">predator_counts</span><span class="p">[</span><span class="mi">1</span><span class="p">:],</span> <span class="n">label</span><span class="o">=</span><span class="s">"Predators"</span><span class="p">)</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s">"Count"</span><span class="p">)</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s">"Time"</span><span class="p">)</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>

    <span class="n">fig</span><span class="p">.</span><span class="n">tight_layout</span><span class="p">()</span>
    <span class="n">fig</span><span class="p">.</span><span class="n">savefig</span><span class="p">(</span><span class="s">"./lotka_volterra_trajectories.png"</span><span class="p">)</span>
</code></pre></div></div>]]></content><author><name></name></author><category term="MCMC" /><category term="Neural Networks" /><category term="mcmc" /><category term="neural networks" /><summary type="html"><![CDATA[Many interesting systems in science can’t be described by closed-form probability distributions. This makes them hard to analyze using classical statistical methods. One common example in biology is the stochastic Lotka–Volterra model, which describes the dynamics of of random predator-prey interactions. It’s a simple model to simulate using the Gillespie algorithm, even though it lacks a closed-form likelihood function. Briefly, this is how the simulation works: Initialize a population of predators (\(a\)) and prey (\(b\)), along with per-capita rates of birth (\(\beta_\text{prey}, \beta_\text{predator}\)), death (\(\gamma_\text{prey}, \gamma_\text{predator}\)), and predation (\(\epsilon\)). Compute the event rates: Prey birth: \(b \cdot \beta_\text{prey}\) Prey death: \(b \cdot \gamma_\text{prey}\) Predator birth: \(a \cdot b \cdot \beta_\text{predator}\) Predation: \(a \cdot b \cdot \epsilon\) Predator death: \(a \cdot \gamma_\text{predator}\) Sum all event rates to get the total event rate \(r\), and draw a waiting time \(t \sim \text{Exponential}(r)\). Choose an event with probability proportional to its rate (e.g., the probability of predation is \(\frac{a \cdot b \cdot \epsilon}{r}\)), then update the system accordingly. Repeat steps 2–4 until a desired simulation time is reached. Despite being easy to simulate, this model has no closed-form likelihood, which makes traditional parameter estimation difficult. However, if we simulate the model many times under a fixed set of parameters, we get an empirical distribution of likely outcomes, which is essentially a way to score how well different parameters explain the data. Running millions of simulations for every parameter setting isn’t feasible. Instead, we can train a neural network to approximate this probability distribution based on a finite number of simulations. There are many kinds of these techniques, but here I’ll focus on one: Sequential Neural Likelihood Estimation (SNLE). SNLE uses normalizing flows to transform a simple base distribution (like a multivariate Gaussian) into one that mimics the complex data-generating process of the simulator. Over successive rounds, the method refines its approximation by focusing simulations on more plausible regions of the parameter space, allowing efficient and accurate inference in models where there is no likelihood. In general, the algorithm looks like this: Initialize a prior over parameters, \(P(\theta)\), and an autoregressive conditional normalizing flow (ACNF). The ACNF transforms a base multivariate Gaussian into a more flexible distribution using a series of affine transformations. Each transformation is conditioned on the preceding dimensions and parameter values. For more details, see Papamakarios et al. (2018). Generate training data by sampling parameter vectors \(\theta\) from the prior \(P(\theta)\) and simulating data from your generative model. Train the ACNF to approximate the likelihood \(L_x(\theta)\) based on these parameter–simulation pairs. This gives a crude estimate of the likelihood for the observed data \(x\). Run MCMC using the product \(P(\theta) L_x(\theta)\) as the unnormalized posterior. For each sampled \(\theta\), simulate new data under the model. Retrain the ACNF on all accumulated parameter–simulation pairs, including both prior samples and MCMC-based samples. This improves the likelihood approximation in regions of high posterior density. Repeat steps 4–5 until the ACNF has adequately learned the likelihood in the relevant region of the parameter space. I spent some time earlier this year trying to understand these methods and so I thought I would share a project of mine on the blog. Although I am unsure what I would use this method for in my own research, it is such a clever application of MCMC that I couldn’t help but implement it. The code below trains a conditional autoregressive flow neural network to learn an autoregressive multiplicative random walk that produces Lotka-Volterra dynamics. While Papamakarios et al. (2019) model the full trajectory as a static vector, I chose to model the incremental dynamics autoregressively. This better respects the sequential structure of the Lotka–Volterra system and allows the learned flow to generate new trajectories step-by-step. Below is an image of a four draws from the trained conditional autoregressive flow; the positions and magnitude of the oscillations shown below are a classic Lotka-Volterra pattern. Although I will not print the output here, SNLE also does a pretty good job at inferring the true parameter value of many of my test simulations. conditional_autoregressive_flow.py import tensorflow as tf from tensorflow import keras import tensorflow_probability as tfp import matplotlib.pyplot as plt tfd = tfp.distributions class ConditionalAutoregressiveFlow(keras.Model): """ A neural network that autoregressively applies affine transformations to an N dimensional normal distribution conditional on some provided parameter. Rather than doing masking like in masked autoregressive flows, we are just explicitly enforcing the autoregressive nature of the network in the structure of the network itself. """ def __init__(self, num_params, num_dimensions, num_flows, internal_dim=16, lr=1e-2): super(ConditionalAutoregressiveFlow, self).__init__() self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr, clipnorm=1.0) self.num_params = num_params self.num_dimensions = num_dimensions self.internal_dim = internal_dim self.num_flows = num_flows self.autoregressive_networks = [] for n in range(num_flows): permutation = tf.random.shuffle(tf.range(0, num_dimensions, dtype=tf.int32)) if(n &gt; 0): while tf.math.reduce_all(permutation == self.autoregressive_networks[-1]["permutation"]): # We want to make sure each layer is different from the last permutation = tf.random.shuffle(tf.range(0, num_dimensions, dtype=tf.int32)) network = [] for i in range(num_dimensions): network.append({ "dense": keras.layers.Dense(internal_dim, activation="relu"), "alpha": keras.layers.Dense(1, activation="softplus",kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.5, stddev=0.1, seed=None)), "beta": keras.layers.Dense(1, activation="linear", kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=1.0, seed=None)), }) self.autoregressive_networks.append({"network": network, "permutation": permutation}) self.base_dist = tfd.MultivariateNormalDiag(loc=tf.ones([num_dimensions])) def call(self, flow_input): """Normalizing Direction""" conditionals = flow_input[:, :self.num_params] data = flow_input[:, self.num_params:] tf.debugging.assert_equal(data.shape[1], self.num_dimensions, f"Error: {self.num_dimensions} dimensions of data were expected, but we only got {data.shape[1]}") output_list = [data[:, i] for i in range(self.num_dimensions)] jacobian_sum = 0 for n in range(self.num_flows): temp_conditionals = conditionals network_obj = self.autoregressive_networks[n] autoregressive_network = network_obj["network"] network_permutation = network_obj["permutation"] for i in range(self.num_dimensions): index = network_permutation[i] y = output_list[index] flow = autoregressive_network[index] x = flow["dense"](temp_conditionals) alpha = tf.squeeze(flow["alpha"](x)) beta = tf.squeeze(flow["beta"](x)) y = tf.divide(y - beta, alpha) output_list[index] = y temp_conditionals = tf.concat([temp_conditionals, tf.expand_dims(y, axis=-1)], axis=-1) jacobian_sum += tf.math.log(alpha) output = tf.stack(output_list, axis=-1) log_prob = self.base_dist.log_prob(output) log_prob -= jacobian_sum return log_prob def transform(self, flow_input): """Generative Direction""" conditionals = flow_input[:, :self.num_params] data = flow_input[:, self.num_params:] tf.debugging.assert_equal(data.shape[1], self.num_dimensions, f"Error: {self.num_dimensions} dimensions of data were expected, but we only got {data.shape[1]}") output_list = [data[:, i] for i in range(self.num_dimensions)] for n in range(self.num_flows-1, -1, -1): temp_conditionals = conditionals network_obj = self.autoregressive_networks[n] autoregressive_network = network_obj["network"] network_permutation = network_obj["permutation"] for i in range(self.num_dimensions): index = network_permutation[i] y = output_list[index] flow = autoregressive_network[index] x = flow["dense"](temp_conditionals) alpha = tf.squeeze(flow["alpha"](x)) beta = tf.squeeze(flow["beta"](x)) y = tf.multiply(y, alpha) + beta temp_conditionals = tf.concat([temp_conditionals, tf.expand_dims(output_list[index], axis=-1)], axis=-1) output_list[index] = y output = tf.stack(output_list, axis=-1) return output def draw(self, parameters, num_draws = 1): """Generative Direction""" conditionals = tf.tile(tf.expand_dims(parameters, axis=0), [num_draws,1]) data = self.base_dist.sample(num_draws) flow_input = tf.concat([conditionals, data], axis=-1) return self.transform(flow_input) def train_step(self, data): """Negative log-likelihood loss""" with tf.GradientTape() as tape: log_likelihood = self(data) loss = -tf.reduce_mean(log_likelihood) trainable_vars = self.trainable_variables gradients = tape.gradient(loss, trainable_vars) self.optimizer.apply_gradients(zip(gradients, trainable_vars)) return {"loss": loss} if __name__=="__main__": scale = tf.random.uniform([500], 4.0, 8.0, dtype=tf.float32) scale = tf.reshape(tf.repeat(scale, 1000), [-1]) quandrant_flag = tfd.Bernoulli(0.5, dtype=tf.bool).sample(scale.shape[0]) mean_vec = tf.where(quandrant_flag, scale * -1.0, scale) training_samples = tf.squeeze(tfd.MultivariateNormalDiag([mean_vec, mean_vec*-1]).sample(1)) training_input = tf.concat([tf.expand_dims(scale, axis=-1), tf.transpose(training_samples)], axis=-1) test_af = ConditionalAutoregressiveFlow(num_params = 1, num_dimensions = 2, num_flows = 5, internal_dim = 64) losses = [] for i in range(500): loss = test_af.train_step(training_input)['loss'] losses.append(loss) print(f"Loss at Epoch {i}: {loss}") transformed_data = test_af.draw([5.5], num_draws=1000) dummy_data = tfd.MultivariateNormalDiag(loc=tf.ones([2])).sample(1000) dummy_data_input = tf.concat([tf.fill([dummy_data.shape[0], 1], 7.5), dummy_data], axis=-1) transformed_data_2 = test_af.transform(dummy_data_input) fig1, ax1 = plt.subplots() ax1.scatter(x = transformed_data[:, 0], y = transformed_data[:, 1], label="Scale = 5.5", alpha=0.5) ax1.scatter(x = transformed_data_2[:, 0], y = transformed_data_2[:, 1], label="Scale = 7.5", alpha=0.5) ax1.legend() ax1.set_ylabel("Y") ax1.set_xlabel("X") ax1.set_title("Transformation") fig1.savefig("./conditional_af_transformation.png") x_min = -10 x_max = 10 y_min = -10 y_max = 10 num_x = 100 num_y = 100 x = tf.linspace(x_min, x_max, num_x) y = tf.linspace(y_min, y_max, num_y) X, Y = tf.meshgrid(x, y) coordinates = tf.stack([tf.reshape(X, [-1]), tf.reshape(Y, [-1])], axis=-1) coordinates = tf.cast(coordinates, tf.float32) coordinate_input = tf.concat([tf.fill([coordinates.shape[0], 1], 7.5), coordinates], axis=-1) probs = tf.math.exp(test_af(coordinate_input)) heatmap = tf.reshape(probs, (num_y, num_x)) fig2, ax2 = plt.subplots() ax2.imshow(heatmap.numpy(), extent=[x_min, x_max, y_min, y_max], origin='lower', cmap='viridis') ax2.set_ylabel("Y") ax2.set_xlabel("X") ax2.set_title("Transform Probability") fig2.savefig("./conditional_af_pdf.png") sequential_neural_likelihood_lotka_volterra.py import tensorflow as tf from tensorflow import keras import tensorflow_probability as tfp import matplotlib.pyplot as plt from conditional_autoregressive_flow import ConditionalAutoregressiveFlow import math tfd = tfp.distributions class LogConditionalAutoregressiveFlow(ConditionalAutoregressiveFlow): """ Our conditionals interact on a multiplicative scale, and we are modeling a multiplicative random walk - so our neural network should reflect that. We log transform all of our conditionals to express that. """ def call(self, flow_input): """Normalizing Direction""" conditionals = tf.math.log(flow_input[:, :self.num_params]) data = tf.math.log(flow_input[:, self.num_params:]) tf.debugging.assert_equal(data.shape[1], self.num_dimensions, f"Error: {self.num_dimensions} dimensions of data were expected, but we only got {data.shape[1]}") output_list = [data[:, i] for i in range(self.num_dimensions)] # Adjust for the log transform (-1 * log(original data)) jacobian_sum = -1 * tf.reduce_sum(data, axis=-1) for n in range(self.num_flows): temp_conditionals = conditionals network_obj = self.autoregressive_networks[n] autoregressive_network = network_obj["network"] network_permutation = network_obj["permutation"] for i in range(self.num_dimensions): index = network_permutation[i] y = output_list[index] flow = autoregressive_network[index] x = flow["dense"](temp_conditionals) alpha = tf.squeeze(flow["alpha"](x)) beta = tf.squeeze(flow["beta"](x)) y = tf.divide(y - beta, alpha) output_list[index] = y temp_conditionals = tf.concat([temp_conditionals, tf.expand_dims(y, axis=-1)], axis=-1) jacobian_sum += tf.math.log(alpha) output = tf.stack(output_list, axis=-1) log_prob = self.base_dist.log_prob(output) log_prob -= jacobian_sum return log_prob def transform(self, flow_input): """Generative Direction""" conditionals = tf.math.log(flow_input[:, :self.num_params]) data = flow_input[:, self.num_params:] tf.debugging.assert_equal(data.shape[1], self.num_dimensions, f"Error: {self.num_dimensions} dimensions of data were expected, but we only got {data.shape[1]}") output_list = [data[:, i] for i in range(self.num_dimensions)] for n in range(self.num_flows-1, -1, -1): temp_conditionals = conditionals network_obj = self.autoregressive_networks[n] autoregressive_network = network_obj["network"] network_permutation = network_obj["permutation"] for i in range(self.num_dimensions): index = network_permutation[i] y = output_list[index] flow = autoregressive_network[index] x = flow["dense"](temp_conditionals) alpha = tf.squeeze(flow["alpha"](x)) beta = tf.squeeze(flow["beta"](x)) y = tf.multiply(y, alpha) + beta temp_conditionals = tf.concat([temp_conditionals, tf.expand_dims(output_list[index], axis=-1)], axis=-1) output_list[index] = y output = tf.stack(output_list, axis=-1) return tf.math.exp(output) def generative_model(prey_birth, predation, predator_birth, predator_death, init_predators=30, init_prey=60, total_time=100, samples_per_unit_time=5): """ Simulate a time series of stochastic Lotka-Volterra dynamics. Here we will output the multiplicative increase from the previous step. """ # Enforce minimum for simplicity prey_birth = prey_birth + 1e-2 predation = predation + 1e-2 predator_birth = predator_birth + 1e-2 predator_death = predator_death + 1e-2 times = [0.0] predators = [init_predators] prey = [init_prey] while True: # Get current population prey_count = prey[-1] predator_count = predators[-1] # Determine event rates prey_birth_rate = prey_birth*prey_count predation_rate = predation*prey_count*predator_count predator_death_rate = predator_death*predator_count predator_birth_rate = predator_birth*predator_count*prey_count # Get total "race" rate total_rate = prey_birth_rate + predation_rate + predator_death_rate + predator_birth_rate if(total_rate == 0): # Total extinction break # These are what we based the uniform draw on prey_birth_cumulative = prey_birth_rate/total_rate predation_cumulative = predation_rate/total_rate + prey_birth_cumulative predator_death_cumulative = predator_death_rate/total_rate + predation_cumulative waiting_time = tfd.Exponential(total_rate).sample(1) new_time = times[-1] + waiting_time if(new_time &lt; total_time): times.append(new_time) # Draw a random event (birth, death, predation) action_draw = tf.random.uniform([1], 0.0, 1.0) if(action_draw &lt; prey_birth_cumulative): predators.append(predator_count) prey.append(prey_count + 1.0) elif(action_draw &lt; predation_cumulative): predators.append(predator_count) prey.append(prey_count - 1.0) elif(action_draw &lt; predator_death_cumulative): predators.append(predator_count - 1.0) prey.append(prey_count) else: predators.append(predator_count + 1.0) prey.append(prey_count) else: break time_sampled_prey = [] time_sampled_predator = [] increment = 1/samples_per_unit_time for i in range(total_time * samples_per_unit_time): last_event = 0 for t in range(len(times)): if(times[t] &lt;= i * increment): last_event = t else: break time_sampled_predator.append(predators[t]) time_sampled_prey.append(prey[t]) current_dim = tf.stack([time_sampled_predator, time_sampled_prey], axis=-1) prev_dim = tf.roll(current_dim, shift=1, axis=0) # Now we get the increase multiplicative_increase = tf.math.divide(current_dim, prev_dim) multiplicative_increase = tf.where(tf.math.is_nan(multiplicative_increase), 1e-6, multiplicative_increase) multiplicative_increase = tf.where(multiplicative_increase &lt;= 0, 1e-6, multiplicative_increase) prev_dim = tf.where(prev_dim == 0, 1e-6, prev_dim) # Roll moves the last element to the first, so we shave that off output = tf.concat([prev_dim, multiplicative_increase], axis=-1)[1:, :] return output def parameter_update(parameter): param_choice = tf.random.uniform([1], 0, 4, dtype=tf.int32) # Pick random param param_choice = tf.reshape(param_choice, [1, 1]) scaler = tf.math.exp(2.0 * tf.random.uniform([1], 0, 1) - 0.5) updated_value = parameter[param_choice[0, 0]] * scaler parameter = tf.tensor_scatter_nd_update(parameter, param_choice, updated_value) return parameter, tf.math.log(scaler) def mcmc(data, prior, likelihood = None, iterations = 1000, debug = True, sample_iter = 1000): """ If a likelihood is not passed we will do Metropolis Hastings on the prior alone. For now we are just going to be doing scale moves on the parameter when it comes to working with the product of the prior and model. For the prior alone we will just draw directly from the prior. This function the final state of the Markov chain, which can be plugged into the generative model. """ return_params = [] if(likelihood is None): for i in range(int(iterations/sample_iter)): param = tf.squeeze(prior.sample()) if(debug): tf.print(f"Drawing parameter from prior: {param}") return_params.append(param) else: param = tf.squeeze(prior.sample()) old_posterior = tf.reduce_sum(prior.log_prob(param)) + tf.reduce_sum(likelihood(tf.concat([tf.tile(tf.expand_dims(param, axis=0), [data.shape[0], 1]), data], axis=-1))) for i in range(1, iterations+1): new_param, hastings = parameter_update(param) new_posterior = tf.reduce_sum(prior.log_prob(new_param)) + tf.reduce_sum(likelihood(tf.concat([tf.tile(tf.expand_dims(new_param, axis=0), [data.shape[0], 1]), data], axis=-1))) posterior_ratio = hastings + new_posterior - old_posterior draw = tf.random.uniform([1], 0, 1) if(tf.math.log(draw) &lt; posterior_ratio): param = new_param old_posterior = new_posterior if(debug): tf.print(f"Accepted new parameter value at iteration {i}: {param}") if(i % sample_iter == 0): return_params.append(param) return return_params def train_flow_model(training_input, settings, debug = True): model = LogConditionalAutoregressiveFlow(settings["num_params"], settings["num_dimensions"], settings["num_flows"], settings["internal_dim"], lr=1e-3) patience = 25 min_delta = 1e-2 best_loss = float('inf') waited_epochs = 0 losses = [] for i in range(settings["training_iterations"]): loss = model.train_step(training_input)['loss'] losses.append(loss) if(debug): tf.print(f"Loss at Epoch {i}: {loss}") if(loss &lt; best_loss - min_delta): best_loss = loss waited_epochs = 0 else: waited_epochs += 1 if(waited_epochs &gt;= patience): if(debug): tf.print(f"Stopping early at epoch {i}") break return model def snle_step(data, prior, model, settings): training_data_list = [] training_param_list = [] for n in range(settings["num_mcmc_chains"]): param = mcmc(data, prior, model, settings["mcmc_iterations"]) for i in range(settings["num_simulations"]): for y in param: t_data = generative_model(y[0], y[1], y[2], y[3], total_time = 10, samples_per_unit_time = 10) training_param_list.append(y) training_data_list.append(t_data) return training_param_list, training_data_list if __name__ == "__main__": settings = {"num_params": 6, "num_dimensions": 2, "num_flows": 3, "internal_dim": 64, "num_retrainings": 4, "num_simulations": 5, "num_mcmc_chains": 100, "mcmc_iterations": 2000, "training_iterations": 2500} training_data_list = [] training_param_list = [] model = None prior = tfd.Exponential(rate = [4, 20, 20, 4]) observed_data = generative_model(0.75, 0.01, 0.01, 1.0, init_predators = 30, init_prey = 60, total_time = 25, samples_per_unit_time = 10) print("We generated:") print(observed_data) for i in range(settings["num_retrainings"]): param, data = snle_step(observed_data, prior, model, settings) for d in data: training_data_list.append(d) for p in param: training_param_list.append(tf.tile(tf.expand_dims(p, axis=0), [99,1])) # 99 time samples per param because one is sliced off training_data = tf.concat(training_data_list, axis=0) training_param = tf.concat(training_param_list, axis=0) training_input = tf.concat([training_param, training_data], axis=-1) model = train_flow_model(training_input, settings) fig, axes = plt.subplots(2,2) for n, ax in enumerate(axes.flat, start=1): param_input = tf.constant([0.75, 0.01, 0.01, 1.0, 60.0, 30.0], dtype=tf.float32) prey_counts = [60.0] predator_counts = [30.0] for i in range(250): next_step = model.draw(param_input, num_draws=1) predator = max(1e-6, next_step[0, 0].numpy() * predator_counts[-1]) prey = max(1e-6, next_step[0, 1].numpy() * prey_counts[-1]) predator_counts.append(predator) prey_counts.append(prey) param_input = tf.constant([0.75, 0.01, 0.01, 1.0, predator, prey], dtype=tf.float32) ax.plot(prey_counts[1:], label="Prey") ax.plot(predator_counts[1:], label="Predators") ax.set_ylabel("Count") ax.set_xlabel("Time") ax.legend() fig.tight_layout() fig.savefig("./lotka_volterra_trajectories.png")]]></summary></entry><entry><title type="html">Variational Bayes from a Generalized Bayesian Inference Perspective</title><link href="https://wesley-demontigny.github.io/bayesian/variational%20inference/2025/04/09/Generalized_Bayesian_Inference_VI.html" rel="alternate" type="text/html" title="Variational Bayes from a Generalized Bayesian Inference Perspective" /><published>2025-04-09T00:00:00+00:00</published><updated>2025-04-09T00:00:00+00:00</updated><id>https://wesley-demontigny.github.io/bayesian/variational%20inference/2025/04/09/Generalized_Bayesian_Inference_VI</id><content type="html" xml:base="https://wesley-demontigny.github.io/bayesian/variational%20inference/2025/04/09/Generalized_Bayesian_Inference_VI.html"><![CDATA[<script type="text/javascript" id="MathJax-script" async="" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js">
</script>

<p>Markov Chain Monte Carlo (MCMC) is an amazing tool for Bayesian inference, but it’s unfortunately quite computationally intensive. Often, that’s fine, and we’re willing to wait for an asymptotically exact answer. But that isn’t always the case. Certain models are so complicated that the Markov chain will fail to converge until we are all dead. And depending on the application, even a day may be far too long to wait for a model to be fit! In these situations, we need an alternative to MCMC.</p>

<p>Variational inference (VI) is one such alternative. It aims to approximate the posterior distribution by using a tractable “variational” distribution to represent the true posterior. This turns Bayesian inference into an optimization problem: instead of sampling from the posterior, we optimize the parameters of the variational distribution to make it as close to the true posterior as possible. For example, suppose we believe that our posterior is roughly Gaussian-shaped. In that case, we may use a Gaussian as our variational distribution for our parameter \(\theta\), and optimize the mean \(\mu\) and variance \(\sigma^2\) accordingly. While this approach has its downsides, it is generally much faster than MCMC.</p>

<p>When I first encountered these methods, I thought it was interesting that the derivation for the objective function- the <strong>E</strong>vidence <strong>L</strong>ower <strong>Bo</strong>und or ELBO - wasn’t immediately obvious. Sure, it all makes sense once you know it, but it would take some thought if I were the first to derive it. Lately, I’ve been reading a lot of material on Generalized Bayesian Inference (GBI), which I will review later in this post. Once you really understand that framework, I think it offers a much stronger motivation for the ELBO than the traditional derivation. Additionally, under GBI, VI can be seen as a fully Bayesian approach, rather than an approximation.</p>
<h3 id="traditional-derivations-for-variational-inference">Traditional Derivations for Variational Inference</h3>
<p>Suppose we have some variational distribution \(q(\theta | \psi)\) that we want to use as an approximation for the posterior distribution \(p(\theta\mid x)\). The most sensible thing to do would be to pick the parameter \(\psi\) such that the KL divergence between the variational distribution and posterior is minimized:</p>

\[KL(q(\theta|\psi) ||p(\theta|x)) = E_{\theta \sim q}(\text{log }\frac{q(\theta|\psi)}{p(\theta|x)})\]

<p>We can expand \(p(\theta \mid x) = \frac{p(\theta, x)}{p(x)}\), so that we get:</p>

\[=E_{\theta \sim q}(\text{log }\frac{q(\theta|\psi)}{p(\theta,x)} + \text{log }p(x))\]

\[=E_{\theta \sim q}(\text{log }\frac{q(\theta|\psi)}{p(\theta,x)}) + \text{log }p(x) = \text{log } p(x) - \text{ELBO}\]

<p>Since \(p(x)\), the marginal likelihood, is a constant, we ignore it for our optimization and focus on the ELBO. By maximizing the ELBO, we will minimize the KL divergence. We can rewrite the ELBO like so:</p>

\[\text{ELBO} = E_{\theta\sim q}(\text{log } p(\theta,x)) - E_{\theta\sim q}(\text{log } q(\theta|\psi))\]

\[= E_{\theta \sim q}(\text{log }p(x|\theta)) + E_{\theta\sim q}(\text{log } p(\theta)) - E_{\theta\sim q}(\text{log } q(\theta|\psi))\]

\[= E_{\theta \sim q}(\text{log }p(x|\theta)) - KL(q(\theta|\psi) || p(\theta))\]

<p>So, to do variational inference, we pick a \(\psi\) to maximize the expected likelihood and minimize the divergence of the variational distribution and the prior. This all makes sense, but if I had been the first one to derive it, I think I would have stumbled around a bit before coming to the solution. Let’s review GBI and see what insights it has to offer.</p>
<h3 id="the-generalized-bayesian-inference-approach">The Generalized Bayesian Inference Approach</h3>
<p>This is just going to be a quick overview of GBI, but it is SO cool. For a comprehensive treatment, I really recommend Bissiri, Holmes &amp; Walker, 2016 (<a href="https://arxiv.org/abs/1306.6430">A General Framework for Updating Belief Distributions</a>). In brief, GBI aims to extend Bayesian updating to models that do not necessarily have a likelihood but do have some loss function \(\ell_\theta(x)\) that indicates the agreement between some parameter \(\theta\) and observed data \(x\). In such a scenario, how does one update some prior belief \(p(\theta)\)? We can think about all hypothetical posterior distributions \(\pi\), and we will aim to choose the updating scheme resulting in:</p>

\[\pi^* = argmin_\pi R(\pi, p(\theta), x)=f(\pi,p(\theta)) + g(\pi,x)\]

<p>where \(R\) is a risk function composed of a loss function \(f\) that measures the posterior’s agreement with the prior and a loss function \(g\) that measures the posterior’s agreement with the data. This looks like something that would be insanely difficult to solve, but it actually is pretty trivial if we pick some clever choices for \(f\) and \(g\). We will say:</p>

\[f(\pi,p(\theta))=E_{\theta\sim\pi}(\ell_\theta(x))\]

\[g(\pi,x)=KL(\pi(\theta)||p(\theta))\]

<p>We can therefore represent the risk function \(R\) as:</p>

\[R(\pi,p(\theta),x)=\int\pi(\theta)(\ell_\theta(x)+\text{log }\frac{\pi(\theta)}{p(\theta)})d\theta\]

<p>We can move \(\ell_\theta(x)\) inside the log so we get:</p>

\[R(\pi,p(\theta),x)=\int\pi(\theta)\text{log }\frac{\pi(\theta)}{exp(-\ell_\theta(x))p(\theta)}d\theta = KL(\pi(\theta) || exp(-\ell_\theta(x))p(\theta))\]

<p>The posterior that minimizes this KL divergence is just \(\pi^*(\theta) \propto exp(-\ell_\theta(x))p(\theta)\)! We call this the “Gibbs posterior” because it looks like a Gibbs distribution with an energy function \(\ell_\theta\) multiplied by some prior. If we choose a loss function \(\ell_\theta(x) = - \text{log } p(x\mid\theta)\), we get the classical posterior, \(\pi(\theta) \propto p(x\mid\theta)p(\theta)\)! However, something really amazing here is that we could have chosen another loss function, and we would still be able to do Bayesian updating. This isn’t exactly a trivial thing to do, but it is totally legitimate under GBI.</p>

<p>Now, how does this connect to variational inference? Instead of considering all possible posteriors, let’s assume the posterior belongs to a parametric family \(q(\theta\mid\psi)\) for our parameters of interest \(\theta\) and tunable parameters \(\psi\). We will also choose the classical loss function \(\ell_\theta(x) = -\text{log } p(x\mid\theta)\). In that case, our risk minimization looks like:</p>

\[q^*=argmin_\psi R(\psi,p(\theta),x) = argmin_\psi (- E_{\theta\sim q}(\text{log } p(x|\theta)) + KL(q(\theta|\psi)||p(\theta)))\]

<p>We can rewrite this as:</p>

\[q^*=argmax_\psi R(\psi, p(\theta), x)= argmax_\psi (E_{\theta\sim q}(\text{log } p(x|\theta)) - KL(q(\theta|\psi)||p(\theta)))\]

<p>That means we maximize the ELBO! So, under GBI, VI is just a special case where we assume that the distribution takes a particular form.</p>
<h3 id="concluding-remarks">Concluding Remarks</h3>
<p>I hope that all made sense and was useful! I have been thinking about GBI a lot, so this felt more intuitive. Under the GBI perspective, VI isn’t a cheap method for faster inference but the natural method for those willing to constrain their updated belief to a family of distributions. In this sense, VI can be viewed as a fully Bayesian update procedure.</p>]]></content><author><name></name></author><category term="Bayesian" /><category term="Variational Inference" /><category term="generalized bayesian" /><category term="VI" /><category term="inference" /><summary type="html"><![CDATA[Markov Chain Monte Carlo (MCMC) is an amazing tool for Bayesian inference, but it’s unfortunately quite computationally intensive. Often, that’s fine, and we’re willing to wait for an asymptotically exact answer. But that isn’t always the case. Certain models are so complicated that the Markov chain will fail to converge until we are all dead. And depending on the application, even a day may be far too long to wait for a model to be fit! In these situations, we need an alternative to MCMC. Variational inference (VI) is one such alternative. It aims to approximate the posterior distribution by using a tractable “variational” distribution to represent the true posterior. This turns Bayesian inference into an optimization problem: instead of sampling from the posterior, we optimize the parameters of the variational distribution to make it as close to the true posterior as possible. For example, suppose we believe that our posterior is roughly Gaussian-shaped. In that case, we may use a Gaussian as our variational distribution for our parameter \(\theta\), and optimize the mean \(\mu\) and variance \(\sigma^2\) accordingly. While this approach has its downsides, it is generally much faster than MCMC. When I first encountered these methods, I thought it was interesting that the derivation for the objective function- the Evidence Lower Bound or ELBO - wasn’t immediately obvious. Sure, it all makes sense once you know it, but it would take some thought if I were the first to derive it. Lately, I’ve been reading a lot of material on Generalized Bayesian Inference (GBI), which I will review later in this post. Once you really understand that framework, I think it offers a much stronger motivation for the ELBO than the traditional derivation. Additionally, under GBI, VI can be seen as a fully Bayesian approach, rather than an approximation. Traditional Derivations for Variational Inference Suppose we have some variational distribution \(q(\theta | \psi)\) that we want to use as an approximation for the posterior distribution \(p(\theta\mid x)\). The most sensible thing to do would be to pick the parameter \(\psi\) such that the KL divergence between the variational distribution and posterior is minimized: \[KL(q(\theta|\psi) ||p(\theta|x)) = E_{\theta \sim q}(\text{log }\frac{q(\theta|\psi)}{p(\theta|x)})\] We can expand \(p(\theta \mid x) = \frac{p(\theta, x)}{p(x)}\), so that we get: \[=E_{\theta \sim q}(\text{log }\frac{q(\theta|\psi)}{p(\theta,x)} + \text{log }p(x))\] \[=E_{\theta \sim q}(\text{log }\frac{q(\theta|\psi)}{p(\theta,x)}) + \text{log }p(x) = \text{log } p(x) - \text{ELBO}\] Since \(p(x)\), the marginal likelihood, is a constant, we ignore it for our optimization and focus on the ELBO. By maximizing the ELBO, we will minimize the KL divergence. We can rewrite the ELBO like so: \[\text{ELBO} = E_{\theta\sim q}(\text{log } p(\theta,x)) - E_{\theta\sim q}(\text{log } q(\theta|\psi))\] \[= E_{\theta \sim q}(\text{log }p(x|\theta)) + E_{\theta\sim q}(\text{log } p(\theta)) - E_{\theta\sim q}(\text{log } q(\theta|\psi))\] \[= E_{\theta \sim q}(\text{log }p(x|\theta)) - KL(q(\theta|\psi) || p(\theta))\] So, to do variational inference, we pick a \(\psi\) to maximize the expected likelihood and minimize the divergence of the variational distribution and the prior. This all makes sense, but if I had been the first one to derive it, I think I would have stumbled around a bit before coming to the solution. Let’s review GBI and see what insights it has to offer. The Generalized Bayesian Inference Approach This is just going to be a quick overview of GBI, but it is SO cool. For a comprehensive treatment, I really recommend Bissiri, Holmes &amp; Walker, 2016 (A General Framework for Updating Belief Distributions). In brief, GBI aims to extend Bayesian updating to models that do not necessarily have a likelihood but do have some loss function \(\ell_\theta(x)\) that indicates the agreement between some parameter \(\theta\) and observed data \(x\). In such a scenario, how does one update some prior belief \(p(\theta)\)? We can think about all hypothetical posterior distributions \(\pi\), and we will aim to choose the updating scheme resulting in: \[\pi^* = argmin_\pi R(\pi, p(\theta), x)=f(\pi,p(\theta)) + g(\pi,x)\] where \(R\) is a risk function composed of a loss function \(f\) that measures the posterior’s agreement with the prior and a loss function \(g\) that measures the posterior’s agreement with the data. This looks like something that would be insanely difficult to solve, but it actually is pretty trivial if we pick some clever choices for \(f\) and \(g\). We will say: \[f(\pi,p(\theta))=E_{\theta\sim\pi}(\ell_\theta(x))\] \[g(\pi,x)=KL(\pi(\theta)||p(\theta))\] We can therefore represent the risk function \(R\) as: \[R(\pi,p(\theta),x)=\int\pi(\theta)(\ell_\theta(x)+\text{log }\frac{\pi(\theta)}{p(\theta)})d\theta\] We can move \(\ell_\theta(x)\) inside the log so we get: \[R(\pi,p(\theta),x)=\int\pi(\theta)\text{log }\frac{\pi(\theta)}{exp(-\ell_\theta(x))p(\theta)}d\theta = KL(\pi(\theta) || exp(-\ell_\theta(x))p(\theta))\] The posterior that minimizes this KL divergence is just \(\pi^*(\theta) \propto exp(-\ell_\theta(x))p(\theta)\)! We call this the “Gibbs posterior” because it looks like a Gibbs distribution with an energy function \(\ell_\theta\) multiplied by some prior. If we choose a loss function \(\ell_\theta(x) = - \text{log } p(x\mid\theta)\), we get the classical posterior, \(\pi(\theta) \propto p(x\mid\theta)p(\theta)\)! However, something really amazing here is that we could have chosen another loss function, and we would still be able to do Bayesian updating. This isn’t exactly a trivial thing to do, but it is totally legitimate under GBI. Now, how does this connect to variational inference? Instead of considering all possible posteriors, let’s assume the posterior belongs to a parametric family \(q(\theta\mid\psi)\) for our parameters of interest \(\theta\) and tunable parameters \(\psi\). We will also choose the classical loss function \(\ell_\theta(x) = -\text{log } p(x\mid\theta)\). In that case, our risk minimization looks like: \[q^*=argmin_\psi R(\psi,p(\theta),x) = argmin_\psi (- E_{\theta\sim q}(\text{log } p(x|\theta)) + KL(q(\theta|\psi)||p(\theta)))\] We can rewrite this as: \[q^*=argmax_\psi R(\psi, p(\theta), x)= argmax_\psi (E_{\theta\sim q}(\text{log } p(x|\theta)) - KL(q(\theta|\psi)||p(\theta)))\] That means we maximize the ELBO! So, under GBI, VI is just a special case where we assume that the distribution takes a particular form. Concluding Remarks I hope that all made sense and was useful! I have been thinking about GBI a lot, so this felt more intuitive. Under the GBI perspective, VI isn’t a cheap method for faster inference but the natural method for those willing to constrain their updated belief to a family of distributions. In this sense, VI can be viewed as a fully Bayesian update procedure.]]></summary></entry></feed>