Source code for polaris.tasks.ocean.inertial_gravity_wave.init

import numpy as np
import xarray as xr
from mpas_tools.io import write_netcdf
from mpas_tools.mesh.conversion import convert, cull
from mpas_tools.planar_hex import make_planar_hex_mesh

from polaris import Step
from polaris.mesh.planar import compute_planar_hex_nx_ny
from polaris.ocean.vertical import init_vertical_coord
from polaris.resolution import resolution_to_string
from polaris.tasks.ocean.inertial_gravity_wave.exact_solution import (
    ExactSolution,
)


[docs] class Init(Step): """ A step for creating a mesh and initial condition for the inertial gravity wave test cases Attributes ---------- resolution : float The resolution of the test case in km """
[docs] def __init__(self, component, resolution, subdir): """ Create the step Parameters ---------- component : polaris.Component The component the step belongs to resolution : float The resolution of the test case in km subdir : str The subdirectory that the task belongs to """ mesh_name = resolution_to_string(resolution) super().__init__( component=component, name=f'init_{mesh_name}', subdir=subdir ) self.resolution = resolution for filename in [ 'culled_mesh.nc', 'initial_state.nc', 'culled_graph.info', ]: self.add_output_file(filename=filename)
[docs] def run(self): """ Run this step of the test case """ logger = self.logger config = self.config section = config['inertial_gravity_wave'] resolution = self.resolution lx = section.getfloat('lx') ly = np.sqrt(3.0) / 2.0 * lx f0 = section.getfloat('coriolis_parameter') nx, ny = compute_planar_hex_nx_ny(lx, ly, resolution) dc = 1e3 * resolution ds_mesh = make_planar_hex_mesh( nx=nx, ny=ny, dc=dc, nonperiodic_x=False, nonperiodic_y=False ) write_netcdf(ds_mesh, 'base_mesh.nc') ds_mesh = cull(ds_mesh, logger=logger) ds_mesh = convert( ds_mesh, graphInfoFileName='culled_graph.info', logger=logger ) write_netcdf(ds_mesh, 'culled_mesh.nc') bottom_depth = config.getfloat('vertical_grid', 'bottom_depth') ds = ds_mesh.copy() ds['ssh'] = xr.zeros_like(ds_mesh.xCell) ds['bottomDepth'] = bottom_depth * xr.ones_like(ds_mesh.xCell) init_vertical_coord(config, ds) ds['fCell'] = f0 * xr.ones_like(ds_mesh.xCell) ds['fEdge'] = f0 * xr.ones_like(ds_mesh.xEdge) ds['fVertex'] = f0 * xr.ones_like(ds_mesh.xVertex) exact_solution = ExactSolution(ds, config) ssh = exact_solution.ssh(0.0) ssh = ssh.expand_dims(dim='Time', axis=0) ds['ssh'] = ssh layerThickness = ssh + bottom_depth layerThickness, _ = xr.broadcast(layerThickness, ds.refBottomDepth) layerThickness = layerThickness.transpose( 'Time', 'nCells', 'nVertLevels' ) ds['layerThickness'] = layerThickness normal_velocity = exact_solution.normal_velocity(0.0) normal_velocity, _ = xr.broadcast(normal_velocity, ds.refBottomDepth) normal_velocity = normal_velocity.transpose('nEdges', 'nVertLevels') normal_velocity = normal_velocity.expand_dims(dim='Time', axis=0) ds['normalVelocity'] = normal_velocity write_netcdf(ds, 'initial_state.nc')