Skip to content

Commit cb9b348

Browse files
committed
[tmva][sofie] Use this pointer when accessing Session data members
This makes the code easier to reason about for both human readers, and automatic post-processing steps where usage of Session data members needs to be identified. This is done in preparation for the SOFIE-emitted code refactor to make it differentiable with Clad.
1 parent f7a2844 commit cb9b348

3 files changed

Lines changed: 20 additions & 20 deletions

File tree

tmva/sofie/inc/TMVA/ROperator_LSTM.icc

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ auto ROperator_LSTM<T>::Generate(std::string OpName)
295295
out << SP << fType << " const *" << OpName << "_input = tensor_" << fNX << ";\n";
296296
} else {
297297
if (fUseSession)
298-
out << SP << fType << " * " << OpName << "_input = fVec_" << OpName << "_input.data();\n";
298+
out << SP << fType << " * " << OpName << "_input = this->fVec_" << OpName << "_input.data();\n";
299299
else
300300
out << SP << fType << " " << OpName << "_input[" << seq_length * batch_size * input_size << "] = {0};\n";
301301

@@ -317,7 +317,7 @@ auto ROperator_LSTM<T>::Generate(std::string OpName)
317317
<< fNInitial_h << ";\n";
318318
} else {
319319
if (fUseSession)
320-
out << SP << fType << " * " << OpName << "_initial_hidden_state = fVec_" << OpName
320+
out << SP << fType << " * " << OpName << "_initial_hidden_state = this->fVec_" << OpName
321321
<< "_initial_hidden_state.data();\n";
322322
else
323323
out << SP << fType << " " << OpName << "_initial_hidden_state[" << num_directions * batch_size *
@@ -343,7 +343,7 @@ auto ROperator_LSTM<T>::Generate(std::string OpName)
343343
<< fNInitial_c << ";\n";
344344
} else {
345345
if (fUseSession)
346-
out << SP << fType << " * " << OpName << "_initial_cell_state = fVec_" << OpName
346+
out << SP << fType << " * " << OpName << "_initial_cell_state = this->fVec_" << OpName
347347
<< "_initial_cell_state.data();\n";
348348
else
349349
out << SP << fType << " " << OpName << "_initial_cell_state[" << num_directions * batch_size *
@@ -365,11 +365,11 @@ auto ROperator_LSTM<T>::Generate(std::string OpName)
365365
// Set the feedforward
366366
size_t ff_size = seq_length * batch_size * fAttrHiddenSize;
367367
if (fUseSession) {
368-
out << SP << fType << " * " << OpName << "_ff_input_gate = fVec_" << OpName << "_ff_input_gate.data();\n";
369-
out << SP << fType << " * " << OpName << "_ff_output_gate = fVec_" << OpName << "_ff_output_gate.data();\n";
370-
out << SP << fType << " * " << OpName << "_ff_cell_gate = fVec_" << OpName << "_ff_cell_gate.data();\n";
368+
out << SP << fType << " * " << OpName << "_ff_input_gate = this->fVec_" << OpName << "_ff_input_gate.data();\n";
369+
out << SP << fType << " * " << OpName << "_ff_output_gate = this->fVec_" << OpName << "_ff_output_gate.data();\n";
370+
out << SP << fType << " * " << OpName << "_ff_cell_gate = this->fVec_" << OpName << "_ff_cell_gate.data();\n";
371371
if (fAttrInputForget == 0) {
372-
out << SP << fType << " * " << OpName << "_ff_forget_gate = fVec_" << OpName << "_ff_forget_gate.data();\n";
372+
out << SP << fType << " * " << OpName << "_ff_forget_gate = this->fVec_" << OpName << "_ff_forget_gate.data();\n";
373373
}
374374
} else {
375375
out << SP << fType << " " << OpName << "_ff_input_gate[" << ff_size << "] = {0};\n";
@@ -382,11 +382,11 @@ auto ROperator_LSTM<T>::Generate(std::string OpName)
382382
// Set the gates
383383
size_t hidden_state_size = seq_length * num_directions * batch_size * fAttrHiddenSize;
384384
if (fUseSession) {
385-
out << SP << fType << " * " << OpName << "_input_gate = fVec_" << OpName << "_input_gate.data();\n";
386-
out << SP << fType << " * " << OpName << "_output_gate = fVec_" << OpName << "_output_gate.data();\n";
387-
out << SP << fType << " * " << OpName << "_cell_gate = fVec_" << OpName << "_cell_gate.data();\n";
385+
out << SP << fType << " * " << OpName << "_input_gate = this->fVec_" << OpName << "_input_gate.data();\n";
386+
out << SP << fType << " * " << OpName << "_output_gate = this->fVec_" << OpName << "_output_gate.data();\n";
387+
out << SP << fType << " * " << OpName << "_cell_gate = this->fVec_" << OpName << "_cell_gate.data();\n";
388388
if (fAttrInputForget == 0) {
389-
out << SP << fType << " * " << OpName << "_forget_gate = fVec_" << OpName << "_forget_gate.data();\n";
389+
out << SP << fType << " * " << OpName << "_forget_gate = this->fVec_" << OpName << "_forget_gate.data();\n";
390390
}
391391
} else {
392392
out << SP << fType << " " << OpName << "_input_gate[" << hidden_state_size << "] = {0};\n";
@@ -398,8 +398,8 @@ auto ROperator_LSTM<T>::Generate(std::string OpName)
398398
}
399399
// Set the cell state and the new cell state = h(cell state)
400400
if (fUseSession) {
401-
out << SP << fType << " * " << OpName << "_cell_state = fVec_" << OpName << "_cell_state.data();\n";
402-
out << SP << fType << " * " << OpName << "_new_cell_state = fVec_" << OpName << "_new_cell_state.data();\n";
401+
out << SP << fType << " * " << OpName << "_cell_state = this->fVec_" << OpName << "_cell_state.data();\n";
402+
out << SP << fType << " * " << OpName << "_new_cell_state = this->fVec_" << OpName << "_new_cell_state.data();\n";
403403
} else {
404404
out << SP << fType << " " << OpName << "_cell_state[" << hidden_state_size << "] = {0};\n";
405405
out << SP << fType << " " << OpName << "_new_cell_state[" << hidden_state_size << "] = {0};\n";
@@ -410,7 +410,7 @@ auto ROperator_LSTM<T>::Generate(std::string OpName)
410410
out << SP << fType << " *" << OpName << "_hidden_state = tensor_" << fNY << ";\n";
411411
} else {
412412
if (fUseSession) {
413-
out << SP << fType << " * " << OpName << "_hidden_state = fVec_" << OpName << "_hidden_state.data();\n";
413+
out << SP << fType << " * " << OpName << "_hidden_state = this->fVec_" << OpName << "_hidden_state.data();\n";
414414
} else {
415415
out << SP << fType << " " << OpName << "_hidden_state[" << hidden_state_size << "] = {0};\n";
416416
}

tmva/sofie/inc/TMVA/ROperator_RNN.icc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ auto ROperator_RNN<T>::Generate(std::string OpName)
235235
}
236236
} else {
237237
if (fUseSession)
238-
out << SP << fType << " * " << OpName << "_input = fVec_" << OpName << "_input.data();\n";
238+
out << SP << fType << " * " << OpName << "_input = this->fVec_" << OpName << "_input.data();\n";
239239
else
240240
out << SP << fType << " " << OpName << "_input[" << seq_length * batch_size * input_size << "];\n";
241241
out << SP << "for(size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
@@ -256,7 +256,7 @@ auto ROperator_RNN<T>::Generate(std::string OpName)
256256
<< fNInitial_h << ";\n";
257257
} else {
258258
if (fUseSession)
259-
out << SP << fType << " * " << OpName << "_initial_hidden_state = fVec_" << OpName
259+
out << SP << fType << " * " << OpName << "_initial_hidden_state = this->fVec_" << OpName
260260
<< "_initial_hidden_state.data();\n";
261261
else
262262
out << fType << " " << OpName << "_initial_hidden_state[" << num_directions * batch_size *
@@ -276,7 +276,7 @@ auto ROperator_RNN<T>::Generate(std::string OpName)
276276
}
277277

278278
if (fUseSession)
279-
out << SP << fType << " * " << OpName << "_feedforward = fVec_" << OpName
279+
out << SP << fType << " * " << OpName << "_feedforward = this->fVec_" << OpName
280280
<< "_feedforward.data();\n";
281281
else
282282
out << SP << fType << " " << OpName << "_feedforward[" << seq_length * batch_size * fAttrHiddenSize << "] = {0};\n";
@@ -286,7 +286,7 @@ auto ROperator_RNN<T>::Generate(std::string OpName)
286286
out << SP << fType << " *" << OpName << "_hidden_state = tensor_" << fNY << ";\n";
287287
} else {
288288
if (fUseSession)
289-
out << SP << fType << " * " << OpName << "_hidden_state = fVec_" << OpName << "_hidden_state.data();\n";
289+
out << SP << fType << " * " << OpName << "_hidden_state = this->fVec_" << OpName << "_hidden_state.data();\n";
290290
else
291291
out << SP << fType << " " << OpName << "_hidden_state[" << seq_length * num_directions *
292292
batch_size * fAttrHiddenSize << "] = {0};\n";

tmva/sofie/inc/TMVA/ROperator_Random.hxx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,13 @@ public:
125125
throw std::runtime_error("TMVA SOFIE RandomNormal op : no mean or scale are defined");
126126
float mean = fParams["mean"];
127127
float scale = fParams["scale"];
128-
out << SP << SP << "tensor_" << fNY << "[i] = fRndmEngine->Gaus(" << mean << "," << scale << ");\n";
128+
out << SP << SP << "tensor_" << fNY << "[i] = this->fRndmEngine->Gaus(" << mean << "," << scale << ");\n";
129129
} else if (fMode == kUniform) {
130130
if (fParams.count("high") == 0 || fParams.count("low") == 0)
131131
throw std::runtime_error("TMVA SOFIE RandomUniform op : no low or high are defined");
132132
float high = fParams["high"];
133133
float low = fParams["low"];
134-
out << SP << SP << "tensor_" << fNY << "[i] = fRndmEngine->Uniform(" << low << "," << high << ");\n";
134+
out << SP << SP << "tensor_" << fNY << "[i] = this->fRndmEngine->Uniform(" << low << "," << high << ");\n";
135135
}
136136
}
137137
out << SP << "}\n";

0 commit comments

Comments
 (0)