多少人用 PyTorch煉丹時(shí)都會(huì)被這個(gè) bug 困擾。
CUDAerror:outofmemory.
一般情況下,你得找出當(dāng)下占顯存的沒(méi)用的程序,然后 kill 掉如果不行,還需手動(dòng)調(diào)整 batch size 到合適的大小,有點(diǎn)麻煩
現(xiàn)在,有人寫(xiě)了一個(gè) PyTorch wrapper,用一行代碼就能無(wú)痛消除這個(gè) bug。
有多厲害。
相關(guān)項(xiàng)目在 GitHub 才發(fā)布沒(méi)幾天就收獲了 600 + 星。
一行代碼解決內(nèi)存溢出錯(cuò)誤
軟件包名叫 koila,已經(jīng)上傳 PyPI,先安裝一下:
pipinstallkoila
現(xiàn)在,假如你面對(duì)這樣一個(gè) PyTorch 項(xiàng)目:構(gòu)建一個(gè)神經(jīng)網(wǎng)絡(luò)來(lái)對(duì) FashionMNIST 數(shù)據(jù)集中的圖像進(jìn)行分類。
先定義 input,label 和 model:
#AbatchofMNISTimageinput=torch.randn#Abatchoflabelslabel=torch.randn)classNeuralNetwork(Module):def__init__(self):super(NeuralNetwork,self).__init__()self.flatten=Flatten()self.linear_relu_stack=Sequential(Linear(28*28,512),ReLU(),Linear(512,512),ReLU(),Linear(512,10),)defforward(self,x):x=self.flatten(x)logits=self.linear_relu_stack(x)returnlogits
然后定義 loss 函數(shù),計(jì)算輸出和 losses。通過(guò)在PyTorchLightning中設(shè)置混合精度標(biāo)志,它將在可能的情況下自動(dòng)使用半精度,同時(shí)在其他地方保留單精度。
loss_fn=CrossEntropyLoss#Calculatelossesout=nn(t)loss=loss_fn(out,label)#Backwardpassnn.zero_gradloss.backward
好了,如何使用 koila 來(lái)防止內(nèi)存溢出。
超級(jí)簡(jiǎn)單!
只需在第一行代碼,也就是把輸入用 lazy 張量 wrap 起來(lái),并指定 bacth 維度,koila 就能自動(dòng)幫你計(jì)算剩余的 GPU 內(nèi)存并使用正確的 batch size 了。
在本例中,batch=0,則修改如下:
input=lazy,batch=0)
完事兒!就這樣和 PyTorch煉丹時(shí)的 OOM 報(bào)錯(cuò)說(shuō)拜拜。。
靈感來(lái)自 TensorFlow 的靜態(tài) / 懶惰評(píng)估
下面就來(lái)說(shuō)說(shuō) koila 背后的工作原理。
CUDA error: out of memory這個(gè)報(bào)錯(cuò)通常發(fā)生在前向傳遞中,因?yàn)檫@時(shí)需要保存很多臨時(shí)變量。
koila 的靈感來(lái)自 TensorFlow 的靜態(tài) / 懶惰評(píng)估。
它通過(guò)構(gòu)建圖,并僅在必要時(shí)運(yùn)行訪問(wèn)所有相關(guān)信息,來(lái)確定模型真正需要多少資源。
而只需計(jì)算臨時(shí)變量的 shape 就能計(jì)算各變量的內(nèi)存使用情況,而知道了在前向傳遞中使用了多少內(nèi)存,koila 也就能自動(dòng)選擇最佳 batch size 了。
又是算 shape 又是算內(nèi)存的,koila 聽(tīng)起來(lái)就很慢。
NO。
即使是像 GPT—3 這種具有 96 層的巨大模型,其計(jì)算圖中也只有幾百個(gè)節(jié)點(diǎn)。
而 Koila 的算法是在線性時(shí)間內(nèi)運(yùn)行,任何現(xiàn)代計(jì)算機(jī)都能夠立即處理這樣的圖計(jì)算,再加上大部分計(jì)算都是單個(gè)張量,所以,koila 運(yùn)行起來(lái)一點(diǎn)也不慢。
你又會(huì)問(wèn)了,PyTorch Lightning 的 batch size 搜索功能不是也可以解決這個(gè)問(wèn)題嗎。
是的,它也可以。
而 koila 靈活又輕量,只需一行代碼就能解決問(wèn)題,非常大快人心有沒(méi)有。
不過(guò)目前,koila 還不適用于分布式數(shù)據(jù)的并行訓(xùn)練方法,未來(lái)才會(huì)支持多 GPU。
以及現(xiàn)在只適用于常見(jiàn)的 nn.Module 類。
項(xiàng)目地址:點(diǎn)此直達(dá)
參考鏈接:點(diǎn)此直達(dá)
。鄭重聲明:此文內(nèi)容為本網(wǎng)站轉(zhuǎn)載企業(yè)宣傳資訊,目的在于傳播更多信息,與本站立場(chǎng)無(wú)關(guān)。僅供讀者參考,并請(qǐng)自行核實(shí)相關(guān)內(nèi)容。
|