MUQ  0.4.3
LinearTransformKernel.h
Go to the documentation of this file.
1 #ifndef LINEARTRANSFORMKERNEL_H
2 #define LINEARTRANSFORMKERNEL_H
3 
5 
6 
7 namespace muq
8 {
9 namespace Approximation
10 {
11 
23 {
24 
25 public:
26 
27 
28  LinearTransformKernel(Eigen::MatrixXd const& Ain,
29  std::shared_ptr<KernelBase> Kin) : KernelBase(Kin->inputDim, Ain.rows(), Kin->numParams),
30  A(Ain), K(Kin)
31  {
32  assert(Ain.cols() == Kin->coDim);
33  cachedParams = Kin->GetParams();
34  };
35 
37 
38 
39  virtual void FillBlock(Eigen::Ref<const Eigen::VectorXd> const& x1,
40  Eigen::Ref<const Eigen::VectorXd> const& x2,
41  Eigen::Ref<const Eigen::VectorXd> const& params,
42  Eigen::Ref<Eigen::MatrixXd> block) const override
43  {
44  Eigen::MatrixXd temp(K->coDim,K->coDim);
45  K->FillBlock(x1,x2, params, temp);
46  block = A*temp*A.transpose();
47  }
48 
49  virtual void FillPosDerivBlock(Eigen::Ref<const Eigen::VectorXd> const& x1,
50  Eigen::Ref<const Eigen::VectorXd> const& x2,
51  Eigen::Ref<const Eigen::VectorXd> const& params,
52  std::vector<int> const& wrts,
53  Eigen::Ref<Eigen::MatrixXd> block) const override
54  {
55  Eigen::MatrixXd temp(K->coDim,K->coDim);
56  K->FillPosDerivBlock(x1,x2, params, wrts, temp);
57  block = A*temp*A.transpose();
58  }
59 
60  virtual std::shared_ptr<KernelBase> Clone() const override{return std::make_shared<LinearTransformKernel>(A,K);};
61 
62 private:
63  Eigen::MatrixXd A;
64  std::shared_ptr<KernelBase> K;
65 
66 };
67 
68 
69 template<typename KernelType>
70 LinearTransformKernel TransformKernel(Eigen::MatrixXd const& A, KernelType const& K)
71 {
72  return LinearTransformKernel(A,K.Clone());
73 }
74 
75 template<typename KernelType,
77 LinearTransformKernel operator*(Eigen::MatrixXd const& A, KernelType const& kernel)
78 {
79  return LinearTransformKernel(A, kernel.Clone());
80 }
81 
82 }
83 }
84 
85 #endif
Base class for all covariance kernels.
Definition: KernelBase.h:36
const unsigned int inputDim
Definition: KernelBase.h:132
const unsigned int numParams
Definition: KernelBase.h:134
Eigen::VectorXd cachedParams
Definition: KernelBase.h:156
virtual void FillPosDerivBlock(Eigen::Ref< const Eigen::VectorXd > const &x1, Eigen::Ref< const Eigen::VectorXd > const &x2, Eigen::Ref< const Eigen::VectorXd > const &params, std::vector< int > const &wrts, Eigen::Ref< Eigen::MatrixXd > block) const override
LinearTransformKernel(Eigen::MatrixXd const &Ain, std::shared_ptr< KernelBase > Kin)
virtual void FillBlock(Eigen::Ref< const Eigen::VectorXd > const &x1, Eigen::Ref< const Eigen::VectorXd > const &x2, Eigen::Ref< const Eigen::VectorXd > const &params, Eigen::Ref< Eigen::MatrixXd > block) const override
virtual std::shared_ptr< KernelBase > Clone() const override
LinearTransformMean< MeanType > operator*(Eigen::MatrixXd const &A, MeanType const &K)
LinearTransformKernel TransformKernel(Eigen::MatrixXd const &A, KernelType const &K)
int int FloatType value
Definition: json.h:15223