Monitor Model Runs¶
Models may be complex, built from many processes and may take a while to run. xarray-simlab provides functionality to help in monitoring model runs.
This section demonstrates how to use the built-in progress bar. Moreover, it exemplifies how to create your own custom monitoring.
Let’s use the following setup for the examples below. It is based on the
advect_model
created in Section Create and Modify Models.
In [1]: import xsimlab as xs
In [2]: in_ds = xs.create_setup(
...: model=advect_model,
...: clocks={
...: 'time': np.linspace(0., 1., 6),
...: },
...: input_vars={
...: 'grid': {'length': 1.5, 'spacing': 0.01},
...: 'init': {'loc': 0.3, 'scale': 0.1},
...: 'advect__v': 1.
...: },
...: output_vars={'profile__u': 'time'}
...: )
...:
Progress bar¶
ProgressBar
is based on the Tqdm package and
allows to track the progress of simulation runs in xarray-simlab
. It can be
used as a context manager around simulation calls:
In [3]: from xsimlab.monitoring import ProgressBar
In [4]: with ProgressBar():
...: out_ds = in_ds.xsimlab.run(model=advect_model)
...:
########## 100% | Simulation finished in 00:00
Alternatively, you can pass the progress bar via the hooks
argument of
Dataset.xsimlab.run()
or you can use the register
method (for more
information, refer to Section Custom runtime hooks).
ProgressBar
and the underlying Tqdm tool are built to work with different
Python front-ends. Use the optional argument frontend
depending on your
environment:
auto
: automatically selects the front-end (default)console
: renders the progress bar as textgui
: progress rich rendering (experimental), which needs matplotlib to be installednotebook
: for use within IPython/Jupyter notebooks, which needs ipywidgets to be installed
Additionally, you can customize the built-in progress bar by supplying
keyword arguments list to ProgressBar
, e.g.:
In [5]: with ProgressBar(bar_format="{desc}|{bar}{r_bar}"):
...: out_ds = in_ds.xsimlab.run(model=advect_model)
...:
Simulation finished in 00:00|##########| 7/7 [00:00<00:00, 921.39it/s]
For a full list of customization options, refer to the Tqdm documentation.
Note
Currently this progress bar doesn’t support tracking the progress of batches of simulations. If those batches are run in parallel you can use Dask’s diagnostics instead.
Custom runtime hooks¶
Custom monitoring can be implemented using simulation runtime hooks.
The runtime_hook()
decorator allows a function to be called once
or repeatedly at specific times during a simulation. The simple example below
prints the step number as the simulation proceeds:
In [6]: @xs.runtime_hook("run_step", "model", "pre")
...: def print_step_start(model, context, state):
...: print(f"Starting execution of step {context['step']}")
...:
In [7]: out_ds = in_ds.xsimlab.run(model=advect_model, hooks=[print_step_start])
Starting execution of step 0
Starting execution of step 1
Starting execution of step 2
Starting execution of step 3
Starting execution of step 4
Runtime hook functions are always called with the following 3 arguments:
model
: the instance ofModel
that is runningcontext
: a read-only dictionary that contains information about simulation runtime (seeruntime()
for a list of available keys)state
: a read-only dictionary that contains the simulation state, where keys are tuples in the form('process_name', 'variable_name')
.
An alternative to the runtime_hook
decorator is the
RuntimeHook
class. You can create new instances with any
number of hook functions, e.g.,
In [8]: @xs.runtime_hook("run_step", "model", "post")
...: def print_step_end(model, context, state):
...: print(f"Finished execution of step {context['step']}")
...:
In [9]: print_steps = xs.RuntimeHook(print_step_start, print_step_end)
In [10]: out_ds = in_ds.xsimlab.run(model=advect_model, hooks=[print_steps])
Starting execution of step 0
Finished execution of step 0
Starting execution of step 1
Finished execution of step 1
Starting execution of step 2
Finished execution of step 2
Starting execution of step 3
Finished execution of step 3
Starting execution of step 4
Finished execution of step 4
An advantage over directly using hook functions is that you can also use an
instance of RuntimeHook
either as a context manager over a model run
call or globally with its register
method:
In [11]: with print_steps:
....: out_ds = in_ds.xsimlab.run(model=advect_model)
....:
Starting execution of step 0
Finished execution of step 0
Starting execution of step 1
Finished execution of step 1
Starting execution of step 2
Finished execution of step 2
Starting execution of step 3
Finished execution of step 3
Starting execution of step 4
Finished execution of step 4
In [12]: print_steps.register()
In [13]: out_ds = in_ds.xsimlab.run(model=advect_model)
Starting execution of step 0
Finished execution of step 0
Starting execution of step 1
Finished execution of step 1
Starting execution of step 2
Finished execution of step 2
Starting execution of step 3
Finished execution of step 3
Starting execution of step 4
Finished execution of step 4
In [14]: print_steps.unregister()
In [15]: out_ds = in_ds.xsimlab.run(model=advect_model) # no print
Another advantage is that you can subclass RuntimeHook
and add decorated
methods that may share some state:
import time
class PrintStepTime(xs.RuntimeHook):
@xs.runtime_hook("run_step", "model", "pre")
def start_step(self, model, context, state):
self._start_time = time.time()
@xs.runtime_hook("run_step", "model", "post")
def finish_step(self, model, context, state):
step_time = time.time() - self._start_time
print(f"Step {context['step']} took {step_time:.2e} seconds")
In [16]: with PrintStepTime():
....: in_ds.xsimlab.run(model=advect_model)
....:
Step 0 took 1.56e-04 seconds
Step 1 took 1.30e-04 seconds
Step 2 took 1.21e-04 seconds
Step 3 took 1.19e-04 seconds
Step 4 took 1.17e-04 seconds
Control simulation runtime¶
Runtime hook functions may return a RuntimeSignal
so that you
can control the simulation workflow (e.g., skip the current stage or process,
break the simulation time steps) based on some condition or some computed value.
In the example below, the simulation stops as soon as the gaussian pulse (peak
value) has been advected past x = 0.4
.
In [17]: @xs.runtime_hook("run_step", "model", "post")
....: def maybe_stop(model, context, state):
....: peak_idx = np.argmax(state[('profile', 'u')])
....: peak_x = state[('grid', 'x')][peak_idx]
....:
....: if peak_x > 0.4:
....: print("Peak crossed x=0.4, stop simulation!")
....: return xs.RuntimeSignal.BREAK
....:
In [18]: out_ds = in_ds.xsimlab.run(
....: model=advect_model,
....: hooks=[print_step_start, maybe_stop]
....: )
....:
Starting execution of step 0
Starting execution of step 1
Starting execution of step 2
Peak crossed x=0.4, stop simulation!
Even when a simulation stops early like in the example above, the resulting
xarray Dataset still contains all time steps defined in the input Dataset.
Output variables have fill (masked) values for the time steps that were not run,
as shown below with the nan
values for profile__u
(fill values are not
stored physically in the Zarr output store).
In [19]: out_ds
Out[19]:
<xarray.Dataset>
Dimensions: (time: 6, x: 150)
Coordinates:
* time (time) float64 0.0 0.2 0.4 0.6 0.8 1.0
* x (x) float64 0.0 0.01 0.02 0.03 0.04 ... 1.46 1.47 1.48 1.49
Data variables:
advect__v float64 1.0
grid__length float64 1.5
grid__spacing float64 0.01
init__loc float64 0.3
init__scale float64 0.1
profile__u (time, x) float64 0.0001234 0.0002226 0.0003937 ... nan nan