1 Conditional Routing Experiment with THRML

1.1 Overview

This notebook implements a conditional routing experiment using synthetic data to evaluate the performance of THRML against traditional multi-armed bandit baselines in a contextual decision-making setting.

1.2 Problem Setting

The experiment simulates an order routing scenario where an agent must select the best venue to route orders. At each step: 1. A context is observed (one venue’s outcome is revealed for free) 2. The agent must select which other venue to route to 3. The agent receives a reward and updates its belief model

1.3 Agents Compared

  • Contextual ε-Greedy: Maintains context-specific success/count statistics with ε=0.1 exploration
  • Contextual Thompson Sampling: Uses Beta-distributed posteriors conditioned on context
  • THRML: Leverages an Ising model to capture correlations between venues and performs probabilistic inference using Gibbs sampling

NOTE — Fair Information Sharing: Baselines (ε-Greedy and Thompson Sampling) update only on the selected venue’s outcome, while THRML performs joint updates on multiple nodes (context + routed). This asymmetry reflects each algorithm’s natural learning structure and is consistent across all scenarios.

1.4 Scenarios

  1. IID Venues: No correlation between venues (correlation_weight=0.0)
  2. Correlated Venues: Venues have positive correlations (correlation_weight=0.4)
  3. Regime Shift: Correlations exist, and venue biases change mid-experiment (step 5000)

1.5 Context Modes

  • Fixed Context: Always observe Venue 0’s outcome as context
  • Random Context: Randomly select which venue provides context each step

1.6 Key Hyperparameters

Parameter Value Description
n_venues 5 Number of trading venues
n_steps 10,000 Steps per experiment run
n_seeds 200 Independent runs for statistical significance
discount_factor 0.995 Forgetting factor for non-stationary adaptation
learning_rate 0.05 THRML learning rate
steps_per_sample 4 Gibbs sampling thinning parameter
propagation_damping 0.3 Mean-field signal propagation factor
epsilon 0.1 ε-Greedy exploration rate

1.7 Output

The notebook produces regret plots comparing all three agents across the three scenarios for both context modes.

Show the code
# Install dependencies for Colab environment
# Install dependencies for Colab T4 GPU
# Note: Run this cell only on Colab
!pip install --quiet jax jaxlib
!pip install --quiet thrml
!pip install --quiet matplotlib
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 182.1/182.1 kB 14.0 MB/s eta 0:00:00

   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.3/56.3 kB 7.3 MB/s eta 0:00:00

Show the code
# Imports for JAX, THRML, and visualization
import jax
import jax.numpy as jnp
from jax import random, lax, vmap, jit
import numpy as np
import matplotlib.pyplot as plt
import time
from typing import NamedTuple, Tuple, Optional, List, Dict
from functools import partial

# THRML imports (JAX-compatible)
from thrml import SpinNode, Block, SamplingSchedule, sample_states
from thrml.models import IsingEBM, IsingSamplingProgram, hinton_init
Show the code
# Verify GPU availability for optimized JAX execution
print(f"JAX Backend: {jax.default_backend()}")
try:
    print(f"Devices: {jax.devices()}")
except:
    print("Warning: No GPU devices found. Running on CPU with JAX optimization.")
JAX Backend: gpu
Devices: [CudaDevice(id=0)]
Show the code
# Configuration settings for the experiments and agents
class ExperimentConfig(NamedTuple):
    n_venues: int = 5  # Number of trading venues participating in the simulation
    n_context_venues: int = 1
    n_steps: int = 10000  # Total simulation steps per independent run
    n_seeds: int = 200
    window_size: int = 200  # Memory depth for incremental covariance tracking
    beta: float = 1.0
    n_warmup: int = 50
    n_samples: int = 100
    steps_per_sample: int = 4
    discount_factor: float = 0.995  # Exponential decay factor for adapting to non-stationary shifts
    learning_rate: float = 0.05  # Step size for bias and edge weight updates,
    propagation_damping: float = 0.3
    context_mode: str = "fixed"
    damp_coupling: bool = True

class ScenarioConfig(NamedTuple):
    name: str
    correlation_weight: float
    biases: jnp.ndarray
    regime_shift_step: Optional[int]
    regime_shift_biases: Optional[jnp.ndarray]
Show the code
class AblationConfig(NamedTuple):
    """Flags that selectively disable THRML mechanisms for ablation analysis."""
    name: str        = "THRML-Full"
    no_clamping:  bool = False  # If True: use unclamped joint sampling (context ignored)
    no_couplings: bool = False  # If True: zero edge weights (bias-only Ising)
    no_damping:   bool = False  # If True: disable mean-field propagation damping
Show the code
class AgentState_CEG(NamedTuple):
    successes: jnp.ndarray
    counts: jnp.ndarray

class AgentState_CTS(NamedTuple):
    alphas: jnp.ndarray
    betas: jnp.ndarray

class AgentState_THRML(NamedTuple):
    biases: jnp.ndarray
    weights: jnp.ndarray
    history_buffer: jnp.ndarray
    history_ptr: jnp.ndarray
    full_history_count: jnp.ndarray
    cov_sum: jnp.ndarray
    pair_counts: jnp.ndarray
Show the code
def select_among_max(key, scores, mask):
    """Choose venue by breaking ties randomly with small noise."""
    noise = random.uniform(key, scores.shape, minval=-1e-6, maxval=1e-6)
    return jnp.argmax(scores + mask + noise)

def thrml_init(n_venues, window_size=200):
    """
    Standardized initialization for THRML Agent State.
    Ensures memory depth is consistent across experiments.
    """
    return AgentState_THRML(
        biases=jnp.zeros(n_venues),
        weights=jnp.zeros((n_venues * (n_venues - 1)) // 2),
        history_buffer=jnp.zeros((window_size, n_venues)),
        history_ptr=jnp.array(0, dtype=jnp.int32),
        full_history_count=jnp.array(0, dtype=jnp.int32),
        cov_sum=jnp.zeros((n_venues, n_venues)),
        pair_counts=jnp.zeros((n_venues, n_venues))
)

def build_thrml_infra(n_venues, config):
    """
    Pre-calculates JAX-optimized graph structures.
    Uses a fixed program structure that clamps the first n_context_venues.
    Selection logic will permute nodes to satisfy this structure.
    """
    nodes = [SpinNode() for _ in range(n_venues)]
    edges = [(nodes[i], nodes[j]) for i in range(n_venues) for j in range(i+1, n_venues)]
    
    schedule = SamplingSchedule(
        n_warmup=config.n_warmup, 
        n_samples=config.n_samples, 
        steps_per_sample=config.steps_per_sample
    )
    
    # Static program structure for conditional selection: always clamp first K nodes
    clamped_block = Block(nodes[:config.n_context_venues])
    free_blocks = [Block([nodes[i]]) for i in range(config.n_context_venues, n_venues)]

    dummy_model = IsingEBM(nodes, edges, jnp.zeros(n_venues), jnp.zeros(len(edges)), jnp.array(config.beta))
    prog_conditional = IsingSamplingProgram(dummy_model, free_blocks, clamped_blocks=[clamped_block])
    
    # Static program structure for joint update: no clamped nodes
    # Note: Use serial schedule because the graph is fully connected (all-to-all).
    # Nodes must be updated sequentially to satisfy Gibbs validity.
    serial_blocks = [Block([n]) for n in nodes]
    prog_joint = IsingSamplingProgram(dummy_model, serial_blocks, clamped_blocks=[])
    
    return {
        'nodes': nodes, 
        'edges': edges, 
        'sched': schedule, 
        'prog': prog_conditional,
        'joint_prog': prog_joint,
        'full_block': [Block(nodes)]
    }

def thrml_update(state, outcomes, obs_mask, model_node_moms, model_edge_moms, 
                 discount_factor, beta, learning_rate,
                 propagation_damping=0.3, damp_coupling=True):
    """
    Perform THRML weight update with incremental covariance tracking.
    Reduces complexity from O(W*N^2) to O(N^2) per step.
    """
    n_venues = state.biases.shape[0]
    triu_idx = jnp.triu_indices(n_venues, 1)
    
    # 1. Update Biases
    J = jnp.zeros((n_venues, n_venues)).at[triu_idx].set(state.weights)
    J = J + J.T
    # Mean-Field Signal Propagation
    influence = propagation_damping * learning_rate * beta * (J @ (outcomes * obs_mask)) * (1.0 - obs_mask)
    new_biases = (state.biases * discount_factor) + (learning_rate * beta * (outcomes * obs_mask - model_node_moms * obs_mask)) + influence
    
    # 2. Incremental Covariance Update
    old_obs = state.history_buffer[state.history_ptr]
    old_present = (old_obs != 0).astype(jnp.float32)
    
    new_obs = outcomes * obs_mask
    new_present = obs_mask
    
    # Subtract old contribution, Add new contribution
    new_cov_sum = state.cov_sum - jnp.outer(old_obs, old_obs) + jnp.outer(new_obs, new_obs)
    new_pair_counts = state.pair_counts - jnp.outer(old_present, old_present) + jnp.outer(new_present, new_present)
    
    # 3. Update Buffer
    new_buffer = state.history_buffer.at[state.history_ptr].set(new_obs)
    new_ptr = (state.history_ptr + 1) % state.history_buffer.shape[0]
    new_count = jnp.minimum(state.full_history_count + 1, state.history_buffer.shape[0])
    
    # Calculate empirical correlations
    emp_cov = new_cov_sum / jnp.maximum(new_pair_counts, 1.0)
    emp = emp_cov[triu_idx]
    
    # 4. Update Weights
    pairs_observed = new_pair_counts[triu_idx] > 0
    weights_grad = (emp - model_edge_moms)
    innovation = beta * learning_rate * weights_grad
    innovation = jnp.where(damp_coupling, innovation, 0.0)

    new_weights = (state.weights * discount_factor) + jnp.where(pairs_observed, innovation, 0.0)

    return AgentState_THRML(new_biases, new_weights, new_buffer, new_ptr, new_count, new_cov_sum, new_pair_counts)
Show the code
def get_context_index(context_venues, context_outcomes, n_venues, n_context_venues):
    # Encode the full context tuple (venues + outcomes) into a single integer.
    # Uses Horner's method for base-2N encoding: index = sum(state_i * (2N)^i)
    # where state_i = venue_i * 2 + outcome_bit_i
    outcome_bits = (context_outcomes > 0).astype(jnp.int32)
    venue_states = context_venues * 2 + outcome_bits
    
    def scan_fn(val, state_i):
        return val * (2 * n_venues) + state_i, None
    
    index, _ = lax.scan(scan_fn, 0, venue_states)
    return index
Show the code
def ceg_init(config: ExperimentConfig) -> AgentState_CEG:
    # Scale context space to (2N)^K to support multiple context venues
    n_contexts = (2 * config.n_venues) ** config.n_context_venues
    return AgentState_CEG(
        successes=jnp.zeros((n_contexts, config.n_venues)),
        counts=jnp.zeros((n_contexts, config.n_venues))
    )
Show the code
def cts_init(config: ExperimentConfig, prior_alpha: float=1.0, prior_beta: float=1.0) -> AgentState_CTS:
    n_contexts = (2 * config.n_venues) ** config.n_context_venues
    return AgentState_CTS(
        alphas=jnp.ones((n_contexts, config.n_venues)) * prior_alpha,
        betas=jnp.ones((n_contexts, config.n_venues)) * prior_beta
    )
Show the code
def ceg_select(state, key, cidx, routing_mask, epsilon=0.1):
    """Choose venue using tabular E-Greedy given context."""
    n = state.counts.shape[1]
    means = jnp.where(state.counts[cidx] > 0, state.successes[cidx] / state.counts[cidx], 0.5)
    k_e, k_r = random.split(key)
    act = jnp.where(random.uniform(k_e) < epsilon, 
                    random.categorical(k_r, jnp.zeros(n) + routing_mask), 
                    select_among_max(k_r, means, routing_mask))
    return act

def ceg_update(state, cidx, venue, outcome, discount_factor):
    """Update expectations with global decay and local venue observation."""
    s = state.successes * discount_factor; c = state.counts * discount_factor
    return AgentState_CEG(s.at[cidx, venue].add(jnp.where(outcome > 0, 1.0, 0.0)), c.at[cidx, venue].add(1.0))

def cts_select(state, key, cidx, routing_mask):
    """Choose venue using Bayesian sampling given context."""
    samples = random.beta(key, state.alphas[cidx], state.betas[cidx])
    return jnp.argmax(samples + routing_mask)

def cts_update(state, cidx, venue, outcome, discount_factor):
    """Update posteriors with global decay and local venue observation."""
    a = (state.alphas - 1.0) * discount_factor + 1.0; b = (state.betas - 1.0) * discount_factor + 1.0
    return AgentState_CTS(a.at[cidx, venue].add(jnp.where(outcome > 0, 1.0, 0.0)), b.at[cidx, venue].add(jnp.where(outcome > 0, 0.0, 1.0)))
Show the code
def thrml_select_conditional(state, key, cvs, cos, infra, config, ablation=None):
    n = config.n_venues
    triu_idx = jnp.triu_indices(n, 1)

    # 1. Create permutation mapping CVs to the first K indices
    priorities = jnp.zeros(n, dtype=jnp.int32)
    priorities = priorities.at[cvs].set(jnp.arange(config.n_context_venues))
    is_context = jnp.zeros(n, dtype=jnp.bool_).at[cvs].set(True)
    priorities = jnp.where(is_context, priorities, jnp.arange(n) + n)
    perm = jnp.argsort(priorities)
    inv_perm = jnp.argsort(perm)

    def get_permuted_weights():
        J = jnp.zeros((n, n)).at[triu_idx].set(state.weights)
        J = J + J.T
        J_p = J[perm][:, perm]
        return J_p[triu_idx]

    b_p = state.biases[perm]
    w_p = lax.cond(jnp.all(perm == jnp.arange(n)), lambda: state.weights, get_permuted_weights)

    # Ablation: No Couplings — zero out edge weights before inference
    if ablation is not None and ablation.no_couplings:
        w_p = jnp.zeros_like(w_p)

    # 2. Build model
    model = IsingEBM(infra['nodes'], infra['edges'], b_p, w_p, jnp.array(config.beta))

    # Ablation: No Clamping — use unclamped joint program (ignore context)
    if ablation is not None and ablation.no_clamping:
        updated_prog = IsingSamplingProgram(
            model,
            infra['joint_prog'].gibbs_spec.superblocks,
            infra['joint_prog'].gibbs_spec.clamped_blocks
        )
        clamped_state = []
    else:
        updated_prog = IsingSamplingProgram(
            model,
            infra['prog'].gibbs_spec.superblocks,
            infra['prog'].gibbs_spec.clamped_blocks
        )
        clamped_state = [(cos > 0).astype(jnp.bool_)]

    k_init, k_sample, k_tie = random.split(key, 3)
    init = hinton_init(k_init, model, updated_prog.gibbs_spec.free_blocks, ())
    samples = sample_states(k_sample, updated_prog, infra['sched'], init, clamped_state, infra['full_block'])[0]

    # 3. Decode results
    spins = (2 * samples.astype(jnp.float32) - 1).reshape(samples.shape[0], -1)[:, inv_perm]
    margs = jnp.mean(spins, axis=0)
    probs = (margs + 1) / 2

    routing_mask = jnp.zeros(n).at[cvs].set(-1e9)
    # k_tie already assigned via split above
    return select_among_max(k_tie, probs, routing_mask), probs


def thrml_sample_joint(state, key, infra, config):
    """Perform UNCLAMPED sampling for unbiased weight updates."""
    model = IsingEBM(infra['nodes'], infra['edges'], state.biases, state.weights, jnp.array(config.beta))
    updated_prog = IsingSamplingProgram(
        model,
        infra['joint_prog'].gibbs_spec.superblocks,
        infra['joint_prog'].gibbs_spec.clamped_blocks
    )

    k1, k2 = random.split(key)
    init = hinton_init(k1, model, updated_prog.gibbs_spec.free_blocks, ())
    samples = sample_states(k2, updated_prog, infra['sched'], init, [], infra['full_block'])[0]

    spins = (2 * samples.astype(jnp.float32) - 1).reshape(samples.shape[0], -1)
    node_moms = jnp.mean(spins, axis=0)
    edge_moms = ((spins.T @ spins) / config.n_samples)[jnp.triu_indices(state.biases.shape[0], 1)]
    return node_moms, edge_moms
Show the code
def sample_outcomes_jit(biases, correlation_weight, beta, key, sim_struct_helper):
    """
    Generates venue outcomes using an Ising model to introduce correlations. 
    """
    nodes, edges, full_block = sim_struct_helper
    
    # Weights represent correlations between all venues
    n_venues = biases.shape[0]
    n_edges = (n_venues * (n_venues - 1)) // 2
    weights = jnp.full((n_edges,), correlation_weight)
    
    model = IsingEBM(nodes, edges, biases, weights, jnp.array(beta))
    # Note: Use serial schedule for valid Gibbs sampling on fully connected graph
    serial_blocks = [Block([n]) for n in nodes]
    prog = IsingSamplingProgram(model, serial_blocks, [])
    
    # Single Gibbs sample to get the current market state
    sched = SamplingSchedule(n_warmup=100, n_samples=1, steps_per_sample=1)
    k1, k2, k3 = random.split(key, 3)
    init = hinton_init(k1, model, serial_blocks, ())
    
    samples = sample_states(k2, prog, sched, init, [], full_block)[0]
    
    # Return both discrete outcomes AND continuous latent scores for argmax
    # Latent score approximation: Bias + Field influence + Noise (using uniform noise for tie-breaking)
    # Since we can't easily extract internal fields, we add small noise to spins to break ties uniquely
    discrete_outcomes = 2 * samples[0].astype(jnp.float32) - 1
    
    
    # THRML-Based Oracle Definition: 
    # The 'best' venue is defined by the Local Field (gamma) as per 
    # the Ising EBM specification: gamma_i = bias_i + sum_{j != i} J_ij * s_j
    total_spin_sum = jnp.sum(discrete_outcomes)
    neighbor_influence = correlation_weight * (total_spin_sum - discrete_outcomes)
    local_field = biases + neighbor_influence
    
    # Add small tie-breaker for unique argmax

    tie_breaker = jax.random.uniform(k3, (n_venues,), minval=-1e-5, maxval=1e-5)
    
    latent_scores = local_field + tie_breaker
    
    return discrete_outcomes, latent_scores
Show the code
def run_single_seed_experiment(
    master_seed: jax.Array,
    config: ExperimentConfig,
    scenario_config: ScenarioConfig,
    infra: Dict,
    sim_struct_helper: Tuple,
    ablation=None
) -> dict:

    rng = master_seed
    agent_ceg   = ceg_init(config)
    agent_cts   = cts_init(config)
    agent_thrml = thrml_init(config.n_venues, window_size=config.window_size)

    Carry = NamedTuple("Carry", [
        ("rng",         jax.Array),
        ("agent_ceg",   AgentState_CEG),
        ("agent_cts",   AgentState_CTS),
        ("agent_thrml", AgentState_THRML)
    ])
    init_carry = Carry(rng, agent_ceg, agent_cts, agent_thrml)

    # Resolve ablation flags once at Python/compile time (static relative to jit)
    eff_prop_damping  = 0.0   if (ablation is not None and ablation.no_damping)   else config.propagation_damping
    eff_damp_coupling = False if (ablation is not None and ablation.no_couplings) else config.damp_coupling

    def step_fn(carry: Carry, step_idx: int):
        rng = carry.rng

        current_biases = lax.select(
            step_idx >= scenario_config.regime_shift_step if scenario_config.regime_shift_step is not None else False,
            scenario_config.regime_shift_biases if scenario_config.regime_shift_biases is not None else scenario_config.biases,
            scenario_config.biases
        )

        rng, k_context, k_sim, k_agents, k_update = random.split(rng, 5)

        outcomes, latent_scores = sample_outcomes_jit(
            current_biases, scenario_config.correlation_weight,
            1.0, k_sim, sim_struct_helper
        )

        is_fixed_mode = (config.context_mode == "fixed")
        context_venues = lax.cond(
            is_fixed_mode,
            lambda: jnp.arange(config.n_context_venues),
            lambda: jax.random.permutation(k_context, jnp.arange(config.n_venues))[:config.n_context_venues]
        )

        oracle_best_global = jnp.argmax(latent_scores)
        rewards = jnp.where(jnp.arange(config.n_venues) == oracle_best_global, 1.0, -1.0)

        context_outcomes = rewards[context_venues]
        context_idx     = get_context_index(context_venues, context_outcomes, config.n_venues, config.n_context_venues)
        routing_mask    = jnp.zeros(config.n_venues).at[context_venues].set(-1e9)

        oracle_best_available = jnp.argmax(latent_scores + routing_mask)
        oracle_reward         = rewards[oracle_best_available]

        k_ceg, k_cts, k_thrml = random.split(k_agents, 3)

        act_ceg = ceg_select(carry.agent_ceg, k_ceg, context_idx, routing_mask)
        act_cts = cts_select(carry.agent_cts, k_cts, context_idx, routing_mask)

        # THRML Selection: passes ablation config for conditional/coupling flags
        act_thrml, _ = thrml_select_conditional(
            carry.agent_thrml, k_thrml,
            context_venues, context_outcomes,
            infra, config, ablation=ablation
        )

        model_node_moms, model_edge_moms = thrml_sample_joint(carry.agent_thrml, k_update, infra, config)

        r_ceg   = oracle_reward - rewards[act_ceg]
        r_cts   = oracle_reward - rewards[act_cts]
        r_thrml = oracle_reward - rewards[act_thrml]

        next_ceg = ceg_update(carry.agent_ceg, context_idx, act_ceg,   rewards[act_ceg],   config.discount_factor)
        next_cts = cts_update(carry.agent_cts, context_idx, act_cts,   rewards[act_cts],   config.discount_factor)

        obs_mask       = jnp.zeros(config.n_venues).at[context_venues].set(1.0).at[act_thrml].set(1.0)
        observed_data  = rewards * obs_mask

        next_thrml = thrml_update(
            carry.agent_thrml, observed_data, obs_mask,
            model_node_moms=model_node_moms, model_edge_moms=model_edge_moms,
            discount_factor=config.discount_factor,
            beta=config.beta,
            learning_rate=config.learning_rate,
            propagation_damping=eff_prop_damping,
            
            damp_coupling=eff_damp_coupling
        )

        next_carry = Carry(rng, next_ceg, next_cts, next_thrml)
        regrets    = jnp.array([r_ceg, r_cts, r_thrml])
        return next_carry, regrets

    steps = jnp.arange(config.n_steps)
    final_carry, step_regrets = lax.scan(step_fn, init_carry, steps)

    return {
        "Contextual ε-Greedy":        jnp.cumsum(step_regrets[:, 0]),
        "Contextual Thompson Sampling": jnp.cumsum(step_regrets[:, 1]),
        "THRML":                       jnp.cumsum(step_regrets[:, 2]),
        "thrml_biases":  final_carry.agent_thrml.biases,
        "thrml_weights": final_carry.agent_thrml.weights
    }
Show the code
def run_experiment_vmapped(
    config: ExperimentConfig,
    scenario: ScenarioConfig,
    ablation=None
):
    abl_name = ablation.name if ablation is not None else 'THRML-Full'
    print(f"Compiling and running scenario: {scenario.name}")
    print(f"  Context mode: {config.context_mode} | Ablation: {abl_name}")
    start_t = time.time()

    infra = build_thrml_infra(config.n_venues, config)
    sim_struct_helper = (infra['nodes'], infra['edges'], infra['full_block'])

    master = random.key(42)
    seeds  = random.split(master, config.n_seeds)

    run_one = partial(run_single_seed_experiment,
                      config=config,
                      scenario_config=scenario,
                      infra=infra,
                      sim_struct_helper=sim_struct_helper,
                      ablation=ablation)

    BATCH_SIZE = 50
    n_total_seeds = config.n_seeds
    run_batch = jit(vmap(run_one))
    
    all_res_list = []
    print(f"  Starting execution in batches of {BATCH_SIZE} seeds...")
    for i in range(0, n_total_seeds, BATCH_SIZE):
        batch_seeds = seeds[i : i + BATCH_SIZE]
        print(f"   - Processing seeds {i} to {i + len(batch_seeds)}... ", end="")
        batch_start = time.time()
        
        res = run_batch(batch_seeds)
        res = jax.tree_util.tree_map(lambda x: x.block_until_ready(), res)
        all_res_list.append(res)
        print(f"[{time.time()-batch_start:.1f}s]")

    # Combine results across batches
    results = {}
    for key in all_res_list[0].keys():
        results[key] = jnp.concatenate([r[key] for r in all_res_list], axis=0)

    elapsed = time.time() - start_t
    print(f"Finished {config.n_seeds} seeds x {config.n_steps} steps in {elapsed:.4f}s")
    return results
Show the code
def plot_all_results(all_results, context_mode: str):
    n = len(all_results)
    fig, axes = plt.subplots(1, n, figsize=(5*n, 5))
    if n == 1: axes = [axes]
    
    for ax, (name, res) in zip(axes, all_results.items()):
        steps = jnp.arange(res['THRML'].shape[1])
        
        for agent in ['Contextual ε-Greedy', 'Contextual Thompson Sampling', 'THRML']:
            data = res[agent]  # [n_seeds, n_steps]
            mean = jnp.mean(data, axis=0)
            std = jnp.std(data, axis=0)
            
            # Prettier labels
            label = agent
            
            ax.plot(steps, mean, label=label)
            ax.fill_between(steps, mean-std, mean+std, alpha=0.2)
            
        ax.set_title(f"{name} - Context Mode: {context_mode}")
        ax.set_xlabel("Step")
        ax.set_ylabel("Cumulative Regret")
        ax.legend()
        ax.grid(True, alpha=0.3)
        
    plt.tight_layout()
    filename = f'conditional_sor_{context_mode}_results.png'
    import os
    os.makedirs('images', exist_ok=True)
    plt.savefig('images/' + filename)
    print(f"Saved plot to {filename}")
    plt.show()
Show the code
# Scenarios for N=5
scenarios = [
    ScenarioConfig(
        "IID Venues", 
        0.0, 
        jnp.linspace(0.8, -0.8, 5), 
        None, None
    ),
    ScenarioConfig(
        "Correlated Venues", 
        0.4, 
        jnp.linspace(0.8, -0.8, 5), 
        None, None
    ),
    ScenarioConfig(
        "Regime Shift", 
        0.4, 
        # Start: High success venues at top, low success at bottom
        jnp.linspace(0.8, -0.8, 5), 
        5000, 
        # Shift: Biases flip
        jnp.flip(jnp.linspace(0.8, -0.8, 5))
    )
]
Show the code
print("" + "="*80)
print("CONDITIONAL ROUTING EXPERIMENT - FIXED CONTEXT (Venue 0)")
print("="*80)

all_results_fixed = {}
for scenario in scenarios:
    conf = ExperimentConfig(context_mode="fixed")
    res = run_experiment_vmapped(conf, scenario)
    all_results_fixed[scenario.name] = res
    
plot_all_results(all_results_fixed, "fixed")
================================================================================
CONDITIONAL ROUTING EXPERIMENT - FIXED CONTEXT (Venue 0)
================================================================================
Compiling and running scenario: IID Venues
  Context mode: fixed | Ablation: THRML-Full
  Starting execution in batches of 50 seeds...
   - Processing seeds 0 to 50... [249.6s]
   - Processing seeds 50 to 100... [223.4s]
   - Processing seeds 100 to 150... [223.4s]
   - Processing seeds 150 to 200... [223.5s]
Finished 200 seeds x 10000 steps in 921.2694s
Compiling and running scenario: Correlated Venues
  Context mode: fixed | Ablation: THRML-Full
  Starting execution in batches of 50 seeds...
   - Processing seeds 0 to 50... [249.3s]
   - Processing seeds 50 to 100... [223.0s]
   - Processing seeds 100 to 150... [223.2s]
   - Processing seeds 150 to 200... [223.2s]
Finished 200 seeds x 10000 steps in 918.7051s
Compiling and running scenario: Regime Shift
  Context mode: fixed | Ablation: THRML-Full
  Starting execution in batches of 50 seeds...
   - Processing seeds 0 to 50... [250.5s]
   - Processing seeds 50 to 100... [224.2s]
   - Processing seeds 100 to 150... [223.5s]
   - Processing seeds 150 to 200... [224.0s]
Finished 200 seeds x 10000 steps in 922.1878s
Saved plot to conditional_sor_fixed_results.png

Show the code
print("" + "="*80)
print("CONDITIONAL ROUTING EXPERIMENT - RANDOM CONTEXT")
print("="*80)

all_results_random = {}
for scenario in scenarios:
    conf = ExperimentConfig(context_mode="random")
    res = run_experiment_vmapped(conf, scenario)
    all_results_random[scenario.name] = res

plot_all_results(all_results_random, "random")
================================================================================
CONDITIONAL ROUTING EXPERIMENT - RANDOM CONTEXT
================================================================================
Compiling and running scenario: IID Venues
  Context mode: random | Ablation: THRML-Full
  Starting execution in batches of 50 seeds...
   - Processing seeds 0 to 50... [253.6s]
   - Processing seeds 50 to 100... [226.2s]
   - Processing seeds 100 to 150... [226.7s]
   - Processing seeds 150 to 200... [226.4s]
Finished 200 seeds x 10000 steps in 932.8930s
Compiling and running scenario: Correlated Venues
  Context mode: random | Ablation: THRML-Full
  Starting execution in batches of 50 seeds...
   - Processing seeds 0 to 50... [254.3s]
   - Processing seeds 50 to 100... [228.2s]
   - Processing seeds 100 to 150... [227.4s]
   - Processing seeds 150 to 200... [227.9s]
Finished 200 seeds x 10000 steps in 937.8307s
Compiling and running scenario: Regime Shift
  Context mode: random | Ablation: THRML-Full
  Starting execution in batches of 50 seeds...
   - Processing seeds 0 to 50... [253.9s]
   - Processing seeds 50 to 100... [226.5s]
   - Processing seeds 100 to 150... [226.8s]
   - Processing seeds 150 to 200... [226.8s]
Finished 200 seeds x 10000 steps in 934.1107s
Saved plot to conditional_sor_random_results.png

Show the code
# Final Regret Summary Table
labels = ["Contextual ε-Greedy", "Contextual Thompson Sampling", "THRML"]
print("" + "="*80)
print(f"{ 'SCENARIO':<25} | { 'MODE':<10} | { 'ε-Greedy':<10} | {'Thompson':<10} | {'THRML':<10}")
print("-" * 80)

for mode_name, results_dict in [("fixed", all_results_fixed), ("random", all_results_random)]:
    for scenario_name, res in results_dict.items():
        row = f"{scenario_name:<25} | {mode_name:<10}"
        for agent in labels:
            final_regret = jnp.mean(res[agent][:, -1])
            row += f" | {float(final_regret):<10.4f}"
        print(row)
print("=" * 80)
================================================================================
SCENARIO                  | MODE       | ε-Greedy   | Thompson   | THRML     
--------------------------------------------------------------------------------
IID Venues                | fixed      | 0.0000     | 0.0000     | 0.0000    
Correlated Venues         | fixed      | 767.7300   | 939.8100   | 624.8300  
Regime Shift              | fixed      | 2415.1299  | 2043.7400  | 1762.1799 
IID Venues                | random     | 1207.1599  | 513.6400   | 3.2000    
Correlated Venues         | random     | 2779.1299  | 2948.9800  | 1784.1899 
Regime Shift              | random     | 3133.2500  | 3061.0300  | 1841.0499 
================================================================================

1.8 Generative Proof: Learning Raw Market Distributions

1.8.1 Motivation

This section is a completely independent experiment that tests THRML’s ability to learn the raw Ising data-generating process, stripped of all routing logic.

Unlike the conditional routing experiment above (where THRML only observes partial, context-masked outcomes through winner-takes-all labels) here we train a fresh THRML agent on the full, unmasked raw market outcomes.

1.8.2 Protocol

  1. Initialize a new THRML agent with zeroed parameters (no carryover from routing)
  2. At each step:
    • Generate raw venue outcomes from the ground-truth Ising model (sample_outcomes_jit)
    • Show all venues’ raw outcomes to THRML (obs_mask = all ones)
    • Update the agent’s biases and weights via thrml_update
  3. After training, sample from the learned model and compare against ground truth

1.8.3 What This Proves

If the learned model’s marginals and correlations match the ground truth, it demonstrates that THRML isn’t just a routing heuristic: it is a genuine generative model that recovers the underlying joint distribution of market outcomes.

Show the code
def run_generative_proof_single(
    master_seed: jax.Array,
    config: ExperimentConfig,
    scenario_config: ScenarioConfig,
    infra: dict,
    sim_struct_helper: tuple,
    gt_biases: jnp.ndarray,
    n_gen_steps: int = 10000
):
    """
    Train a FRESH THRML agent on raw Ising outcomes using lax.scan.
    No routing, no winner-takes-all, no context/selection logic.
    """
    n = config.n_venues
    rng = master_seed

    agent = thrml_init(n, window_size=config.window_size)
    obs_mask = jnp.ones(n)

    Carry = NamedTuple("Carry", [
        ("rng", jax.Array),
        ("agent_thrml", AgentState_THRML),
    ])

    init_carry = Carry(rng, agent)

    def step_fn(carry: Carry, step_idx: int):
        rng = carry.rng
        rng, k_sim, k_update = random.split(rng, 3)

        outcomes, _ = sample_outcomes_jit(
            gt_biases,
            scenario_config.correlation_weight,
            config.beta,
            k_sim,
            sim_struct_helper
        )

        model_node_moms, model_edge_moms = thrml_sample_joint(
            carry.agent_thrml,
            k_update,
            infra,
            config
        )

        next_thrml = thrml_update(
            carry.agent_thrml,
            outcomes,
            obs_mask,
            model_node_moms=model_node_moms,
            model_edge_moms=model_edge_moms,
            discount_factor=config.discount_factor,
            beta=config.beta,
            learning_rate=config.learning_rate,
            propagation_damping=config.propagation_damping,
            damp_coupling=config.damp_coupling
        )

        next_carry = Carry(rng, next_thrml)
        return next_carry, None

    steps = jnp.arange(n_gen_steps)
    final_carry, _ = lax.scan(step_fn, init_carry, steps)

    return {
        'learned_biases': final_carry.agent_thrml.biases,
        'learned_weights': final_carry.agent_thrml.weights,
    }

# Stationary generative target: disable forgetting attenuation.
gen_config = ExperimentConfig(discount_factor=1.0)
gen_n_seeds = 64
Show the code
def run_generative_proof_training(
    scenario: ScenarioConfig,
    config: ExperimentConfig,
    n_gen_steps: int = 10000,
    n_seeds: int = 64
):
    n = config.n_venues
    print(f"\n{'─'*60}")
    print(f"Generative Proof Training: {scenario.name}")
    print(f"  n_venues={n}, n_steps={n_gen_steps}, n_seeds={n_seeds}")
    print(f"  Stationary settings: discount_factor={config.discount_factor}")

    gt_biases = (
        scenario.regime_shift_biases
        if scenario.regime_shift_step is not None
        else scenario.biases
    )
    print(f"  GT biases={np.array(gt_biases)}, GT corr_w={scenario.correlation_weight}")
    print(f"{'─'*60}")

    infra = build_thrml_infra(n, config)
    sim_struct_helper = (infra['nodes'], infra['edges'], infra['full_block'])

    master = random.key(0)
    seeds = random.split(master, n_seeds)

    run_one = partial(
        run_generative_proof_single,
        config=config,
        scenario_config=scenario,
        infra=infra,
        sim_struct_helper=sim_struct_helper,
        gt_biases=gt_biases,
        n_gen_steps=n_gen_steps
    )

    BATCH_SIZE = 50
    run_batch = jit(vmap(run_one))

    start_t = time.time()
    all_res_list = []
    print(f"  Starting execution in batches of {BATCH_SIZE} seeds...")
    for i in range(0, n_seeds, BATCH_SIZE):
        batch_seeds = seeds[i : i + BATCH_SIZE]
        print(f"   - Processing seeds {i} to {i + len(batch_seeds)}... ", end="")
        batch_start = time.time()

        res = run_batch(batch_seeds)
        res = jax.tree_util.tree_map(lambda x: x.block_until_ready(), res)
        all_res_list.append(res)
        print(f"[{time.time()-batch_start:.1f}s]")
    
    results = {}
    for key in all_res_list[0].keys():
        results[key] = jnp.concatenate([r[key] for r in all_res_list], axis=0)

    elapsed = time.time() - start_t

    learned_biases = results['learned_biases']
    learned_weights = results['learned_weights']

    print(f"  Training complete in {elapsed:.2f}s")
    print(f"  Learned biases mean±std:  {np.round(np.array(jnp.mean(learned_biases, axis=0)), 4)} ± {np.round(np.array(jnp.std(learned_biases, axis=0)), 4)}")
    print(f"  Learned weights mean±std: {np.round(np.array(jnp.mean(learned_weights, axis=0)), 4)} ± {np.round(np.array(jnp.std(learned_weights, axis=0)), 4)}")
    print(f"  GT biases:                {np.round(np.array(gt_biases), 4)}")
    print(f"  GT weights:               {scenario.correlation_weight}")

    return {
        'learned_biases': learned_biases,
        'learned_weights': learned_weights,
        'gt_biases': gt_biases,
        'gt_correlation_weight': scenario.correlation_weight,
        'n_seeds': n_seeds,
    }
Show the code
# ── Run Generative Proof Training ──
print("=" * 80)
print("GENERATIVE PROOF — Training THRML on Raw Ising Outcomes")
print("=" * 80)

# Exclude regime-shift from the stationary generative proof.
gen_scenarios = [s for s in scenarios if s.name != "Regime Shift"]
excluded = [s.name for s in scenarios if s.name == "Regime Shift"]
if excluded:
    print(f"Excluded from Generative Proof: {excluded}")

gen_proof_results = {}
for scenario in gen_scenarios:
    gen_proof_results[scenario.name] = run_generative_proof_training(
        scenario,
        gen_config,
        n_gen_steps=10000,
        n_seeds=gen_n_seeds,
    )
================================================================================
GENERATIVE PROOF — Training THRML on Raw Ising Outcomes
================================================================================
Excluded from Generative Proof: ['Regime Shift']

────────────────────────────────────────────────────────────
Generative Proof Training: IID Venues
  n_venues=5, n_steps=10000, n_seeds=64
  Stationary settings: discount_factor=1.0
  GT biases=[ 0.8         0.40000004  0.         -0.40000004 -0.8       ], GT corr_w=0.0
────────────────────────────────────────────────────────────
  Starting execution in batches of 50 seeds...
   - Processing seeds 0 to 50... [154.9s]
   - Processing seeds 50 to 64... [154.8s]
  Training complete in 309.91s
  Learned biases mean±std:  [ 0.8654  0.4644 -0.0321 -0.4531 -0.9048] ± [0.1844 0.2234 0.2509 0.215  0.2327]
  Learned weights mean±std: [-0.017   0.0224  0.0184  0.0186  0.011   0.0139  0.0168  0.0027 -0.0026
 -0.0123] ± [0.1042 0.1291 0.1099 0.0979 0.0855 0.0819 0.1174 0.0781 0.1143 0.1146]
  GT biases:                [ 0.8  0.4  0.  -0.4 -0.8]
  GT weights:               0.0

────────────────────────────────────────────────────────────
Generative Proof Training: Correlated Venues
  n_venues=5, n_steps=10000, n_seeds=64
  Stationary settings: discount_factor=1.0
  GT biases=[ 0.8         0.40000004  0.         -0.40000004 -0.8       ], GT corr_w=0.4
────────────────────────────────────────────────────────────
  Starting execution in batches of 50 seeds...
   - Processing seeds 0 to 50... [154.8s]
   - Processing seeds 50 to 64... [154.5s]
  Training complete in 309.23s
  Learned biases mean±std:  [ 0.8627  0.4522 -0.0154 -0.4356 -0.89  ] ± [0.2143 0.2111 0.1683 0.1705 0.2165]
  Learned weights mean±std: [0.3992 0.4204 0.407  0.4473 0.4141 0.4102 0.4164 0.4111 0.3982 0.4027] ± [0.1123 0.1051 0.1316 0.1058 0.1021 0.0984 0.1147 0.1264 0.1142 0.0915]
  GT biases:                [ 0.8  0.4  0.  -0.4 -0.8]
  GT weights:               0.4
Show the code
# ── Generative Proof: Evaluation — Sampling, Metrics & Comparison Plots ──

def run_generative_proof_evaluation(
    gen_proof_results,
    config,
    n_eval_samples=1500,
    mae_tolerance=0.08
):
    """
    Evaluate generative quality of THRML trained on raw outcomes.
    Reports mean±std MAE across seeds for marginals and pairwise correlations,
    plus pass/fail against a tolerance.
    """
    evaluations = {}
    infra = build_thrml_infra(config.n_venues, config)

    eval_sched = SamplingSchedule(
        n_warmup=200,
        n_samples=n_eval_samples,
        steps_per_sample=config.steps_per_sample
    )
    serial_blocks = [Block([n]) for n in infra['nodes']]

    def sample_model_stats(key, biases, weights):
        model = IsingEBM(
            infra['nodes'],
            infra['edges'],
            biases,
            weights,
            jnp.array(config.beta)
        )
        prog = IsingSamplingProgram(model, serial_blocks, [])
        k1, k2, k3 = random.split(key, 3)
        init_state = hinton_init(k1, model, serial_blocks, ())
        raw_samples = sample_states(k2, prog, eval_sched, init_state, [], infra['full_block'])[0]
        spins = (2 * raw_samples.astype(jnp.float32) - 1).reshape(-1, config.n_venues)

        marginals = (jnp.mean(spins, axis=0) + 1) / 2
        triu_idx = jnp.triu_indices(config.n_venues, 1)
        correlations = ((spins.T @ spins) / spins.shape[0])[triu_idx]
        return marginals, correlations

    for name, res in gen_proof_results.items():
        print(f"\nEvaluating generative quality: {name}")

        learned_biases = res['learned_biases']
        learned_weights = res['learned_weights']
        gt_biases = res['gt_biases']

        n_edges = (config.n_venues * (config.n_venues - 1)) // 2
        gt_weights = jnp.full((n_edges,), res['gt_correlation_weight'])
        baseline_biases = jnp.zeros_like(gt_biases)
        baseline_weights = jnp.zeros_like(gt_weights)

        key = random.key(99)
        k_gt, k_base, k_seed = random.split(key, 3)

        gt_marginals, gt_correlations = sample_model_stats(k_gt, gt_biases, gt_weights)
        baseline_marginals, baseline_correlations = sample_model_stats(k_base, baseline_biases, baseline_weights)

        seed_keys = random.split(k_seed, learned_biases.shape[0])
        learned_marginals, learned_correlations = vmap(
            lambda b, w, k: sample_model_stats(k, b, w)
        )(learned_biases, learned_weights, seed_keys)

        marg_mae_per_seed = jnp.mean(jnp.abs(learned_marginals - gt_marginals), axis=1)
        corr_mae_per_seed = jnp.mean(jnp.abs(learned_correlations - gt_correlations), axis=1)

        marg_mae_mean = float(jnp.mean(marg_mae_per_seed))
        marg_mae_std = float(jnp.std(marg_mae_per_seed))
        corr_mae_mean = float(jnp.mean(corr_mae_per_seed))
        corr_mae_std = float(jnp.std(corr_mae_per_seed))

        baseline_marg_mae = float(jnp.mean(jnp.abs(baseline_marginals - gt_marginals)))
        baseline_corr_mae = float(jnp.mean(jnp.abs(baseline_correlations - gt_correlations)))

        pass_marg = marg_mae_mean < mae_tolerance
        pass_corr = corr_mae_mean < mae_tolerance

        evaluations[name] = {
            'gt_marginals': np.array(gt_marginals),
            'gt_correlations': np.array(gt_correlations),
            'baseline_marginals': np.array(baseline_marginals),
            'baseline_correlations': np.array(baseline_correlations),
            'learned_marginals_all': np.array(learned_marginals),
            'learned_correlations_all': np.array(learned_correlations),
            'learned_marginals_mean': np.array(jnp.mean(learned_marginals, axis=0)),
            'learned_correlations_mean': np.array(jnp.mean(learned_correlations, axis=0)),
            'marg_mae_per_seed': np.array(marg_mae_per_seed),
            'corr_mae_per_seed': np.array(corr_mae_per_seed),
            'marg_mae_mean': marg_mae_mean,
            'marg_mae_std': marg_mae_std,
            'corr_mae_mean': corr_mae_mean,
            'corr_mae_std': corr_mae_std,
            'baseline_marg_mae': baseline_marg_mae,
            'baseline_corr_mae': baseline_corr_mae,
            'pass_marg': pass_marg,
            'pass_corr': pass_corr,
            'mae_tolerance': mae_tolerance,
        }

        print(f"  Marginal MAE (mean±std):    {marg_mae_mean:.4f} ± {marg_mae_std:.4f}  | pass<{mae_tolerance}: {pass_marg}")
        print(f"  Correlation MAE (mean±std): {corr_mae_mean:.4f} ± {corr_mae_std:.4f}  | pass<{mae_tolerance}: {pass_corr}")
        print(f"  Untrained baseline MAE:     marg={baseline_marg_mae:.4f}, corr={baseline_corr_mae:.4f}")

    return evaluations


def plot_generative_proof(evaluations):
    """
    For each scenario, create a 2-panel plot comparing learned mean vs ground truth,
    with an untrained (all-zeros) baseline shown as a dotted reference.
    """
    n = len(evaluations)
    fig, axes = plt.subplots(n, 2, figsize=(10, 4 * n))
    if n == 1:
        axes = axes.reshape(1, -1)

    for idx, (name, data) in enumerate(evaluations.items()):
        learned_m = data['learned_marginals_mean']
        gt_m = data['gt_marginals']
        baseline_m = data['baseline_marginals']

        learned_c = data['learned_correlations_mean']
        gt_c = data['gt_correlations']
        baseline_c = data['baseline_correlations']

        ax = axes[idx, 0]
        ax.scatter(gt_m, learned_m, s=120, zorder=3,
                   c='royalblue', edgecolors='navy', alpha=0.85, label='Learned (mean across seeds)')
        ax.plot(gt_m, baseline_m, linestyle=':', marker='x', markersize=8,
                color='gray', alpha=0.95, label='Untrained baseline')
        pad = 0.05
        lo = min(gt_m.min(), learned_m.min(), baseline_m.min()) - pad
        hi = max(gt_m.max(), learned_m.max(), baseline_m.max()) + pad
        ax.plot([lo, hi], [lo, hi], 'k--', alpha=0.4, label='Perfect match')
        ax.set_xlim(lo, hi)
        ax.set_ylim(lo, hi)
        for i, (x, y) in enumerate(zip(gt_m, learned_m)):
            ax.annotate(f'V{i}', (x, y), textcoords='offset points',
                        xytext=(8, 8), fontsize=9, fontweight='bold')
        ax.set_xlabel('Ground Truth  P(venue = +1)')
        ax.set_ylabel('Learned / Baseline  P(venue = +1)')
        ax.set_title(f"{name} — Marginals", fontweight='bold')
        ax.legend(fontsize=8)
        ax.grid(True, alpha=0.3)
        ax.set_aspect('equal')

        ax = axes[idx, 1]
        ax.scatter(gt_c, learned_c, s=120, zorder=3,
                   c='crimson', edgecolors='darkred', alpha=0.85, label='Learned (mean across seeds)')
        ax.plot(gt_c, baseline_c, linestyle=':', marker='x', markersize=8,
                color='gray', alpha=0.95, label='Untrained baseline')
        lo = min(gt_c.min(), learned_c.min(), baseline_c.min()) - pad
        hi = max(gt_c.max(), learned_c.max(), baseline_c.max()) + pad
        ax.plot([lo, hi], [lo, hi], 'k--', alpha=0.4, label='Perfect match')
        ax.set_xlim(lo, hi)
        ax.set_ylim(lo, hi)
        n_venues = len(gt_m)
        edge_labels = [f'({i},{j})'
                       for i in range(n_venues)
                       for j in range(i + 1, n_venues)]
        for i, (x, y) in enumerate(zip(gt_c, learned_c)):
            ax.annotate(edge_labels[i], (x, y), textcoords='offset points',
                        xytext=(8, 8), fontsize=9, fontweight='bold')
        ax.set_xlabel('Ground Truth  E[s_i · s_j]')
        ax.set_ylabel('Learned / Baseline  E[s_i · s_j]')
        ax.set_title(f"{name} — Pairwise Correlations", fontweight='bold')
        ax.legend(fontsize=8)
        ax.grid(True, alpha=0.3)
        ax.set_aspect('equal')

    fig.suptitle(
        'Generative Proof — Raw Distribution Learning (No Routing)',
        fontsize=14,
        fontweight='bold',
        y=1.02
    )
    plt.tight_layout()
    import os
    os.makedirs('images', exist_ok=True)
    filename = 'images/generative_proof_results.png'
    plt.savefig(filename, bbox_inches='tight', dpi=150)
    print(f"\nSaved plot to {filename}")
    plt.show()


# ── Run evaluation ──
print("\n" + "=" * 80)
print("GENERATIVE PROOF — Evaluation")
print("=" * 80)
gen_proof_eval = run_generative_proof_evaluation(
    gen_proof_results,
    gen_config,
    n_eval_samples=1500,
    mae_tolerance=0.08
)
plot_generative_proof(gen_proof_eval)

print("\n" + "=" * 80)
print("GENERATIVE PROOF — MAE SUMMARY")
print("=" * 80)
print(f"{ 'SCENARIO':<25} | {'MARG_MAE (mean±std)':<24} | {'CORR_MAE (mean±std)':<24} | {'PASS'}")
print("-" * 95)
for scenario_name, data in gen_proof_eval.items():
    row = (
        f"{scenario_name:<25} | "
        f"{data['marg_mae_mean']:.4f}±{data['marg_mae_std']:<14.4f} | "
        f"{data['corr_mae_mean']:.4f}±{data['corr_mae_std']:<14.4f} | "
        f"{bool(data['pass_marg'] and data['pass_corr'])}"
    )
    print(row)
print("=" * 95)

================================================================================
GENERATIVE PROOF — Evaluation
================================================================================

Evaluating generative quality: IID Venues
  Marginal MAE (mean±std):    0.0433 ± 0.0145  | pass<0.08: True
  Correlation MAE (mean±std): 0.0703 ± 0.0195  | pass<0.08: True
  Untrained baseline MAE:     marg=0.2297, corr=0.1840

Evaluating generative quality: Correlated Venues
  Marginal MAE (mean±std):    0.1097 ± 0.0737  | pass<0.08: False
  Correlation MAE (mean±std): 0.0514 ± 0.0172  | pass<0.08: True
  Untrained baseline MAE:     marg=0.0832, corr=0.6789

Saved plot to images/generative_proof_results.png


================================================================================
GENERATIVE PROOF — MAE SUMMARY
================================================================================
SCENARIO                  | MARG_MAE (mean±std)      | CORR_MAE (mean±std)      | PASS
-----------------------------------------------------------------------------------------------
IID Venues                | 0.0433±0.0145         | 0.0703±0.0195         | True
Correlated Venues         | 0.1097±0.0737         | 0.0514±0.0172         | False
===============================================================================================

1.9 Ablation Study

To understand which THRML mechanisms drive the performance gains, we run a controlled ablation study on the two informative scenarios (Correlated Venues and Regime Shift) in Fixed context mode. Each variant disables exactly one component of the full agent:

Variant What is disabled
THRML-Full Nothing — full agent (reference)
THRML-NoClamping Context nodes not clamped; joint (unclamped) sampling used for selection
THRML-NoCouplings Edge weights zeroed and frozen; only node biases are used and learned
THRML-NoDamping Mean-field propagation damping set to 0.0
Show the code
# --- Ablation variants ---
ablation_variants = [
    AblationConfig(name="THRML-Full",        no_clamping=False, no_couplings=False, no_damping=False),
    AblationConfig(name="THRML-NoClamping",  no_clamping=True,  no_couplings=False, no_damping=False),
    AblationConfig(name="THRML-NoCouplings", no_clamping=False, no_couplings=True,  no_damping=False),
    AblationConfig(name="THRML-NoDamping",   no_clamping=False, no_couplings=False, no_damping=True),
]

ablation_config = ExperimentConfig(context_mode="fixed")

# Only the two scenarios where mechanism differences are informative
ablation_scenario_names = ["Correlated Venues", "Regime Shift"]
ablation_scenarios = {s.name: s for s in scenarios if s.name in ablation_scenario_names}

# ablation_results[scenario_name][ablation_name] = {'final_regret': float, 'curves': array}
ablation_results = {}

for scenario_name, scenario in ablation_scenarios.items():
    ablation_results[scenario_name] = {}
    print(f"\n=== {scenario_name} ===")
    for abl in ablation_variants:
        res = run_experiment_vmapped(ablation_config, scenario, ablation=abl)
        final_regret = float(jnp.mean(res['THRML'][:, -1]))
        ablation_results[scenario_name][abl.name] = {
            'final_regret': final_regret,
            'curves': res['THRML']   # shape: [n_seeds, n_steps]
        }
        print(f"  {abl.name:28s} | final cumulative regret: {final_regret:.2f}")

=== Correlated Venues ===
Compiling and running scenario: Correlated Venues
  Context mode: fixed | Ablation: THRML-Full
  Starting execution in batches of 50 seeds...
   - Processing seeds 0 to 50... [251.8s]
   - Processing seeds 50 to 100... [225.1s]
   - Processing seeds 100 to 150... [224.9s]
   - Processing seeds 150 to 200... [224.8s]
Finished 200 seeds x 10000 steps in 926.7745s
  THRML-Full                   | final cumulative regret: 624.83
Compiling and running scenario: Correlated Venues
  Context mode: fixed | Ablation: THRML-NoClamping
  Starting execution in batches of 50 seeds...
   - Processing seeds 0 to 50... [289.8s]
   - Processing seeds 50 to 100... [263.6s]
   - Processing seeds 100 to 150... [263.4s]
   - Processing seeds 150 to 200... [263.7s]
Finished 200 seeds x 10000 steps in 1080.5780s
  THRML-NoClamping             | final cumulative regret: 2445.36
Compiling and running scenario: Correlated Venues
  Context mode: fixed | Ablation: THRML-NoCouplings
  Starting execution in batches of 50 seeds...
   - Processing seeds 0 to 50... [253.5s]
   - Processing seeds 50 to 100... [227.7s]
   - Processing seeds 100 to 150... [227.8s]
   - Processing seeds 150 to 200... [227.8s]
Finished 200 seeds x 10000 steps in 936.7551s
  THRML-NoCouplings            | final cumulative regret: 1803.61
Compiling and running scenario: Correlated Venues
  Context mode: fixed | Ablation: THRML-NoDamping
  Starting execution in batches of 50 seeds...
   - Processing seeds 0 to 50... [255.2s]
   - Processing seeds 50 to 100... [227.7s]
   - Processing seeds 100 to 150... [227.4s]
   - Processing seeds 150 to 200... [227.8s]
Finished 200 seeds x 10000 steps in 938.0799s
  THRML-NoDamping              | final cumulative regret: 1091.54

=== Regime Shift ===
Compiling and running scenario: Regime Shift
  Context mode: fixed | Ablation: THRML-Full
  Starting execution in batches of 50 seeds...
   - Processing seeds 0 to 50... [251.3s]
   - Processing seeds 50 to 100... [225.2s]
   - Processing seeds 100 to 150... [225.2s]
   - Processing seeds 150 to 200... [225.2s]
Finished 200 seeds x 10000 steps in 926.9179s
  THRML-Full                   | final cumulative regret: 1762.18
Compiling and running scenario: Regime Shift
  Context mode: fixed | Ablation: THRML-NoClamping
  Starting execution in batches of 50 seeds...
   - Processing seeds 0 to 50... [290.2s]
   - Processing seeds 50 to 100... [264.2s]
   - Processing seeds 100 to 150... [264.0s]
   - Processing seeds 150 to 200... [264.0s]
Finished 200 seeds x 10000 steps in 1082.4216s
  THRML-NoClamping             | final cumulative regret: 2581.83
Compiling and running scenario: Regime Shift
  Context mode: fixed | Ablation: THRML-NoCouplings
  Starting execution in batches of 50 seeds...
   - Processing seeds 0 to 50... [252.0s]
   - Processing seeds 50 to 100... [226.7s]
   - Processing seeds 100 to 150... [226.9s]
   - Processing seeds 150 to 200... [226.8s]
Finished 200 seeds x 10000 steps in 932.4727s
  THRML-NoCouplings            | final cumulative regret: 2248.69
Compiling and running scenario: Regime Shift
  Context mode: fixed | Ablation: THRML-NoDamping
  Starting execution in batches of 50 seeds...
   - Processing seeds 0 to 50... [251.6s]
   - Processing seeds 50 to 100... [226.7s]
   - Processing seeds 100 to 150... [226.7s]
   - Processing seeds 150 to 200... [226.6s]
Finished 200 seeds x 10000 steps in 931.6394s
  THRML-NoDamping              | final cumulative regret: 1917.66
Show the code
def plot_ablation_results(ablation_results):
    """Regret curves for each ablation variant, one panel per scenario."""
    scenario_names = list(ablation_results.keys())
    fig, axes = plt.subplots(1, len(scenario_names), figsize=(7 * len(scenario_names), 5))
    if len(scenario_names) == 1:
        axes = [axes]

    palette = {
        'THRML-Full':        ('#1f77b4', '-'),
        'THRML-NoClamping':  ('#ff7f0e', '--'),
        'THRML-NoCouplings': ('#2ca02c', '-.'),
        'THRML-NoDamping':   ('#d62728', ':'),
    }

    for ax, scenario_name in zip(axes, scenario_names):
        for abl_name, data in ablation_results[scenario_name].items():
            curves = data['curves']
            mean   = jnp.mean(curves, axis=0)
            std    = jnp.std(curves,  axis=0)
            steps  = jnp.arange(mean.shape[0])
            color, ls = palette.get(abl_name, ('grey', '-'))
            ax.plot(steps, mean, label=abl_name, color=color, linestyle=ls, linewidth=2)
            ax.fill_between(steps, mean - std, mean + std, alpha=0.15, color=color)

        ax.set_title(f'Ablation Study — {scenario_name}\n(Fixed Context, 200 seeds)')
        ax.set_xlabel('Step')
        ax.set_ylabel('Cumulative Regret')
        ax.legend(framealpha=0.9)
        ax.grid(True, alpha=0.3)

    plt.tight_layout()
    import os
    os.makedirs('images', exist_ok=True)
    plt.savefig('images/ablation_study_results.png', dpi=150, bbox_inches='tight')
    print('Saved: ablation_study_results.png')
    plt.show()

plot_ablation_results(ablation_results)
Saved: ablation_study_results.png

1.9.1 Sampling Budget Sweep

We sweep n_samples ∈ {10, 25, 50, 100, 200} on the Correlated Venues scenario (Fixed context) to characterise the regret–compute tradeoff. In the implementation below, n_warmup is also scaled as max(5, n_samples // 4) so that shorter budgets do not spend disproportionate effort on burn-in while larger budgets still receive additional equilibration. This is directly relevant to thermodynamic hardware, where sampling throughput is a first-class physical resource.

Show the code
budget_values = [10, 25, 50, 100, 200]
correlated_scenario = [s for s in scenarios if s.name == "Correlated Venues"][0]
budget_sweep_results = {}  # {n_samples: {'final_regret': float, 'curves': array}}

print("=== Sampling Budget Sweep (Correlated, Fixed Context) ===")
for n_samp in budget_values:
    sweep_cfg = ExperimentConfig(
        context_mode="fixed",
        n_warmup=max(5, n_samp // 4),
        n_samples=n_samp,
        steps_per_sample=4
    )
    res = run_experiment_vmapped(sweep_cfg, correlated_scenario, ablation=None)
    final_regret = float(jnp.mean(res['THRML'][:, -1]))
    budget_sweep_results[n_samp] = {'final_regret': final_regret, 'curves': res['THRML']}
    print(f"  n_samples={n_samp:4d} | final cumulative regret: {final_regret:.2f}")
=== Sampling Budget Sweep (Correlated, Fixed Context) ===
Compiling and running scenario: Correlated Venues
  Context mode: fixed | Ablation: THRML-Full
  Starting execution in batches of 50 seeds...
   - Processing seeds 0 to 50... [75.0s]
   - Processing seeds 50 to 100... [49.0s]
   - Processing seeds 100 to 150... [49.1s]
   - Processing seeds 150 to 200... [49.1s]
Finished 200 seeds x 10000 steps in 222.2398s
  n_samples=  10 | final cumulative regret: 707.28
Compiling and running scenario: Correlated Venues
  Context mode: fixed | Ablation: THRML-Full
  Starting execution in batches of 50 seeds...
   - Processing seeds 0 to 50... [102.0s]
   - Processing seeds 50 to 100... [76.0s]
   - Processing seeds 100 to 150... [76.1s]
   - Processing seeds 150 to 200... [76.0s]
Finished 200 seeds x 10000 steps in 330.0238s
  n_samples=  25 | final cumulative regret: 633.38
Compiling and running scenario: Correlated Venues
  Context mode: fixed | Ablation: THRML-Full
  Starting execution in batches of 50 seeds...
   - Processing seeds 0 to 50... [147.9s]
   - Processing seeds 50 to 100... [122.6s]
   - Processing seeds 100 to 150... [122.6s]
   - Processing seeds 150 to 200... [122.7s]
Finished 200 seeds x 10000 steps in 515.8542s
  n_samples=  50 | final cumulative regret: 627.93
Compiling and running scenario: Correlated Venues
  Context mode: fixed | Ablation: THRML-Full
  Starting execution in batches of 50 seeds...
   - Processing seeds 0 to 50... [244.6s]
   - Processing seeds 50 to 100... [219.2s]
   - Processing seeds 100 to 150... [219.4s]
   - Processing seeds 150 to 200... [219.2s]
Finished 200 seeds x 10000 steps in 902.4749s
  n_samples= 100 | final cumulative regret: 625.15
Compiling and running scenario: Correlated Venues
  Context mode: fixed | Ablation: THRML-Full
  Starting execution in batches of 50 seeds...
   - Processing seeds 0 to 50... [430.9s]
   - Processing seeds 50 to 100... [405.3s]
   - Processing seeds 100 to 150... [404.9s]
   - Processing seeds 150 to 200... [404.9s]
Finished 200 seeds x 10000 steps in 1646.0363s
  n_samples= 200 | final cumulative regret: 623.70
Show the code
def plot_budget_sweep(budget_sweep_results):
    fig, axes = plt.subplots(1, 2, figsize=(13, 5))

    cmap     = plt.cm.viridis
    n_bud    = len(budget_sweep_results)
    sorted_b = sorted(budget_sweep_results.items())

    for i, (n_samp, data) in enumerate(sorted_b):
        curves = data['curves']
        mean   = jnp.mean(curves, axis=0)
        std    = jnp.std(curves,  axis=0)
        steps  = jnp.arange(mean.shape[0])
        color  = cmap(i / max(n_bud - 1, 1))
        axes[0].plot(steps, mean, label=f'n_samples={n_samp}', color=color, linewidth=2)
        axes[0].fill_between(steps, mean - std, mean + std, alpha=0.12, color=color)

    axes[0].set_title('Regret Curves by Sampling Budget\n(Correlated, Fixed, 200 seeds)')
    axes[0].set_xlabel('Step')
    axes[0].set_ylabel('Cumulative Regret')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)

    ns     = [k for k, _ in sorted_b]
    finals = [v['final_regret'] for _, v in sorted_b]
    axes[1].plot(ns, finals, 'o-', color='#1f77b4', linewidth=2, markersize=8)
    for n, f in zip(ns, finals):
        axes[1].annotate(f'{f:.1f}', (n, f), textcoords='offset points',
                         xytext=(0, 10), ha='center', fontsize=9)
    axes[1].set_title('Final Cumulative Regret vs Sampling Budget\n(Correlated, Fixed)')
    axes[1].set_xlabel('n_samples (Gibbs samples per step)')
    axes[1].set_ylabel('Mean Final Cumulative Regret')
    axes[1].grid(True, alpha=0.3)

    plt.tight_layout()
    import os
    os.makedirs('images', exist_ok=True)
    plt.savefig('images/sampling_budget_sweep.png', dpi=150, bbox_inches='tight')
    print('Saved: sampling_budget_sweep.png')
    plt.show()

plot_budget_sweep(budget_sweep_results)
Saved: sampling_budget_sweep.png

1.9.2 Temperature (β) Sensitivity Sweep

The inverse temperature \(\beta\) is the central thermodynamic parameter of the Ising model. It controls the sharpness of the Boltzmann distribution — low \(\beta\) (hot system) produces a flat, exploratory distribution over venue states, while high \(\beta\) (cold system) concentrates probability mass on the lowest-energy configuration.

In this sweep, the environment’s data-generating process uses a fixed \(\beta=1.0\), and only the agent’s inference \(\beta\) is varied. This isolates the effect of the agent’s distributional sharpness on routing performance:

  • Agent: controls how decisively the THRML agent commits to its highest-probability venue

We sweep \(\beta \in \{0.1, 0.5, 1.0, 2.0, 5.0\}\) on the Correlated Venues and Regime Shift scenarios (Fixed context, 200 seeds). The IID scenario is excluded because changing β is not expected to materially change routing performance in a structureless setting where venues are independent.

Show the code
beta_values = [0.1, 0.5, 1.0, 2.0, 5.0]

# Run on Correlated and Regime Shift — same scenarios as the ablation study
beta_scenario_names = ["Correlated Venues", "Regime Shift"]
beta_scenarios = {s.name: s for s in scenarios if s.name in beta_scenario_names}

# beta_sweep_results[scenario_name][beta] = {'final_regret': float, 'curves': array}
beta_sweep_results = {}

for scenario_name, scenario in beta_scenarios.items():
    beta_sweep_results[scenario_name] = {}
    print(f"\n=== {scenario_name} ===")
    for beta_val in beta_values:
        beta_cfg = ExperimentConfig(context_mode="fixed", beta=beta_val)
        res = run_experiment_vmapped(beta_cfg, scenario, ablation=None)
        final_regret = float(jnp.mean(res['THRML'][:, -1]))
        beta_sweep_results[scenario_name][beta_val] = {
            'final_regret': final_regret,
            'curves': res['THRML']  # shape: [n_seeds, n_steps]
        }
        print(f"  beta={beta_val:.1f} | final cumulative regret: {final_regret:.2f}")

=== Correlated Venues ===
Compiling and running scenario: Correlated Venues
  Context mode: fixed | Ablation: THRML-Full
  Starting execution in batches of 50 seeds...
   - Processing seeds 0 to 50... [252.5s]
   - Processing seeds 50 to 100... [226.0s]
   - Processing seeds 100 to 150... [226.5s]
   - Processing seeds 150 to 200... [226.6s]
Finished 200 seeds x 10000 steps in 932.9363s
  beta=0.1 | final cumulative regret: 1905.18
Compiling and running scenario: Correlated Venues
  Context mode: fixed | Ablation: THRML-Full
  Starting execution in batches of 50 seeds...
   - Processing seeds 0 to 50... [252.5s]
   - Processing seeds 50 to 100... [227.3s]
   - Processing seeds 100 to 150... [227.7s]
   - Processing seeds 150 to 200... [227.2s]
Finished 200 seeds x 10000 steps in 934.7665s
  beta=0.5 | final cumulative regret: 1046.72
Compiling and running scenario: Correlated Venues
  Context mode: fixed | Ablation: THRML-Full
  Starting execution in batches of 50 seeds...
   - Processing seeds 0 to 50... [250.2s]
   - Processing seeds 50 to 100... [224.9s]
   - Processing seeds 100 to 150... [224.6s]
   - Processing seeds 150 to 200... [224.7s]
Finished 200 seeds x 10000 steps in 924.4627s
  beta=1.0 | final cumulative regret: 624.83
Compiling and running scenario: Correlated Venues
  Context mode: fixed | Ablation: THRML-Full
  Starting execution in batches of 50 seeds...
   - Processing seeds 0 to 50... [253.3s]
   - Processing seeds 50 to 100... [227.7s]
   - Processing seeds 100 to 150... [227.7s]
   - Processing seeds 150 to 200... [228.1s]
Finished 200 seeds x 10000 steps in 936.8187s
  beta=2.0 | final cumulative regret: 634.34
Compiling and running scenario: Correlated Venues
  Context mode: fixed | Ablation: THRML-Full
  Starting execution in batches of 50 seeds...
   - Processing seeds 0 to 50... [253.1s]
   - Processing seeds 50 to 100... [227.3s]
   - Processing seeds 100 to 150... [227.3s]
   - Processing seeds 150 to 200... [227.4s]
Finished 200 seeds x 10000 steps in 935.0790s
  beta=5.0 | final cumulative regret: 1538.98

=== Regime Shift ===
Compiling and running scenario: Regime Shift
  Context mode: fixed | Ablation: THRML-Full
  Starting execution in batches of 50 seeds...
   - Processing seeds 0 to 50... [251.5s]
   - Processing seeds 50 to 100... [225.5s]
   - Processing seeds 100 to 150... [225.0s]
   - Processing seeds 150 to 200... [225.0s]
Finished 200 seeds x 10000 steps in 927.0813s
  beta=0.1 | final cumulative regret: 3517.46
Compiling and running scenario: Regime Shift
  Context mode: fixed | Ablation: THRML-Full
  Starting execution in batches of 50 seeds...
   - Processing seeds 0 to 50... [252.2s]
   - Processing seeds 50 to 100... [226.7s]
   - Processing seeds 100 to 150... [226.6s]
   - Processing seeds 150 to 200... [226.6s]
Finished 200 seeds x 10000 steps in 932.2503s
  beta=0.5 | final cumulative regret: 1919.56
Compiling and running scenario: Regime Shift
  Context mode: fixed | Ablation: THRML-Full
  Starting execution in batches of 50 seeds...
   - Processing seeds 0 to 50... [253.4s]
   - Processing seeds 50 to 100... [228.3s]
   - Processing seeds 100 to 150... [228.0s]
   - Processing seeds 150 to 200... [228.0s]
Finished 200 seeds x 10000 steps in 937.8048s
  beta=1.0 | final cumulative regret: 1762.18
Compiling and running scenario: Regime Shift
  Context mode: fixed | Ablation: THRML-Full
  Starting execution in batches of 50 seeds...
   - Processing seeds 0 to 50... [252.4s]
   - Processing seeds 50 to 100... [226.9s]
   - Processing seeds 100 to 150... [227.1s]
   - Processing seeds 150 to 200... [226.8s]
Finished 200 seeds x 10000 steps in 933.1656s
  beta=2.0 | final cumulative regret: 1815.75
Compiling and running scenario: Regime Shift
  Context mode: fixed | Ablation: THRML-Full
  Starting execution in batches of 50 seeds...
   - Processing seeds 0 to 50... [251.8s]
   - Processing seeds 50 to 100... [226.8s]
   - Processing seeds 100 to 150... [226.7s]
   - Processing seeds 150 to 200... [226.7s]
Finished 200 seeds x 10000 steps in 932.0655s
  beta=5.0 | final cumulative regret: 2184.24
Show the code
def plot_beta_sweep(beta_sweep_results):
    """
    Two-panel figure per scenario: regret curves (left) and
    final regret vs beta (right).
    """
    scenario_names = list(beta_sweep_results.keys())
    n_scenarios = len(scenario_names)
    fig, axes = plt.subplots(n_scenarios, 2, figsize=(13, 5 * n_scenarios))
    if n_scenarios == 1:
        axes = [axes]

    cmap = plt.cm.coolwarm  # Cool (blue) = cold/high-beta, warm (red) = hot/low-beta
    beta_vals = sorted(next(iter(beta_sweep_results.values())).keys())
    n_betas = len(beta_vals)

    for row, scenario_name in enumerate(scenario_names):
        ax_curves, ax_final = axes[row]

        for i, beta_val in enumerate(beta_vals):
            data   = beta_sweep_results[scenario_name][beta_val]
            curves = data['curves']
            mean   = jnp.mean(curves, axis=0)
            std    = jnp.std(curves,  axis=0)
            steps  = jnp.arange(mean.shape[0])
            # High beta → cold → blue end of coolwarm
            color  = cmap(i / max(n_betas - 1, 1))
            ax_curves.plot(steps, mean, label=f'β={beta_val}', color=color, linewidth=2)
            ax_curves.fill_between(steps, mean - std, mean + std, alpha=0.12, color=color)

        ax_curves.set_title(f'β Sensitivity — {scenario_name}\n(Fixed Context, 200 seeds)')
        ax_curves.set_xlabel('Step')
        ax_curves.set_ylabel('Cumulative Regret')
        ax_curves.legend(title='Inverse Temp. β')
        ax_curves.grid(True, alpha=0.3)

        # Final regret vs beta
        finals = [beta_sweep_results[scenario_name][b]['final_regret'] for b in beta_vals]
        colors = [cmap(i / max(n_betas - 1, 1)) for i in range(n_betas)]
        ax_final.plot(beta_vals, finals, 'o-', color='#444', linewidth=2, zorder=1)
        for bv, fv, col in zip(beta_vals, finals, colors):
            ax_final.scatter(bv, fv, color=col, s=80, zorder=2)
            ax_final.annotate(f'{fv:.1f}', (bv, fv),
                              textcoords='offset points', xytext=(0, 10),
                              ha='center', fontsize=9)
        ax_final.set_title(f'Final Cumulative Regret vs β\n{scenario_name}')
        ax_final.set_xlabel('β (inverse temperature)')
        ax_final.set_ylabel('Mean Final Cumulative Regret')
        ax_final.grid(True, alpha=0.3)

    plt.tight_layout()
    import os
    os.makedirs('images', exist_ok=True)
    plt.savefig('images/beta_sweep.png', dpi=150, bbox_inches='tight')
    print('Saved: beta_sweep.png')
    plt.show()

plot_beta_sweep(beta_sweep_results)
Saved: beta_sweep.png