Skip to content

File atm_field_data_provider.cpp

File List > components > emulator_comps > eatm > src > impl > atm_field_data_provider.cpp

Go to the documentation of this file

#include "atm_field_data_provider.hpp"
#include <algorithm>
#include <regex>

namespace emulator {
namespace impl {

AtmFieldDataProvider::AtmFieldDataProvider(AtmFieldManager &fields,
                                           int ncols_local)
    : m_fields(fields), m_ncols(ncols_local) {}

const std::vector<double> *
AtmFieldDataProvider::get_field(const std::string &name) const {
  // First try direct field access
  std::vector<double> *ptr = m_fields.get_field_ptr(name);
  if (ptr != nullptr) {
    return ptr;
  }

  // Check if it's a stacked field
  if (is_stacked_field(name)) {
    build_stacked_cache(name);
    auto it = m_stacked_cache.find(name);
    if (it != m_stacked_cache.end()) {
      return &it->second;
    }
  }

  return nullptr;
}

std::vector<std::string> AtmFieldDataProvider::get_field_names() const {
  if (m_field_names_cached) {
    return std::vector<std::string>(m_all_field_names.begin(),
                                    m_all_field_names.end());
  }

  // Collect all known field names from field manager
  // First add hardcoded fields by trying common names
  static const char *common_fields[] = {
      // Import fields
      "shf", "cflx", "lhf", "wsx", "wsy", "lwup", "asdir", "aldir", "asdif",
      "aldif", "ts", "sst", "snowhland", "snowhice", "tref", "qref", "u10",
      "u10withgusts", "icefrac", "ocnfrac", "lndfrac",
      // Export fields
      "zbot", "ubot", "vbot", "tbot", "ptem", "shum", "dens", "pbot", "pslv",
      "lwdn", "rainc", "rainl", "snowc", "snowl", "swndr", "swvdr", "swndf",
      "swvdf", "swnet"};

  for (const char *name : common_fields) {
    if (m_fields.get_field_ptr(name) != nullptr) {
      m_all_field_names.insert(name);
    }
  }

  // Add all dynamic fields
  for (const auto &pair : m_fields.dynamic_fields) {
    m_all_field_names.insert(pair.first);
  }

  // Add stacked field basenames
  for (const auto &pair : m_stacked_field_levels) {
    m_all_field_names.insert(pair.first + "_3d");
  }

  m_field_names_cached = true;
  return std::vector<std::string>(m_all_field_names.begin(),
                                  m_all_field_names.end());
}

int AtmFieldDataProvider::get_field_nlevs(const std::string &name) const {
  // Check stacked fields first
  auto it = m_stacked_field_levels.find(name);
  if (it != m_stacked_field_levels.end()) {
    return static_cast<int>(it->second.size());
  }

  // Check if it's a "_3d" stacked field
  if (name.size() > 3 && name.substr(name.size() - 3) == "_3d") {
    std::string basename = name.substr(0, name.size() - 3);
    auto it2 = m_stacked_field_levels.find(basename);
    if (it2 != m_stacked_field_levels.end()) {
      return static_cast<int>(it2->second.size());
    }
  }

  // Single slice from larger field
  const std::vector<double> *ptr = get_field(name);
  if (ptr != nullptr) {
    int total_size = static_cast<int>(ptr->size());
    if (total_size > m_ncols) {
      return total_size / m_ncols;
    }
  }

  return 1; // Default: 2D field
}

void AtmFieldDataProvider::detect_stacked_fields() {
  m_stacked_field_levels.clear();

  // Scan dynamic fields for slice patterns
  for (const auto &pair : m_fields.dynamic_fields) {
    std::string basename;
    int level_idx;

    if (parse_slice_pattern(pair.first, basename, level_idx)) {
      m_stacked_field_levels[basename].push_back(level_idx);
    }
  }

  // Sort levels for each basename
  for (auto &pair : m_stacked_field_levels) {
    std::sort(pair.second.begin(), pair.second.end());
  }

  // Invalidate field names cache
  m_field_names_cached = false;
}

bool AtmFieldDataProvider::is_stacked_field(const std::string &name) const {
  // Direct check
  if (m_stacked_field_levels.find(name) != m_stacked_field_levels.end()) {
    return true;
  }

  // Check "_3d" suffix
  if (name.size() > 3 && name.substr(name.size() - 3) == "_3d") {
    std::string basename = name.substr(0, name.size() - 3);
    return m_stacked_field_levels.find(basename) !=
           m_stacked_field_levels.end();
  }

  return false;
}

bool AtmFieldDataProvider::parse_slice_pattern(const std::string &name,
                                               std::string &basename,
                                               int &level_idx) const {
  // Pattern: basename_N where N is an integer
  static const std::regex pattern(R"((.+)_(\d+)$)");
  std::smatch match;

  if (std::regex_match(name, match, pattern)) {
    basename = match[1].str();
    level_idx = std::stoi(match[2].str());
    return true;
  }

  return false;
}

const std::vector<double> &
AtmFieldDataProvider::get_stacked_field(const std::string &basename) const {
  build_stacked_cache(basename);

  auto it = m_stacked_cache.find(basename);
  if (it != m_stacked_cache.end()) {
    return it->second;
  }

  // Return empty vector if not found
  static const std::vector<double> empty;
  return empty;
}

void AtmFieldDataProvider::build_stacked_cache(
    const std::string &basename) const {
  // Check if already cached
  if (m_stacked_cache.find(basename) != m_stacked_cache.end()) {
    return;
  }

  // Get actual basename (strip _3d if present)
  std::string actual_basename = basename;
  if (basename.size() > 3 && basename.substr(basename.size() - 3) == "_3d") {
    actual_basename = basename.substr(0, basename.size() - 3);
  }

  auto it = m_stacked_field_levels.find(actual_basename);
  if (it == m_stacked_field_levels.end()) {
    return; // Not a stacked field
  }

  const std::vector<int> &levels = it->second;
  int nlevs = static_cast<int>(levels.size());

  // Allocate stacked buffer: [nlevs, ncols] in row-major order
  std::vector<double> stacked(nlevs * m_ncols, 0.0);

  // Copy each level
  for (int lev = 0; lev < nlevs; ++lev) {
    std::string slice_name =
        actual_basename + "_" + std::to_string(levels[lev]);
    std::vector<double> *slice = m_fields.get_field_ptr(slice_name);

    if (slice != nullptr && slice->size() >= static_cast<size_t>(m_ncols)) {
      // Copy to stacked buffer
      std::copy(slice->begin(), slice->begin() + m_ncols,
                stacked.begin() + lev * m_ncols);
    }
  }

  m_stacked_cache[basename] = std::move(stacked);
}

} // namespace impl
} // namespace emulator