MUQ  0.4.3
ParallelMIComponentFactory.h
Go to the documentation of this file.
1 #ifndef PARALLELMICOMPONENTFACTORY_H_
2 #define PARALLELMICOMPONENTFACTORY_H_
3 
4 #include "MUQ/config.h"
5 
6 #if MUQ_HAS_MPI
7 
8 #if !MUQ_HAS_PARCER
9 #error
10 #endif
11 
12 #include "spdlog/spdlog.h"
14 
15 namespace muq {
16  namespace SamplingAlgorithms {
17 
31 
32  public:
33 
34  ParallelMIComponentFactory (std::shared_ptr<parcer::Communicator> comm, std::shared_ptr<parcer::Communicator> global_comm, std::shared_ptr<MIComponentFactory> componentFactory)
36  {
37 
38  if (comm->GetRank() != 0) {
39  while (true) {
40  spdlog::trace("Parallel factory rank {} waiting...", comm->GetRank());
41  ControlFlag command = comm->Recv<ControlFlag>(0, WorkgroupTag);
42  spdlog::trace("Parallel factory rank {} received command", comm->GetRank(), command);
43  if (command == ControlFlag::FINALIZE) {
44  samplingProblems.clear(); // Tear down models synchronously
45  comm->Barrier();
46  spdlog::trace("Parallel factory rank {} passed finalize barrier", comm->GetRank());
47  break;
48  }
49  if (command == ControlFlag::INIT_PROBLEM) {
50  auto index = std::make_shared<MultiIndex>(comm->Recv<MultiIndex>(0, WorkgroupTag));
51  int id = comm->Recv<int>(0, WorkgroupTag);
52  spdlog::trace("Parallel factory rank {} building model index {}", comm->GetRank(), *index);
53  samplingProblems[id] = componentFactory->SamplingProblem(index);//std::make_shared<MySamplingProblem>(index, comm, id, measurements);
54  }
55  else if (command == ControlFlag::LOGDENSITY) {
56  int id = comm->Recv<int>(0, WorkgroupTag);
57  auto state = std::make_shared<SamplingState>(comm->Recv<Eigen::VectorXd>(0, WorkgroupTag));
58  samplingProblems[id]->LogDensity(state);
59  }
60  else if (command == ControlFlag::TEST) {
61  int id = comm->Recv<int>(0, WorkgroupTag);
62  auto state = std::make_shared<SamplingState>(comm->Recv<Eigen::VectorXd>(0, WorkgroupTag));
63 
64  double density;
65  std::cerr << "Not implemented!!!" << std::endl;
66  }
67  else if (command == ControlFlag::QOI) {
68  int id = comm->Recv<int>(0, WorkgroupTag);
69 
70  samplingProblems[id]->QOI();
71  } else {
72  std::cerr << "Unexpected command!" << std::endl;
73  exit(43);
74  }
75  }
76  }
77  }
78  virtual bool IsInverseProblem() override {
79  return componentFactory->IsInverseProblem();
80  }
81 
86  void finalize() {
87  if (comm->GetRank() != 0)
88  return;
89  spdlog::trace("Parallel factory sending finalize");
90  for (int dest = 1; dest < comm->GetSize(); dest++)
92  samplingProblems.clear(); // Tear down models synchronously
93  comm->Barrier();
94  spdlog::trace("Parallel factory finalized");
95  }
96 
97  virtual std::shared_ptr<MCMCProposal> Proposal (std::shared_ptr<MultiIndex> const& index, std::shared_ptr<AbstractSamplingProblem> const& samplingProblem) override {
98  return componentFactory->Proposal(index, samplingProblem);
99  }
100 
101  virtual std::shared_ptr<MultiIndex> FinestIndex() override {
102  return componentFactory->FinestIndex();
103  }
104 
105  virtual std::shared_ptr<MCMCProposal> CoarseProposal (std::shared_ptr<MultiIndex> const& fineIndex,
106  std::shared_ptr<MultiIndex> const& coarseIndex,
107  std::shared_ptr<AbstractSamplingProblem> const& coarseProblem,
108  std::shared_ptr<SingleChainMCMC> const& coarseChain) override {
109  return componentFactory->CoarseProposal(fineIndex, coarseIndex, coarseProblem, coarseChain);
110  }
111 
112  virtual std::shared_ptr<AbstractSamplingProblem> SamplingProblem (std::shared_ptr<MultiIndex> const& index) override {
113  idcnt++;
114  if (comm->GetRank() == 0) {
115  spdlog::debug("Rank {} requesting model {} from parallel factory", comm->GetRank(), *index);
116  for (int dest = 1; dest < comm->GetSize(); dest++) {
118  comm->Send(*index, dest, WorkgroupTag);
119  comm->Send(idcnt, dest, WorkgroupTag);
120  }
121  }
122  samplingProblems[idcnt] = std::make_shared<ParallelAbstractSamplingProblem>(comm, idcnt, componentFactory->SamplingProblem(index));
123  return samplingProblems[idcnt];
124  }
125 
126  virtual std::shared_ptr<MIInterpolation> Interpolation (std::shared_ptr<MultiIndex> const& index) override {
127  return componentFactory->Interpolation(index);
128  }
129 
130  virtual Eigen::VectorXd StartingPoint (std::shared_ptr<MultiIndex> const& index) override {
131  return componentFactory->StartingPoint(index);
132  }
133 
134  private:
135  int idcnt = 0;
136  std::shared_ptr<parcer::Communicator> comm;
137  std::shared_ptr<parcer::Communicator> global_comm;
138  std::shared_ptr<MIComponentFactory> componentFactory;
139 
140  std::map<int, std::shared_ptr<AbstractSamplingProblem>> samplingProblems;
141 
142  };
143 
144  }
145 }
146 
147 #endif
148 
149 #endif
Interface defining models on a multiindex structure.
Wrapper for MIComponentFactory supporting parallel model setup.
virtual Eigen::VectorXd StartingPoint(std::shared_ptr< MultiIndex > const &index) override
virtual std::shared_ptr< MultiIndex > FinestIndex() override
virtual std::shared_ptr< MIInterpolation > Interpolation(std::shared_ptr< MultiIndex > const &index) override
virtual std::shared_ptr< MCMCProposal > Proposal(std::shared_ptr< MultiIndex > const &index, std::shared_ptr< AbstractSamplingProblem > const &samplingProblem) override
virtual std::shared_ptr< MCMCProposal > CoarseProposal(std::shared_ptr< MultiIndex > const &fineIndex, std::shared_ptr< MultiIndex > const &coarseIndex, std::shared_ptr< AbstractSamplingProblem > const &coarseProblem, std::shared_ptr< SingleChainMCMC > const &coarseChain) override
virtual std::shared_ptr< AbstractSamplingProblem > SamplingProblem(std::shared_ptr< MultiIndex > const &index) override
ParallelMIComponentFactory(std::shared_ptr< parcer::Communicator > comm, std::shared_ptr< parcer::Communicator > global_comm, std::shared_ptr< MIComponentFactory > componentFactory)
void finalize()
Stops worker processes command loop, freeing them for other tasks.
std::map< int, std::shared_ptr< AbstractSamplingProblem > > samplingProblems
ControlFlag
Flags used by parallel MCMC/MIMCMC type methods.
Definition: ParallelFlags.h:23