Hi all,
I’m a complete newbie at this so this may be a simple fix. I’m currently running an evidence accumulation model, specifically investigating post-decision evidence accumulation as an index of certainty/confidence. I’m running this on Bayesflow 2.0.6, and on Jax with Keras v 3.11.2
When I attempt to train the model with SetTransformer, I get this as an error
ValueError: Layers added to a Sequential model can only have a single positional argument, the input tensor. Layer SetAttentionBlock has multiple positional arguments: [<Parameter "input_set: ~Tensor">, <Parameter "**kwargs">]
I tried with DeepSet summer instead and it works. I initially ran the same model in Bayesflow v1 (given that there’s more resources on ddm models) and preliminary diagnostics seems to suggest good recovery, so I don’t think it’s the model. Maybe something is wrong with my arguments in translating them to Bayesflow v2 as my main references to the second version has been the regression and t-test example.
Here’s the simulator
@nb.jit(nopython=True, cache=True)
def meta_trial(v, a, ter, v_bias, a_bias, m_bias, v_ratio, dt = 0.001, s = 1, max_iter = 15000):
"""Generates trial response time and choice, and confidence rating"""
"""from a postdecisional model given a set of parameters for one trial"""
# Constant for diffusion process
c = np.sqrt(dt) * s
# Initialise settings
pre_evidence = a * 0.5
rt = ter
# Define stimulus being presented
stim = np.random.choice(np.array([-1, 1]))
# --- 1. Initital Decision Phase ---
# Define initial v to be stimulus dependent
v_trial_dt = v * dt * stim
# Euler-Murayama
for _ in range(max_iter):
if 0 < pre_evidence < a:
pre_evidence += v_trial_dt + c * np.random.randn()
rt += dt
else:
break
# Response and accuracy
if pre_evidence >= a:
resp = 1
post_evidence = a
acc = 1 if stim == 1 else 0
elif pre_evidence <= 0:
resp = -1
post_evidence = 0
acc = 1 if stim == -1 else 0
# --- 2. Post Decisional Phase ---
post_rt = np.random.gamma(2, 0.7)
while post_rt > 5:
post_rt = np.random.gamma(2, 0.7)
v_post = v * stim * v_ratio + v_bias * resp
post_evidence = post_evidence + post_rt * v_post + np.sqrt(post_rt) * np.random.randn() * s
conf_ev = (post_evidence - a) if resp == 1 else (-post_evidence)
# --- 3. Calculate Confidence ---
conf_cont = 1 + 5.0 * (1.0 / (1.0 + np.exp(-(conf_ev - a_bias) * m_bias)))
conf_resp = int(max(1.0, min(np.rint(conf_cont), 6.0)))
return rt, resp, acc, stim, conf_resp, post_rt
def meta_ddm(v, a, ter, v_bias, a_bias, m_bias, v_ratio, num_obs):
rt_list, resp_list, acc_list, stim_list, conf_resp_list, post_rt_list = [], [], [], [], [], []
for n in range(num_obs):
rt, resp, acc, stim, conf_resp, post_rt = meta_trial(v = v, a = a, ter = ter,
v_bias = v_bias, a_bias = a_bias,
m_bias = m_bias, v_ratio = v_ratio)
rt_list.append(rt)
resp_list.append(resp)
acc_list.append(acc)
stim_list.append(stim)
conf_resp_list.append(conf_resp)
post_rt_list.append(post_rt)
return dict(rt = np.asarray(rt_list), resp = np.asarray(resp_list), acc = np.asarray(acc_list),
stim = np.asarray(stim_list), conf_resp = np.asarray(conf_resp_list), post_rt = np.asarray(post_rt_list))
Prior and observation number
def priors():
v = scipy.stats.truncnorm.rvs(-1.74/1.5, 16.75/1.5, loc = 1.76, scale = 1.51)
a = np.random.gamma(11.7, 0.12)
ter = scipy.stats.truncnorm.rvs(-0.34/0.08, 3.25/0.08, loc = 0.44, scale = 0.08)
# z = scipy.stats.truncnorm.rvs(-0.46/0.05, 0.46/0.05, loc = 0.5, scale = 0.05)
v_bias = np.random.normal(0, 1)
a_bias = np.random.normal(-1, 1)
m_bias = np.random.gamma(2, 0.5)
v_ratio = np.random.gamma(2, 0.5)
def meta():
# N: number of observation in a dataset
num_obs = np.random.randint(10, 29) * 40
return dict(num_obs = num_obs)
Simulator, adapter, and summary network
simulator = bf.make_simulator([priors, meta_ddm], meta_fn = meta)
adapter = (bf.Adapter()
.broadcast("num_obs", to = "v")
.rename("num_obs", "inference_conditions")
.constrain(["v", "a", "ter", "m_bias", "v_ratio"], lower=0.0)
.as_set(["rt","resp", "acc", "stim", "conf_resp", "post_rt"])
.concatenate(["v", "a", "ter", "v_bias", "a_bias", "m_bias", "v_ratio"], into = "inference_variables")
.concatenate(["rt", "resp", "acc", "stim", "conf_resp", "post_rt"], into = "summary_variables")
.convert_dtype("float64", "float32"))
summary_network = bf.networks.SetTransformer(summary_dim = 32, dropout = None)
However, when I try training the model I get this error
workflow = bf.BasicWorkflow(simulator = simulator,
adapter = adapter,
inference_network = inference_network,
summary_network = summary_network,
standardize = ["inference_variables", "summary_variables"])
history = workflow.fit_online(epochs = 60, num_batches_per_epoch = 500, batch_size = 64)