Skip to content
Snippets Groups Projects
Commit 38bebb9a authored by Rob Moss's avatar Rob Moss
Browse files

Document and test pypfilt.state.repack

parent 18a04591
No related branches found
No related tags found
No related merge requests found
Pipeline #6810 passed
......@@ -14,3 +14,5 @@ pypfilt.state
.. autofunction:: is_state_vec_matrix
.. autofunction:: require_state_vec_matrix
.. autofunction:: repack
......@@ -3,7 +3,6 @@
import logging
import numpy as np
import numpy.lib.recfunctions
import warnings
def is_history_matrix(ctx, arr):
......@@ -171,25 +170,38 @@ def earlier_states(hist, ix, steps):
return hist[ix - steps, parent_ixs]
def repack(svec):
def repack(svec, astype=float):
"""
Return a copy of the array ``svec`` where the fields are contiguous and
viewed as a regular Numpy array of floats.
:raises UserWarning: if ``svec`` contains any non-float fields, each of
which will be converted to floats.
"""
svec = numpy.lib.recfunctions.repack_fields(svec)
# Convert any non-float fields (e.g., integers) into floats.
cols = [(name, svec.dtype.fields[name][0]) for name in svec.dtype.names]
all_floats = all(np.issubdtype(col[1], float) for col in cols)
if not all_floats:
warnings.warn('Repacking array with non-float fields',
stacklevel=2)
new_dtype = [
(col[0], col[1] if np.issubdtype(col[1], float) else float,
col[1].shape)
for col in cols]
svec = np.array(svec, dtype=new_dtype)
svec = svec.view((float, len(svec.dtype.names)))
return svec
viewed as a regular Numpy array of ``astype``.
:raises ValueError: if ``svec`` contains any fields that are incompatible
with ``astype``.
:Examples:
>>> import numpy as np
>>> from pypfilt.state import repack
>>> xs = np.array([(1.2, (2.2, 3.2)), (4.2, (5.2, 6.2))],
... dtype=[('x', float), ('y', float, 2)])
>>> ys = repack(xs)
>>> assert np.array_equal(ys, np.array([[1.2, 2.2, 3.2],
... [4.2, 5.2, 6.2]]))
"""
def is_compat(dt):
if dt.subdtype is None:
return np.issubdtype(dt, astype)
else:
return np.issubdtype(dt.subdtype[0], astype)
field_dtypes = {name: info[0] for name, info in svec.dtype.fields.items()}
incompat = [name for name, dt in field_dtypes.items()
if not is_compat(dt)]
if incompat:
msg = 'Fields {} are not compatible with type {}'
raise ValueError(msg.format(', '.join(incompat), astype))
out = numpy.lib.recfunctions.repack_fields(svec).view(astype)
new_shape = (*svec.shape, -1)
return np.squeeze(out.reshape(new_shape))
import numpy as np
import pytest
from pypfilt.state import repack
def test_repack_flat_float():
xs = np.array([(1.2, 2.2, 3.2), (4.2, 5.2, 6.2)],
dtype=[('x', float), ('y', float), ('z', float)])
ys = repack(xs)
assert ys.shape == (2, 3)
assert ys.dtype == np.dtype(float)
assert np.array_equal(ys, np.array([[1.2, 2.2, 3.2],
[4.2, 5.2, 6.2]]))
def test_repack_flat_int():
xs = np.array([(1, 2, 3), (4, 5, 6)],
dtype=[('x', int), ('y', int), ('z', int)])
ys = repack(xs, astype=int)
assert ys.shape == (2, 3)
assert ys.dtype == np.dtype(int)
assert np.array_equal(ys, np.array([[1, 2, 3],
[4, 5, 6]],
dtype=int))
def test_repack_flat_int_as_float():
xs = np.array([(1, 2, 3), (4, 5, 6)],
dtype=[('x', int), ('y', int), ('z', int)])
# NOTE: view integers as floats will cause trouble.
with pytest.raises(ValueError):
repack(xs, astype=float)
def test_repack_nested_float():
xs = np.array([(1.2, (2.2, 3.2)), (4.2, (5.2, 6.2))],
dtype=[('x', float), ('y', float, 2)])
ys = repack(xs)
assert ys.shape == (2, 3)
assert ys.dtype == np.dtype(float)
assert np.array_equal(ys, np.array([[1.2, 2.2, 3.2],
[4.2, 5.2, 6.2]]))
def test_repack_nested_int():
xs = np.array([(1, (2, 3)), (4, (5, 6))],
dtype=[('x', int), ('y', int, 2)])
ys = repack(xs, astype=int)
assert ys.shape == (2, 3)
assert ys.dtype == np.dtype(int)
assert np.array_equal(ys, np.array([[1, 2, 3],
[4, 5, 6]],
dtype=int))
def test_repack_nested_int_as_float():
xs = np.array([(1, (2, 3)), (4, (5, 6))],
dtype=[('x', int), ('y', int, 2)])
with pytest.raises(ValueError):
repack(xs, astype=float)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment