Skip to content
Snippets Groups Projects
Commit b9479673 authored by Rob Moss's avatar Rob Moss
Browse files

Display correlation coefficients as a bar plot

parent 9067c1d9
Branches docs/scenario-modelling
Tags
No related merge requests found
Pipeline #115594 passed
......@@ -284,6 +284,20 @@ Finally, we can plot the :ref:`correlations between prevented infections and pat
Spearman rank correlation coefficients between (a) the number of prevented infections (relative to the baseline scenario); and (b) model parameters that characterise the pathogen.
We can also display these correlations :ref:`as a bar plot <fig-prevented-corrs-bars>`, which may be more useful for visual comparisons between individual parameters and/or scenarios.
.. admonition:: Function that plots rank correlation bars.
:class: dropdown seealso
.. literalinclude:: ../../tests/test_sirv.py
:pyobject: plot_prevented_infection_correlations_bars
.. _fig-prevented-corrs-bars:
.. figure:: scenario-sirv-correlations-bars.png
:width: 100%
Spearman rank correlation coefficients between (a) the number of prevented infections (relative to the baseline scenario); and (b) model parameters that characterise the pathogen.
Further details
^^^^^^^^^^^^^^^
......
doc/how-to/scenario-sirv-correlations-bars.png

67.9 KiB

"""Test cases for the SIRV model in ``pypfilt.examples.sirv``."""
import io
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
......@@ -36,6 +37,9 @@ def test_sirv_scenarios():
plot_prevented_infection_correlations(
contexts, results, output_dir / 'scenario-sirv-correlations.png'
)
plot_prevented_infection_correlations_bars(
contexts, results, output_dir / 'scenario-sirv-correlations-bars.png'
)
def run_sirv_scenarios():
......@@ -306,6 +310,82 @@ def plot_prevented_infection_correlations(contexts, results, png_file):
fig.savefig(png_file, dpi=300, **png_kwargs())
def plot_prevented_infection_correlations_bars(contexts, results, png_file):
"""
Plot Spearman correlation coefficients between the numbers of prevented
infections in each scenario, and model parameters that characterise the
pathogen (i.e., parameters that are not related to the interventions).
"""
prevented_infs = get_prevented_infections(results, relative_to='Baseline')
samples = {
'R0': contexts['Baseline'].data['prior']['R0'],
'gamma': contexts['Baseline'].data['prior']['gamma'],
}
samples['beta'] = samples['R0'] * samples['gamma']
corrs = np.zeros((len(samples), len(prevented_infs)))
for param_ix, values in enumerate(samples.values()):
for scenario_ix, num_prevented in enumerate(prevented_infs.values()):
coeff = scipy.stats.spearmanr(num_prevented, values).statistic
corrs[param_ix, scenario_ix] = coeff
palette = mpl.colormaps['Pastel2']
bar_width = 0.3
fig, ax = plt.subplots(layout='constrained', figsize=[4.8, 4.8])
for ix, scenario in enumerate(prevented_infs):
# Plot the correlation coefficients for this scenario as bars.
corrcoefs = corrs[:, ix]
bars = ax.barh(
np.arange(len(samples)) + ix * bar_width,
width=corrcoefs,
height=bar_width,
color=palette(ix),
label=get_scenario_label(scenario),
)
# NOTE: we position labels differently for small and large values.
# Labels for large values are placed within the bars, while labels for
# small values are placed adjacent to the bars.
large_corrcoefs = [
f'{corrcoef:0.3f}' if abs(corrcoef) > 0.2 else ''
for corrcoef in corrcoefs
]
ax.bar_label(
bars,
labels=large_corrcoefs,
padding=-40,
)
small_corrcoefs = [
f'{corrcoef:0.3f}' if abs(corrcoef) < 0.2 else ''
for corrcoef in corrcoefs
]
ax.bar_label(
bars,
labels=small_corrcoefs,
padding=8,
)
# Ensure that the x-axis spans [-1, 1] by plotting invisible points (this
# adds some padding to the axis limits).
ax.scatter(x=[-1, 1], y=[0, 0], alpha=0)
ax.set_xticks([-1, -0.5, 0, 0.5, 1.0])
# Display parameter names on the left-hand side of the plot.
ax.set_yticks(
np.arange(len(samples)) + 0.5 * bar_width,
labels=['$R_0$', r'$\gamma$', r'$\beta$'],
fontsize='large',
)
axis_linewidth = mpl.rcParams['axes.linewidth']
ax.axvline(x=0, color='black', linewidth=axis_linewidth)
ax.legend(loc='best')
ax.set_xlabel('Rank correlation with infections prevented')
fig.savefig(png_file, dpi=300, **png_kwargs())
def png_kwargs():
"""
Return a dictionary of keyword arguments for ``Figure.savefig`` that avoid
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment