Torchvision介紹
Torchvision是基于Pytorch的視覺(jué)深度學(xué)習(xí)遷移學(xué)習(xí)訓(xùn)練框架,當(dāng)前支持的圖像分類、對(duì)象檢測(cè)、實(shí)例分割、語(yǔ)義分割、姿態(tài)評(píng)估模型的遷移學(xué)習(xí)訓(xùn)練與評(píng)估。支持對(duì)數(shù)據(jù)集的合成、變換、增強(qiáng)等,此外還支持預(yù)訓(xùn)練模型庫(kù)下載相關(guān)的模型,直接預(yù)測(cè)推理。
預(yù)訓(xùn)練模型使用
Torchvision從0.13版本開(kāi)始預(yù)訓(xùn)練模型支持多源backbone設(shè)置,以圖像分類的ResNet網(wǎng)絡(luò)模型為例:
支持多個(gè)不同的數(shù)據(jù)集上不同精度的預(yù)訓(xùn)練模型,下載模型,轉(zhuǎn)化為推理模型
對(duì)輸入圖像實(shí)現(xiàn)預(yù)處理
本地加載模型
Torchvision中支持的預(yù)訓(xùn)練模型當(dāng)你使用的時(shí)候都會(huì)加載模型的預(yù)訓(xùn)練模型,然后才可以加載你自己的權(quán)重文件,如果你不想加載torchvision的預(yù)訓(xùn)練模型,只想從本地加載pt或者pth文件實(shí)現(xiàn)推理或者訓(xùn)練的時(shí)候,一定要通過(guò)下面的方式完成,以Faster-RCNN為例:
# Load the model from local host num_classes = len(self.labels) self.model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False, progress=True, num_classes=num_classes, pretrained_backbone=False) self.model.load_state_dict(torch.load(self.model_file)) self.model.eval() self.transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()]) # 使用GPU train_on_gpu = torch.cuda.is_available() if train_on_gpu: self.model.cuda()
就這樣解鎖了在torchvision框架下如何從本地加載預(yù)訓(xùn)練模型文件或者定義訓(xùn)練模型文件。
審核編輯:湯梓紅
-
模型
+關(guān)注
關(guān)注
1文章
3226瀏覽量
48806 -
深度學(xué)習(xí)
+關(guān)注
關(guān)注
73文章
5500瀏覽量
121109 -
pytorch
+關(guān)注
關(guān)注
2文章
807瀏覽量
13198
原文標(biāo)題:torchvision中怎么加載本地模型實(shí)現(xiàn)訓(xùn)練與推理
文章出處:【微信號(hào):CVSCHOOL,微信公眾號(hào):OpenCV學(xué)堂】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
評(píng)論