This post is from Bob
Andrew and I were talking the other day about generalizing R-hat convergence monitoring to the situation where we have multiple asynchronous threads running chains and we needed ragged input. This is because I’m coding (with Steve Bronder and Brian Ward’s help) a parallel auto-stopping version of Stan combining the step-size adaptivity of WALNUTS and the warmup of Nutpie—stay tuned (or follow it or join in and help on the WALNUTS GitHub).
The usual R-hat assumes each chain is the same length. Andrew suggested it would be good to go back to the model to think about how to generalize. I had never thought of R-hat that way, but it turns out it’s really simple, so I can take you all along for the ride.
**A Stan…
This post is from Bob
Andrew and I were talking the other day about generalizing R-hat convergence monitoring to the situation where we have multiple asynchronous threads running chains and we needed ragged input. This is because I’m coding (with Steve Bronder and Brian Ward’s help) a parallel auto-stopping version of Stan combining the step-size adaptivity of WALNUTS and the warmup of Nutpie—stay tuned (or follow it or join in and help on the WALNUTS GitHub).
The usual R-hat assumes each chain is the same length. Andrew suggested it would be good to go back to the model to think about how to generalize. I had never thought of R-hat that way, but it turns out it’s really simple, so I can take you all along for the ride.
A Stan model for Bayesian R-hat
Here’s the Stan model. The input is an M by N matrix of draws theta—the output includes the posterior for R and the indicator if it is below 1.01. The mean of the former is the Bayesian estimate of R, i.e., R-hat. The mean of the latter is the estimated probability that R is below 1.01, just to give it a perhaps not completely arbitrary cutoff.
**rhat.stan**
data {
int<lower=1> M;
int<lower=2> N;
matrix[M, N] theta;
}
parameters {
real mu;
real<lower=0> sigma;
real<lower=0> tau;
vector<multiplier=tau>[M] alpha;
}
model {
mu ~ normal(0, 5);
sigma ~ normal(0, 5);
tau ~ normal(0, 5);
alpha ~ normal(0, tau);
for (m in 1:M) {
theta[m] ~ normal(mu + alpha[m], sigma);
}
}
generated quantities {
real<lower=1> R = sqrt(1 + square(tau) / square(sigma));
int<lower=0, upper=1> R_ok = R < 1.01;
}
The draws in chain m, (i.e., theta[m] == theta[m, 1:N]), are given a normal(mu + alpha[m], sigma) distribution, so that mu is the global mean across chains and alpha[m] is the difference in mean in chain m from the global mean. The posterior standard deviation of all of the draws is represented by sigma. The variable tau is the standard deviation of the alpha[1:M] variables, so that represents the scale of variation among the chain means. The rest of the priors are weakly informative assuming the values are on roughly a unit scale or slightly larger.
The multiplier=tau converts alpha to a non-centered parameterization given its prior distribution normal(0, tau). The non-centered parameterization is more efficient here, though it may not be with large M and large N.
The generated quatnties define R as the square root of 1 + (tau / sigma)^2. This makes it easy to see that it’s the variation between chains as represented by tau, that must go to zero for R to converge to 1. The sigma is just providing the proper scaling to make the R statistic unit free.
Stan will propagate uncertainty through tau and sigma to R. Then the indicator R_ok will be 1 if R is less than 1.01, so its posterior mean is the probability that R is less than 1.01. This is how you code event probability estimators in Stan.
The Bayesian estimate of R, which we can call the “Bayesian R-hat”, is the posterior mean of R. The “hat” etymology derives from statisticians liking to decorate random variables with hats to indicate estimates of the random variable. So if X is a random variable, X-hat is an estimate of that variable. We could’ve taken a posterior median—that’s a different estimator with slightly different properties.
A test model
We need a simple model to generate draws that we can test. I’ll choose one that’s not trivial, so HMC will have to work a bit and we can see some R-hat values that aren’t all 1 after rounding.
**eval.stan**
parameters {
vector<lower=0>[5] theta;
}
model {
theta ~ lognormal(0, 1);
}
A Student-t with 3 degrees of freedom also works.
A simulator using CmdStanPy
Let’s kick the tires, as they say in my hometown of Detroit. We’re going to generate some data using the test model
**fit.py**
import numpy as np
import cmdstanpy as csp
csp.disable_logging()
model = csp.CmdStanModel(stan_file='eval.stan')
fit = model.sample(iter_warmup=10, iter_sampling=200, chains=8, max_treedepth=1, show_progress=False)
print(fit.summary(sig_figs=3))
draws = fit.stan_variable("theta")
rhat_model = csp.CmdStanModel(stan_file='rhat.stan')
for d in range(5):
theta = draws[:, d]
num_chains = fit.runset.chains
draws_per_chain = fit.num_draws_sampling
theta_matrix = theta.reshape(num_chains, draws_per_chain)
rhat_data = {
"M": num_chains,
"N": draws_per_chain,
"theta": theta_matrix
}
rhat_fit = rhat_model.sample(data=rhat_data, show_progress=False)
R_hat = np.mean(rhat_fit.stan_variable("R"))
R_ok = np.mean(rhat_fit.stan_variable("R_ok"))
print(f"{d=}; R-hat: {R_hat:0.3f}; Pr[R < 1.01]: {R_ok:0.3f}")
I find Python very readable, but let me walk you through it. We import and turn off some of the annoying warning messages. Then we compile the evaluation model, sample it, and print the summary. We set tree depth to 1 to force NUTS to act like Langevin and thus sample inefficiently. And we only go 10 warmup iterations and 200 sampling iterations over eight chains.
Then we extract the draws from the posterior as a big 3D array draws of shape (D x M x N), where D is number of parameters, M number of chains, and N iterations per chain.
Then we loop over all the variables in the model, with d going from 0 to 4 (Python indexes from 0, unlike R). For each variable we extract the draws as theta. We grab out the number of chains and draws per chain and reshape the matrix so that it’s (M x N) and put it into a data dictionary (like an R list) called rhat_data.
Then we take the data and fit using the Rhat model and extract our Bayesian estimate of R-hat as the posterior mean of the variable R and same for R_ok. Then we report.
Here’s the output with columns trimmed for readability.
(stanenv) rhat$ python3 fit.py
Mean MCSE StdDev R_hat
theta[1] 1.63 0.1070 1.75 1.04
theta[2] 1.46 0.0915 1.61 1.03
theta[3] 1.63 0.1890 2.06 1.07
theta[4] 1.50 0.0995 1.78 1.03
theta[5] 1.68 0.1110 1.98 1.03
d=0; R-hat: 1.046; Pr[R < 1.01]: 0.016
d=1; R-hat: 1.019; Pr[R < 1.01]: 0.306
d=2; R-hat: 1.058; Pr[R < 1.01]: 0.002
d=3; R-hat: 1.011; Pr[R < 1.01]: 0.616
d=4; R-hat: 1.014; Pr[R < 1.01]: 0.525
Looks reasonable on first glance comparing the estimate R_hat value from CmdStanPy and the Bayesian approach.
Because it’s a hierarchical model, we can’t directly calculate a maximum likelihood estimate—we’d have to marginalize and use max marginal likelihood. Marginalizing out the intermediate normal (alpha) would also be a better way to code rhat.stan. I imagine fitting the model in brms would also be more robust than my simple example code here—they have better priors. I’ll leave all that as an exercise for the reader.
Citation
I went back to the original for the first time ever to verify that indeed this is right model for what R-hat is modeling.
- Gelman and Rubin. 1992. The original R-hat paper.