File emulator_atm.cpp
File List > components > emulator_comps > eatm > src > emulator_atm.cpp
Go to the documentation of this file
#include "emulator_atm.hpp"
#include "emulator_config.hpp"
#include "impl/atm_io.hpp"
#include <cstring>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <mpi.h>
namespace emulator {
EmulatorAtm::EmulatorAtm() : EmulatorComp(CompType::ATM) {
m_inference_config.backend = inference::BackendType::STUB;
m_inference_config.verbose = false;
}
void EmulatorAtm::set_log_file(const std::string &filename) {
m_logger.set_file(filename);
}
void EmulatorAtm::set_inference_config(
const inference::InferenceConfig &config) {
m_inference_config = config;
if (is_root()) {
m_logger.info("[EmulatorAtm] " + get_comp_name(m_type) +
" inference config set:\n backend = " +
inference::backend_type_to_string(config.backend) +
"\n model_path = " + config.model_path);
}
}
void EmulatorAtm::init_coupling_indices(const std::string &export_fields,
const std::string &import_fields) {
m_coupling_fields.initialize(export_fields, import_fields);
m_coupling_idx.initialize(m_coupling_fields);
}
void EmulatorAtm::init_impl() {
// Load YAML configuration
m_config =
parse_emulator_config_with_defaults(m_input_file, "eatm", is_root());
// Configure inference backend from build config
if (m_config.build.inference_backend == "libtorch") {
m_inference_config.backend = inference::BackendType::LIBTORCH;
} else {
m_inference_config.backend = inference::BackendType::STUB;
}
// Model path from runtime config
m_inference_config.model_path = m_config.runtime.model_path;
if (is_root()) {
m_logger.info("[EmulatorAtm] inference config:");
m_logger.info(" backend=" + inference::backend_type_to_string(
m_inference_config.backend));
m_logger.info(" model_path=" + m_inference_config.model_path);
m_logger.info(" spatial_mode=" + std::string(m_config.model_io.spatial_mode
? "true"
: "false"));
m_logger.info(" input_channels=" +
std::to_string(m_config.model_io.input_variables.size()) +
" output_channels=" +
std::to_string(m_config.model_io.output_variables.size()));
}
// Allocate all field storage
m_fields.allocate(m_num_local_cols);
// Read initial conditions
if (!m_config.runtime.ic_file.empty()) {
if (!read_initial_conditions(m_config.runtime.ic_file)) {
throw std::runtime_error("[EmulatorAtm] Failed to read IC file: " +
m_config.runtime.ic_file);
}
} else if (m_inference_config.backend == inference::BackendType::STUB) {
// Test mode: STUB backend without IC file uses defaults
m_fields.set_defaults(m_num_local_cols);
} else {
throw std::runtime_error(
"[EmulatorAtm] runtime.ic_file is required for non-stub backends");
}
// Set up inference channel counts
// For spatial_mode: input_channels = C (the backend receives [1, C*H*W])
// For pointwise: input_channels = C (backend receives [H*W, C])
m_inference_config.input_channels = m_config.model_io.input_variables.size();
m_inference_config.output_channels =
m_config.model_io.output_variables.size();
// Set spatial mode and grid dimensions for proper tensor reshaping
m_inference_config.spatial_mode = m_config.model_io.spatial_mode;
m_inference_config.grid_height = m_ny;
m_inference_config.grid_width = m_nx;
if (is_root()) {
m_logger.info("[EmulatorAtm] Creating inference backend...");
}
m_inference = inference::create_backend(m_inference_config);
if (!m_inference->initialize(m_inference_config)) {
if (is_root()) {
m_logger.warn(
"[EmulatorAtm] Failed to initialize inference backend, using stub");
}
// Fallback to stub
m_inference_config.backend = inference::BackendType::STUB;
m_inference = inference::create_backend(m_inference_config);
m_inference->initialize(m_inference_config);
}
if (is_root()) {
m_logger.info("[EmulatorAtm] Inference backend initialized.");
}
// Create field data provider for output manager
m_field_provider =
std::make_unique<impl::AtmFieldDataProvider>(m_fields, m_num_local_cols);
// Initialize diagnostic output manager
std::string case_name = "emulator"; // TODO: Get from runtime config
m_output_manager.initialize(m_config.diagnostics, m_comm, m_col_gids, m_ny,
m_nx, case_name, ".", m_logger);
m_output_manager.setup(*m_field_provider);
if (is_root()) {
m_logger.info("[EmulatorAtm] Diagnostic output manager initialized "
"(" +
std::to_string(m_config.diagnostics.history_streams.size()) +
" history stream(s))");
}
// Export initial values to coupler
export_coupling_fields();
}
void EmulatorAtm::run_impl(int dt) {
// 1. Import fields from coupler
import_coupling_fields();
// 2. Prepare AI model inputs (with optional spatial reshape)
prepare_inputs();
// 3. Run AI inference
run_inference(m_fields.net_inputs, m_fields.net_outputs);
// 4. Process AI outputs (with optional spatial reshape)
process_outputs();
// 5. Diagnostic output step
// Detect any new stacked fields from AI output (e.g., wind_0, wind_1)
m_field_provider->detect_stacked_fields();
// Run output manager with current field state
m_output_manager.init_timestep(m_step_count, dt);
m_output_manager.run(m_step_count, *m_field_provider);
// 6. Restart file writing (if this is a restart step)
if (m_output_manager.is_restart_step(m_step_count)) {
m_output_manager.write_restart(*m_field_provider, m_step_count);
m_output_manager.write_history_restart(m_step_count);
}
// 7. Export fields to coupler
export_coupling_fields();
}
void EmulatorAtm::import_coupling_fields() {
if (m_import_data == nullptr) {
if (is_root()) {
m_logger.error(
"[EmulatorAtm] import_coupling_fields: No import data pointer!");
}
return;
}
impl::import_atm_fields(m_import_data, m_num_local_cols, m_num_imports,
m_coupling_idx, m_fields);
}
void EmulatorAtm::export_coupling_fields() {
if (m_export_data == nullptr) {
if (is_root()) {
m_logger.error("[EmulatorAtm] export_coupling_fields: No export data!");
}
return;
}
impl::export_atm_fields(m_export_data, m_num_local_cols, m_num_exports,
m_coupling_idx, m_fields);
}
void EmulatorAtm::prepare_inputs() {
const auto &input_vars = m_config.model_io.input_variables;
if (input_vars.empty()) {
if (is_root()) {
m_logger.warn("[EmulatorAtm] No input variables configured!");
}
return;
}
const int C_in = static_cast<int>(input_vars.size());
const int H = m_ny;
const int W = m_nx;
const int HW = H * W; // Should equal m_num_local_cols
// Allocate net_inputs and net_outputs
if (m_config.model_io.spatial_mode) {
// Spatial mode: flatten to [C, H, W] = C*H*W total elements
m_fields.net_inputs.resize(static_cast<size_t>(C_in * HW));
m_fields.net_outputs.resize(m_config.model_io.output_variables.size() *
static_cast<size_t>(HW));
} else {
// Pointwise mode: [H*W, C]
m_fields.net_inputs.resize(static_cast<size_t>(HW * C_in));
m_fields.net_outputs.resize(static_cast<size_t>(HW) *
m_config.model_io.output_variables.size());
}
if (m_config.model_io.spatial_mode) {
// SPATIAL MODE: Pack as [C, H, W] (channel-major)
// For each channel c, copy all H*W spatial values contiguously
// This matches PyTorch NCHW layout (without the N dimension)
for (int c = 0; c < C_in; ++c) {
std::vector<double> *field_ptr = m_fields.get_field_ptr(input_vars[c]);
if (!field_ptr) {
if (is_root()) {
m_logger.error("[EmulatorAtm] Missing input field: " + input_vars[c]);
}
continue;
}
// Copy entire field as one contiguous block for this channel
// net_inputs[c * HW ... (c+1)*HW - 1] = field[0 ... HW-1]
std::memcpy(&m_fields.net_inputs[c * HW], field_ptr->data(),
HW * sizeof(double));
}
} else {
// POINTWISE MODE: Pack as [batch_size, channels] = [HW, C]
// For each grid point, interleave all channels
for (int c = 0; c < C_in; ++c) {
std::vector<double> *field_ptr = m_fields.get_field_ptr(input_vars[c]);
if (!field_ptr) {
if (is_root()) {
m_logger.error("[EmulatorAtm] Missing input field: " + input_vars[c]);
}
continue;
}
for (int col = 0; col < HW; ++col) {
m_fields.net_inputs[col * C_in + c] = (*field_ptr)[col];
}
}
}
}
void EmulatorAtm::process_outputs() {
const auto &output_vars = m_config.model_io.output_variables;
if (output_vars.empty()) {
return;
}
const int C_out = static_cast<int>(output_vars.size());
const int H = m_ny;
const int W = m_nx;
const int HW = H * W;
// Verify output size
size_t expected_size = static_cast<size_t>(C_out * HW);
if (m_fields.net_outputs.size() < expected_size) {
if (is_root()) {
m_logger.error("[EmulatorAtm] net_outputs size mismatch. Expected " +
std::to_string(expected_size) + ", got " +
std::to_string(m_fields.net_outputs.size()));
}
return;
}
if (m_config.model_io.spatial_mode) {
// SPATIAL MODE: Unpack from [C, H, W]
for (int c = 0; c < C_out; ++c) {
m_fields.register_dynamic_field(output_vars[c]);
std::vector<double> *field_ptr = m_fields.get_field_ptr(output_vars[c]);
if (!field_ptr) {
continue;
}
// Ensure field is correctly sized
if (field_ptr->size() != static_cast<size_t>(HW)) {
field_ptr->resize(HW);
}
// Copy contiguous channel data back to field
std::memcpy(field_ptr->data(), &m_fields.net_outputs[c * HW],
HW * sizeof(double));
}
} else {
// POINTWISE MODE: Unpack from [HW, C]
for (int c = 0; c < C_out; ++c) {
m_fields.register_dynamic_field(output_vars[c]);
std::vector<double> *field_ptr = m_fields.get_field_ptr(output_vars[c]);
if (!field_ptr) {
continue;
}
if (field_ptr->size() != static_cast<size_t>(HW)) {
field_ptr->resize(HW);
}
for (int col = 0; col < HW; ++col) {
(*field_ptr)[col] = m_fields.net_outputs[col * C_out + c];
}
}
}
}
void EmulatorAtm::run_inference(const std::vector<double> &inputs,
std::vector<double> &outputs) {
if (!m_inference || !m_inference->is_initialized()) {
m_logger.error("[EmulatorAtm] FATAL: run_inference() called but no "
"backend is initialized!");
std::cerr << "\n*** EMULATOR ABORT: Inference backend not initialized ***\n"
<< std::endl;
MPI_Abort(m_comm, 1);
}
const int C_in = static_cast<int>(m_config.model_io.input_variables.size());
const int C_out = static_cast<int>(m_config.model_io.output_variables.size());
const int HW = m_ny * m_nx;
// Ensure output buffer is sized correctly
size_t required_size = static_cast<size_t>(C_out * HW);
if (outputs.size() != required_size) {
outputs.resize(required_size);
}
bool success = false;
if (m_config.model_io.spatial_mode) {
// SPATIAL MODE for CNN models:
// We pass batch_size=1 and input_channels = C*H*W
// The data is in [C, H, W] flattened format
// The backend (LibTorchBackend) will reshape to [1, C, H, W] before
// inference
success = m_inference->infer(inputs.data(), outputs.data(), 1);
} else {
// POINTWISE MODE for MLP models:
// Each grid cell is a separate sample
// Backend receives [H*W, C]
success = m_inference->infer(inputs.data(), outputs.data(), HW);
}
if (!success) {
m_logger.error("[EmulatorAtm] FATAL: Inference failed!");
m_logger.error("[EmulatorAtm] Input shape: [" +
std::to_string(m_config.model_io.spatial_mode ? 1 : HW) +
", " + std::to_string(C_in) + "]");
m_logger.error("[EmulatorAtm] Expected output: [" +
std::to_string(m_config.model_io.spatial_mode ? 1 : HW) +
", " + std::to_string(C_out) + "]");
m_logger.error("[EmulatorAtm] Check model path and input/output "
"configuration in atm_in");
std::cerr << "\n*** EMULATOR ABORT: Inference failed! ***\n"
<< "Check log for details.\n"
<< std::endl;
MPI_Abort(m_comm, 1);
}
}
void EmulatorAtm::final_impl() {
if (is_root()) {
m_logger.info("[EmulatorAtm] Finalizing...");
}
// Write final restart files if enabled
if (m_field_provider) {
m_output_manager.write_restart(*m_field_provider, m_step_count);
m_output_manager.write_history_restart(m_step_count);
}
// Finalize output manager
m_output_manager.finalize();
if (m_inference) {
m_inference->finalize();
m_inference.reset();
}
m_fields.deallocate();
}
bool EmulatorAtm::read_initial_conditions(const std::string &filename) {
return impl::read_atm_initial_conditions(
filename, m_num_global_cols, m_num_local_cols, m_col_gids, m_lat,
m_fields, m_config.model_io.input_variables, m_logger, is_root());
}
} // namespace emulator