MUQ  0.4.3
CrankNicolsonProposal.cpp
Go to the documentation of this file.
2 
4 
10 #include "MUQ/Modeling/WorkPiece.h"
11 
12 using namespace muq::SamplingAlgorithms;
13 using namespace muq::Modeling;
14 
16 
17 CrankNicolsonProposal::CrankNicolsonProposal(boost::property_tree::ptree const& pt,
18  std::shared_ptr<AbstractSamplingProblem> prob,
19  std::shared_ptr<GaussianBase> priorIn) : MCMCProposal(pt,prob),
20  beta(pt.get("Beta",0.5)),
21  priorDist(priorIn)
22 {
23 }
24 
25 CrankNicolsonProposal::CrankNicolsonProposal(boost::property_tree::ptree const& pt,
26  std::shared_ptr<AbstractSamplingProblem> prob) : MCMCProposal(pt,prob),
27  beta(pt.get("Beta",0.5))
28 {
29  ExtractPrior(prob, pt.get<std::string>("PriorNode"));
30 }
31 
32 
33 std::shared_ptr<SamplingState> CrankNicolsonProposal::Sample(std::shared_ptr<SamplingState> const& currentState)
34 {
35  // the mean of the proposal is the current point
36  std::vector<Eigen::VectorXd> props = currentState->state;
37 
38  std::vector<Eigen::VectorXd> hypers = GetPriorInputs(currentState->state);
39 
40  Eigen::VectorXd priorSamp = priorDist->Sample(hypers);
41 
42  props.at(blockInd) = priorDist->GetMean() + sqrt(1.0-beta*beta)*(currentState->state.at(blockInd)-priorDist->GetMean()) + beta*(priorSamp - priorDist->GetMean());
43 
44  // store the new state in the output
45  return std::make_shared<SamplingState>(props, 1.0);
46 }
47 
48 double CrankNicolsonProposal::LogDensity(std::shared_ptr<SamplingState> const& currState,
49  std::shared_ptr<SamplingState> const& propState)
50 {
51  std::vector<Eigen::VectorXd> hypers = GetPriorInputs(currState->state);
52  if(hypers.size()>0)
53  priorDist->ResetHyperparameters(WorkPiece::ToRefVector(hypers));
54 
55  Eigen::VectorXd diff = (propState->state.at(blockInd) - priorDist->GetMean() - sqrt(1.0-beta*beta)*(currState->state.at(blockInd)-priorDist->GetMean()))/beta;
56 
57  hypers.insert(hypers.begin(), (diff + priorDist->GetMean()).eval());
58  return priorDist->LogDensity(hypers);
59 }
60 
61 std::vector<Eigen::VectorXd> CrankNicolsonProposal::GetPriorInputs(std::vector<Eigen::VectorXd> const& currState)
62 {
63  std::vector<Eigen::VectorXd> hyperParams;
64 
65  if(priorMeanModel){
67  for(int i=0; i<priorMeanInds.size(); ++i)
68  meanIns.push_back( std::cref( currState.at(priorMeanInds.at(i)) ) );
69 
70  hyperParams.push_back(priorMeanModel->Evaluate(meanIns).at(0));
71  }
72 
73  if(priorCovModel){
75  for(int i=0; i<priorCovInds.size(); ++i)
76  covIns.push_back( std::cref( currState.at(priorCovInds.at(i)) ) );
77 
78  hyperParams.push_back(priorCovModel->Evaluate(covIns).at(0));
79  }
80 
81  return hyperParams;
82 }
83 
84 
85 void CrankNicolsonProposal::ExtractPrior(std::shared_ptr<AbstractSamplingProblem> const& prob,
86  std::string nodeName)
87 {
88  // Cast the abstract base class into a sampling problem
89  std::shared_ptr<SamplingProblem> prob2 = std::dynamic_pointer_cast<SamplingProblem>(prob);
90  assert(prob2);
91 
92  // From the sampling problem, extract the ModPiece and try to cast it to a ModGraphPiece
93  std::shared_ptr<ModPiece> targetDens = prob2->GetDistribution();
94  std::shared_ptr<ModGraphPiece> targetDens2 = std::dynamic_pointer_cast<ModGraphPiece>(targetDens);
95  assert(targetDens2);
96 
97  // Get the graph
98  auto graph = targetDens2->GetGraph();
99 
100  // Get the prior piece corresponding to the Gaussian name
101  auto priorPiece = graph->GetPiece(nodeName);
102  assert(priorPiece);
103 
104  // Get the prior distribution
105  std::shared_ptr<Density> priorDens = std::dynamic_pointer_cast<Density>(priorPiece);
106  assert(priorDens);
107 
108  priorDist = std::dynamic_pointer_cast<GaussianBase>(priorDens->GetDistribution());
109  assert(priorDist);
110 
111  // Check to see if the prior has a mean or covariance input.
112  auto gaussPrior = std::dynamic_pointer_cast<Gaussian>(priorDist);
113  if(gaussPrior==nullptr)
114  return;
115 
116  Gaussian::InputMask inputTypes = gaussPrior->GetInputTypes();
117 
118  if(inputTypes == Gaussian::None)
119  return;
120 
121  auto newGraph = graph->Clone();
122  auto constPieces = targetDens2->GetConstantPieces();
123  for(int i=0; i<constPieces.size(); ++i)
124  newGraph->RemoveNode( newGraph->GetName(constPieces.at(i)) );
125 
126  // Check if there is a mean input
127  if(inputTypes & Gaussian::Mean){
128  std::string meanInput = graph->GetParent(nodeName, 1);
129  if(newGraph->HasNode(meanInput)){
130  priorMeanModel = newGraph->CreateModPiece(meanInput);
131  priorMeanInds = targetDens2->MatchInputs(std::dynamic_pointer_cast<ModGraphPiece>(priorMeanModel));
132  }else{
133  priorMeanModel = std::make_shared<IdentityOperator>(priorDens->inputSizes(1));
134  auto iter = std::find(constPieces.begin(), constPieces.end(), graph->GetPiece(meanInput));
135  priorMeanInds.push_back( std::distance(constPieces.begin(),iter));
136  }
137  }
138 
139  // Check if there is a covariance or precision input
140  if((inputTypes & Gaussian::DiagCovariance) || (inputTypes & Gaussian::FullCovariance) || (inputTypes & Gaussian::DiagPrecision)||(inputTypes & Gaussian::FullPrecision)){
141  int covInd = (inputTypes & Gaussian::Mean) ? 2 : 1;
142  priorUsesCov = (inputTypes & Gaussian::DiagCovariance) || (inputTypes & Gaussian::FullCovariance);
143 
144  std::string covInput = graph->GetParent(nodeName, covInd);
145  if(newGraph->HasNode(covInput)){
146  priorCovModel = newGraph->CreateModPiece(covInput);
147  priorCovInds = targetDens2->MatchInputs(std::dynamic_pointer_cast<ModGraphPiece>(priorCovModel));
148  }else{
149  priorCovModel = std::make_shared<IdentityOperator>(priorDens->inputSizes(covInd));
150  auto iter = std::find(constPieces.begin(), constPieces.end(), graph->GetPiece(covInput));
151  priorCovInds.push_back( std::distance(constPieces.begin(),iter));
152  }
153  }
154 
155  // if(priorDens->inputSizes.size()>1){
156  //
157  // // Create a new graph that does not have the constant pieces
158  // auto newGraph = graph->Clone();
159  // auto constPieces = targetDens2->GetConstantPieces();
160  // for(int i=0; i<constPieces.size(); ++i)
161  // newGraph->RemoveNode( newGraph->GetName(constPieces.at(i)) );
162  //
163  // std::vector<std::string> gaussInputs = graph->GetParents(nodeName);
164  // inputPieces.resize(gaussInputs.size()-1);
165  // blockInds.resize(gaussInputs.size()-1);
166  //
167  // std::vector<std::vector<int>> blockInds(gaussInputs.size()-1);
168  //
169  // for(int i=0; i<inputPieces.size(); ++i){
170  //
171  // if(newGraph->HasNode(gaussInputs.at(i+1))){
172  // inputPieces.at(i) = newGraph->CreateModPiece(gaussInputs.at(i+1));
173  // blockInds.at(i) = targetDens2->MatchInputs(std::dynamic_pointer_cast<ModGraphPiece>(inputPieces.at(i)));
174  // }else{
175  // inputPieces.at(i) = std::make_shared<IdentityOperator>(priorDens->inputSizes(i));
176  // auto iter = std::find(constPieces.begin(), constPieces.end(), graph->GetPiece(gaussInputs.at(i+1)));
177  // blockInds.at(i).push_back( std::distance(constPieces.begin(),iter));
178  // }
179  // }
180  //
181  // }
182 
183 }
REGISTER_MCMC_PROPOSAL(CrankNicolsonProposal) CrankNicolsonProposal
An implement of the dimension-independent pCN proposal.
std::shared_ptr< muq::Modeling::ModPiece > priorMeanModel
std::shared_ptr< muq::Modeling::GaussianBase > priorDist
The proposal distribution.
void ExtractPrior(std::shared_ptr< AbstractSamplingProblem > const &prob, std::string nodeName)
std::vector< Eigen::VectorXd > GetPriorInputs(std::vector< Eigen::VectorXd > const &currState)
virtual double LogDensity(std::shared_ptr< SamplingState > const &currState, std::shared_ptr< SamplingState > const &propState) override
std::shared_ptr< muq::Modeling::ModPiece > priorCovModel
CrankNicolsonProposal(boost::property_tree::ptree const &pt, std::shared_ptr< AbstractSamplingProblem > prob, std::shared_ptr< muq::Modeling::GaussianBase > prior)
virtual std::shared_ptr< SamplingState > Sample(std::shared_ptr< SamplingState > const &currentState) override
std::shared_ptr< AbstractSamplingProblem > prob
Definition: MCMCProposal.h:81
std::vector< std::reference_wrapper< const T > > ref_vector
A vector of references to something ...
Definition: WorkPiece.h:37
auto get(const nlohmann::detail::iteration_proxy_value< IteratorType > &i) -> decltype(i.key())
Definition: json.h:3956