diff --git a/src/pypfilt/obs.py b/src/pypfilt/obs.py index aa44140bfac017a5c66df6aae7b40cae885f6ae6..043fc1445d2f2259183ccfb69fc297c5a5eb3431 100644 --- a/src/pypfilt/obs.py +++ b/src/pypfilt/obs.py @@ -30,16 +30,13 @@ def expect(ctx, time, unit, period, prev, curr): raise ValueError("Unknown observation type '{}'".format(unit)) -def log_llhd_of(ctx, hist, hist_ix, obs, max_back=0): +def log_llhd_of(ctx, hist, hist_ix, obs): """Return the log-likelihood of obtaining observations from each particle. :param ctx: The simulation context. :param hist: The particle history matrix. :param hist_ix: The index of the current time-step in the history matrix. :param obs: The observation(s) that have been made. - :param max_back: The number of time-steps into the past when the most - recent resampling occurred (i.e., how far back the current particle - ordering is guaranteed to persist; default is 0, no guarantee). :returns: An array containing the log-likelihood for each particle. """ @@ -58,16 +55,8 @@ def log_llhd_of(ctx, hist, hist_ix, obs, max_back=0): def hist_for(period): """Return past state vectors in the appropriate order.""" steps_back = steps_per_unit * period - same_ixs = max_back >= steps_back - if same_ixs: - if steps_back > hist_ix: - # If the observation period starts before the beginning of the - # the simulation period, the initial state should be returned. - return hist[0] - else: - return hist[hist_ix - steps_back] - else: - return state.earlier_states(hist, hist_ix, steps_back) + return state.earlier_states(hist, hist_ix, steps_back) + period_hists = {period: hist_for(period) for period in periods} # Calculate the log-likelihood of obtaining the given observation, for diff --git a/src/pypfilt/pfilter.py b/src/pypfilt/pfilter.py index 9ab17af4ff2d47498bde18b7e883400a6428ccdd..c7dc6799dde7026f98963e9976c9ce8ca3bae625 100644 --- a/src/pypfilt/pfilter.py +++ b/src/pypfilt/pfilter.py @@ -8,16 +8,13 @@ from . import obs as obs_mod from . import state as state_mod -def reweight(ctx, hist, hist_ix, obs, max_back): +def reweight(ctx, hist, hist_ix, obs): """Adjust particle weights in response to some observation(s). :param params: The simulation parameters. :param hist: The particle history matrix. :param hist_ix: The index of the current time-step in the history matrix. :param obs: The observation(s) that have been made. - :param max_back: The number of time-steps into the past when the most - recent resampling occurred (i.e., how far back the current particle - ordering is guaranteed to persist). :returns: A tuple; the first element (*bool*) indicates whether resampling is required, the second element (*float*) is the **effective** number @@ -25,7 +22,7 @@ def reweight(ctx, hist, hist_ix, obs, max_back): """ # Calculate the log-likelihood of obtaining the given observation, for # each particle. - logs = obs_mod.log_llhd_of(ctx, hist, hist_ix, obs, max_back) + logs = obs_mod.log_llhd_of(ctx, hist, hist_ix, obs) # Scale the log-likelihoods so that the maximum is 0 (i.e., has a # likelihood of 1) to increase the chance of smaller likelihoods @@ -80,7 +77,7 @@ def __log_step(ctx, when, do_resample, num_eff=None): ctx.component['time'].to_unicode(when), resp[do_resample])) -def step(ctx, hist, hist_ix, step_num, when, step_obs, max_back, is_fs): +def step(ctx, hist, hist_ix, step_num, when, step_obs, is_fs): """Perform a single time-step for every particle. :param params: The simulation parameters. @@ -89,9 +86,6 @@ def step(ctx, hist, hist_ix, step_num, when, step_obs, max_back, is_fs): :param step_num: The time-step number. :param when: The current simulation time. :param step_obs: The list of observations for this time-step. - :param max_back: The number of time-steps into the past when the most - recent resampling occurred; must be either a positive integer or - ``None`` (no limit). :param is_fs: Indicate whether this is a forecasting simulation (i.e., no observations). For deterministic models it is useful to add some random noise when @@ -132,8 +126,7 @@ def step(ctx, hist, hist_ix, step_num, when, step_obs, max_back, is_fs): num_eff = None do_resample = False if step_obs: - do_resample, num_eff = reweight(ctx, hist, hist_ix, step_obs, - max_back) + do_resample, num_eff = reweight(ctx, hist, hist_ix, step_obs) __log_step(ctx, when, do_resample, num_eff) @@ -192,13 +185,7 @@ def run(ctx, start, end, streams, state=None, win_start = start # The time of the previous time-step (if any). most_recent = None - # The time-step number of the most recent resampling (if any). # NOTE: the first time-step is number 1 and updates hist[1] given hist[0]. - # So we set this value to zero to indicate that we can step back as far as - # the beginning of this simulation; this is important when resuming from - # cached states where we cannot assume anything about the history prior to - # the first time-step. - last_rs = 0 # The index of the current time-step in the state matrix. hist_ix = None @@ -236,12 +223,10 @@ def run(ctx, start, end, streams, state=None, if hist.dtype.names is not None and 'lookup' in hist.dtype.names: hist['lookup'][hist_ix] = 0 - # Determine how many time-steps back the most recent resampling was. - max_back = (step_num - last_rs) - # Simulate the current time-step. - resampled = step(ctx, hist, hist_ix, step_num, when, obs, - max_back, is_fs) + resampled = step(ctx, hist, hist_ix, step_num, when, obs, is_fs) + # Record whether the particles were resampled at this time-step. + hist['resampled'][hist_ix] = resampled # Check whether to save the particle history matrix to disk. # NOTE: the summary object may not have summarised the model state @@ -272,8 +257,6 @@ def run(ctx, start, end, streams, state=None, # Finally, update loop variables. most_recent = when - if resampled: - last_rs = step_num if hist_ix is None: # There were no time-steps. diff --git a/src/pypfilt/state.py b/src/pypfilt/state.py index 9bf1d768131d43749ae8aeb7e68884cd321b351d..47265a401ad9a95ec96f086b85d43ab4b9e3272f 100644 --- a/src/pypfilt/state.py +++ b/src/pypfilt/state.py @@ -90,7 +90,8 @@ def history_matrix_dtype(ctx): * lookup: sample indices for lookup tables (if required). """ # We always need to record the particle weight and parent index. - base_dtype = [('weight', np.float_), ('prev_ix', np.int_)] + base_dtype = [('weight', np.float_), ('prev_ix', np.int_), + ('resampled', np.bool_)] # Determine the structure of the model state vector. svec_dtype = state_vec_dtype(ctx) @@ -165,12 +166,24 @@ def earlier_states(hist, ix, steps): :param ix: The current time-step index. :param steps: The number of steps back in time. """ + logger = logging.getLogger(__name__) + # Don't go too far back (negative indices jump into the future). + if steps > ix: + msg_fmt = 'Cannot look back {} time-steps, will look back {}' + logger.warn(msg_ftm.format(steps, ix)) steps = min(steps, ix) - parent_ixs = np.arange(hist.shape[1]) - for i in range(steps): - parent_ixs = hist['prev_ix'][ix - i, parent_ixs] + # Start with the parent indices for the current particles, which allow us + # to look back one time-step. + parent_ixs = np.copy(hist['prev_ix'][ix]) + + # Continue looking back one time-step, and only update the parent indices + # at time-step T if the particles were resampled at time-step T+1. + for i in range(1, steps): + step_ix = ix - i + if hist['resampled'][step_ix + 1, 0]: + parent_ixs = hist['prev_ix'][step_ix, parent_ixs] return hist[ix - steps, parent_ixs] diff --git a/tests/test_resample.py b/tests/test_resample.py index 139b89d2ea40c569ce09249b434e731cb5c82e07..1bd5659efe53189050983cf5c0eb9a95aa295e8f 100644 --- a/tests/test_resample.py +++ b/tests/test_resample.py @@ -25,10 +25,12 @@ def test_resample(method): weights = np.array([0.50, 0.25, 0.1, 0.1, 0.02, 0.02, 0.01]) ixs = np.zeros(weights.shape) - dtype = [('weight', np.float_), ('prev_ix', np.int_)] + resampled = np.zeros(weights.shape, dtype=np.bool_) + dtype = [('weight', np.float_), ('prev_ix', np.int_), + ('resampled', np.bool_)] n_tries = 10 for i in range(n_tries): - x = np.array(list(zip(weights, ixs)), dtype=dtype) + x = np.array(list(zip(weights, ixs, resampled)), dtype=dtype) resample(ctx, x) prev_ix = x['prev_ix'] if not all(x['weight'] == x['weight'][0]): @@ -113,18 +115,19 @@ def test_post_regularise(): ctx = Scaffold(component=comps, params=params) # NOTE: at a minimum, we need 'weight', 'prev_ix', 'state_vec'. px = np.array([ - (0.1, 0, (1.0, 1.0, 1.0)), - (0.1, 0, (1.0, 2.0, 1.0)), - (0.1, 0, (1.0, 3.0, 1.0)), - (0.1, 0, (1.0, 4.0, 1.0)), - (0.1, 0, (1.0, 5.0, 1.0)), - (0.1, 1, (2.0, 6.0, 1.0)), - (0.1, 1, (2.0, 7.0, 1.0)), - (0.1, 1, (2.0, 8.0, 1.0)), - (0.1, 1, (2.0, 9.0, 1.0)), - (0.1, 1, (2.0, 10.0, 1.0)), + (0.1, 0, False, (1.0, 1.0, 1.0)), + (0.1, 0, False, (1.0, 2.0, 1.0)), + (0.1, 0, False, (1.0, 3.0, 1.0)), + (0.1, 0, False, (1.0, 4.0, 1.0)), + (0.1, 0, False, (1.0, 5.0, 1.0)), + (0.1, 1, False, (2.0, 6.0, 1.0)), + (0.1, 1, False, (2.0, 7.0, 1.0)), + (0.1, 1, False, (2.0, 8.0, 1.0)), + (0.1, 1, False, (2.0, 9.0, 1.0)), + (0.1, 1, False, (2.0, 10.0, 1.0)), ], dtype=[('weight', np.float_), ('prev_ix', np.int_), + ('resampled', np.bool_), ('state_vec', model.field_types(ctx))]) new_px = np.copy(px) post_regularise(ctx, px, new_px)