Chapter 9 Bayesian models of cognition
9.2 Design the model
[MISSING: EXPLAIN THE MODEL] [MISSING: Explain the different outcomes, including examples of experiments] [MISSING: explain the difference between a simple Bayes with weights of 1 - accumulating evidence - and one with weights that sum up to 1 - integrating/averaging evidence]
9.3 Create the data
bias <- 0
trials <- seq(10)
Source1 <- seq(0.1,0.9, 0.1)
Source2 <- seq(0.1,0.9, 0.1)
db <- expand.grid(bias = bias, trials = trials, Source1 = Source1, Source2 = Source2)
for (n in seq(nrow(db))) {
db$belief[n] <- SimpleBayes_f(db$bias[n], db$Source1[n], db$Source2[n])
db$choice[n] <- rbinom(1,1, db$belief[n])
db$continuous[n] <- db$belief[n]*9
db$discrete[n] <- round(db$belief[n]*9,0)
}
9.4 Visualize
[MISSING: Explain]
ggplot(db, aes(Source1, choice, color = Source2, group = Source2)) +
geom_smooth(se = F) +
theme_bw()
## `geom_smooth()` using method = 'loess' and formula = 'y ~ x'
## `geom_smooth()` using method = 'loess' and formula = 'y ~ x'
## `geom_smooth()` using method = 'loess' and formula = 'y ~ x'
9.6 Create the Stan Model
stan_simpleBayes_model <- "
data {
int<lower=0> N;
array[N] int y;
array[N] real<lower=0, upper = 1> Source1;
array[N] real<lower=0, upper = 1> Source2;
}
transformed data{
array[N] real l_Source1;
array[N] real l_Source2;
l_Source1 = logit(Source1);
l_Source2 = logit(Source2);
}
parameters {
real bias;
}
model {
target += normal_lpdf(bias | 0, 1);
target += bernoulli_logit_lpmf(y | bias + to_vector(l_Source1) + to_vector(l_Source2));
}
generated quantities{
real bias_prior;
array[N] real log_lik;
bias_prior = normal_rng(0, 1);
for (n in 1:N){
log_lik[n] = bernoulli_logit_lpmf(y[n] | bias + l_Source1[n] + l_Source2[n]);
}
}
"
write_stan_file(
stan_simpleBayes_model,
dir = "stan/",
basename = "W9_SimpleBayes.stan")
## [1] "/Users/au209589/Dropbox/Teaching/AdvancedCognitiveModeling23_book/stan/W9_SimpleBayes.stan"
9.7 Fitting the model
samples_simple <- mod_simpleBayes$sample(
data = data_simpleBayes,
#fixed_param = TRUE,
seed = 123,
chains = 2,
parallel_chains = 2,
threads_per_chain = 2,
iter_warmup = 1500,
iter_sampling = 3000,
refresh = 500
)
## Running MCMC with 2 parallel chains, with 2 thread(s) per chain...
##
## Chain 1 Iteration: 1 / 4500 [ 0%] (Warmup)
## Chain 1 Iteration: 500 / 4500 [ 11%] (Warmup)
## Chain 1 Iteration: 1000 / 4500 [ 22%] (Warmup)
## Chain 1 Iteration: 1500 / 4500 [ 33%] (Warmup)
## Chain 1 Iteration: 1501 / 4500 [ 33%] (Sampling)
## Chain 2 Iteration: 1 / 4500 [ 0%] (Warmup)
## Chain 2 Iteration: 500 / 4500 [ 11%] (Warmup)
## Chain 2 Iteration: 1000 / 4500 [ 22%] (Warmup)
## Chain 2 Iteration: 1500 / 4500 [ 33%] (Warmup)
## Chain 2 Iteration: 1501 / 4500 [ 33%] (Sampling)
## Chain 1 Iteration: 2000 / 4500 [ 44%] (Sampling)
## Chain 2 Iteration: 2000 / 4500 [ 44%] (Sampling)
## Chain 1 Iteration: 2500 / 4500 [ 55%] (Sampling)
## Chain 2 Iteration: 2500 / 4500 [ 55%] (Sampling)
## Chain 1 Iteration: 3000 / 4500 [ 66%] (Sampling)
## Chain 2 Iteration: 3000 / 4500 [ 66%] (Sampling)
## Chain 1 Iteration: 3500 / 4500 [ 77%] (Sampling)
## Chain 2 Iteration: 3500 / 4500 [ 77%] (Sampling)
## Chain 1 Iteration: 4000 / 4500 [ 88%] (Sampling)
## Chain 2 Iteration: 4000 / 4500 [ 88%] (Sampling)
## Chain 1 Iteration: 4500 / 4500 [100%] (Sampling)
## Chain 1 finished in 1.0 seconds.
## Chain 2 Iteration: 4500 / 4500 [100%] (Sampling)
## Chain 2 finished in 1.0 seconds.
##
## Both chains finished successfully.
## Mean chain execution time: 1.0 seconds.
## Total execution time: 1.3 seconds.
9.8 Basic evaluation
## Processing csv files: /var/folders/lt/zspkqnxd5yg92kybm5f433_cfjr0d6/T/RtmpP3dlfk/W9_SimpleBayes-202402150650-1-1b8ab4.csv, /var/folders/lt/zspkqnxd5yg92kybm5f433_cfjr0d6/T/RtmpP3dlfk/W9_SimpleBayes-202402150650-2-1b8ab4.csv
##
## Checking sampler transitions treedepth.
## Treedepth satisfactory for all transitions.
##
## Checking sampler transitions for divergences.
## No divergent transitions found.
##
## Checking E-BFMI - sampler transitions HMC potential energy.
## E-BFMI satisfactory.
##
## Effective sample size satisfactory.
##
## Split R-hat values satisfactory all parameters.
##
## Processing complete, no problems detected.
## # A tibble: 813 × 10
## variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 lp__ -3.99e+2 -3.99e+2 6.97e-1 3.05e-1 -4.01e+2 -3.99e+2 1.00 3053. 3376.
## 2 bias -1.56e-1 -1.55e-1 8.87e-2 8.84e-2 -3.05e-1 -9.42e-3 1.00 2150. 3351.
## 3 bias_pri… -2.66e-2 -2.85e-2 1.01e+0 9.92e-1 -1.71e+0 1.65e+0 1.00 5342. 5967.
## 4 log_lik[… -1.05e-2 -1.05e-2 9.30e-4 9.27e-4 -1.22e-2 -9.06e-3 1.00 2150. 3351.
## 5 log_lik[… -1.05e-2 -1.05e-2 9.30e-4 9.27e-4 -1.22e-2 -9.06e-3 1.00 2150. 3351.
## 6 log_lik[… -1.05e-2 -1.05e-2 9.30e-4 9.27e-4 -1.22e-2 -9.06e-3 1.00 2150. 3351.
## 7 log_lik[… -1.05e-2 -1.05e-2 9.30e-4 9.27e-4 -1.22e-2 -9.06e-3 1.00 2150. 3351.
## 8 log_lik[… -1.05e-2 -1.05e-2 9.30e-4 9.27e-4 -1.22e-2 -9.06e-3 1.00 2150. 3351.
## 9 log_lik[… -1.05e-2 -1.05e-2 9.30e-4 9.27e-4 -1.22e-2 -9.06e-3 1.00 2150. 3351.
## 10 log_lik[… -1.05e-2 -1.05e-2 9.30e-4 9.27e-4 -1.22e-2 -9.06e-3 1.00 2150. 3351.
## # ℹ 803 more rows
##
## Computed from 6000 by 810 log-likelihood matrix
##
## Estimate SE
## elpd_loo -398.8 15.8
## p_loo 1.0 0.0
## looic 797.7 31.7
## ------
## Monte Carlo SE of elpd_loo is 0.0.
##
## All Pareto k estimates are good (k < 0.5).
## See help('pareto-k-diagnostic') for details.
draws_df <- as_draws_df(samples_simple$draws())
ggplot(draws_df, aes(.iteration, bias, group = .chain, color = .chain)) +
geom_line(alpha = 0.5) +
theme_classic()
ggplot(draws_df) +
geom_density(aes(bias), alpha = 0.6, fill = "lightblue") +
geom_density(aes(bias_prior), alpha = 0.6, fill = "pink") +
geom_vline(xintercept = db$bias[1]) +
theme_bw()
[MISSING PARAMETER RECOVERY]
9.9 Weighted Bayes
[MISSING: EXPLANATION] [MISSING: FOCUS ON WEIGHTS AND THEIR SCALE: 0-1 on log-odds; 0.5-1 on probability]
WeightedBayes_f <- function(bias, Source1, Source2, w1, w2){
w1 <- (w1 - 0.5)*2
w2 <- (w2 - 0.5)*2
outcome <- inv_logit_scaled(bias + w1 * logit_scaled(Source1) + w2 * logit_scaled(Source2))
return(outcome)
}
## This version of the model is from Taking others into account (in the syllabus)
## It takes two sources of information and weights them, then adds a bias
## to generate a posterior on a 0-1 scale
WeightedBayes_f1 <- function(bias, Source1, Source2, w1, w2){
outcome <- inv_logit_scaled(bias + weight_f(logit_scaled(Source1), w1) +
weight_f(logit_scaled(Source2), w2))
return(outcome)
}
## The weight_f formula comes from https://www.nature.com/articles/ncomms14218
## and ensures that even if we work on a log-odds scale, we get the right weights
## It takes all values of L (- inf to +inf). Technically the only valid values for
## w are 0.5 (no consideration of the evidence) to 1 (taking the evidence at face value).
## In practice the function would also accept 0-0.5 (invert the evidence, at face value
## if 0, at decreased value as it grows towards 0.5), and slightly higher than 1
## (overweighing the evidence, but it's very unstable and quickly gives NaN).
weight_f <- function(L, w){
return(log((w * exp(L) + 1 - w) /
((1 - w) * exp(L) + w)))
}
bias <- 0
trials <- seq(10)
Source1 <- seq(0.1,0.9, 0.1)
Source2 <- seq(0.1,0.9, 0.1)
w1 <- seq(0.5, 1, 0.1)
w2 <- seq(0.5, 1, 0.1)
db <- expand.grid(bias = bias, trials, Source1 = Source1, Source2 = Source2, w1 = w1, w2 = w2)
for (n in seq(nrow(db))) {
db$belief[n] <- WeightedBayes_f(db$bias[n], db$Source1[n], db$Source2[n],db$w1[n], db$w2[n])
db$belief1[n] <- WeightedBayes_f1(db$bias[n], db$Source1[n], db$Source2[n],db$w1[n], db$w2[n])
db$binary[n] <- rbinom(1,1, db$belief[n])
db$binary1[n] <- rbinom(1,1, db$belief1[n])
db$continuous[n] <- db$belief[n] * 9
db$continuous1[n] <- db$belief1[n] * 9
db$discrete[n] <- round(db$belief[n] * 9,0)
db$discrete1[n] <- round(db$belief1[n] * 9,0)
}
9.10 Visualize
9.11 Build the Weighted Bayes Stan model (simple formula)
stan_WB_model <- "
data {
int<lower=0> N;
array[N] int y;
array[N] real <lower = 0, upper = 1> Source1;
array[N] real <lower = 0, upper = 1> Source2;
}
transformed data {
array[N] real l_Source1;
array[N] real l_Source2;
l_Source1 = logit(Source1);
l_Source2 = logit(Source2);
}
parameters {
real bias;
// meaningful weights are btw 0.5 and 1 (theory reasons)
real<lower = 0.5, upper = 1> w1;
real<lower = 0.5, upper = 1> w2;
}
transformed parameters {
real<lower = 0, upper = 1> weight1;
real<lower = 0, upper = 1> weight2;
// weight parameters are rescaled to be on a 0-1 scale (0 -> no effects; 1 -> face value)
weight1 = (w1 - 0.5) * 2;
weight2 = (w2 - 0.5) * 2;
}
model {
target += normal_lpdf(bias | 0, 1);
target += beta_lpdf(weight1 | 1, 1);
target += beta_lpdf(weight2 | 1, 1);
for (n in 1:N)
target += bernoulli_logit_lpmf(y[n] | bias + weight1 *l_Source1[n] + weight2 * l_Source2[n]);
}
generated quantities{
array[N] real log_lik;
real bias_prior;
real w1_prior;
real w2_prior;
bias_prior = normal_rng(0, 1) ;
w1_prior = 0.5 + inv_logit(normal_rng(0, 1))/2 ;
w2_prior = 0.5 + inv_logit(normal_rng(0, 1))/2 ;
for (n in 1:N)
log_lik[n]= bernoulli_logit_lpmf(y[n] | bias + weight1 * l_Source1[n] + weight2 * l_Source2[n]);
}
"
write_stan_file(
stan_WB_model,
dir = "stan/",
basename = "W9_WB.stan")
## [1] "/Users/au209589/Dropbox/Teaching/AdvancedCognitiveModeling23_book/stan/W9_WB.stan"
file <- file.path("stan/W9_WB.stan")
mod_wb <- cmdstan_model(file, cpp_options = list(stan_threads = TRUE),
stanc_options = list("O1"))
db1 <- db %>% subset(w1 == 0.7 & w2 == 0.9)
p3 <- ggplot(db1, aes(Source1, belief, color = Source2, group = Source2)) +
geom_line() +
theme_bw()
p3
## `geom_smooth()` using method = 'loess' and formula = 'y ~ x'
data_weightedBayes <- list(
N = nrow(db1),
y = db1$binary,
Source1 = db1$Source1,
Source2 = db1$Source2
)
samples_weighted <- mod_wb$sample(
data = data_weightedBayes,
seed = 123,
chains = 2,
parallel_chains = 2,
threads_per_chain = 2,
iter_warmup = 1500,
iter_sampling = 3000,
refresh = 500
)
## Running MCMC with 2 parallel chains, with 2 thread(s) per chain...
##
## Chain 1 Iteration: 1 / 4500 [ 0%] (Warmup)
## Chain 2 Iteration: 1 / 4500 [ 0%] (Warmup)
## Chain 1 Iteration: 500 / 4500 [ 11%] (Warmup)
## Chain 2 Iteration: 500 / 4500 [ 11%] (Warmup)
## Chain 1 Iteration: 1000 / 4500 [ 22%] (Warmup)
## Chain 2 Iteration: 1000 / 4500 [ 22%] (Warmup)
## Chain 1 Iteration: 1500 / 4500 [ 33%] (Warmup)
## Chain 1 Iteration: 1501 / 4500 [ 33%] (Sampling)
## Chain 2 Iteration: 1500 / 4500 [ 33%] (Warmup)
## Chain 2 Iteration: 1501 / 4500 [ 33%] (Sampling)
## Chain 1 Iteration: 2000 / 4500 [ 44%] (Sampling)
## Chain 2 Iteration: 2000 / 4500 [ 44%] (Sampling)
## Chain 1 Iteration: 2500 / 4500 [ 55%] (Sampling)
## Chain 2 Iteration: 2500 / 4500 [ 55%] (Sampling)
## Chain 1 Iteration: 3000 / 4500 [ 66%] (Sampling)
## Chain 2 Iteration: 3000 / 4500 [ 66%] (Sampling)
## Chain 1 Iteration: 3500 / 4500 [ 77%] (Sampling)
## Chain 2 Iteration: 3500 / 4500 [ 77%] (Sampling)
## Chain 1 Iteration: 4000 / 4500 [ 88%] (Sampling)
## Chain 2 Iteration: 4000 / 4500 [ 88%] (Sampling)
## Chain 1 Iteration: 4500 / 4500 [100%] (Sampling)
## Chain 1 finished in 3.6 seconds.
## Chain 2 Iteration: 4500 / 4500 [100%] (Sampling)
## Chain 2 finished in 3.7 seconds.
##
## Both chains finished successfully.
## Mean chain execution time: 3.7 seconds.
## Total execution time: 4.1 seconds.
9.12 Model evaluation
## Processing csv files: /var/folders/lt/zspkqnxd5yg92kybm5f433_cfjr0d6/T/RtmpP3dlfk/W9_WB-202402150650-1-470abd.csv, /var/folders/lt/zspkqnxd5yg92kybm5f433_cfjr0d6/T/RtmpP3dlfk/W9_WB-202402150650-2-470abd.csv
##
## Checking sampler transitions treedepth.
## Treedepth satisfactory for all transitions.
##
## Checking sampler transitions for divergences.
## No divergent transitions found.
##
## Checking E-BFMI - sampler transitions HMC potential energy.
## E-BFMI satisfactory.
##
## Effective sample size satisfactory.
##
## Split R-hat values satisfactory all parameters.
##
## Processing complete, no problems detected.
## # A tibble: 819 × 10
## variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 lp__ -4.71e+2 -4.71e+2 1.36 1.10 -474. -4.70e+2 1.00 2628. 3124.
## 2 bias -6.41e-3 -6.56e-3 0.0794 0.0785 -0.135 1.25e-1 1.00 5023. 3770.
## 3 w1 7.03e-1 7.03e-1 0.0331 0.0331 0.649 7.57e-1 1.00 4493. 3661.
## 4 w2 8.95e-1 8.94e-1 0.0353 0.0364 0.838 9.54e-1 1.00 4204. 2727.
## 5 weight1 4.06e-1 4.05e-1 0.0662 0.0662 0.298 5.14e-1 1.00 4493. 3661.
## 6 weight2 7.89e-1 7.87e-1 0.0705 0.0727 0.675 9.07e-1 1.00 4204. 2727.
## 7 log_lik[1] -7.14e-2 -6.99e-2 0.0169 0.0168 -0.101 -4.66e-2 1.00 3913. 3436.
## 8 log_lik[2] -7.14e-2 -6.99e-2 0.0169 0.0168 -0.101 -4.66e-2 1.00 3913. 3436.
## 9 log_lik[3] -7.14e-2 -6.99e-2 0.0169 0.0168 -0.101 -4.66e-2 1.00 3913. 3436.
## 10 log_lik[4] -2.70e+0 -2.70e+0 0.229 0.233 -3.09 -2.34e+0 1.00 3913. 3436.
## # ℹ 809 more rows
##
## Computed from 6000 by 810 log-likelihood matrix
##
## Estimate SE
## elpd_loo -467.0 12.8
## p_loo 3.1 0.2
## looic 933.9 25.5
## ------
## Monte Carlo SE of elpd_loo is 0.0.
##
## All Pareto k estimates are good (k < 0.5).
## See help('pareto-k-diagnostic') for details.
draws_df <- as_draws_df(samples_weighted$draws())
ggplot(draws_df, aes(.iteration, bias, group = .chain, color = .chain)) +
geom_line(alpha = 0.5) +
theme_classic()
ggplot(draws_df, aes(.iteration, w1, group = .chain, color = .chain)) +
geom_line(alpha = 0.5) +
theme_classic()
ggplot(draws_df, aes(.iteration, w2, group = .chain, color = .chain)) +
geom_line(alpha = 0.5) +
theme_classic()
p1 <- ggplot(draws_df) +
geom_density(aes(bias), alpha = 0.6, fill = "lightblue") +
geom_density(aes(bias_prior), alpha = 0.6, fill = "pink") +
geom_vline(xintercept = db1$bias[1]) +
theme_bw()
p2 <- ggplot(draws_df) +
geom_density(aes(w1), alpha = 0.6, fill = "lightblue") +
geom_density(aes(w1_prior), alpha = 0.6, fill = "pink") +
geom_vline(xintercept = db1$w1[1]) +
theme_bw()
p3 <- ggplot(draws_df) +
geom_density(aes(w2), alpha = 0.6, fill = "lightblue") +
geom_density(aes(w2_prior), alpha = 0.6, fill = "pink") +
geom_vline(xintercept = db1$w2[1]) +
theme_bw()
p1 + p2 + p3
9.13 Model recovery
bias <- 0
trials <- seq(10)
Source1 <- seq(0.1,0.9, 0.1)
Source2 <- seq(0.1,0.9, 0.1)
w1 <- 0.7
w2 <- 0.9
db <- expand.grid(bias = bias, trials, Source1 = Source1, Source2 = Source2, w1 = w1, w2 = w2)
for (n in seq(nrow(db))) {
db$simple_belief[n] <- SimpleBayes_f(db$bias[n], db$Source1[n], db$Source2[n])
db$weighted_belief[n] <- WeightedBayes_f(db$bias[n], db$Source1[n], db$Source2[n],db$w1[n], db$w2[n])
db$simple_binary[n] <- rbinom(1,1, db$simple_belief[n])
db$weighted_binary[n] <- rbinom(1,1, db$weighted_belief[n])
}
data_SB <- list(
N = nrow(db),
y = db$simple_binary,
Source1 = db$Source1,
Source2 = db$Source2
)
data_WB <- list(
N = nrow(db),
y = db$weighted_binary,
Source1 = db$Source1,
Source2 = db$Source2
)
## On the simple data
simple2simple <- mod_simpleBayes$sample(
data = data_SB,
seed = 123,
chains = 2,
parallel_chains = 2,
threads_per_chain = 2,
iter_warmup = 1500,
iter_sampling = 3000,
refresh = 500
)
## Running MCMC with 2 parallel chains, with 2 thread(s) per chain...
##
## Chain 1 Iteration: 1 / 4500 [ 0%] (Warmup)
## Chain 1 Iteration: 500 / 4500 [ 11%] (Warmup)
## Chain 1 Iteration: 1000 / 4500 [ 22%] (Warmup)
## Chain 1 Iteration: 1500 / 4500 [ 33%] (Warmup)
## Chain 1 Iteration: 1501 / 4500 [ 33%] (Sampling)
## Chain 1 Iteration: 2000 / 4500 [ 44%] (Sampling)
## Chain 2 Iteration: 1 / 4500 [ 0%] (Warmup)
## Chain 2 Iteration: 500 / 4500 [ 11%] (Warmup)
## Chain 2 Iteration: 1000 / 4500 [ 22%] (Warmup)
## Chain 2 Iteration: 1500 / 4500 [ 33%] (Warmup)
## Chain 2 Iteration: 1501 / 4500 [ 33%] (Sampling)
## Chain 2 Iteration: 2000 / 4500 [ 44%] (Sampling)
## Chain 1 Iteration: 2500 / 4500 [ 55%] (Sampling)
## Chain 1 Iteration: 3000 / 4500 [ 66%] (Sampling)
## Chain 2 Iteration: 2500 / 4500 [ 55%] (Sampling)
## Chain 2 Iteration: 3000 / 4500 [ 66%] (Sampling)
## Chain 1 Iteration: 3500 / 4500 [ 77%] (Sampling)
## Chain 1 Iteration: 4000 / 4500 [ 88%] (Sampling)
## Chain 2 Iteration: 3500 / 4500 [ 77%] (Sampling)
## Chain 1 Iteration: 4500 / 4500 [100%] (Sampling)
## Chain 2 Iteration: 4000 / 4500 [ 88%] (Sampling)
## Chain 1 finished in 1.0 seconds.
## Chain 2 Iteration: 4500 / 4500 [100%] (Sampling)
## Chain 2 finished in 1.1 seconds.
##
## Both chains finished successfully.
## Mean chain execution time: 1.1 seconds.
## Total execution time: 1.4 seconds.
weighted2simple <- mod_wb$sample(
data = data_SB,
seed = 123,
chains = 2,
parallel_chains = 2,
threads_per_chain = 2,
iter_warmup = 1500,
iter_sampling = 3000,
refresh = 500
)
## Running MCMC with 2 parallel chains, with 2 thread(s) per chain...
##
## Chain 1 Iteration: 1 / 4500 [ 0%] (Warmup)
## Chain 2 Iteration: 1 / 4500 [ 0%] (Warmup)
## Chain 1 Iteration: 500 / 4500 [ 11%] (Warmup)
## Chain 2 Iteration: 500 / 4500 [ 11%] (Warmup)
## Chain 1 Iteration: 1000 / 4500 [ 22%] (Warmup)
## Chain 2 Iteration: 1000 / 4500 [ 22%] (Warmup)
## Chain 1 Iteration: 1500 / 4500 [ 33%] (Warmup)
## Chain 1 Iteration: 1501 / 4500 [ 33%] (Sampling)
## Chain 2 Iteration: 1500 / 4500 [ 33%] (Warmup)
## Chain 2 Iteration: 1501 / 4500 [ 33%] (Sampling)
## Chain 1 Iteration: 2000 / 4500 [ 44%] (Sampling)
## Chain 2 Iteration: 2000 / 4500 [ 44%] (Sampling)
## Chain 1 Iteration: 2500 / 4500 [ 55%] (Sampling)
## Chain 2 Iteration: 2500 / 4500 [ 55%] (Sampling)
## Chain 2 Iteration: 3000 / 4500 [ 66%] (Sampling)
## Chain 1 Iteration: 3000 / 4500 [ 66%] (Sampling)
## Chain 2 Iteration: 3500 / 4500 [ 77%] (Sampling)
## Chain 1 Iteration: 3500 / 4500 [ 77%] (Sampling)
## Chain 2 Iteration: 4000 / 4500 [ 88%] (Sampling)
## Chain 1 Iteration: 4000 / 4500 [ 88%] (Sampling)
## Chain 2 Iteration: 4500 / 4500 [100%] (Sampling)
## Chain 2 finished in 4.5 seconds.
## Chain 1 Iteration: 4500 / 4500 [100%] (Sampling)
## Chain 1 finished in 4.8 seconds.
##
## Both chains finished successfully.
## Mean chain execution time: 4.7 seconds.
## Total execution time: 5.0 seconds.
simple2weighted <- mod_simpleBayes$sample(
data = data_WB,
seed = 123,
chains = 2,
parallel_chains = 2,
threads_per_chain = 2,
iter_warmup = 1500,
iter_sampling = 3000,
refresh = 500
)
## Running MCMC with 2 parallel chains, with 2 thread(s) per chain...
##
## Chain 1 Iteration: 1 / 4500 [ 0%] (Warmup)
## Chain 1 Iteration: 500 / 4500 [ 11%] (Warmup)
## Chain 1 Iteration: 1000 / 4500 [ 22%] (Warmup)
## Chain 1 Iteration: 1500 / 4500 [ 33%] (Warmup)
## Chain 1 Iteration: 1501 / 4500 [ 33%] (Sampling)
## Chain 2 Iteration: 1 / 4500 [ 0%] (Warmup)
## Chain 2 Iteration: 500 / 4500 [ 11%] (Warmup)
## Chain 2 Iteration: 1000 / 4500 [ 22%] (Warmup)
## Chain 2 Iteration: 1500 / 4500 [ 33%] (Warmup)
## Chain 2 Iteration: 1501 / 4500 [ 33%] (Sampling)
## Chain 1 Iteration: 2000 / 4500 [ 44%] (Sampling)
## Chain 1 Iteration: 2500 / 4500 [ 55%] (Sampling)
## Chain 2 Iteration: 2000 / 4500 [ 44%] (Sampling)
## Chain 2 Iteration: 2500 / 4500 [ 55%] (Sampling)
## Chain 1 Iteration: 3000 / 4500 [ 66%] (Sampling)
## Chain 1 Iteration: 3500 / 4500 [ 77%] (Sampling)
## Chain 2 Iteration: 3000 / 4500 [ 66%] (Sampling)
## Chain 2 Iteration: 3500 / 4500 [ 77%] (Sampling)
## Chain 1 Iteration: 4000 / 4500 [ 88%] (Sampling)
## Chain 1 Iteration: 4500 / 4500 [100%] (Sampling)
## Chain 2 Iteration: 4000 / 4500 [ 88%] (Sampling)
## Chain 1 finished in 1.0 seconds.
## Chain 2 Iteration: 4500 / 4500 [100%] (Sampling)
## Chain 2 finished in 1.0 seconds.
##
## Both chains finished successfully.
## Mean chain execution time: 1.0 seconds.
## Total execution time: 1.3 seconds.
weighted2weighted <- mod_wb$sample(
data = data_WB,
seed = 123,
chains = 2,
parallel_chains = 2,
threads_per_chain = 2,
iter_warmup = 1500,
iter_sampling = 3000,
refresh = 500
)
## Running MCMC with 2 parallel chains, with 2 thread(s) per chain...
##
## Chain 1 Iteration: 1 / 4500 [ 0%] (Warmup)
## Chain 2 Iteration: 1 / 4500 [ 0%] (Warmup)
## Chain 1 Iteration: 500 / 4500 [ 11%] (Warmup)
## Chain 2 Iteration: 500 / 4500 [ 11%] (Warmup)
## Chain 1 Iteration: 1000 / 4500 [ 22%] (Warmup)
## Chain 2 Iteration: 1000 / 4500 [ 22%] (Warmup)
## Chain 1 Iteration: 1500 / 4500 [ 33%] (Warmup)
## Chain 1 Iteration: 1501 / 4500 [ 33%] (Sampling)
## Chain 2 Iteration: 1500 / 4500 [ 33%] (Warmup)
## Chain 2 Iteration: 1501 / 4500 [ 33%] (Sampling)
## Chain 1 Iteration: 2000 / 4500 [ 44%] (Sampling)
## Chain 2 Iteration: 2000 / 4500 [ 44%] (Sampling)
## Chain 1 Iteration: 2500 / 4500 [ 55%] (Sampling)
## Chain 2 Iteration: 2500 / 4500 [ 55%] (Sampling)
## Chain 1 Iteration: 3000 / 4500 [ 66%] (Sampling)
## Chain 2 Iteration: 3000 / 4500 [ 66%] (Sampling)
## Chain 1 Iteration: 3500 / 4500 [ 77%] (Sampling)
## Chain 2 Iteration: 3500 / 4500 [ 77%] (Sampling)
## Chain 1 Iteration: 4000 / 4500 [ 88%] (Sampling)
## Chain 2 Iteration: 4000 / 4500 [ 88%] (Sampling)
## Chain 1 Iteration: 4500 / 4500 [100%] (Sampling)
## Chain 1 finished in 3.6 seconds.
## Chain 2 Iteration: 4500 / 4500 [100%] (Sampling)
## Chain 2 finished in 3.7 seconds.
##
## Both chains finished successfully.
## Mean chain execution time: 3.7 seconds.
## Total execution time: 4.0 seconds.
Loo_weighted2simple <- weighted2simple$loo(save_psis = TRUE, cores = 4)
p2 <- plot(Loo_weighted2simple)
Loo_simple2weighted <- simple2weighted$loo(save_psis = TRUE, cores = 4)
p3 <- plot(Loo_simple2weighted)
Loo_weighted2weighted <- weighted2weighted$loo(save_psis = TRUE, cores = 4)
p4 <- plot(Loo_weighted2weighted)
elpd <- tibble(
n = seq(810),
simple_diff_elpd =
Loo_simple2simple$pointwise[, "elpd_loo"] -
Loo_weighted2simple$pointwise[, "elpd_loo"],
weighted_diff_elpd =
Loo_weighted2weighted$pointwise[, "elpd_loo"] -
Loo_simple2weighted$pointwise[, "elpd_loo"])
p1 <- ggplot(elpd, aes(x = n, y = simple_diff_elpd)) +
geom_point(alpha = .1) +
#xlim(.5,1.01) +
#ylim(-1.5,1.5) +
geom_hline(yintercept = 0, color = "red", linetype = "dashed") +
theme_bw()
p2 <- ggplot(elpd, aes(x = n, y = weighted_diff_elpd)) +
geom_point(alpha = .1) +
#xlim(.5,1.01) +
#ylim(-1.5,1.5) +
geom_hline(yintercept = 0, color = "red", linetype = "dashed") +
theme_bw()
library(patchwork)
p1 + p2
## elpd_diff se_diff
## model1 0.0 0.0
## model2 -1.3 1.0
## elpd_diff se_diff
## model1 0.0 0.0
## model2 -36.9 9.4
## Method: stacking
## ------
## weight
## model1 1.000
## model2 0.000
## Method: stacking
## ------
## weight
## model1 1.000
## model2 0.000
9.15 Bonus: Build the Weighted Bayes Stan model (fancier formula)
stan_WB1_model <- "
functions{
real weight_f(real L_raw, real w_raw) {
real L;
real w;
L = exp(L_raw);
w = 0.5 + inv_logit(w_raw)/2;
return log((w * L + 1 - w)./((1 - w) * L + w));
}
}
data {
int<lower=0> N;
array[N] int y;
vector[N] Source1;
vector[N] Source2;
}
parameters {
real bias;
real weight1;
real weight2;
}
model {
target += normal_lpdf(bias | 0, 1);
target += normal_lpdf(weight1 | 0, 1.5);
target += normal_lpdf(weight2 | 0, 1.5);
for (n in 1:N){
target += bernoulli_logit_lpmf(y[n] | bias + weight_f(Source1[n], weight1) + weight_f(Source2[n], weight2));
}
}
generated quantities{
array[N] real log_lik;
real bias_prior;
real w1_prior;
real w2_prior;
real w1;
real w2;
bias_prior = normal_rng(0,1);
w1_prior = normal_rng(0,1.5);
w2_prior = normal_rng(0,1.5);
w1_prior = 0.5 + inv_logit(normal_rng(0,1))/2;
w2_prior = 0.5 + inv_logit(normal_rng(0,1))/2;
w1 = 0.5 + inv_logit(weight1)/2;
w2 = 0.5 + inv_logit(weight2)/2;
for (n in 1:N){
log_lik[n] = bernoulli_logit_lpmf(y[n] | bias + weight_f(Source1[n], weight1) +
weight_f(Source2[n], weight2));
}
}
"
write_stan_file(
stan_WB1_model,
dir = "stan/",
basename = "W9_WB1.stan")
## [1] "/Users/au209589/Dropbox/Teaching/AdvancedCognitiveModeling23_book/stan/W9_WB1.stan"
file <- file.path("stan/W9_WB1.stan")
mod_wb1 <- cmdstan_model(file, cpp_options = list(stan_threads = TRUE),
stanc_options = list("O1"))
db1 <- db %>% subset(w1 == 0.7 & w2 == 1) %>% mutate(
l1 = logit_scaled(Source1),
l2 = logit_scaled(Source2)
)
data_weightedBayes1 <- list(
N = nrow(db1),
y = db1$binary1,
Source1 = logit_scaled(db1$Source1),
Source2 = logit_scaled(db1$Source2)
)
samples_weighted1 <- mod_wb1$sample(
data = data_weightedBayes1,
seed = 123,
chains = 2,
parallel_chains = 2,
threads_per_chain = 2,
iter_warmup = 1500,
iter_sampling = 3000,
refresh = 500
)
## Running MCMC with 2 parallel chains, with 2 thread(s) per chain...
##
## Chain 1 Iteration: 1 / 4500 [ 0%] (Warmup)
## Chain 1 Iteration: 500 / 4500 [ 11%] (Warmup)
## Chain 1 Iteration: 1000 / 4500 [ 22%] (Warmup)
## Chain 1 Iteration: 1500 / 4500 [ 33%] (Warmup)
## Chain 1 Iteration: 1501 / 4500 [ 33%] (Sampling)
## Chain 1 Iteration: 2000 / 4500 [ 44%] (Sampling)
## Chain 1 Iteration: 2500 / 4500 [ 55%] (Sampling)
## Chain 1 Iteration: 3000 / 4500 [ 66%] (Sampling)
## Chain 1 Iteration: 3500 / 4500 [ 77%] (Sampling)
## Chain 1 Iteration: 4000 / 4500 [ 88%] (Sampling)
## Chain 1 Iteration: 4500 / 4500 [100%] (Sampling)
## Chain 2 Iteration: 1 / 4500 [ 0%] (Warmup)
## Chain 2 Iteration: 500 / 4500 [ 11%] (Warmup)
## Chain 2 Iteration: 1000 / 4500 [ 22%] (Warmup)
## Chain 2 Iteration: 1500 / 4500 [ 33%] (Warmup)
## Chain 2 Iteration: 1501 / 4500 [ 33%] (Sampling)
## Chain 2 Iteration: 2000 / 4500 [ 44%] (Sampling)
## Chain 2 Iteration: 2500 / 4500 [ 55%] (Sampling)
## Chain 2 Iteration: 3000 / 4500 [ 66%] (Sampling)
## Chain 2 Iteration: 3500 / 4500 [ 77%] (Sampling)
## Chain 2 Iteration: 4000 / 4500 [ 88%] (Sampling)
## Chain 2 Iteration: 4500 / 4500 [100%] (Sampling)
## Chain 1 finished in 0.0 seconds.
## Chain 2 finished in 0.0 seconds.
##
## Both chains finished successfully.
## Mean chain execution time: 0.0 seconds.
## Total execution time: 0.4 seconds.
9.16 Evaluation
draws_df <- as_draws_df(samples_weighted1$draws())
ggplot(draws_df, aes(.iteration, bias, group = .chain, color = .chain)) +
geom_line(alpha = 0.5) +
theme_classic()
ggplot(draws_df, aes(.iteration, w1, group = .chain, color = .chain)) +
geom_line(alpha = 0.5) +
theme_classic()
ggplot(draws_df, aes(.iteration, w2, group = .chain, color = .chain)) +
geom_line(alpha = 0.5) +
theme_classic()
ggplot(draws_df) +
geom_histogram(aes(bias), alpha = 0.6, fill = "lightblue") +
geom_histogram(aes(bias_prior), alpha = 0.6, fill = "pink") +
geom_vline(xintercept = db1$bias[1]) +
theme_bw()
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## Warning: Removed 1 rows containing missing values (`geom_vline()`).
ggplot(draws_df) +
geom_histogram(aes(w1), alpha = 0.6, fill = "lightblue") +
geom_histogram(aes(w1_prior), alpha = 0.6, fill = "pink") +
geom_vline(xintercept = db1$w1[1]) +
theme_bw()
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## Warning: Removed 1 rows containing missing values (`geom_vline()`).
ggplot(draws_df) +
geom_density(aes(w2), alpha = 0.6, fill = "lightblue") +
geom_density(aes(w2_prior), alpha = 0.6, fill = "pink") +
geom_vline(xintercept = db1$w2[1]) +
theme_bw()
## Warning: Removed 1 rows containing missing values (`geom_vline()`).
9.17 Build a temporal Bayes (storing current belief as prior for next turn)
WeightedTimeBayes_f <- function(bias, Source1, Source2, w1, w2){
w1 <- (w1 - 0.5)*2
w2 <- (w2 - 0.5)*2
outcome <- inv_logit_scaled(bias + w1 * logit_scaled(Source1) + w2 * logit_scaled(Source2))
return(outcome)
}
bias <- 0
trials <- seq(10)
Source1 <- seq(0.1,0.9, 0.1)
w1 <- seq(0.5, 1, 0.1)
w2 <- seq(0.5, 1, 0.1)
db <- expand.grid(bias = bias, trials, Source1 = Source1, w1 = w1, w2 = w2) %>%
mutate(Source2 = NA, belief = NA, binary = NA)
for (n in seq(nrow(db))) {
if (n == 1) {db$Source2[1] = 0.5}
db$belief[n] <- WeightedTimeBayes_f(db$bias[n], db$Source1[n], db$Source2[n],db$w1[n], db$w2[n])
db$binary[n] <- rbinom(1,1, db$belief[n])
if (n < nrow(db)) {db$Source2[n + 1] <- db$belief[n]}
}
stan_TB_model <- "
data {
int<lower=0> N;
array[N] int y;
array[N] real <lower = 0, upper = 1> Source1;
}
transformed data {
array[N] real l_Source1;
l_Source1 = logit(Source1);
}
parameters {
real bias;
// meaningful weights are btw 0.5 and 1 (theory reasons)
real<lower = 0.5, upper = 1> w1;
real<lower = 0.5, upper = 1> w2;
}
transformed parameters {
real<lower = 0, upper = 1> weight1;
real<lower = 0, upper = 1> weight2;
array[N] real l_Source2;
// weight parameters are rescaled to be on a 0-1 scale (0 -> no effects; 1 -> face value)
weight1 = (w1 - 0.5) * 2;
weight2 = (w2 - 0.5) * 2;
l_Source2[1] = 0;
for (n in 2:N){
l_Source2[n] = bias + weight1 * l_Source1[n] + weight2 * l_Source2[n-1];
}
}
model {
target += normal_lpdf(bias | 0, 1);
target += beta_lpdf(weight1 | 1, 1);
target += beta_lpdf(weight2 | 1, 1);
target += bernoulli_logit_lpmf(y[1] | bias + weight1 * l_Source1[1]);
for (n in 2:N){
target += bernoulli_logit_lpmf(y[n] | l_Source2[n]);
}
}
generated quantities{
array[N] real log_lik;
real bias_prior;
real w1_prior;
real w2_prior;
bias_prior = normal_rng(0, 1) ;
w1_prior = 0.5 + inv_logit(normal_rng(0, 1))/2 ;
w2_prior = 0.5 + inv_logit(normal_rng(0, 1))/2 ;
for (n in 1:N)
log_lik[n]= bernoulli_logit_lpmf(y[n] | l_Source2[n]);
}
"
write_stan_file(
stan_TB_model,
dir = "stan/",
basename = "W9_TB.stan")
## [1] "/Users/au209589/Dropbox/Teaching/AdvancedCognitiveModeling23_book/stan/W9_TB.stan"
file <- file.path("stan/W9_TB.stan")
mod_tb <- cmdstan_model(file, cpp_options = list(stan_threads = TRUE),
stanc_options = list("O1"))
db1 <- db %>% subset(w1 == 0.7 & w2 == 0.9)
data_TB <- list(
N = nrow(db1),
y = db1$binary,
Source1 = db1$Source1
)
samples_TB <- mod_tb$sample(
data = data_TB,
seed = 123,
chains = 2,
parallel_chains = 2,
threads_per_chain = 2,
iter_warmup = 2000,
iter_sampling = 3000,
refresh = 500,
adapt_delta = 0.99,
max_treedepth = 20
)
## Running MCMC with 2 parallel chains, with 2 thread(s) per chain...
##
## Chain 1 Iteration: 1 / 5000 [ 0%] (Warmup)
## Chain 1 Iteration: 500 / 5000 [ 10%] (Warmup)
## Chain 2 Iteration: 1 / 5000 [ 0%] (Warmup)
## Chain 2 Iteration: 500 / 5000 [ 10%] (Warmup)
## Chain 1 Iteration: 1000 / 5000 [ 20%] (Warmup)
## Chain 1 Iteration: 1500 / 5000 [ 30%] (Warmup)
## Chain 2 Iteration: 1000 / 5000 [ 20%] (Warmup)
## Chain 2 Iteration: 1500 / 5000 [ 30%] (Warmup)
## Chain 1 Iteration: 2000 / 5000 [ 40%] (Warmup)
## Chain 1 Iteration: 2001 / 5000 [ 40%] (Sampling)
## Chain 2 Iteration: 2000 / 5000 [ 40%] (Warmup)
## Chain 2 Iteration: 2001 / 5000 [ 40%] (Sampling)
## Chain 1 Iteration: 2500 / 5000 [ 50%] (Sampling)
## Chain 1 Iteration: 3000 / 5000 [ 60%] (Sampling)
## Chain 1 Iteration: 3500 / 5000 [ 70%] (Sampling)
## Chain 2 Iteration: 2500 / 5000 [ 50%] (Sampling)
## Chain 1 Iteration: 4000 / 5000 [ 80%] (Sampling)
## Chain 1 Iteration: 4500 / 5000 [ 90%] (Sampling)
## Chain 2 Iteration: 3000 / 5000 [ 60%] (Sampling)
## Chain 1 Iteration: 5000 / 5000 [100%] (Sampling)
## Chain 1 finished in 1.9 seconds.
## Chain 2 Iteration: 3500 / 5000 [ 70%] (Sampling)
## Chain 2 Iteration: 4000 / 5000 [ 80%] (Sampling)
## Chain 2 Iteration: 4500 / 5000 [ 90%] (Sampling)
## Chain 2 Iteration: 5000 / 5000 [100%] (Sampling)
## Chain 2 finished in 2.9 seconds.
##
## Both chains finished successfully.
## Mean chain execution time: 2.4 seconds.
## Total execution time: 3.2 seconds.
9.18 Evaluate
## Processing csv files: /var/folders/lt/zspkqnxd5yg92kybm5f433_cfjr0d6/T/RtmpP3dlfk/W9_TB-202402150651-1-26c994.csv, /var/folders/lt/zspkqnxd5yg92kybm5f433_cfjr0d6/T/RtmpP3dlfk/W9_TB-202402150651-2-26c994.csv
##
## Checking sampler transitions treedepth.
## Treedepth satisfactory for all transitions.
##
## Checking sampler transitions for divergences.
## No divergent transitions found.
##
## Checking E-BFMI - sampler transitions HMC potential energy.
## E-BFMI satisfactory.
##
## Effective sample size satisfactory.
##
## Split R-hat values satisfactory all parameters.
##
## Processing complete, no problems detected.
## # A tibble: 189 × 10
## variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 lp__ -50.5 -50.0 1.77 1.53 -54.0 -48.5 1.00 976. 1188.
## 2 bias -0.0234 -0.00914 0.117 0.0931 -0.234 0.141 1.00 1880. 1819.
## 3 w1 0.756 0.744 0.111 0.128 0.595 0.956 1.00 754. 1188.
## 4 w2 0.809 0.825 0.101 0.108 0.620 0.944 1.00 792. 1150.
## 5 weight1 0.511 0.488 0.222 0.256 0.189 0.912 1.00 754. 1188.
## 6 weight2 0.617 0.649 0.202 0.215 0.240 0.888 1.00 792. 1150.
## 7 l_Source2[1] 0 0 0 0 0 0 NA NA NA
## 8 l_Source2[2] -1.15 -1.09 0.527 0.596 -2.09 -0.387 1.00 855. 1362.
## 9 l_Source2[3] -1.76 -1.76 0.654 0.722 -2.86 -0.726 1.00 956. 1412.
## 10 l_Source2[4] -2.14 -2.15 0.685 0.721 -3.29 -1.02 1.00 1083. 1496.
## # ℹ 179 more rows
draws_df <- as_draws_df(samples_TB$draws())
p1 <- ggplot(draws_df, aes(.iteration, bias, group = .chain, color = .chain)) +
geom_line(alpha = 0.5) +
theme_classic()
p2 <- ggplot(draws_df, aes(.iteration, w1, group = .chain, color = .chain)) +
geom_line(alpha = 0.5) +
theme_classic()
p3 <- ggplot(draws_df, aes(.iteration, w2, group = .chain, color = .chain)) +
geom_line(alpha = 0.5) +
theme_classic()
p1 + p2 + p3
p1 <- ggplot(draws_df) +
geom_density(aes(bias), alpha = 0.6, fill = "lightblue") +
geom_density(aes(bias_prior), alpha = 0.6, fill = "pink") +
geom_vline(xintercept = db1$bias[1]) +
theme_bw()
p2 <- ggplot(draws_df) +
geom_density(aes(w1), alpha = 0.6, fill = "lightblue") +
geom_density(aes(w1_prior), alpha = 0.6, fill = "pink") +
geom_vline(xintercept = db1$w1[1]) +
theme_bw()
p3 <- ggplot(draws_df) +
geom_density(aes(w2), alpha = 0.6, fill = "lightblue") +
geom_density(aes(w2_prior), alpha = 0.6, fill = "pink") +
geom_vline(xintercept = db1$w2[1]) +
theme_bw()
p1 + p2 + p3
p1 <- ggplot(draws_df) +
geom_point(aes(w1, w2), alpha = 0.3) +
theme_bw()
p2 <- ggplot(draws_df) +
geom_point(aes(bias, w1), alpha = 0.3) +
theme_bw()
p3 <- ggplot(draws_df) +
geom_point(aes(bias, w2), alpha = 0.3) +
theme_bw()
p1 + p2 + p3
[MISSING: model seems good at recovering. BUT HIGH correlations between weights and funnels, so we would probably be safer reparameterizing]
9.19 Multilevel version of the simple bayes model
stan_simpleBayes_ml_model <- "
data {
int<lower=0> N; // n of trials
int<lower=0> S; // n of participants
array[N, S] int y;
array[N, S] real<lower=0, upper = 1> Source1;
array[N, S] real<lower=0, upper = 1> Source2;
}
transformed data{
array[N, S] real l_Source1;
array[N, S] real l_Source2;
l_Source1 = logit(Source1);
l_Source2 = logit(Source2);
}
parameters {
real biasM;
real biasSD;
array[S] real z_bias;
}
transformed parameters {
vector[S] biasC;
vector[S] bias;
biasC = biasSD * to_vector(z_bias);
bias = biasM + biasC;
}
model {
target += normal_lpdf(biasM | 0, 1);
target += normal_lpdf(biasSD | 0, 1) -
normal_lccdf(0 | 0, 1);
target += std_normal_lpdf(to_vector(z_bias));
for (s in 1:S){
target += bernoulli_logit_lpmf(y[,s] | bias[s] +
to_vector(l_Source1[,s]) +
to_vector(l_Source2[,s]));
}
}
"
write_stan_file(
stan_simpleBayes_ml_model,
dir = "stan/",
basename = "W9_SimpleBayes_ml.stan")
## [1] "/Users/au209589/Dropbox/Teaching/AdvancedCognitiveModeling23_book/stan/W9_SimpleBayes_ml.stan"
9.20 Multilevel version of the weighted bayes model
stan_WB_ml_model <- "
data {
int<lower=0> N; // n of trials
int<lower=0> S; // n of participants
array[N, S] int y;
array[N, S] real<lower=0, upper = 1> Source1;
array[N, S] real<lower=0, upper = 1> Source2;
}
transformed data {
array[N,S] real l_Source1;
array[N,S] real l_Source2;
l_Source1 = logit(Source1);
l_Source2 = logit(Source2);
}
parameters {
real biasM;
real w1_M;
real w2_M;
vector<lower = 0>[3] tau;
matrix[3, S] z_IDs;
cholesky_factor_corr[3] L_u;
}
transformed parameters{
matrix[S,3] IDs;
IDs = (diag_pre_multiply(tau, L_u) * z_IDs)';
}
model {
target += normal_lpdf(biasM | 0, 1);
target += normal_lpdf(tau[1] | 0, 1) -
normal_lccdf(0 | 0, 1);
target += normal_lpdf(w1_M | 0, 1);
target += normal_lpdf(tau[2] | 0, 1) -
normal_lccdf(0 | 0, 1);
target += normal_lpdf(w2_M | 0, 1);
target += normal_lpdf(tau[3] | 0, 1) -
normal_lccdf(0 | 0, 1);
target += lkj_corr_cholesky_lpdf(L_u | 3);
target += std_normal_lpdf(to_vector(z_IDs));
for (s in 1:S){
for (n in 1:N){
target += bernoulli_logit_lpmf(y[n,s] | biasM + IDs[s, 1] +
(w1_M + IDs[s, 2]) * l_Source1[n,s] +
(w2_M + IDs[s, 3]) * l_Source2[n,s]);
}}
}
"
write_stan_file(
stan_WB_ml_model,
dir = "stan/",
basename = "W9_WB_ml.stan")
## [1] "/Users/au209589/Dropbox/Teaching/AdvancedCognitiveModeling23_book/stan/W9_WB_ml.stan"
file <- file.path("stan/W9_WB_ml.stan")
mod_wb <- cmdstan_model(file, cpp_options = list(stan_threads = TRUE),
stanc_options = list("O1"))
[MISSING: Fitting on real data and model comparison]