Skip to content
Snippets Groups Projects
Commit 692bf982 authored by Rob Moss's avatar Rob Moss
Browse files

Store parameter bounds by name, not index

parent d09f02e6
No related branches found
No related tags found
No related merge requests found
Pipeline #6515 failed
......@@ -425,9 +425,6 @@ def make_params_fn(scen_data, settings: _Settings,
overrides = {k: v for (k, v) in settings.params.items()
if k not in ignore}
descr = {field[0]: ix for (ix, field) in
enumerate(settings.model.state_fields())}
def params_fn():
params = default_params(settings.model,
settings.time,
......@@ -456,10 +453,9 @@ def make_params_fn(scen_data, settings: _Settings,
if settings.bounds:
for (name, info) in settings.bounds.items():
# NOTE: apply parameter bounds.
param_ix = descr[name]
params['model']['param_min'][param_ix] = info['min']
params['model']['param_max'][param_ix] = info['max']
# NOTE: apply parameter bounds by name.
params['model']['param_bounds'][name] = (info['min'],
info['max'])
# Allow priors to specify an RNG sampling function and arguments.
if settings.priors:
......
......@@ -16,10 +16,7 @@ def default_params(model, time_scale, max_days, px_count, prng_seed):
:param px_count: The number of particles.
:param prng_seed: The seed for the pseudo-random number generators.
"""
field_names = [field[0] for field in model.state_fields()]
bounds = model.bounds()
p_min = [bounds[name][0] for name in field_names]
p_max = [bounds[name][1] for name in field_names]
params = {
'resample': {
# Resample when the effective number of particles is 25%.
......@@ -80,10 +77,8 @@ def default_params(model, time_scale, max_days, px_count, prng_seed):
'until': None,
},
'model': {
# The lower bounds for each model parameter.
'param_min': np.array(p_min),
# The upper bounds for each model parameter.
'param_max': np.array(p_max),
# The lower and upper bounds for each model parameter.
'param_bounds': bounds,
# The model prior distributions.
'prior': {},
},
......
......@@ -111,12 +111,8 @@ def post_regularise(ctx, px, new_px):
scaled_samples = np.transpose(np.dot(a_mat, h * std_samples))
# Add the sampled noise and clip to respect parameter bounds.
smooth_ixs = [ix for (ix, name) in enumerate(field_names)
if name in smooth_fields]
for (ix, name) in enumerate(smooth_fields):
param_ix = smooth_ixs[ix]
min_val = ctx.params['model']['param_min'][param_ix]
max_val = ctx.params['model']['param_max'][param_ix]
(min_val, max_val) = ctx.params['model']['param_bounds'][name]
new_px['state_vec'][name] = np.clip(
new_px['state_vec'][name] + scaled_samples[:, ix],
min_val, max_val)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment