1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
| '''
函数说明 torch.masked_select( input, mask ,out = None ) -> 张量
根据掩码张量mask中的二元值(0,1),取输入张量中的指定项( mask为一个 ByteTensor),将取值返回到一个新的1D张量---是打平的张量; 张量 mask须跟input张量有相同数量的元素数目,但形状或维度不需要相同。 注意: 返回的张量不与原始张量共享内存空间。 ''' a = torch.randn(3,4) print(a) mask2 = a.ge(0.3) mask3 = a.le(0.7) print(mask2) print(a[mask2]) print(a[mask3]) print(torch.masked_select(a,mask2))
''' output: tensor([[-0.7769, -0.0803, -0.4235, -0.3562], [-0.4744, 1.2078, 0.6371, -0.6981], [-1.1653, -0.3432, -2.3189, 0.1708]]) tensor([[ True, False, True, True], [False, False, False, False], [ True, False, False, False]]) tensor([1.2078, 0.6371])
tensor([-0.7769, -0.0803, -0.4235, -0.3562, -0.4744, 0.6371, -0.6981, -1.1653, -0.3432, -2.3189, 0.1708]) tensor([1.2078, 0.6371])
'''
|