From 1a14e6c60f2d6b1813ec36cc7264e7a8dcd16ec4 Mon Sep 17 00:00:00 2001 From: Rob Moss <robm.dev@gmail.com> Date: Wed, 18 Sep 2024 09:36:14 +1000 Subject: [PATCH] Detect invalid metadata when loading observations --- tests/test_obs_rename_columns.py | 42 ++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 tests/test_obs_rename_columns.py diff --git a/tests/test_obs_rename_columns.py b/tests/test_obs_rename_columns.py new file mode 100644 index 0000000..9ecacdd --- /dev/null +++ b/tests/test_obs_rename_columns.py @@ -0,0 +1,42 @@ +import h5py +import pypfilt +import scipy.stats + + +class FixedNormalObs(pypfilt.obs.Univariate): + def distribution(self, ctx, snapshot): + return scipy.stats.norm(loc=0, scale=1) + + +def test_obs_rename_columns(tmp_path): + """ + Ensure that table metadata is correct when loading observations from files + whose columns names are not ``'time'`` and ``'value'``. + """ + obs_unit = 'observations' + settings = {} + obs_model = FixedNormalObs(obs_unit, settings) + + obs_file = tmp_path / 'test_obs.ssv' + time_scale = pypfilt.Datetime() + with open(obs_file, 'w') as f: + f.write('When What\n') + f.write('2024-09-18 4\n') + f.write('2024-09-19 5\n') + f.write('2024-09-20 6\n') + + table = obs_model.from_file( + obs_file, time_scale, time_col='When', value_col='What' + ) + obs_file.unlink() + + # Verify that the observations table has correct metadata. + assert 'string_columns' in table.dtype.metadata + assert table.dtype.metadata['string_columns'] == [] + assert 'time_columns' in table.dtype.metadata + assert table.dtype.metadata['time_columns'] == ['time'] + + # Verify that we can save the observations table to disk. + out_file = tmp_path / 'test_obs.hdf5' + with h5py.File(out_file, 'w') as f: + pypfilt.io.save_dataset(time_scale, f, 'test_obs', table) -- GitLab