1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
| a = torch.randn(32,8) a1 = torch.randn(32,8) a2 = torch.randn(32,8) b = torch.randn(32,8) b2= torch.randn(32,8) b3 = torch.randn(32,8)
c = torch.stack([a,b],dim=0) d = torch.cat([a,b],dim=0) d = torch.cat([a1,d],dim=0) d = torch.cat([a2,d],dim=0) d = torch.cat([b2,d],dim=0) d = torch.cat([b3,d],dim=0)
print(d.shape)
n,m = d.split([2*32,4*32],dim=0)
print(n.shape) print(m.shape)
nn,mm = d.split(3*32,dim=0)
print(nn.shape) print(mm.shape)
nnn ,mmm = d.chunk(2,dim=0) print(nnn.shape) print(mmm.shape)
|