Any help appreciated!

Hello! I’m in the process of trying to make a nice basic 3 parameter DDM and I hope to be able to nail the basic version so I can branch out in the future. I had some initial success but there were some flaws in the logic that I wanted to smooth out. Attempting to fix this, I’ve gotten into trouble. I’m quite poor at coding (particularly python) so trying to debug it has been quite difficult. I’m very new to this all so any help on anything in general that stands out as a red flag or anything more specific based on this error that I keep getting would be greatly appreciated. Thank you!

Script 1 (which worked)

===== 1. Setup =====

import os
os.environ[“KERAS_BACKEND”]=“jax”
os.environ[“TF_CUDNN_DETERMINISTIC”]=“1”
os.environ[“XLA_PYTHON_CLIENT_PREALLOCATE”]=“false”

import numpy as np, bayesflow as bf, keras
import jax, jax.numpy as jnp

print(“=== JAX GPU check ===”)
print(“JAX version:”, jax.version)
print(“Devices:”, jax.devices())
gpu=[d for d in jax.devices() if d.platform==“gpu”]
print(“GPU detected:” if gpu else “:warning: No GPU detected.”, gpu)
print(“====================\n”)

===== 2. Priors =====

def prior():
return dict(
v=float(np.random.uniform(0.2,2.5)),
a=float(np.random.uniform(0.5,2.5)),
t0=float(np.random.gamma(2,0.25)),
)

===== 3. JAX Simulator =====

@jax.jit(static_argnames=[“n_obs”,“max_steps”,“dt”])
def _simulate_batch(key,v,a,t0,n_obs,max_steps,dt):
z=0.5*a
time_steps=jnp.arange(max_steps,dtype=jnp.float32)*dt
key_noise,key_tie=jax.random.split(key)
noise=jax.random.normal(key_noise,(n_obs,max_steps))*jnp.sqrt(dt)

traj=z+jnp.cumsum(v*dt+noise,axis=1)
up,low=traj>=a,traj<=0.0

up_idx,low_idx=jnp.argmax(up,1),jnp.argmax(low,1)
hit_up,hit_low=jnp.any(up,1),jnp.any(low,1)

tie=(up_idx==low_idx)&hit_up&hit_low
rand=jax.random.bernoulli(key_tie,0.5,shape=up_idx.shape)
up_first=jnp.where(tie,rand,up_idx<low_idx)

up_rt=jnp.where(hit_up,time_steps[up_idx]+t0,jnp.nan)
low_rt=jnp.where(hit_low,time_steps[low_idx]+t0,jnp.nan)

rt=jnp.where(hit_up&hit_low,
jnp.where(up_first,up_rt,low_rt),
jnp.where(hit_up,up_rt,
jnp.where(hit_low,low_rt,jnp.nan)))

resp=jnp.where(hit_up&hit_low,
jnp.where(up_first,1.0,0.0),
jnp.where(hit_up,1.0,
jnp.where(hit_low,0.0,-1.0)))
return rt,resp

def simulator(v,a,t0,n_obs,dt=0.001,max_rt=2.0,min_rt=0.2):
key=jax.random.PRNGKey(np.random.randint(0,2**31-1))
rt,resp=_simulate_batch(key,v,a,t0,n_obs,int(max_rt/dt),dt)

rt,resp=np.asarray(rt),np.asarray(resp)
valid=(~np.isnan(rt))&(resp>=0)&(rt>=min_rt)&(rt<=max_rt)

rt_v,resp_v=rt[valid],resp[valid]
if len(rt_v)>0:
idx=np.random.choice(len(rt_v),n_obs,
replace=len(rt_v)<n_obs)
rt_v,resp_v=rt_v[idx],resp_v[idx]
else:
rt_v=np.full(n_obs,max_rt,np.float32)
resp_v=np.zeros(n_obs,np.float32)

return dict(x=np.stack([rt_v,resp_v],1).astype(np.float32))

===== 4. Simulator Object =====

def meta():
return dict(n_obs=np.int32(np.random.randint(120,161)))

simulator_obj=bf.make_simulator([prior,simulator],meta_fn=meta)

===== 5. Workflow =====

adapter=(bf.Adapter()
.constrain(“a”,lower=0)
.constrain(“t0”,lower=0)
.as_set(“x”)
.concatenate([“v”,“a”,“t0”],into=“inference_variables”)
.rename(“x”,“summary_variables”))

workflow=bf.BasicWorkflow(
simulator=simulator_obj,
adapter=adapter,
inference_network=bf.networks.FlowMatching(),
summary_network=bf.networks.SetTransformer(summary_dim=256),
inference_variables=[“v”,“a”,“t0”],
summary_variables=[“x”],
)

===== 6. Training =====

opt=keras.optimizers.AdamW(learning_rate=1e-4)
history=workflow.fit_online(
epochs=200,
num_batches_per_epoch=2500,
batch_size=32,
optimizer=opt,
)

===== 7. Save =====

workflow.approximator.save(“accuracy_ddm_9_approximator.keras”)
print(“Model saved.”)

Script 2: Has not been working, tried to make it more principled but I am having issues with it. (All I changed when pasting here was some of the formatting)

---- Section 1 ----- Import/Set Up
import os
os.environ[“KERAS_BACKEND”] = “jax”
os.environ[“TF_CUDNN_DETERMINISTIC”] = “1”
os.environ[“XLA_PYTHON_CLIENT_PREALLOCATE”] = “false”

import numpy as np
import bayesflow as bf
import keras
from scipy.stats import truncnorm

import jax
import jax.numpy as jnp

print(“=== JAX GPU check ===”)
print(“JAX version:”, jax.version)
print(“Available devices:”, jax.devices())

gpu_devices = [d for d in jax.devices() if d.platform == ‘gpu’]
if gpu_devices:
print(f"GPU detected: {gpu_devices}“)
else:
print(”:warning: No GPU detected. JAX will run on CPU.“)
print(”====================\n")

print(“JAX devices:”, jax.devices())

---- Section 2 ---- Priors

def truncated_normal(mean, sd, lower, upper):
a = (lower - mean) / sd
b = (upper - mean) / sd
return truncnorm.rvs(a, b, loc=mean, scale=sd)

def prior():
return dict(
v=np.float32(truncated_normal(mean=1.75, sd=0.70, lower=0.0, upper=2.5)),
a=np.float32(truncated_normal(mean=1.75, sd=0.50, lower=0.5, upper=2.5)),
t0=np.float32(truncated_normal(mean=0.60, sd=0.15, lower=0.2, upper=1.0)),
)

---- Section 3 ---- Simulator

@jax.jit(static_argnames=[“n_sim”, “max_steps”, “dt”])
def _simulate_batch(key, v, a, t0, n_sim, max_steps, dt):
z = 0.5 * a
time_steps = jnp.arange(max_steps, dtype=jnp.float32) * dt

key_noise, key_tie = jax.random.split(key)
s = 1.0
noise = jax.random.normal(key_noise, (n_sim, max_steps)) * (s * jnp.sqrt(dt))

increments = v * dt + noise
trajectories = z + jnp.cumsum(increments, axis=1)

upper_cross = trajectories >= a
lower_cross = trajectories <= 0.0

upper_idx = jnp.argmax(upper_cross, axis=1)
lower_idx = jnp.argmax(lower_cross, axis=1)

hit_upper = jnp.any(upper_cross, axis=1)
hit_lower = jnp.any(lower_cross, axis=1)

tie = (upper_idx == lower_idx) & hit_upper & hit_lower
rand = jax.random.bernoulli(key_tie, 0.5, shape=upper_idx.shape)
upper_first = jnp.where(tie, rand, upper_idx < lower_idx)

upper_rt = jnp.where(hit_upper, time_steps[upper_idx] + t0, jnp.nan)
lower_rt = jnp.where(hit_lower, time_steps[lower_idx] + t0, jnp.nan)

rt = jnp.where(
    hit_upper & hit_lower,
    jnp.where(upper_first, upper_rt, lower_rt),
    jnp.where(hit_upper, upper_rt,
              jnp.where(hit_lower, lower_rt, jnp.nan))
)

resp = jnp.where(
    hit_upper & hit_lower,
    jnp.where(upper_first, 1.0, 0.0),
    jnp.where(hit_upper, 1.0,
              jnp.where(hit_lower, 0.0, -1.0))
)

return rt, resp

def simulator(v, a, t0, n_obs, dt=0.001, max_rt=10.0):
n_sim = 300
max_steps = int(max_rt / dt)

key = jax.random.PRNGKey(np.random.randint(0, 2**31 - 1))
rt, resp = _simulate_batch(key, v, a, t0, n_sim, max_steps, dt)

rt = np.asarray(rt)
resp = np.asarray(resp)

return dict(rt=rt, resp=resp)

---- Section 4 ---- Preprocessor

def preprocess_simulator_output(sim_output, n_obs):
rt = sim_output[“rt”].reshape(-1)
resp = sim_output[“resp”].reshape(-1)

valid = (
    (~np.isnan(rt)) &
    (resp >= 0) &
    (rt >= 0.2) &
    (rt <= 2.0)
)

rt_valid = rt[valid]
resp_valid = resp[valid]

if len(rt_valid) == 0:
    mask = ~np.isnan(rt)
    rt_valid = rt[mask]
    resp_valid = resp[mask]

idx = np.random.choice(len(rt_valid), size=n_obs, replace=True)
rt_sampled = rt_valid[idx]
resp_sampled = resp_valid[idx]

log_rt = np.log(rt_sampled)

x = np.stack([log_rt, resp_sampled], axis=1).astype(np.float32)

return dict(x=x)

---- Section 5 ---- Meta + Full Simulator

def meta():
return dict(n_obs=np.int32(160))

def full_simulator():
theta = prior()
n_obs = int(meta()[“n_obs”])

sim_raw = simulator(theta["v"], theta["a"], theta["t0"], n_obs)
x_dict = preprocess_simulator_output(sim_raw, n_obs)

return dict(
    v=theta["v"],
    a=theta["a"],
    t0=theta["t0"],
    x=x_dict["x"],
)

---- Section 6 ---- Workflow (FIXED SHAPES + NAMES)

adapter = (
bf.Adapter()
.constrain(“v”, lower=0.0, upper=2.5)
.constrain(“a”, lower=0.5, upper=2.5)
.constrain(“t0”, lower=0.2, upper=1.0)
.as_set(“x”) # treat x as set
.concatenate([“v”, “a”, “t0”], into=“inference_variables”)
.rename(“x”, “summary_variables”) # ← critical rename
)

workflow = bf.BasicWorkflow(
simulator=simulator_obj,
adapter=adapter,
inference_network=bf.networks.FlowMatching(),
summary_network=bf.networks.SetTransformer(summary_dim=64),
inference_variables=[“v”, “a”, “t0”], # parameter names
summary_variables=[“summary_variables”], # ← matches rename
)

---- Section 7 ---- Training

optimizer = keras.optimizers.AdamW(learning_rate=1e-4)

history = workflow.fit_online(
epochs=200,
num_batches_per_epoch=1000,
batch_size=32,
optimizer=optimizer,
)

---- Section 8 ---- Saving

workflow.approximator.save(“accuracy_ddm_10_approximator.keras”)
print(“Model saved.”)

The key traceback is:

ValueError: Incompatible shapes for broadcasting: (96,) and requested shape (1,)

This happens inside:

time = keras.ops.broadcast_to(time, keras.ops.shape(xz)[:-1] + (1,))

So it looks like time has shape (96,) but is being broadcast to something ending in (1,), which fails.

Again, I’m quite inexperienced so if anything seems like it needs a total redo please let me know, I’ve added the constraints to match the task I want to apply the approximator to. Feel free to reach out to me. Sorry for the ugly posting!

Hi Will, and welcome to our forums!

First off, since there is no GPU, I suggest instructing the LLM to code up your simulator in pure numpy. Also, there is no need to use as_set in the adapter, as this is only used for arrays of shape (num_repeats, ) that are actually unordered sets and need to be represented as arrays of shape (num_repeats, 1).

For your setup, you need to make sure that what goes into the networks is:

  • inference_variables → array of shape (num_sims, num_params)
  • summary_variables → array of shape (num_sims, num_obs, num_columns)

I couple of practical suggestions:

  1. Always check the shapes of all adapted outputs as a sanity check:
test_sims = simulator(2)
adapted_sims = adapter(test_sims)
for k, v in adapted_sims.items():
    print(f"Shape of {k} is {v.shape}")
  1. When you are starting out, you want to train as fast as possible, so don’t use varying n_obs and if possible, instruct the LLM to compute summary statistics of the data (e.g., RT quantiles and accuracies per condition). That way, you can forego using a summary network and can pass the “hand-crafted” summaries directly as inference_conditions (will reduce your training time to < 1 minute).
  2. Related: do not use online training (fit_online) at first, as the total time for any amortized workflow is always:
T_{total} = T_{sim} + T_{train} + T_{infer}

Online training entangles T_{train} and T_{train} and should only be used as a last step in a workflow (i.e., before writing down results for a paper and your model is final). A good heuristic for DDM-like models is that around 10,000 offline simulations are more than sufficient to give you an accurate idea of the recoverability of your model (i.e., using fit_offline).

Feel free to check out some of our resources out there on cognitive modeling. For example, this notebooks demonstrates a few workflows with a DDM using pre-built simulators from the excellent HSSM library: bayesflow_workshops/carney_comp_2025/notebooks/ssms_bayesflow.ipynb at main · bayesflow-org/bayesflow_workshops · GitHub

Let us know if you encounter more roadblocks along the way and happy amortizing!

Thanks so much I really appreciate it!

1 Like