diff --git a/tests/test_obs_rename_columns.py b/tests/test_obs_rename_columns.py new file mode 100644 index 0000000000000000000000000000000000000000..9ecacdd079a3b6f3cca662f2b53eb22c4ebe1aed --- /dev/null +++ b/tests/test_obs_rename_columns.py @@ -0,0 +1,42 @@ +import h5py +import pypfilt +import scipy.stats + + +class FixedNormalObs(pypfilt.obs.Univariate): + def distribution(self, ctx, snapshot): + return scipy.stats.norm(loc=0, scale=1) + + +def test_obs_rename_columns(tmp_path): + """ + Ensure that table metadata is correct when loading observations from files + whose columns names are not ``'time'`` and ``'value'``. + """ + obs_unit = 'observations' + settings = {} + obs_model = FixedNormalObs(obs_unit, settings) + + obs_file = tmp_path / 'test_obs.ssv' + time_scale = pypfilt.Datetime() + with open(obs_file, 'w') as f: + f.write('When What\n') + f.write('2024-09-18 4\n') + f.write('2024-09-19 5\n') + f.write('2024-09-20 6\n') + + table = obs_model.from_file( + obs_file, time_scale, time_col='When', value_col='What' + ) + obs_file.unlink() + + # Verify that the observations table has correct metadata. + assert 'string_columns' in table.dtype.metadata + assert table.dtype.metadata['string_columns'] == [] + assert 'time_columns' in table.dtype.metadata + assert table.dtype.metadata['time_columns'] == ['time'] + + # Verify that we can save the observations table to disk. + out_file = tmp_path / 'test_obs.hdf5' + with h5py.File(out_file, 'w') as f: + pypfilt.io.save_dataset(time_scale, f, 'test_obs', table)