AI programming/pytorch
[pytorch] einops.rearrange()
Bull_
2024. 4. 13. 20:30
pip install einops
from einops import rearrange
x = torch.randn(64, 3, 32, 32)
patch_size = 4 # 4 pixels
print('x :', x.shape)
patches = rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size)
# 64 3 32 32 = 64 3 8*4 8*4 -> 64 8*8 4*4*3
print('patches :', patches.shape)
x : torch.Size([64, 3, 32, 32])
patches : torch.Size([64, 64, 48])
rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size)
64 3 32 32에서 32를 (8 * 4)로 끊어서 텐서를 나눠 줄 수 있다.
즉 3번째 자리에 보이는 (h s1)을 h와 s1으로 나눌 수 있다.
그러니까 [64, 3, 32, 32] → [64, 3, (8 4), (8 4)] → [64, (8 8), (4 4 3)] → [64, 64, 48] 텐서로 변환해준다.