Skip to content
Snippets Groups Projects
Commit cb269565 authored by Rob Moss's avatar Rob Moss
Browse files

Add a simple post-regularisation test case

parent 1ee9a139
No related branches found
No related tags found
No related merge requests found
Pipeline #7119 passed
......@@ -3,7 +3,7 @@
import numpy as np
import pytest
from pypfilt.resample import resample, resample_ixs
from pypfilt.resample import post_regularise, resample, resample_ixs
from pypfilt.context import Scaffold
......@@ -58,3 +58,83 @@ def test_resample_ixs(method):
else:
pytest.fail("Failure after {} loops with {} resampling".format(
n_tries, method))
def dummy_model(fields, smooth_fields):
"""
Construct a dummy model that only defines its field names, field types,
and which fields can be smoothed.
:param fields: The list of field names, all of which are treated as
floating-point values.
:param smooth_fields: The list of field names that can be smoothed.
"""
class DummyModel:
def field_names(self):
return fields
def field_types(self, ctx):
return [(f, float) for f in fields]
def can_smooth(self):
return smooth_fields
return DummyModel()
def test_post_regularise():
"""
Test post-regularisation using an ensemble of 10 particles for a model of
three variables.
"""
model = dummy_model(['a', 'b', 'c'], ['a', 'b', 'c'])
comps = {
'random': {
'resample': np.random.default_rng(seed=202101),
},
'model': model,
}
params = {
'resample': {
'reg_toln': 1e-8,
'regularise_or_fail': False,
},
'model': {
'param_bounds': {
'a': (-100, 100),
'b': (-100, 100),
},
},
'hist': {
'lookup_cols': {},
}
}
ctx = Scaffold(component=comps, params=params)
# NOTE: at a minimum, we need 'weight', 'prev_ix', 'state_vec'.
px = np.array([
(0.1, 0, (1.0, 1.0, 1.0)),
(0.1, 0, (1.0, 2.0, 1.0)),
(0.1, 0, (1.0, 3.0, 1.0)),
(0.1, 0, (1.0, 4.0, 1.0)),
(0.1, 0, (1.0, 5.0, 1.0)),
(0.1, 1, (2.0, 6.0, 1.0)),
(0.1, 1, (2.0, 7.0, 1.0)),
(0.1, 1, (2.0, 8.0, 1.0)),
(0.1, 1, (2.0, 9.0, 1.0)),
(0.1, 1, (2.0, 10.0, 1.0)),
], dtype=[('weight', np.float_),
('prev_ix', np.int_),
('state_vec', model.field_types(ctx))])
new_px = np.copy(px)
post_regularise(ctx, px, new_px)
assert(np.array_equal(px['weight'], new_px['weight']))
assert(np.array_equal(px['prev_ix'], new_px['prev_ix']))
assert(not np.array_equal(px['state_vec'], new_px['state_vec']))
assert(np.array_equal(px['state_vec']['c'], new_px['state_vec']['c']))
abs_diff_a = np.abs(new_px['state_vec']['a'] - px['state_vec']['a'])
assert(np.mean(abs_diff_a) < 0.2)
assert(np.mean(abs_diff_a) > 0.01)
abs_diff_b = np.abs(new_px['state_vec']['b'] - px['state_vec']['b'])
assert(np.mean(abs_diff_b) < 1.0)
assert(np.mean(abs_diff_b) > 0.2)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment