forked from ricardorei/lightning-text-classification
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
21 lines (18 loc) · 695 Bytes
/
utils.py
File metadata and controls
21 lines (18 loc) · 695 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# -*- coding: utf-8 -*-
from datetime import datetime
import torch
def mask_fill(
fill_value: float,
tokens: torch.tensor,
embeddings: torch.tensor,
padding_index: int,
) -> torch.tensor:
"""
Function that masks embeddings representing padded elements.
:param fill_value: the value to fill the embeddings belonging to padded tokens.
:param tokens: The input sequences [bsz x seq_len].
:param embeddings: word embeddings [bsz x seq_len x hiddens].
:param padding_index: Index of the padding token.
"""
padding_mask = tokens.eq(padding_index).unsqueeze(-1)
return embeddings.float().masked_fill_(padding_mask, fill_value).type_as(embeddings)