File libtorch_backend.cpp
File List > common > src > inference > libtorch_backend.cpp
Go to the documentation of this file
#include "libtorch_backend.hpp"
// Standard library
#include <iostream>
// LibTorch headers
#include <torch/script.h>
#include <torch/torch.h>
namespace emulator {
namespace inference {
struct LibTorchBackend::Impl {
torch::jit::script::Module model;
torch::Device device = torch::kCPU;
torch::ScalarType dtype = torch::kFloat32;
bool model_loaded = false;
};
LibTorchBackend::LibTorchBackend() = default;
LibTorchBackend::~LibTorchBackend() {
if (m_initialized) {
finalize();
}
}
bool LibTorchBackend::initialize(const InferenceConfig &config) {
m_config = config;
m_impl = std::make_unique<Impl>();
// Configure execution device
if (m_config.device_id >= 0 && torch::cuda::is_available()) {
m_impl->device = torch::Device(torch::kCUDA, m_config.device_id);
} else {
m_impl->device = torch::kCPU;
}
// Configure precision
// Default to Float64 for E3SM compatibility (E3SM uses double precision)
// FP16 or FP32 only available/recommended on CUDA for acceleration
if (m_config.use_fp16 && m_impl->device.is_cuda()) {
m_impl->dtype = torch::kFloat16;
} else {
// Use Float32 by default
// NOTE: E3SM uses double precision by default
// TODO: make this configurable if needed
m_impl->dtype = torch::kFloat32;
}
// Load TorchScript model
try {
m_impl->model = torch::jit::load(m_config.model_path, m_impl->device);
m_impl->model.eval(); // Set to evaluation mode (disables dropout, etc.)
} catch (const c10::Error &e) {
std::cerr << "[LibTorchBackend] Failed to load model: " << e.what()
<< std::endl;
return false;
}
m_impl->model_loaded = true;
m_initialized = true;
return true;
}
bool LibTorchBackend::infer(const double *inputs, double *outputs,
int batch_size) {
if (!m_initialized || !m_impl || !m_impl->model_loaded) {
std::cerr << "[LibTorchBackend::infer] ERROR: Backend not initialized!"
<< std::endl;
return false;
}
const int C_in = m_config.input_channels;
const int C_out = m_config.output_channels;
try {
torch::NoGradGuard no_grad;
// Create input tensor by wrapping the pointer directly
// EmulatorAtm has already arranged data in the correct shape:
// - Spatial mode: [batch_size, C, H, W] flattened
// - Pointwise mode: [batch_size, C]
std::vector<int64_t> input_shape;
if (m_config.spatial_mode) {
input_shape = {batch_size, C_in, m_config.grid_height,
m_config.grid_width};
} else {
input_shape = {batch_size, C_in};
}
// We assume inputs are in double precision (Float64) and on host
// TODO: make this configurable if needed
torch::Tensor input_tensor = torch::from_blob(
const_cast<double *>(inputs), input_shape,
torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU));
// Convert to model dtype (Float32) and target device
// TODO: make EmulatorComp sends us the right dtype
input_tensor = input_tensor.to(m_impl->device, m_impl->dtype);
// Execute model forward pass
std::vector<torch::jit::IValue> model_inputs;
model_inputs.push_back(input_tensor);
torch::Tensor output_tensor =
m_impl->model.forward(model_inputs).toTensor();
// Convert output back to Float64 on CPU
// TODO: make this configurable if needed
output_tensor = output_tensor.to(torch::kCPU, torch::kFloat64);
// Ensure contiguous and copy
if (!output_tensor.is_contiguous()) {
output_tensor = output_tensor.contiguous();
}
// Output shape matches what EmulatorComp expects:
// - Spatial: [batch_size, C, H, W]
// - Pointwise: [batch_size, C]
const size_t output_size = output_tensor.numel();
std::memcpy(outputs, output_tensor.data_ptr<double>(),
output_size * sizeof(double));
} catch (const c10::Error &e) {
std::cerr << "[LibTorchBackend::infer] FATAL c10::Error: " << e.what()
<< std::endl;
return false;
} catch (const std::exception &e) {
std::cerr << "[LibTorchBackend::infer] FATAL exception: " << e.what()
<< std::endl;
return false;
}
return true;
}
void LibTorchBackend::finalize() {
m_impl.reset();
m_initialized = false;
}
size_t LibTorchBackend::get_memory_usage_bytes() const {
return m_model_memory_bytes;
}
} // namespace inference
} // namespace emulator