PyTorch Reshaping with None

原始链接: https://blog.detorch.xyz/post/2025-06-21-pytorch-reshaping-with-none.md

The `sequence_mask` function in the Dive into Deep Learning book is used to mask padding tokens in sequences during attention calculations. This is necessary when dealing with sequences of variable lengths, where shorter sequences are padded to match the maximum length. The function takes a 2D tensor `X` (representing sequences) and a 1D tensor `valid_len` (representing the actual length of each sequence) as input. It creates a boolean mask where `True` indicates valid (non-padded) tokens and `False` indicates padded tokens. This mask is generated by comparing a range of indices (0 to `max_len`-1) with the `valid_len` for each sequence using broadcasting. Elements in `X` corresponding to `False` in the mask are replaced with a specified `value` (typically a large negative number) to prevent attention from being paid to them. The code then presents an equivalent implementation using `reshape` instead of `None` indexing, arguing it's more readable. Both versions achieve the same outcome: creating a mask to effectively ignore padding tokens during attention.

Hacker News new | past | comments | ask | show | jobs | submit login PyTorch Reshaping with None (detorch.xyz)18 points by demirbey05 1 day ago | hide | past | favorite | discuss Consider applying for YC's Fall 2025 batch! Applications are open till Aug 4 Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact Search:
相关文章

原文

Currently I am learning attention mechanism from Dive into Deep Learning book. In the book I see following implementation in masked softmax:

def sequence_mask(X, valid_len, value= -1e6):
    """ X is 2D array (number_of_points, maxlen), valid_len is 1D array (number_of_points)"""
    max_len = X.size(1)

    mask = torch.arange(max_len, dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None]
    X[~mask] = value
    return X

In sequential data processing, I mean processing natural language. The sequence length might be variable for each data point. For example :

1 : "Welcome To My Blog"

2 : "Hello World"

To solve that problem , we fill remaining values with a special token.

1 : "Welcome To My Blog"

2 : "Hello World blnk blnk"

In attention, we do not want to attend to blnk tokens. So we create mask for that. In the code portion max_len is the maximum length of the sequence and valid_len is the actual length of the sequence. I mean for 1st data point valid_len is 3 and for 2nd data point valid_len is 2.

In the code portion, we are trying to create mask for that. Let's say we have following dictionary ['blnk', 'Welcome', 'To', 'My', 'Blog', 'Hello', 'World'] so X vector will be :

X = [
    [1,2,3,4],
    [5,6,0,0]
]
valid_len = [3,2]

and the mask must be :

[
    [1,1,1,0],
    [1,1,0,0]
]

We are using brodcast mechanism to create mask. First torch.arange(max_len, dtype=torch.float32, device=X.device) will create a 1D array with shape (max_len,). In our example, it would be [0,1,2,3]. It must be [[0,1,2,3],[0,1,2,3]] right? But we will use brodcast mechanism to expand it to [[0,1,2,3],[0,1,2,3]]. For getting our broadcasted mask we need to apply operator to lengths of (max_len, 1) and (1, valid_len). If you do not know broadcast mechanism, you can read about it in PyTorch documentation.

Now we came to point, for reshaping we use None in pytorch. torch.arange(max_len, dtype=torch.float32, device=X.device)[None, :] is equivalent to torch.arange(max_len, dtype=torch.float32, device=X.device).reshape(1, -1) and valid_len[:, None] is equivalent to valid_len.reshape(-1, 1).

To be honest, I would prefer reshape more readable so my version of this function is :

def sequence_mask_with_reshape(X, valid_len, value= -1e6):
    """ X is 2D array (number_of_points, maxlen), valid_len is 1D array (number_of_points)"""
    max_len = X.size(1)
    mask = torch.arange(max_len, dtype=torch.float32, device=X.device).reshape(1, -1) < valid_len.reshape(-1, 1)
    X[~mask] = value
    return X

As you can see, it is more readable.

联系我们 contact @ memedata.com