Skip to content

File inference_backend.hpp

File List > common > src > inference > inference_backend.hpp

Go to the documentation of this file

#ifndef INFERENCE_BACKEND_HPP
#define INFERENCE_BACKEND_HPP

#include <memory>
#include <string>
#include <vector>

// MPI type handling:
// Try to include mpi.h if available to get real definitions.
// If mpi.h is not available, provide stub definitions for non-MPI builds.
#if defined(__has_include) && __has_include(<mpi.h>)
#include <mpi.h>
#else
// No mpi.h available - provide stub definitions for non-MPI builds
typedef int MPI_Comm;
#define MPI_COMM_WORLD 0
#define MPI_COMM_NULL 0
#endif

namespace emulator {
namespace inference {

enum class BackendType {
  STUB,     
  LIBTORCH, 
  // Future backends (uncomment when implemented):
  // PYTORCH, ///< Python interpreter backend (native PyTorch, no tracing)
  // ONNX,    ///< ONNX Runtime backend (cross-framework, optimized)
  // LAPIS,   ///< Kokkos interop backend (GPU memory sharing with E3SM)
};

inline std::string backend_type_to_string(BackendType type) {
  switch (type) {
  case BackendType::STUB:
    return "STUB";
  case BackendType::LIBTORCH:
    return "LIBTORCH";
  default:
    return "UNKNOWN";
  }
}

inline BackendType parse_backend_type(const std::string &str) {
  if (str == "LIBTORCH" || str == "libtorch" || str == "torch") {
    return BackendType::LIBTORCH;
  }
  return BackendType::STUB;
}

struct InferenceConfig {
  BackendType backend = BackendType::STUB; 

  std::string model_path; 

  int device_id = -1;    
  bool use_fp16 = false; 
  bool verbose = false;  

  int input_channels = 44;  
  int output_channels = 50; 

  // Spatial mode settings (for CNN models)
  bool spatial_mode =
      false;           
  int grid_height = 0; 
  int grid_width = 0;  

  // Validation and dry-run
  bool dry_run = false; 
  std::vector<std::string>
      expected_input_vars; 
  std::vector<std::string>
      expected_output_vars; 
};

struct ValidationResult {
  bool valid = true;                 
  std::vector<std::string> errors;   
  std::vector<std::string> warnings; 

  void add_error(const std::string &msg) {
    valid = false;
    errors.push_back(msg);
  }

  void add_warning(const std::string &msg) { warnings.push_back(msg); }

  bool has_warnings() const { return !warnings.empty(); }
};

class InferenceBackend {
public:
  virtual ~InferenceBackend() = default;

  virtual bool initialize(const InferenceConfig &config) = 0;

  virtual bool infer(const double *inputs, double *outputs, int batch_size) = 0;

  virtual void finalize() = 0;

  virtual std::string name() const = 0;

  virtual bool is_initialized() const = 0;

  virtual ValidationResult validate() const {
    // Default implementation: always valid
    return ValidationResult{};
  }

  virtual BackendType type() const = 0;
};

std::unique_ptr<InferenceBackend> create_backend(BackendType type);

std::unique_ptr<InferenceBackend> create_backend(const InferenceConfig &config);

} // namespace inference
} // namespace emulator

#endif // INFERENCE_BACKEND_HPP