Source code for polaris.ocean.tasks.sphere_transport.mixing_analysis

from math import ceil

import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from matplotlib.lines import Line2D

from polaris import Step
from polaris.mpas import time_index_from_xtime
from polaris.ocean.convergence import get_resolution_for_task
from polaris.ocean.resolution import resolution_to_subdir
from polaris.viz import use_mplstyle


[docs] class MixingAnalysis(Step): """ A step for analyzing the output from sphere transport test cases Attributes ---------- resolutions : list of float The resolutions of the meshes that have been run icosahedral : bool Whether to use icosahedral, as opposed to less regular, JIGSAW meshes case_name : str The name of the test case """
[docs] def __init__( self, component, refinement_factors, icosahedral, subdir, case_name, dependencies, refinement='both', ): """ Create the step Parameters ---------- component : polaris.Component The component the step belongs to resolutions : list of float The resolutions of the meshes that have been run icosahedral : bool Whether to use icosahedral, as opposed to less regular, JIGSAW meshes subdir : str The subdirectory that the step resides in case_name: str The name of the test case dependencies : dict of dict of polaris.Steps The dependencies of this step """ super().__init__( component=component, name='mixing_analysis', subdir=subdir ) self.refinement_factors = refinement_factors self.refinement = refinement self.case_name = case_name for refinement_factor in refinement_factors: base_mesh = dependencies['mesh'][refinement_factor] init = dependencies['init'][refinement_factor] forward = dependencies['forward'][refinement_factor] self.add_input_file( filename=f'mesh_r{refinement_factor:02g}.nc', work_dir_target=f'{base_mesh.path}/base_mesh.nc', ) self.add_input_file( filename=f'init_r{refinement_factor:02g}.nc', work_dir_target=f'{init.path}/initial_state.nc', ) self.add_input_file( filename=f'output_r{refinement_factor:02g}.nc', work_dir_target=f'{forward.path}/output.nc', ) self.add_output_file('triplots.png')
[docs] def run(self): """ Run this step of the test case """ plt.switch_backend('Agg') resolutions = list() for refinement_factor in self.refinement_factors: resolution = get_resolution_for_task( self.config, refinement_factor, self.refinement ) resolutions.append(resolution) config = self.config section = config[self.case_name] eval_time = section.getfloat('mixing_evaluation_time') s_per_day = 86400.0 zidx = 1 nrows = int(ceil(len(resolutions) / 2)) use_mplstyle() fig, axes = plt.subplots( nrows=nrows, ncols=2, sharex=True, sharey=True, figsize=(5.5, 7) ) for i, refinement_factor in enumerate(self.refinement_factors): ax = axes[int(i / 2), i % 2] _init_triplot_axes(ax) mesh_name = resolution_to_subdir(resolutions[i]) ax.set(title=mesh_name) ds = xr.open_dataset(f'output_r{refinement_factor:02g}.nc') if i % 2 == 0: ax.set_ylabel('tracer3') if int(i / 2) == nrows - 1: ax.set_xlabel('tracer2') tidx = time_index_from_xtime( ds.xtime.values, eval_time * s_per_day ) ds = ds.isel(Time=tidx) ds = ds.isel(nVertLevels=zidx) tracer2 = ds['tracer2'].values tracer3 = ds['tracer3'].values ax.plot(tracer2, tracer3, '.', markersize=1) ax.set_aspect('equal') if i % 2 < 1: ax = axes[int(i / 2), 1] ax.set_axis_off() plt.subplots_adjust(wspace=0.1, hspace=0.1) fig.suptitle('Correlated tracers 2-d') fig.savefig('triplots.png', bbox_inches='tight')
def _init_triplot_axes(ax): lw = 0.4 topline = Line2D( [0.1, 1.0], [0.9, 0.9], color='k', linestyle='-', linewidth=lw ) midline = Line2D( [0.1, 1.0], [0.9, 0.1], color='k', linestyle='-', linewidth=lw ) rightline = Line2D( [1, 1], [0.1, 0.9], color='k', linestyle='-', linewidth=lw ) leftline = Line2D( [0.1, 0.1], [0.1, 0.9], color='k', linestyle='-', linewidth=lw ) botline = Line2D( [0.1, 1.0], [0.1, 0.1], color='k', linestyle='-', linewidth=lw ) crvx = np.linspace(0.1, 1) crvy = -0.8 * np.square(crvx) + 0.9 ticks = np.array(range(6)) / 5 ax.plot(crvx, crvy, 'k-', linewidth=1.25 * lw) ax.set_xticks(ticks) ax.set_yticks(ticks) ax.add_artist(topline) ax.add_artist(midline) ax.add_artist(botline) ax.add_artist(rightline) ax.add_artist(leftline) ax.set_xlim([0, 1.1]) ax.set_ylim([0, 1.0]) ax.text( 0.98, 0.87, 'Range-preserving\n unmixing', fontsize=8, horizontalalignment='right', verticalalignment='top', ) ax.text( 0.12, 0.12, 'Range-preserving\n unmixing', fontsize=8, horizontalalignment='left', verticalalignment='bottom', ) ax.text(0.5, 0.27, 'Real mixing', rotation=-40.0, fontsize=8) ax.text(0.02, 0.1, 'Overshooting', rotation=90.0, fontsize=8) ax.grid(color='lightgray')