MUQ  0.4.3
ConcatenateOperator.cpp
Go to the documentation of this file.
2 
4 
5 using namespace muq::Modeling;
6 
7 ConcatenateOperator::ConcatenateOperator(std::vector<std::shared_ptr<LinearOperator>> const& opsIn,
8  const int rowColIn) : LinearOperator(GetRows(opsIn, rowColIn), GetCols(opsIn, rowColIn)), ops(opsIn), rowCol(rowColIn)
9 {
10  CheckSizes();
11 };
12 
13 
15 Eigen::MatrixXd ConcatenateOperator::Apply(Eigen::Ref<const Eigen::MatrixXd> const& x)
16 {
17  Eigen::MatrixXd output = Eigen::MatrixXd::Zero(nrows, x.cols());
18 
19  if(rowCol==0){
20  int currRow = 0;
21  for(int i=0; i<ops.size(); ++i){
22  output.block(currRow,0,ops.at(i)->rows(), x.cols()) = ops.at(i)->Apply(x);
23  currRow += ops.at(i)->rows();
24  }
25  }else{
26  int currRow = 0;
27  for(int i=0; i<ops.size(); ++i){
28  output += ops.at(i)->Apply( x.block(currRow, 0, ops.at(i)->cols(), x.cols()) );
29  currRow += ops.at(i)->cols();
30  }
31  }
32 
33  return output;
34 }
35 
36 
38 Eigen::MatrixXd ConcatenateOperator::ApplyTranspose(Eigen::Ref<const Eigen::MatrixXd> const& x)
39 {
40  Eigen::MatrixXd output = Eigen::MatrixXd::Zero(ncols, x.cols());
41 
42  if(rowCol==0){
43  int currRow = 0;
44  for(int i=0; i<ops.size(); ++i){
45  output += ops.at(i)->ApplyTranspose( x.block(currRow, 0, ops.at(i)->rows(), x.cols()) );
46  currRow += ops.at(i)->rows();
47  }
48  }else{
49  int currRow = 0;
50  for(int i=0; i<ops.size(); ++i){
51  output.block(currRow,0, ops.at(i)->cols(), x.cols()) = ops.at(i)->ApplyTranspose(x);
52  currRow += ops.at(i)->cols();
53  }
54  }
55 
56  return output;
57 
58 }
59 
60 
61  std::shared_ptr<ConcatenateOperator> ConcatenateOperator::VStack(std::shared_ptr<LinearOperator> Ain,
62  std::shared_ptr<LinearOperator> Bin)
63  {
64  std::vector<std::shared_ptr<LinearOperator>> temp{Ain, Bin};
65  return std::make_shared<ConcatenateOperator>(temp, 0);
66  }
67 
68 std::shared_ptr<ConcatenateOperator> ConcatenateOperator::HStack(std::shared_ptr<LinearOperator> Ain,
69  std::shared_ptr<LinearOperator> Bin)
70 {
71  std::vector<std::shared_ptr<LinearOperator>> temp{Ain, Bin};
72  return std::make_shared<ConcatenateOperator>(temp,1);
73 }
74 
75 
77 {
78 
79  Eigen::MatrixXd output(nrows, ncols);
80  if(rowCol==0){
81  int currRow = 0;
82  for(int i=0; i<ops.size(); ++i){
83  output.block(currRow, 0, ops.at(i)->rows(), ncols) = ops.at(i)->GetMatrix();
84  currRow += ops.at(i)->rows();
85  }
86  }else{
87  int currCol = 0;
88  for(int i=0; i<ops.size(); ++i){
89  output.block(0,currCol, nrows, ops.at(i)->cols()) = ops.at(i)->GetMatrix();
90  currCol += ops.at(i)->cols();
91  }
92  }
93  return output;
94 }
95 
96 int ConcatenateOperator::GetRows(std::vector<std::shared_ptr<LinearOperator>> const& opsIn,
97  const int rowColIn)
98 {
99  assert(opsIn.size()>0);
100 
101  if(rowColIn==0){
102  int count = 0;
103  for(auto& op : opsIn)
104  count += op->rows();
105  return count;
106  }else{
107  return opsIn.at(0)->rows();
108  }
109 }
110 
111 
112 int ConcatenateOperator::GetCols(std::vector<std::shared_ptr<LinearOperator>> const& opsIn,
113  const int rowColIn)
114 {
115  assert(opsIn.size()>0);
116 
117  if(rowColIn==0){
118  return opsIn.at(0)->cols();
119  }else{
120  int count = 0;
121  for(auto& op : opsIn)
122  count += op->cols();
123  return count;
124  }
125 }
126 
127 
129 {
130  if(rowCol==0){
131  const int fixedCols = ops.at(0)->cols();
132  for(int i=1; i<ops.size(); ++i){
133  if(fixedCols != ops.at(i)->cols())
134  throw muq::WrongSizeError("In ConcatenateOperator: Cannot vertically stack operators with different number of columns. Matrix A has " + std::to_string(fixedCols) + " columns but matrix B has " + std::to_string(ops.at(i)->cols()) + " columns.");
135  }
136 
137  }else{
138  const int fixedRows = ops.at(0)->rows();
139  for(int i=1; i<ops.size(); ++i){
140  if(fixedRows != ops.at(i)->rows())
141  throw muq::WrongSizeError("In ConcatenateOperator: Cannot horizontally stack operators with different number of rows. Matrix A has " + std::to_string(fixedRows) + " rows but matrix B has " + std::to_string(ops.at(i)->rows()) + " rows.");
142  }
143  }
144 }
static std::shared_ptr< ConcatenateOperator > HStack(std::shared_ptr< LinearOperator > Ain, std::shared_ptr< LinearOperator > Bin)
static std::shared_ptr< ConcatenateOperator > VStack(std::shared_ptr< LinearOperator > Ain, std::shared_ptr< LinearOperator > Bin)
virtual Eigen::MatrixXd Apply(Eigen::Ref< const Eigen::MatrixXd > const &x) override
static int GetCols(std::vector< std::shared_ptr< LinearOperator >> const &opsIn, const int rowColIn)
std::vector< std::shared_ptr< LinearOperator > > ops
static int GetRows(std::vector< std::shared_ptr< LinearOperator >> const &opsIn, const int rowColIn)
virtual Eigen::MatrixXd GetMatrix() override
ConcatenateOperator(std::vector< std::shared_ptr< LinearOperator >> const &opsIn, const int rowColIn)
virtual Eigen::MatrixXd ApplyTranspose(Eigen::Ref< const Eigen::MatrixXd > const &x) override
Generic linear operator base class.
Exception to throw when matrices, vectors, or arrays are the wrong size.
Definition: Exceptions.h:58
NLOHMANN_BASIC_JSON_TPL_DECLARATION std::string to_string(const NLOHMANN_BASIC_JSON_TPL &j)
user-defined to_string function for JSON values
Definition: json.h:25172