14 namespace pt = boost::property_tree;
19 ParallelTempering::ParallelTempering(boost::property_tree::ptree opts,
20 std::shared_ptr<InferenceProblem>
const& problem) :
ParallelTempering(opts, ExtractTemps(opts), ExtractKernels(opts, problem))
24 Eigen::VectorXd inverseTemps,
25 std::vector<std::shared_ptr<TransitionKernel>> kernels) :
ParallelTempering(opts, inverseTemps, StackObjects(kernels))
31 Eigen::VectorXd inverseTemps,
32 std::vector<std::vector<std::shared_ptr<TransitionKernel>>> kernelsIn) : numTemps(inverseTemps.size()),
35 numSamps(opts.
get<double>(
"NumSamples")),
36 burnIn(opts.
get(
"BurnIn",0)),
37 printLevel(opts.
get(
"PrintLevel",3)),
38 swapIncr(opts.
get(
"Swap Increment", 2)),
39 seoSwaps(opts.
get(
"Swap Type",
"DEO")==
"SEO"),
40 cumulativeSwapProb(Eigen::VectorXd::Zero(numTemps-1)),
41 successfulSwaps(Eigen::VectorXd::Zero(numTemps-1)),
42 attemptedSwaps(Eigen::VectorXd::Zero(numTemps-1)),
43 nextAdaptInd(opts.
get(
"Adapt Start", 100))
45 if(std::abs(inverseTemps(0))>std::numeric_limits<double>::epsilon()){
46 std::stringstream msg;
47 msg <<
"In ParallelTempering constructor. First inverse temperature in schedule must be 0.0.";
48 throw std::invalid_argument(msg.str());
51 if(std::abs(inverseTemps(inverseTemps.size()-1)-1.0)>std::numeric_limits<double>::epsilon()){
52 std::stringstream msg;
53 msg <<
"In ParallelTempering constructor. Last inverse temperature in schedule must be 1.0.";
54 throw std::invalid_argument(msg.str());
58 if(inverseTemps.minCoeff()<-std::numeric_limits<double>::epsilon()){
59 std::stringstream msg;
60 msg <<
"In ParallelTempering constructor. Inverse temperatures must be in [0,1], but found a minimum temperature of " << inverseTemps.minCoeff();
61 throw std::invalid_argument(msg.str());
64 if(inverseTemps.maxCoeff()>1+std::numeric_limits<double>::epsilon()){
65 std::stringstream msg;
66 msg <<
"In ParallelTempering constructor. Inverse temperatures must be in [0,1], but found a maximum temperature of " << inverseTemps.maxCoeff();
67 throw std::invalid_argument(msg.str());
71 for(
unsigned int i=0; i<
kernels.size(); ++i){
72 problems.at(i) = std::dynamic_pointer_cast<InferenceProblem>(
kernels.at(i).at(0)->Problem() );
75 std::stringstream msg;
76 msg <<
"In ParallelTempering constructor. Could not cast sampling problem for dimension " << i <<
" into an InfereceProblem.";
77 throw std::invalid_argument(msg.str());
82 for(
unsigned int i=0; i<
kernels.size()-1; ++i){
83 for(
unsigned j=i+1; j<
kernels.size(); ++j){
85 std::stringstream msg;
86 msg <<
"In ParallelTempering constructor. Found pointers to the same sampling problem, which prevents setting the temperature at different levels.";
87 throw std::invalid_argument(msg.str());
95 for(
unsigned int i=0; i<
kernels.size(); ++i){
96 problems.at(i)->SetInverseTemp(inverseTemps(i));
97 chains.at(i) = std::make_shared<MarkovChain>();
105 std::stringstream msg;
106 msg <<
" In ParallelTempering::SetState, the size of the argument x0 is " << x0.size() <<
", but the temperature schedule has " <<
numTemps <<
" levels.";
107 throw std::invalid_argument(msg.str());
116 std::vector<std::shared_ptr<SamplingState>> states(
numTemps);
117 for(
unsigned int i=0; i<
numTemps; ++i)
118 states.at(i) = std::make_shared<SamplingState>(x0);
126 std::stringstream msg;
127 msg <<
" In ParallelTempering::SetState, the size of the argument x0 is " << x0.size() <<
", but the temperature schedule has " <<
numTemps <<
" levels.";
128 throw std::invalid_argument(msg.str());
131 std::vector<std::shared_ptr<SamplingState>> states(
numTemps);
132 for(
unsigned int i=0; i<
numTemps; ++i)
133 states.at(i) = std::make_shared<SamplingState>(x0.at(i));
141 return problems.at(chainInd)->GetInverseTemp();
161 const unsigned int printIncr = std::floor(
numSamps /
double(10));
162 unsigned int nextPrintInd = printIncr;
163 unsigned int nextSwapInd =
swapIncr;
167 std::cout <<
"Starting parallel tempering sampler..." << std::endl;
176 nextPrintInd += printIncr;
199 std::cout <<
"Completed in " <<
totalTime <<
" seconds." << std::endl;
214 for(
unsigned int i=0; i<
numTemps; ++i)
215 currBetas(i) =
problems.at(i)->GetInverseTemp();
220 for(
unsigned int i=1; i<
numTemps; ++i)
223 if((cumProbs.maxCoeff()-cumProbs.minCoeff())<1
e-8)
226 for(
unsigned int i=1; i<
numTemps-1; ++i){
232 if(cumProbs(j) >= desiredVal)
236 std::cout <<
"Cumulative probs: " << cumProbs.transpose() << std::endl;
237 std::cout <<
"desiredVal: " << desiredVal << std::endl;
242 double w = (desiredVal - cumProbs(j-1)) / (cumProbs(j) - cumProbs(j-1));
243 double newTemp =
w*currBetas(j) + (1.0-
w)*currBetas(j-1);
246 problems.at(i)->SetInverseTemp(newTemp);
258 for(
unsigned int i=0; i<
numTemps; ++i)
259 output(i) =
problems.at(i)->GetInverseTemp();
265 std::cout << prefix << int(std::floor(
double((currInd - 1) * 100) /
double(
numSamps))) <<
"% Complete" << std::endl;
268 std::streamsize ss = std::cout.precision();
269 std::cout.precision(2);
271 std::cout << prefix <<
" Inverse Temps: " <<
CollectInverseTemps().transpose() << std::endl;
272 std::cout.precision(ss);
276 std::cout << prefix <<
" Kernel 0:\n";
277 for(
int blockInd=0; blockInd<
kernels.at(0).size(); ++blockInd){
278 std::cout << prefix <<
" Block " << blockInd <<
":\n";
279 kernels.at(0).at(blockInd)->PrintStatus(prefix +
" ");
282 for(
int chainInd=0; chainInd<
numTemps; ++chainInd){
283 std::cout << prefix <<
" Kernel " << chainInd <<
":\n";
284 for(
int blockInd=0; blockInd<
kernels.at(chainInd).size(); ++blockInd){
285 std::cout << prefix <<
" Block " << blockInd <<
":\n";
286 kernels.at(chainInd).at(blockInd)->PrintStatus(prefix +
" ");
294 if(!state->HasMeta(
"InverseTemp")){
295 std::stringstream msg;
296 msg <<
"Error in ParallelTempering::SwapStates. Tried swapping states with a state that does not have temperature metadata. The state must have the \"InverseTemp\" metadata, which is typically set in InferenceProblem::LogDensity.";
297 throw std::runtime_error(msg.str());
300 if(!state->HasMeta(
"LogLikelihood")){
301 std::stringstream msg;
302 msg <<
"Error in ParallelTempering::SwapStates. Tried swapping states with a state that does not have likelihood metadata. The state must have the \"LogLikelihood\" metadata, which is typically set in InferenceProblem::LogDensity.";
303 throw std::runtime_error(msg.str());
306 if(!state->HasMeta(
"LogPrior")){
307 std::stringstream msg;
308 msg <<
"Error in ParallelTempering::SwapStates. Tried swapping states with a state that does not have prior metadata. The state must have the \"LogPrior\" metadata, which is typically set in InferenceProblem::LogDensity.";
309 throw std::runtime_error(msg.str());
316 unsigned int startInd;
318 startInd = RandomGenerator::GetUniformInt(0,1);
323 double beta1, beta2, logLikely1, logLikely2, alpha;
325 for(
unsigned int i=startInd; i<
numTemps-1; i+=2){
335 alpha = std::exp( (beta1 - beta2)*(logLikely2 - logLikely1));
340 if(RandomGenerator::GetUniform() < alpha){
358 bool initError =
false;
366 std::stringstream msg;
367 msg <<
"\nERROR in ParallelTempering::Sample. Trying to sample chain but previous (or initial) state has not been set.\n";
368 throw std::runtime_error(msg.str());
371 auto startTime = std::chrono::high_resolution_clock::now();
373 std::vector<std::vector<std::shared_ptr<SamplingState>>> newStates(
numTemps);
377 for(
unsigned int chainInd=0; chainInd<
numTemps; ++chainInd){
378 newStates.at(chainInd).resize(
kernels.at(chainInd).size());
381 for(
int kernInd=0; kernInd<
kernels.at(chainInd).size(); ++kernInd){
385 prevStates.at(chainInd)->meta[
"IsProposal"] =
false;
393 double now = std::chrono::duration<double>(std::chrono::high_resolution_clock::now()-startTime).count();
394 for(
auto& state : newStates.at(chainInd))
395 state->meta[
"time"] = now;
398 kernels.at(chainInd).at(kernInd)->PostStep(
sampNums.at(chainInd), newStates.at(chainInd));
405 auto endTime = std::chrono::high_resolution_clock::now();
406 totalTime += std::chrono::duration<double>(endTime - startTime).count();
413 assert(newStates.size() ==
numTemps);
415 for(
unsigned int chainInd=0; chainInd<
numTemps; ++chainInd){
416 for(
unsigned int stateInd=0; stateInd<newStates.at(chainInd).size(); ++stateInd){
420 chains.at(chainInd)->Add(newStates.at(chainInd).at(stateInd));
423 if(newStates.at(chainInd).at(stateInd)->HasMeta(
"QOI")) {
424 std::shared_ptr<SamplingState> qoi =
AnyCast(newStates.at(chainInd).at(stateInd)->meta[
"QOI"]);
425 QOIs.at(chainInd)->Add(qoi);
437 for(
unsigned int chainInd=0; chainInd<
numTemps; ++chainInd){
438 prevStates.at(chainInd) = newStates.at(chainInd).back();
445 std::vector<std::vector<std::shared_ptr<TransitionKernel>>> newKernels(kerns.size());
446 for(
unsigned int i=0; i<kerns.size(); ++i){
447 newKernels.resize(1);
448 newKernels.at(i).at(0) = kerns.at(i);
455 std::string allTemps = opts.get<std::string>(
"Inverse Temperatures");
458 Eigen::VectorXd inverseTemps(tempStrings.size());
459 for(
unsigned int i=0; i<tempStrings.size(); ++i)
460 inverseTemps(i) = std::stod(tempStrings.at(i));
466 std::shared_ptr<InferenceProblem>
const& problem)
468 std::string allKernelString = opts.get<std::string>(
"Kernel Lists");
471 std::vector<std::vector<std::shared_ptr<TransitionKernel>>>
kernels(chainStrings.size());
473 for(
unsigned int chainInd=0; chainInd<chainStrings.size(); ++chainInd){
476 unsigned int numBlocks = kernelNames.size();
478 kernels.at(chainInd).resize(numBlocks);
481 for(
int i=0; i<numBlocks; ++i) {
482 boost::property_tree::ptree subTree = opts.get_child(kernelNames.at(i));
483 subTree.put(
"BlockIndex",i);
485 auto prob = problem->Clone();
486 prob->AddOptions(subTree);
Defines an MCMC sampler with multiple chains running on problems with different temperatues.
std::vector< std::vector< std::shared_ptr< TransitionKernel > > > kernels
unsigned int nextAdaptInd
const unsigned int numTemps
Number of temperatures in the temperature schedule.
std::vector< std::shared_ptr< MarkovChain > > chains
void SaveSamples(std::vector< std::vector< std::shared_ptr< SamplingState >>> const &newStates)
double GetInverseTemp(unsigned int chainInd) const
void SetState(std::vector< std::shared_ptr< SamplingState >> const &x0)
Set the state of the MCMC chain.
Eigen::VectorXd CollectInverseTemps() const
std::vector< unsigned int > sampNums
Eigen::VectorXd cumulativeSwapProb
static std::vector< std::vector< std::shared_ptr< TransitionKernel > > > StackKernels(std::vector< std::shared_ptr< TransitionKernel >> const &kerns)
std::shared_ptr< SaveSchedulerBase > scheduler
static Eigen::VectorXd ExtractTemps(boost::property_tree::ptree opts)
void AdaptTemperatures()
Adapts the temperatures according to the procedure outlined in Section 5 of .
Eigen::VectorXd attemptedSwaps
std::vector< std::shared_ptr< SamplingState > > prevStates
void PrintStatus(unsigned int currInd) const
std::vector< std::shared_ptr< MarkovChain > > QOIs
static std::vector< std::vector< std::shared_ptr< TransitionKernel > > > ExtractKernels(boost::property_tree::ptree opts, std::shared_ptr< InferenceProblem > const &prob)
std::vector< std::shared_ptr< TransitionKernel > > const & Kernels(unsigned int chainInd) const
std::shared_ptr< MarkovChain > Run()
ParallelTempering(boost::property_tree::ptree opts, std::shared_ptr< InferenceProblem > const &problem)
static void CheckForMeta(std::shared_ptr< SamplingState > const &state)
Checks a sampling state to make sure it has the metadata necessary to swap states.
bool ShouldSave(unsigned int chainInd, unsigned int sampNum) const
Returns true if a sample of a particular chain should be saved.
Eigen::VectorXd successfulSwaps
static std::vector< std::vector< T > > StackObjects(std::vector< T > const &kerns)
std::vector< std::shared_ptr< InferenceProblem > > problems
static std::shared_ptr< TransitionKernel > Construct(boost::property_tree::ptree const &pt, std::shared_ptr< AbstractSamplingProblem > problem)
Static constructor for the transition kernel.
Class for easily casting boost::any's in assignment operations.
std::vector< std::string > Split(std::string str, char delim=',')
Split a string into parts based on a particular character delimiter. Also Strips whitespace from part...
auto get(const nlohmann::detail::iteration_proxy_value< IteratorType > &i) -> decltype(i.key())