Skip to content
Snippets Groups Projects
Commit 7688ca2b authored by Rob Moss's avatar Rob Moss
Browse files

Add a test for the effect of caching PRNG states

Note that PRNG states are not yet saved to, or restored from, the cache
and so this test currently checks that outputs differ when the same
forecast is generated two times.
parent 4a91d117
Branches
Tags
No related merge requests found
Pipeline #17215 passed
......@@ -32,6 +32,12 @@ class LotkaVolterra(pypfilt.Model):
vec['gamma'] = ctx.params['model']['prior']['gamma'](rnd, size)
vec['delta'] = ctx.params['model']['prior']['delta'](rnd, size)
def resume_from_cache(self, ctx):
# NOTE: must ensure that self.dtype is defined when starting from a
# cached model state and LotkaVolterra.init() hasn't previously been
# called.
self.dtype = np.dtype(self.field_types(ctx))
def field_types(self, ctx):
return [(name, np.dtype(float))
for name in self.field_names()]
......
"Test that pypfile.cache correctly saves and restores PRNG states."
import logging
import numpy as np
import os
import pypfilt
import pypfilt.cache
import pypfilt.context
import pypfilt.examples
import pypfilt.examples.predation
import scipy.stats
def test_cached_prng_states(caplog):
"""Test that caching PRNG states has the desired effects."""
# NOTE: The caplog fixture captures logging; see the pytest documentation:
# https://docs.pytest.org/en/stable/logging.html
caplog.set_level(logging.INFO)
time_scale = pypfilt.Scalar()
t0 = 0.0
t1 = 15.0
t2 = 30.0
(state_1, state_2) = run_two_forecasts(time_scale, t0, t1, t2)
# Extract the relevant summary tables for each forecast.
cints_1 = state_1[t1]['summary']['model_cints']
cints_2 = state_2[t1]['summary']['model_cints']
sim_obs_1 = state_1[t1]['summary']['sim_obs']
sim_obs_2 = state_2[t1]['summary']['sim_obs']
# NOTE: PRNG states are not currently cached, so the model credible
# intervals and the simulated observations should differ between these two
# forecasts.
assert not np.array_equal(cints_1, cints_2)
assert not np.array_equal(sim_obs_1, sim_obs_2)
# Show the logging messages produced by the forecast runs.
print(caplog.text)
def make_params(px_count=20, seed=42, obs_sdev=0.2):
"""Construct the simulation parameters for StochModel."""
model = StochModel()
time_scale = pypfilt.Scalar()
params = pypfilt.default_params(model, time_scale, max_days=15,
px_count=px_count, prng_seed=seed)
params['steps_per_unit'] = 1
params['summary']['from_first_day'] = True
params['model']['prior'] = {
'x': lambda r, size=None: r.uniform(10.0, 20.0, size=size)
}
params['obs'] = {
'x': {'sdev': obs_sdev},
}
params['component']['obs'] = {
'x': StochObs(obs_unit='x', obs_period=0),
}
# Write output to the working directory.
params['out_dir'] = '.'
params['tmp_dir'] = '.'
return params
def run_forecast(time_scale, t0, t1, t2, cache_file, x_tbl, obs):
"""Generate a StochModel forecast."""
params = make_params()
params['time']['start'] = t0
params['time']['until'] = t2
params['hist']['cache_file'] = cache_file
params['data']['obs']['x'] = x_tbl
summary = pypfilt.summary.HDF5(params, obs)
params['component']['summary'] = summary
params['component']['summary_table'] = {
'model_cints': pypfilt.summary.ModelCIs(),
'sim_obs': pypfilt.summary.SimulatedObs(),
}
return pypfilt.forecast(params, [obs], [t1], filename=None)
def run_two_forecasts(time_scale, t0, t1, t2):
"""
Run the same forecast twice; the first time with an empty cache, and the
second resuming from a cached state.
"""
# Ensure we start without a cache file.
cache_file = 'test_cached_prng_states.hdf5'
try:
os.remove(cache_file)
except FileNotFoundError:
pass
# Simulate observations from a single particle.
sim_obs = simulate_from_model(time_scale, t0, t1, px_count=1)
# Extract the observations from the summary table.
obs = [{
'date': time_scale.from_dtype(row['date']),
'period': 0,
'unit': row['unit'],
'value': row['value'],
'source': 'test_cache',
} for row in sim_obs]
# Run a forecast to populate the cache, and then check that rerunning the
# forecast from the cached state produces identical model dynamics.
assert not os.path.isfile(cache_file)
state_1 = run_forecast(time_scale, t0, t1, t2, cache_file, sim_obs, obs)
assert os.path.isfile(cache_file)
state_2 = run_forecast(time_scale, t0, t1, t2, cache_file, sim_obs, obs)
os.remove(cache_file)
return (state_1, state_2)
def simulate_from_model(time_scale, t0, t1, px_count):
"""Simulate observations from a single StochModel particle."""
params = make_params()
params['time']['start'] = t0
params['time']['until'] = t1
params['model']['prior'] = {
'x': lambda r, size=None: 15.0 * np.ones(size),
}
return pypfilt.simulate_from_model(params)
class StochModel(pypfilt.Model):
"""A very simple stochastic model."""
def init(self, ctx, vec):
rnd = ctx.component['random']['model']
size = vec.shape
vec['x'] = ctx.params['model']['prior']['x'](rnd, size)
def field_types(self, ctx):
return [('x', np.dtype(float))]
def field_names(self):
return ['x']
def update(self, ctx, t, dt, is_fs, prev, curr):
"""Perform a single time-step."""
rnd = ctx.component['random']['model']
curr['x'] = prev['x'] + rnd.uniform(size=curr.shape)
def can_smooth(self):
return {}
def bounds(self):
return {}
class StochObs(pypfilt.Obs):
"""A Gaussian observation model for StochModel."""
def __init__(self, obs_unit, obs_period):
self.unit = obs_unit
self.period = obs_period
def log_llhd(self, params, op, time, obs, curr, hist):
# NOTE: the expected observations are x(t) and y(t).
# Calculate the log-likelihood of each observation in turn.
x_t = curr['state_vec']['x']
x_dist = scipy.stats.norm(loc=x_t, scale=op['sdev'])
return x_dist.logpdf(obs['value'])
def simulate(self, params, op, time, period, expect, rng=None):
if rng is None:
return scipy.stats.norm(loc=expect, scale=op['sdev']).rvs()
else:
return rng.normal(loc=expect, scale=op['sdev'])
def expect(self, ctx, op, time, period, prev, curr):
return curr['state_vec']['x']
def quantiles(self, params, op, time, mu, wt, probs):
# The minimum interval width before we decide that a value is
# sufficiently accurate.
tolerance = 0.00001
scale = op['sdev']
normal = scipy.stats.norm(loc=mu, scale=scale)
def cdf(y):
"""Calculate the CDF of the weighted sum over all particles."""
return np.dot(wt, normal.cdf(y))
def bisect(a, b):
"""
Return the midpoint of the interval [a, b], or ``None`` if the
minimum tolerance has been reached.
"""
if b > a + tolerance:
return (a + b) / 2
else:
return None
# Find appropriate lower and upper bounds.
pr_min = np.min(probs)
pr_max = np.max(probs)
y0_lower = scipy.stats.norm(loc=np.min(mu), scale=scale).ppf(pr_min)
y0_upper = scipy.stats.norm(loc=np.max(mu), scale=scale).ppf(pr_max)
return pypfilt.obs.bisect_cdf(probs, cdf, bisect, y0_lower, y0_upper)
def from_file(self, filename, time_scale):
raise NotImplementedError()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment