MUQ  0.4.3
ProductKernel.cpp
Go to the documentation of this file.
2 
4 
5 using namespace muq::Approximation;
6 
7 ProductKernel::ProductKernel(std::shared_ptr<KernelBase> kernel1In,
8  std::shared_ptr<KernelBase> kernel2In) : KernelBase(kernel1In->inputDim,
9  std::max(kernel1In->coDim, kernel2In->coDim),
10  kernel1In->numParams + kernel2In->numParams),
11  kernel1(kernel1In),
12  kernel2(kernel2In)
13 {
14  assert((kernel1->coDim==kernel2->coDim) | (kernel1->coDim==1) | (kernel2->coDim==1));
15 
16  cachedParams.resize(numParams);
17  cachedParams.head(kernel1In->numParams) = kernel1In->GetParams();
18  cachedParams.tail(kernel2In->numParams) = kernel2In->GetParams();
19 
20 };
21 
22 
23 void ProductKernel::FillBlock(Eigen::Ref<const Eigen::VectorXd> const& x1,
24  Eigen::Ref<const Eigen::VectorXd> const& x2,
25  Eigen::Ref<const Eigen::VectorXd> const& params,
26  Eigen::Ref<Eigen::MatrixXd> block) const
27 {
28  Eigen::MatrixXd temp1(kernel1->coDim, kernel1->coDim);
29  Eigen::MatrixXd temp2(kernel2->coDim, kernel2->coDim);
30 
31  kernel1->FillBlock(x1, x2, params.head(kernel1->numParams), temp1);
32  kernel2->FillBlock(x1, x2, params.tail(kernel2->numParams), temp2);
33 
34  if(kernel1->coDim==kernel2->coDim)
35  {
36  block = Eigen::MatrixXd(temp1.array() * temp2.array());
37  }
38  else if(kernel1->coDim==1)
39  {
40  block = temp1(0,0)*temp2;
41  }
42  else if(kernel2->coDim==1)
43  {
44  block = temp2(0,0)*temp1;
45  }
46  else
47  {
48  std::cerr << "\nERROR: Something unexpected happened with the dimensions of the kernels in this product.\n";
49  assert(false);
50  }
51 }
52 
53 
54 std::vector<std::shared_ptr<KernelBase>> ProductKernel::GetSeperableComponents()
55 {
56  // Check if the dimensions of the components are distinct
57  bool isSeperable = true;
58  for(unsigned leftDim : kernel1->dimInds)
59  {
60  for(unsigned rightDim : kernel2->dimInds)
61  {
62  if(leftDim==rightDim)
63  {
64  isSeperable = false;
65  break;
66  }
67  }
68 
69  if(!isSeperable)
70  break;
71  }
72 
73  if(isSeperable)
74  {
75  std::vector<std::shared_ptr<KernelBase>> output, output2;
76  output = kernel1->GetSeperableComponents();
77  output2 = kernel2->GetSeperableComponents();
78  output.insert(output.end(), output2.begin(), output2.end());
79 
80  return output;
81  }
82  else
83  {
84  return std::vector<std::shared_ptr<KernelBase>>(1, this->Clone());
85  }
86 
87 };
88 
89 
90 
91 std::tuple<std::shared_ptr<muq::Modeling::LinearSDE>, std::shared_ptr<muq::Modeling::LinearOperator>, Eigen::MatrixXd> ProductKernel::GetStateSpace(boost::property_tree::ptree sdeOptions) const
92 {
93  auto periodicCast1 = std::dynamic_pointer_cast<PeriodicKernel>(kernel1);
94  auto periodicCast2 = std::dynamic_pointer_cast<PeriodicKernel>(kernel2);
95 
96  if(periodicCast1 && (!periodicCast2)){
97  return GetProductStateSpace(periodicCast1, kernel2, sdeOptions);
98  }else if((!periodicCast1) && periodicCast2){
99  return GetProductStateSpace(periodicCast2, kernel1, sdeOptions);
100  }else{
101  int status = 0;
102 
103 
104  std::string type1 = muq::Utilities::GetTypeName(kernel1);
105  std::string type2 = muq::Utilities::GetTypeName(kernel2);
106 
107  throw muq::NotImplementedError("ERROR in ProductKernel::GetStateSpace(). The GetStateSpace() function has not been implemented for these types: \"" + type1 + "\" and \"" + type2 + "\"");
108  }
109 };
110 
111 
112 
113 // See "Explicit Link Between Periodic
114 std::tuple<std::shared_ptr<muq::Modeling::LinearSDE>, std::shared_ptr<muq::Modeling::LinearOperator>, Eigen::MatrixXd> ProductKernel::GetProductStateSpace(std::shared_ptr<PeriodicKernel> const& kernel1,
115  std::shared_ptr<KernelBase> const& kernel2,
116  boost::property_tree::ptree sdeOptions) const
117 {
118 
119  auto periodicGP = kernel1->GetStateSpace(sdeOptions);
120  auto periodicSDE = std::get<0>(periodicGP);
121 
122  auto periodicF = std::dynamic_pointer_cast<muq::Modeling::BlockDiagonalOperator>(periodicSDE->GetF());
123  assert(periodicF);
124 
125  auto periodicL = std::dynamic_pointer_cast<muq::Modeling::BlockDiagonalOperator>(periodicSDE->GetL());
126  assert(periodicL);
127 
128  auto otherGP = kernel2->GetStateSpace(sdeOptions);
129  auto otherSDE = std::get<0>(otherGP);
130  auto otherF = otherSDE->GetF();
131  auto otherL = otherSDE->GetL();
132  auto otherH = std::get<1>(otherGP);
133 
135  std::vector<std::shared_ptr<muq::Modeling::LinearOperator>> newBlocks( periodicF->GetBlocks().size() );
136  for(int i=0; i<newBlocks.size(); ++i)
137  newBlocks.at(i) = muq::Modeling::KroneckerSum(otherF, periodicF->GetBlock(i) );
138 
139  auto newF = std::make_shared<muq::Modeling::BlockDiagonalOperator>(newBlocks);
140 
142  for(int i=0; i<newBlocks.size(); ++i)
143  newBlocks.at(i) = std::make_shared<muq::Modeling::KroneckerProductOperator>(otherL, periodicL->GetBlock(i) );
144 
145  auto newL = std::make_shared<muq::Modeling::BlockDiagonalOperator>(newBlocks);
146 
148  Eigen::MatrixXd Hblock(1,2);
149  Hblock << 1.0, 0.0;
150 
151  for(int i=0; i<newBlocks.size(); ++i)
152  newBlocks.at(i) = std::make_shared<muq::Modeling::KroneckerProductOperator>(otherH, muq::Modeling::LinearOperator::Create(Hblock) );
153 
154  auto newH = std::make_shared<muq::Modeling::BlockRowOperator>(newBlocks);
155 
156  // Construct Pinf
157  Eigen::MatrixXd periodicP = std::get<2>(periodicGP);
158  Eigen::MatrixXd otherP = std::get<2>(otherGP);
159 
160  Eigen::MatrixXd Pinf = Eigen::MatrixXd::Zero(periodicP.rows()*otherP.rows(), periodicP.cols()*otherP.cols());
161  for(int i=0; i<newBlocks.size(); ++i)
162  Pinf.block(2*i*otherP.rows(), 2*i*otherP.rows(), 2*otherP.rows(), 2*otherP.cols()) = muq::Modeling::KroneckerProduct(otherP, periodicP.block(2*i,2*i,2,2));
163 
164  // Construct Q
165  Eigen::MatrixXd const& otherQ = otherSDE->GetQ();
166  Eigen::MatrixXd Q = Eigen::MatrixXd::Zero(otherQ.rows()*periodicP.rows(), otherQ.cols()*periodicP.cols());
167  for(int i=0; i<newBlocks.size(); ++i)
168  Q.block(2*i*otherQ.rows(), 2*i*otherQ.rows(), 2*otherQ.rows(), 2*otherQ.cols()) = muq::Modeling::KroneckerProduct(otherQ, periodicP.block(2*i,2*i,2,2));
169 
170  // Construct the new statespace GP
171  auto newSDE = std::make_shared<muq::Modeling::LinearSDE>(newF, newL, Q, sdeOptions);
172  return std::make_tuple(newSDE, newH, Pinf);
173 }
174 
175 
176 std::shared_ptr<ProductKernel> muq::Approximation::operator*(std::shared_ptr<KernelBase> k1, std::shared_ptr<KernelBase> k2)
177 {
178  return std::make_shared<ProductKernel>(k1,k2);
179 }
Base class for all covariance kernels.
Definition: KernelBase.h:36
const unsigned int numParams
Definition: KernelBase.h:134
Eigen::VectorXd cachedParams
Definition: KernelBase.h:156
std::tuple< std::shared_ptr< muq::Modeling::LinearSDE >, std::shared_ptr< muq::Modeling::LinearOperator >, Eigen::MatrixXd > GetProductStateSpace(std::shared_ptr< PeriodicKernel > const &kernel1, std::shared_ptr< KernelBase > const &kernel2, boost::property_tree::ptree sdeOptions) const
virtual std::vector< std::shared_ptr< KernelBase > > GetSeperableComponents() override
Overridden by ProductKernel.
virtual void FillBlock(Eigen::Ref< const Eigen::VectorXd > const &x1, Eigen::Ref< const Eigen::VectorXd > const &x2, Eigen::Ref< const Eigen::VectorXd > const &params, Eigen::Ref< Eigen::MatrixXd > block) const override
virtual std::tuple< std::shared_ptr< muq::Modeling::LinearSDE >, std::shared_ptr< muq::Modeling::LinearOperator >, Eigen::MatrixXd > GetStateSpace(boost::property_tree::ptree sdeOptions=boost::property_tree::ptree()) const override
Returns a state space representation of the covariance kernel.
virtual std::shared_ptr< KernelBase > Clone() const override
Definition: ProductKernel.h:89
std::shared_ptr< KernelBase > kernel2
std::shared_ptr< KernelBase > kernel1
ProductKernel(std::shared_ptr< KernelBase > kernel1In, std::shared_ptr< KernelBase > kernel2In)
Class for virtual base functions that are not implemented.
Definition: Exceptions.h:22
LinearTransformMean< MeanType > operator*(Eigen::MatrixXd const &A, MeanType const &K)
Eigen::MatrixXd KroneckerProduct(Eigen::Ref< const Eigen::MatrixXd > const &A, Eigen::Ref< const Eigen::MatrixXd > const &B)
std::shared_ptr< LinearOperator > KroneckerSum(std::shared_ptr< LinearOperator > A, std::shared_ptr< LinearOperator > B)
std::string GetTypeName(PointerType const &ptr)
Definition: Demangler.h:15