Source code for polaris.tasks.ocean.sphere_transport.init

import numpy as np
import xarray as xr

from polaris.ocean.model import OceanIOStep
from polaris.ocean.vertical import init_vertical_coord
from polaris.tasks.ocean.sphere_transport.resources.flow_types import (
    flow_divergent,
    flow_nondivergent,
    flow_rotation,
    normal_velocity_from_zonal_meridional,
)
from polaris.tasks.ocean.sphere_transport.resources.tracer_distributions import (  # noqa: E501
    correlation_fn,
    cosine_bells,
    slotted_cylinders,
    xyztrig,
)


[docs] class Init(OceanIOStep): """ A step for an initial condition for for the cosine bell test case """
[docs] def __init__(self, component, name, subdir, base_mesh, case_name): """ Create the step Parameters ---------- component : polaris.Component The component the step belongs to name : str The name of the step subdir : str The subdirectory for the step base_mesh : polaris.Step The base mesh step case_name: str The name of the test case """ super().__init__(component=component, name=name, subdir=subdir) self.case_name = case_name self.add_input_file( filename='mesh.nc', work_dir_target=f'{base_mesh.path}/base_mesh.nc', ) self.add_input_file( filename='graph.info', work_dir_target=f'{base_mesh.path}/graph.info', ) self.add_output_file(filename='initial_state.nc')
[docs] def run(self): """ Run this step of the task """ config = self.config case_name = self.case_name section = config['sphere_transport'] temperature = section.getfloat('temperature') salinity = section.getfloat('salinity') # time (hours) for bell to transit equator once vel_pd = section.getfloat('vel_pd') section = config['vertical_grid'] bottom_depth = section.getfloat('bottom_depth') ds_mesh = xr.open_dataset('mesh.nc') latCell = ds_mesh.latCell latEdge = ds_mesh.latEdge lonCell = ds_mesh.lonCell lonEdge = ds_mesh.lonEdge sphere_radius = ds_mesh.sphere_radius ds = ds_mesh.copy() ds['bottomDepth'] = bottom_depth * xr.ones_like(latCell) ds['ssh'] = xr.zeros_like(latCell) init_vertical_coord(config, ds) temperature_array = temperature * xr.ones_like(latCell) temperature_array, _ = xr.broadcast(temperature_array, ds.refZMid) ds['temperature'] = temperature_array.expand_dims(dim='Time', axis=0) ds['salinity'] = salinity * xr.ones_like(ds.temperature) # tracer1 tracer1 = xyztrig(lonCell, latCell, sphere_radius) # tracer2 section = config['sphere_transport'] radius = section.getfloat('cosine_bells_radius') background_value = section.getfloat('cosine_bells_background') amplitude = section.getfloat('cosine_bells_amplitude') tracer2 = cosine_bells( lonCell, latCell, radius, background_value, amplitude, sphere_radius, ) # tracer3 if case_name == 'correlated_tracers_2d': coeff = config.getlist( case_name, 'correlation_coefficients', dtype=float ) tracer3 = correlation_fn(tracer2, coeff[0], coeff[1], coeff[2]) else: section = config['sphere_transport'] radius = section.getfloat('slotted_cylinders_radius') background_value = section.getfloat('slotted_cylinders_background') amplitude = section.getfloat('slotted_cylinders_amplitude') tracer3 = slotted_cylinders( lonCell, latCell, radius, background_value, amplitude, sphere_radius, ) _, tracer1_array = np.meshgrid(ds.refZMid.values, tracer1) _, tracer2_array = np.meshgrid(ds.refZMid.values, tracer2) _, tracer3_array = np.meshgrid(ds.refZMid.values, tracer3) ds['tracer1'] = ( ( 'nCells', 'nVertLevels', ), tracer1_array, ) ds['tracer1'] = ds.tracer1.expand_dims(dim='Time', axis=0) ds['tracer2'] = ( ( 'nCells', 'nVertLevels', ), tracer2_array, ) ds['tracer2'] = ds.tracer2.expand_dims(dim='Time', axis=0) ds['tracer3'] = ( ( 'nCells', 'nVertLevels', ), tracer3_array, ) ds['tracer3'] = ds.tracer3.expand_dims(dim='Time', axis=0) # Initialize velocity s_per_hour = 3600.0 if case_name == 'rotation_2d': rotation_vector = config.getlist( case_name, 'rotation_vector', dtype=float ) vector = np.array(rotation_vector) u, v = flow_rotation( lonEdge, latEdge, vector, vel_pd * s_per_hour, sphere_radius ) elif case_name == 'divergent_2d': section = config[case_name] vel_amp = section.getfloat('vel_amp') u, v = flow_divergent( 0.0, lonEdge, latEdge, vel_amp, vel_pd * s_per_hour ) elif ( case_name == 'nondivergent_2d' or case_name == 'correlated_tracers_2d' ): section = config[case_name] vel_amp = section.getfloat('vel_amp') u, v = flow_nondivergent( 0.0, lonEdge, latEdge, vel_amp, vel_pd * s_per_hour ) else: raise ValueError(f'Unexpected test case name {case_name}') normalVelocity = normal_velocity_from_zonal_meridional( ds, u, v, recompute_angle_edge=False ) normalVelocity, _ = xr.broadcast(normalVelocity, ds.refZMid) ds['normalVelocity'] = normalVelocity.expand_dims(dim='Time', axis=0) ds['fCell'] = xr.zeros_like(ds_mesh.xCell) ds['fEdge'] = xr.zeros_like(ds_mesh.xEdge) ds['fVertex'] = xr.zeros_like(ds_mesh.xVertex) self.write_model_dataset(ds, 'initial_state.nc')