Skip to content
Snippets Groups Projects
Commit 5ab05586 authored by Rob Moss's avatar Rob Moss
Browse files

Show how to calculate CPRS skill scores

This was an important omission from the Getting Started tutorial, as
identified by Karangupta1994 in the pypfilt JOSS article review process.
parent 61f6f191
Branches
Tags
No related merge requests found
......@@ -9,14 +9,14 @@ In order to fit the simulation model to observations and generate forecasts, we
#. Define an input file for each observation model; and
#. Record summary statistics such as :class:`predictive credible intervals <pypfilt.summary.PredictiveCIs>` for each observation model.
#. Record summary statistics such as :class:`predictive credible intervals <pypfilt.summary.PredictiveCIs>` for each observation model, and :class:`simulated observations <pypfilt.summary.SimulatedObs>` for :math:`z(t)`.
These changes are indicated by the highlighted lines in the following scenario definition:
.. literalinclude:: lorenz63_forecast.toml
:language: toml
:linenos:
:emphasize-lines: 19-21, 25, 29, 33, 35-37
:emphasize-lines: 19-21, 25, 29, 33, 35-39
:name: lorenz63-forecast-toml
:caption: An example scenario for generating forecasts for the Lorenz-63 system.
......
......@@ -30,5 +30,6 @@ This guide shows how to build a simulation model, generate simulated observation
forecasts
plotting
regularisation
scoring
multiple
conclusion
......@@ -47,6 +47,8 @@ observations.y.file = "lorenz63-y.ssv"
observations.z.file = "lorenz63-z.ssv"
summary.tables.forecasts.component = "pypfilt.summary.PredictiveCIs"
summary.tables.forecasts.credible_intervals = [50, 60, 70, 80, 90, 95]
summary.tables.sim_z.component = "pypfilt.summary.SimulatedObs"
summary.tables.sim_z.observation_unit = "z"
[scenario.forecast_regularised]
prior.x = { name = "uniform", args.loc = -5, args.scale = 10 }
......@@ -57,6 +59,8 @@ observations.y.file = "lorenz63-y.ssv"
observations.z.file = "lorenz63-z.ssv"
summary.tables.forecasts.component = "pypfilt.summary.PredictiveCIs"
summary.tables.forecasts.credible_intervals = [50, 60, 70, 80, 90, 95]
summary.tables.sim_z.component = "pypfilt.summary.SimulatedObs"
summary.tables.sim_z.observation_unit = "z"
filter.regularisation.enabled = true
filter.regularisation.bounds.x = { min = -50, max = 50 }
filter.regularisation.bounds.y = { min = -50, max = 50 }
......
doc/getting-started/lorenz63_crps_comparison.png

124 KiB

......@@ -35,6 +35,8 @@ file = "lorenz63-z.ssv"
[summary.tables]
forecasts.component = "pypfilt.summary.PredictiveCIs"
forecasts.credible_intervals = [50, 60, 70, 80, 90, 95]
sim_z.component = "pypfilt.summary.SimulatedObs"
sim_z.observation_unit = "z"
[filter]
particles = 500
......
......@@ -23,6 +23,8 @@ z = { name = "uniform", args.loc = -5, args.scale = 10 }
[summary.tables]
forecasts.component = "pypfilt.summary.PredictiveCIs"
forecasts.credible_intervals = [50, 60, 70, 80, 90, 95]
sim_z.component = "pypfilt.summary.SimulatedObs"
sim_z.observation_unit = "z"
[observations.x]
model = "pypfilt.examples.lorenz.ObsLorenz63"
......
......@@ -39,7 +39,7 @@ In order to use the post-regularised particle filter, we need to:
.. literalinclude:: lorenz63_forecast_regularised.toml
:language: toml
:linenos:
:emphasize-lines: 44, 46-49
:emphasize-lines: 46, 48-51
:name: lorenz63-forecast-regularised-toml
:caption: An example scenario for the Lorenz-63 system that uses the post-regularisation particle filter (**see highlighted lines**).
......
.. _lorenz63-crps:
Forecast performance
====================
It was visually evident from the previous figures that the post-regularised particle filter produced much better forecasts than the bootstrap particle filter.
However, if there wasn't such a clear difference between the forecasts, or if we were interested in evaluating forecast performance over a large number of forecasts, visually inspecting the results would not be a suitable approach.
Instead, we can use a **proper scoring rule** such as the `Continuous Ranked Probability Score <https://otexts.com/fpp3/distaccuracy.html>`__ to evaluate each forecast distribution against the true observations.
We can then measure how much post-regularisation improves the forecast performance by calculating a CRPS skill score **relative to the original forecast**:
.. math::
\mathrm{Skill} = \frac{\operatorname{CRPS}_{\mathrm{Original}} - \operatorname{CRPS}_{\mathrm{Regularised}}}{\operatorname{CRPS}_{\mathrm{Original}}}
.. note::
Depending on the nature of the data and your model, it may be useful to **transform the data** before calculating CRPS values (e.g., computing scores on the log scale).
See `Scoring epidemiological forecasts on transformed scales (Bosse et al., 2023) <https://doi.org/10.1371/journal.pcbi.1011393>`__ for further details.
Shown below are the CPRS values for each :math:`z(t)` forecast.
As displayed in the figure legend, the forecast with post-regularisation is **76.7% better** than the original forecast.
.. figure:: lorenz63_crps_comparison.png
:width: 100%
Comparison of CRPS values for the original :math:`z(t)` forecasts, and for the :math:`z(t)` forecasts with regularisation.
We can calculate CPRS values by taking the following steps:
1. Record simulated :math:`z(t)` observations for each particle with the :class:`~pypfilt.summary.SimulatedObs` summary table;
2. Save the forecast results to HDF5 files;
3. Load the simulated :math:`z(t)` observations with :func:`~pypfilt.io.load_summary_table`;
4. Load the true future :math:`z(t)` observations with :func:`~pypfilt.io.read_table`; and
5. Calculate CRPS values with :func:`~pypfilt.crps.simulated_obs_crps`.
.. literalinclude:: ../../tests/test_lorenz.py
:pyobject: score_lorenz63_forecasts
......@@ -183,6 +183,8 @@ def lorenz63_forecast_toml():
[summary.tables]
forecasts.component = "pypfilt.summary.PredictiveCIs"
forecasts.credible_intervals = [50, 60, 70, 80, 90, 95]
sim_z.component = "pypfilt.summary.SimulatedObs"
sim_z.observation_unit = "z"
[filter]
particles = 500
......@@ -233,6 +235,8 @@ def lorenz63_forecast_regularised_toml():
[summary.tables]
forecasts.component = "pypfilt.summary.PredictiveCIs"
forecasts.credible_intervals = [50, 60, 70, 80, 90, 95]
sim_z.component = "pypfilt.summary.SimulatedObs"
sim_z.observation_unit = "z"
[observations.x]
model = "pypfilt.examples.lorenz.ObsLorenz63"
......@@ -324,6 +328,8 @@ def lorenz63_all_scenarios_toml():
observations.z.file = "lorenz63-z.ssv"
summary.tables.forecasts.component = "pypfilt.summary.PredictiveCIs"
summary.tables.forecasts.credible_intervals = [50, 60, 70, 80, 90, 95]
summary.tables.sim_z.component = "pypfilt.summary.SimulatedObs"
summary.tables.sim_z.observation_unit = "z"
[scenario.forecast_regularised]
prior.x = { name = "uniform", args.loc = -5, args.scale = 10 }
......@@ -334,6 +340,8 @@ def lorenz63_all_scenarios_toml():
observations.z.file = "lorenz63-z.ssv"
summary.tables.forecasts.component = "pypfilt.summary.PredictiveCIs"
summary.tables.forecasts.credible_intervals = [50, 60, 70, 80, 90, 95]
summary.tables.sim_z.component = "pypfilt.summary.SimulatedObs"
summary.tables.sim_z.observation_unit = "z"
filter.regularisation.enabled = true
filter.regularisation.bounds.x = { min = -50, max = 50 }
filter.regularisation.bounds.y = { min = -50, max = 50 }
......
......@@ -6,6 +6,7 @@ import numpy as np
import os
from pathlib import Path
import pypfilt
import pypfilt.crps
import pypfilt.examples.lorenz
......@@ -147,7 +148,7 @@ def simulate_lorenz63_observations():
return obs_tables
def run_lorenz63_forecast():
def run_lorenz63_forecast(filename=None):
scenario_file = 'lorenz63_forecast.toml'
instances = list(pypfilt.load_instances(scenario_file))
instance = instances[0]
......@@ -155,10 +156,10 @@ def run_lorenz63_forecast():
# Run a forecast from t = 20.
forecast_time = 20
context = instance.build_context()
return pypfilt.forecast(context, [forecast_time], filename=None)
return pypfilt.forecast(context, [forecast_time], filename=filename)
def run_lorenz63_forecast_regularised():
def run_lorenz63_forecast_regularised(filename=None):
scenario_file = 'lorenz63_forecast_regularised.toml'
instances = list(pypfilt.load_instances(scenario_file))
instance = instances[0]
......@@ -166,7 +167,7 @@ def run_lorenz63_forecast_regularised():
# Run a forecast from t = 20.
forecast_time = 20
context = instance.build_context()
return pypfilt.forecast(context, [forecast_time], filename=None)
return pypfilt.forecast(context, [forecast_time], filename=filename)
def run_all_lorenz63_scenarios():
......@@ -197,6 +198,78 @@ def run_all_lorenz63_scenarios():
return (obs_tables, forecast_results, regularised_results)
def plot_crps_comparison(crps_fs, crps_reg, png_file):
"""
Plot the CRPS values for the original and regularised forecasts.
"""
import matplotlib.pyplot as plt
import pypfilt.plot
matplotlib.use('Agg')
# Calculate the skill score for the regularised forecasts, relative to the
# original forecasts.
mean_fs = np.mean(crps_fs['score'])
mean_reg = np.mean(crps_reg['score'])
skill = (mean_fs - mean_reg) / mean_fs
with pypfilt.plot.apply_style():
fig, ax = plt.subplots(layout='constrained')
ax.set_xlabel('Time')
ax.set_ylabel('CRPS (lower is better)')
ax.plot('time', 'score', data=crps_fs, label='Original forecast')
ax.plot(
'time',
'score',
data=crps_reg,
label=f'With regularisation\n(skill score = {skill:0.3f})',
)
ax.legend(loc='upper left')
fig.set_figwidth(6)
fig.set_figheight(3)
fig.savefig(
png_file,
format='png',
dpi=300,
transparent=True,
metadata={'Software': None},
)
def score_lorenz63_forecasts():
"""Calculate CRPS values for the simulated `z(t)` observations."""
# Load the true observations that occur after the forecasting time.
columns = [('time', float), ('value', float)]
z_true = pypfilt.io.read_table('lorenz63-z.ssv', columns)
z_true = z_true[z_true['time'] > 20]
# Run the original forecasts.
fs_file = 'lorenz63_forecast.hdf5'
fs = run_lorenz63_forecast(filename=fs_file)
# Run the forecasts with regularisation.
reg_file = 'lorenz63_regularised.hdf5'
fs_reg = run_lorenz63_forecast_regularised(filename=reg_file)
# Extract the simulated z(t) observations for each forecast.
time = pypfilt.Scalar()
z_table = '/tables/sim_z'
z_fs = pypfilt.io.load_summary_table(time, fs_file, z_table)
z_reg = pypfilt.io.load_summary_table(time, reg_file, z_table)
# Calculate CRPS values for each forecast.
crps_fs = pypfilt.crps.simulated_obs_crps(z_true, z_fs)
crps_reg = pypfilt.crps.simulated_obs_crps(z_true, z_reg)
# Check that regularisation improved the forecast performance.
assert np.mean(crps_reg['score']) < np.mean(crps_fs['score'])
# Compare the CRPS values for each forecast.
plot_crps_comparison(crps_fs, crps_reg, 'lorenz63_crps_comparison.png')
return (fs, fs_reg)
def test_lorenz63():
output_dir = Path('doc').resolve() / 'getting-started'
......@@ -204,13 +277,16 @@ def test_lorenz63():
pypfilt.examples.lorenz.save_lorenz63_scenario_files()
obs_tables = simulate_lorenz63_observations()
fs = run_lorenz63_forecast()
fs, fs_reg = score_lorenz63_forecasts()
plot_lorenz63_forecast(fs, obs_tables, 'lorenz63_forecast.png')
fs_reg = run_lorenz63_forecast_regularised()
plot_lorenz63_forecast(
fs_reg, obs_tables, 'lorenz63_forecast_regularised.png'
)
# Remove the output HDF5 files.
for hdf5_file in Path().glob('lorenz63_*.hdf5'):
hdf5_file.unlink()
# Check that the all-scenarios file produces identical results.
(new_obs_tables, new_fs, new_fs_reg) = run_all_lorenz63_scenarios()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment