Error running setTransformer summary

Hi all,

I’m a complete newbie at this so this may be a simple fix. I’m currently running an evidence accumulation model, specifically investigating post-decision evidence accumulation as an index of certainty/confidence. I’m running this on Bayesflow 2.0.6, and on Jax with Keras v 3.11.2

When I attempt to train the model with SetTransformer, I get this as an error

ValueError: Layers added to a Sequential model can only have a single positional argument, the input tensor. Layer SetAttentionBlock has multiple positional arguments: [<Parameter "input_set: ~Tensor">, <Parameter "**kwargs">]

I tried with DeepSet summer instead and it works. I initially ran the same model in Bayesflow v1 (given that there’s more resources on ddm models) and preliminary diagnostics seems to suggest good recovery, so I don’t think it’s the model. Maybe something is wrong with my arguments in translating them to Bayesflow v2 as my main references to the second version has been the regression and t-test example.

Here’s the simulator

@nb.jit(nopython=True, cache=True)

def meta_trial(v, a, ter, v_bias, a_bias, m_bias, v_ratio, dt = 0.001, s = 1, max_iter = 15000):
    """Generates trial response time and choice, and confidence rating"""
    """from a postdecisional model given a set of parameters for one trial"""

    # Constant for diffusion process
    c = np.sqrt(dt) * s
        
    # Initialise settings
    pre_evidence = a * 0.5
    rt = ter

    # Define stimulus being presented 
    stim = np.random.choice(np.array([-1, 1]))

    # --- 1. Initital Decision Phase ---
    # Define initial v to be stimulus dependent
    v_trial_dt = v * dt * stim
    
    # Euler-Murayama
    for _ in range(max_iter):   
        if 0 < pre_evidence < a:    
            pre_evidence += v_trial_dt + c * np.random.randn()
            rt += dt
        else:
            break
    
    # Response and accuracy
    if pre_evidence >= a:
        resp = 1
        post_evidence = a
        acc = 1 if stim == 1 else 0
    elif pre_evidence <= 0:
        resp = -1
        post_evidence = 0
        acc = 1 if stim == -1 else 0

    # --- 2. Post Decisional Phase ---
    post_rt = np.random.gamma(2, 0.7)
    while post_rt > 5:
        post_rt = np.random.gamma(2, 0.7)
    
    v_post = v * stim * v_ratio + v_bias * resp
    post_evidence = post_evidence + post_rt * v_post + np.sqrt(post_rt) * np.random.randn() * s
    conf_ev = (post_evidence - a) if resp == 1 else (-post_evidence)

    # --- 3. Calculate Confidence ---
    conf_cont = 1 + 5.0 * (1.0 / (1.0 + np.exp(-(conf_ev - a_bias) * m_bias)))
    conf_resp = int(max(1.0, min(np.rint(conf_cont), 6.0)))
    
    return rt, resp, acc, stim, conf_resp, post_rt

def meta_ddm(v, a, ter, v_bias, a_bias, m_bias, v_ratio, num_obs):

    rt_list, resp_list, acc_list, stim_list, conf_resp_list, post_rt_list = [], [], [], [], [], []
    
    for n in range(num_obs):
        rt, resp, acc, stim, conf_resp, post_rt = meta_trial(v = v, a = a, ter = ter, 
                                                             v_bias = v_bias, a_bias = a_bias, 
                                                             m_bias = m_bias, v_ratio = v_ratio)

        rt_list.append(rt)
        resp_list.append(resp)
        acc_list.append(acc)
        stim_list.append(stim)
        conf_resp_list.append(conf_resp)
        post_rt_list.append(post_rt)


    return dict(rt = np.asarray(rt_list), resp = np.asarray(resp_list), acc = np.asarray(acc_list), 
                stim = np.asarray(stim_list), conf_resp = np.asarray(conf_resp_list), post_rt = np.asarray(post_rt_list))

Prior and observation number

def priors():
    v = scipy.stats.truncnorm.rvs(-1.74/1.5, 16.75/1.5, loc = 1.76, scale = 1.51)
    a = np.random.gamma(11.7, 0.12)
    ter = scipy.stats.truncnorm.rvs(-0.34/0.08, 3.25/0.08, loc = 0.44, scale = 0.08)
    # z = scipy.stats.truncnorm.rvs(-0.46/0.05, 0.46/0.05, loc = 0.5, scale = 0.05)
    v_bias = np.random.normal(0, 1)
    a_bias = np.random.normal(-1, 1)
    m_bias = np.random.gamma(2, 0.5)
    v_ratio = np.random.gamma(2, 0.5)

def meta():
    # N: number of observation in a dataset
    num_obs = np.random.randint(10, 29) * 40
    return dict(num_obs = num_obs)

Simulator, adapter, and summary network

simulator = bf.make_simulator([priors, meta_ddm], meta_fn = meta)

adapter = (bf.Adapter()
           .broadcast("num_obs", to = "v")
           .rename("num_obs", "inference_conditions")
           .constrain(["v", "a", "ter", "m_bias", "v_ratio"], lower=0.0)
           .as_set(["rt","resp", "acc", "stim", "conf_resp", "post_rt"])
           .concatenate(["v", "a", "ter", "v_bias", "a_bias", "m_bias", "v_ratio"], into = "inference_variables")
           .concatenate(["rt", "resp", "acc", "stim", "conf_resp", "post_rt"], into = "summary_variables")
           .convert_dtype("float64", "float32"))

summary_network = bf.networks.SetTransformer(summary_dim = 32, dropout = None)

However, when I try training the model I get this error

workflow = bf.BasicWorkflow(simulator = simulator,
                            adapter = adapter,
                            inference_network = inference_network,
                            summary_network = summary_network,
                            standardize = ["inference_variables", "summary_variables"])

history = workflow.fit_online(epochs = 60, num_batches_per_epoch = 500, batch_size = 64)

Hey,
thanks for reaching out and welcome to the forum. This looks like a weird bug to me that here seems to be triggered by setting dropout=None instead of dropout=0.0 in the SetTransformer constructor. Using

summary_network = bf.networks.SetTransformer(summary_dim = 32, dropout = 0.0)

instead should fix this. We will take a look if we can make this more robust. Let us know if this resolves the issue for you.

1 Like

Hi,

Thank you that seems to work.