Poor parameter recovery for hybrid reinforcement learning models in a two step task

Hello,

I am currently attempting to use Bayesflow (v1) for a hybrid reinforcement learning model for a two-step RL task but the parameter recovery is quite poor.

To briefly describe the model, we assume that first stage choice is driven by a weighting (w) of the model based component + model free component + perservation component:

c_{1,t} \sim \text{BernoulliLogit}( \beta w \Delta Q_{MB,t} + \beta (1-w) \Delta Q_{MF,t} + \beta_c \text{past choice}_t)

where the model based value for each first stage action (a1) is calculated by the sum of the maximum second stage Q value weighted by their transition probability:

Q_{MB, a1} = 0.7 * \max_{a2} \ Q_2(s_{common}, a2) + 0.3 * \max_{a2} \ Q_2(s_{rare}, a2)

with the common and rare second-stage states depending on the first-stage action.

The Q value of each {state, action} at the second stage, is updated via the observed reward

Q_2(s_t, a_{2,t})= Q_2(s_t, a_{2,t}) + alpha * (reward - Q_2(s_t, a_{2,t}))

The model free process uses the first stage cached values, where:

Q_{MF}(a_{1,t}) = Q_{MF}(a_{1,t}) + alpha * (reward - Q_{MF}(a_{1,t}))

Second stage choices are modeled using a separate inverse temperature and uses the difference in value of each stage 2 action given a specific state

c_{2,t} \sim \text{BernoulliLogit}(\beta_2 \Delta Q_2(s_t, a_{2, t}))

Non-chosen options are decayed/devalued.

Currently I am using a timeseries transformer for the summary network and an invertible network for the inference network

summary_net = bf.networks.TimeSeriesTransformer(input_dim = 9, summary_dim = 64, num_dense_fc = 4, name = "timeseries_summary", attention_settings=dict(num_heads=4, key_dim=32))

spline2 = bf.networks.InvertibleNetwork(
    num_params = len(prior.param_names),
    coupling_settings = {'dropout': False, 'bins' : 16, "dense_args": dict(units=128, kernel_regularizer=None, activation="elu"), "num_dense": 2},
    num_coupling_layers = 12, 
    coupling_design = "spline",
    permutation = "fixed",
    name = "timeseries_inference")

where I fed one-hot encoding of behavioural data (first stage choice, second stage transition, second stage choice, reward ) + positional encodings

Here are the recovery of the parameters for this model, which is quite poor across most parameters

At first I wondered if this was due to insufficient evidence in behavioural data, so I also tried training the model with the Q-values being fed into it, just as a test case of maximal information being fed and its still poor for beta and w

Priors used were empirical priors derived from STAN. Trials were set to 200, but I have tried increasing trial numbers to no avail. I have also tried different parameterisation approaches using separate betas, and also a RL drift diffusion approach, neither of which were able to recover the parameters well.

At this point I think this has to do with parameter identifiability in the model that is specific to SBI. I assume STAN can somehow recover these parameters well as its estimating the likelihood on a trial-by-trial basis, whereas the summary network in SBI may not be able to capture the granularity in information across trials needed for identification. But I am not sure. I am wondering if anyone has any ideas as to why this is not working?

Hi pavgreen, before I make any suggestions for improvements, I strongly suggest you switch to bayesflow 2, which is a much more powerful framework compared to the legacy v1 version. It also supports tensorflow on GPU if that’s your backend setup. I would also need to see the model and training config.