Source code for polaris.streams

import importlib.resources as imp_res
from copy import deepcopy

from jinja2 import Template
from lxml import etree


[docs] def read(package, streams_filename, tree=None, replacements=None): """ Parse the given streams file Parameters ---------- package : Package The package name or module object that contains the streams file streams_filename : str The name of the streams file to read from tree : lxml.etree, optional An existing set of streams to add to or modify replacements : dict, optional A dictionary of replacements, in which case ``streams_filename`` is assumed to be a Jinja2 template to be rendered with these replacements Returns ------- tree : lxml.etree A tree of XML data describing MPAS i/o streams with the content from the given streams file """ text = imp_res.files(package).joinpath(streams_filename).read_text() if replacements is not None: template = Template(text) text = template.render(**replacements) new_tree = etree.fromstring(text) tree = update_tree(tree, new_tree) return tree
[docs] def write(streams, out_filename): """write the streams XML data to the file""" with open(out_filename, 'w') as stream_file: stream_file.write('<streams>\n') # Write out all immutable streams first for stream in streams.findall('immutable_stream'): stream_name = stream.attrib['name'] stream_file.write('\n') stream_file.write(f'<immutable_stream name="{stream_name}"') # Process all attributes on the stream for attr, val in stream.attrib.items(): if attr.strip() != 'name': stream_file.write(f'\n {attr}="{val}"') stream_file.write('/>\n') # Write out all immutable streams for stream in streams.findall('stream'): stream_name = stream.attrib['name'] stream_file.write('\n') stream_file.write(f'<stream name="{stream_name}"') # Process all attributes for attr, val in stream.attrib.items(): if attr.strip() != 'name': stream_file.write(f'\n {attr}="{val}"') stream_file.write('>\n\n') # Write out all contents of the stream for tag in ['stream', 'var_struct', 'var_array', 'var']: for child in stream.findall(tag): child_name = child.attrib['name'] if tag == 'stream' and child_name == stream_name: # don't include the stream itself continue if 'packages' in child.attrib.keys(): package_name = child.attrib['packages'] entry = ( f' <{tag} name="{child_name}" ' f'packages="{package_name}"/>\n' ) else: entry = f' <{tag} name="{child_name}"/>\n' stream_file.write(entry) stream_file.write('</stream>\n') stream_file.write('\n') stream_file.write('</streams>\n')
[docs] def update_defaults(new_child, defaults): """ Update a stream or its children (sub-stream, var, etc.) starting from the defaults or add it if it's new. """ if 'name' not in new_child.attrib: return name = new_child.attrib['name'] found = False for child in defaults: if child.attrib['name'] == name: found = True if child.tag != new_child.tag: raise ValueError( f'Trying to update stream "{name}" with ' f'inconsistent tags {child.tag} vs. ' f'{new_child.tag}.' ) # copy the attributes for attr, value in new_child.attrib.items(): child.attrib[attr] = value if len(new_child) > 0: # we don't want default grandchildren for grandchild in child: child.remove(grandchild) # copy or add the grandchildren's contents for new_grandchild in new_child: update_defaults(new_grandchild, child) if not found: # add a deep copy of the element defaults.append(deepcopy(new_child))
[docs] def update_tree(tree, new_tree): """ Parse the given streams file Parameters ---------- tree : lxml.etree An existing set of streams to add to or modify new_tree : lxml.etree A new set of streams to add or modify Returns ------- tree : lxml.etree A tree of XML data describing MPAS i/o streams with the content from the given streams file """ if tree is None: tree = new_tree else: streams = next(tree.iter('streams')) new_streams = next(new_tree.iter('streams')) for new_stream in new_streams: _update_element(new_stream, streams) return tree
def _update_element(new_child, elements): """ add the new child/grandchildren or add/update attributes if they exist """ if 'name' not in new_child.attrib: return name = new_child.attrib['name'] found = False for child in elements: if child.attrib['name'] == name: found = True if child.tag != new_child.tag: raise ValueError( f'Trying to update stream "{name}" with ' f'inconsistent tags {child.tag} vs. ' f'{new_child.tag}.' ) # copy the attributes for attr, value in new_child.attrib.items(): child.attrib[attr] = value # copy or add the grandchildren's contents for new_grandchild in new_child: _update_element(new_grandchild, child) if not found: # add a deep copy of the element elements.append(deepcopy(new_child)) def set_default_io_type(tree, io_type='pnetcdf,cdf5'): """ Set io_type attribute for all <stream> and <immutable_stream> elements if not already set, except for immutable_stream with name 'mesh'. """ streams = next(tree.iter('streams')) all_streams = streams.findall('stream') + streams.findall( 'immutable_stream' ) for stream in all_streams: stream_type = stream.attrib.get('type') if 'io_type' not in stream.attrib and 'output' in stream_type: stream.attrib['io_type'] = io_type