Source code for polaris.tasks.ocean.customizable_viz.viz_transect

import os

import cmocean  # noqa: F401
import numpy as np
import xarray as xr
from mpas_tools.io import write_netcdf
from mpas_tools.ocean.viz.transect import compute_transect, plot_transect

from polaris.ocean.model import OceanIOStep as OceanIOStep
from polaris.viz import (
    determine_time_variable,
    get_viz_defaults,
)


[docs] class VizTransect(OceanIOStep): """ A step for visualizing MPAS vertical transects """
[docs] def __init__(self, component, name, indir): super().__init__(component=component, name=name, indir=indir)
[docs] def runtime_setup(self): section = self.config['customizable_viz'] self.mesh_file = section.get('mesh_file') self.input_file = section.get('input_file') self.transect_file = section.get('transect_file') section_name = 'customizable_viz_transect' self.variables = self.config.getlist( section_name, 'variables', dtype=str ) if not self.variables: raise ValueError( f'No variables specified in the {section_name} section of ' 'the config file.' )
[docs] def run(self): section_name = 'customizable_viz_transect' section = self.config[section_name] layer_interface_color = section.get('layer_interface_color') x_start = section.getfloat('x_start') x_end = section.getfloat('x_end') y_start = section.getfloat('y_start') y_end = section.getfloat('y_end') x = xr.DataArray(data=[x_start, x_end]) y = xr.DataArray(data=[y_start, y_end]) ds_mesh = self.open_model_dataset(self.mesh_file) ds = self.open_model_dataset(self.input_file, decode_timedelta=False) # TODO support time selection from config file if 'Time' in ds.dims: t_index = 0 ds = ds.isel(Time=t_index) prefix, time_variable = determine_time_variable(ds) if time_variable is not None: start_time = ds[time_variable].values if isinstance(start_time, np.ndarray): start_time = ( start_time.item() if start_time.size == 1 else start_time[0] ) if isinstance(start_time, (bytes, bytearray, np.bytes_)): start_time = start_time.decode() else: start_time = str(start_time) time_stamp = f'_{start_time.split("_")[0]}' else: time_stamp = '' if os.path.exists(self.transect_file): ds_transect = xr.open_dataset(self.transect_file) self.logger.info(f'loading transect from {self.transect_file}') else: ds_transect = compute_transect( x=x, y=y, ds_horiz_mesh=ds_mesh, layer_thickness=ds[f'{prefix}layerThickness'], bottom_depth=ds_mesh.bottomDepth, min_level_cell=ds_mesh.minLevelCell - 1, max_level_cell=ds_mesh.maxLevelCell - 1, spherical=True, ) self.logger.info('saving transect to {self.transect_file}') write_netcdf(ds_transect, self.transect_file) cell_indices = ds_transect.cellIndices ds_data = ds.isel(nCells=cell_indices) viz_dict = get_viz_defaults() if self.config.has_option(section_name, 'colormap_range_percent'): colormap_range_percent = self.config.getfloat( section_name, 'colormap_range_percent' ) else: colormap_range_percent = 0.0 for var_name in self.variables: if 'accumulated' in var_name: full_var_name = var_name else: full_var_name = f'{prefix}{var_name}' if full_var_name not in ds.keys(): if f'{prefix}activeTracers_{var_name}' in ds.keys(): full_var_name = f'{prefix}activeTracers_{var_name}' elif var_name == 'columnThickness': ds[full_var_name] = ds.bottomDepth + ds.ssh else: print( f'Skipping {full_var_name}, ' f'not found in {self.input_file}' ) continue print(f'Plotting {full_var_name}') mpas_field = ds[f'{full_var_name}'] data = ds_data[f'{full_var_name}'] if var_name in viz_dict.keys(): cmap = viz_dict[var_name]['colormap'] units = viz_dict[var_name]['units'] else: cmap = viz_dict['default']['colormap'] units = viz_dict['default']['units'] if colormap_range_percent > 0.0: mask = data.values == data.values vmin = np.percentile(data.values[mask], colormap_range_percent) vmax = np.percentile( data.values[mask], 100.0 - colormap_range_percent ) else: if 'nVertLevelsP1' in ds.sizes: plot_data = data.isel( nVertLevelsP1=ds_transect.levelIndices ) else: plot_data = data.isel(nVertLevels=ds_transect.levelIndices) valid = ds_transect.validCells plot_data = plot_data.where(valid) vmin = plot_data.min().values vmax = plot_data.max().values if self.config.has_option( section_name, 'vmin' ) and self.config.has_option(section_name, 'vmax'): vmin = section.getfloat('vmin') vmax = section.getfloat('vmax') elif ( cmap == 'cmo.balance' or 'vertVelocityTop' in var_name or 'Tendency' in var_name or 'Flux' in var_name ): vmax = max(abs(vmin), abs(vmax)) vmin = -vmax norm_args = f'"vmin": {vmin}, "vmax": {vmax}' self.config.set( section_name, 'norm_args', value='{' + norm_args + '}', ) plot_transect( ds_transect=ds_transect, mpas_field=mpas_field, title=f'{var_name}', out_filename=f'{var_name}_transect{time_stamp}.png', interface_color=layer_interface_color, vmin=vmin, vmax=vmax, cmap=cmap, colorbar_label=units, color_start_and_end=True, )