2023-04-25
◈기존 코드
from torchvision.models.vgg import cfgs, make_layers #, model_urls
def vgg19(avg_pool: bool = True, pretrained: bool = True,): # init.py 에서 사용
model = VGG19(avg_pool=avg_pool) # VGG19 는 클래스 -> class VGG19(nn.ModuleDict)
if pretrained: # 매개변수로 pretrained 여부 받아왔음
### 에러뜨는 부분 -> model_urls 사용
state_dict = load_state_dict_from_url(model_urls["vgg19"], progress=True)
model.load_state_dict(state_dict)
# model.load_state_dict : 역직렬화된 state_dict를 사용, 모델의 매개변수들을 불러옴.
# state_dict는 간단히 말해 각 체층을 매개변수 Tensor로 매핑한 Python 사전(dict) 객체.
return model
◈ 원인
- 참고: https://stackoverflow.com/questions/67317418/importerror-cannot-import-name-model-urls-from-torchvision-models-vgg
- First of all, for all torchvision > 0.13 users, the model_urls are gone, you shouldn't use it.
- torchvision 0.13 이상 버전에서는 model_urls 기능이 사라졌다.
- 해결책으로 해당 사이트에서는 다음과 같이 제시했지만 VGG19 를 사용해서 그런지 실패함..
# change from your model_urls to this
from torchvision.models.resnet import ResNet50_Weights
org_resnet = torch.utils.model_zoo.load_url(ResNet50_Weights.IMAGENET1K_V2)
- 참고
- 코랩에서 -> pytorch 와 pytorchvision 버전 확인
# 토치 , 토피비젼 버전확인인
import torch
import torchvision
print(torch.__version__)
print(torchvision.__version__)
◈ 해결
- 다음 해결을 위해 직접 pythorch 공식사이트에서 torchvision.models.vgg 코드를 들여다봄
torchvision.models.vgg — Torchvision 0.12 documentation
Shortcuts
pytorch.org
- 다음과 같이 url 들이 저장되어있다.
model_urls = {
"vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth",
"vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth",
"vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
"vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
"vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
"vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
"vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
"vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
}
- 기존 코드에서
- state_dict = load_state_dict_from_url(model_urls["vgg19"], progress=True)
- 의 load_state_dict_from_url 의 pytorch 공식문서를 살펴보면
# For checkpoint saved elsewhere
checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))
- 따라서 vgg19 의 url 주소를 직접 갖다 붙이니 잘 동작한다.
◈해결 코드
### 에러해결 : 아이예 import 에서 model_urls 사용안함
from torchvision.models.vgg import cfgs, make_layers
def vgg19(avg_pool: bool = True, pretrained: bool = True,): # init.py 에서 사용
model = VGG19(avg_pool=avg_pool) # VGG19 는 클래스 -> class VGG19(nn.ModuleDict)
if pretrained: # 매개변수로 pretrained 여부 받아왔음
#state_dict = load_state_dict_from_url(model_urls["vgg19"], progress=True)
### 에러해결 : 수정
checkpoint = 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth' #수정
state_dict = load_state_dict_from_url(checkpoint, progress=True) #수정
model.load_state_dict(state_dict)
# model.load_state_dict : 역직렬화된 state_dict를 사용, 모델의 매개변수들을 불러옴.
# state_dict는 간단히 말해 각 체층을 매개변수 Tensor로 매핑한 Python 사전(dict) 객체.
return model
'프로젝트 에러 > [CGVR] texture-synthesis' 카테고리의 다른 글
pytorch - gpu 사용하기 (0) | 2023.07.14 |
---|---|
"ValueError: The truth value of an array with more than one element isambiguous. Use a.any() or a.all()" (0) | 2023.05.02 |
파이썬 python 주석 단축키 [ ctrl + / ] 안 될 때 (0) | 2023.04.27 |
cv2.error: ~~ !_src.empty() in function 'cv::cvtColor' (0) | 2023.04.26 |