import datetime
import cmocean # noqa: F401
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from polaris import Step
from polaris.ocean.convergence import (
get_resolution_for_task,
get_timestep_for_task,
)
from polaris.tasks.ocean.inertial_gravity_wave.exact_solution import (
ExactSolution,
)
from polaris.viz import plot_horiz_field, use_mplstyle
[docs]
class Viz(Step):
"""
A step for visualizing the output from the inertial gravity wave
test case
Attributes
----------
dependencies_dict : dict of dict of polaris.Steps
The dependencies of this step must be given as separate keys in the
dict:
mesh : dict of polaris.Steps
Keys of the dict correspond to `refinement_factors`
Values of the dict are polaris.Steps, which must have the
attribute `path`, the path to `base_mesh.nc` of that
resolution
init : dict of polaris.Steps
Keys of the dict correspond to `refinement_factors`
Values of the dict are polaris.Steps, which must have the
attribute `path`, the path to `initial_state.nc` of that
resolution
forward : dict of polaris.Steps
Keys of the dict correspond to `refinement_factors`
Values of the dict are polaris.Steps, which must have the
attribute `path`, the path to `forward.nc` of that
resolution
refinement : str, optional
Refinement type. One of 'space', 'time' or 'both' indicating both
space and time
"""
[docs]
def __init__(self, component, subdir, dependencies, refinement='both'):
"""
Create the step
Parameters
----------
component : polaris.Component
The component the step belongs to
subdir : str
The subdirectory for this step in the component's work directory
dependencies : dict of dict of polaris.Steps
The dependencies of this step must be given as separate keys in the
dict:
mesh : dict of polaris.Steps
Keys of the dict correspond to `refinement_factors`
Values of the dict are polaris.Steps, which must have the
attribute `path`, the path to `base_mesh.nc` of that
resolution
init : dict of polaris.Steps
Keys of the dict correspond to `refinement_factors`
Values of the dict are polaris.Steps, which must have the
attribute `path`, the path to `initial_state.nc` of that
resolution
forward : dict of polaris.Steps
Keys of the dict correspond to `refinement_factors`
Values of the dict are polaris.Steps, which must have the
attribute `path`, the path to `forward.nc` of that
resolution
refinement : str, optional
Refinement type. One of 'space', 'time' or 'both' indicating both
space and time
"""
super().__init__(component=component, name='viz', subdir=subdir)
self.dependencies_dict = dependencies
self.refinement = refinement
self.add_output_file('comparison.png')
def setup(self):
"""
Add input files based on resolutions, which may have been changed by
user config options
"""
super().setup()
config = self.config
dependencies = self.dependencies_dict
if self.refinement == 'time':
option = 'refinement_factors_time'
else:
option = 'refinement_factors_space'
refinement_factors = config.getlist('convergence', option, dtype=float)
for refinement_factor in refinement_factors:
base_mesh = dependencies['mesh'][refinement_factor]
init = dependencies['init'][refinement_factor]
forward = dependencies['forward'][refinement_factor]
self.add_input_file(
filename=f'mesh_r{refinement_factor:02g}.nc',
work_dir_target=f'{base_mesh.path}/base_mesh.nc',
)
self.add_input_file(
filename=f'init_r{refinement_factor:02g}.nc',
work_dir_target=f'{init.path}/initial_state.nc',
)
self.add_input_file(
filename=f'output_r{refinement_factor:02g}.nc',
work_dir_target=f'{forward.path}/output.nc',
)
[docs]
def run(self):
"""
Run this step of the test case
"""
plt.switch_backend('Agg')
config = self.config
if self.refinement == 'time':
option = 'refinement_factors_time'
else:
option = 'refinement_factors_space'
refinement_factors = config.getlist('convergence', option, dtype=float)
nres = len(refinement_factors)
section = config['inertial_gravity_wave']
eta0 = section.getfloat('ssh_amplitude')
use_mplstyle()
fig, axes = plt.subplots(nrows=nres, ncols=3, figsize=(12, 2 * nres))
rmse = []
error_range = None
for i, refinement_factor in enumerate(refinement_factors):
resolution = get_resolution_for_task(
config, refinement_factor, refinement=self.refinement
)
ds_mesh = xr.open_dataset(f'mesh_r{refinement_factor:02g}.nc')
ds_init = xr.open_dataset(f'init_r{refinement_factor:02g}.nc')
ds = xr.open_dataset(f'output_r{refinement_factor:02g}.nc')
exact = ExactSolution(ds_init, config)
t0 = datetime.datetime.strptime(
ds.xtime.values[0].decode(), '%Y-%m-%d_%H:%M:%S'
)
tf = datetime.datetime.strptime(
ds.xtime.values[-1].decode(), '%Y-%m-%d_%H:%M:%S'
)
t = (tf - t0).total_seconds()
ssh_model = ds.ssh.values[-1, :]
rmse.append(
np.sqrt(np.mean((ssh_model - exact.ssh(t).values) ** 2))
)
# Comparison plots
ds['ssh_exact'] = exact.ssh(t)
ds['ssh_error'] = ssh_model - exact.ssh(t)
if error_range is None:
error_range = np.max(np.abs(ds.ssh_error.values))
cell_mask = ds_init.maxLevelCell >= 1
descriptor = plot_horiz_field(
ds_mesh,
ds['ssh'],
ax=axes[i, 0],
cmap='cmo.balance',
t_index=ds.sizes['Time'] - 1,
vmin=-eta0,
vmax=eta0,
cmap_title='SSH (m)',
field_mask=cell_mask,
)
plot_horiz_field(
ds_mesh,
ds['ssh_exact'],
ax=axes[i, 1],
cmap='cmo.balance',
vmin=-eta0,
vmax=eta0,
cmap_title='SSH (m)',
descriptor=descriptor,
)
plot_horiz_field(
ds_mesh,
ds['ssh_error'],
ax=axes[i, 2],
cmap='cmo.balance',
cmap_title=r'$\Delta$ SSH (m)',
vmin=-error_range,
vmax=error_range,
descriptor=descriptor,
)
axes[0, 0].set_title('Numerical solution')
axes[0, 1].set_title('Analytical solution')
axes[0, 2].set_title('Error (Numerical - Analytical)')
pad = 5
for ax, refinement_factor in zip(
axes[:, 0], refinement_factors, strict=False
):
timestep, _ = get_timestep_for_task(
config, refinement_factor, refinement=self.refinement
)
resolution = get_resolution_for_task(
config, refinement_factor, refinement=self.refinement
)
ax.annotate(
f'{resolution}km\n{timestep}s',
xy=(0, 0.5),
xytext=(-ax.yaxis.labelpad - pad, 0),
xycoords=ax.yaxis.label,
textcoords='offset points',
size='large',
ha='right',
va='center',
)
fig.savefig('comparison.png', bbox_inches='tight', pad_inches=0.1)