Integration with MCMCChains.jl

MCMC packages like Turing often produce results in the form of an MCMCChains.Chain. There is special support in PairPlots.jl for plotting these chains.

Note

The integration between PairPlots and MCMCChains only works on Julia 1.9 and above. On previous versions, you can work around this by running pairplot(DataFrame(chn)).

Plotting chains

For this example, we'll use the following code to generate a Chain. In a real code, you would likey receive a chain as a result of sampling from a model.

chn1 = Chains(randn(10000, 5, 3) .* [1 2 3 4 5] .* [1;;;2;;;3], [:a, :b, :c, :d, :e])
Chains MCMC chain (10000×5×3 Array{Float64, 3}):

Iterations        = 1:1:10000
Number of chains  = 3
Samples per chain = 10000
parameters        = a, b, c, d, e

Summary Statistics
  parameters      mean       std      mcse     ess_bulk    ess_tail      rhat      Symbol   Float64   Float64   Float64      Float64     Float64   Float64  ⋯

           a    0.0043    2.1531    0.0125   29550.7875   1208.2502    1.1322  ⋯
           b   -0.0101    4.2826    0.0249   29586.1705   1286.6040    1.1226  ⋯
           c    0.0295    6.4946    0.0385   28454.8904   1233.1812    1.1280  ⋯
           d    0.0804    8.5783    0.0494   30242.1330   1221.9348    1.1255  ⋯
           e   -0.0268   10.7756    0.0625   29744.9271   1301.6755    1.1239  ⋯
                                                                1 column omitted

Quantiles
  parameters       2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol    Float64   Float64   Float64   Float64   Float64 

           a    -4.5481   -1.1607    0.0074    1.1660    4.5881
           b    -9.0943   -2.3255   -0.0114    2.3283    9.0090
           c   -13.6051   -3.5136    0.0160    3.4791   13.9290
           d   -18.1835   -4.5789    0.0602    4.6961   18.2168
           e   -23.0803   -5.8993   -0.0269    5.8061   22.7915

You can plot the results from all chains in the Chains object:

using CairoMakie, PairPlots

pairplot(chn1)
Example block output

The labels are taken from the column names of the chains. You can modify them by passing in a dictionary mapping column names to strings, LaTeX strings, or Makie rich text objects.

Plotting individual chains separately

If you have multiple parallel chains and want to plot them in different colors, you can pass each one to pairplot:

pairplot(chn1[:,:,1], chn1[:,:,2], chn1[:,:,3])
Example block output

You can title the series indepdendently as well:

c1 = Makie.wong_colors(0.5)[1]
c2 = Makie.wong_colors(0.5)[2]
c3 = Makie.wong_colors(0.5)[3]

pairplot(
    PairPlots.Series(chn1[:,:,1], label="chain 1", color=c1, strokecolor=c1),
    PairPlots.Series(chn1[:,:,2], label="chain 2", color=c2, strokecolor=c2),
    PairPlots.Series(chn1[:,:,3], label="chain 3", color=c3, strokecolor=c3),
)
Example block output

If your chains are well converged, then the different series should look the same.

Comparing the results of two simulations

You may want to compare the results of two simulations. Consider the following chains:

chn2 = Chains(randn(10000, 5, 1) .* [1 2 3 4 5], [:a, :b, :c, :d, :e])
chn3 = Chains(randn(10000, 4, 1) .* [5 4 2 1], [:a, :b, :d, :e]);
Chains MCMC chain (10000×4×1 Array{Float64, 3}):

Iterations        = 1:1:10000
Number of chains  = 1
Samples per chain = 10000
parameters        = a, b, d, e

Summary Statistics
  parameters      mean       std      mcse     ess_bulk     ess_tail      rhat     Symbol   Float64   Float64   Float64      Float64      Float64   Float64 ⋯

           a    0.0349    4.9962    0.0503    9855.1061   10004.1962    1.0000 ⋯
           b    0.0918    4.0185    0.0407    9712.4930    9464.4058    0.9999 ⋯
           d   -0.0086    2.0229    0.0200   10248.0100    9590.9378    1.0001 ⋯
           e    0.0005    1.0036    0.0103    9561.1911    9717.5356    1.0000 ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

           a   -9.7727   -3.3430    0.1216    3.3540   10.0076
           b   -7.7217   -2.6732    0.1085    2.9130    7.9032
           d   -3.9749   -1.3598    0.0006    1.3415    3.9466
           e   -1.9746   -0.6730    0.0079    0.6892    1.9618

Just pass them all to pairplot:

pairplot(chn2, chn3)
Example block output

Note how the parameters of the chains do not have to match exactly. Here, chn2 has an additional variable not present in chn3.