Skip to content

Commit 575f9e0

Browse files
committed
[tmva][sofie] Restructure emitted code to be differentiable with Clad
1 parent 5acd4ab commit 575f9e0

7 files changed

Lines changed: 129 additions & 88 deletions

File tree

tmva/sofie/inc/TMVA/RModel.hxx

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,12 @@ private:
3434
std::vector<std::string> fDimShapeNames; // parameter names used to define the shapes
3535
std::vector<std::string> fOutputTensorNames;
3636
std::vector<std::string> fInputTensorNames; // input tensor names using ONNX order
37+
std::vector<std::string> fPointerMemberNames;
3738

38-
39+
inline std::string AddTensorMember(std::string const &name) {
40+
fPointerMemberNames.push_back(name);
41+
return "tensor_" + name;
42+
}
3943

4044
std::vector<std::unique_ptr<ROperator>> fOperators;
4145

tmva/sofie/inc/TMVA/ROperator_Gemm.hxx

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,9 +393,12 @@ namespace SOFIE{
393393
<< (fAttrTransB ? "true, " : "false, ")
394394
<< (fAttrTransA ? "true, " : "false, ")
395395
<< n << ", " << m << ", " << k << ", ";
396-
out << std::setprecision(std::numeric_limits<float>::max_digits10) << fAttrAlpha << ", tensor_" << fNB;
396+
// TODO: the cast to (float *) is not needed here from the C++ language perspective (the arguments to
397+
// Gemm_Call are const already), but Clad bug https://github.com/vgvassilev/clad/issues/1721 is requiring
398+
// us to do this cast to keep Clad working. Remove this hack once the Clad issue is fixed.
399+
out << std::setprecision(std::numeric_limits<float>::max_digits10) << fAttrAlpha << ", (float*)tensor_" << fNB;
397400
if (extraB) out << " + " << opName << "_B_offset";
398-
out << ", tensor_" << fNA;
401+
out << ", (float*)tensor_" << fNA; // TODO: same here
399402
if (extraA) out << " + " << opName << "_A_offset";
400403
out << ", " << std::setprecision(std::numeric_limits<float>::max_digits10) << fAttrBeta << ",";
401404
// in the case of bias and no broadcasting needed

tmva/sofie/inc/TMVA/ROperator_LSTM.icc

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ std::string ROperator_LSTM<T>::GenerateSessionMembersCode(std::string opName)
245245
size_t input_size = fShapeX[2];
246246

247247
if (fAttrLayout != 0) {
248-
out << "std::vector<" << fType << "> fVec_" << opName << "_input = std::vector<" << fType << ">("
248+
out << "mutable std::vector<" << fType << "> fVec_" << opName << "_input = std::vector<" << fType << ">("
249249
<< seq_length * batch_size * input_size << ");\n";
250250
out << "std::vector<" << fType << "> fVec_" << opName << "_initial_hidden_state = std::vector<" << fType << ">("
251251
<< num_directions * batch_size * fAttrHiddenSize << ");\n";
@@ -254,24 +254,24 @@ std::string ROperator_LSTM<T>::GenerateSessionMembersCode(std::string opName)
254254
}
255255
// Set the feedforward
256256
size_t ff_size = seq_length * batch_size * fAttrHiddenSize;
257-
out << "std::vector<" << fType << "> fVec_" << opName << "_ff_input_gate = std::vector<" << fType << ">(" << ff_size << ");\n";
258-
out << "std::vector<" << fType << "> fVec_" << opName << "_ff_output_gate = std::vector<" << fType << ">(" << ff_size << ");\n";
259-
out << "std::vector<" << fType << "> fVec_" << opName << "_ff_cell_gate = std::vector<" << fType << ">(" << ff_size << ");\n";
257+
out << "mutable std::vector<" << fType << "> fVec_" << opName << "_ff_input_gate = std::vector<" << fType << ">(" << ff_size << ");\n";
258+
out << "mutable std::vector<" << fType << "> fVec_" << opName << "_ff_output_gate = std::vector<" << fType << ">(" << ff_size << ");\n";
259+
out << "mutable std::vector<" << fType << "> fVec_" << opName << "_ff_cell_gate = std::vector<" << fType << ">(" << ff_size << ");\n";
260260
if (fAttrInputForget == 0)
261-
out << "std::vector<" << fType << "> fVec_" << opName << "_ff_forget_gate = std::vector<" << fType << ">(" << ff_size << ");\n";
261+
out << "mutable std::vector<" << fType << "> fVec_" << opName << "_ff_forget_gate = std::vector<" << fType << ">(" << ff_size << ");\n";
262262
// gate results
263263
size_t hs_size = seq_length * num_directions * batch_size * fAttrHiddenSize;
264-
out << "std::vector<" << fType << "> fVec_" << opName << "_input_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
265-
out << "std::vector<" << fType << "> fVec_" << opName << "_output_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
266-
out << "std::vector<" << fType << "> fVec_" << opName << "_cell_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
264+
out << "mutable std::vector<" << fType << "> fVec_" << opName << "_input_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
265+
out << "mutable std::vector<" << fType << "> fVec_" << opName << "_output_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
266+
out << "mutable std::vector<" << fType << "> fVec_" << opName << "_cell_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
267267
if (fAttrInputForget == 0)
268-
out << "std::vector<" << fType << "> fVec_" << opName << "_forget_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
268+
out << "mutable std::vector<" << fType << "> fVec_" << opName << "_forget_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
269269
// cell state
270-
out << "std::vector<" << fType << "> fVec_" << opName << "_cell_state = std::vector<" << fType << ">(" << hs_size << ");\n";
271-
out << "std::vector<" << fType << "> fVec_" << opName << "_new_cell_state = std::vector<" << fType << ">(" << hs_size << ");\n";
270+
out << "mutable std::vector<" << fType << "> fVec_" << opName << "_cell_state = std::vector<" << fType << ">(" << hs_size << ");\n";
271+
out << "mutable std::vector<" << fType << "> fVec_" << opName << "_new_cell_state = std::vector<" << fType << ">(" << hs_size << ");\n";
272272
// hiddden state
273273
if (fAttrLayout != 0 || fNY.empty()) {
274-
out << "std::vector<" << fType << "> fVec_" << opName << "_hidden_state = std::vector<" << fType << ">(" << hs_size << ");\n";
274+
out << "mutable std::vector<" << fType << "> fVec_" << opName << "_hidden_state = std::vector<" << fType << ">(" << hs_size << ");\n";
275275
}
276276

277277
out << "\n";
@@ -313,11 +313,11 @@ auto ROperator_LSTM<T>::Generate(std::string OpName)
313313
// Set the initial hidden state
314314
if (!fNInitial_h.empty()) {
315315
if (fAttrLayout == 0) {
316-
out << SP << fType << " *" << OpName << "_initial_hidden_state = " << " tensor_"
316+
out << SP << fType << " const*" << OpName << "_initial_hidden_state = " << " tensor_"
317317
<< fNInitial_h << ";\n";
318318
} else {
319319
if (fUseSession)
320-
out << SP << fType << " * " << OpName << "_initial_hidden_state = this->fVec_" << OpName
320+
out << SP << fType << " const* " << OpName << "_initial_hidden_state = this->fVec_" << OpName
321321
<< "_initial_hidden_state.data();\n";
322322
else
323323
out << SP << fType << " " << OpName << "_initial_hidden_state[" << num_directions * batch_size *
@@ -339,11 +339,11 @@ auto ROperator_LSTM<T>::Generate(std::string OpName)
339339
// Set the initial cell state
340340
if (!fNInitial_c.empty()) {
341341
if (fAttrLayout == 0) {
342-
out << SP << fType << " *" << OpName << "_initial_cell_state = " << " tensor_"
342+
out << SP << fType << " const*" << OpName << "_initial_cell_state = " << " tensor_"
343343
<< fNInitial_c << ";\n";
344344
} else {
345345
if (fUseSession)
346-
out << SP << fType << " * " << OpName << "_initial_cell_state = this->fVec_" << OpName
346+
out << SP << fType << " const* " << OpName << "_initial_cell_state = this->fVec_" << OpName
347347
<< "_initial_cell_state.data();\n";
348348
else
349349
out << SP << fType << " " << OpName << "_initial_cell_state[" << num_directions * batch_size *

tmva/sofie/inc/TMVA/ROperator_RNN.icc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,16 +198,16 @@ std::string ROperator_RNN<T>::GenerateSessionMembersCode(std::string opName)
198198
size_t input_size = fShapeX[2];
199199

200200
if (fAttrLayout != 0) {
201-
out << "std::vector<" << fType << "> fVec_" << opName << "_input = std::vector<" << fType << ">("
201+
out << "mutable std::vector<" << fType << "> fVec_" << opName << "_input = std::vector<" << fType << ">("
202202
<< seq_length * batch_size * input_size << ");\n";
203-
out << "std::vector<" << fType << "> fVec_" << opName << "_initial_hidden_state = std::vector<" << fType << ">("
203+
out << "mutable std::vector<" << fType << "> fVec_" << opName << "_initial_hidden_state = std::vector<" << fType << ">("
204204
<< num_directions * batch_size * fAttrHiddenSize << ");\n";
205205
}
206-
out << "std::vector<" << fType << "> fVec_" << opName << "_feedforward = std::vector<" << fType << ">("
206+
out << "mutable std::vector<" << fType << "> fVec_" << opName << "_feedforward = std::vector<" << fType << ">("
207207
<< seq_length * batch_size * fAttrHiddenSize << ");\n";
208208

209209
if (fAttrLayout != 0 || fNY.empty()) {
210-
out << "std::vector<" << fType << "> fVec_" << opName << "_hidden_state = std::vector<" << fType << ">("
210+
out << "mutable std::vector<" << fType << "> fVec_" << opName << "_hidden_state = std::vector<" << fType << ">("
211211
<< seq_length * num_directions * batch_size * fAttrHiddenSize << ");\n";
212212
}
213213

tmva/sofie/inc/TMVA/SOFIE_common.hxx

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -685,16 +685,6 @@ void col2im(const Dtype* data_col, const int channels,
685685
//std::cout << "finishing col2imp" << std::endl;
686686
}
687687

688-
// Used at the end of infer() to fill the return object.
689-
template <class T>
690-
void FillOutput(T const *arr, std::vector<T> &out, std::size_t n)
691-
{
692-
out.resize(n);
693-
for (std::size_t i = 0; i < n; ++i) {
694-
out[i] = arr[i];
695-
}
696-
}
697-
698688
} // end namespace UTILITY
699689

700690
namespace BLAS{

0 commit comments

Comments
 (0)