Skip to content
Snippets Groups Projects
Commit f19ed9be authored by Rob Moss's avatar Rob Moss
Browse files

Add support for custom parameter samplers

This allows for greater flexibility in constructing the model prior
distributions and, in particular, enables drop-in support for Latin
hypercube sampling, as used in our pandemic scenario modelling.
parent a31bcf31
No related branches found
No related tags found
No related merge requests found
Pipeline #6797 passed
......@@ -32,3 +32,5 @@ Forecast scenario data types
.. autodata:: ParamsFn
.. autodata:: SummaryFn
.. autoclass:: _Settings
......@@ -62,6 +62,7 @@ table), while others are likely of no use outside of pypfilt_ (see the
particle weights
:mod:`pypfilt.resample` Implements particle resampling and
post-regularisation
:mod:`pypfilt.sampler` Construct sampling functions for each model parameter
:mod:`pypfilt.state` Creates the state history matrix
:mod:`pypfilt.stats` Calculates weighted quantiles, credible intervals,
etc
......@@ -91,5 +92,6 @@ table), while others are likely of no use outside of pypfilt_ (see the
context
pfilter
resample
sampler
state
stats
pypfilt.sampler
===============
.. py:module:: pypfilt.sampler
The :mod:`pypfilt.sampler` module provides a sampler that draws samples from each model parameter's prior distribution:
.. autoclass:: Sampler
.. note::
This sampler is used by default, but you can explicitly specify it in your scenario files:
.. code-block:: toml
[components]
sampler = "pypfilt.Sampler"
Defining a custom sampler
-------------------------
A custom sampler should derive the following base class:
.. autoclass:: Base
:members:
......@@ -11,6 +11,7 @@ from . import pfilter
from . import model
from . import obs
from . import params
from . import sampler
from . import summary
from . import sweep
from . import time
......@@ -30,6 +31,7 @@ Config = config.Config
Context = context.Context
Obs = obs.Obs
Model = model.Model
Sampler = sampler.Sampler
Monitor = summary.Monitor
Table = summary.Table
Datetime = time.Datetime
......
......@@ -96,10 +96,22 @@ class _Settings(NamedTuple):
"""
The configuration settings that apply in a particular context, such as
global settings or scenario-specific settings.
:param time: The simulation time scale component.
:param model: The simulation model component.
:param summary_fn: The simulation summary component constructor.
:param sampler: The model parameter sampler component.
:param params: The scenario parameters table.
:param priors: The ``model.priors`` section of the scenario parameters.
:param bounds: The ``model.bounds`` section of the scenario parameters.
:param summary_args: The constructor arguments for the summary component.
:param summary_monitors: The summary monitor components, indexed by name.
:param summary_tables: The summary table components, indexed by name.
"""
time: Any
model: Any
summary_fn: Any
sampler: Any
params: Dict[str, Any]
priors: Dict[str, Any]
bounds: Dict[str, Any]
......@@ -147,8 +159,13 @@ def make_settings(cfg_data: Dict[str, Any]) -> _Settings:
model = instantiate(__get(components, 'model', 'Components'))
time = instantiate(__get(components, 'time', 'Components'))
summary_fn = lookup(__get(components, 'summary', 'Components'))
if 'sampler' in components:
sampler = instantiate(components['sampler'])
else:
from .sampler import Sampler
sampler = Sampler()
for key in components:
if key not in ['model', 'time', 'summary']:
if key not in ['model', 'time', 'summary', 'sampler']:
logger.warning('unrecognised component "%s"', key)
model_priors = {}
......@@ -198,6 +215,7 @@ def make_settings(cfg_data: Dict[str, Any]) -> _Settings:
time=time,
model=model,
summary_fn=summary_fn,
sampler=sampler,
params=param_data,
priors=model_priors,
bounds=model_bounds,
......@@ -251,31 +269,6 @@ def override_dict(defaults, overrides):
return defaults
def validate_priors(settings: _Settings) -> None:
"""
Ensure each prior comprises a function name and arguments dictionary,
otherwise raise a ValueError.
Note that this doesn't enforce that each prior corresponds to a known
model parameter, because this would prevent us from supporting prior
distributions that are expressed in terms of **transformed** parameters
(such as reciprocals of rate parameters).
"""
logger = logging.getLogger(__name__)
for (name, info) in settings.priors.items():
if 'function' not in info:
raise ValueError('Missing prior function for {}'.format(name))
elif not isinstance(info['function'], str):
raise ValueError('Invalid prior function for {}'.format(name))
if 'args' not in info:
raise ValueError('Missing prior arguments for {}'.format(name))
elif not isinstance(info['args'], dict):
raise ValueError('Invalid prior arguments for {}'.format(name))
if len(info) != 2:
extra_keys = [k for k in info if k not in ['function', 'args']]
logger.warning('Extra prior keys for %s: %s', name, extra_keys)
def validate_bounds(settings: _Settings) -> None:
"""Ensure bounds only refer to known parameters, or raise a ValueError."""
descr = set(settings.model.field_names())
......@@ -376,42 +369,6 @@ def define_summary_tables(settings, params):
params['component']['summary_table'][name] = table
def make_prior_fn(fn_name, args):
"""
Return a function that draws samples from the specified distribution.
:param fn_name: The name of a ``numpy.random.Generator`` method used to
generate samples.
:param args: A dictionary of keyword arguments.
As a special case, ``fn_name`` may be set to ``'inverse_uniform'`` to
sample from a uniform distribution and then take the reciprocal:
.. math:: X \\sim \\frac{1}{\\mathcal{U}(a, b)}
The bounds ``a`` and ``b`` may be specified by the following keyword
arguments:
+ :code:`a = args['low']` **or** :code:`a = 1 / args['inv_low']`
+ :code:`b = args['high']` **or** :code:`b = 1 / args['inv_high']`
"""
if fn_name == 'inverse_uniform':
if 'low' in args:
low = args['low']
else:
low = 1 / args['inv_low']
if 'high' in args:
high = args['high']
else:
high = 1 / args['inv_high']
return lambda r, size=None: 1 / r.uniform(
low=low, high=high, size=size)
else:
return lambda r, size=None: getattr(r, fn_name)(**args, size=size)
def make_params_fn(scen_data, settings: _Settings,
scen_id, scen_name) -> Callable[[], dict]:
"""
......@@ -430,6 +387,8 @@ def make_params_fn(scen_data, settings: _Settings,
max_days=max_days,
px_count=px_count,
prng_seed=prng_seed)
# Replace the default sampler with the chosen sampler.
params['component']['sampler'] = settings.sampler
params = override_dict(params, overrides)
# NOTE: ensure the start and end of the simulation period are
......@@ -456,11 +415,8 @@ def make_params_fn(scen_data, settings: _Settings,
params['model']['param_bounds'][name] = (info['min'],
info['max'])
# Allow priors to specify an RNG sampling function and arguments.
if settings.priors:
for (name, prior) in settings.priors.items():
params['model']['prior'][name] = make_prior_fn(
prior['function'], prior['args'])
# Ask the sampler to prepare the prior distributions.
settings.sampler.update_params(scen_data, settings, params)
create_lookup_tables(scen_data, settings, params)
define_summary_tables(settings, params)
......@@ -485,7 +441,7 @@ def make_scenario(scen_id: str, scen_data: dict, defaults: dict) -> Scenario:
cfg_data = override_dict(defaults, scen_data)
settings = make_settings(cfg_data)
obs_models = make_obs_models(cfg_data, scen_name)
validate_priors(settings)
settings.sampler.prepare_scenario(cfg_data, settings)
validate_bounds(settings)
params_fn = make_params_fn(cfg_data, settings, scen_id, scen_name)
......
import numpy as np
import tempfile
......@@ -16,6 +15,9 @@ def default_params(model, time_scale, max_days, px_count, prng_seed):
:param px_count: The number of particles.
:param prng_seed: The seed for the pseudo-random number generators.
"""
# NOTE: avoid circular imports by importing this module here.
from .sampler import Sampler
bounds = model.bounds()
params = {
'resample': {
......@@ -54,6 +56,7 @@ def default_params(model, time_scale, max_days, px_count, prng_seed):
'component': {
'time': time_scale,
'model': model,
'sampler': Sampler(),
'random': {},
'lookup': {},
'obs': {},
......
"""Construct sampling functions for each model parameter."""
import abc
import logging
from .config import _Settings
class Base(abc.ABC):
"""
The base class for parameter samplers.
"""
@abc.abstractmethod
def prepare_scenario(self, scen_data, settings):
"""
Perform any validation or other preparations for the given scenario.
:param scen_data: The configuration dictionary for this scenario.
:param settings: The collected settings for this scenario (:class:`~pypfilt.config._Settings`).
"""
pass
@abc.abstractmethod
def update_params(self, scen_data, settings, params):
"""
Define a sampling function for each model parameter, each of which
should be stored in ``params['model']['prior'][param_name]``.
:param scen_data: The configuration dictionary for this scenario.
:param settings: The collected settings for this scenario (:class:`~pypfilt.config._Settings`).
:param params: The simulation parameters dictionary.
"""
pass
def make_prior_fn(fn_name, args):
"""
Return a function that draws samples from the specified distribution.
:param fn_name: The name of a ``numpy.random.Generator`` method used to
generate samples.
:param args: A dictionary of keyword arguments.
As a special case, ``fn_name`` may be set to ``'inverse_uniform'`` to
sample from a uniform distribution and then take the reciprocal:
.. math:: X \\sim \\frac{1}{\\mathcal{U}(a, b)}
The bounds ``a`` and ``b`` may be specified by the following keyword
arguments:
+ :code:`a = args['low']` **or** :code:`a = 1 / args['inv_low']`
+ :code:`b = args['high']` **or** :code:`b = 1 / args['inv_high']`
"""
if fn_name == 'inverse_uniform':
if 'low' in args:
low = args['low']
else:
low = 1 / args['inv_low']
if 'high' in args:
high = args['high']
else:
high = 1 / args['inv_high']
return lambda r, size=None: 1 / r.uniform(
low=low, high=high, size=size)
else:
return lambda r, size=None: getattr(r, fn_name)(**args, size=size)
def validate_priors(settings: _Settings) -> None:
"""
Ensure each prior comprises a function name and arguments dictionary,
otherwise raise a ValueError.
Note that this doesn't enforce that each prior corresponds to a known
model parameter, because this would prevent us from supporting prior
distributions that are expressed in terms of **transformed** parameters
(such as reciprocals of rate parameters).
"""
logger = logging.getLogger(__name__)
for (name, info) in settings.priors.items():
if 'function' not in info:
raise ValueError('Missing prior function for {}'.format(name))
elif not isinstance(info['function'], str):
raise ValueError('Invalid prior function for {}'.format(name))
if 'args' not in info:
raise ValueError('Missing prior arguments for {}'.format(name))
elif not isinstance(info['args'], dict):
raise ValueError('Invalid prior arguments for {}'.format(name))
if len(info) != 2:
extra_keys = [k for k in info if k not in ['function', 'args']]
logger.warning('Extra prior keys for %s: %s', name, extra_keys)
class Sampler(Base):
"""
The default sampler, which draws independent samples for each model
parameter.
"""
def prepare_scenario(self, scen_data, settings):
validate_priors(settings)
def update_params(self, scen_data, settings, params):
if settings.priors:
for (name, prior) in settings.priors.items():
params['model']['prior'][name] = make_prior_fn(
prior['function'], prior['args'])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment