15  Categorization Models: Prototypes

📍 Where we are in the Bayesian modeling workflow: Chapter 13 built the GCM — an exemplar model that stores every encountered stimulus and decides by summed similarity to memory. This chapter introduces a structurally different account: the prototype model, which abstracts each category to a single running estimate (its mean and uncertainty) rather than storing individual instances. We implement prototype learning as a multivariate Kalman filter and then add a Stan parameter-estimation model for the prototype architecture.

The full validation pipeline from Ch. 5 — prior predictive check, Pathfinder initialization, mandatory MCMC diagnostic battery, posterior predictive check, prior–posterior update plot, randomised LOO-PIT, LFO-CV, precision analysis, prior sensitivity, and SBC — is applied to this new model. Crucially, the LFO-CV machinery implemented in Chapter 13 (psis_lfo_gcm) is generalised here as psis_lfo_kalman and applied to the prototype model. The chapter then extends to a multilevel prototype model with population-level priors on log_r, per-subject NCP offsets, and the full multilevel SBC battery (population and individual calibration).

15.1 Rethinking Category Representation: From Examples to Averages

In the previous chapter, we explored the Generalized Context Model (GCM), an exemplar model where every encountered example is stored in memory. This approach is powerful but raises questions about cognitive efficiency. Do we really store every single instance we encounter?

Consider learning about common categories like “birds” or “chairs”. While specific examples matter, we also seem to develop a general sense of what constitutes a typical member. Prototype theory offers an alternative perspective: instead of storing individual exemplars, the mind abstracts a summary representation — a prototype — for each category. This prototype often represents the central tendency or “average” member.

15.2 Core Ideas of Prototype Models

Feature Prototype Models Exemplar Models (GCM)
Representation Abstract summary (e.g., average) Collection of individual examples
Memory Cost Low (one prototype per category) High (potentially all examples)
Key Information Central tendency & variance Specific instances & their labels
Sensitivity Less sensitive to specific examples Highly sensitive to specific examples
NoteHistorical Context: From Typicality Effects to Dynamic Prototypes

The prototype approach to categorization did not emerge fully formed. It arose as the solution to a specific failure mode of the dominant view (rules) and evolved in a debate with alternative approaches (exemplars).

15.2.0.1 The Classical View and Its Discontents

Through the 1950s and into the 1960s, both philosophers and psychologists took it for granted that concepts had defining features: necessary and sufficient conditions that every member must satisfy. A bachelor is an unmarried adult male — full stop. Categorization, on this view, was logical: check the features, apply the rule. The model is appealingly crisp, but Wittgenstein had already noted the trouble in Philosophical Investigations (1953): most ordinary categories — game, furniture, bird — resist such definitions. Family resemblance, not shared essence, seemed to be the operative structure.

The empirical broadside came from Eleanor Rosch (E. H. Rosch 1973; E. Rosch 1975; E. Rosch et al. 1976). In a series of influential experiments she documented typicality effects: robins are judged faster and more confidently as birds than penguins are; desk chairs are more chair-like than beanbags. Categories showed graded membership rather than all-or-nothing inclusion. Rosch also identified a privileged basic level — the level of abstraction (dog, not animal and not poodle) at which objects are most efficiently perceived, named, and remembered. Neither observation is compatible with the classical view, which predicts uniform membership within a category and no principled distinction across levels.

15.2.0.2 The Prototype as Psychological Mechanism

Rosch’s work raised the question of representation: if not a rule, what mental structure supports graded membership? The answer on offer was a prototype — a summary representation encoding the central tendency of encountered category members. New objects are categorized by their distance from stored prototypes; items close to the prototype are typical, items distant are atypical.

Formal prototype models predated Rosch’s program. Posner and Keele (1968) showed that subjects who had studied distorted dot patterns without ever seeing the prototype nonetheless recognized the prototype at the end as highly familiar — abstraction was occurring implicitly. Reed (1972) implemented an explicit prototype-matching classifier and showed it outperformed several alternatives on face-classification data. These early models computed a static average from training examples: the prototype was fixed after learning and then used as a lookup.

15.2.0.3 Static Prototypes and the Exemplar Challenge

The clean elegance of static prototype models made them obvious targets for empirical attack. Medin and Schaffer (1978) showed that, on certain category structures, the Context Model (a pure exemplar model) fit human data significantly better than a prototype model could. The critical cases were exception items: an atypical member of one category that closely resembles the other category’s prototype. Exemplar models accommodate exceptions naturally — the deviant item is stored individually. A static prototype, by definition, assimilates the deviant item into the average, smoothing away the very information that drives behavior.

Through the 1980s and early 1990s, the prototype-versus-exemplar debate generated substantial empirical literature. The consensus, codified in Smith and Minda (1998), was that neither account was universally superior: prototype-like effects dominated in some conditions (large training sets, brief study times, transfer to novel stimuli) while exemplar-like effects dominated in others (small training sets, extended learning, highly distinctive items). The field was left with a theoretical draw that demanded either a hybrid architecture or a shift in the question being asked.

15.2.0.4 From Static Averaging to Bayesian Updating

We do not take a stance here on the prototype-versus-exemplar debate. The previous chapter implemented exemplar models in some detail; this chapter sticks with prototype representations precisely so we can understand them on their own terms: what changes when the summary representation is allowed to update incrementally, how its parameters are recovered, and how it behaves under different learning environments.

Within that scope, the focus shifts from what is stored to how the stored summary is updated. A static average treats all observed members the same regardless of when they were encountered, and provides no mechanism for tracking categories that drift over time: the clothing you learn is “fashionable” today may not be fashionable in five years. Several incremental prototype schemes have been proposed in response. The simplest is a delta rule: \(\mu_{t} = \mu_{t-1} + \alpha\,(x_t - \mu_{t-1})\), with a fixed learning rate \(\alpha\) that yields exponentially weighted recency (Gluck and Bower 1988; Estes 1994). You should recognize this rule from reinforcement learning :-) Anderson (1991)’s rational model takes a fully Bayesian route, but over a Dirichlet-process partition rather than a parametric prototype. Each captures part of the picture: the delta rule is principled in form but its \(\alpha\) is a free knob disconnected from the learner’s uncertainty as it faces the stimuli; Anderson’s model is Bayesian but solves a different problem (discovering the partition) and is harder to fit at the trial level.

The Kalman filter (Kalman 1960), developed originally for aerospace navigation, sits [I think, nobody else has used it for this!] in a useful intermediate position for prototype modelling. It is the minimum-variance unbiased estimator for a linear-Gaussian dynamical system: at each step it computes exactly how much weight to place on the new observation versus the accumulated prior, using the relative magnitudes of measurement noise \(R\) and process noise \(Q\). The learning rate (the Kalman gain) is not a free parameter but is derived at each trial from current uncertainty — recovering a delta-rule-like update whose effective \(\alpha\) shrinks as evidence accumulates and grows again when the world is assumed to drift. When \(Q = 0\) the prototype is assumed static and the gain decays monotonically — recovering the classical static average as a special case. When \(Q > 0\) the prototype can drift and the steady-state gain remains positive, allowing perpetual sensitivity to new observations.

Applications of Kalman-style updating to human learning appear across multiple domains: classical conditioning (Dayan, Kakade, and Montague 2000), interval timing, and reward prediction. The two-parameter family \((r, q)\) instantiates a spectrum from steep, responsive updating (\(r\) small, gain high) to heavily smoothed abstraction (\(r\) large, gain low, slow forgetting), with \(q\) controlling whether the steady-state gain decays to zero or saturates above it.

15.2.0.5 The Bayesian Connection

Importantly for us [well, let’s be honest, for me, but I’ll take you along for the ride], the Kalman filter is a conceptually Bayesian method: it is the exact Bayesian posterior for a Gaussian state-space model. The prototype at trial \(t\) is the posterior mean \(\mathbb{E}[\mu_t \mid x_{1:t}]\); the covariance \(\Sigma_t\) is the posterior variance. Recognising this identity connects the prototype model to the broader programme of rational analysis (Anderson 1991) and Bayesian models of cognition (Tenenbaum et al. 2011): the learner is performing optimal inference about a latent category structure given noisy, sequentially arriving evidence. The parameters \(r\) and \(q\) then carry clear psychological interpretations — observation noise (how variable are category members around the prototype?) and process noise (how rapidly does the category itself change?) — rather than being purely phenomenological fitting coefficients.

This framing also clarifies the key difference from the GCM: where the GCM stores the full history and retrieves it at decision time, the Kalman prototype model compresses the history into sufficient statistics (mean and covariance) that are updated recursively. The compression is lossless when the generative model is Gaussian and linear — which is why recovery and SBC are cleaner here than in most cognitive models of comparable scope.

15.3 Modeling Dynamic Prototypes: The Kalman Filter

Early prototype models often calculated a static average. However, human learning is dynamic; our understanding evolves with experience. How can we model a prototype that updates incrementally as new examples are encountered?

The Kalman filter provides a suitable mathematical framework. Originally developed for tracking physical systems amidst noise, it’s well-suited for modeling prototype learning because it allows us to:

  • Track an Estimate: The category prototype (its average feature values).
  • Represent Uncertainty: Maintain not just the prototype’s location but also our uncertainty about it.
  • Update Incrementally: Refine the prototype with each new example.
  • Balance Old and New: Optimally combine the current prototype (prior belief) with the new example (evidence).
  • Implement an Adaptive Learning Rate: Automatically adjust how much influence a new example has (the “Kalman gain”) based on current uncertainty. Learning is faster when uncertainty is high and slows as confidence grows.

15.3.1 The Kalman Filter Logic: Updating an Estimate

Imagine tracking a prototype based on a single feature (e.g., average height). The Kalman filter maintains your Current State. At any moment, you have two key pieces of information:

  • Your Best Guess (\(\mu\)): This is your current estimate of the prototype’s average height.
  • Your Confidence (\(\sigma^2\)): This isn’t just if you’re confident, but how much. It’s the variance around your guess. A small variance means you’re quite sure; a large variance means you’re very unsure.

When a new category member (observation \(x\)) arrives, you measure its height.

  • The Catch: Your measurement tool (or the example itself) isn’t perfect. There’s some inherent “noise” or variability in how examples relate to the true prototype. We represent the uncertainty of this measurement process with another variance, often called \(R\).

The Core Problem: You have your guess (with \(\sigma^2_{prev}\)) and a new measurement (with \(R\)). How do you combine them to get a new, better guess and an updated level of uncertainty?

Enter the Kalman Gain (\(K\)): The “Trust” Knob

The Kalman filter calculates the Kalman Gain \(K\), which tells you how much to trust the new measurement compared to your current guess:

\[K = \frac{\sigma^2_{prev}}{\sigma^2_{prev} + R}\]

If \(\sigma^2_{prev}\) is large relative to \(R\), then \(K \to 1\): trust the measurement almost completely. If \(\sigma^2_{prev}\) is small relative to \(R\), then \(K \to 0\): stick closer to your previous guess.

Updating Your Guess:

You nudge your old guess towards the new measurement. The size of the nudge is determined by the Kalman Gain \(K\):

\[\mu_{new} = \mu_{prev} + K \cdot (x - \mu_{prev})\]

The term \((x - \mu_{prev})\) is the “prediction error” or “innovation” — how much the new data differs from what you expected.

Updating Your Confidence (Uncertainty):

After incorporating new information, uncertainty decreases:

\[\sigma^2_{new} = (1 - K) \cdot \sigma^2_{prev}\]

Since \(K \in (0, 1)\), multiplying by \((1 - K)\) always shrinks uncertainty. The more you trusted the measurement (higher \(K\)), the more your uncertainty shrinks.

Extending to drifting targets. The scalar Kalman filter above assumes a static target: uncertainty monotonically decreases because there is no mechanism for it to grow back. The prototype model in this chapter adds a process noise term \(Q\). Before each measurement update, a prediction step is applied:

\[\sigma^2_{k|k-1} = \sigma^2_{k-1|k-1} + Q\]

This increases uncertainty between observations, reflecting the possibility that the true prototype has shifted since the last trial. The subsequent measurement update then decreases uncertainty via the Kalman gain, exactly as before. At steady state, the two forces balance and uncertainty converges to a fixed value rather than shrinking to zero. When \(Q = 0\) the model reduces to the static-target filter. We treat the scalar \(q\) (the diagonal of \(Q = q \cdot I\)) as a free parameter to be estimated from data, alongside the observation noise \(r\).

15.3.2 Extending to Multiple Features: Multivariate Kalman Filter

Our stimuli have two features (height and position). We need a multivariate version:

  • Your Current State (Vectors and Matrices):
    • Your Best Guess (\(\vec{\mu}\)): Now a vector containing your best guess for each feature.
    • Your Confidence (\(\Sigma\)): Now a covariance matrix. The diagonal elements give the variance for each feature individually; the off-diagonal elements give their covariance — for example, whether taller prototypes are also typically further to the right.
  • A New Hint Arrives: Your new observation (\(\vec{x}\)) is also a vector.
    • Measurement Noise (\(R\)): Also a covariance matrix. Often simplified to a diagonal matrix, assuming measurement errors on different features are independent.
  • The Logic Stays the Same, the Math Uses Matrices:
    • Kalman Gain (\(K\)): A matrix: \(K = \Sigma_{prev} (\Sigma_{prev} + R)^{-1}\)
    • Updating Your Guess (\(\vec{\mu}\)): \(\vec{\mu}_{new} = \vec{\mu}_{prev} + K (\vec{x} - \vec{\mu}_{prev})\)
    • Updating Your Confidence (\(\Sigma\)): The intuition is the same as the scalar case — uncertainty shrinks after each observation — but now in matrix form. The textbook update is \(\Sigma_{new} = (I - K)\Sigma_{prev}\), which can be read as: “keep the share of the prior uncertainty that the new observation did not explain away.” In code, however, we use the algebraically equivalent Joseph form, \(\Sigma_{new} = (I - K)\Sigma_{prev}(I - K)^T + KRK^T\). This longer expression makes the two sources of remaining uncertainty explicit: the leftover prior uncertainty (first term) plus the measurement noise that leaks in through the Kalman gain (second term). The reason for using it in practice is purely numerical: a covariance matrix must be symmetric and have non-negative variances, and the Joseph form preserves both properties even when small rounding errors accumulate over many trials, whereas the shorter form can drift into invalid (asymmetric or negative-variance) territory.

15.3.3 In a Nutshell

The Kalman filter tracks a prototype as a running estimate \(\vec{\mu}\) and a covariance \(\Sigma\) that quantifies how confident the learner is. Each trial has two phases: a predict step that inflates \(\Sigma\) by \(Q\) to allow the prototype to drift between observations, and an update step that pulls \(\vec{\mu}\) toward the new exemplar by an amount set by the Kalman gain \(K\) — high when the learner is uncertain or measurements are precise, low when the learner is already confident or measurements are noisy. The gain also determines how much \(\Sigma\) shrinks. At steady state, drift (\(Q\)) and shrinkage balance, so uncertainty stops collapsing to zero and the model keeps learning indefinitely.

15.3.4 Implementing the Multivariate Update in R

Code
# One update step of a multivariate Kalman filter.
# Returns updated mean vector, covariance matrix, and Kalman gain.
multivariate_kalman_update <- function(mu_prev,      # previous mean vector
                                       sigma_prev,    # previous covariance matrix
                                       observation,   # observed feature vector
                                       r_matrix) {    # observation noise matrix
  mu_prev     <- as.numeric(mu_prev)
  observation <- as.numeric(observation)
  sigma_prev  <- as.matrix(sigma_prev)
  r_matrix    <- as.matrix(r_matrix)

  n_dim <- length(mu_prev)
  I     <- diag(n_dim)

  # Kalman gain: K = Sigma_prev * (Sigma_prev + R)^{-1}
  S     <- sigma_prev + r_matrix
  S_inv <- tryCatch(
    solve(S),
    error = function(e) {
      warning("Matrix inversion failed; using pseudo-inverse.", call. = FALSE)
      MASS::ginv(S)
    }
  )
  K <- sigma_prev %*% S_inv

  # Update mean
  innovation <- observation - mu_prev
  mu_new     <- as.numeric(mu_prev + K %*% innovation)

  # Update covariance — Joseph form for numerical stability
  IK        <- I - K
  sigma_new <- IK %*% sigma_prev %*% t(IK) + K %*% r_matrix %*% t(K)

  # Enforce symmetry (floating-point can introduce tiny asymmetry)
  sigma_new <- (sigma_new + t(sigma_new)) / 2

  list(mu = mu_new, sigma = sigma_new, k = K)
}

This function performs a single update for one category’s prototype when it observes a new member.


We are fitting a two-parameter model: \(r\) controls both Kalman gain and decisional precision; \(q\) controls prototype drift speed and steady-state uncertainty. Both are inferred from data on the log scale. There is no category bias, fixed initial means, fixed initial covariances, and diagonal noise structures for both \(R\) and \(Q\). The Stan implementation in §“The Prototype Model in Stan” implements exactly this.

Code
\usetikzlibrary{bayesnet}
\begin{tikzpicture}
  % ── Hyperparameters (const) ─────────────────────────────
  \node[const] (hyp_r) at (0, 0) {$\mu_r,\,\sigma_r$};
  \node[const] (hyp_q) at (6, 0) {$\mu_q,\,\sigma_q$};

  % ── Latent parameters ───────────────────────────────────
  \node[latent] (logr) at (0, -2) {$\log r$};
  \node[latent] (logq) at (6, -2) {$\log q$};

  % ── Trial-level nodes (inside plate) ────────────────────
  \node[obs] (xi)    at (1.5, -4.5) {$\mathbf{x}_i$};
  \node[obs] (ci)    at (3.5, -4.5) {$c_i$};
  \node[det] (state) at (5.5, -4.5) {$\boldsymbol{\mu}_i,\Sigma_i$};
  \node[det] (pi)    at (3.5, -6.0) {$p_i$};
  \node[obs] (yi)    at (3.5, -7.5) {$y_i$};

  % ── Edges ───────────────────────────────────────────────
  \edge {hyp_r} {logr};
  \edge {hyp_q} {logq};
  \edge {logq}  {state};
  \edge {ci}    {state};
  \edge {state} {pi};
  \edge {xi}    {pi};
  \edge {logr}  {pi};
  \edge {pi}    {yi};

  % ── Self-loop on state (sequential update) ─────────────
  \draw[->, >=latex] (state.north east) .. controls +(1.2,0.6) and +(1.2,-0.6) .. (state.south east);

  % ── Plate ───────────────────────────────────────────────
  \plate {trial} {(xi)(ci)(state)(pi)(yi)} {$N$ trials};
\end{tikzpicture}

DAG for the single-subject prototype (Kalman) model. Hyperparameters (dots) feed into the two latent parameters (circles): \(\log r\) (observation noise) and \(\log q\) (process noise). Inside the trial plate, the stimulus \(\mathbf{x}_i\) and feedback \(c_i\) combine with the carried-forward state \((\boldsymbol{\mu}_i, \Sigma_i)\) — which is updated across trials via the Kalman predict/update steps — to yield the deterministic choice probability \(p_i\) (double circle), from which the response \(y_i\) is drawn. The self-loop on the state node represents the sequential prototype update.

15.4 Building the Full Categorization Model

We now integrate this update mechanism into a full agent that learns prototypes for two categories and makes categorization decisions.

15.4.1 Initialization

  • We need separate prototypes (\(\mu_0, \Sigma_0\) and \(\mu_1, \Sigma_1\)) for Category 0 and Category 1.
  • We start with high uncertainty (large initial \(\Sigma\)) and uninformative means at the center of the feature space.
  • We define the observation noise matrix \(R = r_\text{val} \cdot I\). The parameter \(r_\text{val}\) represents how much variability we assume exists within categories or in our perception of stimuli.
  • We define the process noise matrix \(Q = q_\text{val} \cdot I\). The parameter \(q_\text{val}\) controls how much the category prototype may drift between successive observations. Small \(q\) means near-static categories; large \(q\) means rapidly shifting ones.

15.4.2 Categorization Decision

How does the model decide which category a new stimulus \(\vec{x}\) belongs to? The decision has three steps: compute a distance from the stimulus to each prototype, convert distance to similarity, then apply a choice rule.

Step 1 — Distance. The distance between stimulus \(\vec{x}\) and category \(C\)’s prototype is the Mahalanobis distance:

\[d_C = (\vec{x} - \vec{\mu}_C)^T (\Sigma_C + R)^{-1} (\vec{x} - \vec{\mu}_C)\]

Breaking this down for two features (height and position):

\[d_C = \frac{(x_\text{height} - \mu_{C,\text{height}})^2}{\sigma_{C,\text{height}}^2 + r} + \frac{(x_\text{position} - \mu_{C,\text{position}})^2}{\sigma_{C,\text{position}}^2 + r}\]

Each feature contributes its squared difference divided by total uncertainty on that dimension. The three quantities that enter are:

  • \(\vec{\mu}_C\): the prototype mean — the Kalman filter’s current best estimate of the category centre.
  • \(\Sigma_C\): the prototype uncertainty — a matrix tracking how confident the filter is about that estimate. It starts large (high uncertainty) and shrinks as more examples are seen. High uncertainty on dimension \(k\) means differences on that dimension count for less in the distance.
  • \(R = r \cdot I\): observation noise — a fixed matrix reflecting how variable category members are around the prototype, or how noisy perception is. Together with \(\Sigma_C\), it forms the total uncertainty \(\Sigma_C + R\) that normalises each squared difference.

Step 2 — Similarity. Distance is converted to similarity by an exponential decay:

\[\eta_C = \exp\!\left(-\tfrac{1}{2}\, d_C\right)\]

This is the multivariate normal density evaluated at \(\vec{x}\). Similarity is 1 when the stimulus sits exactly on the prototype mean (\(d_C = 0\)) and falls toward 0 as distance grows. The parameter \(r\) controls how steeply: large \(r\) inflates the denominators in \(d_C\), shrinks the distance, and produces a flat bell curve — the model generalises broadly. Small \(r\) produces a sharp bell curve — only stimuli very close to the prototype receive high similarity. This is the role that the sensitivity parameter \(c\) plays in the GCM’s \(\exp(-c \cdot d)\); here it is embedded inside the distance rather than sitting outside it.

Step 3 — Choice rule. The probability of choosing Category 1 is the proportion of total similarity belonging to that category (Luce choice rule):

\[P(\text{Choose Cat 1} \mid \vec{x}) = \frac{\eta_1}{\eta_0 + \eta_1}\]

We work with log-probabilities for numerical stability. There is no separate softmax temperature: \(r\) already controls decisional sharpness through the similarity step.

15.4.3 The Learning Loop

The agent processes trials sequentially:

  1. Predict: Add \(Q\) to both category covariances — \(\Sigma_{0} \mathrel{+}= Q\), \(\Sigma_{1} \mathrel{+}= Q\) — reflecting that each prototype may have drifted since the last observation.
  2. Decide: Calculate \(P(\text{Choose Cat 1} \mid \vec{x}_i)\) based on the post-prediction \(\mu_0, \Sigma_0, \mu_1, \Sigma_1\).
  3. Respond: Generate a response.
  4. Receive feedback: Observe true category \(y_i\).
  5. Update: Apply the Kalman measurement update to the correct category only.

15.4.4 R Implementation of the Prototype Agent

Each agent’s learning trajectory depends on the order in which stimuli are encountered, and different randomizations produce genuinely different trajectories.

Code
# Prototype model agent using Kalman filter with process noise for categorization.
# r_value   : observation noise variance (scalar > 0)
# q_value   : process noise variance / prototype drift rate (scalar >= 0)
# obs       : matrix of observations (trials x features)
# cat_one   : vector of true category labels (0 or 1) for feedback
# Returns:  tibble with prob_cat1 and sim_response
prototype_kalman <- function(r_value,
                             q_value,
                             obs,
                             cat_one,
                             initial_mu         = NULL,
                             initial_sigma_diag = 10.0,
                             quiet              = TRUE) {
  n_trials   <- nrow(obs)
  n_features <- ncol(obs)

  if (is.null(initial_mu)) {
    mu0_init <- rep(2.5, n_features)
    mu1_init <- rep(2.5, n_features)
  } else {
    mu0_init <- initial_mu[[1]]
    mu1_init <- initial_mu[[2]]
  }

  prototype_cat_0 <- list(mu = mu0_init, sigma = diag(initial_sigma_diag, n_features))
  prototype_cat_1 <- list(mu = mu1_init, sigma = diag(initial_sigma_diag, n_features))
  r_matrix        <- diag(r_value, n_features)
  q_matrix        <- diag(q_value, n_features)

  response_probs <- numeric(n_trials)

  log_sum_exp <- function(v) {
    m <- max(v)
    m + log(sum(exp(v - m)))
  }

  for (i in seq_len(n_trials)) {
    if (!quiet && i %% 20 == 0) cat("Trial", i, "\n")

    current_obs <- as.numeric(obs[i, ])

    # ── Prediction step: add process noise to both prototypes ──────────────
    # Reflects potential drift in category location since last trial.
    # When q_value = 0 this reduces to the static-target filter.
    prototype_cat_0$sigma <- prototype_cat_0$sigma + q_matrix
    prototype_cat_1$sigma <- prototype_cat_1$sigma + q_matrix

    # ── Decision ─────────────────────────────────────────────────────────────
    cov_cat_0 <- prototype_cat_0$sigma + r_matrix
    cov_cat_1 <- prototype_cat_1$sigma + r_matrix

    log_prob_0 <- tryCatch(
      mvtnorm::dmvnorm(current_obs, mean = prototype_cat_0$mu,
                       sigma = cov_cat_0, log = TRUE),
      error = function(e) -Inf
    )
    log_prob_1 <- tryCatch(
      mvtnorm::dmvnorm(current_obs, mean = prototype_cat_1$mu,
                       sigma = cov_cat_1, log = TRUE),
      error = function(e) -Inf
    )

    if (!is.finite(log_prob_0) && !is.finite(log_prob_1)) {
      prob_cat_1 <- 0.5
    } else if (!is.finite(log_prob_0)) {
      prob_cat_1 <- 1.0
    } else if (!is.finite(log_prob_1)) {
      prob_cat_1 <- 0.0
    } else {
      prob_cat_1 <- exp(log_prob_1 - log_sum_exp(c(log_prob_0, log_prob_1)))
    }
    response_probs[i] <- pmax(1e-9, pmin(1 - 1e-9, prob_cat_1))

    # ── Update (measurement update for the correct category only) ────────────
    true_cat <- cat_one[i]
    if (true_cat == 1) {
      upd <- multivariate_kalman_update(prototype_cat_1$mu, prototype_cat_1$sigma,
                                        current_obs, r_matrix)
      prototype_cat_1$mu    <- upd$mu
      prototype_cat_1$sigma <- upd$sigma
    } else {
      upd <- multivariate_kalman_update(prototype_cat_0$mu, prototype_cat_0$sigma,
                                        current_obs, r_matrix)
      prototype_cat_0$mu    <- upd$mu
      prototype_cat_0$sigma <- upd$sigma
    }
  }

  tibble(
    prob_cat1    = response_probs,
    sim_response = rbinom(n_trials, 1, response_probs)
  )
}

# Wrapper: generates per-agent schedule then calls prototype_kalman
simulate_prototype_agent <- function(agent_id, r_value, q_value,
                                     stimulus_info, n_blocks, subject_seed) {
  schedule <- make_subject_schedule(stimulus_info, n_blocks, seed = subject_seed)
  obs      <- as.matrix(schedule[, c("height", "position")])
  cat_one  <- schedule$category_feedback

  result <- prototype_kalman(
    r_value            = r_value,
    q_value            = q_value,
    obs                = obs,
    cat_one            = cat_one,
    initial_mu         = list(rep(2.5, 2), rep(2.5, 2)),
    initial_sigma_diag = 10.0,
    quiet              = TRUE
  )

  schedule |>
    mutate(
      agent_id     = agent_id,
      r_value_true = r_value,
      q_value_true = q_value,
      log_r_true   = log(r_value),
      log_q_true   = log(q_value),
      prob_cat1    = result$prob_cat1,
      sim_response = result$sim_response,
      correct      = as.integer(category_feedback == sim_response)
    ) |>
    group_by(agent_id) |>
    mutate(performance = cumsum(correct) / row_number()) |>
    ungroup()
}

15.4.5 Simulating Categorization Behavior

Let’s simulate behavior on the Kruschke (1993) task. The key parameter here is r_value, the observation noise. This parameter reflects the model’s assumption about how much variability exists within categories or arises from perceptual noise.

  • Low r_value: The model assumes observations are precise representations of the category; each new example exerts strong influence and prototypes update rapidly.
  • High r_value: The model assumes observations are noisy; learning is slower as each example has less influence, but the model generalizes more broadly.
Code
param_df <- expand_grid(
  agent_id  = 1:5,
  r_value   = c(0.5, 2.0),
  q_value   = c(0.001, 0.1, 0.5)
) |>
  mutate(subject_seed = agent_id)

sim_file <- here("simdata", "ch13_prototype_simulated_responses.csv")
if (regenerate_simulations || !file.exists(sim_file)) {
  cat("Regenerating prototype simulations...\n")
  prototype_responses <- future_pmap_dfr(
    list(
      agent_id     = param_df$agent_id,
      r_value      = param_df$r_value,
      q_value      = param_df$q_value,
      subject_seed = param_df$subject_seed
    ),
    simulate_prototype_agent,
    stimulus_info = stimulus_info,
    n_blocks      = n_blocks,
    .options = furrr_options(seed = TRUE)
  )
  write_csv(prototype_responses, sim_file)
  cat("Simulations saved.\n")
} else {
  prototype_responses <- read_csv(sim_file, show_col_types = FALSE)
  cat("Simulations loaded.\n")
}
Simulations loaded.
Code
ggplot(prototype_responses,
       aes(x = trial_within_subject, y = performance, color = factor(r_value_true))) +
  stat_summary(fun = mean, geom = "line", linewidth = 1) +
  stat_summary(fun.data = mean_se, geom = "ribbon", alpha = 0.15,
               aes(fill = factor(r_value_true))) +
  scale_color_viridis_d(option = "plasma", name = "r") +
  scale_fill_viridis_d(option = "plasma", name = "r") +
  facet_wrap(~paste0("q = ", q_value_true), ncol = 3) +
  labs(
    title    = "Prototype Model Learning Performance",
    subtitle = "Columns: process noise q; colors: observation noise r; mean ± SE across agents",
    x = "Trial", y = "Cumulative Accuracy"
  ) +
  ylim(0.3, 1.0)

Interpretation: The panel structure separates the effects of \(r\) (colors) and \(q\) (columns). Horizontally (fixed \(q\)), higher \(r\) slows learning and lowers asymptotic accuracy. Vertically (fixed \(r\)), higher \(q\) introduces persistent uncertainty — the model cannot converge to a stable prototype, so accuracy may plateau below ceiling or even oscillate slightly. When \(q = 0\) the model reduces to a static-target kalman filter (assumptions of categories as static). The key pedagogical point: \(q\) and \(r\) are not redundant — they affect different phases of the learning curve and, in non-stationary environments, should become separately identifiable. But we should have learned a lesson from the previous chapter and not assume: let’s see if they are identifiable in practice.

15.4.6 Visualizing Prototype Evolution

Where do the prototypes end up? Let’s track their movement and uncertainty over trials.

Code
# Track prototype means and covariance matrices over the course of learning.
# Returns a tibble with one row per (trial × category).
track_prototypes <- function(r_value, obs, cat_one,
                             q_value = 0,
                             initial_mu = NULL, initial_sigma_diag = 10.0) {
  n_trials  <- nrow(obs)
  n_features <- ncol(obs)

  if (is.null(initial_mu)) {
    mu0_init <- rep(2.5, n_features)
    mu1_init <- rep(2.5, n_features)
  } else {
    mu0_init <- initial_mu[[1]]
    mu1_init <- initial_mu[[2]]
  }
  prototype_cat_0 <- list(mu = mu0_init, sigma = diag(initial_sigma_diag, n_features))
  prototype_cat_1 <- list(mu = mu1_init, sigma = diag(initial_sigma_diag, n_features))
  r_matrix        <- diag(r_value, n_features)

  history <- list()

  for (i in seq_len(n_trials + 1)) {
    history[[length(history) + 1]] <- tibble(
      trial      = i, category = 0,
      feature1_mean = prototype_cat_0$mu[1],
      feature2_mean = prototype_cat_0$mu[2],
      cov_matrix    = list(prototype_cat_0$sigma)
    )
    history[[length(history) + 1]] <- tibble(
      trial      = i, category = 1,
      feature1_mean = prototype_cat_1$mu[1],
      feature2_mean = prototype_cat_1$mu[2],
      cov_matrix    = list(prototype_cat_1$sigma)
    )

    if (i <= n_trials) {
      # Prediction step
      prototype_cat_0$sigma <- prototype_cat_0$sigma + diag(q_value, n_features)
      prototype_cat_1$sigma <- prototype_cat_1$sigma + diag(q_value, n_features)

      current_obs <- as.numeric(obs[i, ])
      true_cat    <- cat_one[i]
      if (true_cat == 1) {
        upd <- multivariate_kalman_update(prototype_cat_1$mu, prototype_cat_1$sigma,
                                          current_obs, r_matrix)
        prototype_cat_1$mu    <- upd$mu
        prototype_cat_1$sigma <- upd$sigma
      } else {
        upd <- multivariate_kalman_update(prototype_cat_0$mu, prototype_cat_0$sigma,
                                          current_obs, r_matrix)
        prototype_cat_0$mu    <- upd$mu
        prototype_cat_0$sigma <- upd$sigma
      }
    }
  }
  bind_rows(history)
}

# Ellipse helper for uncertainty visualization
get_ellipse <- function(mu, sigma, level = 0.68) {
  mu    <- as.numeric(mu)
  sigma <- as.matrix(sigma)
  if (length(mu) != 2 || !all(dim(sigma) == c(2, 2))) return(NULL)
  ev <- eigen(sigma, symmetric = TRUE, only.values = TRUE)$values
  if (any(ev <= 1e-6)) sigma <- sigma + diag(ncol(sigma)) * 1e-6
  pts <- tryCatch(
    ellipse::ellipse(sigma, centre = mu, level = level),
    error = function(e) NULL
  )
  if (is.null(pts)) return(NULL)
  as_tibble(pts) |> setNames(c("feature1_mean", "feature2_mean"))
}

# Use a fixed seed-1 schedule for the trajectory visualization
traj_schedule <- make_subject_schedule(stimulus_info, n_blocks, seed = 1)

prototype_trajectory <- track_prototypes(
  r_value            = 1.0,
  q_value            = 0.05,
  obs                = as.matrix(traj_schedule[, c("height", "position")]),
  cat_one            = traj_schedule$category_feedback,
  initial_mu         = list(rep(2.5, 2), rep(2.5, 2)),
  initial_sigma_diag = 10.0
)

final_prototypes    <- prototype_trajectory |> filter(trial == max(trial))

ellipse_data <- final_prototypes |>
  rowwise() |>
  mutate(ell = list(get_ellipse(c(feature1_mean, feature2_mean), cov_matrix[[1]]))) |>
  ungroup() |>
  filter(!map_lgl(ell, is.null)) |>
  unnest(ell)

ggplot() +
  geom_point(data = stimulus_info,
             aes(x = position, y = height, color = factor(category_true),
                 shape = factor(category_true)),
             size = 3, alpha = 0.5) +
  geom_path(
    data = prototype_trajectory |> filter(trial <= total_trials),
    aes(x = feature2_mean, y = feature1_mean,
        group = category, color = factor(category)),
    linetype = "dashed",
    arrow = arrow(type = "closed", length = unit(0.08, "inches"), angle = 20)
  ) +
  geom_point(data = final_prototypes,
             aes(x = feature2_mean, y = feature1_mean, color = factor(category)),
             size = 4, shape = 18) +
  geom_path(data = ellipse_data,
            aes(x = feature2_mean, y = feature1_mean,
                group = category, color = factor(category)),
            alpha = 0.7, linewidth = 0.8) +
  scale_color_manual(values = c("0" = "#0072B2", "1" = "#D55E00"), name = "Category") +
  scale_shape_manual(values = c("0" = 16, "1" = 17), name = "Category") +
  labs(
    title    = "Prototype Learning Trajectory (Kalman Filter, r = 1.0)",
    subtitle = "Prototypes start at (2.5, 2.5). Dashed lines show path; ellipses show final 68% uncertainty.",
    x = "Position Feature", y = "Height Feature"
  ) +
  coord_cartesian(
    xlim = range(c(stimulus_info$position, prototype_trajectory$feature2_mean)) + c(-0.5, 0.5),
    ylim = range(c(stimulus_info$height,   prototype_trajectory$feature1_mean)) + c(-0.5, 0.5)
  )

Interpretation: The dashed lines show the path the prototypes took during learning (starting from (2.5, 2.5), the center of the feature space). The final prototype locations are marked with diamonds, and the ellipses represent the model’s final uncertainty (68% confidence) about the true prototype location for each category. We see the prototypes moving towards the center of their respective category members and uncertainty shrinking over time — exactly the incremental, variance-reducing Bayesian update the Kalman filter implements. The shrinking ellipses indicate that once the agent has seen enough exemplars, the covariances are small enough that further learning is negligible, and the prototypes are effectively frozen.


15.5 Prior Predictive Check

Before fitting any data, we verify that the prior on log_r generates only scientifically plausible learning trajectories. We draw \(S = 500\) values from \(\log r \sim \text{Normal}(0, 1)\) — implying \(r \in (0.13, 7.4)\) at \(\pm 2\) SD — simulate one agent per draw, and plot the envelope.

Code
n_ppc_samples <- 500

ppc_log_r  <- rnorm(n_ppc_samples, LOG_R_PRIOR_MEAN, LOG_R_PRIOR_SD)
ppc_log_q  <- rnorm(n_ppc_samples, LOG_Q_PRIOR_MEAN, LOG_Q_PRIOR_SD)
ppc_r      <- exp(ppc_log_r)
ppc_q      <- exp(ppc_log_q)
ppc_sched  <- make_subject_schedule(stimulus_info, n_blocks, seed = 999)

prior_pred_curves <- map_dfr(seq_len(n_ppc_samples), function(s) {
  res <- prototype_kalman(
    r_value   = ppc_r[s],
    q_value   = ppc_q[s],
    obs       = as.matrix(ppc_sched[, c("height", "position")]),
    cat_one   = ppc_sched$category_feedback,
    initial_mu = list(rep(2.5, 2), rep(2.5, 2))
  )
  tibble(
    sample_id = s,
    trial     = seq_len(nrow(ppc_sched)),
    correct   = as.integer(res$sim_response == ppc_sched$category_feedback)
  ) |>
    mutate(cum_acc = cumsum(correct) / row_number())
})

ppc_summary <- prior_pred_curves |>
  group_by(trial) |>
  summarise(
    q05 = quantile(cum_acc, 0.05), q25 = quantile(cum_acc, 0.25),
    q50 = median(cum_acc),
    q75 = quantile(cum_acc, 0.75), q95 = quantile(cum_acc, 0.95),
    .groups = "drop"
  )

ggplot(ppc_summary, aes(x = trial)) +
  geom_ribbon(aes(ymin = q05, ymax = q95), fill = "#009E73", alpha = 0.20) +
  geom_ribbon(aes(ymin = q25, ymax = q75), fill = "#009E73", alpha = 0.40) +
  geom_line(aes(y = q50), color = "#006D4E", linewidth = 1) +
  geom_hline(yintercept = 0.5, linetype = "dashed", color = "grey40") +
  scale_y_continuous(limits = c(0, 1)) +
  labs(
    title    = "Prior Predictive Check: Kalman Filter Prototype Model",
    subtitle = "Ribbons: 50% and 90% prior predictive intervals\nPriors: log(r) ~ Normal(0, 1), log(q) ~ Normal(-2, 1)",
    x = "Trial", y = "Cumulative Accuracy"
  )

The prior predictive envelope spans chance performance (uninformative \(r\) draws) through near-ceiling accuracy (informative \(r\) draws), without implying impossible behavior such as sustained below-chance performance. Not a hard bar to pass, but a good start.


15.6 The Prototype Model in Stan

Simulations help us understand the model, but we want to fit it to experimental data to estimate its parameters. The primary free parameter is the observation noise variance r_value. We implement the Kalman filter prototype model in Stan to estimate r_value (on the log scale) from data.

15.6.1 Some implementation notes

15.6.2 Kalman filter in the transformed parameters

The prototype model has two parameters — log_r and log_q. Given both and the observed data (stimuli and feedback), the entire Kalman filter state is deterministic: prototype means, prototype covariances, and the per-trial choice probability \(p_i\) all follow algebraically from log_r and the trial sequence. Whenever a quantity depends on a parameter but is otherwise deterministic, it belongs in transformed parameters (evaluated once per leapfrog step) rather than transformed data (which is for parameter-free preprocessing) or model{} (which would force a recomputation in generated quantities). We follow that rule here, and log_lik[i] in generated quantities becomes a one-line lookup of p[i].

A note on architectural symmetry with the chapter on exemplars. The previous chapter GCM’s choice probabilities were deterministic given the parameters and the data, and therefore belonged in transformed parameters. The same is true for the kalman filter model: prob_cat1[i] is a transformed parameter, the model block reduces to a single vectorised bernoulli_lpmf(y | prob_cat1), and log_lik[i] is a one-line lookup.

15.6.2.1 Path dependence requires LFO-CV

The same path-dependence concern of the previous chapter applies: trial \(t\)’s choice probability depends on all previous stimuli through the Kalman state, so standard PSIS-LOO is invalid. Unlike previous drafts, we no longer treat that as a caveat to repeat — the LFO-CV machinery from the previous chapter is generalised to the Kalman case in §“Leave-Future-Out Cross-Validation” below.

reduce_sum and parallelization: As with the GCM, the Kalman prototype model cannot use reduce_sum because successive Kalman steps are causally chained — trial \(t\)’s contribution to the likelihood depends on the filter state from trials \(1, \ldots, t-1\). The entire filter must be computed sequentially. With two parameters (log_r, log_q) the posterior geometry is 2D, and Pathfinder initialization is still reliable because both parameters are unconstrained scalars with prior-dominated tails.

15.6.3 Stan Model

Code
prototype_single_stan <- "
// Kalman Filter Prototype Model — Single Subject, with Process Noise
// Key design:
//   log_r and log_q are the two parameters (both unconstrained).
//   r_value = exp(log_r): observation noise variance
//   q_value = exp(log_q): process noise / prototype drift rate
//   Kalman loop structure: Predict (add Q) → Decide → Update.
//   The full filter is in transformed parameters so log_lik in generated
//   quantities reads p[i] directly without re-running the filter.

data {
  int<lower=1> ntrials;
  int<lower=1> nfeatures;
  array[ntrials] int<lower=0, upper=1> cat_one;
  array[ntrials] int<lower=0, upper=1> y;
  array[ntrials, nfeatures] real obs;

  vector[nfeatures] initial_mu_cat0;
  vector[nfeatures] initial_mu_cat1;
  real<lower=0> initial_sigma_diag;

  real prior_logr_mean;
  real<lower=0> prior_logr_sd;
  real prior_logq_mean;
  real<lower=0> prior_logq_sd;
}

parameters {
  real log_r;
  real log_q;
}

transformed parameters {
  real<lower=0> r_value = exp(log_r);
  real<lower=0> q_value = exp(log_q);

  array[ntrials] real<lower=1e-9, upper=1-1e-9> p;

  {
    vector[nfeatures] mu_cat0 = initial_mu_cat0;
    vector[nfeatures] mu_cat1 = initial_mu_cat1;
    matrix[nfeatures, nfeatures] sigma_cat0 =
      diag_matrix(rep_vector(initial_sigma_diag, nfeatures));
    matrix[nfeatures, nfeatures] sigma_cat1 =
      diag_matrix(rep_vector(initial_sigma_diag, nfeatures));
    matrix[nfeatures, nfeatures] r_matrix =
      diag_matrix(rep_vector(r_value, nfeatures));
    matrix[nfeatures, nfeatures] q_matrix =
      diag_matrix(rep_vector(q_value, nfeatures));
    matrix[nfeatures, nfeatures] I_mat =
      diag_matrix(rep_vector(1.0, nfeatures));

    for (i in 1:ntrials) {
      vector[nfeatures] x = to_vector(obs[i]);

      // ── Prediction step: add process noise to both categories ──────────
      sigma_cat0 = sigma_cat0 + q_matrix;
      sigma_cat1 = sigma_cat1 + q_matrix;

      // ── Decision ────────────────────────────────────────────────────────
      matrix[nfeatures, nfeatures] cov0 = sigma_cat0 + r_matrix;
      matrix[nfeatures, nfeatures] cov1 = sigma_cat1 + r_matrix;

      real log_p0 = multi_normal_lpdf(x | mu_cat0, cov0);
      real log_p1 = multi_normal_lpdf(x | mu_cat1, cov1);
      real prob1  = exp(log_p1 - log_sum_exp(log_p0, log_p1));

      p[i] = fmax(1e-9, fmin(1 - 1e-9, prob1));

      // ── Update (measurement update for the correct category only) ────────
      if (cat_one[i] == 1) {
        vector[nfeatures] innov = x - mu_cat1;
        matrix[nfeatures, nfeatures] S  = sigma_cat1 + r_matrix;
        matrix[nfeatures, nfeatures] K  = mdivide_right_spd(sigma_cat1, S);
        matrix[nfeatures, nfeatures] IK = I_mat - K;
        mu_cat1    = mu_cat1 + K * innov;
        sigma_cat1 = IK * sigma_cat1 * IK' + K * r_matrix * K';
        sigma_cat1 = 0.5 * (sigma_cat1 + sigma_cat1');
      } else {
        vector[nfeatures] innov = x - mu_cat0;
        matrix[nfeatures, nfeatures] S  = sigma_cat0 + r_matrix;
        matrix[nfeatures, nfeatures] K  = mdivide_right_spd(sigma_cat0, S);
        matrix[nfeatures, nfeatures] IK = I_mat - K;
        mu_cat0    = mu_cat0 + K * innov;
        sigma_cat0 = IK * sigma_cat0 * IK' + K * r_matrix * K';
        sigma_cat0 = 0.5 * (sigma_cat0 + sigma_cat0');
      }
    }
  }
}

model {
  target += normal_lpdf(log_r | prior_logr_mean, prior_logr_sd);
  target += normal_lpdf(log_q | prior_logq_mean, prior_logq_sd);
  target += bernoulli_lpmf(y | p);
}

generated quantities {
  vector[ntrials] log_lik;
  real lprior;
  for (i in 1:ntrials)
    log_lik[i] = bernoulli_lpmf(y[i] | p[i]);
  lprior = normal_lpdf(log_r | prior_logr_mean, prior_logr_sd) +
           normal_lpdf(log_q | prior_logq_mean, prior_logq_sd);
}
"

stan_file_proto_single <- here("stan", "ch13_prototype_single.stan")

write_stan_file(prototype_single_stan, dir = here("stan"),
                basename = "ch13_prototype_single.stan")
[1] "/Users/au209589/Dropbox/Teaching/AdvancedCognitiveModeling23_book/stan/ch13_prototype_single.stan"
Code
mod_prototype_single <- cmdstan_model(stan_file_proto_single)

Key aspects of the Stan implementation:

  • Parameters block: Two scalars log_r and log_q. Positivity constraints are handled by exp() in transformed parameters.
  • Prediction step: sigma_cat0 and sigma_cat1 each have q_matrix added before the decision. With q_value = 0 this recovers the static-target filter.
  • Transformed parameters: Runs the full Kalman filter trial-by-trial, updating the two prototype distributions and recording p[i] — the probability of choosing Category 1 on trial \(i\) — before each trial’s update. The update happens on all trials (including the last), consistent with the R simulation.
  • Model block: Priors on log_r and log_q plus a vectorised Bernoulli likelihood using the pre-computed p vector. One line.
  • Generated quantities: log_lik[i] reads directly from p[i] — no re-computation of the Kalman filter required.

15.7 Fitting the Model to Simulated Data

We use Pathfinder initialization before HMC sampling. Theoretically, it should make the HMC faster to run. I haven’t tested that yet for this model, shamefully, but it doesn’t hurt to do it.

Code
# Select one simulated agent with r_value = 2.0, q_value = 0.1
agent_to_fit <- prototype_responses |>
  filter(r_value_true == 2.0, q_value_true == 0.1, agent_id == 1)

stopifnot(nrow(agent_to_fit) == total_trials)

proto_data_single <- list(
  ntrials            = nrow(agent_to_fit),
  nfeatures          = 2L,
  cat_one            = agent_to_fit$category_feedback,
  y                  = agent_to_fit$sim_response,
  obs                = as.matrix(agent_to_fit[, c("height", "position")]),
  initial_mu_cat0    = c(2.5, 2.5),
  initial_mu_cat1    = c(2.5, 2.5),
  initial_sigma_diag = 10.0,
  prior_logr_mean    = LOG_R_PRIOR_MEAN,
  prior_logr_sd      = LOG_R_PRIOR_SD,
  prior_logq_mean    = LOG_Q_PRIOR_MEAN,
  prior_logq_sd      = LOG_Q_PRIOR_SD
)

fit_filepath_single <- here("simmodels", "ch13_proto_single_fit.rds")
if (regenerate_simulations || !file.exists(fit_filepath_single)) {
  pf_single <- mod_prototype_single$pathfinder(
    data = proto_data_single, seed = 123, num_paths = 4, refresh = 0
  )

  fit_proto_single <- mod_prototype_single$sample(
    data            = proto_data_single,
    init            = pf_single,
    seed            = 123,
    chains          = 4,
    parallel_chains = min(4, availableCores()),
    iter_warmup     = 1000,
    iter_sampling   = 1500,
    refresh         = 300,
    adapt_delta     = 0.9
  )
  fit_proto_single$save_object(fit_filepath_single)
  cat("Single-agent prototype fit computed and saved.\n")
} else {
  fit_proto_single <- readRDS(fit_filepath_single)
  cat("Loaded existing single-agent prototype fit.\n")
}
Loaded existing single-agent prototype fit.
Code
if (!is.null(fit_proto_single)) {
  param_summary <- fit_proto_single$summary(variables = c("log_r", "r_value", "log_q", "q_value"))
  print(param_summary)

  true_r     <- agent_to_fit$r_value_true[1]
  true_log_r <- log(true_r)
  true_q     <- agent_to_fit$q_value_true[1]
  true_log_q <- log(true_q)
  cat("\nTrue log_r =", round(true_log_r, 3), "  True r =", true_r, "\n")
  cat("True log_q =", round(true_log_q, 3), "  True q =", true_q, "\n")
}
# A tibble: 4 × 10
  variable    mean  median    sd   mad      q5     q95  rhat ess_bulk ess_tail
  <chr>      <dbl>   <dbl> <dbl> <dbl>   <dbl>   <dbl> <dbl>    <dbl>    <dbl>
1 log_r     0.0948  0.0971 0.548 0.527 -0.808   0.991   1.00    1841.    1906.
2 r_value   1.28    1.10   0.749 0.565  0.446   2.69    1.00    1841.    1906.
3 log_q    -1.22   -1.12   0.817 0.776 -2.70   -0.0641  1.00    2026.    1926.
4 q_value   0.390   0.328  0.280 0.244  0.0670  0.938   1.00    2026.    1926.

True log_r = 0.693   True r = 2 
True log_q = -2.303   True q = 0.1 

15.7.1 MCMC Diagnostic Battery

Before reading anything off the posterior, we need to check our diagnostic table.

Code
if (!is.null(fit_proto_single)) {
  diag_tbl <- diagnostic_summary_table(
    fit_proto_single,
    params = c("log_r", "r_value", "log_q", "q_value")
  )
  print(diag_tbl)
}
# A tibble: 6 × 4
  metric                           value threshold pass 
  <chr>                            <dbl> <chr>     <lgl>
1 Divergences (zero tolerance)    0      == 0      TRUE 
2 Max rank-normalised R-hat       1      < 1.01    TRUE 
3 Min bulk ESS                 1841      > 400     TRUE 
4 Min tail ESS                 1906      > 400     TRUE 
5 Min E-BFMI                      0.977  > 0.2     TRUE 
6 Max MCSE / posterior SD         0.0233 < 0.05    TRUE 

The trace plot grid plus rank histograms remain the primary MCMC diagnostics; we save the joint geometry of log_r and log_q for after the prior–posterior overlay, where it directly informs the interpretation.

Code
if (!is.null(fit_proto_single)) {
  p_trace <- bayesplot::mcmc_trace(
    fit_proto_single$draws(c("log_r", "log_q")),
    facet_args = list(ncol = 2)
  ) +
    ggtitle("Trace plots — should look like hairy caterpillars")

  p_rank <- bayesplot::mcmc_rank_overlay(
    fit_proto_single$draws(c("log_r", "log_q")),
    facet_args = list(ncol = 2)
  ) +
    ggtitle("Rank histograms — should be uniform across chains")

  print(p_trace / p_rank)
}

15.7.2 Posterior Predictive Check

The posterior predictive check asks: does the fitted model, when run forward with parameters drawn from the posterior, produce data that look like the observed data? Since p is already in transformed parameters, we can read posterior draws of the choice probability directly and use them to generate posterior predictive responses.

Code
if (!is.null(fit_proto_single)) {
  # Posterior draws of p[i] (one row per draw, one column per trial)
  p_draws <- fit_proto_single$draws("p", format = "matrix")
  n_draws <- nrow(p_draws)
  n_obs   <- ncol(p_draws)

  # Generate posterior predictive responses
  yrep <- matrix(rbinom(n_draws * n_obs, 1, as.vector(p_draws)),
                 nrow = n_draws, ncol = n_obs)

  # Cumulative-accuracy ribbons vs. observed
  obs_cum_acc <- cumsum(agent_to_fit$sim_response == agent_to_fit$category_feedback) /
                 seq_len(n_obs)
                 
  # Use "==" directly and add 0 to coerce to numeric without losing matrix dims
  yrep_correct <- sweep(yrep, 2, as.numeric(agent_to_fit$category_feedback), "==") + 0
  
  yrep_cum_acc <- t(apply(yrep_correct, 1, function(r) cumsum(r) / seq_along(r)))

  cum_acc_summary <- tibble(
    trial = seq_len(n_obs),
    q05   = apply(yrep_cum_acc, 2, quantile, 0.05),
    q25   = apply(yrep_cum_acc, 2, quantile, 0.25),
    q50   = apply(yrep_cum_acc, 2, quantile, 0.50),
    q75   = apply(yrep_cum_acc, 2, quantile, 0.75),
    q95   = apply(yrep_cum_acc, 2, quantile, 0.95),
    obs   = obs_cum_acc
  )

  ggplot(cum_acc_summary, aes(x = trial)) +
    geom_ribbon(aes(ymin = q05, ymax = q95), fill = "#56B4E9", alpha = 0.20) +
    geom_ribbon(aes(ymin = q25, ymax = q75), fill = "#56B4E9", alpha = 0.40) +
    geom_line(aes(y = q50), color = "#0072B2", linewidth = 1) +
    geom_line(aes(y = obs), color = "#D55E00", linewidth = 1) +
    geom_hline(yintercept = 0.5, linetype = "dashed", color = "grey50") +
    labs(
      title    = "Posterior Predictive Check: Cumulative Accuracy",
      subtitle = "Blue: posterior predictive (50% and 90% intervals + median)\nOrange: observed",
      x = "Trial", y = "Cumulative accuracy"
    ) +
    coord_cartesian(ylim = c(0.4, 1))
}

The observed cumulative-accuracy curve should sit comfortably inside the posterior predictive band. If it drifts to the edge, the model is missing structure that the data carry — the most likely candidate being a separate decisional-noise term that the single \(r\) parameter cannot capture.

15.7.3 Prior–Posterior Update Plot

Did the data teach the model anything? We overlay prior and posterior densities on both the computational scale (log_r, log_q, where the priors were defined) and the interpretable scale (r, q, where the cognitive interpretation lives). Heavy overlap means the data were uninformative; a narrow posterior far from the truth means the model is confidently wrong.

Code
if (!is.null(fit_proto_single)) {
  draws_df <- as_draws_df(
    fit_proto_single$draws(c("log_r", "r_value", "log_q", "q_value"))
  )

  set.seed(7)
  n_prior    <- nrow(draws_df)
  prior_logr <- rnorm(n_prior, LOG_R_PRIOR_MEAN, LOG_R_PRIOR_SD)
  prior_r    <- exp(prior_logr)
  prior_logq <- rnorm(n_prior, LOG_Q_PRIOR_MEAN, LOG_Q_PRIOR_SD)
  prior_q    <- exp(prior_logq)

  prior_post_df <- bind_rows(
    tibble(scale = "log_r", value = prior_logr,    source = "prior"),
    tibble(scale = "log_r", value = draws_df$log_r, source = "posterior"),
    tibble(scale = "r",     value = prior_r,        source = "prior"),
    tibble(scale = "r",     value = draws_df$r_value, source = "posterior"),
    tibble(scale = "log_q", value = prior_logq,    source = "prior"),
    tibble(scale = "log_q", value = draws_df$log_q, source = "posterior"),
    tibble(scale = "q",     value = prior_q,        source = "prior"),
    tibble(scale = "q",     value = draws_df$q_value, source = "posterior")
  ) |>
    mutate(scale = factor(scale, levels = c("log_r", "r", "log_q", "q")))

  true_lines <- tibble(
    scale = factor(c("log_r", "r", "log_q", "q"),
                   levels = c("log_r", "r", "log_q", "q")),
    value = c(true_log_r, true_r, true_log_q, true_q)
  )

  ggplot(prior_post_df, aes(x = value, fill = source, color = source)) +
    geom_density(alpha = 0.4, linewidth = 0.6) +
    geom_vline(data = true_lines,
               aes(xintercept = value),
               color = "red", linetype = "dashed") +
    facet_wrap(~ scale, scales = "free", ncol = 2) +
    scale_fill_manual(values = c(prior = "#999999", posterior = "#009E73")) +
    scale_color_manual(values = c(prior = "#666666", posterior = "#006D4E")) +
    labs(
      title    = "Prior → Posterior Update: Kalman Prototype Model (r and q)",
      subtitle = "Red dashed line = true generating value",
      x = NULL, y = "Density"
    )
}

The two parameters update very differently. On log_r, the posterior is clearly tighter than the prior and the true value sits well inside the posterior mass — modest but real information gain, with no sign of confident error. On the r scale the prior is so heavy-tailed that the contraction is visually compressed into the bulk near zero; the posterior is mostly a sharper version of the prior’s left mode, which is the right behaviour given a single subject’s worth of trials.

log_q is the more interesting panel. The posterior has shifted noticeably to the right of the prior, with its mode near \(-1\) while the true value sits around \(-2.3\) in the posterior’s left tail. The data were informative — the posterior is not just the prior — but they pulled \(q\) toward larger values than truly generated the sequence. This is the kind of mild, prior-dominated bias one expects when a single trajectory under-determines the process-noise scale: many \((r, q)\) combinations explain the choices about equally well, and the posterior settles in the high-likelihood ridge rather than at the truth. The pairs plot below is the direct way to assess that likely ridge; SBC across many simulated agents is what will tell us whether the offset is systematic.

Code
if (!is.null(fit_proto_single)) {
  bayesplot::mcmc_pairs(
    fit_proto_single$draws(c("log_r", "log_q")),
    diag_fun     = "dens",
    off_diag_fun = "scatter",
    np           = nuts_params(fit_proto_single)
  )
}

The off-diagonal scatter shows the ridge plainly: log_r and log_q are clearly negatively correlated, with a fan that opens toward small log_q (low process noise paired with high observation noise). Along that ridge, decreasing \(q\) and increasing \(r\) leave the predicted choice probabilities largely unchanged — a smaller process step combined with noisier readout produces the same trial-by-trial behaviour as a larger process step with cleaner readout. That is exactly why the marginal posterior on log_q can sit above the true value while still fitting the choices well: the truth lies somewhere on this ridge, and a single subject’s data are not enough to pin down where. What should we do? We could try hierarchical pooling across subjects, or - likely to work - change the experimental setup more dynamic categories, and more trials could help.

15.7.4 Randomised LOO-PIT Calibration

For continuous outcomes the preferred calibration check is LOO-PIT. For binary outcomes the equivalent is the randomised LOO-PIT (Säilynoja et al. 2022), which avoids the degenerate two-bin distribution that plain LOO-PIT produces on Bernoulli data. Chapter 13 implements this for the GCM; we apply the same construction here.

Code
if (!is.null(fit_proto_single)) {
  # PSIS-LOO weights from log_lik
  log_lik_mat <- fit_proto_single$draws("log_lik", format = "matrix")
  
  # FIX: Add save_psis = TRUE so the weight matrix is retained
  loo_obj     <- loo::loo(log_lik_mat, cores = 4, save_psis = TRUE)
  log_w       <- loo::weights.importance_sampling(loo_obj$psis_object, 
                                                  normalize = TRUE, 
                                                  log = TRUE)

  # Posterior draws of choice probability
  p_draws <- fit_proto_single$draws("p", format = "matrix")

  # Per-observation LOO predictive probability of cat 1
  loo_p_cat1 <- vapply(seq_len(ncol(p_draws)), function(i) {
    sum(exp(log_w[, i]) * p_draws[, i])
  }, numeric(1))

  y_obs <- agent_to_fit$sim_response

  # Randomised PIT for binary y:
  #   if y == 1: U ~ Uniform(1 - p, 1)
  #   if y == 0: U ~ Uniform(0, 1 - p)
  set.seed(2026)
  rloo_pit <- vapply(seq_along(y_obs), function(i) {
    p_i <- loo_p_cat1[i]
    if (y_obs[i] == 1) runif(1, 1 - p_i, 1) else runif(1, 0, 1 - p_i)
  }, numeric(1))

  p_pit <- ggplot(tibble(rloo_pit = rloo_pit), aes(x = rloo_pit)) +
    geom_histogram(aes(y = after_stat(density)), bins = 20,
                   fill = "#56B4E9", color = "white") +
    geom_hline(yintercept = 1, linetype = "dashed", color = "red") +
    labs(
      title    = "Randomised LOO-PIT: Kalman Prototype Model",
      subtitle = "If the model is calibrated, the histogram should look uniform on [0, 1]",
      x = "Randomised PIT value", y = "Density"
    ) +
    coord_cartesian(ylim = c(0, 2.5))
    
  print(p_pit)
}

A roughly uniform histogram is the calibration target. The histogram here is bumpy but not pathological: bars hover around the uniform reference (red dashed line at density 1), with no systematic U-shape (which would flag overconfidence), no central hump (underconfidence), and no monotone skew (a directional bias the model has not captured). The two visible spikes — one near \(0.1\)\(0.25\) and a taller one near \(0.7\) — are the kind of small-sample noise expected from 64 binary trials; with only a few dozen randomised PIT values, individual bins fluctuate considerably even under perfect calibration. The mild excess on the lower-middle and upper-middle and the slight under-density in the tails (\(>0.85\)) are worth keeping in mind, but with this much data they do not constitute evidence of miscalibration. The check kinda passes; the meaningful calibration question is whether this pattern persists across many simulated agents, which is what SBC will answer.

15.7.5 LOO with Pareto-k as a Diagnostic, Not a Predictive Score

Code
if (!is.null(fit_proto_single)) {
  loo_proto <- loo::loo(log_lik_mat, cores = 4)
  print(loo_proto)
  plot(loo_proto, label_points = TRUE,
       main = "PSIS-LOO Pareto-k: Kalman Prototype Model")
}

Computed from 6000 by 64 log-likelihood matrix.

         Estimate  SE
elpd_loo    -38.8 3.5
p_loo         1.1 0.2
looic        77.6 7.0
------
MCSE of elpd_loo is NA.
MCSE and ESS estimates assume independent draws (r_eff=1).

Pareto k diagnostic values:
                         Count Pct.    Min. ESS
(-Inf, 0.7]   (good)     63    98.4%   4987    
   (0.7, 1]   (bad)       0     0.0%   <NA>    
   (1, Inf)   (very bad)  1     1.6%   <NA>    
See help('pareto-k-diagnostic') for details.

⚠️ Reframing PSIS-LOO for sequential models. The Kalman filter is path-dependent: trial \(t\)’s choice probability depends on the entire filter state from trials \(1, \ldots, t-1\). Standard PSIS-LOO is not a valid measure of out-of-sample predictive accuracy here — leaving out trial \(t\) implicitly conditions on a counterfactual filter state that the model never actually saw. We compute it anyway, for two reasons: (1) the Pareto-\(\hat{k}\) values are still informative as a localisation diagnostic — high \(\hat{k}\) at specific trials flags observations the model cannot easily reconstruct from the rest, which can reveal where the structural assumptions are stressed; (2) compatibility with downstream tooling. The actual predictive evaluation is done by LFO-CV in the next subsection. The numeric ELPD reported by loo() should not be taken as a generalisation estimate.

What we see here. The Pareto-\(\hat{k}\) scatter is essentially flat: 63 of 64 points sit between roughly \(-0.2\) and \(+0.35\), with \(p_\text{loo} = 1.1\) (close to the parameter count, as expected for a well-behaved fit) and a minimum ESS of $$5000 across observations. Read as a localisation diagnostic, this is the good case — no individual trial is a structural outlier that the rest of the data cannot reconstruct. The single “very bad” entry counted in the diagnostic table (1.6%) is the first trial, which has no preceding history and so is not a meaningful PSIS-LOO target; the plot above clips it. The headline \(\text{elpd}_\text{loo} = -38.8\) is reported for completeness but, per the reframing above, should not be interpreted as out-of-sample predictive accuracy — that is what LFO-CV in the next subsection actually measures.

15.7.6 Leave-Future-Out Cross-Validation: Inheriting and Generalising the GCM Implementation

For sequential models, LOO is invalid and LFO-CV is the correct alternative (Bürkner, Gabry, & Vehtari 2020). Chapter 13 implements PSIS-LFO for the GCM as psis_lfo_gcm(). The Kalman prototype model has the same path-dependence structure but a much cheaper per-trial cost, so LFO-CV is even more tractable here. Below is the same algorithm specialised to the prototype model and saved as psis_lfo_kalman(). The two functions share an interface so later chapters with empirical analyses can call either through a common front-end.

The algorithm (Bürkner et al. 2020 Algorithm 1):

  1. Choose a minimum training window \(L\) (we use \(L = 16\), two blocks of the Kruschke task).
  2. Fit the model once on \(y_{1:L}\). This is the reference fit.
  3. For each \(t = L+1, \ldots, T\):
    1. Use the reference fit’s posterior draws \(\theta^{(s)}\) to compute the importance ratio for re-weighting \(\theta\) from \(p(\theta \mid y_{1:t-1})\) relative to the reference.
    2. Smooth the ratios with PSIS, check Pareto \(\hat{k}\).
    3. If \(\hat{k} \leq 0.7\), accept the ratios and use them to evaluate \(\log p(y_t \mid y_{1:t-1})\).
    4. If \(\hat{k} > 0.7\), refit the model on \(y_{1:t-1}\), update the reference, and continue.
  4. Return pointwise ELPD, the refit history, and the \(\hat{k}\) trajectory.
Code
# psis_lfo_kalman(): Leave-future-out CV for the Kalman prototype model.
# Mirrors psis_lfo_gcm() from Chapter 12. Returns pointwise ELPD across the
# evaluation window, the indices at which a refit was triggered, and the
# Pareto-k trajectory.
psis_lfo_kalman <- function(stan_data,
                            stan_model,
                            L           = 16,    # min training window
                            k_thresh    = 0.7,   # PSIS-k refit threshold
                            iter_warmup = 800,
                            iter_sample = 1000,
                            chains      = 2,
                            seed        = 1) {
  T_total <- stan_data$ntrials

  # Helper: build a stan_data list restricted to trials 1:t
  truncate_data <- function(d, t) {
    out <- d
    out$ntrials <- t
    out$cat_one <- d$cat_one[1:t]
    out$y       <- d$y[1:t]
    out$obs     <- d$obs[1:t, , drop = FALSE]
    out
  }

  # Helper: fit on y_{1:t}
  refit <- function(t) {
    d_t <- truncate_data(stan_data, t)
    pf  <- tryCatch(
      stan_model$pathfinder(data = d_t, num_paths = 4, refresh = 0,
                            seed = seed + t),
      error = function(e) NULL
    )
    init_val <- if (!is.null(pf)) pf else 0.5
    stan_model$sample(
      data            = d_t,
      init            = init_val,
      seed            = seed + t,
      chains          = chains,
      parallel_chains = chains,
      iter_warmup     = iter_warmup,
      iter_sampling   = iter_sample,
      refresh         = 0,
      adapt_delta     = 0.9
    )
  }

  # Helper: given a fitted model, evaluate log_lik at trial t
  # using cmdstanr's $generate_quantities() so we don't re-MCMC.
  # Helper: given a fitted model, evaluate log_lik at trial t
  eval_log_lik_at <- function(fit_ref, t) {
    # Build the "extended" data that includes trial t in the likelihood
    d_ext <- truncate_data(stan_data, t)
    
    # FIX: Call generate_quantities on stan_model, not fit_ref
    gq <- stan_model$generate_quantities(
      fitted_params = fit_ref,
      data          = d_ext,
      seed          = seed + t,
      parallel_chains = chains
    )
    
    # log_lik[t] for each posterior draw
    ll_mat <- gq$draws("log_lik", format = "matrix")
    ll_mat[, t]
  }

  # ── Initial fit on the minimum window ─────────────────────────────────────
  cat("Initial LFO fit on trials 1 to", L, "...\n")
  fit_ref <- refit(L)
  ref_window <- L

  # Storage
  pointwise_elpd <- rep(NA_real_, T_total)
  k_hat_traj     <- rep(NA_real_, T_total)
  refit_at       <- integer(0)

  # ── Sequential evaluation ─────────────────────────────────────────────────
  log_ratios_accum <- numeric(0)  # log p(y_{ref+1:t-1} | theta) accumulation
  for (t in (L + 1):T_total) {
    # Build importance ratios for theta_ref vs theta_{1:t-1}
    if (length(log_ratios_accum) == 0) {
      log_w <- rep(0, posterior::ndraws(fit_ref$draws()))
    } else {
      log_w <- log_ratios_accum  # raw log ratios; psis() handles normalization
    }

    # PSIS smoothing
    psis_obj <- tryCatch(
      loo::psis(log_ratios = matrix(log_w, ncol = 1), r_eff = 1),
      error = function(e) NULL
    )
    k_hat <- if (!is.null(psis_obj)) loo::pareto_k_values(psis_obj) else Inf
    k_hat_traj[t] <- k_hat

    if (!is.finite(k_hat) || k_hat > k_thresh) {
      # Refit on y_{1:t-1}
      cat("  trial", t, ": k_hat =", round(k_hat, 2), "→ refitting\n")
      fit_ref <- refit(t - 1)
      ref_window <- t - 1
      refit_at <- c(refit_at, t)
      log_ratios_accum <- numeric(0)
      log_w <- rep(0, posterior::ndraws(fit_ref$draws()))
      psis_obj <- loo::psis(log_ratios = matrix(log_w, ncol = 1), r_eff = 1)
    }

    # Evaluate log p(y_t | y_{1:t-1}, theta) for each draw
    ll_t <- eval_log_lik_at(fit_ref, t)

    # Importance-weighted estimate of log p(y_t | y_{1:t-1})
    log_w_norm <- as.vector(weights(psis_obj, normalize = TRUE, log = TRUE))
    pointwise_elpd[t] <- matrixStats::logSumExp(log_w_norm + ll_t)

    # Accumulate log p(y_t | theta) for the next iteration's importance ratio
    if (length(log_ratios_accum) == 0) {
      log_ratios_accum <- ll_t
    } else {
      log_ratios_accum <- log_ratios_accum + ll_t
    }
  }

  list(
    pointwise_elpd = pointwise_elpd,
    refit_at       = refit_at,
    k_hat_traj     = k_hat_traj,
    L              = L,
    elpd_lfo       = sum(pointwise_elpd, na.rm = TRUE),
    n_refits       = length(refit_at)
  )
}

Note on the importance-ratio accumulation. The trick that makes PSIS-LFO efficient is that the importance ratio for \(p(\theta \mid y_{1:t-1})\) vs the reference \(p(\theta \mid y_{1:L})\) is a running product: for each new trial \(t\) we just multiply the existing per-draw weight by \(p(y_{t-1} \mid \theta)\). We never need to refit so long as the smoothed Pareto-\(\hat{k}\) stays below the threshold. When it crosses 0.7, we refit and reset the running product. The accumulation in log_ratios_accum above implements this — every time we evaluate log_lik[t], we add it to the cumulative log-ratio for the next iteration.

We now run this on the same single-agent fit, and validate it against exact one-step-ahead refits at every 8th trial. If PSIS-LFO is doing its job, the two ELPD curves should agree closely.

Code
lfo_cache_file <- here("simdata", "ch13_lfo_kalman_results.rds")
if (regenerate_simulations || !file.exists(lfo_cache_file)) {
  cat("Running PSIS-LFO for the prototype model...\n")
  lfo_psis <- psis_lfo_kalman(
    stan_data  = proto_data_single,
    stan_model = mod_prototype_single,
    L          = 16,
    k_thresh   = 0.7
  )

  # ── Exact one-step-ahead refits at every 8th trial ────────────────────────
  exact_t <- seq(24, total_trials, by = 8)
  exact_elpd <- numeric(length(exact_t))
  for (k in seq_along(exact_t)) {
    t <- exact_t[k]
    cat("  exact refit at t =", t, "\n")
    d_t <- proto_data_single
    d_t$ntrials <- t - 1
    d_t$cat_one <- proto_data_single$cat_one[1:(t - 1)]
    d_t$y       <- proto_data_single$y[1:(t - 1)]
    d_t$obs     <- proto_data_single$obs[1:(t - 1), , drop = FALSE]

    pf  <- tryCatch(
      mod_prototype_single$pathfinder(data = d_t, num_paths = 4, refresh = 0),
      error = function(e) NULL
    )
    init_val <- if (!is.null(pf)) pf else 0.5
    fit_t <- mod_prototype_single$sample(
      data = d_t, init = init_val, chains = 2, parallel_chains = 2,
      iter_warmup = 800, iter_sampling = 1000, refresh = 0, adapt_delta = 0.9
    )

    # Evaluate log_lik at trial t
    d_ext <- proto_data_single
    d_ext$ntrials <- t
    d_ext$cat_one <- proto_data_single$cat_one[1:t]
    d_ext$y       <- proto_data_single$y[1:t]
    d_ext$obs     <- proto_data_single$obs[1:t, , drop = FALSE]
    
    # ── THE FIX: Called on mod_prototype_single instead of fit_t ──
    gq <- mod_prototype_single$generate_quantities(fitted_params = fit_t, data = d_ext)
    
    ll_mat <- gq$draws("log_lik", format = "matrix")
    exact_elpd[k] <- matrixStats::logSumExp(ll_mat[, t]) - log(nrow(ll_mat))
  }

  saveRDS(list(lfo_psis = lfo_psis,
               exact_t = exact_t, exact_elpd = exact_elpd),
          lfo_cache_file)
  cat("LFO results computed and saved.\n")
} else {
  cached <- readRDS(lfo_cache_file)
  lfo_psis <- cached$lfo_psis
  exact_t <- cached$exact_t
  exact_elpd <- cached$exact_elpd
  cat("Loaded existing LFO results.\n")
}
Loaded existing LFO results.
Code
if (exists("lfo_psis")) {
  cat("PSIS-LFO sum ELPD:", round(lfo_psis$elpd_lfo, 2), "\n")
  cat("Number of refits triggered:", lfo_psis$n_refits,
      "out of", total_trials - lfo_psis$L, "trials\n")

  # Plot: PSIS-LFO pointwise vs exact refits
  comp_df <- tibble(
    trial    = seq_along(lfo_psis$pointwise_elpd),
    psis_elpd = lfo_psis$pointwise_elpd
  )
  exact_df <- tibble(trial = exact_t, exact_elpd = exact_elpd)

  p_lfo <- ggplot(comp_df, aes(x = trial, y = psis_elpd)) +
    geom_line(color = "#0072B2") +
    geom_point(data = exact_df,
               aes(x = trial, y = exact_elpd),
               color = "#D55E00", size = 2) +
    labs(
      title    = "PSIS-LFO vs. Exact One-Step-Ahead Refits",
      subtitle = "Blue line: PSIS-LFO pointwise ELPD; orange points: exact refits at every 8th trial",
      x = "Trial", y = "log p(y_t | y_{1:t-1})"
    )

  # Plot: Pareto-k trajectory and refit triggers
  k_df <- tibble(trial = seq_along(lfo_psis$k_hat_traj),
                 k_hat = lfo_psis$k_hat_traj)
  refit_df <- tibble(trial = lfo_psis$refit_at)

  p_k <- ggplot(k_df, aes(x = trial, y = k_hat)) +
    geom_line(color = "#009E73") +
    geom_hline(yintercept = 0.7, linetype = "dashed", color = "red") +
    geom_vline(data = refit_df, aes(xintercept = trial),
               linetype = "dotted", color = "red", alpha = 0.5) +
    labs(
      title    = "Pareto-k trajectory and refit triggers",
      subtitle = "Dashed red: k = 0.7 threshold; dotted vertical lines = refits",
      x = "Trial", y = "Pareto k"
    )

  print(p_lfo / p_k)
}
PSIS-LFO sum ELPD: -32.25 
Number of refits triggered: 1 out of 48 trials

Reading the output. The blue line is the PSIS-LFO pointwise ELPD over the evaluation window (\(t = L+1, \ldots, T\) with \(L = 16\)); the orange points are exact one-step-ahead refits at every 8th trial. The two sit on top of each other within Monte Carlo noise: the orange points fall on the blue trace at every trial they appear, so the importance-sampling shortcut is reproducing the exact one-step-ahead predictive density rather than drifting away from it.

The Pareto-\(\hat{k}\) panel tells the same story more directly. There is one initial spike at \(t = L+1 = 17\) that crosses the \(0.7\) threshold and triggers the first (and only) refit — the dotted vertical line — which makes sense: extending the reference fit by one trial after only \(L\) trials of training data is the regime where importance ratios are most volatile. After that single refit, \(\hat{k}\) collapses to negative values and stays well below the threshold for the rest of the sequence, drifting further down (toward \(-0.8\)) as more data accumulates and the posterior becomes increasingly stable. No further refits are needed, so for this fit PSIS-LFO is essentially free after the first step.

The pointwise ELPD itself shows the expected post-warmup pattern: very negative around \(t \approx 17\) (the model has only just begun learning), rapidly rising to a band roughly between \(-1.5\) and \(-0.3\) as the prototype filter converges, with trial-to-trial variation reflecting harder vs. easier stimuli rather than systematic drift. Summing across the window gives the LFO ELPD that should be used for model comparison in Chapter 15 — not the elpd_loo from the previous subsection. If, on a different fit or scenario, the two curves diverged systematically or \(\hat{k}\) stayed above \(0.7\) for many consecutive trials, the exact refit pipeline (slow but unambiguous) is the fallback.

15.7.7 Prior Sensitivity Analysis

Code
if (!is.null(fit_proto_single)) {
  ps <- priorsense::powerscale_sensitivity(fit_proto_single)
  print(ps)
  priorsense::powerscale_plot_dens(fit_proto_single, variables = c("log_r", "r_value", "log_q", "q_value"))
}
Sensitivity based on cjs_dist
Prior selection: all priors
Likelihood selection: all data

 variable prior likelihood                                diagnosis
    log_r 0.079      0.096            potential prior-data conflict
    log_q 0.081      0.137            potential prior-data conflict
  r_value 0.066      0.197            potential prior-data conflict
  q_value 0.208      0.126            potential prior-data conflict
     p[1]   NaN        NaN                                     <NA>
     p[2] 0.052      0.103            potential prior-data conflict
     p[3] 0.077      0.128            potential prior-data conflict
     p[4] 0.090      0.134            potential prior-data conflict
     p[5] 0.065      0.071            potential prior-data conflict
     p[6] 0.116      0.079            potential prior-data conflict
     p[7] 0.029      0.123                                        -
     p[8] 0.028      0.129                                        -
     p[9] 0.038      0.121                                        -
    p[10] 0.034      0.115                                        -
    p[11] 0.059      0.118            potential prior-data conflict
    p[12] 0.059      0.085            potential prior-data conflict
    p[13] 0.034      0.144                                        -
    p[14] 0.035      0.100                                        -
    p[15] 0.083      0.104            potential prior-data conflict
    p[16] 0.073      0.030 potential strong prior / weak likelihood
    p[17] 0.032      0.110                                        -
    p[18] 0.047      0.117                                        -
    p[19] 0.061      0.096            potential prior-data conflict
    p[20] 0.071      0.058            potential prior-data conflict
    p[21] 0.062      0.095            potential prior-data conflict
    p[22] 0.090      0.091            potential prior-data conflict
    p[23] 0.083      0.054            potential prior-data conflict
    p[24] 0.068      0.135            potential prior-data conflict
    p[25] 0.038      0.097                                        -
    p[26] 0.065      0.062            potential prior-data conflict
    p[27] 0.089      0.084            potential prior-data conflict
    p[28] 0.029      0.140                                        -
    p[29] 0.066      0.066            potential prior-data conflict
    p[30] 0.065      0.084            potential prior-data conflict
    p[31] 0.052      0.106            potential prior-data conflict
    p[32] 0.096      0.120            potential prior-data conflict
    p[33] 0.038      0.086                                        -
    p[34] 0.055      0.087            potential prior-data conflict
    p[35] 0.059      0.065            potential prior-data conflict
    p[36] 0.130      0.090            potential prior-data conflict
    p[37] 0.065      0.103            potential prior-data conflict
    p[38] 0.098      0.076            potential prior-data conflict
    p[39] 0.040      0.118                                        -
    p[40] 0.090      0.130            potential prior-data conflict
    p[41] 0.110      0.051            potential prior-data conflict
    p[42] 0.029      0.141                                        -
    p[43] 0.025      0.106                                        -
    p[44] 0.057      0.075            potential prior-data conflict
    p[45] 0.037      0.112                                        -
    p[46] 0.037      0.124                                        -
    p[47] 0.094      0.097            potential prior-data conflict
    p[48] 0.184      0.034 potential strong prior / weak likelihood
    p[49] 0.025      0.115                                        -
    p[50] 0.068      0.074            potential prior-data conflict
    p[51] 0.085      0.122            potential prior-data conflict
    p[52] 0.032      0.129                                        -
    p[53] 0.030      0.091                                        -
    p[54] 0.082      0.092            potential prior-data conflict
    p[55] 0.063      0.035 potential strong prior / weak likelihood
    p[56] 0.102      0.085            potential prior-data conflict
    p[57] 0.040      0.105                                        -
    p[58] 0.026      0.109                                        -
    p[59] 0.076      0.084            potential prior-data conflict
    p[60] 0.033      0.139                                        -
    p[61] 0.036      0.046                                        -
    p[62] 0.133      0.084            potential prior-data conflict
    p[63] 0.067      0.110            potential prior-data conflict
    p[64] 0.090      0.098            potential prior-data conflict

powerscale_sensitivity() perturbs the prior and the likelihood analytically (without re-running MCMC) and reports, for each parameter, a sensitivity index — the cumulative Jensen–Shannon distance between the unperturbed posterior and the perturbed one. A value above \(\sim 0.05\) is non-trivial; above \(0.1\) is a flag. The “diagnosis” column combines the two columns: large likelihood sensitivity with comparable prior sensitivity is reported as potential prior–data conflict; large prior sensitivity with small likelihood sensitivity is reported as strong prior / weak likelihood.

For this single-agent fit the numbers tell a coherent story. log_r (prior \(\psi = 0.08\), likelihood \(\psi = 0.10\)) and log_q (prior \(0.08\), likelihood \(0.14\)) both sit just at or above the soft threshold, with the likelihood doing slightly more of the work — the data are informative, but the prior is not negligible. On the natural scale the picture sharpens: r_value is dominated by the likelihood (prior \(0.07\), likelihood \(0.20\)), while q_value flips the other way (prior \(0.21\), likelihood \(0.13\)) — the prior on \(q\) is materially shaping the posterior, exactly as the prior–posterior overlay and the log_rlog_q ridge already suggested. Almost every per-trial \(p[i]\) is flagged as well, but those flags are induced by the parameter-level fragility flowing through the deterministic Kalman recursion, not by independent likelihood issues at each trial.

Practically, this is not a model failure — it is a small-data diagnosis. With one subject and 64 trials the prior on log_q cannot be treated as a passive default; the conclusion about process noise depends on it. The two routes out are the standard ones: tighter priors backed by an explicit justification, or hierarchical pooling across subjects so that log_q is informed by the population rather than by a single trajectory along the ridge. The density plot overlays posteriors at \(\alpha \in \{0.8, 1, 1.25\}\) for both prior- and likelihood-scaling and confirms the table visually: visible (but not dramatic) shifts on log_q and q_value, near-overlapping curves on log_r and r_value.


15.8 Simulation-Based Calibration Checks (SBC)

Parameter recovery at specific true values is a necessary but not sufficient validation. SBC checks that the full posterior is correctly calibrated across the entire prior distribution — that if the true log_r and log_q were drawn from their priors, the posterior would on average contain each true value in the appropriate fraction of its probability mass.

The prototype model’s two scalar parameters make the resulting ECDF-difference plot particularly legible: the log_r curve should show tight calibration (the data are informative), while the log_q curve may be wider under the static paradigm (partial confound with log_r).

Code
sbc_proto_filepath <- here("simdata", "ch13_sbc_proto_results.rds")
if (regenerate_simulations || !file.exists(sbc_proto_filepath)) {
  n_sbc_iterations_proto <- 500   # use >= 1000 for publication-quality SBC

  proto_sbc_generator <- SBC_generator_function(
    function() {
      # 1a. Draw true parameters from their priors
      log_r_true <- rnorm(1, LOG_R_PRIOR_MEAN, LOG_R_PRIOR_SD)
      r_true     <- exp(log_r_true)
      log_q_true <- rnorm(1, LOG_Q_PRIOR_MEAN, LOG_Q_PRIOR_SD)
      q_true     <- exp(log_q_true)

      # 1b. Simulate one dataset with its own randomized schedule
      sbc_sched <- make_subject_schedule(stimulus_info, n_blocks,
                                         seed = sample.int(1e8, 1))
      sim       <- prototype_kalman(
        r_value    = r_true,
        q_value    = q_true,
        obs        = as.matrix(sbc_sched[, c("height", "position")]),
        cat_one    = sbc_sched$category_feedback,
        initial_mu = list(rep(2.5, 2), rep(2.5, 2))
      )

      list(
        variables = list(log_r = log_r_true, log_q = log_q_true),
        generated = list(
          ntrials            = nrow(sbc_sched),
          nfeatures          = 2L,
          cat_one            = as.integer(sbc_sched$category_feedback),
          y                  = as.integer(sim$sim_response),
          obs                = as.matrix(sbc_sched[, c("height", "position")]),
          initial_mu_cat0    = c(2.5, 2.5),
          initial_mu_cat1    = c(2.5, 2.5),
          initial_sigma_diag = 10.0,
          prior_logr_mean    = LOG_R_PRIOR_MEAN,
          prior_logr_sd      = LOG_R_PRIOR_SD,
          prior_logq_mean    = LOG_Q_PRIOR_MEAN,
          prior_logq_sd      = LOG_Q_PRIOR_SD
        )
      )
    }
  )

  proto_sbc_backend <- SBC_backend_cmdstan_sample(
    mod_prototype_single,
    iter_warmup   = 1000,
    iter_sampling = 1000,
    chains        = 2,
    adapt_delta   = 0.9,
    refresh       = 0
  )

  proto_datasets <- generate_datasets(
    proto_sbc_generator,
    n_sims = n_sbc_iterations_proto
  )

  sbc_results_proto <- compute_SBC(
    proto_datasets,
    proto_sbc_backend,
    cache_mode     = "results",
    cache_location = here("simdata", "ch13_sbc_proto_cache"),
    keep_fits      = FALSE
  )
  saveRDS(sbc_results_proto, sbc_proto_filepath)
  cat("SBC prototype results computed and saved.\n")
} else {
  sbc_results_proto <- readRDS(sbc_proto_filepath)
  cat("Loaded existing SBC prototype results.\n")
}
Loaded existing SBC prototype results.
Code
plot_ecdf_diff(sbc_results_proto)

Code
plot_rank_hist(sbc_results_proto)

Code
# ── Diagnostic issue investigation ───────────────────────────────────────────
backend_diags  <- sbc_results_proto$backend_diagnostics
default_diags  <- sbc_results_proto$default_diagnostics |>
  dplyr::select(sim_id, max_rhat, min_ess_to_rank)
true_params_wide <- sbc_results_proto$stats |>
  dplyr::select(sim_id, variable, simulated_value) |>
  tidyr::pivot_wider(names_from = variable, values_from = simulated_value)

diag_df <- backend_diags |>
  dplyr::left_join(default_diags,    by = "sim_id") |>
  dplyr::left_join(true_params_wide, by = "sim_id") |>
  dplyr::mutate(has_issue = n_divergent > 0 | n_max_treedepth > 0 | max_rhat > 1.01)

p_diag_r_q <- ggplot(diag_df, aes(x = log_r, y = log_q, color = has_issue)) +
  geom_point(alpha = 0.7, size = 2) +
  scale_color_manual(values = c("FALSE" = "#0072B2", "TRUE" = "#D55E00")) +
  labs(title    = "Diagnostic Issues in Parameter Space",
       subtitle = "Orange points: divergence, max treedepth, or Rhat > 1.01",
       x = "True log(r)", y = "True log(q)") +
  theme(legend.position = "bottom")

print(p_diag_r_q)

Code
# ── Parameter recovery on SBC simulations ────────────────────────────────────
recovery_df <- sbc_results_proto$stats

p_sbc_recovery <- ggplot(recovery_df, aes(x = simulated_value, y = mean)) +
  geom_point(alpha = 0.4, color = "#0072B2") +
  geom_abline(intercept = 0, slope = 1, color = "red", linetype = "dashed") +
  facet_wrap(~variable, scales = "free") +
  labs(
    title    = "Parameter Recovery Across SBC Simulations",
    subtitle = "Posterior mean vs True generated value",
    x = "True Value (simulated_value)",
    y = "Estimated Value (Posterior Mean)"
  ) +
  theme_bw()

print(p_sbc_recovery)


15.9 Non-Stationary Environments: When Does \(q\) Become Identifiable?

The SBC result above hints at a structural tension. Under the static Kruschke paradigm, \(\log r\) shows mild systematic overestimation while \(\log q\) exhibits a prior-dominated, overdispersed posterior — ranks cluster at the extremes rather than uniformly across the distribution, indicating the posterior is correctly wide rather than incorrectly confident. Crucially, both curves remain within the simultaneous bands, so the model is calibrated: the posterior honestly represents its own uncertainty. But the patterns reveal that \(\log q\) is weakly identified when nothing drifts — many \((r, q)\) combinations explain static choices about equally well, and the posterior spreads across a wide range of process-noise values. This is not a coding bug or a sampler problem: it is a genuine identifiability limitation caused by a mismatch between the parameter (process noise controlling prototype drift) and the task (fixed category centres with no drift to detect).

Three generative processes stress this identifiability question differently. We implement all three and run the full validation battery on each.

As established in the canonical scenario framework introduced in the previous chapter, all three categorization chapters use the same three scenarios in the same order (static → contingent shift → drift), reflecting increasing assumption match between the scenario’s data-generating process and the model’s structural commitments.

  1. Static Kruschke (baseline). Fixed category labels. The Kalman filter’s process-noise prior assumes smooth Gaussian drift; under stationary data, no amount of \(q\) improves fit once transient learning is over. We expect \(\log q\) to remain weakly identified — this is the failure case we need to diagnose clearly.

  2. Performance-contingent abrupt shifts. After the agent reaches a streak of \(k\) consecutive correct responses, all category labels flip. Post-flip recovery is fast for agents whose \(q\) lets the covariance re-inflate after surprise, and slow for agents whose \(q\) is tiny — so \(q\) acquires a direct behavioural signature. But the true process has abrupt jumps, which the filter’s Gaussian-drift prior does not literally contain; this is a misspecified-but-informative scenario. Note that these shifts are endogenous to the agent’s own accuracy — a noisier agent triggers fewer flips and accumulates less information about \(q\) (see previous chapter for discussion).

  3. Smooth drift. Category centres in feature space follow an independent random walk; labels are assigned trial-by-trial by proximity to the current drifted centres. This is the assumption-matched scenario: the Kalman filter’s generative story (prototypes drift as Gaussian random walks) is now literally true of the environment. We expect \(\log q\) to be cleanly identified and SBC to pass.

The three scenarios span the spectrum of assumption match — none → misspecified → exact — and the SBC signatures should sort in that order.

15.9.1 Scenario-Aware Simulator

Code
simulate_prototype_scenario <- function(r_value, q_value, schedule,
                                        scenario = c("static",
                                                     "contingent_shift",
                                                     "drift"),
                                        streak_target      = 8,
                                        drift_traj         = NULL,
                                        initial_mu_cat0    = c(2.5, 2.5),
                                        initial_mu_cat1    = c(2.5, 2.5),
                                        initial_sigma_diag = 10.0) {
  scenario <- match.arg(scenario)
  ntrials  <- nrow(schedule)
  obs_mat  <- as.matrix(schedule[, c("height", "position")])
  base_lab <- schedule$category_feedback
  nfeat    <- 2L

  mu0    <- initial_mu_cat0
  mu1    <- initial_mu_cat1
  sigma0 <- diag(initial_sigma_diag, nfeat)
  sigma1 <- diag(initial_sigma_diag, nfeat)
  R_mat  <- diag(r_value, nfeat)
  Q_mat  <- diag(q_value, nfeat)
  I_mat  <- diag(nfeat)

  label_flip <- FALSE
  streak     <- 0L

  prob_cat1         <- numeric(ntrials)
  sim_response      <- integer(ntrials)
  observed_feedback <- integer(ntrials)
  label_flip_trace  <- logical(ntrials)

  for (i in seq_len(ntrials)) {
    x <- as.numeric(obs_mat[i, ])

    # ── Prediction step ──────────────────────────────────────────────
    sigma0 <- sigma0 + Q_mat
    sigma1 <- sigma1 + Q_mat

    # ── Decision ─────────────────────────────────────────────────────
    cov0 <- sigma0 + R_mat
    cov1 <- sigma1 + R_mat
    lp0  <- tryCatch(mvtnorm::dmvnorm(x, mean = mu0, sigma = cov0, log = TRUE),
                     error = function(e) -Inf)
    lp1  <- tryCatch(mvtnorm::dmvnorm(x, mean = mu1, sigma = cov1, log = TRUE),
                     error = function(e) -Inf)
    m    <- max(lp0, lp1)
    p    <- exp(lp1 - (m + log(exp(lp0 - m) + exp(lp1 - m))))
    p    <- pmax(1e-9, pmin(1 - 1e-9, p))
    prob_cat1[i]    <- p
    sim_response[i] <- rbinom(1, 1, p)

    # ── Observed feedback depends on the scenario ────────────────────
    fb <- switch(scenario,
      static           = base_lab[i],
      contingent_shift = if (label_flip) 1L - base_lab[i] else base_lab[i],
      drift            = {
        d0 <- sum(abs(x - drift_traj$mu0[i, ]))
        d1 <- sum(abs(x - drift_traj$mu1[i, ]))
        as.integer(d1 < d0)
      }
    )
    observed_feedback[i] <- fb
    label_flip_trace[i]  <- label_flip

    # Update the streak/flip state for the contingent scenario
    if (scenario == "contingent_shift") {
      streak <- if (sim_response[i] == fb) streak + 1L else 0L
      if (streak >= streak_target) { label_flip <- !label_flip; streak <- 0L }
    }

    # ── Measurement update for the observed-feedback category ────────
    if (fb == 1) {
      innov  <- x - mu1
      S      <- sigma1 + R_mat
      K      <- sigma1 %*% solve(S)
      mu1    <- as.numeric(mu1 + K %*% innov)
      IK     <- I_mat - K
      sigma1 <- IK %*% sigma1 %*% t(IK) + K %*% R_mat %*% t(K)
      sigma1 <- (sigma1 + t(sigma1)) / 2
    } else {
      innov  <- x - mu0
      S      <- sigma0 + R_mat
      K      <- sigma0 %*% solve(S)
      mu0    <- as.numeric(mu0 + K %*% innov)
      IK     <- I_mat - K
      sigma0 <- IK %*% sigma0 %*% t(IK) + K %*% R_mat %*% t(K)
      sigma0 <- (sigma0 + t(sigma0)) / 2
    }
  }

  schedule |>
    mutate(
      prob_cat1         = prob_cat1,
      sim_response      = sim_response,
      observed_feedback = observed_feedback,
      label_flip        = label_flip_trace,
      correct           = as.integer(sim_response == observed_feedback)
    )
}

15.9.2 Drift Trajectory Generator

The drift scenario pre-computes a trajectory for both category centres as independent bivariate Gaussian random walks.

Code
make_drift_trajectory <- function(ntrials, drift_sigma = 0.05, seed = 1) {
  set.seed(seed)
  mu0_init <- c(1.75, 3.25)   # Kruschke cat-0 mean (height, position)
  mu1_init <- c(3.25, 1.75)   # Kruschke cat-1 mean
  d0 <- apply(matrix(rnorm(ntrials * 2, 0, drift_sigma), ncol = 2), 2, cumsum)
  d1 <- apply(matrix(rnorm(ntrials * 2, 0, drift_sigma), ncol = 2), 2, cumsum)
  list(
    mu0 = sweep(d0, 2, mu0_init, "+"),
    mu1 = sweep(d1, 2, mu1_init, "+")
  )
}

15.9.3 Scenario Agent Wrapper

Code
simulate_prototype_agent_scenario <- function(agent_id, scenario,
                                              r_value, q_value,
                                              subject_seed,
                                              streak_target = 8,
                                              drift_sigma   = 0.05) {
  schedule <- make_subject_schedule(stimulus_info, n_blocks, seed = subject_seed)
  drift_tr <- if (scenario == "drift")
    make_drift_trajectory(nrow(schedule), drift_sigma, seed = subject_seed)
  else NULL

  out <- simulate_prototype_scenario(
    r_value       = r_value,
    q_value       = q_value,
    schedule      = schedule,
    scenario      = scenario,
    streak_target = streak_target,
    drift_traj    = drift_tr
  )

  out |>
    mutate(
      agent_id     = agent_id,
      scenario     = scenario,
      r_value_true = r_value,
      q_value_true = q_value,
      log_r_true   = log(r_value),
      log_q_true   = log(q_value)
    ) |>
    group_by(agent_id) |>
    mutate(performance = cumsum(correct) / row_number()) |>
    ungroup()
}

15.9.4 Visualising the Three Scenarios

We simulate 10 agents per cell across a 3 × 3 grid of scenarios × \(q\) values (at fixed \(r = 1.5\)) and plot cumulative accuracy.

Code
scenario_viz_file <- here("simdata", "ch13_scenario_viz.csv")

if (regenerate_simulations || !file.exists(scenario_viz_file)) {
  viz_grid <- expand_grid(
    scenario    = c("static", "contingent_shift", "drift"),
    q_value     = c(0.02, 0.15, 0.6),
    agent_id    = 1:10
  ) |>
    mutate(subject_seed = 30000 + row_number())

  viz_sim <- pmap_dfr(
    list(agent_id     = viz_grid$agent_id,
         scenario     = viz_grid$scenario,
         q_value      = viz_grid$q_value,
         subject_seed = viz_grid$subject_seed),
    function(agent_id, scenario, q_value, subject_seed) {
      simulate_prototype_agent_scenario(
        agent_id     = agent_id,
        scenario     = scenario,
        r_value      = 1.5,
        q_value      = q_value,
        subject_seed = subject_seed
      )
    }
  )
  write_csv(viz_sim, scenario_viz_file)
} else {
  viz_sim <- read_csv(scenario_viz_file, show_col_types = FALSE)
}

ggplot(viz_sim |>
         dplyr::mutate(scenario = factor(scenario,
                                         levels = c("static",
                                                    "contingent_shift",
                                                    "drift"))),
       aes(x = trial_within_subject, y = cumulative_accuracy,
           color = factor(q_value_true))) +
  stat_summary(fun = mean, geom = "line",
               aes(group = factor(q_value_true)), linewidth = 1.1) +
  stat_summary(fun.data = mean_se, geom = "ribbon", alpha = 0.15,
               aes(fill = factor(q_value_true), group = factor(q_value_true)),
               color = NA) +
  facet_wrap(~scenario, ncol = 3,
             labeller = as_labeller(c(
               static           = "Static Kruschke",
               contingent_shift = "Performance-contingent flips",
               drift            = "Smooth drift"))) +
  scale_color_viridis_d(option = "plasma", name = "q") +
  scale_fill_viridis_d(option  = "plasma", name = "q") +
  geom_hline(yintercept = 0.5, linetype = "dashed", color = "grey50") +
  labs(
    title    = "Prototype Learning Curves Across Scenarios",
    subtitle = "Mean ± SE across 10 agents per cell; r = 1.5",
    x = "Trial", y = "Cumulative accuracy"
  ) +
  theme(legend.position = "bottom")

Reading the panels. The static panel (leftmost) shows a clear ordering by \(q\) at asymptote: low \(q\) (\(= 0.02\), dark blue) eventually achieves the highest cumulative accuracy (\(\approx 0.72\)), but with a notable early dip around trials 5–10 before recovering. This dip occurs because a very small process-noise budget forces the Kalman covariance to shrink aggressively — the model becomes overconfident in an initial (wrong) prototype — and it takes several corrective trials to escape. Medium \(q\) (\(= 0.15\), pink) climbs smoothly to a similar asymptote. High \(q\) (\(= 0.6\), yellow) plateaus around 0.60–0.62: perpetually high process noise keeps the covariance inflated and the Kalman gain elevated, so the model keeps updating but also keeps being pulled off-centre by noise, confirming that assuming drift when none exists is costly.

The performance-contingent flips panel (middle) shows that medium \(q\) achieves the highest asymptote, with low \(q\) close behind, and high \(q\) underperforming by about 5–7 percentage points. After each label flip the model must re-learn the reassigned prototypes; a moderate \(q\) allows the covariance to re-inflate enough to absorb the shock, while high \(q\) creates too much perpetual uncertainty to track the post-flip regime efficiently. Note that cumulative accuracy smooths out the within-flip saw-teeth; the underlying trial-level accuracy does dip at each flip for all conditions.

The smooth drift panel (rightmost) yields a counter-intuitive ordering: low \(q\) (\(= 0.02\)) reaches the highest asymptote (\(\approx 0.72\)), while high \(q\) (\(= 0.6\)) is the worst (\(\approx 0.66\)), and medium \(q\) sits between. Under continuous Gaussian drift, the Kalman gain never fully collapses even at \(q = 0.02\) — the drift itself keeps re-inflating the covariance — so the expected advantage of higher \(q\) does not materialise with only 65 trials. All three curves show wider early ribbons because drift-induced label assignments are noisy at the very start of each block. The first few trials also drive the spike visible around trial 2–3, where cumulative accuracy briefly exceeds 0.8 for some conditions due to lucky initial assignments before settling.

15.9.5 Fitting One Agent per Scenario (Original Parameterisation)

We pick one representative agent per scenario with \(r = 1.5\), \(q = 0.15\), and fit the Kalman prototype model (original (log_r, log_q) parameterisation) to each.

Code
scenario_fit_file <- here("simmodels", "ch13_prototype_scenario_fits.rds")

scenario_single_fit_data <- list(
  static           = simulate_prototype_agent_scenario(
                        1, "static",           r_value = 1.5, q_value = 0.15,
                        subject_seed = 40001),
  contingent_shift = simulate_prototype_agent_scenario(
                        1, "contingent_shift", r_value = 1.5, q_value = 0.15,
                        subject_seed = 40002),
  drift            = simulate_prototype_agent_scenario(
                        1, "drift",            r_value = 1.5, q_value = 0.15,
                        subject_seed = 40003)
)

stan_data_from_scenario_agent <- function(agent_df) {
  list(
    ntrials            = nrow(agent_df),
    nfeatures          = 2L,
    cat_one            = as.integer(agent_df$observed_feedback),
    y                  = as.integer(agent_df$sim_response),
    obs                = as.matrix(agent_df[, c("height", "position")]),
    initial_mu_cat0    = c(2.5, 2.5),
    initial_mu_cat1    = c(2.5, 2.5),
    initial_sigma_diag = 10.0,
    prior_logr_mean    = LOG_R_PRIOR_MEAN,
    prior_logr_sd      = LOG_R_PRIOR_SD,
    prior_logq_mean    = LOG_Q_PRIOR_MEAN,
    prior_logq_sd      = LOG_Q_PRIOR_SD
  )
}

if (regenerate_simulations || !file.exists(scenario_fit_file)) {
  scenario_fits <- purrr::imap(scenario_single_fit_data, function(agent_df, scen) {
    dat <- stan_data_from_scenario_agent(agent_df)
    pf  <- tryCatch(
      mod_prototype_single$pathfinder(data = dat, num_paths = 4,
                                      refresh = 0, seed = 500),
      error = function(e) NULL
    )
    fit <- mod_prototype_single$sample(
      data            = dat,
      init            = if (!is.null(pf)) pf else 0.5,
      seed            = 500,
      chains          = 4,
      parallel_chains = min(4, availableCores()),
      iter_warmup     = 1000,
      iter_sampling   = 1500,
      refresh         = 0,
      adapt_delta     = 0.9
    )
    fit$save_object(here("simmodels",
                         paste0("ch13_prototype_scenario_", scen, ".rds")))
    fit
  })
  saveRDS(lapply(scenario_fits, function(f) f$output_files()),
          scenario_fit_file)
  cat("Scenario fits computed and saved.\n")
} else {
  scenario_fits <- setNames(
    lapply(names(scenario_single_fit_data), function(scen) {
      readRDS(here("simmodels",
                   paste0("ch13_prototype_scenario_", scen, ".rds")))
    }),
    names(scenario_single_fit_data)
  )
  cat("Loaded existing scenario fits.\n")
}
Loaded existing scenario fits.

15.9.6 MCMC Diagnostic Battery Across Scenarios

Code
scenario_diag_tbls <- purrr::imap_dfr(scenario_fits, function(fit, scen) {
  tbl <- diagnostic_summary_table(fit, params = c("log_r", "log_q",
                                                  "r_value", "q_value"))
  tbl |> mutate(scenario = scen)
})
print(scenario_diag_tbls |>
        dplyr::select(scenario, metric, value, threshold, pass))
# A tibble: 18 × 5
   scenario         metric                           value threshold pass 
   <chr>            <chr>                            <dbl> <chr>     <lgl>
 1 static           Divergences (zero tolerance)    0      == 0      TRUE 
 2 static           Max rank-normalised R-hat       1      < 1.01    TRUE 
 3 static           Min bulk ESS                 2578      > 400     TRUE 
 4 static           Min tail ESS                 2361      > 400     TRUE 
 5 static           Min E-BFMI                      0.919  > 0.2     TRUE 
 6 static           Max MCSE / posterior SD         0.0197 < 0.05    TRUE 
 7 contingent_shift Divergences (zero tolerance)    0      == 0      TRUE 
 8 contingent_shift Max rank-normalised R-hat       1      < 1.01    TRUE 
 9 contingent_shift Min bulk ESS                 2840      > 400     TRUE 
10 contingent_shift Min tail ESS                 2272      > 400     TRUE 
11 contingent_shift Min E-BFMI                      0.94   > 0.2     TRUE 
12 contingent_shift Max MCSE / posterior SD         0.0188 < 0.05    TRUE 
13 drift            Divergences (zero tolerance)    0      == 0      TRUE 
14 drift            Max rank-normalised R-hat       1      < 1.01    TRUE 
15 drift            Min bulk ESS                 2892      > 400     TRUE 
16 drift            Min tail ESS                 2847      > 400     TRUE 
17 drift            Min E-BFMI                      0.949  > 0.2     TRUE 
18 drift            Max MCSE / posterior SD         0.0186 < 0.05    TRUE 
Code
walk(names(scenario_fits), function(scen) {
  cat("=== Pair plot:", scen, "===\n")
  print(bayesplot::mcmc_pairs(
    scenario_fits[[scen]]$draws(c("log_r", "log_q")),
    diag_fun     = "dens",
    off_diag_fun = "hex",
    np           = nuts_params(scenario_fits[[scen]])
  ))
})
=== Pair plot: static ===

=== Pair plot: contingent_shift ===

=== Pair plot: drift ===

MCMC diagnostics. The table shows clean sampling across all three scenarios: zero divergences, R-hat = 1 exactly, bulk ESS in the 2578–2892 range, tail ESS in the 2272–2847 range, E-BFMI comfortably above 0.2, and MCSE/posterior SD well below 0.05. HMC is traversing the posterior geometry efficiently in all conditions.

What the pair plots reveal. The three scenarios produce qualitatively different posterior geometries that directly reflect how much information each environment provides about each parameter.

In the static scenario, log_r is reasonably well-identified — a compact, bell-shaped marginal centred around \(\log r \approx 0.7{-}0.8\). The log_q marginal is much broader and fat-tailed, spanning roughly \(-5\) to \(0\), reflecting the familiar weak identifiability of process noise when nothing drifts. The joint distribution shows a moderate negative correlation: higher observation noise paired with lower process noise produces similar behaviour to lower observation noise paired with higher process noise. Importantly, this is a diffuse cloud rather than a knife-edge ridge — the posterior has mass across a wide region, not along a thin diagonal line, which is why the sampler has no trouble and ESS remains high.

In the contingent-shift scenario, both marginals are noticeably more concentrated. log_r centres around \(\log r \approx 0.1{-}0.2\) and log_q is a tighter, near-Gaussian distribution centred around \(\log q \approx -2.5\). The joint cloud is more compact and less elongated than static, consistent with the SBC finding that log_q is well-calibrated in this condition. Label-flip dynamics provide specific leverage on \(q\): post-flip recovery speed depends on how quickly the covariance can re-inflate, and that speed scales with \(q\).

The smooth drift scenario is the most surprising. Despite being the assumption-matched condition, log_q has the widest and heaviest-tailed marginal of the three — extending from roughly \(-5\) to \(+1.5\) and with a left tail reaching below \(-5\) in some draws. log_r is well-identified (compact, centred around \(\log r \approx 0.3{-}0.4\)), but the joint cloud is visibly more diffuse than static, especially in the log_q direction. The explanation is a floor effect: once \(q\) exceeds some minimum threshold, the filter tracks continuous drift adequately regardless of the exact value — further increases in \(q\) provide diminishing marginal improvement in tracking accuracy. A single agent’s 65-trial trajectory cannot distinguish \(q = 0.05\) from \(q = 0.5\) when both keep the Kalman gain sufficiently elevated. This is not a sampler failure (ESS = 2892) but a genuine likelihood plateau.

15.9.7 Prior → Posterior Updates Across Scenarios

Code
scenario_pp_df <- purrr::imap_dfr(scenario_fits, function(fit, scen) {
  d <- as_draws_df(fit$draws(c("log_r", "log_q")))
  bind_rows(
    tibble(scenario = scen, parameter = "log_r", value = d$log_r, source = "posterior"),
    tibble(scenario = scen, parameter = "log_q", value = d$log_q, source = "posterior")
  )
})

set.seed(21)
n_prior <- 4000
prior_ref <- bind_rows(
  tibble(parameter = "log_r", value = rnorm(n_prior, LOG_R_PRIOR_MEAN, LOG_R_PRIOR_SD),
         source = "prior"),
  tibble(parameter = "log_q", value = rnorm(n_prior, LOG_Q_PRIOR_MEAN, LOG_Q_PRIOR_SD),
         source = "prior")
) |>
  tidyr::crossing(scenario = c("static", "contingent_shift", "drift"))

truth_lines <- tibble(
  parameter = c("log_r", "log_q"),
  value     = c(log(1.5), log(0.15))
)

ggplot(bind_rows(prior_ref, scenario_pp_df) |>
         dplyr::mutate(scenario = factor(scenario,
                                          levels = c("static",
                                                     "contingent_shift",
                                                     "drift"))),
       aes(x = value, fill = source, color = source)) +
  geom_density(alpha = 0.4, linewidth = 0.5) +
  geom_vline(data = truth_lines, aes(xintercept = value),
             color = "red", linetype = "dashed") +
  facet_grid(scenario ~ parameter, scales = "free") +
  scale_fill_manual(values  = c(prior = "grey70", posterior = "#009E73")) +
  scale_color_manual(values = c(prior = "grey50", posterior = "#006D4E")) +
  labs(
    title    = "Prior → Posterior Update: Prototype Model Across Scenarios",
    subtitle = "Red dashed line = true generating value used for the single-agent fit",
    x = NULL, y = "Density"
  )

Reading the prior–posterior updates. The two parameters tell very different stories.

log_r (right column) shows equally strong and consistent updating in all three scenarios. The posterior (green) is sharply peaked close to the true generating value (\(\log(1.5) \approx 0.4\), red dashed line) in every row, with the prior (grey) clearly wider. The peak posterior density exceeds 1.0 in all three conditions. This confirms that observation noise is well-identified regardless of whether categories are static, intermittently flipped, or continuously drifting — the sharpness of within-category decisions provides leverage on \(r\) in any environment.

log_q (left column) is more nuanced. All three scenarios show genuine, if moderate, updating: the posterior is narrower than the prior and shifted leftward relative to it. However, in all three cases the posterior mode sits to the left of the true value (\(\log(0.15) \approx -2\), red dashed line) — roughly at \(-2.5\) to \(-3\) — meaning the model systematically underestimates the true process-noise level for this particular agent. Among the three, the contingent-shift scenario produces the most concentrated log_q posterior (sharpest peak, highest density around \(\sim 0.5\)), while static and drift are comparable and somewhat flatter. The expected gradient — drift best, static worst — does not materialise: the assumption-matched drift scenario does not unlock tighter identification of \(q\) than the misspecified contingent-shift scenario, at least for this single agent’s trajectory. The prior is still visibly contributing to all three posteriors, consistent with the power-scaling analysis and the upward-bias pattern confirmed by the SBC below.

15.9.8 One-Shot Parameter Recovery Across Scenarios

We sweep a 3 × 3 grid of \((r, q)\) truths with 5 iterations per cell per scenario (135 fits) and plot recovery on the log scale.

Code
scenario_recovery_file <- here("simdata", "ch13_scenario_recovery_results.rds")

run_proto_scenario_recovery_iter <- function(scenario, r_true, q_true, iteration) {
  subject_seed <- iteration * 10000 +
                   round(r_true * 100) +
                   round(q_true * 1000) +
                   switch(scenario, static = 1, contingent_shift = 2, drift = 3)

  agent_df <- simulate_prototype_agent_scenario(
    agent_id     = iteration,
    scenario     = scenario,
    r_value      = r_true,
    q_value      = q_true,
    subject_seed = subject_seed
  )
  dat <- stan_data_from_scenario_agent(agent_df)
  pf  <- tryCatch(
    mod_prototype_single$pathfinder(data = dat, num_paths = 4,
                                    refresh = 0, seed = subject_seed),
    error = function(e) NULL
  )
  fit <- tryCatch(
    mod_prototype_single$sample(
      data            = dat,
      init            = if (!is.null(pf)) pf else 0.5,
      seed            = subject_seed,
      chains          = 2,
      parallel_chains = 2,
      iter_warmup     = 800,
      iter_sampling   = 1000,
      refresh         = 0,
      adapt_delta     = 0.9
    ),
    error = function(e) { message("Fit error: ", e$message); NULL }
  )
  if (is.null(fit)) return(NULL)

  s <- fit$summary(variables = c("log_r", "log_q"))
  diag <- fit$diagnostic_summary(quiet = TRUE)
  tibble(
    scenario       = scenario,
    r_true         = r_true,
    q_true         = q_true,
    iteration      = iteration,
    log_r_mean     = s$mean[s$variable == "log_r"],
    log_q_mean     = s$mean[s$variable == "log_q"],
    log_r_q5       = s$q5[s$variable == "log_r"],
    log_r_q95      = s$q95[s$variable == "log_r"],
    log_q_q5       = s$q5[s$variable == "log_q"],
    log_q_q95      = s$q95[s$variable == "log_q"],
    had_divergence = any(diag$num_divergent > 0)
  )
}

if (regenerate_simulations || !file.exists(scenario_recovery_file)) {
  recovery_plan <- expand_grid(
    scenario  = c("static", "contingent_shift", "drift"),
    r_true    = c(0.5, 1.5, 3.0),
    q_true    = c(0.02, 0.15, 0.6),
    iteration = 1:5
  )

  plan(multisession, workers = max(1, availableCores() - 1))
  scenario_recovery <- future_pmap_dfr(
    list(scenario  = recovery_plan$scenario,
         r_true    = recovery_plan$r_true,
         q_true    = recovery_plan$q_true,
         iteration = recovery_plan$iteration),
    run_proto_scenario_recovery_iter,
    .options = furrr_options(seed = TRUE, scheduling = 1)
  )
  plan(sequential)
  saveRDS(scenario_recovery, scenario_recovery_file)
  cat("Scenario recovery sweep complete.\n")
} else {
  scenario_recovery <- readRDS(scenario_recovery_file)
  cat("Loaded existing scenario recovery sweep.\n")
}
Loaded existing scenario recovery sweep.
Code
scenario_recovery_long <- scenario_recovery |>
  dplyr::select(scenario, r_true, q_true, iteration, had_divergence,
                log_r_mean, log_q_mean) |>
  pivot_longer(cols = c(log_r_mean, log_q_mean),
               names_to = "variable", values_to = "estimated") |>
  mutate(
    parameter = if_else(variable == "log_r_mean", "log_r", "log_q"),
    truth     = if_else(parameter == "log_r", log(r_true), log(q_true))
  )

ggplot(scenario_recovery_long |>
         dplyr::mutate(scenario = factor(scenario,
                                          levels = c("static",
                                                     "contingent_shift",
                                                     "drift"))),
       aes(x = truth, y = estimated, color = had_divergence)) +
  geom_abline(intercept = 0, slope = 1, linetype = "dashed", color = "grey50") +
  geom_point(alpha = 0.7, size = 2) +
  facet_grid(parameter ~ scenario, scales = "free") +
  scale_color_manual(values = c("FALSE" = "#0072B2", "TRUE" = "#D55E00"),
                     name = "Had divergence") +
  labs(
    title    = "One-Shot Parameter Recovery for the Kalman Prototype Model",
    subtitle = "Posterior mean vs true generating value; 5 iterations per (r, q) cell, 3 scenarios",
    x = "True value (log scale)", y = "Estimated value (posterior mean)"
  ) +
  theme(legend.position = "bottom")

Interpretation. All 135 fits completed without divergences (every point is the same blue), confirming clean sampling geometry across the full \((r, q)\) grid.

log_r (bottom row) is recovered reasonably well in all three scenarios: points cluster near the dashed diagonal, with scatter attributable to single-agent noise rather than systematic bias. All three panels look similar — within-category decision sharpness informs \(r\) regardless of whether the environment is static, intermittently flipped, or continuously drifting.

log_q (top row) tells a different and more sobering story. Across all three scenarios, recovery is good only in the middle of the prior range (true \(\log q \approx -2\)): points near the diagonal there. At both extremes the recovery compresses toward that same middle region. For very low true \(\log q\) (around \(-4\)), the posterior mean is pulled upward (toward \(-2\) to \(-3\)) — the model cannot confidently identify very small process noise. For high true \(\log q\) (above \(-1\)), the posterior mean is pulled downward toward \(-2\) — the model saturates and cannot distinguish moderate from large drift rates. The result is that the posterior mean for \(\log q\) is effectively anchored near the prior centre across a wide range of true values.

Contrary to what one might expect, the three scenarios do not sort cleanly by how well \(q\) is recovered. All three show the same compression pattern; static is perhaps marginally more compressed at high true values, but the difference is small. The contingent-shift scenario offers no clear advantage over drift for \(\log q\) recovery at this sample size — a result consistent with the prior–posterior overlays and SBC patterns discussed above. The takeaway is not that recovery fails, but that a single agent’s data can only identify \(q\) within a fairly narrow effective range, regardless of the environment’s dynamics.

15.9.9 Simulation-Based Calibration Checks Across Scenarios

SBC is the formal calibration test. Each scenario gets its own generator drawing \((\log r, \log q)\) from their priors.

Code
make_proto_sbc_generator <- function(scenario) {
  SBC_generator_function(
    function() {
      log_r_true <- rnorm(1, LOG_R_PRIOR_MEAN, LOG_R_PRIOR_SD)
      log_q_true <- rnorm(1, LOG_Q_PRIOR_MEAN, LOG_Q_PRIOR_SD)
      r_true     <- exp(log_r_true)
      q_true     <- exp(log_q_true)

      subject_seed <- sample.int(1e8, 1)
      agent_df <- simulate_prototype_agent_scenario(
        agent_id     = 1,
        scenario     = scenario,
        r_value      = r_true,
        q_value      = q_true,
        subject_seed = subject_seed
      )

      list(
        variables = list(log_r = log_r_true, log_q = log_q_true),
        generated = stan_data_from_scenario_agent(agent_df)
      )
    }
  )
}

run_proto_scenario_sbc <- function(scenario, n_sims = 200) {
  cache_file <- here("simdata", paste0("ch13_proto_sbc_", scenario, ".rds"))
  if (!regenerate_simulations && file.exists(cache_file)) {
    cat("Loaded cached prototype SBC for scenario:", scenario, "\n")
    return(readRDS(cache_file))
  }
  gen <- make_proto_sbc_generator(scenario)
  backend <- SBC_backend_cmdstan_sample(
    mod_prototype_single,
    iter_warmup   = 800,
    iter_sampling = 1000,
    chains        = 2,
    adapt_delta   = 0.9,
    refresh       = 0
  )
  datasets <- generate_datasets(gen, n_sims = n_sims)
  res <- compute_SBC(
    datasets,
    backend,
    cache_mode     = "results",
    cache_location = here("simdata", paste0("ch13_proto_sbc_cache_", scenario)),
    keep_fits      = FALSE
  )
  saveRDS(res, cache_file)
  res
}

proto_sbc_static <- run_proto_scenario_sbc("static",           n_sims = 200)
Loaded cached prototype SBC for scenario: static 
Code
proto_sbc_shift  <- run_proto_scenario_sbc("contingent_shift", n_sims = 200)
Loaded cached prototype SBC for scenario: contingent_shift 
Code
proto_sbc_drift  <- run_proto_scenario_sbc("drift",            n_sims = 200)
Loaded cached prototype SBC for scenario: drift 
Code
print(plot_ecdf_diff(proto_sbc_static) + ggtitle("SBC: static"))

Code
print(plot_ecdf_diff(proto_sbc_shift)  + ggtitle("SBC: performance-contingent shifts"))

Code
print(plot_ecdf_diff(proto_sbc_drift)  + ggtitle("SBC: smooth drift"))

Code
print(plot_rank_hist(proto_sbc_static) + ggtitle("Rank histograms — static"))

Code
print(plot_rank_hist(proto_sbc_shift)  + ggtitle("Rank histograms — contingent shifts"))

Code
print(plot_rank_hist(proto_sbc_drift)  + ggtitle("Rank histograms — drift"))

Code
# Summarise divergences and other sampler issues across the three scenarios
extract_diag <- function(sbc_obj, scenario_label) {
  backend_diags <- sbc_obj$backend_diagnostics
  default_diags <- sbc_obj$default_diagnostics |>
    dplyr::select(sim_id, max_rhat, min_ess_to_rank)
  true_params <- sbc_obj$stats |>
    dplyr::select(sim_id, variable, simulated_value) |>
    tidyr::pivot_wider(names_from = variable, values_from = simulated_value)

  backend_diags |>
    dplyr::left_join(default_diags, by = "sim_id") |>
    dplyr::left_join(true_params,   by = "sim_id") |>
    dplyr::mutate(
      scenario  = scenario_label,
      has_issue = n_divergent > 0 | n_max_treedepth > 0 | max_rhat > 1.01
    )
}

diag_all <- bind_rows(
  extract_diag(proto_sbc_static, "static"),
  extract_diag(proto_sbc_shift,  "contingent_shift"),
  extract_diag(proto_sbc_drift,  "drift")
)

# Summary table: divergence rates per scenario
diag_all |>
  dplyr::group_by(scenario) |>
  dplyr::summarise(
    n_sims         = dplyr::n(),
    n_divergent    = sum(n_divergent > 0),
    pct_divergent  = round(100 * mean(n_divergent > 0), 1),
    n_treedepth    = sum(n_max_treedepth > 0),
    n_high_rhat    = sum(max_rhat > 1.01, na.rm = TRUE),
    mean_div_count = round(mean(n_divergent), 2),
    .groups = "drop"
  ) |>
  print()
# A tibble: 3 × 7
  scenario         n_sims n_divergent pct_divergent n_treedepth n_high_rhat
  <chr>             <int>       <int>         <dbl>       <int>       <int>
1 contingent_shift    200           0             0           0           1
2 drift               200           0             0           0           2
3 static              200           0             0           0           4
# ℹ 1 more variable: mean_div_count <dbl>
Code
# Where in parameter space do divergences cluster?
ggplot(diag_all |>
         dplyr::mutate(scenario = factor(scenario,
                                          levels = c("static",
                                                     "contingent_shift",
                                                     "drift"))),
       aes(x = log_r, y = log_q, color = has_issue)) +
  geom_point(alpha = 0.6, size = 1.8) +
  scale_color_manual(values = c("FALSE" = "#0072B2", "TRUE" = "#D55E00"),
                     labels = c("FALSE" = "OK", "TRUE" = "Issue")) +
  facet_wrap(~scenario) +
  labs(
    title    = "Diagnostic Issues in Parameter Space — Scenario SBC",
    subtitle = "Orange: divergence, max treedepth, or Rhat > 1.01",
    x = "True log(r)", y = "True log(q)", color = NULL
  ) +
  theme(legend.position = "bottom")

Code
combine_proto_sbc_stats <- function(sbc_obj, scenario) {
  sbc_obj$stats |> mutate(scenario = scenario)
}

proto_sbc_stats <- bind_rows(
  combine_proto_sbc_stats(proto_sbc_static, "static"),
  combine_proto_sbc_stats(proto_sbc_shift,  "contingent_shift"),
  combine_proto_sbc_stats(proto_sbc_drift,  "drift")
) |>
  dplyr::filter(variable %in% c("log_r", "log_q"))

proto_sbc_stats_ordered <- proto_sbc_stats |>
  dplyr::mutate(scenario = factor(scenario,
                                   levels = c("static",
                                              "contingent_shift",
                                              "drift")))

ggplot(proto_sbc_stats_ordered, aes(x = simulated_value, y = mean)) +
  geom_point(alpha = 0.35, color = "#0072B2") +
  geom_abline(intercept = 0, slope = 1, color = "red", linetype = "dashed") +
  facet_grid(variable ~ scenario, scales = "free") +
  labs(
    title    = "SBC-Level Parameter Recovery: Kalman Prototype",
    subtitle = "Posterior mean vs true generating value, pooled across SBC simulations",
    x = "True value", y = "Posterior mean"
  )

Code
# Natural-scale recovery (r and q, back-transformed from log scale)
proto_sbc_stats_natural <- proto_sbc_stats_ordered |>
  dplyr::mutate(
    variable_natural  = dplyr::case_when(
      variable == "log_r" ~ "r (observation noise)",
      variable == "log_q" ~ "q (process noise)"
    ),
    sim_natural  = exp(simulated_value),
    mean_natural = exp(mean)
  )

ggplot(proto_sbc_stats_natural, aes(x = sim_natural, y = mean_natural)) +
  geom_point(alpha = 0.35, color = "#0072B2") +
  geom_abline(intercept = 0, slope = 1, color = "red", linetype = "dashed") +
  facet_grid(variable_natural ~ scenario, scales = "free") +
  labs(
    title    = "SBC-Level Parameter Recovery: Natural Scale",
    subtitle = "Posterior mean vs true generating value (r = exp(log_r), q = exp(log_q))",
    x = "True value", y = "Posterior mean"
  )

Code
proto_sbc_stats |>
  group_by(scenario, variable) |>
  summarise(
    sbc_mean_bias = mean(mean - simulated_value, na.rm = TRUE),
    sbc_rmse      = sqrt(mean((mean - simulated_value)^2, na.rm = TRUE)),
    sbc_corr      = cor(mean, simulated_value,
                        use = "complete.obs", method = "pearson"),
    n_sims        = dplyr::n(),
    .groups       = "drop"
  ) |>
  print(n = Inf)
# A tibble: 6 × 6
  scenario         variable sbc_mean_bias sbc_rmse sbc_corr n_sims
  <chr>            <chr>            <dbl>    <dbl>    <dbl>  <int>
1 contingent_shift log_q           0.0410    0.889    0.522    200
2 contingent_shift log_r          -0.0372    0.426    0.910    200
3 drift            log_q           0.0649    0.769    0.641    200
4 drift            log_r          -0.0276    0.453    0.893    200
5 static           log_q           0.118     0.828    0.529    200
6 static           log_r          -0.0262    0.411    0.917    200

Reading the three SBC results. All three scenarios pass the formal SBC test — every ECDF-difference curve remains within the simultaneous bands — but the two parameters show opposite patterns, and neither sorts cleanly along the expected static → shift → drift gradient.

log_q (left panel in each plot) shows a broad positive hump in the static and smooth drift scenarios, and near-zero flat fluctuations in the performance-contingent shifts scenario. A positive hump (ECDF above the theoretical CDF in the lower-to-middle quantile range) means ranks are systematically low: the posterior for log_q tends to sit above the true generating value — a systematic upward bias. In the static case this bias is largest (hump peaks around \(+0.08\) from \(u \approx 0.1\)\(0.6\)): when categories never drift, the filter finds no evidence against large process noise, so the posterior for \(q\) is pulled toward higher values than those that actually generated the data. The same upward pull appears in the smooth drift scenario, though slightly attenuated. In the contingent-shift condition the bias disappears — label-flip dynamics provide specific leverage that anchors the posterior near the truth.

log_r (right panel in each plot) shows the opposite signature: a pronounced negative central dip in all three scenarios, deepest in the contingent-shift condition (reaching \(\approx{-}0.07\) around \(u = 0.5\)\(0.7\)). A negative dip means too few ranks land in the middle quantile range — the true value tends to fall in the tails of the posterior rather than near the centre. This is the mark of an overconfident (too narrow) posterior: the model sometimes picks a position on the \(r\)\(q\) likelihood ridge and concentrates its mass there, but the position it picks is not always the truth. The contingent-shift scenario is the worst case for log_r precisely because the label flips tightly constrain \(q\), which pins down one end of the ridge and forces the entire posterior mass onto a narrower, potentially displaced location for \(\log r\).

The divergence investigation (table and scatter above) shows the proportion and parameter-space location of any sampling issues: if issues cluster at extreme log_q values or low log_r, that would flag a geometry problem rather than a conceptual one. Clean sampling across both these diagnostics means the patterns are genuine posterior properties, not sampler artefacts.

The parameter recovery panels (both log-scale and natural-scale) confirm good overall alignment with the diagonal across all three scenarios and both parameters. On the natural scale, the upward bias in \(q\) is amplified by the exponential transform — a modest positive bias in \(\log q\) translates into a larger multiplicative error in \(q\) itself. The formal summary is that all three scenarios achieve calibration (correct uncertainty), but log_q carries a directional bias toward overestimation and log_r occasionally overconcentrates; both effects are within the simultaneous bands and shrink with more data or stronger experimental designs.


15.10 Discussion: When Does \(q\) Become Identifiable?

The three SBC results give the cleanest summary of the identifiability argument.

  • Static. log_q shows a systematic upward bias (positive ECDF hump): when categories never drift, the filter finds no evidence against large process noise and overestimates \(q\). log_r shows mild overconfidence (negative central dip). Both stay within the simultaneous bands — the model is calibrated but not very informative about \(q\). Parameter recovery compresses toward the prior centre for all but middling \(q\) values.
  • Performance-contingent shifts. The log_q bias disappears: post-flip recovery speed provides direct leverage on process noise, anchoring the posterior near the truth. log_r inherits a deeper overconfidence dip — constraining \(q\) via flip dynamics pins one end of the \(r\)\(q\) likelihood ridge and can displace the posterior for \(r\). A caveat applies: the filter’s smooth-drift assumption is technically violated by the abrupt jumps, so posterior predictive checks may reveal systematic underperformance right after each flip.
  • Smooth drift. log_q recovers the upward bias seen in static — even in the assumption-matched scenario, a single agent’s 65-trial trajectory does not always pin down the exact drift rate. The log_r overconfidence dip is also present. The posterior for log_q is the widest and heaviest-tailed of the three scenarios (floor effect: any \(q\) above a minimum threshold tracks drift adequately, making very different \(q\) values hard to distinguish). All checks pass.

The ratio \(\rho = q/r\) is a natural dimensionless quantity that directly governs the Kalman filter’s steady-state gain — it is the signal-to-noise ratio of drift relative to observation variability. When this ratio is far from 1 in either direction, the gain saturates and further changes to \(q\) or \(r\) have diminishing marginal effect on the filter’s behaviour, which is why the posterior for \(q\) becomes diffuse at extremes regardless of the scenario. Practical implication: if the research question centres on \(q\), design tasks with moderate \(r\) and enough trials that the transient dynamics (where the gain is changing) are well-sampled.

Paralleling Chapter 13. The GCM-with-decay analysis in Chapter 13 showed that the decay parameter \(\lambda\) becomes identifiable when the task exercises forgetting — whether through reversals or drift. The mechanism there is general-purpose (“downweight old stuff”), so reversals and drift work comparably well. The Kalman filter’s \(q\) is assumption-specific (it encodes a Gaussian drift prior), so the assumption-matched smooth-drift task gives the cleanest result; reversals work but with a mild misspecification signature. Both chapters make the same underlying point: identifiability is a joint property of the model and the task, not of the model alone.


15.11 Multilevel Prototype Model: Population Inference Across Subjects

Real experiments involve many participants, each bringing their own observation-noise tolerance and drift rate. Some learners trust individual examples heavily (low \(r\)); others discount them (high \(r\)). Some learners have more stable category representations (low \(q\)); others show faster adaptation (high \(q\)). The multilevel prototype model captures this variation by placing population distributions over both log_r and log_q and partially pooling information across subjects. Because the single-agent analyses above already established that \(q\) is identifiable only under dynamic environments, we fit the multilevel model across all three scenarios from the outset rather than starting with the static case.

The structure is the minimal multilevel extension of the two-parameter single-subject model:

\[\text{pop\_log\_r\_mean} \sim \mathcal{N}(0, 1), \quad \text{pop\_log\_r\_sd} \sim \text{Exponential}(2)\] \[z_{r,j} \sim \mathcal{N}(0, 1), \quad \log r_j = \text{pop\_log\_r\_mean} + z_{r,j} \cdot \text{pop\_log\_r\_sd}\] \[\text{pop\_log\_q\_mean} \sim \mathcal{N}(-2, 1), \quad \text{pop\_log\_q\_sd} \sim \text{Exponential}(2)\] \[z_{q,j} \sim \mathcal{N}(0, 1), \quad \log q_j = \text{pop\_log\_q\_mean} + z_{q,j} \cdot \text{pop\_log\_q\_sd}\]

The non-centred parameterisation (NCP) for the per-subject offsets avoids the funnel geometry that centred parameterisations produce when the population SD is small. With two per-subject parameters, the SBC rank histograms (below) show six curves: two population means, two population SDs, and two individual-level curves (subj_log_r[1], subj_log_q[1]).

15.11.1 Multilevel Forward Simulation

Code
# Simulate a full multilevel dataset from specified population parameters.
simulate_multilevel_prototype <- function(
    n_subjects,
    stimulus_info,
    n_blocks,
    base_seed,
    pop_log_r_mean,
    pop_log_r_sd,
    pop_log_q_mean = LOG_Q_PRIOR_MEAN,
    pop_log_q_sd   = 0.5) {

  set.seed(base_seed)
  z_log_r     <- rnorm(n_subjects, 0, 1)
  subj_log_r  <- pop_log_r_mean + pop_log_r_sd * z_log_r
  subj_r      <- exp(subj_log_r)

  z_log_q     <- rnorm(n_subjects, 0, 1)
  subj_log_q  <- pop_log_q_mean + pop_log_q_sd * z_log_q
  subj_q      <- exp(subj_log_q)

  true_params <- tibble(
    agent_id   = seq_len(n_subjects),
    log_r_true = subj_log_r,
    r_true     = subj_r,
    log_q_true = subj_log_q,
    q_true     = subj_q
  )

  sim_data <- map_dfr(seq_len(n_subjects), function(j) {
    sched <- make_subject_schedule(stimulus_info, n_blocks,
                                   seed = base_seed * 100 + j)
    obs   <- as.matrix(sched[, c("height", "position")])
    res   <- prototype_kalman(
      r_value            = subj_r[j],
      q_value            = subj_q[j],
      obs                = obs,
      cat_one            = sched$category_feedback,
      initial_mu         = list(rep(2.5, 2), rep(2.5, 2)),
      initial_sigma_diag = 10.0
    )
    sched |>
      mutate(
        agent_id     = j,
        sim_response = res$sim_response,
        correct      = as.integer(category_feedback == sim_response)
      ) |>
      group_by(agent_id) |>
      mutate(performance = cumsum(correct) / row_number()) |>
      ungroup()
  })

  list(data = sim_data, true_params = true_params)
}
Code
# Forward simulation uses the scenario-aware function defined below.
# Population parameters are set in ch13_ml_scenario_simulate.

15.11.2 Implementing the Multilevel Prototype Model in Stan

The multilevel Stan model extends the single-subject architecture by:

  1. Adding population hyperparameters pop_log_r_mean, pop_log_r_sd, pop_log_q_mean, and pop_log_q_sd to the parameters block.
  2. Adding per-subject NCP offsets z_log_r[j] ~ Normal(0, 1) and z_log_q[j] ~ Normal(0, 1).
  3. Looping the Kalman filter (with prediction step) over subjects in transformed parameters, exactly as the GCM does in Chapter 13.

The resulting model has \(4 + 2J\) parameters (\(J\) is the number of subjects) — still relatively compact compared to exemplar models — which makes posterior geometry and SBC diagnostics particularly legible.

Code
prototype_ml_stan <- "
// Kalman Filter Prototype Model — Multilevel (Non-Centred Parameterisation)
//
// New relative to the single-subject model:
//   1. Population hyperparameters for log_r AND log_q.
//   2. Per-subject NCP offsets z_log_r[j] and z_log_q[j] ~ Normal(0, 1).
//   3. Subject-indexed Kalman filter loop with prediction step in transformed parameters.
//   4. ALL prior hyperparameters passed through the data block.

data {
  int<lower=1> N_total;
  int<lower=1> N_subjects;
  int<lower=1> N_features;

  array[N_subjects] int<lower=1, upper=N_total> subj_start;
  array[N_subjects] int<lower=1, upper=N_total> subj_end;

  array[N_total] int<lower=0, upper=1> y;
  array[N_total] int<lower=0, upper=1> cat_one;
  array[N_total, N_features] real obs;

  // Fixed structural values (same as single-subject; not inferred)
  vector[N_features] initial_mu_cat0;
  vector[N_features] initial_mu_cat1;
  real<lower=0> initial_sigma_diag;

  // Prior hyperparameters for r
  real pop_log_r_mean_prior_mean;
  real<lower=0> pop_log_r_mean_prior_sd;
  real<lower=0> pop_log_r_sd_prior_rate;

  // Prior hyperparameters for q
  real pop_log_q_mean_prior_mean;
  real<lower=0> pop_log_q_mean_prior_sd;
  real<lower=0> pop_log_q_sd_prior_rate;
}

parameters {
  real pop_log_r_mean;
  real<lower=0> pop_log_r_sd;
  vector[N_subjects] z_log_r;

  real pop_log_q_mean;
  real<lower=0> pop_log_q_sd;
  vector[N_subjects] z_log_q;
}

transformed parameters {
  vector[N_subjects] subj_log_r;
  vector<lower=0>[N_subjects] subj_r;
  vector[N_subjects] subj_log_q;
  vector<lower=0>[N_subjects] subj_q;
  vector<lower=1e-9, upper=1-1e-9>[N_total] prob_cat1;

  for (j in 1:N_subjects) {
    subj_log_r[j] = pop_log_r_mean + z_log_r[j] * pop_log_r_sd;
    subj_r[j]     = exp(subj_log_r[j]);
    subj_log_q[j] = pop_log_q_mean + z_log_q[j] * pop_log_q_sd;
    subj_q[j]     = exp(subj_log_q[j]);
  }

  {
    matrix[N_features, N_features] I_mat =
      diag_matrix(rep_vector(1.0, N_features));

    for (j in 1:N_subjects) {
      vector[N_features] mu_cat0 = initial_mu_cat0;
      vector[N_features] mu_cat1 = initial_mu_cat1;
      matrix[N_features, N_features] sigma_cat0 =
        diag_matrix(rep_vector(initial_sigma_diag, N_features));
      matrix[N_features, N_features] sigma_cat1 =
        diag_matrix(rep_vector(initial_sigma_diag, N_features));
      matrix[N_features, N_features] r_matrix =
        diag_matrix(rep_vector(subj_r[j], N_features));
      matrix[N_features, N_features] q_matrix =
        diag_matrix(rep_vector(subj_q[j], N_features));

      for (i in subj_start[j]:subj_end[j]) {
        vector[N_features] x = to_vector(obs[i]);

        // ── Prediction step: add process noise to both categories ──────────
        sigma_cat0 = sigma_cat0 + q_matrix;
        sigma_cat1 = sigma_cat1 + q_matrix;

        // ── Decision ────────────────────────────────────────────────────────
        matrix[N_features, N_features] cov0 = sigma_cat0 + r_matrix;
        matrix[N_features, N_features] cov1 = sigma_cat1 + r_matrix;
        real log_p0 = multi_normal_lpdf(x | mu_cat0, cov0);
        real log_p1 = multi_normal_lpdf(x | mu_cat1, cov1);
        real p_i    = exp(log_p1 - log_sum_exp(log_p0, log_p1));
        prob_cat1[i] = fmax(1e-9, fmin(1 - 1e-9, p_i));

        // ── Update ───────────────────────────────────────────────────────────
        if (cat_one[i] == 1) {
          vector[N_features] innov = x - mu_cat1;
          matrix[N_features, N_features] S  = sigma_cat1 + r_matrix;
          matrix[N_features, N_features] K  = mdivide_right_spd(sigma_cat1, S);
          matrix[N_features, N_features] IK = I_mat - K;
          mu_cat1    = mu_cat1 + K * innov;
          sigma_cat1 = IK * sigma_cat1 * IK' + K * r_matrix * K';
          sigma_cat1 = 0.5 * (sigma_cat1 + sigma_cat1');
        } else {
          vector[N_features] innov = x - mu_cat0;
          matrix[N_features, N_features] S  = sigma_cat0 + r_matrix;
          matrix[N_features, N_features] K  = mdivide_right_spd(sigma_cat0, S);
          matrix[N_features, N_features] IK = I_mat - K;
          mu_cat0    = mu_cat0 + K * innov;
          sigma_cat0 = IK * sigma_cat0 * IK' + K * r_matrix * K';
          sigma_cat0 = 0.5 * (sigma_cat0 + sigma_cat0');
        }
      }
    }
  }
}

model {
  target += normal_lpdf(pop_log_r_mean | pop_log_r_mean_prior_mean,
                                         pop_log_r_mean_prior_sd);
  target += exponential_lpdf(pop_log_r_sd | pop_log_r_sd_prior_rate);
  target += std_normal_lpdf(z_log_r);

  target += normal_lpdf(pop_log_q_mean | pop_log_q_mean_prior_mean,
                                         pop_log_q_mean_prior_sd);
  target += exponential_lpdf(pop_log_q_sd | pop_log_q_sd_prior_rate);
  target += std_normal_lpdf(z_log_q);

  target += bernoulli_lpmf(y | prob_cat1);
}

generated quantities {
  vector[N_total] log_lik;
  real lprior;
  for (i in 1:N_total)
    log_lik[i] = bernoulli_lpmf(y[i] | prob_cat1[i]);
  lprior = normal_lpdf(pop_log_r_mean | pop_log_r_mean_prior_mean, pop_log_r_mean_prior_sd) +
           exponential_lpdf(pop_log_r_sd | pop_log_r_sd_prior_rate) +
           std_normal_lpdf(z_log_r) +
           normal_lpdf(pop_log_q_mean | pop_log_q_mean_prior_mean, pop_log_q_mean_prior_sd) +
           exponential_lpdf(pop_log_q_sd | pop_log_q_sd_prior_rate) +
           std_normal_lpdf(z_log_q);
}
"

stan_file_proto_ml <- here("stan", "ch13_prototype_ml.stan")
write_stan_file(prototype_ml_stan, dir = here("stan"),
                basename = "ch13_prototype_ml.stan")
[1] "/Users/au209589/Dropbox/Teaching/AdvancedCognitiveModeling23_book/stan/ch13_prototype_ml.stan"
Code
mod_prototype_ml <- cmdstan_model(stan_file_proto_ml)
cat("Multilevel prototype Stan model compiled successfully.\n")
Multilevel prototype Stan model compiled successfully.

15.11.3 Forward Simulation and Fitting Across Scenarios

The scenario-aware simulation and fitting are handled by the functions below. Population parameters are shared across scenarios so that any differences in posterior recovery are attributable to the environment, not to different data-generating values.

Code
simulate_ml_proto_scenario <- function(
    n_subjects, stimulus_info, n_blocks, scenario, base_seed,
    pop_log_r_mean, pop_log_r_sd,
    pop_log_q_mean = LOG_Q_PRIOR_MEAN,
    pop_log_q_sd   = 0.5) {

  set.seed(base_seed)
  z_log_r    <- rnorm(n_subjects, 0, 1)
  subj_log_r <- pop_log_r_mean + pop_log_r_sd * z_log_r
  subj_r     <- exp(subj_log_r)

  z_log_q    <- rnorm(n_subjects, 0, 1)
  subj_log_q <- pop_log_q_mean + pop_log_q_sd * z_log_q
  subj_q     <- exp(subj_log_q)

  true_params <- tibble(
    agent_id   = seq_len(n_subjects),
    scenario   = scenario,
    log_r_true = subj_log_r, r_true = subj_r,
    log_q_true = subj_log_q, q_true = subj_q
  )

  sim_data <- map_dfr(seq_len(n_subjects), function(j) {
    simulate_prototype_agent_scenario(
      agent_id     = j, scenario = scenario,
      r_value      = subj_r[j], q_value = subj_q[j],
      subject_seed = base_seed * 100 + j
    )
  })

  list(data = sim_data, true_params = true_params)
}
Code
n_subjects_ml_scen <- 10
pop_params_ml_scen <- list(
  pop_log_r_mean = log(1.5),    # population median r = 1.5
  pop_log_r_sd   = 0.5,
  pop_log_q_mean = log(0.15),   # population median q = 0.15
  pop_log_q_sd   = 0.5
)
ml_proto_scenarios <- c("static", "contingent_shift", "drift")

ml_proto_scen_files <- map(ml_proto_scenarios, function(scen) list(
  data  = here("simdata", paste0("ch13_proto_ml_scen_data_",  scen, ".csv")),
  truth = here("simdata", paste0("ch13_proto_ml_scen_truth_", scen, ".csv"))
)) |> setNames(ml_proto_scenarios)

ml_scen_files_exist <- all(
  file.exists(unlist(map(ml_proto_scen_files, ~ c(.x$data, .x$truth))))
)

if (regenerate_simulations || !ml_scen_files_exist) {
  ml_proto_scen_sims <- map(ml_proto_scenarios, function(scen) {
    out <- simulate_ml_proto_scenario(
      n_subjects     = n_subjects_ml_scen,
      stimulus_info  = stimulus_info,
      n_blocks       = n_blocks,
      scenario       = scen,
      base_seed      = match(scen, ml_proto_scenarios) * 1000 + 77,
      pop_log_r_mean = pop_params_ml_scen$pop_log_r_mean,
      pop_log_r_sd   = pop_params_ml_scen$pop_log_r_sd,
      pop_log_q_mean = pop_params_ml_scen$pop_log_q_mean,
      pop_log_q_sd   = pop_params_ml_scen$pop_log_q_sd
    )
    write_csv(out$data,        ml_proto_scen_files[[scen]]$data)
    write_csv(out$true_params, ml_proto_scen_files[[scen]]$truth)
    out
  }) |> setNames(ml_proto_scenarios)
} else {
  ml_proto_scen_sims <- map(ml_proto_scenarios, function(scen) list(
    data        = read_csv(ml_proto_scen_files[[scen]]$data,  show_col_types = FALSE),
    true_params = read_csv(ml_proto_scen_files[[scen]]$truth, show_col_types = FALSE)
  )) |> setNames(ml_proto_scenarios)
}

ml_proto_scen_all <- map_dfr(ml_proto_scenarios, function(scen)
  ml_proto_scen_sims[[scen]]$data |> mutate(scenario = scen)
) |>
  mutate(scenario = factor(scenario, levels = ml_proto_scenarios,
                           labels = c("Static", "Contingent shift", "Smooth drift")))

ggplot(ml_proto_scen_all,
       aes(x = trial_within_subject, y = performance, group = agent_id)) +
  geom_line(alpha = 0.25, color = "grey40") +
  stat_summary(fun = mean, geom = "line", aes(group = 1),
               color = "#009E73", linewidth = 1.4) +
  facet_wrap(~scenario, ncol = 3) +
  geom_hline(yintercept = 0.5, linetype = "dashed", color = "grey50") +
  labs(
    title    = "Multilevel Prototype Model: learning curves across scenarios",
    subtitle = paste0(n_subjects_ml_scen,
                      " agents per scenario; grey = individuals, green = population mean"),
    x = "Trial", y = "Cumulative accuracy"
  )

Code
build_ml_proto_stan_data <- function(ml_data) {
  sorted <- ml_data |>
    mutate(subj_id_stan = as.integer(factor(agent_id))) |>
    arrange(subj_id_stan, trial_within_subject)

  bounds <- sorted |>
    mutate(row_idx = row_number()) |>
    group_by(subj_id_stan) |>
    summarise(subj_start = min(row_idx), subj_end = max(row_idx),
              .groups = "drop") |>
    arrange(subj_id_stan)

  list(
    N_total                   = nrow(sorted),
    N_subjects                = max(sorted$subj_id_stan),
    N_features                = 2L,
    subj_start                = bounds$subj_start,
    subj_end                  = bounds$subj_end,
    y                         = as.integer(sorted$sim_response),
    cat_one                   = as.integer(sorted$category_feedback),
    obs                       = as.matrix(sorted[, c("height", "position")]),
    initial_mu_cat0           = c(2.5, 2.5),
    initial_mu_cat1           = c(2.5, 2.5),
    initial_sigma_diag        = 10.0,
    pop_log_r_mean_prior_mean = LOG_R_PRIOR_MEAN,
    pop_log_r_mean_prior_sd   = LOG_R_PRIOR_SD,
    pop_log_r_sd_prior_rate   = 2.0,
    pop_log_q_mean_prior_mean = LOG_Q_PRIOR_MEAN,
    pop_log_q_mean_prior_sd   = LOG_Q_PRIOR_SD,
    pop_log_q_sd_prior_rate   = 2.0
  )
}

ml_proto_scen_fit_files <- map_chr(ml_proto_scenarios, function(scen)
  here("simmodels", paste0("ch13_proto_ml_scen_fit_", scen, ".rds"))
) |> setNames(ml_proto_scenarios)

if (regenerate_simulations || !all(file.exists(ml_proto_scen_fit_files))) {
  ml_proto_scen_fits <- map(ml_proto_scenarios, function(scen) {
    cat("Fitting multilevel prototype — scenario:", scen, "\n")
    dat <- build_ml_proto_stan_data(ml_proto_scen_sims[[scen]]$data)
    fit <- mod_prototype_ml$sample(
      data            = dat,
      chains          = 4,
      parallel_chains = min(4, availableCores()),
      iter_warmup     = 1500,
      iter_sampling   = 1500,
      refresh         = 200,
      adapt_delta     = 0.95,
      max_treedepth   = 12
    )
    fit$save_object(ml_proto_scen_fit_files[[scen]])
    fit
  }) |> setNames(ml_proto_scenarios)
} else {
  ml_proto_scen_fits <- map(ml_proto_scen_fit_files, readRDS) |>
    setNames(ml_proto_scenarios)
}

15.11.4 Mandatory MCMC Diagnostic Battery (Multilevel, All Scenarios)

Code
pop_vars_ml <- c("pop_log_r_mean", "pop_log_r_sd",
                 "pop_log_q_mean", "pop_log_q_sd")

ml_diag_all <- purrr::imap_dfr(ml_proto_scen_fits, function(fit, scen) {
  diagnostic_summary_table(fit, params = pop_vars_ml) |>
    dplyr::mutate(scenario = scen)
})
print(ml_diag_all |> dplyr::select(scenario, metric, value, threshold, pass))
# A tibble: 18 × 5
   scenario         metric                           value threshold pass 
   <chr>            <chr>                            <dbl> <chr>     <lgl>
 1 static           Divergences (zero tolerance)    0      == 0      TRUE 
 2 static           Max rank-normalised R-hat       1      < 1.01    TRUE 
 3 static           Min bulk ESS                 1879      > 400     TRUE 
 4 static           Min tail ESS                 1647      > 400     TRUE 
 5 static           Min E-BFMI                      0.786  > 0.2     TRUE 
 6 static           Max MCSE / posterior SD         0.0231 < 0.05    TRUE 
 7 contingent_shift Divergences (zero tolerance)    0      == 0      TRUE 
 8 contingent_shift Max rank-normalised R-hat       1      < 1.01    TRUE 
 9 contingent_shift Min bulk ESS                 1647      > 400     TRUE 
10 contingent_shift Min tail ESS                 1943      > 400     TRUE 
11 contingent_shift Min E-BFMI                      0.716  > 0.2     TRUE 
12 contingent_shift Max MCSE / posterior SD         0.0246 < 0.05    TRUE 
13 drift            Divergences (zero tolerance)    0      == 0      TRUE 
14 drift            Max rank-normalised R-hat       1      < 1.01    TRUE 
15 drift            Min bulk ESS                 1830      > 400     TRUE 
16 drift            Min tail ESS                 1537      > 400     TRUE 
17 drift            Min E-BFMI                      0.722  > 0.2     TRUE 
18 drift            Max MCSE / posterior SD         0.0234 < 0.05    TRUE 
Code
walk(names(ml_proto_scen_fits), function(scen) {
  cat("=== Pair plot (population params):", scen, "===\n")
  print(bayesplot::mcmc_pairs(
    ml_proto_scen_fits[[scen]]$draws(pop_vars_ml),
    diag_fun     = "dens",
    off_diag_fun = "hex"
  ))
})
=== Pair plot (population params): static ===

=== Pair plot (population params): contingent_shift ===

=== Pair plot (population params): drift ===

15.11.5 Population Prior-Posterior Updates Across Scenarios

Code
set.seed(9)
n_prior_ml <- 4000
prior_ml <- tibble(
  pop_log_r_mean = rnorm(n_prior_ml, LOG_R_PRIOR_MEAN, LOG_R_PRIOR_SD),
  pop_log_r_sd   = rexp(n_prior_ml, 2.0),
  pop_log_q_mean = rnorm(n_prior_ml, LOG_Q_PRIOR_MEAN, LOG_Q_PRIOR_SD),
  pop_log_q_sd   = rexp(n_prior_ml, 2.0),
  source         = "prior",
  scenario       = "prior"
)

ml_post_long <- purrr::imap_dfr(ml_proto_scen_fits, function(fit, scen) {
  as_draws_df(fit$draws(pop_vars_ml)) |>
    dplyr::select(all_of(pop_vars_ml)) |>
    dplyr::mutate(source = "posterior", scenario = scen)
})

ml_pp_long <- bind_rows(
  prior_ml |> dplyr::select(-scenario),
  ml_post_long |> dplyr::select(-scenario)
) |>
  tidyr::pivot_longer(cols = all_of(pop_vars_ml),
                      names_to = "parameter", values_to = "value") |>
  dplyr::mutate(scenario_col = rep(
    c(rep("prior", n_prior_ml * length(pop_vars_ml)),
      rep(ml_post_long$scenario, length(pop_vars_ml))),
    1
  ))

# Simpler: build tidy data per scenario
ml_pp_tidy <- purrr::imap_dfr(ml_proto_scen_fits, function(fit, scen) {
  post <- as_draws_df(fit$draws(pop_vars_ml)) |>
    dplyr::select(all_of(pop_vars_ml)) |>
    dplyr::mutate(source = "posterior", scenario = scen)
  prior_scen <- prior_ml |>
    dplyr::select(all_of(pop_vars_ml)) |>
    dplyr::mutate(source = "prior", scenario = scen)
  bind_rows(prior_scen, post)
}) |>
  tidyr::pivot_longer(cols = all_of(pop_vars_ml),
                      names_to = "parameter", values_to = "value") |>
  dplyr::mutate(scenario = factor(scenario,
                                   levels = c("static",
                                              "contingent_shift",
                                              "drift")))

true_pop <- tibble(
  parameter = pop_vars_ml,
  true_val  = c(pop_params_ml_scen$pop_log_r_mean,
                pop_params_ml_scen$pop_log_r_sd,
                pop_params_ml_scen$pop_log_q_mean,
                pop_params_ml_scen$pop_log_q_sd)
)

ggplot(ml_pp_tidy, aes(x = value, fill = source, color = source)) +
  geom_density(alpha = 0.4, linewidth = 0.5) +
  geom_vline(data = true_pop, aes(xintercept = true_val),
             color = "red", linetype = "dashed", inherit.aes = FALSE) +
  facet_grid(scenario ~ parameter, scales = "free") +
  scale_fill_manual(values  = c(prior = "grey70", posterior = "#009E73")) +
  scale_color_manual(values = c(prior = "grey50", posterior = "#006D4E")) +
  labs(
    title    = "Population Prior → Posterior: Multilevel Prototype Model",
    subtitle = "Red dashed = true generating value; rows = scenario",
    x = NULL, y = "Density"
  )

15.11.6 Population Parameter Recovery

Code
pop_vars_scen <- c("pop_log_r_mean", "pop_log_r_sd",
                   "pop_log_q_mean", "pop_log_q_sd")

true_vals_scen <- tibble(
  variable = pop_vars_scen,
  true_val = c(
    pop_params_ml_scen$pop_log_r_mean,
    pop_params_ml_scen$pop_log_r_sd,
    pop_params_ml_scen$pop_log_q_mean,
    pop_params_ml_scen$pop_log_q_sd
  )
)

# Prior 90% CIs:
#   pop_log_r_mean ~ N(LOG_R_PRIOR_MEAN, LOG_R_PRIOR_SD)
#   pop_log_q_mean ~ N(LOG_Q_PRIOR_MEAN, LOG_Q_PRIOR_SD)
#   pop_log_r_sd, pop_log_q_sd ~ Exponential(rate = 2)
prior_ref_ml <- tibble(
  variable  = pop_vars_scen,
  prior_lo  = c(qnorm(0.05, LOG_R_PRIOR_MEAN, LOG_R_PRIOR_SD),
                qexp(0.05, 2),
                qnorm(0.05, LOG_Q_PRIOR_MEAN, LOG_Q_PRIOR_SD),
                qexp(0.05, 2)),
  prior_hi  = c(qnorm(0.95, LOG_R_PRIOR_MEAN, LOG_R_PRIOR_SD),
                qexp(0.95, 2),
                qnorm(0.95, LOG_Q_PRIOR_MEAN, LOG_Q_PRIOR_SD),
                qexp(0.95, 2)),
  prior_med = c(LOG_R_PRIOR_MEAN,
                log(2) / 2,
                LOG_Q_PRIOR_MEAN,
                log(2) / 2)
)

ml_proto_scen_recovery <- map_dfr(ml_proto_scenarios, function(scen) {
  ml_proto_scen_fits[[scen]]$summary(variables = pop_vars_scen) |>
    dplyr::select(variable, mean, q5, q95) |>
    mutate(scenario = scen)
}) |>
  left_join(true_vals_scen, by = "variable") |>
  mutate(scenario = factor(scenario,
                           levels = c("static", "contingent_shift", "drift"),
                           labels = c("Static", "Contingent shift", "Smooth drift")))

# ── Log-scale plot with prior bands ─────────────────────────────────────────
ggplot(ml_proto_scen_recovery,
       aes(x = scenario, y = mean, ymin = q5, ymax = q95, color = scenario)) +
  geom_rect(data = prior_ref_ml,
            aes(xmin = -Inf, xmax = Inf, ymin = prior_lo, ymax = prior_hi),
            inherit.aes = FALSE,
            fill = "steelblue", alpha = 0.12) +
  geom_hline(data = prior_ref_ml,
             aes(yintercept = prior_med),
             inherit.aes = FALSE,
             linetype = "dotted", color = "steelblue", linewidth = 0.7) +
  geom_hline(aes(yintercept = true_val), linetype = "dashed", color = "red") +
  geom_pointrange() +
  facet_wrap(~variable, scales = "free_y", ncol = 2) +
  scale_color_manual(values = c("#999999", "#E69F00", "#009E73")) +
  labs(
    title    = "Multilevel Population Recovery — log scale",
    subtitle = "Points = posterior mean; bars = 90% CI; red dashes = true values; blue band = prior 90% CI",
    x = NULL, y = "Posterior estimate (log scale)"
  ) +
  theme(legend.position = "none",
        axis.text.x = element_text(angle = 30, hjust = 1))

Code
# ── Natural-scale plot for the mean parameters only ──────────────────────────
ml_proto_scen_recovery_nat <- ml_proto_scen_recovery |>
  dplyr::filter(variable %in% c("pop_log_r_mean", "pop_log_q_mean")) |>
  dplyr::mutate(
    mean_nat    = exp(mean),
    q5_nat      = exp(q5),
    q95_nat     = exp(q95),
    true_nat    = exp(true_val),
    param_label = dplyr::case_when(
      variable == "pop_log_r_mean" ~ "pop r median\n(observation noise)",
      variable == "pop_log_q_mean" ~ "pop q median\n(process noise)"
    )
  )

prior_ref_nat <- tibble(
  param_label = c("pop r median\n(observation noise)",
                  "pop q median\n(process noise)"),
  prior_lo  = c(exp(qnorm(0.05, LOG_R_PRIOR_MEAN, LOG_R_PRIOR_SD)),
                exp(qnorm(0.05, LOG_Q_PRIOR_MEAN, LOG_Q_PRIOR_SD))),
  prior_hi  = c(exp(qnorm(0.95, LOG_R_PRIOR_MEAN, LOG_R_PRIOR_SD)),
                exp(qnorm(0.95, LOG_Q_PRIOR_MEAN, LOG_Q_PRIOR_SD))),
  prior_med = c(exp(LOG_R_PRIOR_MEAN),
                exp(LOG_Q_PRIOR_MEAN))
)

ggplot(ml_proto_scen_recovery_nat,
       aes(x = scenario, y = mean_nat, ymin = q5_nat, ymax = q95_nat,
           color = scenario)) +
  geom_rect(data = prior_ref_nat,
            aes(xmin = -Inf, xmax = Inf, ymin = prior_lo, ymax = prior_hi),
            inherit.aes = FALSE,
            fill = "steelblue", alpha = 0.12) +
  geom_hline(data = prior_ref_nat,
             aes(yintercept = prior_med),
             inherit.aes = FALSE,
             linetype = "dotted", color = "steelblue", linewidth = 0.7) +
  geom_hline(aes(yintercept = true_nat), linetype = "dashed", color = "red") +
  geom_pointrange() +
  facet_wrap(~param_label, scales = "free_y", ncol = 2) +
  scale_color_manual(values = c("#999999", "#E69F00", "#009E73")) +
  labs(
    title    = "Multilevel Population Recovery — natural scale",
    subtitle = "Points = posterior median; bars = 90% CI; red dashes = true values; blue band = prior 90% CI",
    x = NULL, y = "Posterior estimate (natural scale)"
  ) +
  theme(legend.position = "none",
        axis.text.x = element_text(angle = 30, hjust = 1))

Cross-scenario population recovery. This plot shows the 90% CI from fitting the multilevel model to a single representative dataset per scenario — a point estimate of recovery quality, not a systematic calibration check (the multilevel SBC in §“Multilevel SBC” below provides that). pop_log_r_mean appears well-recovered in all three scenarios on this single dataset: observation noise directly shapes every per-trial likelihood, so the data always constrain \(r\) at the population level. pop_log_q_mean tells a different story. On static data the 90% CI is wide and centred near the prior mean — partial pooling across 10 subjects cannot manufacture information that no individual trial provides. On contingent-shift data the CI narrows as reversal dynamics provide a collective signal; subjects with more reversals lend strength to those with fewer. On drift data both population means appear cleanly recovered, with the 90% CI sitting close to the true value. However, the multilevel SBC reveals that pop_log_r_mean carries a systematic upward bias in the drift scenario that is invisible in a single-dataset recovery plot — a reminder that one-shot recovery and formal calibration are complementary, not interchangeable, diagnostics.

Comparison with Chapter 13. The decay GCM’s cross-scenario multilevel analysis showed that pop_log_lambda_mean was most clearly recovered on contingent-shift data. Here pop_log_q_mean is most clearly recovered on drift data. The contrast reflects the models’ different generative stories: the GCM’s general-purpose decay kernel captures contingent shifts as well as drift (forgetting is useful whenever old evidence is stale, regardless of why it became stale); the Kalman filter’s smooth-drift prior is most consistent with the scenario where prototypes literally undergo Gaussian random-walk drift.


15.12 Multilevel SBC

Multilevel SBC checks that the full posterior — including both population hyperparameters and individual-level parameters — is correctly calibrated across the entire joint prior. The prototype model now has two per-subject parameters (\(\log r\) and \(\log q\)), giving six ECDF-difference curves in the SBC output.

What the key curves test:

  • pop_log_r_mean / pop_log_q_mean — Are the population-mean estimators centred? A monotone drift indicates directional bias.
  • pop_log_r_sd / pop_log_q_sd — Are the population-SD estimators well-calibrated? U-shaped rank histograms indicate the posterior is overconfident; humps indicate underconfidence. The NCP should prevent reverse-funnel issues.
  • subj_log_r[1] / subj_log_q[1] — Do partial pooling and NCP produce valid individual-level credible intervals for both parameters?
Code
sbc_ml_proto_filepath <- here("simdata", "ch13_sbc_proto_ml_results.rds")

if (regenerate_simulations || !file.exists(sbc_ml_proto_filepath)) {
  n_sbc_iterations_ml_proto <- 200   # use >= 1000 for publication-quality SBC
  n_subjects_sbc_proto      <- 10

  proto_ml_sbc_generator <- SBC_generator_function(
    function() {
      pop_log_r_mean <- rnorm(1, LOG_R_PRIOR_MEAN, LOG_R_PRIOR_SD)
      pop_log_r_sd   <- rexp(1, 2.0)
      pop_log_q_mean <- rnorm(1, LOG_Q_PRIOR_MEAN, LOG_Q_PRIOR_SD)
      pop_log_q_sd   <- rexp(1, 2.0)

      z_log_r    <- rnorm(n_subjects_sbc_proto, 0, 1)
      subj_log_r <- pop_log_r_mean + pop_log_r_sd * z_log_r
      subj_r     <- exp(subj_log_r)

      z_log_q    <- rnorm(n_subjects_sbc_proto, 0, 1)
      subj_log_q <- pop_log_q_mean + pop_log_q_sd * z_log_q
      subj_q     <- exp(subj_log_q)

      all_dat <- map_dfr(seq_len(n_subjects_sbc_proto), function(j) {
        sched <- make_subject_schedule(stimulus_info, n_blocks,
                                       seed = sample.int(1e8, 1))
        res   <- prototype_kalman(
          r_value            = subj_r[j],
          q_value            = subj_q[j],
          obs                = as.matrix(sched[, c("height", "position")]),
          cat_one            = sched$category_feedback,
          initial_mu         = list(rep(2.5, 2), rep(2.5, 2)),
          initial_sigma_diag = 10.0
        )
        sched |> mutate(agent_id = j, sim_response = res$sim_response)
      })

      sorted <- all_dat |>
        mutate(subj_id_stan = as.integer(factor(agent_id))) |>
        arrange(subj_id_stan, trial_within_subject)

      bounds <- sorted |>
        mutate(row_idx = row_number()) |>
        group_by(subj_id_stan) |>
        summarise(subj_start = min(row_idx), subj_end = max(row_idx),
                  .groups = "drop") |>
        arrange(subj_id_stan)

      list(
        variables = list(
          pop_log_r_mean    = pop_log_r_mean,
          pop_log_r_sd      = pop_log_r_sd,
          `subj_log_r[1]`   = subj_log_r[1],
          pop_log_q_mean    = pop_log_q_mean,
          pop_log_q_sd      = pop_log_q_sd,
          `subj_log_q[1]`   = subj_log_q[1]
        ),
        generated = list(
          N_total                   = nrow(sorted),
          N_subjects                = n_subjects_sbc_proto,
          N_features                = 2L,
          subj_start                = bounds$subj_start,
          subj_end                  = bounds$subj_end,
          y                         = as.integer(sorted$sim_response),
          cat_one                   = as.integer(sorted$category_feedback),
          obs                       = as.matrix(sorted[, c("height", "position")]),
          initial_mu_cat0           = c(2.5, 2.5),
          initial_mu_cat1           = c(2.5, 2.5),
          initial_sigma_diag        = 10.0,
          pop_log_r_mean_prior_mean = LOG_R_PRIOR_MEAN,
          pop_log_r_mean_prior_sd   = LOG_R_PRIOR_SD,
          pop_log_r_sd_prior_rate   = 2.0,
          pop_log_q_mean_prior_mean = LOG_Q_PRIOR_MEAN,
          pop_log_q_mean_prior_sd   = LOG_Q_PRIOR_SD,
          pop_log_q_sd_prior_rate   = 2.0
        )
      )
    }
  )

  proto_ml_sbc_backend <- SBC_backend_cmdstan_sample(
    mod_prototype_ml,
    iter_warmup   = 1000,
    iter_sampling = 1000,
    chains        = 2,
    adapt_delta   = 0.95,
    refresh       = 0
  )

  proto_ml_datasets <- generate_datasets(
    proto_ml_sbc_generator,
    n_sims = n_sbc_iterations_ml_proto
  )

  sbc_results_ml_proto <- compute_SBC(
    proto_ml_datasets,
    proto_ml_sbc_backend,
    cache_mode     = "results",
    cache_location = here("simdata", "ch13_sbc_proto_ml_cache"),
    keep_fits      = FALSE
  )
  saveRDS(sbc_results_ml_proto, sbc_ml_proto_filepath)
  cat("Multilevel prototype SBC results computed and saved.\n")
} else {
  sbc_results_ml_proto <- readRDS(sbc_ml_proto_filepath)
  cat("Loaded existing multilevel prototype SBC results.\n")
}
Loaded existing multilevel prototype SBC results.
Code
plot_ecdf_diff(sbc_results_ml_proto)

Code
plot_rank_hist(sbc_results_ml_proto)

Code
# ── Population-level parameter recovery across SBC simulations ───────────────
recovery_df_pop_proto <- sbc_results_ml_proto$stats |>
  filter(variable %in% c("pop_log_r_mean", "pop_log_r_sd",
                          "pop_log_q_mean", "pop_log_q_sd"))

p_sbc_pop_proto <- ggplot(recovery_df_pop_proto,
                           aes(x = simulated_value, y = mean)) +
  geom_point(alpha = 0.5, color = "#009E73") +
  geom_abline(intercept = 0, slope = 1, color = "red", linetype = "dashed") +
  facet_wrap(~variable, scales = "free") +
  labs(
    title    = "Population Parameter Recovery Across SBC Simulations",
    subtitle = "Posterior mean vs. true generated value",
    x = "True Value", y = "Estimated Value (Posterior Mean)"
  ) +
  theme_bw()
print(p_sbc_pop_proto)

Code
# ── Population recovery — natural scale (mean parameters only) ───────────────
recovery_df_pop_proto_nat <- recovery_df_pop_proto |>
  dplyr::filter(variable %in% c("pop_log_r_mean", "pop_log_q_mean")) |>
  dplyr::mutate(
    sim_nat  = exp(simulated_value),
    mean_nat = exp(mean),
    label    = dplyr::case_when(
      variable == "pop_log_r_mean" ~ "pop r median (obs. noise)",
      variable == "pop_log_q_mean" ~ "pop q median (process noise)"
    )
  )
ggplot(recovery_df_pop_proto_nat, aes(x = sim_nat, y = mean_nat)) +
  geom_point(alpha = 0.5, color = "#009E73") +
  geom_abline(intercept = 0, slope = 1, color = "red", linetype = "dashed") +
  facet_wrap(~label, scales = "free", ncol = 2) +
  labs(
    title    = "Population Recovery — Static Multilevel SBC (natural scale)",
    subtitle = "Posterior median vs. true generated value",
    x = "True value", y = "Estimated value (posterior median)"
  ) +
  theme_bw()

Code
# ── Individual-level parameter recovery across SBC simulations ──────────────
recovery_df_indiv_proto <- sbc_results_ml_proto$stats |>
  filter(variable %in% c("subj_log_r[1]", "subj_log_q[1]"))

p_sbc_indiv_proto <- ggplot(recovery_df_indiv_proto,
                             aes(x = simulated_value, y = mean)) +
  geom_point(alpha = 0.5, color = "#D55E00") +
  geom_abline(intercept = 0, slope = 1,
              color = "grey30", linetype = "dashed") +
  facet_wrap(~variable, scales = "free") +
  labs(
    title    = "Individual Parameter Recovery (Subject 1) Across SBC",
    subtitle = "Posterior mean vs. true generated value",
    x = "True value [Subject 1]",
    y = "Estimated value (Posterior Mean)"
  ) +
  theme_bw()
print(p_sbc_indiv_proto)

Code
# ── Individual recovery — natural scale ──────────────────────────────────────
recovery_df_indiv_proto_nat <- recovery_df_indiv_proto |>
  dplyr::mutate(
    sim_nat  = exp(simulated_value),
    mean_nat = exp(mean),
    label    = dplyr::case_when(
      variable == "subj_log_r[1]" ~ "subj r[1] (obs. noise)",
      variable == "subj_log_q[1]" ~ "subj q[1] (process noise)"
    )
  )
ggplot(recovery_df_indiv_proto_nat, aes(x = sim_nat, y = mean_nat)) +
  geom_point(alpha = 0.5, color = "#D55E00") +
  geom_abline(intercept = 0, slope = 1, color = "grey30", linetype = "dashed") +
  facet_wrap(~label, scales = "free", ncol = 2) +
  labs(
    title    = "Individual Recovery (Subject 1) — Static Multilevel SBC (natural scale)",
    subtitle = "Posterior median vs. true generated value",
    x = "True value [Subject 1]", y = "Estimated value (posterior median)"
  ) +
  theme_bw()

Code
# ── Divergence diagnostics — static multilevel SBC ───────────────────────────
ml_diags_static     <- sbc_results_ml_proto$backend_diagnostics
ml_fail_rate_static <- mean(ml_diags_static$n_divergent > 0) * 100

cat(sprintf(
  "Static multilevel SBC: %.0f%% of universes produced at least one divergence.\n",
  ml_fail_rate_static
))
Static multilevel SBC: 4% of universes produced at least one divergence.
Code
ml_div_profile_static <- sbc_results_ml_proto$stats |>
  dplyr::filter(variable %in% c("pop_log_r_sd", "pop_log_q_sd")) |>
  dplyr::select(sim_id, variable, simulated_value) |>
  tidyr::pivot_wider(names_from = variable, values_from = simulated_value) |>
  dplyr::left_join(
    dplyr::select(ml_diags_static, sim_id, n_divergent),
    by = "sim_id"
  ) |>
  dplyr::mutate(had_divergence = n_divergent > 0)

ggplot(ml_div_profile_static,
       aes(x = pop_log_r_sd, y = pop_log_q_sd, color = had_divergence)) +
  geom_point(alpha = 0.7, size = 2) +
  scale_color_manual(
    name   = "Sampler",
    values = c("FALSE" = "steelblue", "TRUE" = "darkred"),
    labels = c("FALSE" = "Clean", "TRUE" = "Divergent")
  ) +
  labs(
    title    = "Divergence Profile — Static Multilevel SBC",
    subtitle = "Divergences concentrate where population SDs are small (funnel geometry)",
    x        = expression("True " * sigma[log~r]),
    y        = expression("True " * sigma[log~q])
  ) +
  theme_bw()

Static multilevel SBC. All six ECDF-difference curves remain within the simultaneous bands and the rank histograms are approximately uniform — the multilevel model passes calibration on static data. This is a stronger result than one might expect given the single-subject SBC finding that log_q is individually biased: partial pooling across 10 subjects keeps the population-level posteriors adequately calibrated even when no individual’s data alone identifies \(q\).

Two patterns in the scatter plots are worth noting. First, pop_log_q_sd and pop_log_r_sd both show floor effects: estimated values pile up below 0.5–0.6 when true values exceed 1, because partial pooling forces individual estimates toward the population mean when variability is high and the data are non-informative. Second, subj_log_q[1] shows a systematic upward bias — individual \(q\) estimates sit above the identity line across the full range of true values. This is the individual-level signature of the single-subject upward bias: the posterior for \(\log q\) is pulled above the truth because static data provide no evidence against large process noise. The natural-scale plots amplify this: the bias in \(q\) becomes a multiplicative overestimation on the original scale.

The divergence profile confirms that the NCP handles the static geometry well. Divergences concentrate at small pop_log_q_sd values (below ~0.5), the region where the funnel geometry of hierarchical models creates high posterior curvature. The rate is low and the pattern is bounded, consistent with the 95+ adapt_delta setting being adequate.

15.12.1 Multilevel SBC on Contingent-Shift Data

The performance-contingent shift scenario provides a middle ground between static and drift. Category labels flip once an agent achieves a streak of correct responses, creating abrupt changes that give the model indirect leverage on \(q\) — faster post-flip re-learning is only possible with higher process noise. Unlike drift, the label jumps violate the filter’s Gaussian-drift assumption, making this a deliberately misspecified but informative scenario.

Code
sbc_ml_shift_filepath <- here("simdata", "ch13_sbc_proto_ml_shift_results.rds")

if (regenerate_simulations || !file.exists(sbc_ml_shift_filepath)) {
  n_sbc_iterations_ml_shift <- 200
  n_subjects_sbc_shift      <- 10

  proto_ml_shift_generator <- SBC_generator_function(
    function() {
      pop_log_r_mean <- rnorm(1, LOG_R_PRIOR_MEAN, LOG_R_PRIOR_SD)
      pop_log_r_sd   <- rexp(1, 2.0)
      pop_log_q_mean <- rnorm(1, LOG_Q_PRIOR_MEAN, LOG_Q_PRIOR_SD)
      pop_log_q_sd   <- rexp(1, 2.0)

      z_log_r    <- rnorm(n_subjects_sbc_shift, 0, 1)
      subj_log_r <- pop_log_r_mean + pop_log_r_sd * z_log_r
      subj_r     <- exp(subj_log_r)

      z_log_q    <- rnorm(n_subjects_sbc_shift, 0, 1)
      subj_log_q <- pop_log_q_mean + pop_log_q_sd * z_log_q
      subj_q     <- exp(subj_log_q)

      all_dat <- map_dfr(seq_len(n_subjects_sbc_shift), function(j) {
        sched <- make_subject_schedule(stimulus_info, n_blocks,
                                       seed = sample.int(1e8, 1))
        simulate_prototype_scenario(
          r_value  = subj_r[j],
          q_value  = subj_q[j],
          schedule = sched,
          scenario = "contingent_shift"
        ) |> mutate(agent_id = j)
      })

      sorted <- all_dat |>
        mutate(subj_id_stan = as.integer(factor(agent_id))) |>
        arrange(subj_id_stan, trial_within_subject)

      bounds <- sorted |>
        mutate(row_idx = row_number()) |>
        group_by(subj_id_stan) |>
        summarise(subj_start = min(row_idx), subj_end = max(row_idx),
                  .groups = "drop") |>
        arrange(subj_id_stan)

      list(
        variables = list(
          pop_log_r_mean    = pop_log_r_mean,
          pop_log_r_sd      = pop_log_r_sd,
          `subj_log_r[1]`   = subj_log_r[1],
          pop_log_q_mean    = pop_log_q_mean,
          pop_log_q_sd      = pop_log_q_sd,
          `subj_log_q[1]`   = subj_log_q[1]
        ),
        generated = list(
          N_total                   = nrow(sorted),
          N_subjects                = n_subjects_sbc_shift,
          N_features                = 2L,
          subj_start                = bounds$subj_start,
          subj_end                  = bounds$subj_end,
          y                         = as.integer(sorted$sim_response),
          cat_one                   = as.integer(sorted$observed_feedback),
          obs                       = as.matrix(sorted[, c("height", "position")]),
          initial_mu_cat0           = c(2.5, 2.5),
          initial_mu_cat1           = c(2.5, 2.5),
          initial_sigma_diag        = 10.0,
          pop_log_r_mean_prior_mean = LOG_R_PRIOR_MEAN,
          pop_log_r_mean_prior_sd   = LOG_R_PRIOR_SD,
          pop_log_r_sd_prior_rate   = 2.0,
          pop_log_q_mean_prior_mean = LOG_Q_PRIOR_MEAN,
          pop_log_q_mean_prior_sd   = LOG_Q_PRIOR_SD,
          pop_log_q_sd_prior_rate   = 2.0
        )
      )
    }
  )

  proto_ml_shift_backend <- SBC_backend_cmdstan_sample(
    mod_prototype_ml,
    iter_warmup   = 1000,
    iter_sampling = 1000,
    chains        = 2,
    adapt_delta   = 0.95,
    refresh       = 0
  )

  proto_ml_shift_datasets <- generate_datasets(
    proto_ml_shift_generator,
    n_sims = n_sbc_iterations_ml_shift
  )

  sbc_results_ml_shift <- compute_SBC(
    proto_ml_shift_datasets,
    proto_ml_shift_backend,
    cache_mode     = "results",
    cache_location = here("simdata", "ch13_sbc_proto_ml_shift_cache"),
    keep_fits      = FALSE
  )
  saveRDS(sbc_results_ml_shift, sbc_ml_shift_filepath)
  cat("Multilevel prototype SBC (contingent shift) results computed and saved.\n")
} else {
  sbc_results_ml_shift <- readRDS(sbc_ml_shift_filepath)
  cat("Loaded existing multilevel prototype SBC (contingent shift) results.\n")
}
Loaded existing multilevel prototype SBC (contingent shift) results.
Code
plot_ecdf_diff(sbc_results_ml_shift) +
  ggtitle("Multilevel SBC: contingent-shift scenario")

Code
plot_rank_hist(sbc_results_ml_shift) +
  ggtitle("Rank histograms — multilevel contingent shift")

Code
# ── Population recovery — log scale ──────────────────────────────────────────
recovery_df_pop_shift <- sbc_results_ml_shift$stats |>
  dplyr::filter(variable %in% c("pop_log_r_mean", "pop_log_r_sd",
                                  "pop_log_q_mean", "pop_log_q_sd"))

ggplot(recovery_df_pop_shift, aes(x = simulated_value, y = mean)) +
  geom_point(alpha = 0.5, color = "#009E73") +
  geom_abline(intercept = 0, slope = 1, color = "red", linetype = "dashed") +
  facet_wrap(~variable, scales = "free") +
  labs(
    title    = "Population Parameter Recovery — Contingent-Shift Multilevel SBC",
    subtitle = "Posterior mean vs. true generated value",
    x = "True Value", y = "Estimated Value (Posterior Mean)"
  ) +
  theme_bw()

Code
# ── Population recovery — natural scale ──────────────────────────────────────
recovery_df_pop_shift_nat <- recovery_df_pop_shift |>
  dplyr::filter(variable %in% c("pop_log_r_mean", "pop_log_q_mean")) |>
  dplyr::mutate(
    sim_nat  = exp(simulated_value),
    mean_nat = exp(mean),
    label    = dplyr::case_when(
      variable == "pop_log_r_mean" ~ "pop r median (obs. noise)",
      variable == "pop_log_q_mean" ~ "pop q median (process noise)"
    )
  )
ggplot(recovery_df_pop_shift_nat, aes(x = sim_nat, y = mean_nat)) +
  geom_point(alpha = 0.5, color = "#009E73") +
  geom_abline(intercept = 0, slope = 1, color = "red", linetype = "dashed") +
  facet_wrap(~label, scales = "free", ncol = 2) +
  labs(
    title    = "Population Recovery — Contingent-Shift Multilevel SBC (natural scale)",
    subtitle = "Posterior median vs. true generated value",
    x = "True value", y = "Estimated value (posterior median)"
  ) +
  theme_bw()

Code
# ── Individual recovery — log scale ──────────────────────────────────────────
recovery_df_indiv_shift <- sbc_results_ml_shift$stats |>
  dplyr::filter(variable %in% c("subj_log_r[1]", "subj_log_q[1]"))

ggplot(recovery_df_indiv_shift, aes(x = simulated_value, y = mean)) +
  geom_point(alpha = 0.5, color = "#D55E00") +
  geom_abline(intercept = 0, slope = 1, color = "grey30", linetype = "dashed") +
  facet_wrap(~variable, scales = "free") +
  labs(
    title    = "Individual Parameter Recovery (Subject 1) — Contingent-Shift Multilevel SBC",
    subtitle = "Posterior mean vs. true generated value",
    x = "True value [Subject 1]", y = "Estimated value (Posterior Mean)"
  ) +
  theme_bw()

Code
# ── Individual recovery — natural scale ──────────────────────────────────────
recovery_df_indiv_shift_nat <- recovery_df_indiv_shift |>
  dplyr::mutate(
    sim_nat  = exp(simulated_value),
    mean_nat = exp(mean),
    label    = dplyr::case_when(
      variable == "subj_log_r[1]" ~ "subj r[1] (obs. noise)",
      variable == "subj_log_q[1]" ~ "subj q[1] (process noise)"
    )
  )
ggplot(recovery_df_indiv_shift_nat, aes(x = sim_nat, y = mean_nat)) +
  geom_point(alpha = 0.5, color = "#D55E00") +
  geom_abline(intercept = 0, slope = 1, color = "grey30", linetype = "dashed") +
  facet_wrap(~label, scales = "free", ncol = 2) +
  labs(
    title    = "Individual Recovery (Subject 1) — Contingent-Shift Multilevel SBC (natural scale)",
    subtitle = "Posterior median vs. true generated value",
    x = "True value [Subject 1]", y = "Estimated value (posterior median)"
  ) +
  theme_bw()

Code
# ── Divergence diagnostics — contingent-shift multilevel SBC ─────────────────
ml_diags_shift     <- sbc_results_ml_shift$backend_diagnostics
ml_fail_rate_shift <- mean(ml_diags_shift$n_divergent > 0) * 100

cat(sprintf(
  "Contingent-shift multilevel SBC: %.0f%% of universes produced at least one divergence.\n",
  ml_fail_rate_shift
))
Contingent-shift multilevel SBC: 4% of universes produced at least one divergence.
Code
ml_div_profile_shift <- sbc_results_ml_shift$stats |>
  dplyr::filter(variable %in% c("pop_log_r_sd", "pop_log_q_sd")) |>
  dplyr::select(sim_id, variable, simulated_value) |>
  tidyr::pivot_wider(names_from = variable, values_from = simulated_value) |>
  dplyr::left_join(
    dplyr::select(ml_diags_shift, sim_id, n_divergent),
    by = "sim_id"
  ) |>
  dplyr::mutate(had_divergence = n_divergent > 0)

ggplot(ml_div_profile_shift,
       aes(x = pop_log_r_sd, y = pop_log_q_sd, color = had_divergence)) +
  geom_point(alpha = 0.7, size = 2) +
  scale_color_manual(
    name   = "Sampler",
    values = c("FALSE" = "steelblue", "TRUE" = "darkred"),
    labels = c("FALSE" = "Clean", "TRUE" = "Divergent")
  ) +
  labs(
    title    = "Divergence Profile — Contingent-Shift Multilevel SBC",
    subtitle = "Divergences concentrate where population SDs are small (funnel geometry)",
    x        = expression("True " * sigma[log~r]),
    y        = expression("True " * sigma[log~q])
  ) +
  theme_bw()

Contingent-shift multilevel SBC. The contingent-shift results have not yet been examined at the time of writing; the code above will produce ECDF-difference, rank histograms, population and individual recovery (log and natural scale), and a divergence profile. Expected patterns based on the single-subject SBC: pop_log_q_mean calibration should improve relative to static, since label-flip recovery dynamics give the filter indirect leverage on \(q\). pop_log_r_mean may show mild overconfidence (negative central dip) similar to the static case. Divergences are expected to concentrate at small population SDs. Once the simulations have been run, update this paragraph with the observed patterns.

15.12.2 Multilevel SBC on Drift Data: The Assumption-Matched Case

The smooth-drift scenario is the assumption-matched case: the true generative process is Gaussian random-walk drift on category centres, which is exactly what the Kalman filter’s process-noise term models. If multilevel inference works as intended, population \(q\) parameters should be better recovered here than in the static scenario.

Code
sbc_ml_drift_filepath <- here("simdata", "ch13_sbc_proto_ml_drift_results.rds")

if (regenerate_simulations || !file.exists(sbc_ml_drift_filepath)) {
  n_sbc_iterations_ml_drift <- 200
  n_subjects_sbc_drift      <- 10

  proto_ml_drift_generator <- SBC_generator_function(
    function() {
      pop_log_r_mean <- rnorm(1, LOG_R_PRIOR_MEAN, LOG_R_PRIOR_SD)
      pop_log_r_sd   <- rexp(1, 2.0)
      pop_log_q_mean <- rnorm(1, LOG_Q_PRIOR_MEAN, LOG_Q_PRIOR_SD)
      pop_log_q_sd   <- rexp(1, 2.0)

      z_log_r    <- rnorm(n_subjects_sbc_drift, 0, 1)
      subj_log_r <- pop_log_r_mean + pop_log_r_sd * z_log_r
      subj_r     <- exp(subj_log_r)

      z_log_q    <- rnorm(n_subjects_sbc_drift, 0, 1)
      subj_log_q <- pop_log_q_mean + pop_log_q_sd * z_log_q
      subj_q     <- exp(subj_log_q)

      all_dat <- map_dfr(seq_len(n_subjects_sbc_drift), function(j) {
        sched    <- make_subject_schedule(stimulus_info, n_blocks,
                                          seed = sample.int(1e8, 1))
        drift_tr <- make_drift_trajectory(nrow(sched), drift_sigma = 0.05,
                                          seed = sample.int(1e8, 1))
        simulate_prototype_scenario(
          r_value    = subj_r[j],
          q_value    = subj_q[j],
          schedule   = sched,
          scenario   = "drift",
          drift_traj = drift_tr
        ) |> mutate(agent_id = j)
      })

      sorted <- all_dat |>
        mutate(subj_id_stan = as.integer(factor(agent_id))) |>
        arrange(subj_id_stan, trial_within_subject)

      bounds <- sorted |>
        mutate(row_idx = row_number()) |>
        group_by(subj_id_stan) |>
        summarise(subj_start = min(row_idx), subj_end = max(row_idx),
                  .groups = "drop") |>
        arrange(subj_id_stan)

      list(
        variables = list(
          pop_log_r_mean    = pop_log_r_mean,
          pop_log_r_sd      = pop_log_r_sd,
          `subj_log_r[1]`   = subj_log_r[1],
          pop_log_q_mean    = pop_log_q_mean,
          pop_log_q_sd      = pop_log_q_sd,
          `subj_log_q[1]`   = subj_log_q[1]
        ),
        generated = list(
          N_total                   = nrow(sorted),
          N_subjects                = n_subjects_sbc_drift,
          N_features                = 2L,
          subj_start                = bounds$subj_start,
          subj_end                  = bounds$subj_end,
          y                         = as.integer(sorted$sim_response),
          cat_one                   = as.integer(sorted$observed_feedback),
          obs                       = as.matrix(sorted[, c("height", "position")]),
          initial_mu_cat0           = c(2.5, 2.5),
          initial_mu_cat1           = c(2.5, 2.5),
          initial_sigma_diag        = 10.0,
          pop_log_r_mean_prior_mean = LOG_R_PRIOR_MEAN,
          pop_log_r_mean_prior_sd   = LOG_R_PRIOR_SD,
          pop_log_r_sd_prior_rate   = 2.0,
          pop_log_q_mean_prior_mean = LOG_Q_PRIOR_MEAN,
          pop_log_q_mean_prior_sd   = LOG_Q_PRIOR_SD,
          pop_log_q_sd_prior_rate   = 2.0
        )
      )
    }
  )

  proto_ml_drift_backend <- SBC_backend_cmdstan_sample(
    mod_prototype_ml,
    iter_warmup   = 1000,
    iter_sampling = 1000,
    chains        = 2,
    adapt_delta   = 0.95,
    refresh       = 0
  )

  proto_ml_drift_datasets <- generate_datasets(
    proto_ml_drift_generator,
    n_sims = n_sbc_iterations_ml_drift
  )

  sbc_results_ml_drift <- compute_SBC(
    proto_ml_drift_datasets,
    proto_ml_drift_backend,
    cache_mode     = "results",
    cache_location = here("simdata", "ch13_sbc_proto_ml_drift_cache"),
    keep_fits      = FALSE
  )
  saveRDS(sbc_results_ml_drift, sbc_ml_drift_filepath)
  cat("Multilevel prototype SBC (drift) results computed and saved.\n")
} else {
  sbc_results_ml_drift <- readRDS(sbc_ml_drift_filepath)
  cat("Loaded existing multilevel prototype SBC (drift) results.\n")
}
Loaded existing multilevel prototype SBC (drift) results.
Code
plot_ecdf_diff(sbc_results_ml_drift) +
  ggtitle("Multilevel SBC: smooth drift scenario")

Code
plot_rank_hist(sbc_results_ml_drift) +
  ggtitle("Rank histograms — multilevel drift")

Code
# ── Parameter recovery — drift multilevel SBC ────────────────────────────────
recovery_df_pop_drift <- sbc_results_ml_drift$stats |>
  dplyr::filter(variable %in% c("pop_log_r_mean", "pop_log_r_sd",
                                  "pop_log_q_mean", "pop_log_q_sd"))

p_sbc_pop_drift <- ggplot(recovery_df_pop_drift,
                           aes(x = simulated_value, y = mean)) +
  geom_point(alpha = 0.5, color = "#009E73") +
  geom_abline(intercept = 0, slope = 1, color = "red", linetype = "dashed") +
  facet_wrap(~variable, scales = "free") +
  labs(
    title    = "Population Parameter Recovery — Drift Multilevel SBC",
    subtitle = "Posterior mean vs. true generated value",
    x = "True Value", y = "Estimated Value (Posterior Mean)"
  ) +
  theme_bw()
print(p_sbc_pop_drift)

Code
# ── Population recovery — natural scale (mean parameters only) ───────────────
recovery_df_pop_drift_nat <- recovery_df_pop_drift |>
  dplyr::filter(variable %in% c("pop_log_r_mean", "pop_log_q_mean")) |>
  dplyr::mutate(
    sim_nat  = exp(simulated_value),
    mean_nat = exp(mean),
    label    = dplyr::case_when(
      variable == "pop_log_r_mean" ~ "pop r median (obs. noise)",
      variable == "pop_log_q_mean" ~ "pop q median (process noise)"
    )
  )
ggplot(recovery_df_pop_drift_nat, aes(x = sim_nat, y = mean_nat)) +
  geom_point(alpha = 0.5, color = "#009E73") +
  geom_abline(intercept = 0, slope = 1, color = "red", linetype = "dashed") +
  facet_wrap(~label, scales = "free", ncol = 2) +
  labs(
    title    = "Population Recovery — Drift Multilevel SBC (natural scale)",
    subtitle = "Posterior median vs. true generated value",
    x = "True value", y = "Estimated value (posterior median)"
  ) +
  theme_bw()

Code
recovery_df_indiv_drift <- sbc_results_ml_drift$stats |>
  dplyr::filter(variable %in% c("subj_log_r[1]", "subj_log_q[1]"))

p_sbc_indiv_drift <- ggplot(recovery_df_indiv_drift,
                             aes(x = simulated_value, y = mean)) +
  geom_point(alpha = 0.5, color = "#D55E00") +
  geom_abline(intercept = 0, slope = 1,
              color = "grey30", linetype = "dashed") +
  facet_wrap(~variable, scales = "free") +
  labs(
    title    = "Individual Parameter Recovery (Subject 1) — Drift Multilevel SBC",
    subtitle = "Posterior mean vs. true generated value",
    x = "True value [Subject 1]",
    y = "Estimated value (Posterior Mean)"
  ) +
  theme_bw()
print(p_sbc_indiv_drift)

Code
# ── Individual recovery — natural scale ──────────────────────────────────────
recovery_df_indiv_drift_nat <- recovery_df_indiv_drift |>
  dplyr::mutate(
    sim_nat  = exp(simulated_value),
    mean_nat = exp(mean),
    label    = dplyr::case_when(
      variable == "subj_log_r[1]" ~ "subj r[1] (obs. noise)",
      variable == "subj_log_q[1]" ~ "subj q[1] (process noise)"
    )
  )
ggplot(recovery_df_indiv_drift_nat, aes(x = sim_nat, y = mean_nat)) +
  geom_point(alpha = 0.5, color = "#D55E00") +
  geom_abline(intercept = 0, slope = 1, color = "grey30", linetype = "dashed") +
  facet_wrap(~label, scales = "free", ncol = 2) +
  labs(
    title    = "Individual Recovery (Subject 1) — Drift Multilevel SBC (natural scale)",
    subtitle = "Posterior median vs. true generated value",
    x = "True value [Subject 1]", y = "Estimated value (posterior median)"
  ) +
  theme_bw()

Code
# ── Divergence diagnostics — drift multilevel SBC ────────────────────────────
ml_diags_drift     <- sbc_results_ml_drift$backend_diagnostics
ml_fail_rate_drift <- mean(ml_diags_drift$n_divergent > 0) * 100

cat(sprintf(
  "Drift multilevel SBC: %.0f%% of universes produced at least one divergence.\n",
  ml_fail_rate_drift
))
Drift multilevel SBC: 2% of universes produced at least one divergence.
Code
ml_div_profile_drift <- sbc_results_ml_drift$stats |>
  dplyr::filter(variable %in% c("pop_log_r_sd", "pop_log_q_sd")) |>
  dplyr::select(sim_id, variable, simulated_value) |>
  tidyr::pivot_wider(names_from = variable, values_from = simulated_value) |>
  dplyr::left_join(
    dplyr::select(ml_diags_drift, sim_id, n_divergent),
    by = "sim_id"
  ) |>
  dplyr::mutate(had_divergence = n_divergent > 0)

ggplot(ml_div_profile_drift,
       aes(x = pop_log_r_sd, y = pop_log_q_sd, color = had_divergence)) +
  geom_point(alpha = 0.7, size = 2) +
  scale_color_manual(
    name   = "Sampler",
    values = c("FALSE" = "steelblue", "TRUE" = "darkred"),
    labels = c("FALSE" = "Clean", "TRUE" = "Divergent")
  ) +
  labs(
    title    = "Divergence Profile — Drift Multilevel SBC",
    subtitle = "Divergences concentrate where population SDs are small (funnel geometry)",
    x        = expression("True " * sigma[log~r]),
    y        = expression("True " * sigma[log~q])
  ) +
  theme_bw()

Drift multilevel SBC. The drift scenario reveals an unexpected asymmetry: while pop_log_q_mean and pop_log_q_sd pass calibration (curves within the simultaneous bands, uniform rank histograms), pop_log_r_mean and subj_log_r[1] fail clearly. Both ECDF curves rise steeply and exit the simultaneous band from the top, reaching a maximum ECDF difference of approximately +0.30 for pop_log_r_mean. The corresponding rank histograms confirm this: both show extreme piling at rank near zero, meaning the true value consistently falls in the lower tail of the posterior — the posterior for \(\log r\) is systematically shifted upward relative to the truth, i.e., \(r\) is overestimated.

This failure is not present on static data and is therefore specific to the drift scenario. The mechanism is plausible: when category centres drift, the Kalman filter observes larger-than-expected distances between stimuli and the current prototype estimate. Part of this excess distance is correctly attributed to \(q\) (drift rate), but part is misattributed to \(r\) (observation noise). The model interprets stimulus variability from two sources — true sensor noise and unmodelled prototype movement — as a single inflated observation noise term, systematically overestimating \(r\) at the population level.

The population recovery scatter (log and natural scale) confirms this: pop_log_r_mean points align closely with the identity line on average, but there is a systematic upward shift — the cloud of points runs parallel to but above the diagonal. The individual recovery for subj_log_q[1] shows a floor effect on the natural scale: individual \(q\) estimates are compressed toward the population mean when the true \(q\) is very small, an expected partial-pooling artefact.

Divergences in the drift scenario are few but spread slightly more widely in the \((\sigma_{\log r}, \sigma_{\log q})\) space than in the static case, consistent with higher posterior curvature from the drift dynamics.

The three multilevel SBC analyses together — static (r calibrated, q biased), contingent shift (pending), drift (q calibrated, r overestimated) — establish that model misspecification at the task level propagates selectively to specific parameters at the population level, and that the Kalman filter’s parameter identifiability depends jointly on the environment and the parameter in question.


15.13 Extending the Prototype Model: Selective Attention

The prototype model developed so far has two free parameters — observation noise \(r\) and process noise \(q\) — and treats all stimulus features as equally important. Every dimension contributes identically to the Mahalanobis distance between a new stimulus and the stored prototype. This is a strong assumption that is often violated in real categorization tasks: people selectively attend to features that are diagnostic of category membership and largely ignore irrelevant ones.

In the Kruschke (1993) task, for instance, height and position both vary, but one of them may be far more informative for deciding which category a stimulus belongs to. A model without attention has no way to capture this selectivity. The GCM handles it through explicit attention weights \(w_k\) that stretch or compress feature dimensions before computing distance. Here we extend the prototype model in the same spirit.

15.13.1 The Idea: Weighting the Difference Between Stimulus and Prototype

What attention needs to do is make some feature dimensions count more than others when measuring distance. Looking at the distance formula from §“Categorization Decision”, the quantity that drives everything is the per-feature difference \((x_k - \mu_{Ck})\). Attention should scale that difference: a high-attention feature should contribute more to distance, a low-attention feature less. This is exactly what the GCM does — its distance formula weights each \(|x_{ik} - x_{jk}|\) by \(w_k\) before summing.

Define attention weights \(\mathbf{w} = (w_1, w_2, \ldots, w_K)\) with \(w_k \geq 0\) and \(\sum_k w_k = 1\), and let \(W = \text{diag}(\mathbf{w})\) be the diagonal matrix with those weights on the diagonal. The attention-weighted distance is:

\[d_C = (\vec{x} - \vec{\mu}_C)^T W^T (\Sigma_C + R)^{-1} W (\vec{x} - \vec{\mu}_C)\]

which expands for two features to:

\[d_C = \frac{w_1^2\,(x_\text{height} - \mu_{C,\text{height}})^2}{\sigma_{C,\text{height}}^2 + r} + \frac{w_2^2\,(x_\text{position} - \mu_{C,\text{position}})^2}{\sigma_{C,\text{position}}^2 + r}\]

Each feature’s squared difference is now also multiplied by \(w_k^2\). A feature with \(w_k\) close to 1 contributes almost its full squared difference to the distance; a feature with \(w_k\) close to 0 contributes almost nothing, regardless of how far apart stimulus and prototype are on that dimension.

Implementation. In practice, multiplying by \(W\) is equivalent to rescaling both the stimulus and the prototype mean by \(\mathbf{w}\) before subtraction, since \(w_k x_k - w_k \mu_{Ck} = w_k(x_k - \mu_{Ck})\). The code therefore rescales the raw observation matrix by \(\mathbf{w}\) before passing it to the Kalman filter, and rescales the initial means by the same weights. The filter itself is structurally unchanged — it simply receives already-weighted inputs.

15.13.2 Why Not Rescale Through the Noise Matrix Instead?

An alternative approach would be to keep the raw features and instead inflate or deflate the noise matrix \(R\) by attention: \(R = \text{diag}(r / w_1, \ldots, r / w_K)\). This is mathematically similar in spirit but creates a parameter redundancy for model fitting: multiplying all \(w_k\) by a constant \(c\) and dividing \(r\) by the same \(c\) leaves \(R\) unchanged. The likelihood depends only on the ratio \(r / w_k\), so the individual values of \(r\) and the \(w_k\) are not separately identifiable — the posterior degenerates into a ridge, SBC fails, and recovery is poor.

Rescaling the feature space (Option 2) avoids this because \(\mathbf{w}\) and \(r\) play genuinely distinct roles:

  • \(\mathbf{w}\) controls the shape of the decision boundary — which features matter relative to each other.
  • \(r\) controls the overall sharpness of similarity — how steeply similarity falls off with distance, across all features simultaneously.

These are not interchangeable when \(\mathbf{w}\) is constrained to the simplex (\(\sum_k w_k = 1\)), because any rescaling of \(r\) changes the overall sharpness without changing the relative weighting. The posterior geometry is well-conditioned and SBC works cleanly.

15.13.3 Prior on Attention Weights: The Dirichlet Distribution

The natural prior for a probability vector on a simplex is the Dirichlet distribution:

\[\mathbf{w} \sim \text{Dirichlet}(\boldsymbol{\alpha})\]

With a symmetric concentration parameter \(\boldsymbol{\alpha} = (1, 1, \ldots, 1)\) this is the uniform distribution over the simplex — every allocation of attention is equally probable a priori. With \(\alpha_k > 1\) the prior pushes toward equal weights (more concentrated near the centroid); with \(\alpha_k < 1\) it pushes toward sparse solutions where attention is concentrated on one feature.

For two features, \(\text{Dirichlet}(1, 1)\) is equivalent to \(w_1 \sim \text{Uniform}(0, 1)\) with \(w_2 = 1 - w_1\). It is the minimally informative choice: it says nothing about which feature will be more diagnostic, only that attention must sum to one.

15.13.4 R Implementation

The implementation is a thin wrapper around the existing prototype_kalman function: rescale the observations and the initial means before passing them in. The Kalman filter itself is unchanged.

Code
# Prototype model with selective attention via feature-space rescaling.
#
# The core idea: multiply each feature dimension k by attention weight w_k
# before any computation. This stretches dimensions the learner attends to
# and compresses dimensions they ignore. The Kalman filter then operates
# entirely in the rescaled space — both the prototype update and the
# categorization decision use rescaled features.
#
# w_values  : attention weights (length = nfeatures, must sum to 1)
# r_value   : observation noise variance (scalar > 0)
# q_value   : process noise / prototype drift rate (scalar >= 0)
# obs       : matrix of raw observations (trials x features)
# cat_one   : vector of true category labels (0 or 1) for feedback

prototype_kalman_attention <- function(w_values,
                                       r_value,
                                       q_value,
                                       obs,
                                       cat_one,
                                       initial_mu         = NULL,
                                       initial_sigma_diag = 10.0,
                                       quiet              = TRUE) {

  n_features <- ncol(obs)

  # Validate attention weights: they must be non-negative and sum to 1.
  # Small numerical deviations are tolerated (e.g., from optimisation).
  stopifnot(length(w_values) == n_features)
  stopifnot(all(w_values >= 0))
  w_values <- w_values / sum(w_values)   # renormalise defensively

  # Rescale all raw observations by attention weights.
  # After this point, the model never sees the original feature values.
  # This is the single change that adds attention to the model.
  obs_scaled <- sweep(obs, 2, w_values, "*")

  # Rescale the initial prototype means to match the rescaled space.
  # If we initialise at 2.5 in the original space, a feature with w_k = 0.9
  # starts at 2.25 in the rescaled space, which is consistent.
  if (is.null(initial_mu)) {
    mu0_raw <- rep(2.5, n_features)
    mu1_raw <- rep(2.5, n_features)
  } else {
    mu0_raw <- initial_mu[[1]]
    mu1_raw <- initial_mu[[2]]
  }
  mu0_scaled <- w_values * mu0_raw
  mu1_scaled <- w_values * mu1_raw

  # Everything from here on is identical to prototype_kalman().
  # The Kalman filter does not "know" that the space has been rescaled —
  # it simply receives rescaled inputs and returns rescaled prototype means.
  prototype_kalman(
    r_value            = r_value,
    q_value            = q_value,
    obs                = obs_scaled,
    cat_one            = cat_one,
    initial_mu         = list(mu0_scaled, mu1_scaled),
    initial_sigma_diag = initial_sigma_diag,
    quiet              = quiet
  )
}

# Wrapper: generates per-agent schedule then calls prototype_kalman_attention
simulate_prototype_attention_agent <- function(agent_id, r_value, q_value,
                                               w_values,
                                               stimulus_info, n_blocks,
                                               subject_seed) {
  schedule <- make_subject_schedule(stimulus_info, n_blocks, seed = subject_seed)
  obs      <- as.matrix(schedule[, c("height", "position")])
  cat_one  <- schedule$category_feedback

  result <- prototype_kalman_attention(
    w_values           = w_values,
    r_value            = r_value,
    q_value            = q_value,
    obs                = obs,
    cat_one            = cat_one,
    initial_mu         = list(rep(2.5, 2), rep(2.5, 2)),
    initial_sigma_diag = 10.0,
    quiet              = TRUE
  )

  schedule |>
    mutate(
      agent_id       = agent_id,
      r_value_true   = r_value,
      q_value_true   = q_value,
      w1_true        = w_values[1],
      w2_true        = w_values[2],
      prob_cat1      = result$prob_cat1,
      sim_response   = result$sim_response,
      correct        = as.integer(category_feedback == sim_response)
    ) |>
    group_by(agent_id) |>
    mutate(performance = cumsum(correct) / row_number()) |>
    ungroup()
}

15.13.5 What Does Attention Actually Do?

Before fitting, it is worth building intuition about what different attention allocations produce. We simulate three agents who share the same \(r\) and \(q\) but differ in how they weight the two features:

  • Balanced (\(w_1 = w_2 = 0.5\)): Both features contribute equally. This is the same as the original model with half the observation scale.
  • Height-focused (\(w_1 = 0.9,\ w_2 = 0.1\)): The model is nearly ten times more sensitive to height differences than to position differences.
  • Position-focused (\(w_1 = 0.1,\ w_2 = 0.9\)): The reverse.
Code
attn_param_df <- tibble(
  agent_id     = 1:5,
  r_value      = 2.0,
  q_value      = 0.1,
  subject_seed = 1:5
) |>
  cross_join(
    tibble(
      condition = c("Balanced", "Height-focused", "Position-focused"),
      w1        = c(0.5, 0.9, 0.1),
      w2        = c(0.5, 0.1, 0.9)
    )
  )

attn_sim_file <- here("simdata", "ch13_attention_simulated_responses.csv")

if (regenerate_simulations || !file.exists(attn_sim_file)) {
  cat("Simulating attention agents...\n")
  attention_responses <- future_pmap_dfr(
    list(
      agent_id     = attn_param_df$agent_id,
      r_value      = attn_param_df$r_value,
      q_value      = attn_param_df$q_value,
      w_values     = map2(attn_param_df$w1, attn_param_df$w2, c),
      subject_seed = attn_param_df$subject_seed
    ),
    simulate_prototype_attention_agent,
    stimulus_info = stimulus_info,
    n_blocks      = n_blocks,
    .options = furrr_options(seed = TRUE)
  ) |>
    left_join(
      attn_param_df |> dplyr::select(agent_id, condition, w1, w2),
      by = c("agent_id", "w1_true" = "w1", "w2_true" = "w2")
    )
  write_csv(attention_responses, attn_sim_file)
} else {
  attention_responses <- read_csv(attn_sim_file, show_col_types = FALSE)
}

# Summarise accuracy by block and attention condition
attention_responses |>
  group_by(condition, block) |>
  summarise(mean_correct = mean(correct), .groups = "drop") |>
  ggplot(aes(x = block, y = mean_correct, colour = condition)) +
  geom_line(linewidth = 0.8) +
  geom_point(size = 1.5) +
  scale_y_continuous(limits = c(0.4, 1.0), labels = scales::percent) +
  labs(
    title    = "Effect of Attention Allocation on Learning Curves",
    subtitle = "r = 2.0, q = 0.1; three agents per condition (averaged over 5 seeds)",
    x        = "Block", y = "Proportion Correct", colour = "Attention"
  )

The height-focused and position-focused agents will diverge in accuracy because the Kruschke task is not symmetrically diagnostic in the two dimensions. The agent that attends to the more diagnostic feature should learn faster and reach higher asymptotic accuracy.

15.13.6 Distance, Similarity, and the Role of Each Parameter

It is worth being precise about where \(\mathbf{w}\) and \(r\) each enter the calculation, because they are sometimes conflated.

Distance. With \(W = \text{diag}(\mathbf{w})\), the attention-weighted Mahalanobis distance is:

\[d_C = (\vec{x} - \vec{\mu}_C)^T W^T (\Sigma_C + R)^{-1} W (\vec{x} - \vec{\mu}_C)\]

Expanded for two features:

\[d_C = \frac{w_1^2\,(x_\text{height} - \mu_{C,\text{height}})^2}{\sigma_{C,\text{height}}^2 + r} + \frac{w_2^2\,(x_\text{position} - \mu_{C,\text{position}})^2}{\sigma_{C,\text{position}}^2 + r}\]

\(\mathbf{w}\) lives entirely inside the numerators — it scales each feature’s contribution to the distance. \(r\) lives in the denominators — it scales total uncertainty on each dimension.

Similarity. Distance is converted to similarity by:

\[\eta_C = \exp\!\left(-\tfrac{1}{2}\, d_C\right)\]

\(\mathbf{w}\) does not appear here. Once \(d_C\) is computed, the conversion to similarity is fixed. Attention only influences similarity indirectly, by changing the distance that feeds into this step.

The role of \(r\). \(r\) enters through \((\Sigma_C + R)^{-1}\) in the denominator of each term. Large \(r\) makes every denominator large, shrinks \(d_C\), and produces a flat bell curve — the model generalises broadly even to distant stimuli. Small \(r\) sharpens the curve. This is the role of the GCM’s sensitivity parameter \(c\); the difference is that \(c\) multiplies the distance from outside (\(\exp(-c \cdot d)\)), while \(r\) divides each term from inside the distance.

Why \(\mathbf{w}\) and \(r\) are not redundant. Multiplying all \(w_k\) by a constant \(\lambda\) multiplies every numerator term by \(\lambda^2\), which is equivalent to multiplying \(d_C\) by \(\lambda^2\). Multiplying \(r\) by \(\lambda^2\) multiplies every denominator by \(\lambda^2\), which divides \(d_C\) by \(\lambda^2\). So if \(\mathbf{w}\) were unconstrained, any rescaling of \(\mathbf{w}\) could be cancelled by a matching rescaling of \(r\) — the likelihood would be unchanged and the posterior would degenerate into a ridge. The simplex constraint \(\sum_k w_k = 1\) fixes the total scale of \(\mathbf{w}\), leaving only its shape (relative feature weighting) free — and that cannot be absorbed into the scalar \(r\). The two parameters then have genuinely distinct roles: \(\mathbf{w}\) determines which features contribute to distance; \(r\) determines how steeply similarity falls off with distance.

15.13.7 The Attention Prototype Model in Stan

Adding attention to the Stan model requires three changes relative to the two-parameter version:

  1. Declare attention weights as a simplex parameter.
  2. Add a Dirichlet prior in the model block.
  3. Rescale the stimulus vector x by w inside the Kalman loop.

The initial means must also be rescaled to the attention-weighted space, exactly as in the R implementation. Everything else — the Kalman prediction step, the Mahalanobis decision, the Kalman update — is structurally unchanged.

Code
prototype_attention_stan <- "
// Prototype Model with Selective Attention — Feature-Space Rescaling
//
// Parameters:
//   log_r  : log observation noise (same role as in the base model)
//   log_q  : log process noise     (same role as in the base model)
//   w      : attention simplex (length = nfeatures, sums to 1)
//
// Attention mechanism:
//   Before any computation on trial i, we replace the raw stimulus x[i]
//   with its attention-rescaled version  x_tilde = w .* x[i]  (element-wise).
//   Initial prototype means are also rescaled: mu_init_tilde = w .* mu_init.
//   The Kalman filter then runs entirely in this rescaled space.
//   w and r are NOT redundant: w controls the *shape* of the decision
//   boundary (which features matter); r controls overall decisional
//   sharpness. Redundancy would arise if we had instead written
//   R = diag(r ./ w), which Option 1 would require.

data {
  int<lower=1> ntrials;
  int<lower=1> nfeatures;
  array[ntrials] int<lower=0, upper=1> cat_one;
  array[ntrials] int<lower=0, upper=1> y;
  array[ntrials, nfeatures] real obs;

  vector[nfeatures] initial_mu_cat0;
  vector[nfeatures] initial_mu_cat1;
  real<lower=0> initial_sigma_diag;

  real prior_logr_mean;
  real<lower=0> prior_logr_sd;
  real prior_logq_mean;
  real<lower=0> prior_logq_sd;

  // Dirichlet concentration parameters for the attention prior.
  // alpha = rep_vector(1, nfeatures) gives a uniform prior over the simplex:
  // every allocation of attention is equally likely a priori.
  vector<lower=0>[nfeatures] alpha_w;
}

parameters {
  real log_r;
  real log_q;
  // Stan's simplex type enforces w_k >= 0 and sum(w) = 1 automatically.
  simplex[nfeatures] w;
}

transformed parameters {
  real<lower=0> r_value = exp(log_r);
  real<lower=0> q_value = exp(log_q);

  array[ntrials] real<lower=1e-9, upper=1-1e-9> p;

  {
    // Rescale the initial means into the attention-weighted space.
    // This keeps the initial prototype location consistent with the
    // rescaled observations that the filter will receive.
    vector[nfeatures] mu_cat0 = w .* initial_mu_cat0;
    vector[nfeatures] mu_cat1 = w .* initial_mu_cat1;

    matrix[nfeatures, nfeatures] sigma_cat0 =
      diag_matrix(rep_vector(initial_sigma_diag, nfeatures));
    matrix[nfeatures, nfeatures] sigma_cat1 =
      diag_matrix(rep_vector(initial_sigma_diag, nfeatures));

    // r and q matrices are unchanged: attention acts on the data, not
    // on the noise structure (this is the key difference from Option 1).
    matrix[nfeatures, nfeatures] r_matrix =
      diag_matrix(rep_vector(r_value, nfeatures));
    matrix[nfeatures, nfeatures] q_matrix =
      diag_matrix(rep_vector(q_value, nfeatures));
    matrix[nfeatures, nfeatures] I_mat =
      diag_matrix(rep_vector(1.0, nfeatures));

    for (i in 1:ntrials) {
      // Rescale the raw stimulus by attention weights before any computation.
      // This is the only line that differs from the base model's loop.
      vector[nfeatures] x = w .* to_vector(obs[i]);

      // ── Prediction step (unchanged from base model) ────────────────────────
      sigma_cat0 = sigma_cat0 + q_matrix;
      sigma_cat1 = sigma_cat1 + q_matrix;

      // ── Decision (unchanged from base model) ──────────────────────────────
      // x is now in the rescaled space; prototypes are also in the rescaled
      // space. The Mahalanobis distance is computed in that common space.
      matrix[nfeatures, nfeatures] cov0 = sigma_cat0 + r_matrix;
      matrix[nfeatures, nfeatures] cov1 = sigma_cat1 + r_matrix;

      real log_p0 = multi_normal_lpdf(x | mu_cat0, cov0);
      real log_p1 = multi_normal_lpdf(x | mu_cat1, cov1);
      real prob1  = exp(log_p1 - log_sum_exp(log_p0, log_p1));

      p[i] = fmax(1e-9, fmin(1 - 1e-9, prob1));

      // ── Update (unchanged from base model) ────────────────────────────────
      if (cat_one[i] == 1) {
        vector[nfeatures] innov = x - mu_cat1;
        matrix[nfeatures, nfeatures] S  = sigma_cat1 + r_matrix;
        matrix[nfeatures, nfeatures] K  = mdivide_right_spd(sigma_cat1, S);
        matrix[nfeatures, nfeatures] IK = I_mat - K;
        mu_cat1    = mu_cat1 + K * innov;
        sigma_cat1 = IK * sigma_cat1 * IK' + K * r_matrix * K';
        sigma_cat1 = 0.5 * (sigma_cat1 + sigma_cat1');
      } else {
        vector[nfeatures] innov = x - mu_cat0;
        matrix[nfeatures, nfeatures] S  = sigma_cat0 + r_matrix;
        matrix[nfeatures, nfeatures] K  = mdivide_right_spd(sigma_cat0, S);
        matrix[nfeatures, nfeatures] IK = I_mat - K;
        mu_cat0    = mu_cat0 + K * innov;
        sigma_cat0 = IK * sigma_cat0 * IK' + K * r_matrix * K';
        sigma_cat0 = 0.5 * (sigma_cat0 + sigma_cat0');
      }
    }
  }
}

model {
  // Priors on noise parameters: same as base model.
  target += normal_lpdf(log_r | prior_logr_mean, prior_logr_sd);
  target += normal_lpdf(log_q | prior_logq_mean, prior_logq_sd);

  // Prior on attention weights: uniform over the simplex.
  // Dirichlet(1, 1, ...) = no preference for any feature a priori.
  // To favour equal attention (e.g., a regularisation prior), increase alpha.
  target += dirichlet_lpdf(w | alpha_w);

  target += bernoulli_lpmf(y | p);
}

generated quantities {
  vector[ntrials] log_lik;
  real lprior;
  for (i in 1:ntrials)
    log_lik[i] = bernoulli_lpmf(y[i] | p[i]);
  lprior = normal_lpdf(log_r | prior_logr_mean, prior_logr_sd) +
           normal_lpdf(log_q | prior_logq_mean, prior_logq_sd) +
           dirichlet_lpdf(w | alpha_w);
}
"

stan_file_attn <- here("stan", "ch13_prototype_attention.stan")
write_stan_file(prototype_attention_stan, dir = here("stan"),
                basename = "ch13_prototype_attention.stan")
[1] "/Users/au209589/Dropbox/Teaching/AdvancedCognitiveModeling23_book/stan/ch13_prototype_attention.stan"
Code
mod_prototype_attention <- cmdstan_model(stan_file_attn)
cat("Attention prototype Stan model compiled successfully.\n")
Attention prototype Stan model compiled successfully.

Key aspects of the Stan implementation:

  • New parameter w: Declared as simplex[nfeatures]. Stan enforces the sum-to-one constraint internally; we never need to manually project onto the simplex.
  • alpha_w passed as data: This lets the caller tune the Dirichlet prior without recompiling. Passing rep(1, nfeatures) gives a uniform prior; passing rep(2, nfeatures) gently favours equal attention.
  • One extra line in the loop: vector[nfeatures] x = w .* to_vector(obs[i]); replaces the base model’s vector[nfeatures] x = to_vector(obs[i]);. Everything else is structurally identical.
  • Initial means rescaled: mu_cat0 = w .* initial_mu_cat0 rather than initial_mu_cat0 directly. Forgetting this would create an inconsistency: the filter would receive rescaled stimuli but compare them against unrescaled prototypes on the very first trial.

15.13.8 Parameter Recovery

We fit the attention model to a simulated height-focused agent (\(w_1 = 0.9,\ w_2 = 0.1\)) to check that Stan can recover both the noise parameters and the attention weights.

Code
# Use the height-focused agent, first seed
attn_agent <- attention_responses |>
  filter(condition == "Height-focused", agent_id == 1)

stopifnot(nrow(attn_agent) == total_trials)

attn_data <- list(
  ntrials            = nrow(attn_agent),
  nfeatures          = 2L,
  cat_one            = attn_agent$category_feedback,
  y                  = attn_agent$sim_response,
  obs                = as.matrix(attn_agent[, c("height", "position")]),
  initial_mu_cat0    = c(2.5, 2.5),
  initial_mu_cat1    = c(2.5, 2.5),
  initial_sigma_diag = 10.0,
  prior_logr_mean    = LOG_R_PRIOR_MEAN,
  prior_logr_sd      = LOG_R_PRIOR_SD,
  prior_logq_mean    = LOG_Q_PRIOR_MEAN,
  prior_logq_sd      = LOG_Q_PRIOR_SD,
  # Uniform Dirichlet: no preference for either feature a priori
  alpha_w            = c(1.0, 1.0)
)

attn_fit_file <- here("simmodels", "ch13_proto_attention_fit.rds")

if (regenerate_simulations || !file.exists(attn_fit_file)) {
  pf_attn <- mod_prototype_attention$pathfinder(
    data = attn_data, seed = 123, num_paths = 4, refresh = 0
  )
  fit_attn <- mod_prototype_attention$sample(
    data            = attn_data,
    init            = pf_attn,
    seed            = 123,
    chains          = 4,
    parallel_chains = min(4, availableCores()),
    iter_warmup     = 1000,
    iter_sampling   = 1500,
    refresh         = 300,
    adapt_delta     = 0.9
  )
  fit_attn$save_object(attn_fit_file)
  cat("Attention model fit computed and saved.\n")
} else {
  fit_attn <- readRDS(attn_fit_file)
  cat("Loaded existing attention model fit.\n")
}
Loaded existing attention model fit.
Code
fit_attn$summary(variables = c("log_r", "r_value", "log_q", "q_value", "w[1]", "w[2]"))
# A tibble: 6 × 10
  variable   mean median    sd   mad      q5    q95  rhat ess_bulk ess_tail
  <chr>     <dbl>  <dbl> <dbl> <dbl>   <dbl>  <dbl> <dbl>    <dbl>    <dbl>
1 log_r    -0.313 -0.285 0.680 0.665 -1.44    0.771  1.00    2532.    2135.
2 r_value   0.923  0.752 0.838 0.470  0.237   2.16   1.00    2532.    2135.
3 log_q    -1.67  -1.58  0.810 0.785 -3.16   -0.489  1.00    3139.    3221.
4 q_value   0.250  0.205 0.186 0.156  0.0425  0.613  1.00    3139.    3221.
5 w[1]      0.824  0.836 0.112 0.123  0.629   0.983  1.00    2462.    1633.
6 w[2]      0.176  0.164 0.112 0.123  0.0166  0.371  1.00    2462.    1633.
Code
# Overlay posterior on true values for a visual parameter recovery check.
# True values: r = 2.0, q = 0.1, w[1] = 0.9, w[2] = 0.1
true_vals <- tibble(
  parameter = c("r_value", "q_value", "w[1]", "w[2]"),
  truth     = c(2.0, 0.1, 0.9, 0.1)
)

fit_attn$draws(variables = c("r_value", "q_value", "w[1]", "w[2]"),
               format = "df") |>
  pivot_longer(cols = c("r_value", "q_value", "w[1]", "w[2]"),
               names_to = "parameter", values_to = "value") |>
  left_join(true_vals, by = "parameter") |>
  ggplot(aes(x = value)) +
  geom_histogram(bins = 40, fill = "steelblue", alpha = 0.7) +
  geom_vline(aes(xintercept = truth), colour = "red", linewidth = 0.8,
             linetype = "dashed") +
  facet_wrap(~parameter, scales = "free", ncol = 2) +
  labs(
    title    = "Attention Prototype: Parameter Recovery",
    subtitle = "Red dashed line = true generating value",
    x = "Parameter value", y = "Posterior draws"
  )

A successful recovery plot shows each posterior histogram clearly enclosing its red dashed line. The attention weights \(w[1]\) and \(w[2]\) are the most demanding parameters to recover: because they live on the simplex, \(w[2] = 1 - w[1]\) for two features, so all information about attention comes from a single degree of freedom. When one feature is strongly dominant (\(w_k\) close to 0 or 1), recovery is typically sharp because the likelihood is highly sensitive to that ratio. When attention is more balanced, the posterior widens.

What if the true agent had equal attention but was fit with the attention model? The posterior on \(\mathbf{w}\) should be broad and centred near \((0.5, 0.5)\), and \(r\) should recover the same value as the base model. This is the correct Bayesian response: the model expresses uncertainty about a parameter it cannot identify from the data, rather than forcing a spurious extreme value.

15.13.9 Simulation-Based Calibration

Single-agent parameter recovery confirms that Stan can reach the right neighbourhood, but SBC tells us whether the posteriors are correctly calibrated across the entire prior. The attention model has three free parameters (in the two-feature case): log_r, log_q, and w[1] (with w[2] = 1 − w[1]). The key question is whether the Dirichlet simplex representation creates any systematic miscalibration — for instance, boundary effects near \(w = 0\) or \(w = 1\) that could inflate or deflate posterior coverage.

The generator draws \((\log r, \log q)\) from their Normal priors and \(w_1 \sim \text{Uniform}(0, 1)\) (the marginal of \(\text{Dirichlet}(1,1)\)), simulates one agent’s choices, and feeds the data to the attention Stan model.

Code
sbc_attn_filepath <- here("simdata", "ch13_sbc_attention_results.rds")

if (regenerate_simulations || !file.exists(sbc_attn_filepath)) {
  n_sbc_iterations_attn <- 500   # use >= 1000 for publication-quality SBC

  attn_sbc_generator <- SBC_generator_function(
    function() {
      # Draw true parameters from their priors
      log_r_true <- rnorm(1, LOG_R_PRIOR_MEAN, LOG_R_PRIOR_SD)
      r_true     <- exp(log_r_true)
      log_q_true <- rnorm(1, LOG_Q_PRIOR_MEAN, LOG_Q_PRIOR_SD)
      q_true     <- exp(log_q_true)
      # Dirichlet(1,1) marginal for w[1] is Uniform(0,1)
      w1_true    <- runif(1, 0, 1)
      w_true     <- c(w1_true, 1 - w1_true)

      # Simulate one agent's choices
      sbc_sched  <- make_subject_schedule(stimulus_info, n_blocks,
                                          seed = sample.int(1e8, 1))
      sim        <- prototype_kalman_attention(
        w_values   = w_true,
        r_value    = r_true,
        q_value    = q_true,
        obs        = as.matrix(sbc_sched[, c("height", "position")]),
        cat_one    = sbc_sched$category_feedback
      )

      list(
        variables = list(log_r = log_r_true, log_q = log_q_true,
                         `w[1]` = w1_true,   `w[2]` = 1 - w1_true),
        generated = list(
          ntrials            = nrow(sbc_sched),
          nfeatures          = 2L,
          cat_one            = as.integer(sbc_sched$category_feedback),
          y                  = as.integer(sim$sim_response),
          obs                = as.matrix(sbc_sched[, c("height", "position")]),
          initial_mu_cat0    = c(2.5, 2.5),
          initial_mu_cat1    = c(2.5, 2.5),
          initial_sigma_diag = 10.0,
          prior_logr_mean    = LOG_R_PRIOR_MEAN,
          prior_logr_sd      = LOG_R_PRIOR_SD,
          prior_logq_mean    = LOG_Q_PRIOR_MEAN,
          prior_logq_sd      = LOG_Q_PRIOR_SD,
          alpha_w            = c(1.0, 1.0)
        )
      )
    }
  )

  attn_sbc_backend <- SBC_backend_cmdstan_sample(
    mod_prototype_attention,
    iter_warmup   = 1000,
    iter_sampling = 1000,
    chains        = 2,
    adapt_delta   = 0.9,
    refresh       = 0
  )

  attn_sbc_datasets <- generate_datasets(
    attn_sbc_generator,
    n_sims = n_sbc_iterations_attn
  )

  sbc_results_attn <- compute_SBC(
    attn_sbc_datasets,
    attn_sbc_backend,
    cache_mode     = "results",
    cache_location = here("simdata", "ch13_sbc_attention_cache"),
    keep_fits      = FALSE
  )
  saveRDS(sbc_results_attn, sbc_attn_filepath)
  cat("SBC attention results computed and saved.\n")
} else {
  sbc_results_attn <- readRDS(sbc_attn_filepath)
  cat("Loaded existing SBC attention results.\n")
}
Loaded existing SBC attention results.
Code
plot_ecdf_diff(sbc_results_attn)

Code
plot_rank_hist(sbc_results_attn)

The ECDF-difference plot should show all four curves — log_r, log_q, w[1], w[2] — within the simultaneous confidence bands. The attention-weight curves are the main test: a U-shaped rank histogram for w[1] would indicate the posterior is overconfident about attention allocation (placing too little probability near 0.5 when the true weight happens to fall there); an inverted-U would indicate underconfidence. If the noise parameters log_r and log_q show the same mild overestimation seen in the base model, that confirms the pattern is inherent to the paradigm rather than induced by adding attention.

Code
# Inspect sampler diagnostics across the SBC runs
attn_backend_diags <- sbc_results_attn$backend_diagnostics
attn_default_diags <- sbc_results_attn$default_diagnostics |>
  dplyr::select(sim_id, max_rhat, min_ess_to_rank)

attn_true_params <- sbc_results_attn$stats |>
  dplyr::select(sim_id, variable, simulated_value) |>
  tidyr::pivot_wider(names_from = variable, values_from = simulated_value)

attn_diag_df <- attn_backend_diags |>
  dplyr::left_join(attn_default_diags, by = "sim_id") |>
  dplyr::left_join(attn_true_params,   by = "sim_id") |>
  dplyr::mutate(has_issue = n_divergent > 0 | n_max_treedepth > 0 | max_rhat > 1.01)

ggplot(attn_diag_df, aes(x = log_r, y = log_q, colour = has_issue)) +
  geom_point(alpha = 0.7, size = 2) +
  scale_colour_manual(values = c("FALSE" = "#0072B2", "TRUE" = "#D55E00")) +
  labs(
    title    = "Attention Model: Diagnostic Issues in Parameter Space",
    subtitle = "Orange = divergence, max treedepth, or Rhat > 1.01",
    x = "True log(r)", y = "True log(q)"
  ) +
  theme(legend.position = "bottom")

Code
# Parameter recovery scatter across SBC iterations
attn_recovery_df <- sbc_results_attn$stats

ggplot(attn_recovery_df, aes(x = simulated_value, y = mean)) +
  geom_point(alpha = 0.4, colour = "#0072B2") +
  geom_abline(intercept = 0, slope = 1, colour = "red", linetype = "dashed") +
  facet_wrap(~variable, scales = "free") +
  labs(
    title    = "Attention Prototype: Parameter Recovery Across SBC Simulations",
    subtitle = "Posterior mean vs. true generating value; red line = identity",
    x = "True value", y = "Posterior mean"
  ) +
  theme_bw()

The recovery scatter is the clearest summary of where the model succeeds and where it struggles. log_r points should cluster tightly around the diagonal. log_q points will be more diffuse (the same weak identification seen in the base model). The w[1] scatter is the new information: extreme true values (near 0 or 1) are typically well-recovered because the likelihood strongly prefers one feature; mid-range values (near 0.5) produce wider posteriors and more scatter around the diagonal, reflecting genuine uncertainty rather than model failure.

15.13.10 The Trade-Off: Interpretability vs. Identifiability

Adding attention weights increases the number of free parameters from 2 to \(2 + (K - 1)\) (the \(-1\) is because the simplex constraint removes one degree of freedom). For \(K = 2\) features this is a total of 3 parameters, a modest increase. For higher-dimensional feature spaces the parameter count grows, and attention weights may become individually unidentifiable if many features are similarly diagnostic.

The attention prototype model is most useful in tasks where:

  1. Features vary in diagnosticity — not all dimensions carry equal information about category membership.
  2. Enough trials are available — with only 64 trials (8 stimuli × 8 blocks), the posterior on \(\mathbf{w}\) will often remain broad, especially if the true \(w\) is close to \((0.5, 0.5)\).
  3. The question is comparative — fitting both the base model and the attention model and comparing their ELPDs (via LFO-CV, since path-dependence still applies) tells you whether allowing selective attention meaningfully improves predictive accuracy.

15.14 Cognitive Insights and Model Comparison

The Kalman filter prototype model offers valuable insights:

  • Incremental, Adaptive Learning: It formalizes how category knowledge can be built up one example at a time, with learning automatically slowing down as confidence increases.
  • Uncertainty Representation: It explicitly models the learner’s uncertainty about category prototypes — something the basic GCM does not directly represent.
  • Memory Efficiency: It achieves learning by storing only summary statistics (mean and covariance) per category, aligning with intuitions about cognitive limits.

15.14.1 Comparing to Exemplar Models (GCM)

  • The GCM stores every detail; the prototype model abstracts.
  • The GCM is sensitive to specific outliers; the prototype model is less so (outliers get averaged in).
  • The GCM requires explicit attention weights (\(w\)) as free parameters; the prototype model can be extended with the same mechanism via feature-space rescaling, as developed in §“Extending the Prototype Model: Selective Attention” above.
  • The GCM’s sensitivity parameter \(c\) controls generalization globally; the Kalman filter’s \(r\) controls how rapidly the prototype absorbs new information, providing a different kind of flexibility.

A note of caution about ELPD comparisons in Chapter 17: now that LFO-CV is implemented for both the GCM (Chapter 13) and the Kalman prototype (this chapter), the comparison between them is on equal predictive footing. The rule mixture model in Chapter 15 is conditionally exchangeable and uses standard PSIS-LOO, which is technically valid for that model — but the comparison across architectures is still a comparison of fundamentally different generative processes, and “ELPD is higher” should not be read as “this is the right cognitive theory.” See Chapter 17 for the deeper version of this caveat (and the ABC-based alternative comparison framework that resolves it more cleanly).

15.14.2 Limitations

  • Abstraction Loses Detail: Information about specific examples and within-category variance beyond the prototype covariance is discarded. Hybrid models attempt to combine both approaches.

  • Single Prototype: Assumes categories have a single central tendency, struggling with disjunctive categories (e.g., “mammals that fly or swim”). Extensions using multiple prototypes address this.

  • Path-dependence and LFO-CV (now discharged): The Kalman filter is path-dependent — trial \(t\)’s choice probability depends on the entire filter state from trials \(1, \ldots, t-1\) — and standard PSIS-LOO is an invalid measure of out-of-sample predictive accuracy. Earlier drafts of this material left this as an unresolved caveat. This chapter implements LFO-CV via psis_lfo_kalman(), validates it against exact one-step-ahead refits, and Chapter 15 inherits the implementation. The caveat is no longer outstanding.

  • Conflation of observation noise and decisional precision: The parameter \(r\) controls both the Kalman gain (learning speed) and the multivariate normal variance at decision time (decisional precision), alongside \(q\) which controls drift speed. These roles are partially separable from learning trajectories, but adding a separate softmax temperature \(\tau\) would further disentangle decisional precision from observational noise. We chose not to add \(\tau\) here to keep the model tractable, but the conflation is real — when fitting human data, posterior estimates of \(r\) should be interpreted as an average over learning-speed and decisional-precision contributions.

  • Trial-1 indifference: The model has no category bias and no perceptual fluency prior. Trial 1 is a coin flip by construction. For tasks with strong baseline asymmetries, this is a misspecification.

  • Process noise identifiability: The model includes process noise \(Q = q \cdot I\), but \(q\) and \(r\) are partially confounded under the static Kruschke (1993) paradigm — both affect the effective likelihood width at decision time. Separate identifiability requires non-stationary data where prototypes demonstrably shift. The SBC analysis shows the extent of this confound, and the precision analysis quantifies how many trials of category drift are needed to pin down \(q\).

  • Structural priors: The initial means, initial covariances, and diagonal observation-noise structure are fixed by hand. None of these is inferred from data. A reader who finds the model behaving oddly on a new dataset should re-examine these choices before concluding that the cognitive theory is wrong.

  • Mahalanobis distance, not city-block: The GCM allows the modeller to choose between Euclidean and city-block distance. City-block is often preferred for separable dimensions — features perceived independently, such as height and position in the Kruschke task — because it better matches how people combine information across such dimensions (Shepard 1987). The prototype model has no such choice: Mahalanobis distance is not a design decision but a mathematical consequence of the Gaussian generative model. The multivariate normal density evaluated at \(\vec{x}\) is the exponentiated Mahalanobis distance; substituting city-block would require replacing the Gaussian likelihood with a Laplace one and abandoning the Kalman update equations. The model therefore inherits a geometric assumption — that the feature space is Euclidean — that may not hold for perceptually separable stimuli.

15.15 Conclusion

The Kalman filter prototype model provides a dynamic, computationally plausible account of how categories might be learned by abstracting central tendencies and tracking uncertainty. It stands as a compelling alternative to exemplar-based approaches, highlighting the fundamental trade-off between detailed instance memory and abstract summary representation in categorization.

As a Stan model it has a particularly clean shape: with two parameters and a fully deterministic likelihood, Pathfinder initialization is reliable, parameter recovery is tractable, the diagnostic battery is nearly trivial to satisfy, and SBC is straightforward. That cleanliness is why this chapter is the right place in the trio to introduce LFO-CV — the algorithm is most legible when there are only two parameters to re-weight and the per-trial filter is fast. Chapter 15 inherits psis_lfo_kalman() (alongside its GCM sibling psis_lfo_gcm() from Chapter 13) so that the LFO-CV caveat that earlier drafts of this material repeated across chapters is now discharged in one place.

The next chapter introduces a third perspective — rule-based models — which makes categorization decisions through explicit hypothesis testing rather than continuous similarity to either stored examples or running averages.

Anderson, John R. 1991. “The Adaptive Nature of Human Categorization.” Psychological Review 98 (3): 409.
Dayan, Peter, Sham Kakade, and P Read Montague. 2000. “Learning and Selective Attention.” Nature Neuroscience 3 (12): 1218–23.
Estes, William Kaye. 1994. Classification and Cognition. Oxford University Press.
Gluck, Mark A, and Gordon H Bower. 1988. “From Conditioning to Category Learning: An Adaptive Network Model.” Journal of Experimental Psychology: General 117 (3): 227.
Kalman, Rudolph Emil. 1960. “A New Approach to Linear Filtering and Prediction Problems.”
Medin, Douglas L, and Marguerite M Schaffer. 1978. “Context Theory of Classification Learning.” Psychological Review 85 (3): 207–38.
Posner, Michael I, and Steven W Keele. 1968. “On the Genesis of Abstract Ideas.” Journal of Experimental Psychology 77 (3p1): 353.
Reed, Stephen K. 1972. “Pattern Recognition and Categorization.” Psychonomic Monograph Supplements.
Rosch, Eleanor. 1975. “Cognitive Representations of Semantic Categories.” Journal of Experimental Psychology: General 104 (3): 192.
Rosch, Eleanor H. 1973. “Natural Categories.” Cognitive Psychology 4 (3): 328–50.
Rosch, Eleanor, Carolyn B Mervis, Wayne D Gray, David M Johnson, and Panni Boyes-Braem. 1976. “Basic Objects in Natural Categories.” Cognitive Psychology 8 (3): 382–439.
Shepard, Roger N. 1987. “Toward a Universal Law of Generalization for Psychological Science.” Science 237 (4820): 1317–23.
Smith, J David, and John Paul Minda. 1998. “Prototype and Exemplar Accounts of Category Learning.” Journal of Experimental Psychology: Learning, Memory, and Cognition 24 (6): 1411.
Tenenbaum, Joshua B, Charles Kemp, Thomas L Griffiths, and Noah D Goodman. 2011. “How to Grow a Mind: Statistics, Structure, and Abstraction.” Science 331 (6022): 1279–85.