In a [previous article](https://github.com/WinVector/vtreat/blob/main/Examples/CustomLevelCoding/CustomLevelCoding.md), we showed the use of partial pooling, or hierarchical/multilevel models, for level coding high-cardinality categorical variables in [`vtreat`](https://winvector.github.io/vtreat/). In this article, we will discuss a little more about the how and why of partial pooling in [`R`](https://www.r-project.org).
We will use the `lme4` package to fit the hierarchical models. The acronym "lme" stands for "linear mixed-effects" models: models that combine so-called "fixed effects" and "random effects" in a single (generalized) linear model. The `lme4` documentation uses the random/fixed effects terminology, but we are going to follow Gelman and Hill, and avoid the use of the terms "fixed" and "random" effects.
> The varying coefficients [corresponding to the levels of a categorical variable] in a multilevel model are sometimes called *random effects*, a term that refers to the randomness in the probability model for the group-level coefficients....
> The term *fixed effects* is used in contrast to random effects -- but not in a consistent way! ... Because of the conflicting definitions and advice, we will avoid the terms "fixed" and "random" entirely, and focus on the description of the model itself...
-- Gelman and Hill 2007, Chapter 11.4
We will also restrict ourselves to the case that `vtreat` considers: partially pooled estimates of conditional group expectations, with no other predictors considered.
## The Data
Let's assume that the data is generated from a mixture of $M$ populations; each population is normally distributed with (unknown) means $\mu_{gp}$, all with the same (unknown) standard deviation $\sigma_w$:
$$
y_{gp} = N(\mu_{gp}, {\sigma_{w}}^2)
$$
The population means themselves are normally distributed, with unknown mean $\mu_0$ and unknown standard deviation $\sigma_b$:
$$
\mu_{gp} = N(\mu_0, {\sigma_{b}}^2)
$$
(The subscripts *w* and *b* stand for "within-group" and "between-group" standard deviations, respectively.)
```{r setup, include=FALSE}
knitr::opts_chunk$set(echo = TRUE)
```
```{r dat, echo=FALSE, message=FALSE, warning=FALSE}
library("lme4")
library("dplyr")
library("tidyr")
library("ggplot2")
library("viridis")
library("kableExtra")
library("cdata")
set.seed(46356)
# source generate_data to create and save the data
# source("generate_data.R")
# read in the synthetic data and the true population means
df = readRDS("synthdata.rds")
true_mu = readRDS("truemu.rds")
```
We can generate a synthetic data set according to these assumptions, with distributions similar to the distributions observed in the radon data set that we used in our earlier post: 85 groups, sampled unevenly. We'll use $\mu_0 = 0, \sigma_w = 0.7, \sigma_b = 0.5$. Here, we take a peek at our data, `df`.
```{r}
head(df)
```
```{r echo=FALSE}
rawdata = df %>% group_by(gp) %>%
summarize(gmean = mean(y),
gsd = sd(y),
count = n(),
grandmean = mean(df$y)) %>%
mutate(stderr = ifelse(count<=1, 1, gsd/sqrt(count))) %>%
mutate(errlo = gmean - stderr,
errhi = gmean + stderr) %>%
right_join(true_mu, by="gp")
# add a jitter to the count so we can plot by number of observations without overlap
set.seed(34255)
countnoise = 0.8*runif(nrow(rawdata))
rawdata$countr = rawdata$count + countnoise
df = inner_join(df, select(rawdata, gp, countr), by="gp")
# distribution of population sizes
rawdata %>% group_by(count) %>%
summarize(nobs = n()) %>%
ggplot(aes(x=count, y=nobs)) +
geom_pointrange(aes(ymin=0, ymax=nobs), fatten=2) +
ggtitle("Distribution of group sample sizes")
```
As the graph shows, some groups were heavily sampled, but most groups have only a handful of samples in the data set. Since this is synthetic data, we know the true population means (shown in red in the graph below), and we can compare them to the observed means $\bar{y}_i$ of each group $i$ (shown in black, with standard errors. The gray points are the actual observations). We've sorted the groups by the number of observations.
```{r echo=FALSE}
# meanorder = levels(reorder(rawdata$gp, rawdata$mu_gp))
countorder = levels(reorder(rawdata$gp, rawdata$count))
# by count
gplevels = countorder
rawdata %>% mutate(gp = factor(gp, levels=gplevels)) %>%
ggplot(aes(x=gp)) +
geom_point(data=df, aes(x = factor(gp, levels=gplevels), y=y), color="#bdbdbd",
position=position_jitter(width=0.05, height=0)) +
geom_line(aes(y=mu_gp, group=1), color="#993404") +
geom_point(aes(y=mu_gp), color="#993404") +
geom_pointrange(aes(y=gmean, ymin=errlo, ymax=errhi), fatten=3) +
geom_hline(aes(yintercept=grandmean), color="darkblue") +
scale_x_discrete(breaks=NULL) +
ggtitle("Raw group mean estimates",
subtitle="Groups ordered by observation count")
```
For groups with many observations, the observed group mean is near the true mean. For groups with few observations, the estimates are uncertain, and the observed group mean can be far from the true population mean.
Can we get better estimates of the conditional mean for groups with only a few observations?
## Partial Pooling
If the data is generated by the process described above, and if we knew $\sigma_w$ and $\sigma_b$, then a good estimate $\hat{y}_i$ for the mean of group $i$ is the weighted average of the grand mean over all the data, $\bar{y}$, and the observed mean of all the observations in group $i$, $\bar{y}_i$.
$$
\large
\hat{y_i} \approx \frac{\frac{n_i} {\sigma_w^2} \cdot \bar{y}_i + \frac{1}{\sigma_b^2} \cdot \bar{y}}
{\frac{n_i} {\sigma_w^2} + \frac{1}{\sigma_b^2}}
$$
where $n_i$ is the number of observations for group $i$. In other words, for groups where you have a lot of observations, use an estimate close to the observed group mean. For groups where you have only a few observations, fall back to an estimate close to the grand mean.
Gelman and Hill call the grand mean the *complete-pooling* estimate, because the data from all the groups is pooled to create the estimate \hat{y_i} (which is the same for all $i$). The "raw" observed means are the *no-pooling* estimate, because no pooling occurs; only observations from group $i$ contribute to $\hat{y_i}$. The weighted sum of the complete-pooling and the no-pooling estimate is hence the *partial-pooling* estimate.
Of course, in practice we don't know $\sigma_w$ and $\sigma_b$. The `lmer` function essentially solves for the restricted maximum likelihood (REML) estimates of the appropriate parameters in order to estimate $\hat{y_i}$. You can express multilevel models in `lme4` using the notation `| gp` in formulas to designate that `gp` is the grouping variable that you want conditional estimates for. The model that we are interested in is the simplest: outcome as a function of the grouping variable, with no other predictors.
```{r shrinkage}
poolmod = lmer(y ~ (1 | gp), data=df)
```
See section 2.2 of [this `lmer` vignette](https://cran.r-project.org/web/packages/lme4/vignettes/lmer.pdf) for more discussion on writing formulas for models with additional predictors. Printing `poolmod` displays the REML estimates of the grand mean (The intercept), $\sigma_b$ (the standard deviation of $gp$) and $\sigma_w$ (the residual).
```{r}
poolmod
```
To pull these values out explicitly:
```{r}
# the estimated grand mean
(grandmean_est= fixef(poolmod))
# get the estimated between-group standard deviation
(sigma_b = as.data.frame(VarCorr(poolmod)) %>%
filter(grp=="gp") %>%
pull(sdcor))
# get the estimated within-group standard deviation
(sigma_w = as.data.frame(VarCorr(poolmod)) %>%
filter(grp=="Residual") %>%
pull(sdcor))
```
`predict(poolmod)` will return the partial pooling estimates of the group means. Below, we compare the partial pooling estimates to the raw group mean expectations. The gray lines represent the true group means, the dark blue horizontal line is the observed grand mean, and the black dots are the estimates. We have again sorted the groups by number of observations, and laid them out (with a slight jitter) on a log10 scale.
```{r echo=FALSE}
pooldata = rawdata %>%
select(gp, count, countr) %>%
mutate(gmean = as.numeric(predict(poolmod, newdata=rawdata)),
grandmean=grandmean_est) %>%
right_join(true_mu, by="gp")
alldata = select(rawdata, gp, count, countr, grandmean, gmean, mu_gp) %>%
mutate(estimate_type="raw") %>%
bind_rows(mutate(pooldata, estimate_type="partial pooling")) %>%
mutate(estimate_type=factor(estimate_type, levels=c("raw", "partial pooling")))
alldata %>%
ggplot(aes(x=countr)) +
geom_line(aes(y=mu_gp), color="darkgray") +
geom_point(aes(y=mu_gp), color="darkgray") +
geom_point(aes(y=gmean)) +
geom_hline(aes(yintercept=grandmean), color="darkblue") +
facet_wrap(~estimate_type, ncol=1) +
scale_x_log10("Number of observations (log10 scale)") +
ggtitle("Group mean estimates",
subtitle="Compared to observed and estimated grand means, respectively")
```
For groups with only a few observations, the partial pooling "shrinks" the estimates towards the grand mean[^1], which often results in a better estimate of the true conditional population means. We can see the relationship between shrinkage (the raw estimate minus the partial pooling estimate) and the groups, ordered by sample size.
```{r, echo=FALSE}
alldata %>% select(gp, countr, gmean, estimate_type) %>%
mutate(estimate_type = as.character(estimate_type)) %>%
mutate(estimate_type = ifelse(estimate_type=="partial pooling", "partial_pooling", estimate_type)) %>%
pivot_to_rowrecs(columnToTakeKeysFrom="estimate_type",
columnToTakeValuesFrom="gmean",
rowKeyColumns=c("gp", "countr")) %>%
ggplot(aes(x=countr, y=raw-partial_pooling)) +
geom_point() + geom_hline(yintercept=0, color="darkblue") +
ylab("shrinkage (raw - partial pooling estimate)") +
scale_x_log10("Number of observations (log10 scale)") +
ggtitle("Estimate shrinkage")
```
For this data set, the partial pooling estimates are on average closer to the true means than the raw estimates; we can see this by comparing the root mean squared errors of the two estimates.
```{r echo=FALSE}
# RMSE table on group mean estimates
alldata %>%
group_by(estimate_type) %>%
mutate(sqrerr = (gmean-mu_gp)^2) %>%
summarize(rmse = sqrt(mean(sqrerr))) %>%
knitr::kable(format="html") %>%
kableExtra::kable_styling(full_width=F)
```
[^1]: To be precise, partial pooling shrinks estimates toward the *estimated* grand mean `r format(grandmean_est, digits=3)`, not to the *observed* grand mean `r format(mean(df$y), digits=3)`.
### The Discrete Case
For discrete (binary) outcomes or classification, use the function `glmer()` to fit multilevel logistic regression models. Suppose we want to predict $\mbox{P}(y > 0 \,|\, gp)$, the conditional probability that the outcome $y$ is positive, as a function of $gp$.
```{r}
df$ispos = df$y > 0
# fit a logistic regression model
mod_glm = glm(ispos ~ gp, data=df, family=binomial)
```
Again, the conditional probability estimates will be highly uncertain for groups with only a few observations. We can fit a multilevel model with `glmer` and compare the distributions of the resulting predictions in link space.
```{r}
mod_glmer = glmer(ispos ~ (1|gp), data=df, family=binomial)
```
```{r echo=FALSE}
sigmoid = function (x) {
1/(1 + exp(-x))
}
meanests = data.frame(estimate_type = c("pred_glm", "pred_glmer"),
pred = c(mean(df$ispos), sigmoid(fixef(mod_glmer))))
global_prob = sigmoid(fixef(mod_glmer))
predframe = group_by(df, gp) %>% summarize(count=n())
# links, not probabilities.
# Use predict(, type="response") for probabilities
predframe %>% mutate(pred_glm = predict(mod_glm, newdata=predframe),
pred_glmer = predict(mod_glmer, newdata=predframe)) -> predframe
gather(predframe, key = estimate_type, value=pred, pred_glm, pred_glmer) %>%
ggplot(aes(x=pred)) + geom_density(adjust=0.5) +
geom_vline(data=meanests, aes(xintercept=pred), color="darkblue") +
facet_wrap(~estimate_type, ncol=1) +
ggtitle("Distribution of link values",
subtitle="Compared to observed and estimated grand means, respectively")
```
Note that the distribution of predictions for the standard logistic regression model is trimodal, and that for some groups, the logistic regression model predicts probabilities very close to 0 or to 1. In most cases, these predictions will correspond to groups with few observations, and are unlikely to be good estimates of the true conditional probability. The partial pooling model avoids making *unjustified* predictions near 0 or 1, instead “shrinking” the estimates to the estimated global probability that $y > 0$, which in this case is about `r format(global_prob, digits=2)`.
We can see how the number of observations corresponds to the shrinkage (the difference between the logistic regression and the partial pooling estimates) in the graph below (this time in probability space). Points in orange correspond to groups where the logistic regression estimated probabilities of 0 or 1 (the two outer lobes of the response distribution). Multimodal densities are often symptoms of model flaws such as omitted variables or un-modeled mixtures, so it is exciting to see the partially pooled estimator avoid the "wings" seen in the simpler logistic regression estimator.
```{r echo=FALSE}
# probabilities this time
# predframe = group_by(df, gp) %>% summarize(count=n())
predframe = rawdata %>% select(gp, count, countr)
predframe %>%
mutate(pred_glm = predict(mod_glm, newdata=predframe, type="response"),
pred_glmer = predict(mod_glmer, newdata=predframe, type="response")) %>%
mutate(shrinkage = pred_glm-pred_glmer,
pred0or1 = ifelse(pred_glm >= 0.99 | pred_glm < 1e-6, TRUE, FALSE),
gp=factor(gp, gplevels)) -> predframe
predframe %>%
ggplot( aes(x=countr, y=shrinkage)) +
geom_point(aes(color=pred0or1)) +
geom_hline(yintercept=0, color="darkblue") +
scale_x_log10("Number of observations (log10 scale)") +
ylab("shrinkage (pred_glm - pred_glmer)") +
# scale_color_manual(values=c("black", "maroon")) +
scale_color_brewer(palette="Dark2") +
theme(legend.position="none") +
ggtitle("Shrinkage of probability estimates", subtitle="Groups ordered by observation count")
```
## Partial Pooling Degrades Gracefully
When there is enough data for each population to get a good estimate of the population means -- for example, when the distribution of groups is fairly uniform, or at least not too skewed -- the partial pooling estimates will converge to the the raw (no-pooling) estimates. When the variation between population means is very low, the partial pooling estimates will converge to the complete pooling estimate (the grand mean).
When there are only a few levels (Gelman and Hill suggest less than about five), there will generally not be enough information to make a good estimate of $\sigma_b$, so the partial pooled estimates likely won't be much better than the raw estimates.
So partial pooling will be of the most potential value when the number of groups is large, and there are many rare levels. With respect to `vtreat`, this is exactly the situation when level coding is most useful!
Multilevel modeling assumes the data was generated from the mixture process above: each population is normally distributed, with the same standard deviation, and the population means are also normally distributed. Obviously, this may not be the case, but as Gelman and Hill argue, the additional inductive bias can be useful for those populations where you have little information.
Thanks to [Geoffrey Simmons](https://www.linkedin.com/in/geoffrey-simmons-bb675242/), Principal Data Scientist at Echo Global Logistics, for suggesting partial pooling based level coding for `vtreat`, introducing us to the references, and reviewing our articles.
## References
Gelman, Andrew and Jennifer Hill. *Data Analysis Using Regression and Multilevel/Hierarchical Models*. Cambridge University Press, 2007.