From 9dcf621618fe76281a486af7bfba364e7bd26c13 Mon Sep 17 00:00:00 2001 From: Rob Moss <robm.dev@gmail.com> Date: Wed, 27 Oct 2021 13:54:05 +1100 Subject: [PATCH] Fix a bug in pypfilt.state.earlier_states This bug was introduced in commit 408b5f1 but was not initially detected due to stale tox artefacts. --- src/pypfilt/state.py | 11 ++++++--- tests/test_earlier_state.py | 47 +++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 3 deletions(-) create mode 100644 tests/test_earlier_state.py diff --git a/src/pypfilt/state.py b/src/pypfilt/state.py index b22f0bf..abd240b 100644 --- a/src/pypfilt/state.py +++ b/src/pypfilt/state.py @@ -168,10 +168,15 @@ def earlier_states(hist, ix, steps): """ logger = logging.getLogger(__name__) + # Return the current particles when looking back zero time-steps. + if steps == 0: + logger.debug('Looking back zero steps') + return hist[ix - steps] + # Don't go too far back (negative indices jump into the future). if steps > ix: msg_fmt = 'Cannot look back {} time-steps, will look back {}' - logger.warn(msg_fmt.format(steps, ix)) + logger.debug(msg_fmt.format(steps, ix)) steps = min(steps, ix) # Start with the parent indices for the current particles, which allow us @@ -179,10 +184,10 @@ def earlier_states(hist, ix, steps): parent_ixs = np.copy(hist['prev_ix'][ix]) # Continue looking back one time-step, and only update the parent indices - # at time-step T if the particles were resampled at time-step T+1. + # if the particles were resampled. for i in range(1, steps): step_ix = ix - i - if hist['resampled'][step_ix + 1, 0]: + if hist['resampled'][step_ix, 0]: parent_ixs = hist['prev_ix'][step_ix, parent_ixs] return hist[ix - steps, parent_ixs] diff --git a/tests/test_earlier_state.py b/tests/test_earlier_state.py new file mode 100644 index 0000000..327016b --- /dev/null +++ b/tests/test_earlier_state.py @@ -0,0 +1,47 @@ +""" +Test the behaviour of ``pypfilt.state.earlier_state()``. +""" + +import numpy as np +import pypfilt.state + + +def test_earlier_state(): + dtype = [('prev_ix', np.int_), ('resampled', np.bool_), + ('id', np.float_)] + num_steps = 6 + num_px = 4 + same_order = np.arange(num_px) + reverse_order = same_order[::-1] + hist = np.zeros(shape=(num_steps, num_px), dtype=dtype) + + # No resampling, particles remain in the same order. + hist['prev_ix'][:, :] = same_order[None, :] + hist['id'][:, :] = same_order[None, :] + + for ix in [1, 3, 5]: + for steps_back in [0, 1, 2, 3]: + prev_state = pypfilt.state.earlier_states(hist, ix, steps_back) + assert np.array_equal(prev_state['id'], same_order) + + # Reverse the particle ordering at step 2. + resample_ix = 2 + hist['resampled'][resample_ix, :] = True + hist['prev_ix'][resample_ix, :] = reverse_order + hist['id'][resample_ix:, :] = reverse_order[None, :] + + # Starting at all time-steps prior to the resampling event, the particles + # should be returned in their original order. + for ix in range(resample_ix): + for steps_back in range(num_steps): + state = pypfilt.state.earlier_states(hist, ix, steps_back) + ids = state['id'] + assert np.array_equal(ids, same_order) + + # Starting at all time-steps from the resampling event, the particles + # should be returned in reverse order. + for ix in range(2, num_steps): + for steps_back in range(num_steps): + state = pypfilt.state.earlier_states(hist, ix, steps_back) + ids = state['id'] + assert np.array_equal(ids, reverse_order) -- GitLab