Chapter 10 Win-Stay-Lose-Shift: A Heuristic Decision Strategy

10.1 Introduction

Win-Stay-Lose-Shift (WSLS) represents one of the simplest yet most fascinating decision-making strategies observed in both human and animal behavior. The core principle is intuitive: if an action leads to success, repeat it; if it leads to failure, try something else. Despite its simplicity, this strategy can produce sophisticated behavioral patterns and proves surprisingly effective in many scenarios. In this chapter, we’ll explore WSLS through computational modeling, building on our previous work with random agents. We’ll see how this apparently simple strategy can capture important aspects of learning and adaptation. Through careful implementation and testing, we’ll develop insights into both the strengths and limitations of WSLS as a model of decision-making.

Our exploration will follow several key steps:

  • Implementing the basic WSLS strategy in code

  • Testing it against different opponents

  • Scaling up to multiple agents

  • Analyzing patterns in the resulting data

The WSLS strategy differs significantly from our previous random agent models. Rather than making choices based on fixed probabilities, a WSLS agent:

  • Remembers its previous choice

  • Tracks whether that choice was successful

  • Uses this information to determine its next move

This creates an interesting form of path dependence - the agent’s choices are shaped by its history of interactions (and we need to ensure that memory is calculated for the correct trial and applied to the following trial).

10.1.1 Loading Required Packages

Let’s begin by loading the packages we’ll need for our analysis:

# Load necessary packages for simulation and analysis
pacman::p_load(tidyverse,   # For data manipulation and visualization
               here,        # For file path management
               posterior,   # For working with posterior samples
               cmdstanr,    # For interfacing with Stan
               brms,        # For Bayesian regression models
               tidybayes,   # For working with Bayesian samples
               loo,         # For model comparison
               patchwork)   # For combining multiple plots

10.2 Generating Data: Simulating WSLS Agents

Now we’ll set up a simulation environment where WSLS agents interact with random agents. The parameters we define here will shape our simulation. We’ll create 100 agents who each play 120 trials, allowing us to observe both individual behavior and broader patterns across many interactions.

Our WSLS agent implementation uses a parameterized approach where:

  • alpha represents a baseline bias toward choosing one option over another

  • betaWin represents the strength of the “stay” response after a win

  • betaLose represents the strength of the “shift” response after a loss

  • noise allows for occasional random deviations from the strategy

This parameterization lets us explore variants of the WSLS strategy with different sensitivities to wins and losses.

# Define simulation parameters
agents <- 100    # Number of agents to simulate
trials <- 120    # Number of trials per agent

# Noise parameter (probability of random choice regardless of strategy)
noise <- 0

# Parameters for random agent (on log-odds scale)
rateM <- 1.4     # Population mean bias for random agent
rateSD <- 0.3    # Population SD of bias

# Parameters for WSLS agent (on log-odds scale)
alphaM <- 0      # Population mean baseline bias
alphaSD <- 0.1   # Population SD of baseline bias

betaWinM <- 1.5  # Population mean win-stay parameter
betaWinSD <- 0.3 # Population SD of win-stay parameter

betaLoseM <- 1.5 # Population mean lose-shift parameter 
betaLoseSD <- 0.3 # Population SD of lose-shift parameter

Next, we’ll define functions to implement our agent strategies:

# Random agent function: makes choices based on bias parameter (on log-odds scale)
# with possibility of random noise
RandomAgentNoise_f <- function(rate, noise) {
  # Generate choice based on agent's bias parameter
  choice <- rbinom(1, 1, inv_logit_scaled(rate))
  
  # With probability 'noise', override choice with random 50/50 selection
  if (rbinom(1, 1, noise) == 1) {
    choice = rbinom(1, 1, 0.5)
  }
  
  return(choice)
}

# Win-Stay-Lose-Shift agent function:
# - alpha: baseline bias parameter
# - betaWin: parameter controlling the tendency to repeat winning choices
# - betaLose: parameter controlling the tendency to switch after losing
# - win: indicator of previous win (+1 if won with choice 1, -1 if won with choice 0, 0 if lost)
# - lose: indicator of previous loss (+1 if lost with choice 0, -1 if lost with choice 1, 0 if won)
# - noise: probability of making a random choice
WSLSAgentNoise_f <- function(alpha, betaWin, betaLose, win, lose, noise) {
  # Calculate choice probability based on WSLS parameters and previous outcomes
  rate <- alpha + betaWin * win + betaLose * lose
  
  # Generate choice based on calculated probability
  choice <- rbinom(1, 1, inv_logit_scaled(rate))
  
  # With probability 'noise', override choice with random 50/50 selection
  if (rbinom(1, 1, noise) == 1) {
    choice = rbinom(1, 1, 0.5)
  }
  
  return(choice)
}

Now let’s generate the simulation data. We’ll simulate each agent playing against a random opponent, and track their choices, wins, and losses:

# File path for saved simulation results
sim_data_file <- "simdata/W9_WSLS_data.RDS"

# Check if we need to regenerate the simulation data
if (regenerate_simulations || !file.exists(sim_data_file)) {
  # Initialize dataframe to store results
  d <- NULL
  
  # Loop through all agents
  for (agent in 1:agents) {
    # Sample individual agent parameters from population distributions
    rate <- rnorm(1, rateM, rateSD)         # Individual bias for random agent
    alpha <- rnorm(1, alphaM, alphaSD)      # Individual baseline bias for WSLS agent
    betaWin <- rnorm(1, betaWinM, betaWinSD)  # Individual win-stay parameter
    betaLose <- rnorm(1, betaLoseM, betaLoseSD) # Individual lose-shift parameter
    
    # Initialize vectors to store data for this agent
    randomChoice <- rep(NA, trials)  # Choices of random agent
    wslsChoice <- rep(NA, trials)    # Choices of WSLS agent
    win <- rep(NA, trials)           # Win indicator for WSLS agent
    lose <- rep(NA, trials)          # Lose indicator for WSLS agent
    feedback <- rep(NA, trials)      # Whether WSLS agent won (1) or lost (0)
    
    # Generate choices for each trial
    for (trial in 1:trials) {
      # Random agent makes choice
      randomChoice[trial] <- RandomAgentNoise_f(rate, noise)
      
      # WSLS agent makes first choice randomly, then follows strategy
      if (trial == 1) {
        wslsChoice[trial] <- rbinom(1, 1, 0.5)  # First choice is random
      } else {
        # Use WSLS strategy based on previous outcome
        wslsChoice[trial] <- WSLSAgentNoise_f(
          alpha, betaWin, betaLose, win[trial - 1], lose[trial - 1], noise
        )
      }
      
      # Determine outcome (1 = WSLS agent wins, 0 = WSLS agent loses)
      feedback[trial] <- ifelse(wslsChoice[trial] == randomChoice[trial], 1, 0)
      
      # Encode win/lose signals for WSLS strategy:
      # win: +1 if won with choice 1, -1 if won with choice 0, 0 if lost
      win[trial] <- ifelse(feedback[trial] == 1, 
                          ifelse(wslsChoice[trial] == 1, 1, -1), 
                          0)
      
      # lose: +1 if lost with choice 0, -1 if lost with choice 1, 0 if won
      lose[trial] <- ifelse(feedback[trial] == 0, 
                           ifelse(wslsChoice[trial] == 1, -1, 1), 
                           0)
    }
    
    # Create data frames for this agent
    tempRandom <- tibble(
      agent, 
      trial = seq(trials), 
      choice = randomChoice, 
      rate, 
      noise, 
      rateM, 
      rateSD, 
      alpha, 
      alphaM, 
      alphaSD, 
      betaWin, 
      betaWinM, 
      betaWinSD, 
      betaLose, 
      betaLoseM, 
      betaLoseSD, 
      win, 
      lose, 
      feedback, 
      strategy = "Random"
    )
    
    tempWSLS <- tibble(
      agent, 
      trial = seq(trials), 
      choice = wslsChoice, 
      rate, 
      noise, 
      rateM, 
      rateSD, 
      alpha, 
      alphaM, 
      alphaSD, 
      betaWin, 
      betaWinM, 
      betaWinSD, 
      betaLose, 
      betaLoseM, 
      betaLoseSD, 
      win, 
      lose, 
      feedback, 
      strategy = "WSLS"
    )
    
    # Combine data for both agent types
    temp <- rbind(tempRandom, tempWSLS)
    
    # Append to main dataframe
    if (agent > 1) {
      d <- rbind(d, temp)
    } else {
      d <- temp
    }
  }
  
  # Save the simulation results
  saveRDS(d, sim_data_file)
  cat("Generated new simulation data and saved to", sim_data_file, "\n")
} else {
  # Load existing simulation results
  d <- readRDS(sim_data_file)
  cat("Loaded existing simulation data from", sim_data_file, "\n")
}
## Loaded existing simulation data from simdata/W9_WSLS_data.RDS

10.3 Data Processing and Initial Visualization

Let’s process our data to create useful variables for analysis and visualization:

# Process data: add lead/lag variables and calculate cumulative statistics
d <- d %>% 
  group_by(agent, strategy) %>% 
  mutate(
    # Next and previous choices/outcomes
    nextChoice = lead(choice),            # Next choice (for analysis)
    prevWin = lag(win),                   # Previous win indicator
    prevLose = lag(lose),                 # Previous lose indicator
    
    # Performance metrics
    cumulativerate = cumsum(choice) / seq_along(choice),  # Running proportion of "right" choices
    performance = cumsum(feedback) / seq_along(feedback)  # Running proportion of wins
  ) %>%
  mutate(
    # identify win and lose as 0 during first trial
    prevWin = ifelse(is.na(prevWin), 0, prevWin),
    prevLose = ifelse(is.na(prevLose), 0, prevLose)
  )

Now let’s visualize the choice patterns of our agents to get a sense of their behavior:

# First visualization: Proportion of "right" choices over time for both strategies
p1 <- ggplot(d, aes(trial, cumulativerate, group = agent, color = agent)) + 
  geom_line(alpha = 0.3) + 
  geom_hline(yintercept = 0.5, linetype = "dashed") + 
  ylim(0, 1) + 
  labs(
    title = "Choice Patterns by Strategy",
    subtitle = "Each line represents one agent's running proportion of 'right' choices",
    x = "Trial",
    y = "Proportion of 'Right' Choices"
  ) +
  theme_classic() + 
  facet_wrap(~strategy) +
  theme(legend.position = "none")  # Hide individual agent legend for clarity

# Display the plot
p1

The visualization above shows how the choice patterns evolve over time for both the Random and WSLS agents. The random agents show an approximately stable choice pattern (though with individual biases), while the WSLS agents show more varied patterns as they adapt to their opponents.

Let’s also examine how each strategy performs against its opponent:

# Create visualization for performance against opponent
p2a <- ggplot(subset(d, strategy == "Random"), 
             aes(trial, 1 - performance, group = agent, color = agent)) + 
  geom_line(alpha = 0.3) + 
  geom_hline(yintercept = 0.5, linetype = "dashed") + 
  ylim(0, 1) + 
  labs(
    title = "Performance of Random Agents",
    subtitle = "Lower values indicate better performance against WSLS opponents",
    x = "Trial",
    y = "Proportion of Losses"
  ) +
  theme_classic() +
  theme(legend.position = "none")

p2b <- ggplot(subset(d, strategy == "WSLS"), 
             aes(trial, performance, group = agent, color = agent)) + 
  geom_line(alpha = 0.3) + 
  geom_hline(yintercept = 0.5, linetype = "dashed") + 
  ylim(0, 1) + 
  labs(
    title = "Performance of WSLS Agents",
    subtitle = "Higher values indicate better performance against Random opponents",
    x = "Trial",
    y = "Proportion of Wins"
  ) +
  theme_classic() +
  theme(legend.position = "none")

# Combine plots using patchwork
p2a / p2b

The performance plots reveal an interesting pattern. The WSLS agents generally maintain a winning percentage above 0.5 (the dashed line), indicating they can effectively exploit the biases in the random agents. This demonstrates a key strength of the WSLS strategy - it can adapt to and take advantage of predictable patterns in opponent behavior.

10.4 Verifying Model Properties and Assumptions

Let’s check some key properties of our simulation to ensure it’s working as expected:

# Check that win and lose indicators are orthogonal as expected
p3 <- ggplot(d, aes(win, lose)) + 
  geom_jitter(alpha = 0.1, width = 0.1, height = 0.1) + 
  labs(
    title = "Orthogonality of Win and Lose Indicators",
    subtitle = "Confirms that win and lose signals are mutually exclusive",
    x = "Win Indicator",
    y = "Lose Indicator"
  ) +
  theme_bw()

p3

The plot confirms that our win and lose indicators are mutually exclusive - when the win signal is active (±1), the lose signal is 0, and vice versa. This is important for the proper functioning of our WSLS model.

Now let’s check that the WSLS agents are indeed responding to win and lose signals as expected:

# Check that WSLS agents respond to win/lose signals appropriately
p4 <- d %>% 
  subset(strategy == "WSLS") %>% 
  mutate(nextChoice = lead(choice)) %>%
  group_by(agent, win, lose) %>%
  summarize(heads = mean(nextChoice), .groups = "drop") %>%
  ggplot(aes(win, heads)) +
  geom_point(alpha = 0.5) +
  labs(
    title = "WSLS Strategy Implementation Check",
    subtitle = "Shows how next choice probability depends on win/lose signals",
    x = "Win Signal (-1 = won with 'left', +1 = won with 'right', 0 = lost)",
    y = "Probability of Choosing 'Right' on Next Trial"
  ) +
  theme_bw() +
  facet_wrap(~lose, labeller = labeller(
    lose = c("-1" = "Lose = -1\n(lost with 'right')", 
             "0" = "Lose = 0\n(won)", 
             "1" = "Lose = 1\n(lost with 'left')")
  ))

p4

This plot confirms that our WSLS agents are behaving as expected:

  • When win = +1 (won with ‘right’), the agents tend to choose ‘right’ again
  • When win = -1 (won with ‘left’), the agents tend to choose ‘left’ again
  • When lose = +1 (lost with ‘left’), the agents tend to switch to ‘right’
  • When lose = -1 (lost with ‘right’), the agents tend to switch to ‘left’

The variation in probabilities reflects the individual differences in our agents’ parameters.

10.5 Modeling a Single WSLS Agent

Let’s first build a model to infer the parameters of a single WSLS agent. This will help us understand the basic mechanics before scaling up to the multilevel model:

# Select one agent's data for single-agent model
d_a <- d %>% subset(
  strategy == "WSLS" & agent == 2  # Arbitrarily select agent #2 Note that trial 1 is already excluded
)

# Prepare data for Stan model
data_wsls_simple <- list(
  trials = trials,       # trials
  h = d_a$choice,            # Choices (0/1)
  win = d_a$prevWin,         # Previous win signal
  lose = d_a$prevLose        # Previous lose signal
)

Now let’s define the Stan model for a single WSLS agent:

# Stan model code for single WSLS agent
stan_wsls_model <- "
functions{
  // Helper function for generating truncated normal random numbers
  real normal_lb_rng(real mu, real sigma, real lb) {
    real p = normal_cdf(lb | mu, sigma);  // cdf for bounds
    real u = uniform_rng(p, 1);
    return (sigma * inv_Phi(u)) + mu;  // inverse cdf for value
  }
}

data {
 int<lower = 1> trials;        // Number of trials
 array[trials] int h;          // Choices (0/1)
 vector[trials] win;           // Win signal for each trial
 vector[trials] lose;          // Lose signal for each trial
}

parameters {
  real alpha;                  // Baseline bias parameter
  real winB;                   // Win-stay parameter
  real loseB;                  // Lose-shift parameter
}

model {
  // Priors
  target += normal_lpdf(alpha | 0, .3);   // Prior for baseline bias
  target += normal_lpdf(winB | 1, 1);     // Prior for win-stay parameter
  target += normal_lpdf(loseB | 1, 1);    // Prior for lose-shift parameter
  
  // Likelihood: WSLS choice model
  // Remember that in the first trial we ensured win and lose have a value of 0,
  // which correspond to a fixed 0.5 probability (0 on logit scale)
  // since there is no previous outcome to guide the choice
  target += bernoulli_logit_lpmf(h | alpha + winB * win + loseB * lose);

}

generated quantities{
   // Prior predictive samples
   real alpha_prior;
   real winB_prior;
   real loseB_prior;
   
   // Posterior and prior predictions
   array[trials] int prior_preds;
   array[trials] int posterior_preds;
   
   // Log likelihood for model comparison
   vector[trials] log_lik;

   // Generate prior samples
   alpha_prior = normal_rng(0, 1);
   winB_prior = normal_rng(0, 1);
   loseB_prior = normal_rng(0, 1);
   
   // Prior predictive simulations
   for (t in 1:trials) {
     prior_preds[t] = bernoulli_logit_rng(alpha_prior + winB_prior * win[t] + loseB_prior * lose[t]);
   }
   
   // Posterior predictive simulations
   for (t in 1:trials) {
     posterior_preds[t] = bernoulli_logit_rng(alpha + winB * win[t] + loseB * lose[t]);
   }
   
   // Calculate log likelihood for each observation
   for (t in 1:trials){
     log_lik[t] = bernoulli_logit_lpmf(h[t] | alpha + winB * win[t] + loseB * lose[t]);
   }
}
"

# Write the model to a file
write_stan_file(
  stan_wsls_model,
  dir = "stan/",
  basename = "W9_WSLS.stan")
## [1] "/Users/au209589/Dropbox/Teaching/AdvancedCognitiveModeling23_book/stan/W9_WSLS.stan"
# Compile the model
file <- file.path("stan/W9_WSLS.stan")
mod_wsls_simple <- cmdstan_model(
  file, 
  cpp_options = list(stan_threads = TRUE),
  stanc_options = list("O1"), 
  pedantic = TRUE
)

Now let’s fit the model to our single agent’s data:

# File path for saved model results
single_model_file <- "simmodels/W9_WSLS_simple.RDS"

# Check if we need to rerun the model
if (regenerate_simulations || !file.exists(single_model_file)) {
  # Fit the model
  samples_wsls_simple <- mod_wsls_simple$sample(
    data = data_wsls_simple,
    seed = 123,
    chains = 2,
    parallel_chains = 2,
    threads_per_chain = 1,
    iter_warmup = 2000,
    iter_sampling = 2000,
    refresh = 1000,
    max_treedepth = 20,
    adapt_delta = 0.99
  )
  
  # Save the results
  samples_wsls_simple$save_object(file = single_model_file)
  cat("Generated new model fit and saved to", single_model_file, "\n")
} else {
  # Load existing results
  samples_wsls_simple <- readRDS(single_model_file)
  cat("Loaded existing model fit from", single_model_file, "\n")
}
## Loaded existing model fit from simmodels/W9_WSLS_simple.RDS

Let’s examine the parameter estimates and convergence for the single agent model:

# Display summary statistics for the parameters
summary_stats <- samples_wsls_simple$summary() 
print(summary_stats)
## # A tibble: 367 × 10
##    variable           mean   median    sd   mad      q5     q95  rhat ess_bulk ess_tail
##    <chr>             <dbl>    <dbl> <dbl> <dbl>   <dbl>   <dbl> <dbl>    <dbl>    <dbl>
##  1 lp__           -58.4    -58.1    1.20  0.997 -60.7   -57.1    1.00    1606.    2161.
##  2 alpha            0.172    0.174  0.224 0.228  -0.190   0.532  1.00    1950.    2217.
##  3 winB             1.16     1.15   0.332 0.337   0.628   1.70   1.00    1845.    2075.
##  4 loseB            1.90     1.90   0.413 0.413   1.24    2.60   1.00    2039.    1774.
##  5 alpha_prior     -0.0314  -0.0233 0.993 0.978  -1.69    1.60   1.00    3888.    3857.
##  6 winB_prior       0.0165   0.0208 1.00  1.01   -1.62    1.66   1.00    3845.    3867.
##  7 loseB_prior     -0.0254  -0.0333 1.02  1.01   -1.71    1.65   1.00    3926.    3728.
##  8 prior_preds[1]   0.484    0      0.500 0       0       1      1.00    3927.      NA 
##  9 prior_preds[2]   0.490    0      0.500 0       0       1      1.00    3597.      NA 
## 10 prior_preds[3]   0.491    0      0.500 0       0       1      1.00    3847.      NA 
## # ℹ 357 more rows
# Extract posterior samples and include sampling of the prior
draws_df <- as_draws_df(samples_wsls_simple$draws())

# Create plot grid for parameter recovery
p5a <- ggplot(draws_df) +
  geom_histogram(aes(alpha), fill = "blue", alpha = 0.3) +
  geom_histogram(aes(alpha_prior), fill = "red", alpha = 0.3) +
  geom_vline(xintercept = d_a$alpha[1], linetype = "dashed", size = 1) +
  labs(
    title = "Baseline Bias Parameter (Alpha)",
    subtitle = "Blue = posterior, Red = prior, Dashed line = true value",
    x = "Parameter Value",
    y = "Density"
  ) +
  theme_classic()

p5b <- ggplot(draws_df) +
  geom_histogram(aes(winB), fill = "blue", alpha = 0.3) +
  geom_histogram(aes(winB_prior), fill = "red", alpha = 0.3) +
  geom_vline(xintercept = d_a$betaWin[1], linetype = "dashed", size = 1) +
  labs(
    title = "Win-Stay Parameter (Beta Win)",
    subtitle = "Blue = posterior, Red = prior, Dashed line = true value",
    x = "Parameter Value",
    y = "Density"
  ) +
  theme_classic()

p5c <- ggplot(draws_df) +
  geom_histogram(aes(loseB), fill = "blue", alpha = 0.3) +
  geom_histogram(aes(loseB_prior), fill = "red", alpha = 0.3) +
  geom_vline(xintercept = d_a$betaLose[1], linetype = "dashed", size = 1) +
  labs(
    title = "Lose-Shift Parameter (Beta Lose)",
    subtitle = "Blue = posterior, Red = prior, Dashed line = true value",
    x = "Parameter Value",
    y = "Density"
  ) +
  theme_classic()

# Combine plots
(p5a | p5b) / p5c

The parameter recovery for our single agent looks good. The posterior distributions (blue) are centered around the true parameter values (dashed lines), showing that our model can accurately recover the underlying parameters that generated the agent’s behavior. The posteriors are also substantially narrower than the priors (red), indicating that the data is informative.

[MISSING A FULL PARAMETER RECOVERY]

10.6 Multilevel WSLS Model

Now let’s scale up to model all agents simultaneously with a multilevel (hierarchical) model. This allows us to estimate both population-level parameters and individual differences:

## Prepare data for multilevel model
# Transform data into matrices where columns are agents, rows are trials
d_wsls1 <- d %>% 
  subset(strategy == "WSLS") %>% 
  subset(select = c(agent, choice)) %>% 
  mutate(row = row_number()) %>% 
  pivot_wider(names_from = agent, values_from = choice)

d_wsls2 <- d %>% 
  subset(strategy == "WSLS") %>% 
  subset(select = c(agent, prevWin)) %>% 
  mutate(row = row_number()) %>% 
  pivot_wider(names_from = agent, values_from = prevWin)

d_wsls3 <- d %>% 
  subset(strategy == "WSLS") %>% 
  subset(select = c(agent, prevLose)) %>% 
  mutate(row = row_number()) %>% 
  pivot_wider(names_from = agent, values_from = prevLose)

# Create the data list for Stan
data_wsls <- list(
  trials = trials,
  agents = agents,
  h = as.matrix(d_wsls1[, 2:(agents + 1)]),    # Choice matrix
  win = as.matrix(d_wsls2[, 2:(agents + 1)]),  # Win signal matrix
  lose = as.matrix(d_wsls3[, 2:(agents + 1)])  # Lose signal matrix
)

Now let’s define the multilevel Stan model:

# Stan code for multilevel WSLS model
stan_wsls_ml_model <- "
functions{
  real normal_lb_rng(real mu, real sigma, real lb) {
    real p = normal_cdf(lb | mu, sigma);  // cdf for bounds
    real u = uniform_rng(p, 1);
    return (sigma * inv_Phi(u)) + mu;  // inverse cdf for value
  }
}

// The input (data) for the model. 
data {
 int<lower = 1> trials;           // Number of trials
 int<lower = 1> agents;           // Number of agents
 array[trials, agents] int h;      // Choice data (0/1)
 array[trials, agents] real win;   // Win signals
 array[trials, agents] real lose;  // Lose signals
}

parameters {
  // Population-level parameters
  real winM;                    // Population mean for win-stay parameter
  real loseM;                   // Population mean for lose-shift parameter
  
  // Population standard deviations
  vector<lower = 0>[2] tau;     // SDs for [win, lose] parameters
  
  // Individual z-scores (non-centered parameterization)
  matrix[2, agents] z_IDs;
  
  // Correlation matrix
  cholesky_factor_corr[2] L_u;
}

transformed parameters {
  // Individual parameters (constructed from non-centered parameterization)
  matrix[agents, 2] IDs;
  IDs = (diag_pre_multiply(tau, L_u) * z_IDs)';
}

model {
  // Population-level priors
  target += normal_lpdf(winM | 0, 1);     // Prior for win-stay mean
  target += normal_lpdf(tau[1] | 0, .3) - normal_lccdf(0 | 0, .3);  // Half-normal for SD
  
  target += normal_lpdf(loseM | 0, .3);   // Prior for lose-shift mean
  target += normal_lpdf(tau[2] | 0, .3) - normal_lccdf(0 | 0, .3);  // Half-normal for SD
  
  // Prior for correlation matrix
  target += lkj_corr_cholesky_lpdf(L_u | 2);
  
  // Prior for individual z-scores
  target += std_normal_lpdf(to_vector(z_IDs));
  
  // Likelihood
  for (i in 1:agents)
    target += bernoulli_logit_lpmf(h[,i] | to_vector(win[,i]) * (winM + IDs[i,1]) + 
                                            to_vector(lose[,i]) * (loseM + IDs[i,2]));
}

generated quantities{
   // Prior predictive samples
   real winM_prior;
   real<lower=0> winSD_prior;
   real loseM_prior;
   real<lower=0> loseSD_prior;
   
   real win_prior;
   real lose_prior;
   
   // Posterior predictive samples for various scenarios
   array[trials, agents] int prior_preds;
   array[trials, agents] int posterior_preds;
   
   // Log likelihood for model comparison
   array[trials, agents] real log_lik;

   // Generate prior samples
   winM_prior = normal_rng(0, 1);
   winSD_prior = normal_lb_rng(0, 0.3, 0);
   loseM_prior = normal_rng(0, 1);
   loseSD_prior = normal_lb_rng(0, 0.3, 0);
   
   win_prior = normal_rng(winM_prior, winSD_prior);
   lose_prior = normal_rng(loseM_prior, loseSD_prior);
   
   // Generate predictions
   for (i in 1:agents){
      // Prior predictive simulations
      for (t in 1:trials) {
        prior_preds[t, i] = bernoulli_logit_rng(
          win[t, i] * win_prior + lose[t, i] * lose_prior
        );
      }
      
      // Posterior predictive simulations
      for (t in 1:trials) {
        posterior_preds[t, i] = bernoulli_logit_rng(
          win[t, i] * (winM + IDs[i, 1]) + lose[t, i] * (loseM + IDs[i, 2])
        );
      }
      
      // Calculate log likelihood for each observation
      for (t in 1:trials){
        log_lik[t, i] = bernoulli_logit_lpmf(
          h[t, i] | win[t, i] * (winM + IDs[i, 1]) + lose[t, i] * (loseM + IDs[i, 2])
        );
      }
   }
}
"

# Write the model to a file
write_stan_file(
  stan_wsls_ml_model,
  dir = "stan/",
  basename = "W9_WSLS_ml.stan")
## [1] "/Users/au209589/Dropbox/Teaching/AdvancedCognitiveModeling23_book/stan/W9_WSLS_ml.stan"
# Compile the model
file <- file.path("stan/W9_WSLS_ml.stan")
mod_wsls_ml <- cmdstan_model(
  file, 
  cpp_options = list(stan_threads = TRUE),
  stanc_options = list("O1"), 
  pedantic = TRUE
)

Now let’s fit the multilevel model to all agents:

# File path for saved multilevel model results
multilevel_model_file <- "simmodels/W9_WSLS_multilevel.RDS"

# Check if we need to rerun the model
if (regenerate_simulations || !file.exists(multilevel_model_file)) {
  # Fit the multilevel model
  samples_wsls_ml <- mod_wsls_ml$sample(
    data = data_wsls,
    seed = 123,
    chains = 2,
    parallel_chains = 2,
    threads_per_chain = 1,
    iter_warmup = 2000,
    iter_sampling = 2000,
    refresh = 500,  # Reduced refresh rate for cleaner output
    max_treedepth = 20,
    adapt_delta = 0.99
  )
  
  # Save the results
  samples_wsls_ml$save_object(file = multilevel_model_file)
  cat("Generated new multilevel model fit and saved to", multilevel_model_file, "\n")
} else {
  # Load existing results
  samples_wsls_ml <- readRDS(multilevel_model_file)
  cat("Loaded existing multilevel model fit from", multilevel_model_file, "\n")
}
## Loaded existing multilevel model fit from simmodels/W9_WSLS_multilevel.RDS

10.7 Quality Checks and Parameter Recovery

10.7.1 Convergence Diagnostics

Let’s examine convergence diagnostics to ensure our model has estimated the parameters reliably:

draws_ml <- as_draws_df(samples_wsls_ml$draws())

# Trace plots for convergence checking
mcmc_trace(draws_ml, pars = c("winM", "loseM", "tau[1]", "tau[2]")) +
  labs(
    title = "MCMC Trace Plots for Key Parameters",
    subtitle = "Good mixing indicates model convergence"
  )

# Rank plots for assessing convergence
mcmc_rank_hist(draws_ml, pars = c("winM", "loseM", "tau[1]", "tau[2]")) +
  labs(
    title = "Rank Plots for Key Parameters",
    subtitle = "Uniform distributions indicate good mixing"
  )

The trace plots show the parameter values across iterations. Good mixing (chains overlapping without patterns) indicates convergence. Rank histograms near uniform also suggest good convergence, while U-shaped or inverted-U histograms would indicate poor mixing.

10.7.2 Prior-Posterior Update Visualization

Now let’s visualize how our knowledge about the parameters has been updated by the data:

# Extract prior samples
prior_samples <- tibble(
  winM_prior = draws_ml$winM_prior,
  loseM_prior = draws_ml$loseM_prior,
  winSD_prior = draws_ml$winSD_prior,
  loseSD_prior = draws_ml$loseSD_prior
)

# Create dataframe for prior-posterior comparison
update_df <- tibble(
  Parameter = rep(c("Win-Stay Mean", "Lose-Shift Mean", "Win-Stay SD", "Lose-Shift SD"), 
                 each = nrow(draws_ml), times = 2),
  Value = c(draws_ml$winM, draws_ml$loseM, draws_ml$`tau[1]`, draws_ml$`tau[2]`,
            prior_samples$winM_prior, prior_samples$loseM_prior, 
            prior_samples$winSD_prior, prior_samples$loseSD_prior),
  Distribution = rep(c("Posterior", "Prior"), each = 4 * nrow(draws_ml)),
  True_Value = rep(c(betaWinM, betaLoseM, betaWinSD, betaLoseSD), 
                   each = nrow(draws_ml), times = 2)
)

# Visualize prior-posterior update
ggplot(update_df, aes(x = Value, fill = Distribution)) +
  geom_histogram(alpha = 0.6) +
  geom_vline(aes(xintercept = True_Value), color = "black", linetype = "dashed") +
  facet_wrap(~ Parameter, scales = "free") +
  scale_fill_manual(values = c("Prior" = "lightpink", "Posterior" = "steelblue")) +
  labs(
    title = "Prior vs. Posterior Distributions",
    subtitle = "Prior (pink) vs. posterior (blue) distributions with true values (dashed lines)",
    x = "Parameter Value",
    y = "Density"
  ) +
  theme_minimal()

This visualization shows how our knowledge about the parameters has been updated by the data. The prior distributions (pink) represent our knowledge before seeing the data, while the posterior distributions (blue) show what we learned after fitting the model to the data. Narrower posteriors centered near the true values indicate that our model effectively learned from the data.

10.7.3 Individual-Level Parameter Recovery

One of the key advantages of multilevel modeling is the ability to estimate parameters for individual agents. Let’s extract individual parameters and assess recovery:

# Extract individual agent parameters
agent_params <- tibble()

for (i in 1:agents) {
  # Extract parameters for this agent
  win_param <- mean(draws_ml$winM) + mean(draws_ml[[paste0("IDs[", i, ",1]")]])
  lose_param <- mean(draws_ml$loseM) + mean(draws_ml[[paste0("IDs[", i, ",2]")]])
  
  # Get true parameters
  true_data <- filter(d, strategy == "WSLS", agent == i, trial == 1)
  true_win <- first(true_data$betaWin)
  true_lose <- first(true_data$betaLose)
  
  # Add to dataframe
  agent_params <- bind_rows(
    agent_params,
    tibble(
      agent = i,
      win_estimated = win_param,
      lose_estimated = lose_param,
      win_true = true_win,
      lose_true = true_lose
    )
  )
}

# Create comparison plots
p1 <- ggplot(agent_params, aes(win_true, win_estimated)) +
  geom_point(alpha = 0.7) +
  geom_abline(intercept = 0, slope = 1, linetype = "dashed") +
  geom_smooth(method = "lm", color = "red", se = TRUE) +
  labs(
    title = "Individual Win-Stay Parameter Recovery",
    x = "True Win-Stay Parameter",
    y = "Estimated Win-Stay Parameter"
  ) +
  theme_minimal()

p2 <- ggplot(agent_params, aes(lose_true, lose_estimated)) +
  geom_point(alpha = 0.7) +
  geom_abline(intercept = 0, slope = 1, linetype = "dashed") +
  geom_smooth(method = "lm", color = "red", se = TRUE) +
  labs(
    title = "Individual Lose-Shift Parameter Recovery",
    x = "True Lose-Shift Parameter",
    y = "Estimated Lose-Shift Parameter"
  ) +
  theme_minimal()

# Calculate recovery statistics
win_corr <- cor(agent_params$win_true, agent_params$win_estimated)
lose_corr <- cor(agent_params$lose_true, agent_params$lose_estimated)

# Display correlation statistics
cat("Correlation between true and estimated Win-Stay parameters:", round(win_corr, 3), "\n")
## Correlation between true and estimated Win-Stay parameters: 0.746
cat("Correlation between true and estimated Lose-Shift parameters:", round(lose_corr, 3), "\n")
## Correlation between true and estimated Lose-Shift parameters: 0.666
# Create combined plot
p1 | p2

These scatter plots show how well our model recovers individual-level parameters. Points near the diagonal line (dashed) indicate accurate parameter recovery, while the red regression line shows the overall relationship between true and estimated values. High correlation coefficients suggest good recovery of individual differences.

10.7.4 Posterior Predictive Checks

Posterior predictive checks help us assess whether our model can generate data that resembles the observed data:

I’ll modify the posterior predictive checks chunk to use the regeneration_simulations flag for better efficiency. This will save the computed results and reload them if they already exist: rCopy# Posterior predictive checks with regeneration flag

# File path for saved predictive check data
pred_checks_file <- "simdata/W9_WSLS_predictive_checks.RDS"
pred_checks_uncertainty_file <- "simdata/W9_WSLS_predictive_uncertainty.RDS"

# Check if we need to recompute the predictive checks
if (regenerate_simulations || !file.exists(pred_checks_file) || !file.exists(pred_checks_uncertainty_file)) {
  # Extract samples for a few selected agents
  selected_agents <- c(1, 25, 50, 75)
  
  # Initialize data structures
  observed_data <- matrix(NA, nrow = trials-1, ncol = length(selected_agents))
  prior_means <- matrix(NA, nrow = trials-1, ncol = length(selected_agents))
  posterior_means <- matrix(NA, nrow = trials-1, ncol = length(selected_agents))
  
  # Also track quantiles for uncertainty bands
  prior_lower <- matrix(NA, nrow = trials-1, ncol = length(selected_agents))
  prior_upper <- matrix(NA, nrow = trials-1, ncol = length(selected_agents))
  posterior_lower <- matrix(NA, nrow = trials-1, ncol = length(selected_agents))
  posterior_upper <- matrix(NA, nrow = trials-1, ncol = length(selected_agents))
  
  # Fill observed data 
  for (i in seq_along(selected_agents)) {
    agent_idx <- selected_agents[i]
    observed_data[, i] <- data_wsls$h[1:(trials-1), agent_idx]
  }
  
  # Process the predictions trial by trial to reduce memory usage
  for (t in 1:(trials-1)) {
    cat("Processing trial", t, "of", trials-1, "\n")
    
    # Extract variable names for this trial
    prior_vars <- paste0("prior_preds[", t, ",", selected_agents, "]")
    post_vars <- paste0("posterior_preds[", t, ",", selected_agents, "]")
    
    # Extract predictions for all agents for this trial
    for (i in seq_along(selected_agents)) {
      prior_preds <- as.numeric(draws_ml[[prior_vars[i]]])
      post_preds <- as.numeric(draws_ml[[post_vars[i]]])
      
      # Calculate means
      prior_means[t, i] <- mean(prior_preds)
      posterior_means[t, i] <- mean(post_preds)
      
      # Calculate 95% credible intervals
      prior_lower[t, i] <- quantile(prior_preds, 0.025)
      prior_upper[t, i] <- quantile(prior_preds, 0.975)
      posterior_lower[t, i] <- quantile(post_preds, 0.025)
      posterior_upper[t, i] <- quantile(post_preds, 0.975)
    }
  }
  
  # Convert to long format for plotting
  pred_check_data <- tibble(
    trial = rep(2:trials, length(selected_agents)),  # Trial numbers (starting from 2)
    agent = rep(selected_agents, each = trials-1),
    observed = c(observed_data),
    prior_mean = c(prior_means),
    posterior_mean = c(posterior_means),
    agent_label = factor(paste("Agent", rep(selected_agents, each = trials-1)))
  )
  
  # Uncertainty data
  uncertainty_data <- tibble(
    trial = rep(2:trials, length(selected_agents)),
    agent = rep(selected_agents, each = trials-1),
    prior_lower = c(prior_lower),
    prior_upper = c(prior_upper),
    posterior_lower = c(posterior_lower),
    posterior_upper = c(posterior_upper),
    agent_label = factor(paste("Agent", rep(selected_agents, each = trials-1)))
  )
  
  # Calculate cumulative choice proportions
  pred_check_cumulative <- pred_check_data %>%
    group_by(agent_label) %>%
    mutate(
      trial_idx = trial - 1,  # Index starting from 1
      obs_cumulative = cumsum(observed) / seq_along(observed),
      prior_cumulative = cumsum(prior_mean) / seq_along(prior_mean),
      posterior_cumulative = cumsum(posterior_mean) / seq_along(posterior_mean)
    )
  
  # Save the results
  saveRDS(pred_check_cumulative, pred_checks_file)
  saveRDS(uncertainty_data, pred_checks_uncertainty_file)
  cat("Generated new predictive checks and saved to", pred_checks_file, "\n")
} else {
  # Load existing results
  pred_check_cumulative <- readRDS(pred_checks_file)
  uncertainty_data <- readRDS(pred_checks_uncertainty_file)
  cat("Loaded existing predictive checks from saved files\n")
}
## Loaded existing predictive checks from saved files
# Plot cumulative choice proportions - Posterior predictive check with uncertainty
p1 <- ggplot(pred_check_cumulative, aes(x = trial)) +
  # Add point-wise uncertainty intervals
  geom_ribbon(data = uncertainty_data, 
             aes(ymin = posterior_lower, ymax = posterior_upper),
             fill = "blue", alpha = 0.2) +
  # Add observed and predicted lines
  geom_line(aes(y = obs_cumulative, color = "Observed"), size = 1) +
  geom_line(aes(y = posterior_cumulative, color = "Posterior Predicted"), 
           size = 1, linetype = "dashed") +
  facet_wrap(~ agent_label, ncol = 2) +
  labs(
    title = "Posterior Predictive Check with Uncertainty",
    subtitle = "Observed (solid) vs. Posterior Predicted (dashed) with 95% CI bands",
    x = "Trial",
    y = "Proportion of 'Right' Choices",
    color = "Data Source"
  ) +
  scale_color_manual(values = c("Observed" = "black", "Posterior Predicted" = "blue")) +
  theme_minimal() +
  theme(legend.position = "bottom")

# Plot cumulative choice proportions - Prior predictive check with uncertainty
p2 <- ggplot(pred_check_cumulative, aes(x = trial)) +
  # Add point-wise uncertainty intervals
  geom_ribbon(data = uncertainty_data, 
             aes(ymin = prior_lower, ymax = prior_upper),
             fill = "red", alpha = 0.2) +
  # Add observed and predicted lines
  geom_line(aes(y = obs_cumulative, color = "Observed"), size = 1) +
  geom_line(aes(y = prior_cumulative, color = "Prior Predicted"), 
           size = 1, linetype = "dashed") +
  facet_wrap(~ agent_label, ncol = 2) +
  labs(
    title = "Prior Predictive Check with Uncertainty",
    subtitle = "Observed (solid) vs. Prior Predicted (dashed) with 95% CI bands",
    x = "Trial",
    y = "Proportion of 'Right' Choices",
    color = "Data Source"
  ) +
  scale_color_manual(values = c("Observed" = "black", "Prior Predicted" = "red")) +
  theme_minimal() +
  theme(legend.position = "bottom")

# Display both plots using patchwork
p1 / p2

This posterior predictive check compares the observed choice patterns (solid lines) with those predicted by our model (dashed lines) for a few selected agents. Close alignment indicates that our model captures the key patterns in the data well.

10.7.5 Model Comparison with LOO-CV

Finally, let’s compute Leave-One-Out Cross-Validation (LOO-CV) to assess our model’s predictive performance. We’ll also demonstrate how this could be used to compare our WSLS model with a simpler alternative:

# Compute LOO for the multilevel WSLS model
loo_wsls <- samples_wsls_ml$loo()

# Print LOO results
print(loo_wsls)
## 
## Computed from 4000 by 12000 log-likelihood matrix.
## 
##          Estimate    SE
## elpd_loo  -5687.6  62.8
## p_loo        99.2   1.5
## looic     11375.1 125.6
## ------
## MCSE of elpd_loo is NA.
## MCSE and ESS estimates assume MCMC draws (r_eff in [0.5, 1.6]).
## 
## Pareto k diagnostic values:
##                          Count Pct.    Min. ESS
## (-Inf, 0.7]   (good)     11900 99.2%   2900    
##    (0.7, 1]   (bad)          0  0.0%   <NA>    
##    (1, Inf)   (very bad)   100  0.8%   <NA>    
## See help('pareto-k-diagnostic') for details.
# Check Pareto k diagnostics
pareto_k_table <- table(loo_wsls$diagnostics$pareto_k > 0.7)
cat("Number of observations with Pareto k > 0.7:", pareto_k_table["TRUE"], "\n")
## Number of observations with Pareto k > 0.7: 100

The LOO-CV computation provides an estimate of the model’s expected predictive accuracy. In a full analysis, we would compare this with alternative models to determine which provides the best balance of fit and generalizability.

# Visualize distribution of pointwise elpd values
elpd_data <- tibble(
  observation = 1:length(loo_wsls$pointwise[,"elpd_loo"]),
  elpd = loo_wsls$pointwise[,"elpd_loo"]
)

# Create histogram of elpd values
ggplot(elpd_data, aes(x = elpd)) +
  geom_histogram(bins = 30, fill = "steelblue", color = "black", alpha = 0.7) +
  geom_vline(aes(xintercept = mean(elpd)), color = "darkred", linetype = "dashed", size = 1) +
  labs(
    title = "Distribution of Pointwise Expected Log Predictive Density (ELPD)",
    subtitle = "Higher values indicate better prediction for individual observations",
    x = "ELPD",
    y = "Count"
  ) +
  theme_minimal()

This visualization shows the distribution of pointwise expected log predictive density (ELPD) values, with the mean indicated by the dashed line. Observations with higher ELPD values are better predicted by our model. A long left tail would suggest some observations are particularly difficult for the model to predict and in a real project we should explore what the issues are.