Skip to content

File libtorch_backend.hpp

File List > common > src > inference > libtorch_backend.hpp

Go to the documentation of this file

#ifndef LIBTORCH_BACKEND_HPP
#define LIBTORCH_BACKEND_HPP

#include "inference_backend.hpp"
#include <memory>
#include <string>

namespace emulator {
namespace inference {

class LibTorchBackend : public InferenceBackend {
public:
  LibTorchBackend();
  ~LibTorchBackend() override;

  bool initialize(const InferenceConfig &config) override;

  bool infer(const double *inputs, double *outputs, int batch_size) override;

  void finalize() override;

  std::string name() const override { return "LibTorch"; }

  bool is_initialized() const override { return m_initialized; }

  BackendType type() const override { return BackendType::LIBTORCH; }

  size_t get_memory_usage_bytes() const;

private:
  bool m_initialized = false;      
  InferenceConfig m_config;        
  size_t m_model_memory_bytes = 0; 

  struct Impl;
  std::unique_ptr<Impl> m_impl;
};

} // namespace inference
} // namespace emulator

#endif // LIBTORCH_BACKEND_HPP