Source code for polaris.ocean.viz.transect.vert

import numpy as np
import xarray as xr

from polaris.ocean.viz.transect.horiz import (
    find_planar_transect_cells_and_weights,
    find_spherical_transect_cells_and_weights,
    make_triangle_tree,
    mesh_to_triangles,
)


[docs] def compute_transect( x, y, ds_horiz_mesh, layer_thickness, bottom_depth, min_level_cell, max_level_cell, spherical=False, ): """ build a sequence of quads showing the transect intersecting mpas cells. This can be used to plot transects of fields with dimensions ``nCells`` and ``nVertLevels`` using :py:func:`polaris.ocean.viz.plot_transect()` Parameters ---------- x : xarray.DataArray The x or longitude coordinate of the transect y : xarray.DataArray The y or latitude coordinate of the transect ds_horiz_mesh : xarray.Dataset The horizontal MPAS mesh to use for plotting layer_thickness : xarray.DataArray The layer thickness at a particular instant in time. `layerThickness.isel(Time=tidx)` to select a particular time index `tidx` if the original data array contains `Time`. bottom_depth : xarray.DataArray the (positive down) depth of the seafloor on the MPAS mesh min_level_cell : xarray.DataArray the vertical zero-based index of the sea surface on the MPAS mesh max_level_cell : xarray.DataArray the vertical zero-based index of the bathymetry on the MPAS mesh spherical : bool, optional Whether the x and y coordinates are latitude and longitude in degrees Returns ------- ds_transect : xarray.Dataset The transect dataset, see :py:func:`polaris.ocean.viz.transect.vert.find_transect_levels_and_weights()` for details """ # noqa: E501 ds_tris = mesh_to_triangles(ds_horiz_mesh) triangle_tree = make_triangle_tree(ds_tris) if spherical: ds_horiz_transect = find_spherical_transect_cells_and_weights( x, y, ds_tris, ds_horiz_mesh, triangle_tree, degrees=True ) else: ds_horiz_transect = find_planar_transect_cells_and_weights( x, y, ds_tris, ds_horiz_mesh, triangle_tree ) # mask horizontal transect to valid cells (max_level_cell >= 0) cell_indices = ds_horiz_transect.horizCellIndices seg_mask = max_level_cell.isel(nCells=cell_indices).values >= 0 node_mask = np.zeros(ds_horiz_transect.sizes['nNodes'], dtype=bool) node_mask[0:-1] = seg_mask node_mask[1:] = np.logical_or(node_mask[1:], seg_mask) ds_horiz_transect = ds_horiz_transect.isel( nSegments=seg_mask, nNodes=node_mask ) ds_transect = find_transect_levels_and_weights( ds_horiz_transect=ds_horiz_transect, layer_thickness=layer_thickness, bottom_depth=bottom_depth, min_level_cell=min_level_cell, max_level_cell=max_level_cell, ) ds_transect.compute() return ds_transect
[docs] def find_transect_levels_and_weights( ds_horiz_transect, layer_thickness, bottom_depth, min_level_cell, max_level_cell, ): """ Construct a vertical coordinate for a transect produced by :py:func:`polaris.ocean.viz.transect.horiz.find_spherical_transect_cells_and_weights()` or :py:func:`polaris.ocean.viz.transect.horiz.find_planar_transect_cells_and_weights()`. Also, compute interpolation weights such that observations at points on the original transect and with vertical coordinate ``transectZ`` can be bilinearly interpolated to the nodes of the transect. Parameters ---------- ds_horiz_transect : xarray.Dataset A dataset that defines nodes of the transect layer_thickness : xarray.DataArray layer thicknesses on the MPAS mesh bottom_depth : xarray.DataArray the (positive down) depth of the seafloor on the MPAS mesh min_level_cell : xarray.DataArray the vertical zero-based index of the sea surface on the MPAS mesh max_level_cell : xarray.DataArray the vertical zero-based index of the bathymetry on the MPAS mesh Returns ------- ds_transect : xarray.Dataset A dataset that contains nodes and cells that make up a 2D transect. There are ``nSegments`` horizontal and ``nHalfLevels`` vertical transect cells (quadrilaterals), bounded by ``nHorizNodes`` horizontal and ``nVertNodes`` vertical nodes (corners). In addition to the variables and coordinates in the input ``ds_transect``, the output dataset contains: - ``validCells``, ``validNodes``: which transect cells and nodes are valid (above the bathymetry and below the sea surface) - zTransectNode: the vertical height of each triangle node - ssh, zSeaFloor: the sea-surface height and sea-floor height at each node of each transect segment - ``cellIndices``: the MPAS-Ocean cell of a given transect segment - ``levelIndices``: the MPAS-Ocean vertical level of a given transect level - ``interpCellIndices``, ``interpLevelIndices``: the MPAS-Ocean cells and levels from which the value at a given transect cell is interpolated. This can involve up to ``nHorizWeights * nVertWeights = 12`` different cells and levels. - interpCellWeights: the weight to multiply each field value by to perform interpolation to a transect cell. - ``dInterfaceSegment``, ``zInterfaceSegment`` - segments that can be used to plot the interfaces between MPAS-Ocean layers - ``dCellBoundary``, ``zCellBoundary`` - segments that can be used to plot the vertical boundaries between MPAS-Ocean cells Interpolation of a DataArray from MPAS cells and levels to transect cells can be performed with :py:func:`polaris.ocean.viz.transect.vert.interp_mpas_to_transect_cells()`. Similarly, interpolation to transect nodes can be performed with :py:func:`polaris.ocean.viz.transect.vert.interp_mpas_to_transect_nodes()`. """ # noqa: E501 if 'Time' in layer_thickness.dims: raise ValueError( 'Please select a single time level in layer thickness.' ) ds_transect_cells = ds_horiz_transect.rename({'nNodes': 'nHorizNodes'}) ( z_half_interface, ssh, z_seafloor, interp_cell_indices, interp_cell_weights, valid_transect_cells, level_indices, ) = _get_vertical_coordinate( ds_transect_cells, layer_thickness, bottom_depth, min_level_cell, max_level_cell, ) ds_transect_cells['zTransectNode'] = z_half_interface ds_transect_cells['ssh'] = ssh ds_transect_cells['zSeafloor'] = z_seafloor ds_transect_cells['cellIndices'] = ds_transect_cells.horizCellIndices ds_transect_cells['levelIndices'] = level_indices ds_transect_cells['validCells'] = valid_transect_cells d_interface_seg, z_interface_seg = _get_interface_segments( z_half_interface, ds_transect_cells.dNode, valid_transect_cells ) ds_transect_cells['dInterfaceSegment'] = d_interface_seg ds_transect_cells['zInterfaceSegment'] = z_interface_seg d_cell_boundary, z_cell_boundary = _get_cell_boundary_segments( ssh, z_seafloor, ds_transect_cells.dNode, ds_transect_cells.horizCellIndices, ) ds_transect_cells['dCellBoundary'] = d_cell_boundary ds_transect_cells['zCellBoundary'] = z_cell_boundary interp_level_indices, interp_cell_weights, valid_nodes = ( _get_interp_indices_and_weights( layer_thickness, interp_cell_indices, interp_cell_weights, level_indices, valid_transect_cells, ) ) ds_transect_cells['interpCellIndices'] = interp_cell_indices ds_transect_cells['interpLevelIndices'] = interp_level_indices ds_transect_cells['interpCellWeights'] = interp_cell_weights ds_transect_cells['validNodes'] = valid_nodes dims = [ 'nSegments', 'nHalfLevels', 'nHorizNodes', 'nVertNodes', 'nInterfaceSegments', 'nCellBoundaries', 'nHorizBounds', 'nVertBounds', 'nHorizWeights', 'nVertWeights', ] for dim in ds_transect_cells.dims: if dim not in dims: dims.insert(0, dim) ds_transect_cells = ds_transect_cells.transpose(*dims) return ds_transect_cells
[docs] def interp_mpas_to_transect_cells(ds_transect, da): """ Interpolate an MPAS-Ocean DataArray with dimensions ``nCells`` by ``nVertLevels`` to transect cells Parameters ---------- ds_transect : xarray.Dataset A dataset that defines an MPAS-Ocean transect, the results of calling ``find_transect_levels_and_weights()`` da : xarray.DataArray An MPAS-Ocean field with dimensions `nCells`` and ``nVertLevels`` (possibly among others) Returns ------- da_cells : xarray.DataArray The data array interpolated to transect cells with dimensions ``nSegments`` and ``nHalfLevels`` (in addition to whatever dimensions were in ``da`` besides ``nCells`` and ``nVertLevels``) """ cell_indices = ds_transect.cellIndices level_indices = ds_transect.levelIndices da_cells = da.isel(nCells=cell_indices, nVertLevels=level_indices) da_cells = da_cells.where(ds_transect.validCells) return da_cells
[docs] def interp_mpas_to_transect_nodes(ds_transect, da): """ Interpolate an MPAS-Ocean DataArray with dimensions ``nCells`` by ``nVertLevels`` to transect nodes, linearly interpolating fields between the closest neighboring cells Parameters ---------- ds_transect : xarray.Dataset A dataset that defines an MPAS-Ocean transect, the results of calling ``find_transect_levels_and_weights()`` da : xarray.DataArray An MPAS-Ocean field with dimensions `nCells`` and ``nVertLevels`` (possibly among others) Returns ------- da_nodes : xarray.DataArray The data array interpolated to transect nodes with dimensions ``nHorizNodes`` and ``nVertNodes`` (in addition to whatever dimensions were in ``da`` besides ``nCells`` and ``nVertLevels``) """ interp_cell_indices = ds_transect.interpCellIndices interp_level_indices = ds_transect.interpLevelIndices interp_cell_weights = ds_transect.interpCellWeights da = da.isel(nCells=interp_cell_indices, nVertLevels=interp_level_indices) da_nodes = (da * interp_cell_weights).sum( dim=('nHorizWeights', 'nVertWeights') ) da_nodes = da_nodes.where(ds_transect.validNodes) return da_nodes
def _get_vertical_coordinate( ds_transect, layer_thickness, bottom_depth, min_level_cell, max_level_cell ): n_horiz_nodes = ds_transect.sizes['nHorizNodes'] n_segments = ds_transect.sizes['nSegments'] n_vert_levels = layer_thickness.sizes['nVertLevels'] # we assume below that there is a segment (whether valid or invalid) # connecting each pair of adjacent nodes assert n_horiz_nodes == n_segments + 1 interp_horiz_cell_indices = ds_transect.interpHorizCellIndices interp_horiz_cell_weights = ds_transect.interpHorizCellWeights bottom_depth_interp = bottom_depth.isel(nCells=interp_horiz_cell_indices) layer_thickness_interp = layer_thickness.isel( nCells=interp_horiz_cell_indices ) cell_mask_interp = _get_cell_mask( interp_horiz_cell_indices, min_level_cell, max_level_cell, n_vert_levels, ) layer_thickness_interp = layer_thickness_interp.where( cell_mask_interp, 0.0 ) ssh_interp = -bottom_depth_interp + layer_thickness_interp.sum( dim='nVertLevels' ) interp_mask = np.logical_and( interp_horiz_cell_indices > 0, cell_mask_interp ) interp_cell_weights = interp_mask * interp_horiz_cell_weights weight_sum = interp_cell_weights.sum(dim='nHorizWeights') cell_indices = ds_transect.horizCellIndices valid_cells = _get_cell_mask( cell_indices, min_level_cell, max_level_cell, n_vert_levels ) valid_cells = valid_cells.transpose('nSegments', 'nVertLevels').values valid_nodes = np.zeros((n_horiz_nodes, n_vert_levels), dtype=bool) valid_nodes[0:-1, :] = valid_cells valid_nodes[1:, :] = np.logical_or(valid_nodes[1:, :], valid_cells) valid_nodes = xr.DataArray( dims=('nHorizNodes', 'nVertLevels'), data=valid_nodes ) valid_weights = valid_nodes.broadcast_like(interp_cell_weights) interp_cell_weights = (interp_cell_weights / weight_sum).where( valid_weights ) layer_thickness_transect = ( layer_thickness_interp * interp_cell_weights ).sum(dim='nHorizWeights') interp_mask = max_level_cell.isel(nCells=interp_horiz_cell_indices) >= 0 interp_horiz_cell_weights = interp_mask * interp_horiz_cell_weights weight_sum = interp_horiz_cell_weights.sum(dim='nHorizWeights') interp_horiz_cell_weights = (interp_horiz_cell_weights / weight_sum).where( interp_mask ) ssh_transect = (ssh_interp * interp_horiz_cell_weights).sum( dim='nHorizWeights' ) z_bot = ssh_transect - layer_thickness_transect.cumsum(dim='nVertLevels') z_mid = z_bot + 0.5 * layer_thickness_transect z_half_interfaces = [ssh_transect] for z_index in range(n_vert_levels): z_half_interfaces.extend( [z_mid.isel(nVertLevels=z_index), z_bot.isel(nVertLevels=z_index)] ) z_half_interface = xr.concat(z_half_interfaces, dim='nVertNodes') z_half_interface = z_half_interface.transpose('nHorizNodes', 'nVertNodes') z_seafloor = ssh_transect - layer_thickness_transect.sum(dim='nVertLevels') valid_transect_cells = np.zeros( (n_segments, 2 * n_vert_levels), dtype=bool ) valid_transect_cells[:, 0::2] = valid_cells valid_transect_cells[:, 1::2] = valid_cells valid_transect_cells = xr.DataArray( dims=('nSegments', 'nHalfLevels'), data=valid_transect_cells ) level_indices = np.zeros(2 * n_vert_levels, dtype=int) level_indices[0::2] = np.arange(n_vert_levels) level_indices[1::2] = np.arange(n_vert_levels) level_indices = xr.DataArray(dims=('nHalfLevels',), data=level_indices) return ( z_half_interface, ssh_transect, z_seafloor, interp_horiz_cell_indices, interp_cell_weights, valid_transect_cells, level_indices, ) def _get_cell_mask( cell_indices, min_level_cell, max_level_cell, n_vert_levels ): level_indices = xr.DataArray( data=np.arange(n_vert_levels), dims='nVertLevels' ) min_level_cell = min_level_cell.isel(nCells=cell_indices) max_level_cell = max_level_cell.isel(nCells=cell_indices) cell_mask = np.logical_and( level_indices >= min_level_cell, level_indices <= max_level_cell ) cell_mask = np.logical_and(cell_mask, cell_indices >= 0) return cell_mask def _get_interface_segments(z_half_interface, d_node, valid_transect_cells): d = d_node.broadcast_like(z_half_interface) z_interface = z_half_interface.values[:, 0::2] d = d.values[:, 0::2] n_segments = valid_transect_cells.sizes['nSegments'] n_half_levels = valid_transect_cells.sizes['nHalfLevels'] n_vert_levels = n_half_levels // 2 valid_segs = np.zeros((n_segments, n_vert_levels + 1), dtype=bool) valid_segs[:, 0:-1] = valid_transect_cells.values[:, 1::2] valid_segs[:, 1:] = np.logical_or( valid_segs[:, 1:], valid_transect_cells.values[:, 0::2] ) n_interface_segs = np.count_nonzero(valid_segs) d_seg = np.zeros((n_interface_segs, 2)) z_seg = np.zeros((n_interface_segs, 2)) d_seg[:, 0] = d[0:-1, :][valid_segs] d_seg[:, 1] = d[1:, :][valid_segs] z_seg[:, 0] = z_interface[0:-1, :][valid_segs] z_seg[:, 1] = z_interface[1:, :][valid_segs] d_seg = xr.DataArray( dims=('nInterfaceSegments', 'nHorizBounds'), data=d_seg ) z_seg = xr.DataArray( dims=('nInterfaceSegments', 'nHorizBounds'), data=z_seg ) return d_seg, z_seg def _get_cell_boundary_segments(ssh, z_seafloor, d_node, cell_indices): n_horiz_nodes = d_node.sizes['nHorizNodes'] cell_boundary = np.ones(n_horiz_nodes, dtype=bool) cell_boundary[1:-1] = cell_indices.values[0:-1] != cell_indices.values[1:] n_cell_boundaries = np.count_nonzero(cell_boundary) d_seg = np.zeros((n_cell_boundaries, 2)) z_seg = np.zeros((n_cell_boundaries, 2)) d_seg[:, 0] = d_node.values[cell_boundary] d_seg[:, 1] = d_seg[:, 0] z_seg[:, 0] = ssh[cell_boundary] z_seg[:, 1] = z_seafloor[cell_boundary] d_seg = xr.DataArray(dims=('nCellBoundaries', 'nVertBounds'), data=d_seg) z_seg = xr.DataArray(dims=('nCellBoundaries', 'nVertBounds'), data=z_seg) return d_seg, z_seg def _get_interp_indices_and_weights( layer_thickness, interp_cell_indices, interp_cell_weights, level_indices, valid_transect_cells, ): n_horiz_nodes = interp_cell_indices.sizes['nHorizNodes'] n_vert_levels = layer_thickness.sizes['nVertLevels'] n_vert_nodes = 2 * n_vert_levels + 1 n_vert_weights = 2 interp_level_indices = -1 * np.ones( (n_vert_nodes, n_vert_weights), dtype=int ) interp_level_indices[1:, 0] = level_indices.values interp_level_indices[0:-1, 1] = level_indices.values interp_level_indices = xr.DataArray( dims=('nVertNodes', 'nVertWeights'), data=interp_level_indices ) half_level_thickness = 0.5 * layer_thickness.isel( nCells=interp_cell_indices, nVertLevels=interp_level_indices ) half_level_thickness = half_level_thickness.where( interp_level_indices >= 0, other=0.0 ) # vertical weights are proportional to the half-level thickness interp_cell_weights = half_level_thickness * interp_cell_weights.isel( nVertLevels=interp_level_indices ) valid_nodes = np.zeros((n_horiz_nodes, n_vert_nodes), dtype=bool) valid_nodes[0:-1, 0:-1] = valid_transect_cells valid_nodes[1:, 0:-1] = np.logical_or( valid_nodes[1:, 0:-1], valid_transect_cells ) valid_nodes[0:-1, 1:] = np.logical_or( valid_nodes[0:-1, 1:], valid_transect_cells ) valid_nodes[1:, 1:] = np.logical_or( valid_nodes[1:, 1:], valid_transect_cells ) valid_nodes = xr.DataArray( dims=('nHorizNodes', 'nVertNodes'), data=valid_nodes ) weight_sum = interp_cell_weights.sum(dim=('nHorizWeights', 'nVertWeights')) out_mask = (weight_sum > 0.0).broadcast_like(interp_cell_weights) interp_cell_weights = (interp_cell_weights / weight_sum).where(out_mask) return interp_level_indices, interp_cell_weights, valid_nodes