딥러닝 모델을 짜다보면 모델 or 가중치들을 공유하거나 불러오는 일이 다분합니다.
불러온 pretrained 모델에 input 형식만 잘 맞춰줘도 어느정도 성능은 보장이 됩니다.
이를 Transfer Learning, 전이학습이라고 합니다.
Transfer Learning : 이미지넷과 같이 아주 큰 데이터셋에 훈련된 모델의 가중치를 가지고 와서 우리가 해결하고자 하는 과제에 맞게 재보정해서 사용하는 것
전이학습을 위해서는 모델을 불러와야하는데,
Pytorch에서 모델을 불러오는 법에는 크게 두가지가 있습니다.
1. Model의 state_dict를 저장 및 불러오기
모델의 파라미터를 저장
torch.save(model.state_dict(),PATH)
model.load_state_dict(torch.load(PATH))
Model을 생성하면 자동으로 state_dict라는 Ordered_dict type의 변수가 생성됩니다.
여기에는 모든 레이어들에 대한 구체적인 정보들이 담겨있습니다. 자세한 내용은 Pytorch 공식문서에 나와있습니다.
state_dict는 같은 모델의 형태에서 불러야 합니다.
2. Model의 아키텍쳐까지 함께 저장
torch.save(model, PATH)
model = torch.load(PATH)
Python의 pickle 모듈을 사용하여 전체 모듈을 저장하게 됩니다.
PyTorch에서는 모델을 저장할 때 .pt 또는 .pth 확장자를 사용하는 것이 일반적인 규칙입니다.
하지만 .pth는 python과 겹쳐서 되도록이면 .pt를 쓰는것을 추천합니다.
'2021 네이버 부스트캠프 - Ai tech' 카테고리의 다른 글
Week_3 Pytorch - Out of Memory, OOM 해결 (0) | 2021.08.20 |
---|---|
Week_3 Pytorch - Dataset & Dataloader (0) | 2021.08.20 |
Week_3 Pytorch - view vs. reshape (0) | 2021.08.17 |
Week_3 Pytorch - Squeeze vs. Unsqueeze (0) | 2021.08.17 |
Week_2 딥러닝에서 비선형성이 중요한 이유 (0) | 2021.08.12 |