9. Mixture models and label switching with MCMC

Data set download


[2]:
import numpy as np
import scipy.stats as st
import polars as pl
import polars.selectors as cs

import cmdstanpy
import arviz as az

import iqplot
import bebi103

import bokeh.io
import bokeh.plotting
bokeh.io.output_notebook()
Loading BokehJS ...

We continue with our analysis of the smFISH data, but this time for the Rex1 gene. Here, we saw clear bimodality in the data.

[3]:
# Load DataFrame and get counts
df = pl.read_csv(os.path.join(data_path, "singer_transcript_counts.csv"), comment_prefix="#")
n = df["Rex1"].to_numpy()

bokeh.io.show(
    iqplot.ecdf(n, x_axis_label="mRNA count", frame_height=150, frame_width=200)
)

Mixture models

Since the Negative Binomial distribution is unimodal, what could be the story here? It is quite possible we are seeing two different states of cells, one with one level of bursty expression of the gene of interest, and another with a higher level of bursty expression. This could mean the cells are differentiating. So, we would expect the number of mRNA transcripts to be distributed according to a linear combination of two negative binomial distributions. We can write out the PMF as

\begin{align} f(n\mid \alpha_1, \alpha_2, \beta_1, \beta_2, w) &= w\,\frac{\Gamma(n + \alpha_1)}{n!\,\Gamma(\alpha_1)}\,\left(\frac{\beta_1}{1+\beta_1}\right)^{\alpha_1}\left(\frac{1}{1+\beta_1}\right)^{n} \\[1em] &\;\;\;\;+ (1-w) \,\frac{\Gamma(n + \alpha_2)}{n!\,\Gamma(\alpha_2)}\,\left(\frac{\beta_2}{1+\beta_2}\right)^{\alpha_2}\left(\frac{1}{1+\beta_2}\right)^{n} , \end{align}

where \(w\) is the probability that the burst size and frequency are determined by \(1/\beta_1\) and \(\alpha_1\). Such a model, in which a variable is distributed as a linear combination of distributions, is called a mixture model.

For this case, we can write this likelihood more concisely as

\begin{align} n_i \sim w \, \text{NegBinom}(\alpha_1, \beta_1) + (1-w)\,\text{NegBinom}(\alpha_2, \beta_2)\;\;\forall i. \end{align}

We have to specify priors on \(\alpha_1\), \(\beta_1\), \(\alpha_2\), \(\beta_2\), and \(w\). We can retain the same priors from the previous lesson for the \(\alpha\)’s and \(\beta\)’s, and we will assume a Uniform prior for \(w\). We then have the following model.

\begin{align} &\log_{10} \alpha_i \sim \text{Norm}(0,1) \text{ for } i \in [1, 2] \\[1em] &\log_{10} b_i \sim \text{Norm}(2, 1) \text{ for } i \in [1, 2], \\[1em] &\beta_i = 1/b_i,\\[1em] &w \sim \text{Beta}(1, 1), \\[1em] &n_i \sim w \, \text{NegBinom}(\alpha_1, \beta_1) + (1-w)\,\text{NegBinom}(\alpha_2, \beta_2)\;\;\forall i. \end{align}

Note that since the prior for \(w\) is Uniform (Beta(1, 1)), we do not need to explicitly consider its prior, but simply enforce constrains on \(w\).

Coding up a mixture model

There are a few considerations for coding up a mixture model that also introduce Stan syntax. Importantly, under the hood, Stan uses the log posterior, as do almost all samplers, when sampling out of the posterior. Dealing with a mixture model presents a unique challenge for computing the log likelihood (which is one of the summands of the log posterior). Consider the log likelihood in the present example.

\begin{align} \ln f(n\mid \alpha_1, \alpha_2, \beta_1, \beta_2, w) = \ln(w\,a_1 + (1-w)a_2), \end{align}

where

\begin{align} a_i = \frac{\Gamma(n + \alpha_i)!}{n!\,\Gamma(\alpha_i)}\,\left(\frac{\beta_i}{1+\beta_i}\right)^{\alpha_i}\left(\frac{1}{1+\beta_i}\right)^{n}. \end{align}

While the logarithm of a product is conveniently split, we cannot split the logarithm of a sum. If we consider the sum directly, we will get serious underflow errors for parameters for which the terms \(a_1\) or \(a_2\) are small. To compute this in a more numerically stable way, we need to use the log-sum-exp trick. Fortunately, Stan has a built-in function to compute the contributions to the log posterior of a mixture, the log_mix function. To update the posterior with this log_mix function, we need to add to target. In Stan, the keyword target is a special variable that holds the running sum of the contributions to the log posterior. When you make statements like theta ~ normal(0.0, 1.0), Stan is adding the appropriate terms to target under the hood. In the case of mixture models, we need to add to target explicitly. More generally, you can add any terms to target, and Stan considers these terms as part of the log posterior. The Stan code below implements this for the mixture model.

data {
  int<lower=0> N;
  array[N] int<lower=0> n;
}


parameters {
  vector<lower=0>[2] log10_alpha;
  vector<lower=0>[2] log10_b;
  real<lower=0, upper=1> w;
}


transformed parameters {
  vector[2] alpha = 10 .^ log10_alpha;
  vector[2] b = 10 .^ log10_b;
  vector[2] beta_ = 1.0 ./ b;
}


model {
  // Priors
  log10_alpha ~ normal(0, 1);
  log10_b ~ normal(2, 1);
  w ~ beta(1.0, 1.0);

  // Likelihood
  for (n_val in n) {
    target += log_mix(
      w,
      neg_binomial_lupmf(n_val | alpha[1], beta_[1]),
      neg_binomial_lupmf(n_val | alpha[2], beta_[2])
    );
  }
}

In addition to the log_mix function, there is some more new syntax.

  • To add the contribution of the Negative Binomial PMF to the log posterior, we use negative_binomial_lpmf. Every discrete Stan distribution has a function <distribution>_lpmf that gives the value of the log PMF, and every continuous distribution has a similar function <distribution>_lpdf. They also have functions <distribution>_lupmf and <distribution>_lupdf which is the unnormalized version, computing only the terms that are necessary to add to the target. (Which terms are chosen depend on what variables are declared as parameters and data; Stan automatically handles what should be included.) The arguments of the function are as you would write them on paper, with a bar (|) signifying the conditioning on the parameters.

  • In the above, I have specified alpha, b, and beta_ as vectors. A vector is an array that behaves like a column vector, which means you can do matrix operations with it. It also allows you to do element-wise operations. Notice that I have written expressions like

vvector[2] b = 10 .^ log10_beta;

and

vector[2] beta_ = 1.0 ./ b;

The latter means that beta_ is a 2-vector and that each element is given by the inverse of the corresponding element in b. The ./ operator accomplishes this (note the dot in front of the slash). In general, preceding an operator with a . indicates that the operation should be done elementwise, as is the case for the .^ operator above.

  • Note also that Stan is smart enough to know that if I give log10_alpha a prior that it needs to assign it independently to each element in the vector log10_alpha. The same is of course true for log10_b.

  • To specify a lower and upper bound, we use the <lower=0, upper=1> syntax.

  • Note the for loop construction. for (n_val in n) means iterate over each entry in n, yielding its value as n_val within the loop. The contents of the loop are enclosed in braces. Note that we could equivalently have written the for loop as

// Likelihood
for (i in 1:N) {
  target += log_mix(
    w,
    neg_binomial_lupmf(n[i] | alpha[1], beta_[1]),
    neg_binomial_lupmf(n[i] | alpha[2], beta_[2])
  );
}

Here, we are looping over integers starting an 1 and ending at \(N\), inclusive. You can think of the 1:N syntax to be kind of like Python’s range() function. However, unlike Python’s range(), Stan’s range is inclusive of the end value.

Now that we have our model set up, let’s compile and sample from it! I will use the seed kwarg to set the seed for the random number generator to ensure that I always get the same result for illustrative purposes.

[4]:
data = {'n': n, 'N': len(n)}

with bebi103.stan.disable_logging():
    sm = cmdstanpy.CmdStanModel(stan_file='mixture.stan')
    samples = sm.sample(
        data=data,
        seed=3252,
        chains=4,
        iter_sampling=1000,
    )

samples = az.from_cmdstanpy(posterior=samples)

Parsing the output

Let’s look at the results.

[5]:
samples
[5]:
arviz.InferenceData
    • <xarray.Dataset> Size: 360kB
      Dimensions:            (chain: 4, draw: 1000, log10_alpha_dim_0: 2,
                              log10_b_dim_0: 2, alpha_dim_0: 2, b_dim_0: 2,
                              beta__dim_0: 2)
      Coordinates:
        * chain              (chain) int64 32B 0 1 2 3
        * draw               (draw) int64 8kB 0 1 2 3 4 5 ... 994 995 996 997 998 999
        * log10_alpha_dim_0  (log10_alpha_dim_0) int64 16B 0 1
        * log10_b_dim_0      (log10_b_dim_0) int64 16B 0 1
        * alpha_dim_0        (alpha_dim_0) int64 16B 0 1
        * b_dim_0            (b_dim_0) int64 16B 0 1
        * beta__dim_0        (beta__dim_0) int64 16B 0 1
      Data variables:
          log10_alpha        (chain, draw, log10_alpha_dim_0) float64 64kB 0.3836 ....
          log10_b            (chain, draw, log10_b_dim_0) float64 64kB 0.8141 ... 0...
          w                  (chain, draw) float64 32kB 0.2084 0.1713 ... 0.8397
          alpha              (chain, draw, alpha_dim_0) float64 64kB 2.419 ... 1.769
          b                  (chain, draw, b_dim_0) float64 64kB 6.517 29.36 ... 8.078
          beta_              (chain, draw, beta__dim_0) float64 64kB 0.1534 ... 0.1238
      Attributes:
          created_at:                 2025-02-01T23:13:06.481192+00:00
          arviz_version:              0.20.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.5

    • <xarray.Dataset> Size: 204kB
      Dimensions:          (chain: 4, draw: 1000)
      Coordinates:
        * chain            (chain) int64 32B 0 1 2 3
        * draw             (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
      Data variables:
          lp               (chain, draw) float64 32kB -1.597e+03 ... -1.598e+03
          acceptance_rate  (chain, draw) float64 32kB 0.9272 0.9948 ... 0.9495 0.9129
          step_size        (chain, draw) float64 32kB 0.09361 0.09361 ... 0.1234
          tree_depth       (chain, draw) int64 32kB 3 4 5 5 5 4 2 5 ... 5 4 3 4 5 4 5
          n_steps          (chain, draw) int64 32kB 15 23 43 31 63 ... 7 23 31 23 51
          diverging        (chain, draw) bool 4kB False False False ... False False
          energy           (chain, draw) float64 32kB 1.598e+03 ... 1.601e+03
      Attributes:
          created_at:                 2025-02-01T23:13:06.483368+00:00
          arviz_version:              0.20.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.5

The arviz.InferenceData object has two xarray DataSets, posterior and sample_stats. We will work extensively with sample_stats in future lessons (and find that they are crucial for checking your sampling!), but for now will focus on the posterior object that has the samples. You can look at it be expanding the posterior view above.

Note now that the posterior output is a bit more complex. This is because the parameters alpha, b, and beta_ are vector-valued. The entries in the vectors are respectively indexed by indexes alpha_dim_0, b_dim_0, and beta__dim_0. The samples of these vector valued parameters are then three dimensional arrays. The first two dimensions are the chain and draw, like we have already seen, but the third dimension specifies which element in the vector.

Samples are selecting using xarray data selection, which you can read about in the xarray docs. As an example, to access sample number 478 from chain 2, do the following.

[6]:
samples.posterior.loc[dict(chain=2, draw=478)]
[6]:
<xarray.Dataset> Size: 184B
Dimensions:            (log10_alpha_dim_0: 2, log10_b_dim_0: 2, alpha_dim_0: 2,
                        b_dim_0: 2, beta__dim_0: 2)
Coordinates:
    chain              int64 8B 2
    draw               int64 8B 478
  * log10_alpha_dim_0  (log10_alpha_dim_0) int64 16B 0 1
  * log10_b_dim_0      (log10_b_dim_0) int64 16B 0 1
  * alpha_dim_0        (alpha_dim_0) int64 16B 0 1
  * b_dim_0            (b_dim_0) int64 16B 0 1
  * beta__dim_0        (beta__dim_0) int64 16B 0 1
Data variables:
    log10_alpha        (log10_alpha_dim_0) float64 16B 0.4777 0.6959
    log10_b            (log10_b_dim_0) float64 16B 0.6956 1.518
    w                  float64 8B 0.1779
    alpha              (alpha_dim_0) float64 16B 3.004 4.964
    b                  (b_dim_0) float64 16B 4.962 33.0
    beta_              (beta__dim_0) float64 16B 0.2015 0.0303
Attributes:
    created_at:                 2025-02-01T23:13:06.481192+00:00
    arviz_version:              0.20.0
    inference_library:          cmdstanpy
    inference_library_version:  1.2.5

If you wanted only the value of \(\alpha_2\) from that sample, you would do (noting the zero-based indexing of xarrays):

[7]:
samples.posterior["alpha"].loc[dict(chain=2, draw=478, alpha_dim_0=1)]
[7]:
<xarray.DataArray 'alpha' ()> Size: 8B
array(4.96436)
Coordinates:
    chain        int64 8B 2
    draw         int64 8B 478
    alpha_dim_0  int64 8B 1

If you wanted it as a scalar, you would use the values attribute and convert the resulting zero-dimension Numpy array to a float.

[8]:
float(samples.posterior["alpha"].loc[dict(chain=2, draw=478, alpha_dim_0=1)].values)
[8]:
4.96436

While xarray data types are quite powerful, it is often more convenient to work with the more familiar data frames. The .to_dataframe() method of xarrays is quite useful for this purpose. It converts the xarray to a Pandas data frame. This, in turn, can be converted to a Polars data frame using pl.from_pandas().

[9]:
pl.from_pandas(samples.posterior.to_dataframe(), include_index=True).head()
[9]:
shape: (5, 13)
chaindrawlog10_alpha_dim_0log10_b_dim_0alpha_dim_0b_dim_0beta__dim_0log10_alphalog10_bwalphabbeta_
i64i64i64i64i64i64i64f64f64f64f64f64f64
00000000.3835580.8140610.2084422.418576.51720.15344
00000010.3835580.8140610.2084422.418576.51720.0340593
00000100.3835580.8140610.2084422.4185729.36050.15344
00000110.3835580.8140610.2084422.4185729.36050.0340593
00001000.3835580.8140610.2084425.756396.51720.15344

As we might expect, the indexes of the xarray become indexes of the data frame. This can be cumbersome to work with when plotting because we typically wish to plot marginal distributions, e.g., of element 1 of alpha. A more convenient form might be a data frame where each element of a vector- (or matrix-) valued parameter is a column. This is accomplished using the bebi103.stan.arviz_to_dataframe() function.

[10]:
samples
[10]:
arviz.InferenceData
    • <xarray.Dataset> Size: 360kB
      Dimensions:            (chain: 4, draw: 1000, log10_alpha_dim_0: 2,
                              log10_b_dim_0: 2, alpha_dim_0: 2, b_dim_0: 2,
                              beta__dim_0: 2)
      Coordinates:
        * chain              (chain) int64 32B 0 1 2 3
        * draw               (draw) int64 8kB 0 1 2 3 4 5 ... 994 995 996 997 998 999
        * log10_alpha_dim_0  (log10_alpha_dim_0) int64 16B 0 1
        * log10_b_dim_0      (log10_b_dim_0) int64 16B 0 1
        * alpha_dim_0        (alpha_dim_0) int64 16B 0 1
        * b_dim_0            (b_dim_0) int64 16B 0 1
        * beta__dim_0        (beta__dim_0) int64 16B 0 1
      Data variables:
          log10_alpha        (chain, draw, log10_alpha_dim_0) float64 64kB 0.3836 ....
          log10_b            (chain, draw, log10_b_dim_0) float64 64kB 0.8141 ... 0...
          w                  (chain, draw) float64 32kB 0.2084 0.1713 ... 0.8397
          alpha              (chain, draw, alpha_dim_0) float64 64kB 2.419 ... 1.769
          b                  (chain, draw, b_dim_0) float64 64kB 6.517 29.36 ... 8.078
          beta_              (chain, draw, beta__dim_0) float64 64kB 0.1534 ... 0.1238
      Attributes:
          created_at:                 2025-02-01T23:13:06.481192+00:00
          arviz_version:              0.20.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.5

    • <xarray.Dataset> Size: 204kB
      Dimensions:          (chain: 4, draw: 1000)
      Coordinates:
        * chain            (chain) int64 32B 0 1 2 3
        * draw             (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
      Data variables:
          lp               (chain, draw) float64 32kB -1.597e+03 ... -1.598e+03
          acceptance_rate  (chain, draw) float64 32kB 0.9272 0.9948 ... 0.9495 0.9129
          step_size        (chain, draw) float64 32kB 0.09361 0.09361 ... 0.1234
          tree_depth       (chain, draw) int64 32kB 3 4 5 5 5 4 2 5 ... 5 4 3 4 5 4 5
          n_steps          (chain, draw) int64 32kB 15 23 43 31 63 ... 7 23 31 23 51
          diverging        (chain, draw) bool 4kB False False False ... False False
          energy           (chain, draw) float64 32kB 1.598e+03 ... 1.601e+03
      Attributes:
          created_at:                 2025-02-01T23:13:06.483368+00:00
          arviz_version:              0.20.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.5

[11]:
df_mcmc = bebi103.stan.arviz_to_dataframe(samples)

df_mcmc.head()
[11]:
shape: (5, 14)
log10_alpha[0]log10_alpha[1]log10_b[0]log10_b[1]walpha[0]alpha[1]b[0]b[1]beta_[0]beta_[1]chain__draw__diverging__
f64f64f64f64f64f64f64f64f64f64f64i64i64bool
0.3835580.760150.8140611.467760.2084422.418575.756396.517229.36050.153440.034059300false
0.3180550.7763620.916961.454940.171292.079965.975338.2596228.50650.1210710.035079701false
0.6325810.6982780.4876641.51250.1628894.291224.992033.0737232.54590.3253390.030725802false
0.4557490.7081540.6870931.499740.1324032.855945.106864.8651131.60350.2055450.031642103false
0.3010130.7127020.8691321.497520.137341.999925.160627.398331.44250.1351660.031804104false

This data frame is easier to work with for plotting. Note that there are added columns for the chain number ('chain__') and draw ('draw__'). These have double underscores in their names for easier reference. This function also includes some diagnostics from the sample_stats group of the ArviZ InferenceData in the data frame, in this case the diverging__ column, and we will discuss the meaning of those in future lessons.

Plotting the samples

To plot our samples, it is convenient to make scatter plots of samples for each pair of parameters. This means we plot all marginalized posteriors that contain two parameters. We can also plot marginalized posteriors where we only have one variable. This is conveniently done using a corner plot, as implemented in bebi103.viz.corner().

[12]:
# Parameters we want to plot with pretty names for display
pars = [("alpha[0]", "α₁"), ("alpha[1]", "α₂"), ("b[0]", "b₁"), ("b[1]", "b₂"), "w"]

bokeh.io.show(bebi103.viz.corner(samples, parameters=pars))

This looks peculiar. We see a strong bimodality in the posterior. The cause of this can be revealed if we color the glyphs by the chain ID, accomplished using the color_by_chain kwarg. We only need to show one plot, we I will show \(w\) vs. \(\alpha_2\).

[13]:
bokeh.io.show(
    bebi103.viz.corner(
        samples, parameters=[("alpha[1]", "α₂"), "w"], color_by_chain=True
    )
)

We see that three of the chains (colored blue, red, and green) are centered around w ≈ 0.8, while the other (colored orange) is around w ≈ 0.2. Note that these two values of w sum to one. We have just uncovered a nonidentifiable model. A nonidentifiable model is a model for which we cannot unambiguously determine the parameter values. That is, two or more parameter sets are observationally equivalent.

Label switching

There are many reasons why a model may be nonidentifiable. In this case, we are seeing a manifestation of label switching (see the Stan Manual. Before launching into label switching in this particular case, I emphasize that this is certainly not the only way models can be nonidentifiable, and I am presenting this as a case study on how to deal with a particular kind of nonidentifiability. If you are seeing nonidentifiability in your model, do not automatically assume that it is because of label switching. I also am going through this case study to demonstrate that even in fairly simple models (in this case, two cell populations, each characterized by their burst size and frequency), devilish problems can arise when trying to do inference. You need to be vigilant.

In this mixture model, it is arbitrary which \((\alpha, b)\) pair we label as \((\alpha_1, b_1)\) or \((\alpha_2, b_2)\). We can switch the labels, and also change \(w\) to \(1-w\), and we have exactly the same posterior probability. To demonstrate that this is the case, I will generate the same grid of plots as above, but switch the appropriate labels and convert \(w\) to \(1-w\) for every \(w < 0.5\). (Note that this will not in general work, especially if the different modes from label switching overlap, and is not a good idea for analysis; I’m just doing it here to illustrate how label switching leads to nonidentifiability.)

[14]:
# Perform the label switch
params = ["alpha[0]", "alpha[1]", "b[0]", "b[1]", "w"]
switch = df_mcmc.filter(pl.col("w") > 0.5)[params]
switch = switch.rename(
    {
        "b[0]": "b[1]",
        "b[1]": "b[0]",
        "alpha[1]": "alpha[0]",
        "alpha[0]": "alpha[1]",
    }
)
switch = switch.with_columns((1 - pl.col("w")).alias('w'))

df_switch = pl.concat([df_mcmc.filter(pl.col("w") < 0.5)[params], switch], how='diagonal')

# Make corner plot
bokeh.io.show(bebi103.viz.corner(df_switch, parameters=pars))

We see that if we fix the label switching, the posterior is indeed unimodal. So, making an identifiable model in this case means that we have to deal with the label switching problem. There are many approaches to doing this, and you can see a very detailed discussion about dealing with label switch and other problems associated with mixture models in this blog post by Michael Betancourt.

In looking at the corner plot, we see a very strong, non-Normal correlation between \(\alpha_1\) and \(b_1\) and also between \(\alpha_2\) and \(b_2\). There is clearly some structure in the posterior that would be impossible to discover using the MAP alone.

Initializing walkers

We will now take a strategy suggested in earlier versions of the Stan manual. We noted earlier that the chains tended to stay on a single mode. We will initialize the chains to instead start near only one mode. This, like any of the fixes for mixture models, does not guarantee that we get good sampling (and we will discuss more diagnostics for good in a future lesson), but it in practice it can work, so we will try it.

To determine where to start the chains, we will select a chain and start the samplers at the mean. First, we need to compute the parameter means for the first chain.

[15]:
# Compute mean of parameters for chain 1
param_means = (
    df_mcmc
    .filter(pl.col('chain__') == 1)
    .select(cs.exclude('chain__', 'draw__', 'diverging__'))
    .mean()
)

# Take a look
param_means
[15]:
shape: (1, 11)
log10_alpha[0]log10_alpha[1]log10_b[0]log10_b[1]walpha[0]alpha[1]b[0]b[1]beta_[0]beta_[1]
f64f64f64f64f64f64f64f64f64f64f64
0.72730.4500451.4858760.7563780.8270435.3899333.04777830.8696536.6347030.0329470.204229

Now that we have the means for the first chain, we can use them to pass into Stan’s sampler. An easy way to do that is to pass a dictionary of starting points using the inits kwarg for sm.sample(). Note that vector-valued parameters (such as alpha and b) need to be specified as vectors, which means you can use a Python list.

Before constructing this, I pause to note that by default Stan chooses starting values for the chains by drawing random number on the interval [-2, 2]. For constrained parameters, they are transformed to be unconstrained and then again drawn from this interval. If the posterior parameter values have very low probability mass in this interval, warmup may take longer. It is sometimes advisable to start the chains at different starting points. The method of initializing the chains we are using here is meant to deal with the label switching problem, but I am using it also to demonstrate how to provide starting points for chains.

[16]:
inits = {
    "log10_alpha": [param_means["log10_alpha[0]"].item(), param_means["log10_alpha[1]"].item()],
    "log10_b": [param_means["log10_b[0]"].item(), param_means["log10_b[1]"].item()],
    "w": param_means["w"].item(),
}

with bebi103.stan.disable_logging():
    samples = sm.sample(data=data, inits=inits, seed=3252)

samples = az.from_cmdstanpy(posterior=samples)

Let’s look at the corner plot again to see how this worked.

[17]:
bokeh.io.show(bebi103.viz.corner(samples, parameters=pars))

In this case, the walker initialization solved the label switching problem. We can do our ad hoc model assessment by plotting the theoretical CDFs.

[18]:
# Make ECDF
p = iqplot.ecdf(data=n, x_axis_label="mRNA count")

# x-values and samples to use in plot
x = np.arange(int(1.05 * n.max()))
alpha0s = samples.posterior["alpha"].values[:, :, 0].flatten()[::40]
alpha1s = samples.posterior["alpha"].values[:, :, 1].flatten()[::40]
beta0s = samples.posterior["beta_"].values[:, :, 0].flatten()[::40]
beta1s = samples.posterior["beta_"].values[:, :, 1].flatten()[::40]
ws = samples.posterior["w"].values.flatten()[::40]

for alpha0, alpha1, beta0, beta1, w in zip(alpha0s, alpha1s, beta0s, beta1s, ws):
    y = w * st.nbinom.cdf(x, alpha0, beta0 / (1 + beta0))
    y += (1 - w) * st.nbinom.cdf(x, alpha1, beta1 / (1 + beta1))
    x_plot, y_plot = bebi103.viz.cdf_to_staircase(x, y)
    p.line(x_plot, y_plot, line_width=0.5, color="orange", level="underlay", alpha=0.2)

bokeh.io.show(p)

The mixture model performs much better than the single Negative-Binomial model.

Conclusions

You have learned how to use Stan to use MCMC to sample out of a posterior distribution. I hope it is evident how convenient and powerful this is. I also hope you have an understanding of how fragile statistical modeling can be, as you saw with a label switching-based nonidentifiability.

We have looked at some visualizations of MCMC results in this lesson, and in coming lessons, we will take a closer look at how to visualize and report MCMC results.

[19]:
bebi103.stan.clean_cmdstan()

Computing environment

[20]:
%load_ext watermark
%watermark -v -p numpy,scipy,polars,cmdstanpy,arviz,bokeh,iqplot,bebi103,jupyterlab
print("cmdstan   :", bebi103.stan.cmdstan_version())
Python implementation: CPython
Python version       : 3.12.8
IPython version      : 8.27.0

numpy     : 1.26.4
scipy     : 1.14.1
polars    : 1.17.1
cmdstanpy : 1.2.5
arviz     : 0.20.0
bokeh     : 3.6.0
iqplot    : 0.3.7
bebi103   : 0.1.26
jupyterlab: 4.2.6

cmdstan   : 2.36.0