Skip to content

Integration with MCMCChains.jl

MCMC packages like Turing often produce results in the form of an MCMCChains.Chain. There is special support in PairPlots.jl for plotting these chains.

Note

The integration between PairPlots and MCMCChains only works on Julia 1.9 and above. On previous versions, you can work around this by running pairplot(DataFrame(chn)).

Plotting chains

For this example, we'll use the following code to generate a Chain. In a real code, you would likely receive a chain as a result of sampling from a model.

julia
chn1 = Chains(randn(10000, 5, 3) .* [1 2 3 4 5] .* [1;;;2;;;3], [:a, :b, :c, :d, :e])
Chains MCMC chain (10000×5×3 Array{Float64, 3}):

Iterations        = 1:1:10000
Number of chains  = 3
Samples per chain = 10000
parameters        = a, b, c, d, e

Summary Statistics
  parameters      mean       std      mcse     ess_bulk    ess_tail      rhat  ⋯
      Symbol   Float64   Float64   Float64      Float64     Float64   Float64  ⋯

           a    0.0008    2.1638    0.0126   29557.1146   1156.5243    1.1294  ⋯
           b    0.0057    4.3014    0.0249   29950.0994   1237.6512    1.1271  ⋯
           c   -0.0695    6.4569    0.0373   30027.1500   1183.1833    1.1287  ⋯
           d    0.0686    8.6427    0.0497   30248.2697   1208.3892    1.1305  ⋯
           e    0.0643   10.7898    0.0631   29249.3091   1218.1662    1.1302  ⋯
                                                                1 column omitted

Quantiles
  parameters       2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol    Float64   Float64   Float64   Float64   Float64

           a    -4.5493   -1.1685   -0.0176    1.1714    4.5982
           b    -9.0758   -2.3262   -0.0027    2.3355    9.1067
           c   -13.6681   -3.6198   -0.0354    3.4200   13.5601
           d   -18.0243   -4.6067    0.0593    4.6890   18.2898
           e   -22.5568   -5.7641    0.0302    5.9473   22.5938

You can plot the results from all chains in the Chains object:

julia
using CairoMakie, PairPlots

pairplot(chn1)

The labels are taken from the column names of the chains. You can modify them by passing in a dictionary mapping column names to strings, LaTeX strings, or Makie rich text objects.

Plotting individual chains separately

If you have multiple parallel chains and want to plot them in different colors, you can pass each one to pairplot:

julia
pairplot(chn1[:,:,1], chn1[:,:,2], chn1[:,:,3])

You can title the series indepdendently as well:

julia
c1 = Makie.wong_colors(0.5)[1]
c2 = Makie.wong_colors(0.5)[2]
c3 = Makie.wong_colors(0.5)[3]

pairplot(
    PairPlots.Series(chn1[:,:,1], label="chain 1", color=c1, strokecolor=c1),
    PairPlots.Series(chn1[:,:,2], label="chain 2", color=c2, strokecolor=c2),
    PairPlots.Series(chn1[:,:,3], label="chain 3", color=c3, strokecolor=c3),
)

If your chains are well converged, then the different series should look the same.

Comparing the results of two simulations

You may want to compare the results of two simulations. Consider the following chains:

julia
chn2 = Chains(randn(10000, 5, 1) .* [1 2 3 4 5], [:a, :b, :c, :d, :e])
chn3 = Chains(randn(10000, 4, 1) .* [5 4 2 1], [:a, :b, :d, :e]);
Chains MCMC chain (10000×4×1 Array{Float64, 3}):

Iterations        = 1:1:10000
Number of chains  = 1
Samples per chain = 10000
parameters        = a, b, d, e

Summary Statistics
  parameters      mean       std      mcse     ess_bulk    ess_tail      rhat  ⋯
      Symbol   Float64   Float64   Float64      Float64     Float64   Float64  ⋯

           a   -0.0248    4.9761    0.0497   10016.3180   9180.5072    1.0001  ⋯
           b    0.0143    3.9940    0.0396   10205.0321   9320.5488    0.9999  ⋯
           d   -0.0246    2.0101    0.0198   10304.7748   9487.4974    0.9999  ⋯
           e   -0.0176    0.9944    0.0099   10100.6717   9469.0326    0.9999  ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

           a   -9.7614   -3.4194   -0.0086    3.3789    9.5334
           b   -7.9027   -2.6637   -0.0022    2.7271    7.8388
           d   -3.9502   -1.4025   -0.0551    1.3290    3.8781
           e   -1.9244   -0.7014   -0.0142    0.6524    1.9375

Just pass them all to pairplot:

julia
pairplot(chn2, chn3)

Note how the parameters of the chains do not have to match exactly. Here, chn2 has an additional variable not present in chn3.