2021 네이버 부스트캠프 - Ai tech

Week_3 Pytorch - Squeeze vs. Unsqueeze

미미수 2021. 8. 17. 19:49

Python과 Pytorch의 특장점 중 하나는 미리 구현된 라이브러리와 모듈들이 풍부하다는 점이다.

하지만 풍부할 수록 알아야할 가짓수는 n만개,, 

그중에서도 헷갈리지만 유용한 개념 몇가지에 대해 리뷰하겠다.

 

우선 Squeeze 와 Unsqueeze!

기본적으로 data dimension manipulation에 관한 함수이다.

 

딥러닝 모델을 구축하다보면 input data의 차원을 맞춰야 하는 경우가 생기는데, 그럴때 유용하게 사용가능하다.

Squeeze : Tensor의 차원을 줄이는 것

 

squeeze_tensor = torch.rand(size=(2,1,2))
#tensor([[[0.1117, 0.8158]],
#       [[0.2626, 0.4839]]])

squeeze_tensor.squeeze()
#tensor([[0.1117, 0.8158],
#       [0.2626, 0.4839]])

 

torch.Size([2, 1, 2]) → torch.Size([2, 2])

로 size가 1인 차원을 줄여준다.

 

t = torch.rand(size=(1,1,2,2,3))
t.squeeze().shape

# torch.Size([2, 2, 3])

size가 1인 차원이 여러개여도 다 제거해준다.

 

원하는 dimension을 선택해서 제거해줄 수도 있다.

t = torch.rand(size=(1,1,2,2,3))
t.squeeze(1).shape

# torch.Size([1, 2, 2, 3])

 

 

Unsqueeze : Tensor의 차원을 늘리는 것

원하는 차원의 size를 1로 늘려준다. 차원을 원하는 위치에 쑤셔넣는다고 보면 된다. (고로 얘는 dimestion을 꼭 인자로 전달해줘야함)

unsqueeze_tensor = torch.rand(size=(2,2))

unsqueeze_tensor.unsqueeze(0).shape # [1,2,2]
unsqueeze_tensor.unsqueeze(1).shape # [2,1,2]
unsqueeze_tensor.unsqueeze(2).shape # [2,2,1]