Skip to main content
Win Vector Examples

Partial Pooling





Much of data science is merely taking averages to estimate expected values.

One can ask: are there better estimates for expected values than the obvious average of observations? One might think, in particular given the Gauss Markov theorem, the answer would have to be "no." Oddly enough, the answer is in fact "yes." One method to do this is called partial pooling, and we will describe this method here. There are even other methods, such as James–Stein estimator (which we do not go into here).

This note works an example of partial pooling from equation 12.1 of Gelman, Hill, "Data Analysis Using Regression and Multilevel/Hierarchical Models", Cambridge 2007. This is a bit of a magic trick, but we will show how it is done. I am not claiming this will revolutionize your data science projects, but it can be an improvement when you don't have a lot of data.

Introduction

The topic of formula 12.1 is: how to partially pool data by hand, though this task is often handled by a hierarchical model solver.

Take as our example problem: we want to estimate the ideal expected value of observed values from various locations. The optimal linear un-biased estimator is to just take the average of the observed values at each location, and hope this is close to the ideal unobserved generating process.

Dr. Nina Zumel has two excellent articles on exactly this problem:

These show how to use the right tools to perform and use the partial pooling analysis, largely in context of the vtreat data preparation system (R version, Python version).

For this note, let's work on a formula and a synthetic example directly.

We will use Python and the sympy symbolic math package to work our examples. This is part of may "data for engineers" teaching effort, where we share code so people can try the methods. First we import our packages.

# import packages
import sympy
from sympy.stats import E, Normal
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

Our problem set-up is:

To execute this idea we need to define a great number of variables and their relations as follows.

# define initial variables
TrueEffectCenter = sympy.Symbol("TrueEffectCenter")  # center of distribution generating locations, unobserved
BetweenLocationsSD = sympy.Symbol("BetweenLocationsSD")  # how individual locations very from each other in expected behavior, unobserved
EstimatedEffectCenter = sympy.Symbol("EstimatedEffectCenter")  # estimate of TrueEffectCenter, observed but some issues how to calculate
LocationValue_j = sympy.Symbol("LocationValue_j")  # actual expected behavior of a given location, unobserved and the goal to estimate
LocationDistFactor_j = Normal("LocationDistFactor_j", mean=0, std=BetweenLocationsSD)  # how locations differ from each other, unobserved
def_LocationValue_j = TrueEffectCenter + LocationDistFactor_j  #  generative definition of LocationDistFactor_j
PerObservationSD = sympy.Symbol("PerObservationSD", positive=True)  # sd of distribution generating observations per location, unobserved
n_j = sympy.Symbol("n_j", positive=True)  # number of observations at the j-th location, observed
LocationMean_j = sympy.Symbol("LocationMean_j", mean=LocationValue_j, std=PerObservationSD)  # mean of all observations at location j, observed
LocationCenterNoise_ji = Normal("LocationCenterNoise_ji", mean=0, std = PerObservationSD / sympy.sqrt(n_j))  # distribution of average seen at location j, unobserved
def_LocationMean_j = LocationValue_j + LocationCenterNoise_ji  # generative definition of LocationCenterNoise_ji
w = sympy.Symbol("w", positive=True)  # our weighting term picking how to pool specific and general observations, to solve for
estimate_j = sympy.Symbol("estimate_j")  # our estimate of the behavior of the j-th location, to solve for

These are a lot of terms, but most of them are merely placeholders for our calculations. The trickiest term is LocationCenterNoise_ji which is correct for independent observations. LocationCenterNoise_ji is founded on the fact that variance is additive for independent observations. The Normal() could be relaxed a bit, but it appears we must specify some specific distribution to use sympy's expectation operators E(). If we can't use E() we could simulate it with more guided substitution steps. Throughout this analysis we are assuming our random variables are independent.

Our goal is to pick a estimate_j to minimize the expected square of the following expression.

expected_error_term = LocationValue_j - estimate_j  # error of our estimate, to minimize square of

expected_error_term

LocationValuejestimatej

Our plan is to re-factor this expression into independent expectation zero terms A, B and then use the fact that for such terms E[(A - B_**2)] = E[A**2] + E[B**2]. The goal is to derive relation 12.1 from Gelman and Hill:

alt text

This is a big equation with a "pooling strength" coefficient expanded into the expression. If we write the pooling coefficient as w, we can then re-write the above as the following.

def_estimate_j = w * LocationMean_j + (1-w) * EstimatedEffectCenter  # definition of our estimate

def_estimate_j

EstimatedEffectCenter(1w)+LocationMeanjw

That is: our estimate is w times using the observed per-location mean (LocationMean_j, the obvious estimate) plus (1-w) times average expected behavior over all locations. Setting w = 1 gives us the traditional "use only observations from the chosen location" solution. Setting w = 0 gives us complete pooling, assuming there is no difference between locations. The trick is to find a w between 0 and 1 that might have lower expected square error.

The 12.1 estimate is a solution that picks w as the following.

neat_soln_approx = 1 / (1 + PerObservationSD**2 / (n_j * BetweenLocationsSD**2))

neat_soln_approx

11+PerObservationSD2BetweenLocationsSD2nj

Some algebra shows this is the same as the Gelman and Hill citation (if we adopt the convention that w = 0 when n_j = 0). We confirm that as follows.

GH_soln =(n_j / PerObservationSD**2) / ( n_j / PerObservationSD**2 + 1 / BetweenLocationsSD**2)
assert (neat_soln_approx - GH_soln).together() == 0

This solution for w has some nice properties.

Our goal is to derive this solution.

Reproducing the Gelman and Hill citation

We can match the Gelman and Hill citation by, during the solving, replacing the visible TrueEffectCenter with our estimated EstimatedEffectCenter (ignoring an an assumed small difference between them).

When we solve in that manner we get the Gelman and Hill cited w as follows. Let's work on that error term, that we want to minimize the expected square of, for a bit.

error_term_approx = (
    expected_error_term
        .subs(estimate_j, def_estimate_j)  # definition of estimate_j
        .subs(EstimatedEffectCenter, TrueEffectCenter)  # this step is an approximation, using the unobserved TrueEffectCenter as if it is the observed EstimatedEffectCenter
        .subs(LocationMean_j, def_LocationMean_j)  # generative definition of LocationMean_j
        .subs(LocationValue_j, def_LocationValue_j)   # generative definition of LocationValue_j
).expand().simplify()

error_term_approx

wLocationCenterNoisejiwLocationDistFactorj+LocationDistFactorj

The simplifying substitution MeanLocation value for the EstimatedEffectCenter is an approximation, and is why the total number of observations is not present in this estimate. This substitution is ignoring the uncertainty in estimating the grand mean, which we assume is slightly easier than our overall task of estimating the conditional means. Or a different view would be: if we felt we knew MeanLocation this is how we would perform the estimate. For example in re-encoding variables (such as in vtreat) we know the overall MeanLocation is in fact zero by the structure of the application set-up.

There is in fact an issue in how to estimate EstimatedEffectCenter when different locations have very different numbers of observations. In this case an average where each observation contributes equally is sensitive to the between location variances, and an average where each location contributes equally is sensitive to per observation variances. What to put in for EstimatedEffectCenter is essentially the entire estimation problem started over again, so picking the data weighting cleverly probably is of some benefit.

Assuming we were happy with what we plugged in for EstimatedEffectCenter we can continue our calculation as follows.

We claim this error term is equal to A - B, for the following A, B.

A = (1-w) * LocationDistFactor_j
B = w * LocationCenterNoise_ji

And we can confirm this claim as follows.

assert (error_term_approx - (A - B)).simplify() == 0
assert E(A).expand() == 0
assert E(B).expand() == 0

We can expand the expected square error E[(A - B)**2] as E[A**2] + E[B**2] (using the independence of A and B, and that they are both expected value zero).

We expect the minimum to occur at either the bounds of our allowed values for w (0 and 1) or at a place where the derivative of our expected square error term with respect to w is zero. That solution is found as follows.

soln_approx = sympy.solve(sympy.diff(E(A**2) + E(B**2), w), w)[0]

soln_approx

BetweenLocationsSD2njBetweenLocationsSD2nj+PerObservationSD2

And we confirm this is (for non-zero quantities) equivalent to the earlier claimed solution.

assert (soln_approx - neat_soln_approx).together().expand() == 0

And this, as promised, matches the text book.

I am experimenting more and more with using sympy in place of hand-worked algebra in these sort of write-ups. My hope is that it allows me to spend more time on intent, and less time explaining steps (which often introduced yet more steps).

Working an Example

Let's see the partial pooling inference in action.

We will generate the synthetic example data according to the above hierarchical design. We will then see how a pooled estimate, itself using estimate parameters, can in fact improve estimation results. The advantage of synthetic data include: we can see how close the estimates come to the correct answers, and we can generate a favorable situation for demonstration.

We generate the not directly observable portion of our example as follows.

# set state of pseudo random number generator for repeatability
rng = np.random.default_rng(2023)  
# set parameters to specific values
example_between_locations_sd = 3.0
example_per_observations_sd = 10.0
n_locations = 10
n_obs_per_location = 3
def generate_example_centers() -> pd.DataFrame:
    """generate the unobserved location values"""
    example_location_value_mean = rng.normal(loc=0, scale=15, size=1)
    example_locations = pd.DataFrame({
        "location_id": range(n_locations),
        "effect": rng.normal(
        loc=example_location_value_mean, 
        scale=example_between_locations_sd, 
        size=n_locations)
    })
    return example_locations
example_locations = generate_example_centers()
example_locations

location_id effect
0 0 12.480676
1 1 4.947432
2 2 9.691985
3 3 6.698217
4 4 11.451937
5 5 8.429935
6 6 4.289738
7 7 7.137952
8 8 7.862549
9 9 9.176378

And we can then generate observations as follows.

def generate_observations(example_locations: pd.DataFrame)-> pd.DataFrame:
    """
    generate observed data

    :param example_locations: the (unobserved) ground truth to generate from
    :return: observed data
    """
    assert isinstance(example_locations, pd.DataFrame)
    observations = []
    for j in range(example_locations.shape[0]):
        obs_j = pd.DataFrame({
            "location_id": j,
            "observation": rng.normal(
            loc=example_locations.effect[j], 
            scale=example_per_observations_sd, 
            size=n_obs_per_location),
        })
        observations.append(obs_j)
    observations = pd.concat(observations, ignore_index=True)
    return observations
observations = generate_observations(example_locations)
observations

location_id observation
0 0 3.410240
1 0 13.802057
2 0 26.529701
3 1 8.988453
4 1 -5.426702
5 1 -2.470594
6 2 22.326980
7 2 2.798745
8 2 16.772133
9 3 6.347742
10 3 -1.679948
11 3 -0.626963
12 4 21.925762
13 4 -2.786075
14 4 24.992347
15 5 4.479009
16 5 -0.862222
17 5 12.870562
18 6 5.711914
19 6 17.977145
20 6 -0.697886
21 7 6.395265
22 7 16.531671
23 7 0.876524
24 8 -2.427039
25 8 -13.467843
26 8 16.906428
27 9 -11.667661
28 9 12.001735
29 9 5.930353

Now we define the standard estimator, and code to evaluate how well it estimates the first group's expected value.

# define the standard estimator
def standard_effect_estimate(observations: pd.DataFrame) -> pd.DataFrame:
    """
    Get the standard estimate of the effect at locations from observations.

    :param observations: the observations data frame
    :return: standard estimate of effect or mean by location
    """
    assert isinstance(observations, pd.DataFrame)
    means = (
        observations.loc[:, ["location_id", "observation"]]
            .reset_index(drop=True, inplace=False)
            .groupby(["location_id"])
            .mean()
            .reset_index(drop=False, inplace=False)
    )
    means.sort_values(["location_id"], inplace=True, ignore_index=True)
    means.rename(columns={"observation": "estimate"}, inplace=True)
    means["grand_mean"] = np.mean(observations["observation"])
    means["impact"] = means["estimate"] - means["grand_mean"]
    means.sort_values(["location_id"], inplace=True, ignore_index=True)
    return means

The standard estimate works as follows.

standard_effect_estimate(observations)

location_id estimate grand_mean impact
0 0 14.580666 6.848728 7.731938
1 1 0.363719 6.848728 -6.485009
2 2 13.965953 6.848728 7.117225
3 3 1.346944 6.848728 -5.501784
4 4 14.710678 6.848728 7.861950
5 5 5.495783 6.848728 -1.352945
6 6 7.663724 6.848728 0.814996
7 7 7.934487 6.848728 1.085759
8 8 0.337182 6.848728 -6.511546
9 9 2.088142 6.848728 -4.760585
# estimate the quality of the standard estimator
def get_sq_error_loss_std_estimate(
        *, 
        location_id: int, example_locations: pd.DataFrame, observations: pd.DataFrame) -> float:
    """
    Show the square-error of estimate of the effect at location location_id from observations

    :parm location_id: which location to calculate for
    :param example_locations: the (unobserved) ground truth to compare to
    :param observations: the observations data frame
    :return: estimated square error of this estimation procedure
    """
    # get the standard estimates
    observed_means = standard_effect_estimate(observations=observations)
    # get the unobservable true effect for comparison
    true_effect = example_locations.loc[example_locations["location_id"] == location_id, "effect"].values[0]
    # calculate the square error of these estimates
    estimated_effect = observed_means.loc[observed_means["location_id"] == location_id, "estimate"].values[0]
    square_error = (estimated_effect - true_effect)**2
    return square_error

And we define the pooled estimator, and it's evaluation procedure.

# define the pooled estimator
def pooled_effect_estimate(observations: pd.DataFrame) -> pd.DataFrame:
    """
    Get the pooled estimate of effect

    :param observations: the observations data frame, with columns location_id and observation
    :return: pooled estimates
    """
    assert isinstance(observations, pd.DataFrame)
    observations = observations.loc[:, ["location_id", "observation"]].reset_index(inplace=False, drop=True)
    # get the standard estimates
    estimated_centers = standard_effect_estimate(observations=observations)
    if estimated_centers.shape[0] <= 1:
        # no pooling possible
        return estimated_centers
    # get counts per group
    obs_count_frame = (
        pd.DataFrame({"location_id": observations["location_id"], "count": 1})
            .groupby(["location_id"])
            .sum()
            .reset_index(drop=False, inplace=False)
            .sort_values(["location_id"], inplace=False, ignore_index=True)
    )
    n_j = obs_count_frame["count"].values
    # get the observed variance for each item at for each location
    combined = (
        observations
            .merge(
                estimated_centers,
                on=["location_id"],
                how="left",
                )
            .merge(
                obs_count_frame,
                on=["location_id"],
                how="left",
                )
    )
    combined.sort_values(["location_id"], inplace=True, ignore_index=True)
    per_location_observation_var = (
        np.sum((combined["observation"] - combined["estimate"])**2) 
        / (combined.shape[0] - len(set(combined["location_id"])))
    )
    # get the observed variance between locations
    between_location_var = np.var(estimated_centers["estimate"], ddof=1)
    # get v, the pooling coefficient
    if between_location_var <= 0:
        v = 0
    elif per_location_observation_var <= 0:
        v = 1
    else:
        # as between_location_var > 0 and per_location_observation_var > 0 here
        # v will be in the range 0 to 1
        v = 1 / (1 + per_location_observation_var / (n_j * between_location_var))
    # our estimate of the overall shared effect
    # note we are using the mixing proportions suggested by the variance reduction ideas
    # simpler weightings include:
    #   combined["obs_weight"] = 1   # weights all observations equally
    #   combined["obs_weight"] = 1 / combined["count"]  # weights all locations equally
    # below, weights larger observations groups more, but with a diminishing return
    # this is an ad-hoc heuristic to try to reduce square error when the number of 
    # observations per location has a lot of variation
    combined["obs_weight"] = 1
    if (between_location_var > 0) and (per_location_observation_var > 0):
        combined["obs_weight"] = 1 / (1 + per_location_observation_var / (combined["count"] * between_location_var))
    # this quantity can be improved using knowledge of the variances
    grand_mean = np.sum(combined["observation"] * combined["obs_weight"]) / np.sum(combined["obs_weight"])
    # build the pooled estimate
    pooled_estimate = v * estimated_centers["estimate"] + (1 - v) * grand_mean
    return pd.DataFrame({
        "location_id": estimated_centers["location_id"],
        "estimate": pooled_estimate,
        "grand_mean": grand_mean,
        "impact": pooled_estimate - grand_mean,
    })

The pooled estimate works as follows.

# show the estimate in action
pooled_effect_estimate(observations)

location_id estimate grand_mean impact
0 0 10.557640 6.848728 3.708913
1 1 3.737951 6.848728 -3.110776
2 2 10.262770 6.848728 3.414043
3 3 4.209592 6.848728 -2.639136
4 4 10.620005 6.848728 3.771278
5 5 6.199737 6.848728 -0.648991
6 6 7.239671 6.848728 0.390943
7 7 7.369552 6.848728 0.520825
8 8 3.725222 6.848728 -3.123506
9 9 4.565135 6.848728 -2.283592
# evaluate the pooled estimator quality
def get_sq_error_loss_pooled_estimate(
        *, 
        location_id: int, example_locations: pd.DataFrame, observations: pd.DataFrame) -> float:
    """
    Show the square error of partial pooled out estimates of the effect at location location_id from observations

    :parm location_id: which location to calculate for
    :param example_locations: the (unobserved) ground truth to compare to
    :param observations: the observations data frame
    :return: estimated square error of this estimation procedure
    """
    # get the estimates
    pooled_estimates = pooled_effect_estimate(observations=observations)
    # get the unobservable true effect for comparison
    true_effect = example_locations.loc[example_locations["location_id"] == location_id, "effect"].values[0]
    # calculate the square error of these estimates
    estimated_effect = pooled_estimates.loc[pooled_estimates["location_id"] == location_id, "estimate"].values[0]
    square_error = (estimated_effect - true_effect)**2
    return square_error

Now, let's try the estimators. In each case we the evaluation is returning square-error, so smaller is better.

# run the experiment for the standard estimator
std_est_loss = get_sq_error_loss_std_estimate(
    location_id=0, 
    example_locations=example_locations, 
    observations=observations)

std_est_loss

4.409957134811661

# run the experiment for the pooled estimator
pooled_est_loss = get_sq_error_loss_pooled_estimate(
    location_id=0, 
    example_locations=example_locations, 
    observations=observations)

pooled_est_loss

3.698066838111626

assert pooled_est_loss < std_est_loss

We see, the pooled estimate has smaller expected square error loss that the standard estimate. That is: we have evidence the pooled estimate would be closer to new observations for a given location. We want to further confirm by repeating the experiment more times.

# wrap our experiment in a repeatable function
def run_experiment() -> pd.DataFrame:
    example_locations = generate_example_centers()
    observations = generate_observations(example_locations)
    std_est_loss = get_sq_error_loss_std_estimate(
        location_id=0, 
        example_locations=example_locations, 
        observations=observations)
    pooled_est_loss = get_sq_error_loss_pooled_estimate(
        location_id=0, 
        example_locations=example_locations, 
        observations=observations)
    return pd.DataFrame({
        "std_loss": [std_est_loss],
        "pooled_loss": [pooled_est_loss],
    })
# run the experiments
experiments = pd.concat([run_experiment() for i in range(1000)], ignore_index=True)
# show the average loss of each method
experiments.mean()

std_loss 31.652005
pooled_loss 14.987383
dtype: float64

# show the median loss of each method
medians = experiments.median()
medians

std_loss 13.859971
pooled_loss 5.906482
dtype: float64

# show how each method's loss varies (standard deviation)
experiments.std()

std_loss 46.143738
pooled_loss 24.747940
dtype: float64

Notice the pooled method, in this example, out-performs the standard method in all three of the above measurements. This is because the example is deliberately generating closely related data.

We can also plot the distribution losses to get a bit more detail.

# plot the experiments
plt_frame = pd.concat([
    pd.DataFrame({"estimate": "standard", "loss": experiments["std_loss"]}),
    pd.DataFrame({"estimate": "pooled", "loss": experiments["pooled_loss"]}),
    ], ignore_index=True)
palette = ["#d95f02", "#1b9e77"]
sns.kdeplot(
    data=plt_frame,
    x="loss",
    hue="estimate",
    hue_order=["standard", "pooled"],
    palette=palette,
    fill=True,
)
plt.title("distribution of square error losses by estimation method (lower better)\ndashed bars at medians")
plt.xlim(0, np.quantile(plt_frame["loss"], 0.95))
plt.axvline(x=medians["std_loss"], color=palette[0], linestyle="dashed")
plt.axvline(x=medians["pooled_loss"], color=palette[1], linestyle="dashed")
plt.show()

png

The method has been seen to work, even with estimated variances and grand means (instead of using the unobserved true generating parameters). This is evidence the method can be used in practice.

Alternatives

One alternative method is to re-write the mixing estimate as w = X / n_j, where X is an unknown constant. We can then estimate an effective X by cross-validation, instead of using the theoretical value. That is: we use the functional form from the above theory, and instead of estimating the variances instead pick X directly from data.

Or, one can delegate the entire task to a hierarchical modeling tool. This allows specification of more detailed hierarchies, that can encode useful modeling assumptions.

A note on hierarchical models

We've not really talked about the meaning of the term "hierarchical model."

Hierarchical models are the clear way to explain the how different data generation assumptions effect our modeling strategy. In this case, knowing how strongly different data is related allows us to implement a lower variance (though biased) estimate. This note is the direct re-implementation of a simple hierarchical model: using that locations come from a common distribution, and then later produce observations.

From Gelman and Hill section 11.3 we take that hierarchical models are an alternative to attempts to define fixed effect versus random effect modeling. We understand this is from Bayesian-friendly authors that don't particularly like the Frequentist terms fixed effect and random effect, but their advice seems sound:

Our advice (elaborated upon in the rest of this book) is to always use multilevel modeling ("random effects"). Because of conflicting definitions and advice, we avoid the terms "fixed" and "random" entirely, and focus on the description of the model itself [...] with the understanding that batches of coefficients [...] will themselves be modeled.

(Ellipses meant to avoid some un-introduced detail, and not alter the meaning, context, or scope of the original statement. In this context "modeled" essentially means subject to distributional assumptions, be these assumptions priors or conditions relating groups of variables. This is quite normal for Bayesians. But needs to be mystified a bit when used by Frequentists, as this is something Frequentists claim not to do.)

Conclusion

The partial pooling improvement for estimating an unseen value from noisy observations depends on a single parameter PerObservationSD**2 / (n_j * BetweenLocationsSD**2). This parameter compares the uncertainty in the mean of the observations (PerObservationSD**2 / n_j, or standard error squared) for a single location, to the uncertainty between locations (BetweenLocationsSD**2, a standard deviation squared). When this ratio is small, we don't pool data, we just estimate the average value using data from one location. When this ratio is large, pooling is likely a useful variance reducing procedure. Of course, these values are not directly observed; and thus have to be themselves estimated.

Essentially the partial pooling method is trading away variance (due to having very few samples) for bias (due to including samples that don't match the target location). If the different locations are sufficiently related, this can be an improving trade off in terms of expected square error.

In practice, all of the above is typically ignored when there is a lot of data. And, the above inference is made inside a hierarchical model solver, not directly by a practitioner. However, it is good to see the actual form of the pooling strategy.