Poor parameter recovery for hybrid reinforcement learning models in a two step task

Hello,

I am currently attempting to use Bayesflow (v1) for a hybrid reinforcement learning model for a two-step RL task but the parameter recovery is quite poor.

To briefly describe the model, we assume that first stage choice is driven by a weighting (w) of the model based component + model free component + perservation component:

c_{1,t} \sim \text{BernoulliLogit}( \beta w \Delta Q_{MB,t} + \beta (1-w) \Delta Q_{MF,t} + \beta_c \text{past choice}_t)

where the model based value for each first stage action (a1) is calculated by the sum of the maximum second stage Q value weighted by their transition probability:

Q_{MB, a1} = 0.7 * \max_{a2} \ Q_2(s_{common}, a2) + 0.3 * \max_{a2} \ Q_2(s_{rare}, a2)

with the common and rare second-stage states depending on the first-stage action.

The Q value of each {state, action} at the second stage, is updated via the observed reward

Q_2(s_t, a_{2,t})= Q_2(s_t, a_{2,t}) + alpha * (reward - Q_2(s_t, a_{2,t}))

The model free process uses the first stage cached values, where:

Q_{MF}(a_{1,t}) = Q_{MF}(a_{1,t}) + alpha * (reward - Q_{MF}(a_{1,t}))

Second stage choices are modeled using a separate inverse temperature and uses the difference in value of each stage 2 action given a specific state

c_{2,t} \sim \text{BernoulliLogit}(\beta_2 \Delta Q_2(s_t, a_{2, t}))

Non-chosen options are decayed/devalued.

Currently I am using a timeseries transformer for the summary network and an invertible network for the inference network

summary_net = bf.networks.TimeSeriesTransformer(input_dim = 9, summary_dim = 64, num_dense_fc = 4, name = "timeseries_summary", attention_settings=dict(num_heads=4, key_dim=32))

spline2 = bf.networks.InvertibleNetwork(
    num_params = len(prior.param_names),
    coupling_settings = {'dropout': False, 'bins' : 16, "dense_args": dict(units=128, kernel_regularizer=None, activation="elu"), "num_dense": 2},
    num_coupling_layers = 12, 
    coupling_design = "spline",
    permutation = "fixed",
    name = "timeseries_inference")

where I fed one-hot encoding of behavioural data (first stage choice, second stage transition, second stage choice, reward ) + positional encodings

Here are the recovery of the parameters for this model, which is quite poor across most parameters

At first I wondered if this was due to insufficient evidence in behavioural data, so I also tried training the model with the Q-values being fed into it, just as a test case of maximal information being fed and its still poor for beta and w

Priors used were empirical priors derived from STAN. Trials were set to 200, but I have tried increasing trial numbers to no avail. I have also tried different parameterisation approaches using separate betas, and also a RL drift diffusion approach, neither of which were able to recover the parameters well.

At this point I think this has to do with parameter identifiability in the model that is specific to SBI. I assume STAN can somehow recover these parameters well as its estimating the likelihood on a trial-by-trial basis, whereas the summary network in SBI may not be able to capture the granularity in information across trials needed for identification. But I am not sure. I am wondering if anyone has any ideas as to why this is not working?

Hi pavgreen, before I make any suggestions for improvements, I strongly suggest you switch to bayesflow 2, which is a much more powerful framework compared to the legacy v1 version. It also supports tensorflow on GPU if that’s your backend setup. I would also need to see the model and training config.

Hi, thank you for the reply and sorry for the late reply.

Indeed, after switching to Bayesflow v2 the recovery of parameters has improved significantly!

This is the data generating model

Summary
@nb.jit(nopython=True, fastmath=False, parallel=False)
def sigmoid(x):
    return 1.0 / (1.0 + np.exp(-x))

@nb.jit(nopython=True, fastmath=False, parallel=False)
def timeseries_likelihood(betam, w, betac, betas, alpha, decay, ntrials, reward_list = reward_list):

    betam = sigmoid(betam) * 10
    w = sigmoid(w)
    # betat = sigmoid(betat) * 10 # Uncomment this and comment w if using separate betas
    betas = sigmoid(betas) * 10
    alpha = sigmoid(alpha)
    decay = sigmoid(decay)

    # Outcome that needs storing
    c1 = np.zeros(ntrials, dtype=np.int32)
    c2 = np.zeros(ntrials, dtype=np.int32)
    r  = np.zeros(ntrials, dtype=np.int32)
    st = np.zeros(ntrials, dtype=np.int32)
    ctx = np.zeros(ntrials, dtype=np.int32)
    
    # Values that need updatin
    qt1 = np.zeros(2, dtype=np.float64) + 0.5
    qt2 = np.zeros((2, 2), dtype=np.float64) + 0.5
    pc = 0.0
    
    # Intitiate time series
    for n in range(ntrials):
    
        qs1 = np.max(qt2[0])
        qs2 = np.max(qt2[1])

        qc1 = 0.7 * qs1 + 0.3 * qs2
        qc2 = 0.7 * qs2 + 0.3 * qs1

        y1 = sigmoid(betam * w * (qc2 - qc1) + betam * (1-w) * (qt1[1] - qt1[0]) + betac * pc)
        # y1 = sigmoid(betam * (qc2 - qc1) + betat * (qt1[1] - qt1[0]) + betac * pc) # Uncomment this and comment previous line if using separate betas
        c1[n] = np.random.binomial(1, y1)
    
        # Update the states
        if np.random.random() < 0.7:
            st[n] = c1[n]
        else:
            st[n] = 1 - c1[n]
    
        # Update the second choice
        y2 = sigmoid(betas * (qt2[st[n], 1] - qt2[st[n], 0]))
        c2[n] = np.random.binomial(1, y2)
    
        pc = 2 * c1[n] - 1
    
        # Generate reward based on stage + choice
        idx = st[n] * 2 + c2[n]
        p = reward_list[n, idx]
        r[n] = np.random.binomial(1, p)

        # Dont rescale
        delta_1 = qt2[st[n], c2[n]] - qt1[c1[n]]
        delta_2 = r[n] - qt2[st[n], c2[n]]; 

        qt1[c1[n]] = qt1[c1[n]] + alpha * (delta_1 + delta_2)
        qt2[st[n], c2[n]] = qt2[st[n], c2[n]] + alpha * (delta_2)
    
        # Decay all non-chosen options
        nc1 = 1 - c1[n]
        nst = 1 - st[n]
        nc2 = 1 - c2[n]
        
        qt1[nc1]        *= (1 - decay)
        qt2[st[n], nc2] *= (1 - decay)
        qt2[nst, 0]     *= (1 - decay)
        qt2[nst, 1]     *= (1 - decay)
        
    return c1, st, c2, r
    

Here are the details of the adapter, summary and inference network

obs_keys = ["c1", "r", "c2", "st"]
par_keys = ["betam", "w", "betac", "betas", "alpha", "decay"]
# par_keys = ["betam", "betat", "betac", "betas", "alpha", "decay"]

adapter = (
    bf.Adapter()
    .broadcast("ntrials", to="c1")
    .sqrt("ntrials")
    .as_time_series(["c1", "r", "c2", "st"])
    .convert_dtype("int32", "float32", include = obs_keys)
    .convert_dtype("float64", "float32", include = par_keys + ["ntrials"])
    .concatenate(["betam", "w", "betac", "betas", "alpha", "decay"], into = "inference_variables")
    # .concatenate(["betam", "betat", "betac", "betas", "alpha", "decay"], into = "inference_variables")
    .concatenate(["c1", "r", "c2", "st"], into = "summary_variables")
    .rename("ntrials", "inference_conditions"))

summary_network = bf.networks.TimeSeriesTransformer(summary_dim = 64)
inference_network = bf.networks.CouplingFlow(transform = "spline") # tried affine with similar results

The priors used for training were empirial priors obtained from STAN modelling of an existing dataset. Training is set to 200 trials. I will need to do recovery on the same exact datasets used in the SBI recovery plot shown, but they look close to what you would get with STAN.


I have a couple of general questions, sorry if they seem stupid. First, regarding the difference in performance across both, I had assumed that the timeseries transformer network, and couplingflow and invertible network in v1 and v2 were fairly similar? Second, my experience has been that uniform priors work fairly poorly with SBI - why is that?

Thank you for taking the time to read and reply

That makes sense, the new library has vastly improved internals. If you are on JAX backend, you should also see considerable speed improvements. As of version 2.0.12, there were some very notable improvements specific to the transformers as well.

You can always do a quick heuristic check with a small vs. large summary network to see if increasing the model’s size improves performance (5. Summary Networks — BayesFlow: Amortized Bayesian Inference).

You can also quickly summarize performance holistically using workflow.compute_default_diagnostics or workflow.plot_default_diagnostics.

As of v2, we also recommend FlowMatching as the most performant inference architecture instead of CouplngFlow.

I have not observed a universal linear link between performance and prior width. For example, if the prior is very informative, recovery would often not be strong (if that’s the main performance outcome; the triple of recovery x calibration x sharpness is usually the most sensible for these models), because the data provides little information. With a diffuse prior, the data generally exerts a large influence so recovery should be “better”. But that of course depends on the interaction between simulator and prior (i.e., “the prior should only be viewed in the context of the likelihood”).