Skip to content

Commit 0434fc8

Browse files
committed
[tmva][sofie] Restructure emitted code to be differentiable with Clad
The idea of this commit is to refactor the `doInfer()` function that implements the inference from a member function of the `Session` struct to a free function that takes the `Session` by `const`-reference.
1 parent e9a81b8 commit 0434fc8

7 files changed

Lines changed: 249 additions & 72 deletions

File tree

tmva/sofie/inc/TMVA/RModel.hxx

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,14 @@ private:
3434
std::vector<std::string> fDimShapeNames; // parameter names used to define the shapes
3535
std::vector<std::string> fOutputTensorNames;
3636
std::vector<std::string> fInputTensorNames; // input tensor names using ONNX order
37+
std::vector<std::string> fPointerMemberNames;
3738

39+
inline std::string AddTensorMember(std::string const &name) {
40+
fPointerMemberNames.push_back(name);
41+
return "tensor_" + name;
42+
}
3843

44+
bool IsInputTensorShapeParam(std::string const &name) const;
3945

4046
std::vector<std::unique_ptr<ROperator>> fOperators;
4147

tmva/sofie/inc/TMVA/ROperator_Gemm.hxx

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,9 +393,12 @@ namespace SOFIE{
393393
<< (fAttrTransB ? "true, " : "false, ")
394394
<< (fAttrTransA ? "true, " : "false, ")
395395
<< n << ", " << m << ", " << k << ", ";
396-
out << std::setprecision(std::numeric_limits<float>::max_digits10) << fAttrAlpha << ", tensor_" << fNB;
396+
// TODO: the cast to (float *) is not needed here from the C++ language perspective (the arguments to
397+
// Gemm_Call are const already), but Clad bug https://github.com/vgvassilev/clad/issues/1721 is requiring
398+
// us to do this cast to keep Clad working. Remove this hack once the Clad issue is fixed.
399+
out << std::setprecision(std::numeric_limits<float>::max_digits10) << fAttrAlpha << ", (float*)tensor_" << fNB;
397400
if (extraB) out << " + " << opName << "_B_offset";
398-
out << ", tensor_" << fNA;
401+
out << ", (float*)tensor_" << fNA; // TODO: same here
399402
if (extraA) out << " + " << opName << "_A_offset";
400403
out << ", " << std::setprecision(std::numeric_limits<float>::max_digits10) << fAttrBeta << ",";
401404
// in the case of bias and no broadcasting needed

tmva/sofie/inc/TMVA/ROperator_NonZero.hxx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,16 +101,16 @@ public:
101101
}
102102
}
103103
}
104+
104105
std::string GenerateSessionMembersCode(std::string /*opName*/) override {
105106
if (fIsOutputConstant) return "";
106107
// define output value used as max non zero with max size = input shape * N
107108
auto inputLength = ConvertDimShapeToLength(fShapeX);
108109
std::stringstream out;
109-
out << SP << "size_t v_NonZero_" << fNX << " = " << inputLength << ";\n";
110+
out << SP << "size_t fV_NonZero_" << fNX << " = " << inputLength << ";\n";
110111
return out.str();
111112
}
112113

113-
114114
std::string Generate(std::string opName) override {
115115
if (fIsOutputConstant) {
116116
return "";
@@ -133,7 +133,7 @@ public:
133133

134134
// loop on input indices
135135
out << SP << "size_t offset_" << opName << " = 0;\n";
136-
out << SP << vnonzero << " = 0;\n";
136+
out << SP << "size_t " << vnonzero << " = 0;\n";
137137
for (size_t j = 0; j < dims; j++) {
138138
std::string index = "i_" + std::to_string(j);
139139
for (size_t k = 0; k <= j; k++) out << SP;

tmva/sofie/inc/TMVA/SOFIE_common.hxx

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -681,16 +681,6 @@ void col2im(const Dtype* data_col, const int channels,
681681
//std::cout << "finishing col2imp" << std::endl;
682682
}
683683

684-
// Used at the end of infer() to fill the return object.
685-
template <class T>
686-
void FillOutput(T const *arr, std::vector<T> &out, std::size_t n)
687-
{
688-
out.resize(n);
689-
for (std::size_t i = 0; i < n; ++i) {
690-
out[i] = arr[i];
691-
}
692-
}
693-
694684
} // end namespace UTILITY
695685

696686
namespace BLAS{

0 commit comments

Comments
 (0)