Skip to content

Commit 948f8f2

Browse files
committed
Add return_attention support to Generator
The decoding engine already computes attention weights when requested, but this was only wired through the Translator API. This exposes the same capability for decoder-only models (Generator) by propagating the return_attention flag from GenerationOptions to DecodingOptions and transferring the attention data back to GenerationResult.
1 parent 226c95d commit 948f8f2

4 files changed

Lines changed: 19 additions & 3 deletions

File tree

include/ctranslate2/generation.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ namespace ctranslate2 {
5353

5454
// Include scores in the result.
5555
bool return_scores = false;
56+
// Store attention vectors in the GenerationResult class.
57+
bool return_attention = false;
5658
// Include log probs of each token in the result
5759
bool return_logits_vocab = false;
5860

@@ -81,6 +83,7 @@ namespace ctranslate2 {
8183
std::vector<std::vector<std::string>> sequences;
8284
std::vector<std::vector<size_t>> sequences_ids;
8385
std::vector<float> scores;
86+
std::vector<std::vector<std::vector<float>>> attention;
8487
std::vector<std::vector<StorageView>> logits;
8588

8689
size_t num_sequences() const {

python/cpp/generation_result.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,16 @@ namespace ctranslate2 {
4949
"Generated sequences of token IDs.")
5050
.def_readonly("scores", &GenerationResult::scores,
5151
"Score of each sequence (empty if :obj:`return_scores` was disabled).")
52+
.def_readonly("attention", &GenerationResult::attention,
53+
"Attention matrix of each sequence (empty if :obj:`return_attention` was disabled).")
5254
.def_readonly("logits", &GenerationResult::logits,
5355
"Logits of each sequence (empty if :obj:`return_logits_vocab` was disabled).")
5456

5557
.def("__repr__", [](const GenerationResult& result) {
5658
return "GenerationResult(sequences=" + std::string(py::repr(py::cast(result.sequences)))
5759
+ ", sequences_ids=" + std::string(py::repr(py::cast(result.sequences_ids)))
5860
+ ", scores=" + std::string(py::repr(py::cast(result.scores)))
61+
+ ", attention=" + std::string(py::repr(py::cast(result.attention)))
5962
+ ", logits=" + std::string(py::repr(py::cast(result.logits)))
6063
+ ")";
6164
})

python/cpp/generator.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ namespace ctranslate2 {
3333
bool cache_static_prompt,
3434
bool include_prompt_in_result,
3535
bool return_scores,
36+
bool return_attention,
3637
bool return_logits_vocab,
3738
bool return_alternatives,
3839
float min_alternative_expansion_prob,
@@ -59,6 +60,7 @@ namespace ctranslate2 {
5960
options.num_hypotheses = num_hypotheses;
6061
options.return_end_token = return_end_token;
6162
options.return_scores = return_scores;
63+
options.return_attention = return_attention;
6264
options.return_logits_vocab = return_logits_vocab;
6365
options.return_alternatives = return_alternatives;
6466
options.cache_static_prompt = cache_static_prompt;
@@ -205,6 +207,7 @@ namespace ctranslate2 {
205207
py::arg("cache_static_prompt")=true,
206208
py::arg("include_prompt_in_result")=true,
207209
py::arg("return_scores")=false,
210+
py::arg("return_attention")=false,
208211
py::arg("return_logits_vocab")=false,
209212
py::arg("return_alternatives")=false,
210213
py::arg("min_alternative_expansion_prob")=0,
@@ -263,6 +266,7 @@ namespace ctranslate2 {
263266
reuse it for future generations using the same static prompt.
264267
include_prompt_in_result: Include the :obj:`start_tokens` in the result.
265268
return_scores: Include the scores in the output.
269+
return_attention: Include the attention matrices in the output.
266270
return_logits_vocab: Include log probs for each token in the output
267271
return_alternatives: Return alternatives at the first unconstrained decoding position.
268272
min_alternative_expansion_prob: Minimum initial probability to expand an alternative.

src/models/language_model.cc

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ namespace ctranslate2 {
165165
decoding_options.sampling_temperature = options.sampling_temperature;
166166
decoding_options.num_hypotheses = options.num_hypotheses;
167167
decoding_options.return_scores = options.return_scores;
168+
decoding_options.return_attention = options.return_attention;
168169
decoding_options.return_logits_vocab = options.return_logits_vocab;
169170
decoding_options.return_alternatives = options.return_alternatives;
170171
decoding_options.min_alternative_expansion_prob = options.min_alternative_expansion_prob;
@@ -251,9 +252,13 @@ namespace ctranslate2 {
251252

252253
// Remove EOS token.
253254
if (!options.return_end_token) {
254-
for (auto& sequence : result.hypotheses) {
255-
while (!sequence.empty() && is_eos(sequence.back(), end_ids))
256-
sequence.pop_back();
255+
for (size_t h = 0; h < result.hypotheses.size(); ++h) {
256+
while (!result.hypotheses[h].empty()
257+
&& is_eos(result.hypotheses[h].back(), end_ids)) {
258+
result.hypotheses[h].pop_back();
259+
if (!result.attention.empty())
260+
result.attention[h].pop_back();
261+
}
257262
}
258263
}
259264

@@ -269,6 +274,7 @@ namespace ctranslate2 {
269274
final_result.sequences = vocabulary.to_tokens(result.hypotheses);
270275
final_result.sequences_ids = std::move(result.hypotheses);
271276
final_result.scores = std::move(result.scores);
277+
final_result.attention = std::move(result.attention);
272278
final_result.logits = std::move(result.logits_vocab);
273279
final_results.emplace_back(std::move(final_result));
274280
}

0 commit comments

Comments
 (0)