import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
from matplotlib.colors import Normalize
from mpl_toolkits.axes_grid1 import make_axes_locatable
from polaris.step import Step
[docs]
class VizCombinedStep(Step):
"""
A step for visualizing the combined topography dataset
Attributes
----------
combine_step : polaris.tasks.e3sm.init.topo.combine.CombineStep
The combine step to use for visualization
"""
[docs]
def __init__(self, component, config, combine_step):
"""
Create a new step
Parameters
----------
component : polaris.Component
The component the step belongs
config : polaris.config.PolarisConfigParser
The config options for the step
combine_step : polaris.tasks.e3sm.init.topo.combine.CombineStep
The combine step to use for visualization
"""
super().__init__(
component=component,
name='viz_combine_topo',
subdir=os.path.join(combine_step.subdir, 'viz'),
cpus_per_task=128,
min_cpus_per_task=1,
)
self.combine_step = combine_step
self.set_shared_config(config, link='combine_topo.cfg')
[docs]
def setup(self):
"""
Set up the step in the work directory, including linking input files
"""
combine_step = self.combine_step
topo_filename = combine_step.combined_filename
exodus_filename = combine_step.exodus_filename
self.add_input_file(
filename='topography.nc',
work_dir_target=os.path.join(combine_step.path, topo_filename),
)
self.add_input_file(
filename='cubed_sphere.g',
work_dir_target=os.path.join(combine_step.path, exodus_filename),
)
[docs]
def run(self):
"""
Run this step
"""
colormaps = {
'bathymetry': 'cmo.deep_r',
'thickness': 'cmo.ice_r',
'ice_draft': 'cmo.deep_r',
'ice_mask': 'cmo.amp_r',
'bathymetry_mask': 'cmo.amp_r',
'grounded_mask': 'cmo.amp_r',
'ocean_mask': 'cmo.amp_r',
'water_column': 'cmo.deep',
}
ds_data = xr.open_dataset('topography.nc')
# Use one field to define the valid mask (they all share indexing)
valid_mask = np.isfinite(ds_data['bathymetry'].values)
# Build mesh only once
vertices, tris = self._load_trimesh_geometry(
'cubed_sphere.g', valid_mask
)
# Plot each field
for field, colormap in colormaps.items():
self.logger.info(f'Plotting field: {field}')
data = ds_data[field].values[valid_mask]
self._plot_field(vertices, tris, data, field, colormap)
@staticmethod
def _load_trimesh_geometry(exodus_path, valid_mask):
ds_mesh = xr.open_dataset(exodus_path, decode_coords=False)
coords = ds_mesh['coord'].values
# 0-based
conn = ds_mesh['connect1'].values - 1
x, y, z = coords[0], coords[1], coords[2]
r = np.sqrt(x**2 + y**2 + z**2)
lat_nodes = np.degrees(np.arcsin(z / r))
lon_nodes = np.degrees(np.arctan2(y, x))
# Apply element mask to connectivity
# shape (n_cells, 4)
conn_valid = conn[valid_mask]
# Split each quad into 2 triangles: [0, 1, 2] and [0, 2, 3]
tris = np.empty((2 * conn_valid.shape[0], 3), dtype=int)
# lower triangle
tris[0::2] = conn_valid[:, [0, 1, 2]]
# upper triangle
tris[1::2] = conn_valid[:, [0, 2, 3]]
# Convert to DataFrame: each row is a triangle with 3 vertex indices
tris = pd.DataFrame(tris, columns=['v0', 'v1', 'v2'])
# Convert vertices to DataFrame: each row is a vertex with lon/lat
vertices = pd.DataFrame({'lon': lon_nodes, 'lat': lat_nodes})
return vertices, tris
def _plot_field(self, vertices, tris, field_data, field_name, cmap):
"""
Rasterize and save a trisurf-style field image using Datashader.
"""
try:
import datashader
except ImportError as err:
raise ImportError(
'the datashader package is not installed. '
'Please install in your conda environment so you can run '
'the topography visualization step.'
) from err
import numba
numba.set_num_threads(self.cpus_per_task)
image_filename = f'{field_name}.png'
# Repeat each field value twice (for 2 triangles per quad)
tris['value'] = np.repeat(field_data, 2)
canvas = datashader.Canvas(
plot_width=2000,
plot_height=1000,
x_range=(-180, 180),
y_range=(-90, 90),
)
agg = canvas.trimesh(
simplices=tris, vertices=vertices, agg=datashader.mean('value')
)
self._plot_with_colorbar(
agg, cmap=cmap, field_name=field_name, filename=image_filename
)
def _plot_with_colorbar(
self, agg, cmap, field_name, vmin=None, vmax=None, filename=None
):
"""
Render a datashader aggregate with matplotlib colorbar.
"""
# Convert the aggregate (xarray) to numpy and mask NaNs
# mask background
img_data = np.ma.masked_invalid(agg.data).astype('float32')
# Normalize range
norm = Normalize(
vmin=np.nanmin(img_data) if vmin is None else vmin,
vmax=np.nanmax(img_data) if vmax is None else vmax,
)
# Plot
fig, ax = plt.subplots(figsize=(22, 10))
im = ax.imshow(
img_data, cmap=cmap, norm=norm, origin='lower', aspect='equal'
)
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='2%', pad=0.05)
plt.colorbar(im, cax=cax, label=field_name)
ax.axis('off')
plt.savefig(filename, dpi=150, bbox_inches='tight')
self.logger.info(f' Plot with colorbar saved to {filename}')
plt.close()