diff --git a/src/pypfilt/obs.py b/src/pypfilt/obs.py
index aa44140bfac017a5c66df6aae7b40cae885f6ae6..043fc1445d2f2259183ccfb69fc297c5a5eb3431 100644
--- a/src/pypfilt/obs.py
+++ b/src/pypfilt/obs.py
@@ -30,16 +30,13 @@ def expect(ctx, time, unit, period, prev, curr):
         raise ValueError("Unknown observation type '{}'".format(unit))
 
 
-def log_llhd_of(ctx, hist, hist_ix, obs, max_back=0):
+def log_llhd_of(ctx, hist, hist_ix, obs):
     """Return the log-likelihood of obtaining observations from each particle.
 
     :param ctx: The simulation context.
     :param hist: The particle history matrix.
     :param hist_ix: The index of the current time-step in the history matrix.
     :param obs: The observation(s) that have been made.
-    :param max_back: The number of time-steps into the past when the most
-        recent resampling occurred (i.e., how far back the current particle
-        ordering is guaranteed to persist; default is 0, no guarantee).
 
     :returns: An array containing the log-likelihood for each particle.
     """
@@ -58,16 +55,8 @@ def log_llhd_of(ctx, hist, hist_ix, obs, max_back=0):
     def hist_for(period):
         """Return past state vectors in the appropriate order."""
         steps_back = steps_per_unit * period
-        same_ixs = max_back >= steps_back
-        if same_ixs:
-            if steps_back > hist_ix:
-                # If the observation period starts before the beginning of the
-                # the simulation period, the initial state should be returned.
-                return hist[0]
-            else:
-                return hist[hist_ix - steps_back]
-        else:
-            return state.earlier_states(hist, hist_ix, steps_back)
+        return state.earlier_states(hist, hist_ix, steps_back)
+
     period_hists = {period: hist_for(period) for period in periods}
 
     # Calculate the log-likelihood of obtaining the given observation, for
diff --git a/src/pypfilt/pfilter.py b/src/pypfilt/pfilter.py
index 9ab17af4ff2d47498bde18b7e883400a6428ccdd..c7dc6799dde7026f98963e9976c9ce8ca3bae625 100644
--- a/src/pypfilt/pfilter.py
+++ b/src/pypfilt/pfilter.py
@@ -8,16 +8,13 @@ from . import obs as obs_mod
 from . import state as state_mod
 
 
-def reweight(ctx, hist, hist_ix, obs, max_back):
+def reweight(ctx, hist, hist_ix, obs):
     """Adjust particle weights in response to some observation(s).
 
     :param params: The simulation parameters.
     :param hist: The particle history matrix.
     :param hist_ix: The index of the current time-step in the history matrix.
     :param obs: The observation(s) that have been made.
-    :param max_back: The number of time-steps into the past when the most
-        recent resampling occurred (i.e., how far back the current particle
-        ordering is guaranteed to persist).
 
     :returns: A tuple; the first element (*bool*) indicates whether resampling
         is required, the second element (*float*) is the **effective** number
@@ -25,7 +22,7 @@ def reweight(ctx, hist, hist_ix, obs, max_back):
     """
     # Calculate the log-likelihood of obtaining the given observation, for
     # each particle.
-    logs = obs_mod.log_llhd_of(ctx, hist, hist_ix, obs, max_back)
+    logs = obs_mod.log_llhd_of(ctx, hist, hist_ix, obs)
 
     # Scale the log-likelihoods so that the maximum is 0 (i.e., has a
     # likelihood of 1) to increase the chance of smaller likelihoods
@@ -80,7 +77,7 @@ def __log_step(ctx, when, do_resample, num_eff=None):
             ctx.component['time'].to_unicode(when), resp[do_resample]))
 
 
-def step(ctx, hist, hist_ix, step_num, when, step_obs, max_back, is_fs):
+def step(ctx, hist, hist_ix, step_num, when, step_obs, is_fs):
     """Perform a single time-step for every particle.
 
     :param params: The simulation parameters.
@@ -89,9 +86,6 @@ def step(ctx, hist, hist_ix, step_num, when, step_obs, max_back, is_fs):
     :param step_num: The time-step number.
     :param when: The current simulation time.
     :param step_obs: The list of observations for this time-step.
-    :param max_back: The number of time-steps into the past when the most
-        recent resampling occurred; must be either a positive integer or
-        ``None`` (no limit).
     :param is_fs: Indicate whether this is a forecasting simulation (i.e., no
         observations).
         For deterministic models it is useful to add some random noise when
@@ -132,8 +126,7 @@ def step(ctx, hist, hist_ix, step_num, when, step_obs, max_back, is_fs):
     num_eff = None
     do_resample = False
     if step_obs:
-        do_resample, num_eff = reweight(ctx, hist, hist_ix, step_obs,
-                                        max_back)
+        do_resample, num_eff = reweight(ctx, hist, hist_ix, step_obs)
 
     __log_step(ctx, when, do_resample, num_eff)
 
@@ -192,13 +185,7 @@ def run(ctx, start, end, streams, state=None,
     win_start = start
     # The time of the previous time-step (if any).
     most_recent = None
-    # The time-step number of the most recent resampling (if any).
     # NOTE: the first time-step is number 1 and updates hist[1] given hist[0].
-    # So we set this value to zero to indicate that we can step back as far as
-    # the beginning of this simulation; this is important when resuming from
-    # cached states where we cannot assume anything about the history prior to
-    # the first time-step.
-    last_rs = 0
     # The index of the current time-step in the state matrix.
     hist_ix = None
 
@@ -236,12 +223,10 @@ def run(ctx, start, end, streams, state=None,
             if hist.dtype.names is not None and 'lookup' in hist.dtype.names:
                 hist['lookup'][hist_ix] = 0
 
-        # Determine how many time-steps back the most recent resampling was.
-        max_back = (step_num - last_rs)
-
         # Simulate the current time-step.
-        resampled = step(ctx, hist, hist_ix, step_num, when, obs,
-                         max_back, is_fs)
+        resampled = step(ctx, hist, hist_ix, step_num, when, obs, is_fs)
+        # Record whether the particles were resampled at this time-step.
+        hist['resampled'][hist_ix] = resampled
 
         # Check whether to save the particle history matrix to disk.
         # NOTE: the summary object may not have summarised the model state
@@ -272,8 +257,6 @@ def run(ctx, start, end, streams, state=None,
 
         # Finally, update loop variables.
         most_recent = when
-        if resampled:
-            last_rs = step_num
 
     if hist_ix is None:
         # There were no time-steps.
diff --git a/src/pypfilt/state.py b/src/pypfilt/state.py
index 9bf1d768131d43749ae8aeb7e68884cd321b351d..47265a401ad9a95ec96f086b85d43ab4b9e3272f 100644
--- a/src/pypfilt/state.py
+++ b/src/pypfilt/state.py
@@ -90,7 +90,8 @@ def history_matrix_dtype(ctx):
     * lookup: sample indices for lookup tables (if required).
     """
     # We always need to record the particle weight and parent index.
-    base_dtype = [('weight', np.float_), ('prev_ix', np.int_)]
+    base_dtype = [('weight', np.float_), ('prev_ix', np.int_),
+                  ('resampled', np.bool_)]
 
     # Determine the structure of the model state vector.
     svec_dtype = state_vec_dtype(ctx)
@@ -165,12 +166,24 @@ def earlier_states(hist, ix, steps):
     :param ix: The current time-step index.
     :param steps: The number of steps back in time.
     """
+    logger = logging.getLogger(__name__)
+
     # Don't go too far back (negative indices jump into the future).
+    if steps > ix:
+        msg_fmt = 'Cannot look back {} time-steps, will look back {}'
+        logger.warn(msg_ftm.format(steps, ix))
     steps = min(steps, ix)
-    parent_ixs = np.arange(hist.shape[1])
 
-    for i in range(steps):
-        parent_ixs = hist['prev_ix'][ix - i, parent_ixs]
+    # Start with the parent indices for the current particles, which allow us
+    # to look back one time-step.
+    parent_ixs = np.copy(hist['prev_ix'][ix])
+
+    # Continue looking back one time-step, and only update the parent indices
+    # at time-step T if the particles were resampled at time-step T+1.
+    for i in range(1, steps):
+        step_ix = ix - i
+        if hist['resampled'][step_ix + 1, 0]:
+            parent_ixs = hist['prev_ix'][step_ix, parent_ixs]
 
     return hist[ix - steps, parent_ixs]
 
diff --git a/tests/test_resample.py b/tests/test_resample.py
index 139b89d2ea40c569ce09249b434e731cb5c82e07..1bd5659efe53189050983cf5c0eb9a95aa295e8f 100644
--- a/tests/test_resample.py
+++ b/tests/test_resample.py
@@ -25,10 +25,12 @@ def test_resample(method):
 
     weights = np.array([0.50, 0.25, 0.1, 0.1, 0.02, 0.02, 0.01])
     ixs = np.zeros(weights.shape)
-    dtype = [('weight', np.float_), ('prev_ix', np.int_)]
+    resampled = np.zeros(weights.shape, dtype=np.bool_)
+    dtype = [('weight', np.float_), ('prev_ix', np.int_),
+             ('resampled', np.bool_)]
     n_tries = 10
     for i in range(n_tries):
-        x = np.array(list(zip(weights, ixs)), dtype=dtype)
+        x = np.array(list(zip(weights, ixs, resampled)), dtype=dtype)
         resample(ctx, x)
         prev_ix = x['prev_ix']
         if not all(x['weight'] == x['weight'][0]):
@@ -113,18 +115,19 @@ def test_post_regularise():
     ctx = Scaffold(component=comps, params=params)
     # NOTE: at a minimum, we need 'weight', 'prev_ix', 'state_vec'.
     px = np.array([
-        (0.1, 0, (1.0, 1.0, 1.0)),
-        (0.1, 0, (1.0, 2.0, 1.0)),
-        (0.1, 0, (1.0, 3.0, 1.0)),
-        (0.1, 0, (1.0, 4.0, 1.0)),
-        (0.1, 0, (1.0, 5.0, 1.0)),
-        (0.1, 1, (2.0, 6.0, 1.0)),
-        (0.1, 1, (2.0, 7.0, 1.0)),
-        (0.1, 1, (2.0, 8.0, 1.0)),
-        (0.1, 1, (2.0, 9.0, 1.0)),
-        (0.1, 1, (2.0, 10.0, 1.0)),
+        (0.1, 0, False, (1.0, 1.0, 1.0)),
+        (0.1, 0, False, (1.0, 2.0, 1.0)),
+        (0.1, 0, False, (1.0, 3.0, 1.0)),
+        (0.1, 0, False, (1.0, 4.0, 1.0)),
+        (0.1, 0, False, (1.0, 5.0, 1.0)),
+        (0.1, 1, False, (2.0, 6.0, 1.0)),
+        (0.1, 1, False, (2.0, 7.0, 1.0)),
+        (0.1, 1, False, (2.0, 8.0, 1.0)),
+        (0.1, 1, False, (2.0, 9.0, 1.0)),
+        (0.1, 1, False, (2.0, 10.0, 1.0)),
     ], dtype=[('weight', np.float_),
               ('prev_ix', np.int_),
+              ('resampled', np.bool_),
               ('state_vec', model.field_types(ctx))])
     new_px = np.copy(px)
     post_regularise(ctx, px, new_px)