2021 네이버 부스트캠프 - Ai tech

Week_3 Pytorch - Transfer Learning, 모델 저장 및 불러오기

미미수 2021. 8. 19. 23:41

딥러닝 모델을 짜다보면 모델 or 가중치들을 공유하거나 불러오는 일이 다분합니다. 

불러온 pretrained 모델에 input 형식만 잘 맞춰줘도 어느정도 성능은 보장이 됩니다.

이를 Transfer Learning, 전이학습이라고 합니다.

 

Transfer Learning : 이미지넷과 같이 아주 큰 데이터셋에 훈련된 모델의 가중치를 가지고 와서 우리가 해결하고자 하는 과제에 맞게 재보정해서 사용하는 것

 

전이학습을 위해서는 모델을 불러와야하는데,

Pytorch에서 모델을 불러오는 법에는 크게 두가지가 있습니다.


1. Model의 state_dict를 저장 및 불러오기

모델의 파라미터를 저장

torch.save(model.state_dict(),PATH)

model.load_state_dict(torch.load(PATH))

 

Model을 생성하면 자동으로 state_dict라는 Ordered_dict type의 변수가 생성됩니다. 

여기에는 모든 레이어들에 대한 구체적인 정보들이 담겨있습니다. 자세한 내용은 Pytorch 공식문서에 나와있습니다.

 

state_dict는 같은 모델의 형태에서 불러야 합니다.

 

 

2. Model의 아키텍쳐까지 함께 저장

 

torch.save(model, PATH)

model = torch.load(PATH)

 

Python의 pickle 모듈을 사용하여 전체 모듈을 저장하게 됩니다.

 

 

PyTorch에서는 모델을 저장할 때 .pt 또는 .pth 확장자를 사용하는 것이 일반적인 규칙입니다.

하지만 .pth는 python과 겹쳐서 되도록이면 .pt를 쓰는것을 추천합니다.