diff --git a/src/pypfilt/state.py b/src/pypfilt/state.py
index b22f0bf55135a9cbcfd1e0feb22e1583fb6cf852..abd240b8b5f34d9ce35fbb2844e43717807cfcb9 100644
--- a/src/pypfilt/state.py
+++ b/src/pypfilt/state.py
@@ -168,10 +168,15 @@ def earlier_states(hist, ix, steps):
     """
     logger = logging.getLogger(__name__)
 
+    # Return the current particles when looking back zero time-steps.
+    if steps == 0:
+        logger.debug('Looking back zero steps')
+        return hist[ix - steps]
+
     # Don't go too far back (negative indices jump into the future).
     if steps > ix:
         msg_fmt = 'Cannot look back {} time-steps, will look back {}'
-        logger.warn(msg_fmt.format(steps, ix))
+        logger.debug(msg_fmt.format(steps, ix))
     steps = min(steps, ix)
 
     # Start with the parent indices for the current particles, which allow us
@@ -179,10 +184,10 @@ def earlier_states(hist, ix, steps):
     parent_ixs = np.copy(hist['prev_ix'][ix])
 
     # Continue looking back one time-step, and only update the parent indices
-    # at time-step T if the particles were resampled at time-step T+1.
+    # if the particles were resampled.
     for i in range(1, steps):
         step_ix = ix - i
-        if hist['resampled'][step_ix + 1, 0]:
+        if hist['resampled'][step_ix, 0]:
             parent_ixs = hist['prev_ix'][step_ix, parent_ixs]
 
     return hist[ix - steps, parent_ixs]
diff --git a/tests/test_earlier_state.py b/tests/test_earlier_state.py
new file mode 100644
index 0000000000000000000000000000000000000000..327016b1dee6c123122dab8d0947415323ec1da8
--- /dev/null
+++ b/tests/test_earlier_state.py
@@ -0,0 +1,47 @@
+"""
+Test the behaviour of ``pypfilt.state.earlier_state()``.
+"""
+
+import numpy as np
+import pypfilt.state
+
+
+def test_earlier_state():
+    dtype = [('prev_ix', np.int_), ('resampled', np.bool_),
+             ('id', np.float_)]
+    num_steps = 6
+    num_px = 4
+    same_order = np.arange(num_px)
+    reverse_order = same_order[::-1]
+    hist = np.zeros(shape=(num_steps, num_px), dtype=dtype)
+
+    # No resampling, particles remain in the same order.
+    hist['prev_ix'][:, :] = same_order[None, :]
+    hist['id'][:, :] = same_order[None, :]
+
+    for ix in [1, 3, 5]:
+        for steps_back in [0, 1, 2, 3]:
+            prev_state = pypfilt.state.earlier_states(hist, ix, steps_back)
+            assert np.array_equal(prev_state['id'], same_order)
+
+    # Reverse the particle ordering at step 2.
+    resample_ix = 2
+    hist['resampled'][resample_ix, :] = True
+    hist['prev_ix'][resample_ix, :] = reverse_order
+    hist['id'][resample_ix:, :] = reverse_order[None, :]
+
+    # Starting at all time-steps prior to the resampling event, the particles
+    # should be returned in their original order.
+    for ix in range(resample_ix):
+        for steps_back in range(num_steps):
+            state = pypfilt.state.earlier_states(hist, ix, steps_back)
+            ids = state['id']
+            assert np.array_equal(ids, same_order)
+
+    # Starting at all time-steps from the resampling event, the particles
+    # should be returned in reverse order.
+    for ix in range(2, num_steps):
+        for steps_back in range(num_steps):
+            state = pypfilt.state.earlier_states(hist, ix, steps_back)
+            ids = state['id']
+            assert np.array_equal(ids, reverse_order)