Source code for polaris.ocean.tasks.sphere_transport.filament_analysis

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr

from polaris import Step
from polaris.mpas import time_index_from_xtime
from polaris.ocean.convergence import get_resolution_for_task
from polaris.ocean.resolution import resolution_to_subdir
from polaris.viz import use_mplstyle


[docs] class FilamentAnalysis(Step): """ A step for analyzing the output from sphere transport test cases Attributes ---------- resolutions : list of float The resolutions of the meshes that have been run icosahedral : bool Whether to use icosahedral, as opposed to less regular, JIGSAW meshes case_name : str The name of the test case """
[docs] def __init__(self, component, refinement_factors, icosahedral, subdir, case_name, dependencies, refinement='both'): """ Create the step Parameters ---------- component : polaris.Component The component the step belongs to resolutions : list of float The resolutions of the meshes that have been run icosahedral : bool Whether to use icosahedral, as opposed to less regular, JIGSAW meshes subdir : str The subdirectory that the step resides in case_name: str The name of the test case dependencies : dict of dict of polaris.Steps The dependencies of this step """ super().__init__(component=component, name='filament_analysis', subdir=subdir) self.refinement_factors = refinement_factors self.refinement = refinement self.case_name = case_name for refinement_factor in refinement_factors: base_mesh = dependencies['mesh'][refinement_factor] init = dependencies['init'][refinement_factor] forward = dependencies['forward'][refinement_factor] self.add_input_file( filename=f'mesh_r{refinement_factor:02g}.nc', work_dir_target=f'{base_mesh.path}/base_mesh.nc') self.add_input_file( filename=f'init_r{refinement_factor:02g}.nc', work_dir_target=f'{init.path}/initial_state.nc') self.add_input_file( filename=f'output_r{refinement_factor:02g}.nc', work_dir_target=f'{forward.path}/output.nc') self.add_output_file('filament.png')
[docs] def run(self): """ Run this step of the test case """ plt.switch_backend('Agg') resolutions = list() for refinement_factor in self.refinement_factors: resolution = get_resolution_for_task( self.config, refinement_factor, self.refinement) resolutions.append(resolution) config = self.config section = config[self.case_name] eval_time = section.getfloat('filament_evaluation_time') s_per_day = 86400.0 zidx = 1 variable_name = 'tracer2' num_tau = 21 filament_tau = np.linspace(0, 1, num_tau) filament_norm = np.zeros((len(resolutions), num_tau)) use_mplstyle() fig, ax = plt.subplots() for i, refinement_factor in enumerate(self.refinement_factors): mesh_name = resolution_to_subdir(resolutions[i]) ds = xr.open_dataset(f'output_r{refinement_factor:02g}.nc') tidx = time_index_from_xtime(ds.xtime.values, eval_time * s_per_day) tracer = ds[variable_name] area_cell = ds["areaCell"] for j, tau in enumerate(filament_tau): cells_above_tau = tracer[tidx, :, zidx] >= tau cells_above_tau0 = tracer[0, :, zidx] >= tau if np.sum(cells_above_tau0 * area_cell) == 0.: filament_norm[i, j] = np.nan else: filament_norm[i, j] = np.divide( np.sum(area_cell * cells_above_tau), np.sum(cells_above_tau0 * area_cell)) plt.plot(filament_tau, filament_norm[i, :], '.-', label=mesh_name) plt.plot([filament_tau[0], filament_tau[-1]], [1., 1.], 'k--') ax.set_xlim([filament_tau[0], filament_tau[-1]]) ax.set_xlabel(r'$\tau$') ax.set_ylabel(r'$l_f$') plt.title(f'Filament preservation diagnostic for {variable_name}') plt.legend() fig.savefig('filament.png', bbox_inches='tight') res_array = np.array(resolutions, dtype=float) data = np.column_stack((res_array, filament_norm)) col_headers = ['resolution'] for tau in filament_tau: col_headers.append(f'{tau:g}') df = pd.DataFrame(data, columns=col_headers) df.to_csv('filament.csv', index=False)