Chapter 9 Bayesian models of cognition

9.1 Load the packages

pacman::p_load(
  tidyverse,
  brms,
  cmdstanr,
  patchwork
)

9.2 Design the model

[MISSING: EXPLAIN THE MODEL] [MISSING: Explain the different outcomes, including examples of experiments] [MISSING: explain the difference between a simple Bayes with weights of 1 - accumulating evidence - and one with weights that sum up to 1 - integrating/averaging evidence]

SimpleBayes_f <- function(bias, Source1, Source2){
  
  outcome <- inv_logit_scaled(bias + logit_scaled(Source1) + logit_scaled(Source2))
  
  return(outcome)
  
}

SimpleBayes_MultiSource_f <- function(bias, sources) {

  outcome <- inv_logit_scaled(bias + sum(logit_scaled(sources)))
  
  return(outcome)
}

9.3 Create the data

bias <- 0
trials <- seq(10)
Source1 <- seq(0.1,0.9, 0.1)
Source2 <- seq(0.1,0.9, 0.1)

db <- expand.grid(bias = bias, trials = trials, Source1 = Source1, Source2 = Source2)

for (n in seq(nrow(db))) {
  db$belief[n] <- SimpleBayes_f(db$bias[n], db$Source1[n], db$Source2[n])
  db$choice[n] <- rbinom(1,1, db$belief[n])
  db$continuous[n] <- db$belief[n]*9
  db$discrete[n] <- round(db$belief[n]*9,0)
}

9.4 Visualize

[MISSING: Explain]

ggplot(db, aes(belief)) +
  geom_histogram(bins = 10, alpha = 0.3, color = "black") +
  theme_bw()

ggplot(db, aes(Source1, belief, color = Source2, group = Source2)) +
  geom_line() +
  theme_bw()

ggplot(db, aes(choice)) +
  geom_histogram(bins = 10, alpha = 0.3, color = "black") +
  theme_bw()

ggplot(db, aes(Source1, choice, color = Source2, group = Source2)) +
  geom_smooth(se = F) +
  theme_bw()
## `geom_smooth()` using method = 'loess' and formula = 'y ~ x'

ggplot(db, aes(continuous)) +
  geom_histogram(bins = 10, alpha = 0.3, color = "black") +
  theme_bw()

ggplot(db, aes(Source1, continuous, color = Source2, group = Source2)) +
  geom_smooth() +
  theme_bw()
## `geom_smooth()` using method = 'loess' and formula = 'y ~ x'

ggplot(db, aes(discrete)) +
  geom_histogram(bins = 10, alpha = 0.3, color = "black") +
  theme_bw()

ggplot(db, aes(Source1, discrete, color = Source2, group = Source2)) +
  geom_smooth() +
  theme_bw()
## `geom_smooth()` using method = 'loess' and formula = 'y ~ x'

9.5 Data for Stan

data_simpleBayes <- list(
  N = nrow(db),
  y = db$choice,
  Source1 = db$Source1,
  Source2 = db$Source2
)

9.6 Create the Stan Model

stan_simpleBayes_model <- "
data {
  int<lower=0> N;
  array[N] int y;
  array[N] real<lower=0, upper = 1> Source1;
  array[N] real<lower=0, upper = 1> Source2;
}

transformed data{
  array[N] real l_Source1;
  array[N] real l_Source2;
  l_Source1 = logit(Source1);
  l_Source2 = logit(Source2);
}

parameters {
  real bias;
}

model {
  target +=  normal_lpdf(bias | 0, 1);
  target +=  bernoulli_logit_lpmf(y | bias + to_vector(l_Source1) + to_vector(l_Source2));
}

generated quantities{
  real bias_prior;
  array[N] real log_lik;
  
  bias_prior = normal_rng(0, 1);
  
  for (n in 1:N){  
    log_lik[n] = bernoulli_logit_lpmf(y[n] | bias + l_Source1[n] +  l_Source2[n]);
  }
  
}

"

write_stan_file(
  stan_simpleBayes_model,
  dir = "stan/",
  basename = "W9_SimpleBayes.stan")
## [1] "/Users/au209589/Dropbox/Teaching/AdvancedCognitiveModeling23_book/stan/W9_SimpleBayes.stan"
file <- file.path("stan/W9_SimpleBayes.stan")
mod_simpleBayes <- cmdstan_model(file, cpp_options = list(stan_threads = TRUE),
                     stanc_options = list("O1"))

9.7 Fitting the model

samples_simple <- mod_simpleBayes$sample(
  data = data_simpleBayes,
  #fixed_param = TRUE,
  seed = 123,
  chains = 2,
  parallel_chains = 2,
  threads_per_chain = 2,
  iter_warmup = 1500,
  iter_sampling = 3000,
  refresh = 500
)
## Running MCMC with 2 parallel chains, with 2 thread(s) per chain...
## 
## Chain 1 Iteration:    1 / 4500 [  0%]  (Warmup) 
## Chain 1 Iteration:  500 / 4500 [ 11%]  (Warmup) 
## Chain 1 Iteration: 1000 / 4500 [ 22%]  (Warmup) 
## Chain 1 Iteration: 1500 / 4500 [ 33%]  (Warmup) 
## Chain 1 Iteration: 1501 / 4500 [ 33%]  (Sampling) 
## Chain 2 Iteration:    1 / 4500 [  0%]  (Warmup) 
## Chain 2 Iteration:  500 / 4500 [ 11%]  (Warmup) 
## Chain 2 Iteration: 1000 / 4500 [ 22%]  (Warmup) 
## Chain 2 Iteration: 1500 / 4500 [ 33%]  (Warmup) 
## Chain 2 Iteration: 1501 / 4500 [ 33%]  (Sampling) 
## Chain 1 Iteration: 2000 / 4500 [ 44%]  (Sampling) 
## Chain 2 Iteration: 2000 / 4500 [ 44%]  (Sampling) 
## Chain 1 Iteration: 2500 / 4500 [ 55%]  (Sampling) 
## Chain 2 Iteration: 2500 / 4500 [ 55%]  (Sampling) 
## Chain 1 Iteration: 3000 / 4500 [ 66%]  (Sampling) 
## Chain 2 Iteration: 3000 / 4500 [ 66%]  (Sampling) 
## Chain 1 Iteration: 3500 / 4500 [ 77%]  (Sampling) 
## Chain 2 Iteration: 3500 / 4500 [ 77%]  (Sampling) 
## Chain 1 Iteration: 4000 / 4500 [ 88%]  (Sampling) 
## Chain 2 Iteration: 4000 / 4500 [ 88%]  (Sampling) 
## Chain 1 Iteration: 4500 / 4500 [100%]  (Sampling) 
## Chain 1 finished in 1.0 seconds.
## Chain 2 Iteration: 4500 / 4500 [100%]  (Sampling) 
## Chain 2 finished in 1.0 seconds.
## 
## Both chains finished successfully.
## Mean chain execution time: 1.0 seconds.
## Total execution time: 1.3 seconds.

9.8 Basic evaluation

samples_simple$cmdstan_diagnose()
## Processing csv files: /var/folders/lt/zspkqnxd5yg92kybm5f433_cfjr0d6/T/RtmpP3dlfk/W9_SimpleBayes-202402150650-1-1b8ab4.csv, /var/folders/lt/zspkqnxd5yg92kybm5f433_cfjr0d6/T/RtmpP3dlfk/W9_SimpleBayes-202402150650-2-1b8ab4.csv
## 
## Checking sampler transitions treedepth.
## Treedepth satisfactory for all transitions.
## 
## Checking sampler transitions for divergences.
## No divergent transitions found.
## 
## Checking E-BFMI - sampler transitions HMC potential energy.
## E-BFMI satisfactory.
## 
## Effective sample size satisfactory.
## 
## Split R-hat values satisfactory all parameters.
## 
## Processing complete, no problems detected.
samples_simple$summary()
## # A tibble: 813 × 10
##    variable      mean   median      sd     mad       q5      q95  rhat ess_bulk ess_tail
##    <chr>        <dbl>    <dbl>   <dbl>   <dbl>    <dbl>    <dbl> <dbl>    <dbl>    <dbl>
##  1 lp__      -3.99e+2 -3.99e+2 6.97e-1 3.05e-1 -4.01e+2 -3.99e+2  1.00    3053.    3376.
##  2 bias      -1.56e-1 -1.55e-1 8.87e-2 8.84e-2 -3.05e-1 -9.42e-3  1.00    2150.    3351.
##  3 bias_pri… -2.66e-2 -2.85e-2 1.01e+0 9.92e-1 -1.71e+0  1.65e+0  1.00    5342.    5967.
##  4 log_lik[… -1.05e-2 -1.05e-2 9.30e-4 9.27e-4 -1.22e-2 -9.06e-3  1.00    2150.    3351.
##  5 log_lik[… -1.05e-2 -1.05e-2 9.30e-4 9.27e-4 -1.22e-2 -9.06e-3  1.00    2150.    3351.
##  6 log_lik[… -1.05e-2 -1.05e-2 9.30e-4 9.27e-4 -1.22e-2 -9.06e-3  1.00    2150.    3351.
##  7 log_lik[… -1.05e-2 -1.05e-2 9.30e-4 9.27e-4 -1.22e-2 -9.06e-3  1.00    2150.    3351.
##  8 log_lik[… -1.05e-2 -1.05e-2 9.30e-4 9.27e-4 -1.22e-2 -9.06e-3  1.00    2150.    3351.
##  9 log_lik[… -1.05e-2 -1.05e-2 9.30e-4 9.27e-4 -1.22e-2 -9.06e-3  1.00    2150.    3351.
## 10 log_lik[… -1.05e-2 -1.05e-2 9.30e-4 9.27e-4 -1.22e-2 -9.06e-3  1.00    2150.    3351.
## # ℹ 803 more rows
samples_simple$loo()
## 
## Computed from 6000 by 810 log-likelihood matrix
## 
##          Estimate   SE
## elpd_loo   -398.8 15.8
## p_loo         1.0  0.0
## looic       797.7 31.7
## ------
## Monte Carlo SE of elpd_loo is 0.0.
## 
## All Pareto k estimates are good (k < 0.5).
## See help('pareto-k-diagnostic') for details.
draws_df <- as_draws_df(samples_simple$draws())

ggplot(draws_df, aes(.iteration, bias, group = .chain, color = .chain)) +
  geom_line(alpha = 0.5) +
  theme_classic()

ggplot(draws_df) +
  geom_density(aes(bias), alpha = 0.6, fill = "lightblue") +
  geom_density(aes(bias_prior), alpha = 0.6, fill = "pink") +
  geom_vline(xintercept = db$bias[1]) +
  theme_bw()

[MISSING PARAMETER RECOVERY]

9.9 Weighted Bayes

[MISSING: EXPLANATION] [MISSING: FOCUS ON WEIGHTS AND THEIR SCALE: 0-1 on log-odds; 0.5-1 on probability]

WeightedBayes_f <- function(bias, Source1, Source2, w1, w2){
  w1 <- (w1 - 0.5)*2
  w2 <- (w2 - 0.5)*2
  outcome <- inv_logit_scaled(bias + w1 * logit_scaled(Source1) + w2 * logit_scaled(Source2))
  return(outcome)
}

## This version of the model is from Taking others into account (in the syllabus)
## It takes two sources of information and weights them, then adds a bias
## to generate a posterior on a 0-1 scale
WeightedBayes_f1 <- function(bias, Source1, Source2, w1, w2){
  outcome <- inv_logit_scaled(bias + weight_f(logit_scaled(Source1), w1) +
                                weight_f(logit_scaled(Source2), w2))
  return(outcome)
}

## The weight_f formula comes from https://www.nature.com/articles/ncomms14218
## and ensures that even if we work on a log-odds scale, we get the right weights
## It takes all values of L (- inf to +inf). Technically the only valid values for
## w are 0.5 (no consideration of the evidence) to 1 (taking the evidence at face value).
## In practice the function would also accept 0-0.5 (invert the evidence, at face value
## if 0, at decreased value as it grows towards 0.5), and slightly higher than 1
## (overweighing the evidence, but it's very unstable and quickly gives NaN).
weight_f <- function(L, w){
  return(log((w * exp(L) + 1 - w) / 
        ((1 - w) * exp(L) + w)))
      }


bias <- 0
trials <- seq(10)
Source1 <- seq(0.1,0.9, 0.1)
Source2 <- seq(0.1,0.9, 0.1)
w1 <- seq(0.5, 1, 0.1)
w2 <- seq(0.5, 1, 0.1)

db <- expand.grid(bias = bias, trials, Source1 = Source1, Source2 = Source2, w1 = w1, w2 = w2)

for (n in seq(nrow(db))) {
  db$belief[n] <- WeightedBayes_f(db$bias[n], db$Source1[n], db$Source2[n],db$w1[n], db$w2[n])
  db$belief1[n] <- WeightedBayes_f1(db$bias[n], db$Source1[n], db$Source2[n],db$w1[n], db$w2[n])
  db$binary[n] <- rbinom(1,1, db$belief[n])
  db$binary1[n] <- rbinom(1,1, db$belief1[n])
  db$continuous[n] <- db$belief[n] * 9
  db$continuous1[n] <- db$belief1[n] * 9
  db$discrete[n] <- round(db$belief[n] * 9,0)
  db$discrete1[n] <- round(db$belief1[n] * 9,0)
}

9.10 Visualize

ggplot(db, aes(belief, belief1)) +
  geom_point() +
  theme_bw()

ggplot(db) +
  geom_histogram(aes(belief), bins = 10, alpha = 0.3, color = "black", fill = "red") +
  geom_histogram(aes(belief1), bins = 10, alpha = 0.3, color = "black", fill = "blue") +
  theme_bw()

p1 <- ggplot(db, aes(Source1, belief, color = Source2, group = Source2)) +
  geom_line() +
  theme_bw() +
  facet_wrap(w1~w2)

p2 <- ggplot(db, aes(Source1, belief1, color = Source2, group = Source2)) +
  geom_line() +
  theme_bw() +
  facet_wrap(w1~w2)

p1 + p2

9.11 Build the Weighted Bayes Stan model (simple formula)

stan_WB_model <- "
data {
  int<lower=0> N;
  array[N] int y;
  array[N] real <lower = 0, upper = 1> Source1; 
  array[N] real <lower = 0, upper = 1> Source2; 
}

transformed data {
  array[N] real l_Source1;
  array[N] real l_Source2;
  l_Source1 = logit(Source1);
  l_Source2 = logit(Source2);
}
parameters {
  real bias;
  // meaningful weights are btw 0.5 and 1 (theory reasons)
  real<lower = 0.5, upper = 1> w1; 
  real<lower = 0.5, upper = 1> w2;
}
transformed parameters {
  real<lower = 0, upper = 1> weight1;
  real<lower = 0, upper = 1> weight2;
  // weight parameters are rescaled to be on a 0-1 scale (0 -> no effects; 1 -> face value)
  weight1 = (w1 - 0.5) * 2;  
  weight2 = (w2 - 0.5) * 2;
}
model {
  target += normal_lpdf(bias | 0, 1);
  target += beta_lpdf(weight1 | 1, 1);
  target += beta_lpdf(weight2 | 1, 1);
  for (n in 1:N)
    target += bernoulli_logit_lpmf(y[n] | bias + weight1 *l_Source1[n] + weight2 * l_Source2[n]);
}
generated quantities{
  array[N] real log_lik;
  real bias_prior;
  real w1_prior;
  real w2_prior;
  bias_prior = normal_rng(0, 1) ;
  w1_prior = 0.5 + inv_logit(normal_rng(0, 1))/2 ;
  w2_prior = 0.5 + inv_logit(normal_rng(0, 1))/2 ;
  for (n in 1:N)
    log_lik[n]= bernoulli_logit_lpmf(y[n] | bias + weight1 * l_Source1[n] + weight2 * l_Source2[n]);
}

"

write_stan_file(
  stan_WB_model,
  dir = "stan/",
  basename = "W9_WB.stan")
## [1] "/Users/au209589/Dropbox/Teaching/AdvancedCognitiveModeling23_book/stan/W9_WB.stan"
file <- file.path("stan/W9_WB.stan")
mod_wb <- cmdstan_model(file, cpp_options = list(stan_threads = TRUE),
                     stanc_options = list("O1"))

db1 <-  db %>% subset(w1 == 0.7 & w2 == 0.9) 


p3 <- ggplot(db1, aes(Source1, belief, color = Source2, group = Source2)) +
  geom_line() +
  theme_bw()
p3 

ggplot(db1, aes(Source1, binary)) +
  geom_smooth() +
  theme_bw()
## `geom_smooth()` using method = 'loess' and formula = 'y ~ x'

data_weightedBayes <- list(
  N = nrow(db1),
  y = db1$binary,
  Source1 = db1$Source1,
  Source2 = db1$Source2
)

samples_weighted <- mod_wb$sample(
  data = data_weightedBayes,
  seed = 123,
  chains = 2,
  parallel_chains = 2,
  threads_per_chain = 2,
  iter_warmup = 1500,
  iter_sampling = 3000,
  refresh = 500
)
## Running MCMC with 2 parallel chains, with 2 thread(s) per chain...
## 
## Chain 1 Iteration:    1 / 4500 [  0%]  (Warmup) 
## Chain 2 Iteration:    1 / 4500 [  0%]  (Warmup) 
## Chain 1 Iteration:  500 / 4500 [ 11%]  (Warmup) 
## Chain 2 Iteration:  500 / 4500 [ 11%]  (Warmup) 
## Chain 1 Iteration: 1000 / 4500 [ 22%]  (Warmup) 
## Chain 2 Iteration: 1000 / 4500 [ 22%]  (Warmup) 
## Chain 1 Iteration: 1500 / 4500 [ 33%]  (Warmup) 
## Chain 1 Iteration: 1501 / 4500 [ 33%]  (Sampling) 
## Chain 2 Iteration: 1500 / 4500 [ 33%]  (Warmup) 
## Chain 2 Iteration: 1501 / 4500 [ 33%]  (Sampling) 
## Chain 1 Iteration: 2000 / 4500 [ 44%]  (Sampling) 
## Chain 2 Iteration: 2000 / 4500 [ 44%]  (Sampling) 
## Chain 1 Iteration: 2500 / 4500 [ 55%]  (Sampling) 
## Chain 2 Iteration: 2500 / 4500 [ 55%]  (Sampling) 
## Chain 1 Iteration: 3000 / 4500 [ 66%]  (Sampling) 
## Chain 2 Iteration: 3000 / 4500 [ 66%]  (Sampling) 
## Chain 1 Iteration: 3500 / 4500 [ 77%]  (Sampling) 
## Chain 2 Iteration: 3500 / 4500 [ 77%]  (Sampling) 
## Chain 1 Iteration: 4000 / 4500 [ 88%]  (Sampling) 
## Chain 2 Iteration: 4000 / 4500 [ 88%]  (Sampling) 
## Chain 1 Iteration: 4500 / 4500 [100%]  (Sampling) 
## Chain 1 finished in 3.6 seconds.
## Chain 2 Iteration: 4500 / 4500 [100%]  (Sampling) 
## Chain 2 finished in 3.7 seconds.
## 
## Both chains finished successfully.
## Mean chain execution time: 3.7 seconds.
## Total execution time: 4.1 seconds.

9.12 Model evaluation

samples_weighted$cmdstan_diagnose()
## Processing csv files: /var/folders/lt/zspkqnxd5yg92kybm5f433_cfjr0d6/T/RtmpP3dlfk/W9_WB-202402150650-1-470abd.csv, /var/folders/lt/zspkqnxd5yg92kybm5f433_cfjr0d6/T/RtmpP3dlfk/W9_WB-202402150650-2-470abd.csv
## 
## Checking sampler transitions treedepth.
## Treedepth satisfactory for all transitions.
## 
## Checking sampler transitions for divergences.
## No divergent transitions found.
## 
## Checking E-BFMI - sampler transitions HMC potential energy.
## E-BFMI satisfactory.
## 
## Effective sample size satisfactory.
## 
## Split R-hat values satisfactory all parameters.
## 
## Processing complete, no problems detected.
samples_weighted$summary()
## # A tibble: 819 × 10
##    variable        mean   median     sd    mad       q5      q95  rhat ess_bulk ess_tail
##    <chr>          <dbl>    <dbl>  <dbl>  <dbl>    <dbl>    <dbl> <dbl>    <dbl>    <dbl>
##  1 lp__        -4.71e+2 -4.71e+2 1.36   1.10   -474.    -4.70e+2  1.00    2628.    3124.
##  2 bias        -6.41e-3 -6.56e-3 0.0794 0.0785   -0.135  1.25e-1  1.00    5023.    3770.
##  3 w1           7.03e-1  7.03e-1 0.0331 0.0331    0.649  7.57e-1  1.00    4493.    3661.
##  4 w2           8.95e-1  8.94e-1 0.0353 0.0364    0.838  9.54e-1  1.00    4204.    2727.
##  5 weight1      4.06e-1  4.05e-1 0.0662 0.0662    0.298  5.14e-1  1.00    4493.    3661.
##  6 weight2      7.89e-1  7.87e-1 0.0705 0.0727    0.675  9.07e-1  1.00    4204.    2727.
##  7 log_lik[1]  -7.14e-2 -6.99e-2 0.0169 0.0168   -0.101 -4.66e-2  1.00    3913.    3436.
##  8 log_lik[2]  -7.14e-2 -6.99e-2 0.0169 0.0168   -0.101 -4.66e-2  1.00    3913.    3436.
##  9 log_lik[3]  -7.14e-2 -6.99e-2 0.0169 0.0168   -0.101 -4.66e-2  1.00    3913.    3436.
## 10 log_lik[4]  -2.70e+0 -2.70e+0 0.229  0.233    -3.09  -2.34e+0  1.00    3913.    3436.
## # ℹ 809 more rows
samples_weighted$loo()
## 
## Computed from 6000 by 810 log-likelihood matrix
## 
##          Estimate   SE
## elpd_loo   -467.0 12.8
## p_loo         3.1  0.2
## looic       933.9 25.5
## ------
## Monte Carlo SE of elpd_loo is 0.0.
## 
## All Pareto k estimates are good (k < 0.5).
## See help('pareto-k-diagnostic') for details.
draws_df <- as_draws_df(samples_weighted$draws())

ggplot(draws_df, aes(.iteration, bias, group = .chain, color = .chain)) +
  geom_line(alpha = 0.5) +
  theme_classic()

ggplot(draws_df, aes(.iteration, w1, group = .chain, color = .chain)) +
  geom_line(alpha = 0.5) +
  theme_classic()

ggplot(draws_df, aes(.iteration, w2, group = .chain, color = .chain)) +
  geom_line(alpha = 0.5) +
  theme_classic()

p1 <- ggplot(draws_df) +
  geom_density(aes(bias), alpha = 0.6, fill = "lightblue") +
  geom_density(aes(bias_prior), alpha = 0.6, fill = "pink") +
  geom_vline(xintercept = db1$bias[1]) +
  theme_bw()

p2 <- ggplot(draws_df) +
  geom_density(aes(w1), alpha = 0.6, fill = "lightblue") +
  geom_density(aes(w1_prior), alpha = 0.6, fill = "pink") +
  geom_vline(xintercept = db1$w1[1]) +
  theme_bw()

p3 <- ggplot(draws_df) +
  geom_density(aes(w2), alpha = 0.6, fill = "lightblue") +
  geom_density(aes(w2_prior), alpha = 0.6, fill = "pink") +
  geom_vline(xintercept = db1$w2[1]) +
  theme_bw()

p1 + p2 + p3

ggplot(draws_df) +
  geom_point(aes(w1, w2), alpha = 0.3) +
  theme_bw()

ggplot(draws_df) +
  geom_point(aes(bias, w1), alpha = 0.3) +
  theme_bw()

ggplot(draws_df) +
  geom_point(aes(bias, w2), alpha = 0.3) +
  theme_bw()

9.13 Model recovery

bias <- 0
trials <- seq(10)
Source1 <- seq(0.1,0.9, 0.1)
Source2 <- seq(0.1,0.9, 0.1)
w1 <- 0.7
w2 <- 0.9

db <- expand.grid(bias = bias, trials, Source1 = Source1, Source2 = Source2, w1 = w1, w2 = w2)

for (n in seq(nrow(db))) {
  db$simple_belief[n] <- SimpleBayes_f(db$bias[n], db$Source1[n], db$Source2[n])
  db$weighted_belief[n] <- WeightedBayes_f(db$bias[n], db$Source1[n], db$Source2[n],db$w1[n], db$w2[n])
  db$simple_binary[n] <- rbinom(1,1, db$simple_belief[n])
  db$weighted_binary[n] <- rbinom(1,1, db$weighted_belief[n])
}

data_SB <- list(
  N = nrow(db),
  y = db$simple_binary,
  Source1 = db$Source1,
  Source2 = db$Source2
)

data_WB <- list(
  N = nrow(db),
  y = db$weighted_binary,
  Source1 = db$Source1,
  Source2 = db$Source2
)

## On the simple data
simple2simple <- mod_simpleBayes$sample(
  data = data_SB,
  seed = 123,
  chains = 2,
  parallel_chains = 2,
  threads_per_chain = 2,
  iter_warmup = 1500,
  iter_sampling = 3000,
  refresh = 500
)
## Running MCMC with 2 parallel chains, with 2 thread(s) per chain...
## 
## Chain 1 Iteration:    1 / 4500 [  0%]  (Warmup) 
## Chain 1 Iteration:  500 / 4500 [ 11%]  (Warmup) 
## Chain 1 Iteration: 1000 / 4500 [ 22%]  (Warmup) 
## Chain 1 Iteration: 1500 / 4500 [ 33%]  (Warmup) 
## Chain 1 Iteration: 1501 / 4500 [ 33%]  (Sampling) 
## Chain 1 Iteration: 2000 / 4500 [ 44%]  (Sampling) 
## Chain 2 Iteration:    1 / 4500 [  0%]  (Warmup) 
## Chain 2 Iteration:  500 / 4500 [ 11%]  (Warmup) 
## Chain 2 Iteration: 1000 / 4500 [ 22%]  (Warmup) 
## Chain 2 Iteration: 1500 / 4500 [ 33%]  (Warmup) 
## Chain 2 Iteration: 1501 / 4500 [ 33%]  (Sampling) 
## Chain 2 Iteration: 2000 / 4500 [ 44%]  (Sampling) 
## Chain 1 Iteration: 2500 / 4500 [ 55%]  (Sampling) 
## Chain 1 Iteration: 3000 / 4500 [ 66%]  (Sampling) 
## Chain 2 Iteration: 2500 / 4500 [ 55%]  (Sampling) 
## Chain 2 Iteration: 3000 / 4500 [ 66%]  (Sampling) 
## Chain 1 Iteration: 3500 / 4500 [ 77%]  (Sampling) 
## Chain 1 Iteration: 4000 / 4500 [ 88%]  (Sampling) 
## Chain 2 Iteration: 3500 / 4500 [ 77%]  (Sampling) 
## Chain 1 Iteration: 4500 / 4500 [100%]  (Sampling) 
## Chain 2 Iteration: 4000 / 4500 [ 88%]  (Sampling) 
## Chain 1 finished in 1.0 seconds.
## Chain 2 Iteration: 4500 / 4500 [100%]  (Sampling) 
## Chain 2 finished in 1.1 seconds.
## 
## Both chains finished successfully.
## Mean chain execution time: 1.1 seconds.
## Total execution time: 1.4 seconds.
weighted2simple <- mod_wb$sample(
  data = data_SB,
  seed = 123,
  chains = 2,
  parallel_chains = 2,
  threads_per_chain = 2,
  iter_warmup = 1500,
  iter_sampling = 3000,
  refresh = 500
)
## Running MCMC with 2 parallel chains, with 2 thread(s) per chain...
## 
## Chain 1 Iteration:    1 / 4500 [  0%]  (Warmup) 
## Chain 2 Iteration:    1 / 4500 [  0%]  (Warmup) 
## Chain 1 Iteration:  500 / 4500 [ 11%]  (Warmup) 
## Chain 2 Iteration:  500 / 4500 [ 11%]  (Warmup) 
## Chain 1 Iteration: 1000 / 4500 [ 22%]  (Warmup) 
## Chain 2 Iteration: 1000 / 4500 [ 22%]  (Warmup) 
## Chain 1 Iteration: 1500 / 4500 [ 33%]  (Warmup) 
## Chain 1 Iteration: 1501 / 4500 [ 33%]  (Sampling) 
## Chain 2 Iteration: 1500 / 4500 [ 33%]  (Warmup) 
## Chain 2 Iteration: 1501 / 4500 [ 33%]  (Sampling) 
## Chain 1 Iteration: 2000 / 4500 [ 44%]  (Sampling) 
## Chain 2 Iteration: 2000 / 4500 [ 44%]  (Sampling) 
## Chain 1 Iteration: 2500 / 4500 [ 55%]  (Sampling) 
## Chain 2 Iteration: 2500 / 4500 [ 55%]  (Sampling) 
## Chain 2 Iteration: 3000 / 4500 [ 66%]  (Sampling) 
## Chain 1 Iteration: 3000 / 4500 [ 66%]  (Sampling) 
## Chain 2 Iteration: 3500 / 4500 [ 77%]  (Sampling) 
## Chain 1 Iteration: 3500 / 4500 [ 77%]  (Sampling) 
## Chain 2 Iteration: 4000 / 4500 [ 88%]  (Sampling) 
## Chain 1 Iteration: 4000 / 4500 [ 88%]  (Sampling) 
## Chain 2 Iteration: 4500 / 4500 [100%]  (Sampling) 
## Chain 2 finished in 4.5 seconds.
## Chain 1 Iteration: 4500 / 4500 [100%]  (Sampling) 
## Chain 1 finished in 4.8 seconds.
## 
## Both chains finished successfully.
## Mean chain execution time: 4.7 seconds.
## Total execution time: 5.0 seconds.
simple2weighted <- mod_simpleBayes$sample(
  data = data_WB,
  seed = 123,
  chains = 2,
  parallel_chains = 2,
  threads_per_chain = 2,
  iter_warmup = 1500,
  iter_sampling = 3000,
  refresh = 500
)
## Running MCMC with 2 parallel chains, with 2 thread(s) per chain...
## 
## Chain 1 Iteration:    1 / 4500 [  0%]  (Warmup) 
## Chain 1 Iteration:  500 / 4500 [ 11%]  (Warmup) 
## Chain 1 Iteration: 1000 / 4500 [ 22%]  (Warmup) 
## Chain 1 Iteration: 1500 / 4500 [ 33%]  (Warmup) 
## Chain 1 Iteration: 1501 / 4500 [ 33%]  (Sampling) 
## Chain 2 Iteration:    1 / 4500 [  0%]  (Warmup) 
## Chain 2 Iteration:  500 / 4500 [ 11%]  (Warmup) 
## Chain 2 Iteration: 1000 / 4500 [ 22%]  (Warmup) 
## Chain 2 Iteration: 1500 / 4500 [ 33%]  (Warmup) 
## Chain 2 Iteration: 1501 / 4500 [ 33%]  (Sampling) 
## Chain 1 Iteration: 2000 / 4500 [ 44%]  (Sampling) 
## Chain 1 Iteration: 2500 / 4500 [ 55%]  (Sampling) 
## Chain 2 Iteration: 2000 / 4500 [ 44%]  (Sampling) 
## Chain 2 Iteration: 2500 / 4500 [ 55%]  (Sampling) 
## Chain 1 Iteration: 3000 / 4500 [ 66%]  (Sampling) 
## Chain 1 Iteration: 3500 / 4500 [ 77%]  (Sampling) 
## Chain 2 Iteration: 3000 / 4500 [ 66%]  (Sampling) 
## Chain 2 Iteration: 3500 / 4500 [ 77%]  (Sampling) 
## Chain 1 Iteration: 4000 / 4500 [ 88%]  (Sampling) 
## Chain 1 Iteration: 4500 / 4500 [100%]  (Sampling) 
## Chain 2 Iteration: 4000 / 4500 [ 88%]  (Sampling) 
## Chain 1 finished in 1.0 seconds.
## Chain 2 Iteration: 4500 / 4500 [100%]  (Sampling) 
## Chain 2 finished in 1.0 seconds.
## 
## Both chains finished successfully.
## Mean chain execution time: 1.0 seconds.
## Total execution time: 1.3 seconds.
weighted2weighted <- mod_wb$sample(
  data = data_WB,
  seed = 123,
  chains = 2,
  parallel_chains = 2,
  threads_per_chain = 2,
  iter_warmup = 1500,
  iter_sampling = 3000,
  refresh = 500
)
## Running MCMC with 2 parallel chains, with 2 thread(s) per chain...
## 
## Chain 1 Iteration:    1 / 4500 [  0%]  (Warmup) 
## Chain 2 Iteration:    1 / 4500 [  0%]  (Warmup) 
## Chain 1 Iteration:  500 / 4500 [ 11%]  (Warmup) 
## Chain 2 Iteration:  500 / 4500 [ 11%]  (Warmup) 
## Chain 1 Iteration: 1000 / 4500 [ 22%]  (Warmup) 
## Chain 2 Iteration: 1000 / 4500 [ 22%]  (Warmup) 
## Chain 1 Iteration: 1500 / 4500 [ 33%]  (Warmup) 
## Chain 1 Iteration: 1501 / 4500 [ 33%]  (Sampling) 
## Chain 2 Iteration: 1500 / 4500 [ 33%]  (Warmup) 
## Chain 2 Iteration: 1501 / 4500 [ 33%]  (Sampling) 
## Chain 1 Iteration: 2000 / 4500 [ 44%]  (Sampling) 
## Chain 2 Iteration: 2000 / 4500 [ 44%]  (Sampling) 
## Chain 1 Iteration: 2500 / 4500 [ 55%]  (Sampling) 
## Chain 2 Iteration: 2500 / 4500 [ 55%]  (Sampling) 
## Chain 1 Iteration: 3000 / 4500 [ 66%]  (Sampling) 
## Chain 2 Iteration: 3000 / 4500 [ 66%]  (Sampling) 
## Chain 1 Iteration: 3500 / 4500 [ 77%]  (Sampling) 
## Chain 2 Iteration: 3500 / 4500 [ 77%]  (Sampling) 
## Chain 1 Iteration: 4000 / 4500 [ 88%]  (Sampling) 
## Chain 2 Iteration: 4000 / 4500 [ 88%]  (Sampling) 
## Chain 1 Iteration: 4500 / 4500 [100%]  (Sampling) 
## Chain 1 finished in 3.6 seconds.
## Chain 2 Iteration: 4500 / 4500 [100%]  (Sampling) 
## Chain 2 finished in 3.7 seconds.
## 
## Both chains finished successfully.
## Mean chain execution time: 3.7 seconds.
## Total execution time: 4.0 seconds.
Loo_simple2simple <- simple2simple$loo(save_psis = TRUE, cores = 4)
p1 <- plot(Loo_simple2simple)

Loo_weighted2simple <- weighted2simple$loo(save_psis = TRUE, cores = 4)
p2 <- plot(Loo_weighted2simple)

Loo_simple2weighted <- simple2weighted$loo(save_psis = TRUE, cores = 4)
p3 <- plot(Loo_simple2weighted)

Loo_weighted2weighted <- weighted2weighted$loo(save_psis = TRUE, cores = 4)
p4 <- plot(Loo_weighted2weighted)

elpd <- tibble(
  n = seq(810),
  simple_diff_elpd = 
  Loo_simple2simple$pointwise[, "elpd_loo"] - 
  Loo_weighted2simple$pointwise[, "elpd_loo"],
  weighted_diff_elpd = 
  Loo_weighted2weighted$pointwise[, "elpd_loo"] -
  Loo_simple2weighted$pointwise[, "elpd_loo"])

p1 <- ggplot(elpd, aes(x = n, y = simple_diff_elpd)) +
  geom_point(alpha = .1) +
  #xlim(.5,1.01) +
  #ylim(-1.5,1.5) +
  geom_hline(yintercept = 0, color = "red", linetype = "dashed") +
  theme_bw()

p2 <- ggplot(elpd, aes(x = n, y = weighted_diff_elpd)) +
  geom_point(alpha = .1) +
  #xlim(.5,1.01) +
  #ylim(-1.5,1.5) +
  geom_hline(yintercept = 0, color = "red", linetype = "dashed") +
  theme_bw()

library(patchwork)
p1 + p2

loo_compare(Loo_simple2simple, Loo_weighted2simple)
##        elpd_diff se_diff
## model1  0.0       0.0   
## model2 -1.3       1.0
loo_compare(Loo_weighted2weighted, Loo_simple2weighted)
##        elpd_diff se_diff
## model1   0.0       0.0  
## model2 -36.9       9.4
loo_model_weights(list(Loo_simple2simple, Loo_weighted2simple))
## Method: stacking
## ------
##        weight
## model1 1.000 
## model2 0.000
loo_model_weights(list(Loo_weighted2weighted, Loo_simple2weighted))
## Method: stacking
## ------
##        weight
## model1 1.000 
## model2 0.000

9.14 Multilevel

9.15 Bonus: Build the Weighted Bayes Stan model (fancier formula)

stan_WB1_model <- "

functions{
  real weight_f(real L_raw, real w_raw) {
    real L;
    real w;
    L = exp(L_raw);
    w = 0.5 + inv_logit(w_raw)/2;
    return log((w * L + 1 - w)./((1 - w) * L + w));
  }
}


data {
  int<lower=0> N;
  array[N] int y;
  vector[N] Source1;
  vector[N] Source2;
}

parameters {
  real bias;
  real weight1;
  real weight2;
}

model {
  target += normal_lpdf(bias | 0, 1);
  target += normal_lpdf(weight1 | 0, 1.5);
  target += normal_lpdf(weight2 | 0, 1.5);
  
  for (n in 1:N){  
  target += bernoulli_logit_lpmf(y[n] | bias + weight_f(Source1[n], weight1) + weight_f(Source2[n], weight2));
  }
}

generated quantities{
  array[N] real log_lik;
  real bias_prior;
  real w1_prior;
  real w2_prior;
  real w1;
  real w2;
  
  bias_prior = normal_rng(0,1);
  w1_prior = normal_rng(0,1.5);
  w2_prior = normal_rng(0,1.5);
  
  w1_prior = 0.5 + inv_logit(normal_rng(0,1))/2;
  w2_prior = 0.5 + inv_logit(normal_rng(0,1))/2;
  w1 = 0.5 + inv_logit(weight1)/2;
  w2 = 0.5 + inv_logit(weight2)/2;
  
  for (n in 1:N){  
    log_lik[n] = bernoulli_logit_lpmf(y[n] | bias + weight_f(Source1[n], weight1) +
      weight_f(Source2[n], weight2));
  }
  
}

"

write_stan_file(
  stan_WB1_model,
  dir = "stan/",
  basename = "W9_WB1.stan")
## [1] "/Users/au209589/Dropbox/Teaching/AdvancedCognitiveModeling23_book/stan/W9_WB1.stan"
file <- file.path("stan/W9_WB1.stan")
mod_wb1 <- cmdstan_model(file, cpp_options = list(stan_threads = TRUE),
                     stanc_options = list("O1"))

db1 <-  db %>% subset(w1 == 0.7 & w2 == 1) %>% mutate(
  l1 = logit_scaled(Source1),
  l2 = logit_scaled(Source2)
)



data_weightedBayes1 <- list(
  N = nrow(db1),
  y = db1$binary1,
  Source1 = logit_scaled(db1$Source1),
  Source2 = logit_scaled(db1$Source2)
)

samples_weighted1 <- mod_wb1$sample(
  data = data_weightedBayes1,
  seed = 123,
  chains = 2,
  parallel_chains = 2,
  threads_per_chain = 2,
  iter_warmup = 1500,
  iter_sampling = 3000,
  refresh = 500
)
## Running MCMC with 2 parallel chains, with 2 thread(s) per chain...
## 
## Chain 1 Iteration:    1 / 4500 [  0%]  (Warmup) 
## Chain 1 Iteration:  500 / 4500 [ 11%]  (Warmup) 
## Chain 1 Iteration: 1000 / 4500 [ 22%]  (Warmup) 
## Chain 1 Iteration: 1500 / 4500 [ 33%]  (Warmup) 
## Chain 1 Iteration: 1501 / 4500 [ 33%]  (Sampling) 
## Chain 1 Iteration: 2000 / 4500 [ 44%]  (Sampling) 
## Chain 1 Iteration: 2500 / 4500 [ 55%]  (Sampling) 
## Chain 1 Iteration: 3000 / 4500 [ 66%]  (Sampling) 
## Chain 1 Iteration: 3500 / 4500 [ 77%]  (Sampling) 
## Chain 1 Iteration: 4000 / 4500 [ 88%]  (Sampling) 
## Chain 1 Iteration: 4500 / 4500 [100%]  (Sampling) 
## Chain 2 Iteration:    1 / 4500 [  0%]  (Warmup) 
## Chain 2 Iteration:  500 / 4500 [ 11%]  (Warmup) 
## Chain 2 Iteration: 1000 / 4500 [ 22%]  (Warmup) 
## Chain 2 Iteration: 1500 / 4500 [ 33%]  (Warmup) 
## Chain 2 Iteration: 1501 / 4500 [ 33%]  (Sampling) 
## Chain 2 Iteration: 2000 / 4500 [ 44%]  (Sampling) 
## Chain 2 Iteration: 2500 / 4500 [ 55%]  (Sampling) 
## Chain 2 Iteration: 3000 / 4500 [ 66%]  (Sampling) 
## Chain 2 Iteration: 3500 / 4500 [ 77%]  (Sampling) 
## Chain 2 Iteration: 4000 / 4500 [ 88%]  (Sampling) 
## Chain 2 Iteration: 4500 / 4500 [100%]  (Sampling) 
## Chain 1 finished in 0.0 seconds.
## Chain 2 finished in 0.0 seconds.
## 
## Both chains finished successfully.
## Mean chain execution time: 0.0 seconds.
## Total execution time: 0.4 seconds.

9.16 Evaluation

draws_df <- as_draws_df(samples_weighted1$draws())

ggplot(draws_df, aes(.iteration, bias, group = .chain, color = .chain)) +
  geom_line(alpha = 0.5) +
  theme_classic()

ggplot(draws_df, aes(.iteration, w1, group = .chain, color = .chain)) +
  geom_line(alpha = 0.5) +
  theme_classic()

ggplot(draws_df, aes(.iteration, w2, group = .chain, color = .chain)) +
  geom_line(alpha = 0.5) +
  theme_classic()

ggplot(draws_df) +
  geom_histogram(aes(bias), alpha = 0.6, fill = "lightblue") +
  geom_histogram(aes(bias_prior), alpha = 0.6, fill = "pink") +
  geom_vline(xintercept = db1$bias[1]) +
  theme_bw()
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## Warning: Removed 1 rows containing missing values (`geom_vline()`).

ggplot(draws_df) +
  geom_histogram(aes(w1), alpha = 0.6, fill = "lightblue") +
  geom_histogram(aes(w1_prior), alpha = 0.6, fill = "pink") +
  geom_vline(xintercept = db1$w1[1]) +
  theme_bw()
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## Warning: Removed 1 rows containing missing values (`geom_vline()`).

ggplot(draws_df) +
  geom_density(aes(w2), alpha = 0.6, fill = "lightblue") +
  geom_density(aes(w2_prior), alpha = 0.6, fill = "pink") +
  geom_vline(xintercept = db1$w2[1]) +
  theme_bw()
## Warning: Removed 1 rows containing missing values (`geom_vline()`).

ggplot(draws_df) +
  geom_point(aes(w1, w2), alpha = 0.3) +
  theme_bw()

ggplot(draws_df) +
  geom_point(aes(bias, w1), alpha = 0.3) +
  theme_bw()

ggplot(draws_df) +
  geom_point(aes(bias, w2), alpha = 0.3) +
  theme_bw()

9.17 Build a temporal Bayes (storing current belief as prior for next turn)

WeightedTimeBayes_f <- function(bias, Source1, Source2, w1, w2){
  w1 <- (w1 - 0.5)*2
  w2 <- (w2 - 0.5)*2
  outcome <- inv_logit_scaled(bias + w1 * logit_scaled(Source1) + w2 * logit_scaled(Source2))
  return(outcome)
}

bias <- 0
trials <- seq(10)
Source1 <- seq(0.1,0.9, 0.1)
w1 <- seq(0.5, 1, 0.1)
w2 <- seq(0.5, 1, 0.1)

db <- expand.grid(bias = bias, trials, Source1 = Source1, w1 = w1, w2 = w2) %>%
  mutate(Source2 = NA, belief = NA, binary = NA)

for (n in seq(nrow(db))) {
  if (n == 1) {db$Source2[1] = 0.5}
  db$belief[n] <- WeightedTimeBayes_f(db$bias[n], db$Source1[n], db$Source2[n],db$w1[n], db$w2[n])
  db$binary[n] <- rbinom(1,1, db$belief[n])
  if (n < nrow(db)) {db$Source2[n + 1] <- db$belief[n]}
}

stan_TB_model <- "
data {
  int<lower=0> N;
  array[N] int y;
  array[N] real <lower = 0, upper = 1> Source1; 
}

transformed data {
  array[N] real l_Source1;
  l_Source1 = logit(Source1);
}

parameters {
  real bias;
  // meaningful weights are btw 0.5 and 1 (theory reasons)
  real<lower = 0.5, upper = 1> w1; 
  real<lower = 0.5, upper = 1> w2;
}

transformed parameters {
  real<lower = 0, upper = 1> weight1;
  real<lower = 0, upper = 1> weight2;
  array[N] real l_Source2;

  // weight parameters are rescaled to be on a 0-1 scale (0 -> no effects; 1 -> face value)
  weight1 = (w1 - 0.5) * 2;  
  weight2 = (w2 - 0.5) * 2;
  
  l_Source2[1] = 0;
  
  for (n in 2:N){
    l_Source2[n] = bias + weight1 * l_Source1[n] + weight2 * l_Source2[n-1];
    }
}

model {
  target += normal_lpdf(bias | 0, 1);
  target += beta_lpdf(weight1 | 1, 1);
  target += beta_lpdf(weight2 | 1, 1);
  
  target += bernoulli_logit_lpmf(y[1] | bias + weight1 * l_Source1[1]);
  
  for (n in 2:N){  
    target += bernoulli_logit_lpmf(y[n] | l_Source2[n]);
  }
}

generated quantities{
  array[N] real log_lik;
  real bias_prior;
  real w1_prior;
  real w2_prior;
  bias_prior = normal_rng(0, 1) ;
  w1_prior = 0.5 + inv_logit(normal_rng(0, 1))/2 ;
  w2_prior = 0.5 + inv_logit(normal_rng(0, 1))/2 ;
  for (n in 1:N)
    log_lik[n]= bernoulli_logit_lpmf(y[n] | l_Source2[n]);
}
"

write_stan_file(
  stan_TB_model,
  dir = "stan/",
  basename = "W9_TB.stan")
## [1] "/Users/au209589/Dropbox/Teaching/AdvancedCognitiveModeling23_book/stan/W9_TB.stan"
file <- file.path("stan/W9_TB.stan")
mod_tb <- cmdstan_model(file, cpp_options = list(stan_threads = TRUE),
                     stanc_options = list("O1"))


db1 <- db %>% subset(w1 == 0.7 & w2 == 0.9)

data_TB <- list(
  N = nrow(db1),
  y = db1$binary,
  Source1 = db1$Source1
)

samples_TB <- mod_tb$sample(
  data = data_TB,
  seed = 123,
  chains = 2,
  parallel_chains = 2,
  threads_per_chain = 2,
  iter_warmup = 2000,
  iter_sampling = 3000,
  refresh = 500,
  adapt_delta = 0.99,
  max_treedepth = 20
)
## Running MCMC with 2 parallel chains, with 2 thread(s) per chain...
## 
## Chain 1 Iteration:    1 / 5000 [  0%]  (Warmup) 
## Chain 1 Iteration:  500 / 5000 [ 10%]  (Warmup) 
## Chain 2 Iteration:    1 / 5000 [  0%]  (Warmup) 
## Chain 2 Iteration:  500 / 5000 [ 10%]  (Warmup) 
## Chain 1 Iteration: 1000 / 5000 [ 20%]  (Warmup) 
## Chain 1 Iteration: 1500 / 5000 [ 30%]  (Warmup) 
## Chain 2 Iteration: 1000 / 5000 [ 20%]  (Warmup) 
## Chain 2 Iteration: 1500 / 5000 [ 30%]  (Warmup) 
## Chain 1 Iteration: 2000 / 5000 [ 40%]  (Warmup) 
## Chain 1 Iteration: 2001 / 5000 [ 40%]  (Sampling) 
## Chain 2 Iteration: 2000 / 5000 [ 40%]  (Warmup) 
## Chain 2 Iteration: 2001 / 5000 [ 40%]  (Sampling) 
## Chain 1 Iteration: 2500 / 5000 [ 50%]  (Sampling) 
## Chain 1 Iteration: 3000 / 5000 [ 60%]  (Sampling) 
## Chain 1 Iteration: 3500 / 5000 [ 70%]  (Sampling) 
## Chain 2 Iteration: 2500 / 5000 [ 50%]  (Sampling) 
## Chain 1 Iteration: 4000 / 5000 [ 80%]  (Sampling) 
## Chain 1 Iteration: 4500 / 5000 [ 90%]  (Sampling) 
## Chain 2 Iteration: 3000 / 5000 [ 60%]  (Sampling) 
## Chain 1 Iteration: 5000 / 5000 [100%]  (Sampling) 
## Chain 1 finished in 1.9 seconds.
## Chain 2 Iteration: 3500 / 5000 [ 70%]  (Sampling) 
## Chain 2 Iteration: 4000 / 5000 [ 80%]  (Sampling) 
## Chain 2 Iteration: 4500 / 5000 [ 90%]  (Sampling) 
## Chain 2 Iteration: 5000 / 5000 [100%]  (Sampling) 
## Chain 2 finished in 2.9 seconds.
## 
## Both chains finished successfully.
## Mean chain execution time: 2.4 seconds.
## Total execution time: 3.2 seconds.

9.18 Evaluate

samples_TB$cmdstan_diagnose()
## Processing csv files: /var/folders/lt/zspkqnxd5yg92kybm5f433_cfjr0d6/T/RtmpP3dlfk/W9_TB-202402150651-1-26c994.csv, /var/folders/lt/zspkqnxd5yg92kybm5f433_cfjr0d6/T/RtmpP3dlfk/W9_TB-202402150651-2-26c994.csv
## 
## Checking sampler transitions treedepth.
## Treedepth satisfactory for all transitions.
## 
## Checking sampler transitions for divergences.
## No divergent transitions found.
## 
## Checking E-BFMI - sampler transitions HMC potential energy.
## E-BFMI satisfactory.
## 
## Effective sample size satisfactory.
## 
## Split R-hat values satisfactory all parameters.
## 
## Processing complete, no problems detected.
#samples_TB$loo()
samples_TB$summary()
## # A tibble: 189 × 10
##    variable         mean    median    sd    mad      q5     q95  rhat ess_bulk ess_tail
##    <chr>           <dbl>     <dbl> <dbl>  <dbl>   <dbl>   <dbl> <dbl>    <dbl>    <dbl>
##  1 lp__         -50.5    -50.0     1.77  1.53   -54.0   -48.5    1.00     976.    1188.
##  2 bias          -0.0234  -0.00914 0.117 0.0931  -0.234   0.141  1.00    1880.    1819.
##  3 w1             0.756    0.744   0.111 0.128    0.595   0.956  1.00     754.    1188.
##  4 w2             0.809    0.825   0.101 0.108    0.620   0.944  1.00     792.    1150.
##  5 weight1        0.511    0.488   0.222 0.256    0.189   0.912  1.00     754.    1188.
##  6 weight2        0.617    0.649   0.202 0.215    0.240   0.888  1.00     792.    1150.
##  7 l_Source2[1]   0        0       0     0        0       0     NA         NA       NA 
##  8 l_Source2[2]  -1.15    -1.09    0.527 0.596   -2.09   -0.387  1.00     855.    1362.
##  9 l_Source2[3]  -1.76    -1.76    0.654 0.722   -2.86   -0.726  1.00     956.    1412.
## 10 l_Source2[4]  -2.14    -2.15    0.685 0.721   -3.29   -1.02   1.00    1083.    1496.
## # ℹ 179 more rows
draws_df <- as_draws_df(samples_TB$draws())

p1 <- ggplot(draws_df, aes(.iteration, bias, group = .chain, color = .chain)) +
  geom_line(alpha = 0.5) +
  theme_classic()

p2 <- ggplot(draws_df, aes(.iteration, w1, group = .chain, color = .chain)) +
  geom_line(alpha = 0.5) +
  theme_classic()

p3 <- ggplot(draws_df, aes(.iteration, w2, group = .chain, color = .chain)) +
  geom_line(alpha = 0.5) +
  theme_classic()

p1 + p2 + p3

p1 <- ggplot(draws_df) +
  geom_density(aes(bias), alpha = 0.6, fill = "lightblue") +
  geom_density(aes(bias_prior), alpha = 0.6, fill = "pink") +
  geom_vline(xintercept = db1$bias[1]) +
  theme_bw()

p2 <- ggplot(draws_df) +
  geom_density(aes(w1), alpha = 0.6, fill = "lightblue") +
  geom_density(aes(w1_prior), alpha = 0.6, fill = "pink") +
  geom_vline(xintercept = db1$w1[1]) +
  theme_bw()

p3 <- ggplot(draws_df) +
  geom_density(aes(w2), alpha = 0.6, fill = "lightblue") +
  geom_density(aes(w2_prior), alpha = 0.6, fill = "pink") +
  geom_vline(xintercept = db1$w2[1]) +
  theme_bw()

p1 + p2 + p3

p1 <- ggplot(draws_df) +
  geom_point(aes(w1, w2), alpha = 0.3) +
  theme_bw()
p2 <- ggplot(draws_df) +
  geom_point(aes(bias, w1), alpha = 0.3) +
  theme_bw()
p3 <- ggplot(draws_df) +
  geom_point(aes(bias, w2), alpha = 0.3) +
  theme_bw()

p1 + p2 + p3

[MISSING: model seems good at recovering. BUT HIGH correlations between weights and funnels, so we would probably be safer reparameterizing]

9.19 Multilevel version of the simple bayes model

stan_simpleBayes_ml_model <- "
data {
  int<lower=0> N;  // n of trials
  int<lower=0> S;  // n of participants
  array[N, S] int y;
  array[N, S] real<lower=0, upper = 1> Source1;
  array[N, S] real<lower=0, upper = 1> Source2;
}

transformed data{
  array[N, S] real l_Source1;
  array[N, S] real l_Source2;
  l_Source1 = logit(Source1);
  l_Source2 = logit(Source2);
}

parameters {
  real biasM;
  real biasSD;
  array[S] real z_bias;
}

transformed parameters {
  vector[S] biasC;
  vector[S] bias;
  biasC = biasSD * to_vector(z_bias);
  bias = biasM + biasC;
 }

model {
  target +=  normal_lpdf(biasM | 0, 1);
  target +=  normal_lpdf(biasSD | 0, 1)  -
    normal_lccdf(0 | 0, 1);
  target += std_normal_lpdf(to_vector(z_bias));
  
  for (s in 1:S){
  target +=  bernoulli_logit_lpmf(y[,s] | bias[s] + 
                                          to_vector(l_Source1[,s]) + 
                                          to_vector(l_Source2[,s]));
  }
}



"

write_stan_file(
  stan_simpleBayes_ml_model,
  dir = "stan/",
  basename = "W9_SimpleBayes_ml.stan")
## [1] "/Users/au209589/Dropbox/Teaching/AdvancedCognitiveModeling23_book/stan/W9_SimpleBayes_ml.stan"
file <- file.path("stan/W9_SimpleBayes_ml.stan")
mod_simpleBayes <- cmdstan_model(file, cpp_options = list(stan_threads = TRUE),
                     stanc_options = list("O1"))

9.20 Multilevel version of the weighted bayes model

stan_WB_ml_model <- "
data {
  int<lower=0> N;  // n of trials
  int<lower=0> S;  // n of participants
  array[N, S] int y;
  array[N, S] real<lower=0, upper = 1> Source1;
  array[N, S] real<lower=0, upper = 1> Source2; 
}

transformed data {
  array[N,S] real l_Source1;
  array[N,S] real l_Source2;
  l_Source1 = logit(Source1);
  l_Source2 = logit(Source2);
}

parameters {
  real biasM;
  
  real w1_M; 
  real w2_M;
  
  vector<lower = 0>[3] tau;
  matrix[3, S] z_IDs;
  cholesky_factor_corr[3] L_u;
  
}

transformed parameters{
  matrix[S,3] IDs;
  IDs = (diag_pre_multiply(tau, L_u) * z_IDs)';
}

model {

  target += normal_lpdf(biasM | 0, 1);
  target +=  normal_lpdf(tau[1] | 0, 1)  -
    normal_lccdf(0 | 0, 1);
  target += normal_lpdf(w1_M | 0, 1);
  target +=  normal_lpdf(tau[2] | 0, 1)  -
    normal_lccdf(0 | 0, 1);
  target += normal_lpdf(w2_M | 0, 1);
  target +=  normal_lpdf(tau[3] | 0, 1)  -
    normal_lccdf(0 | 0, 1);
  target += lkj_corr_cholesky_lpdf(L_u | 3);
  
  target += std_normal_lpdf(to_vector(z_IDs));
    
  for (s in 1:S){
  for (n in 1:N){
    target += bernoulli_logit_lpmf(y[n,s] | biasM + IDs[s, 1] + 
          (w1_M + IDs[s, 2]) * l_Source1[n,s] + 
          (w2_M + IDs[s, 3]) * l_Source2[n,s]);
  }}
}

"

write_stan_file(
  stan_WB_ml_model,
  dir = "stan/",
  basename = "W9_WB_ml.stan")
## [1] "/Users/au209589/Dropbox/Teaching/AdvancedCognitiveModeling23_book/stan/W9_WB_ml.stan"
file <- file.path("stan/W9_WB_ml.stan")
mod_wb <- cmdstan_model(file, cpp_options = list(stan_threads = TRUE),
                     stanc_options = list("O1"))

[MISSING: Fitting on real data and model comparison]

pacman::p_load(
    tidyverse,
    future,
    purrr,
    furrr,
    patchwork,
    brms,
    cmdstanr
)