File size: 1,017 Bytes
0a948c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import torch
def create_grid_mask(seq_length, trunck_length, fill_triangle):
assert seq_length > 0
# 先不考虑seen_length创建一个grid mask:
if fill_triangle:
mask = 1 - torch.triu(torch.ones(seq_length, seq_length), diagonal=1)
# 下三角与主对角线都为1
else:
mask = torch.zeros(seq_length, seq_length)
for i in range(seq_length):
trunck_idx = i // trunck_length
trunck_start = trunck_idx * trunck_length
trunck_end = trunck_length + trunck_start
mask[i][trunck_start:trunck_end] = 1
return mask
if __name__ == "__main__":
mask = create_grid_mask(seq_length=8, trunck_length=3, fill_triangle=True).int()
print(mask)
# tensor([[1, 1, 1, 0, 0, 0, 0, 0],
# [1, 1, 1, 0, 0, 0, 0, 0],
# [1, 1, 1, 0, 0, 0, 0, 0],
# [1, 1, 1, 1, 1, 1, 0, 0],
# [1, 1, 1, 1, 1, 1, 0, 0],
# [1, 1, 1, 1, 1, 1, 0, 0],
# [1, 1, 1, 1, 1, 1, 1, 1],
# [1, 1, 1, 1, 1, 1, 1, 1]]
|