《NLP》Using MLP with BERT encoding

Jimmy Lin
Sep 7, 2021

--

好的這是花了四天(對,我就爛)才終於搞定一個小小的微不足道作業的我

雖然這東西其實很簡單,根本不太需要講解,但我還是決定貢獻一下自己的慘痛經驗給大家

這次要用的東西還算簡單,就用PyTorch實作MLP(Multi-Level Peceptions)跟Google提供的BERT編碼來預測一段文字是正面還是負面的意義。

資料集我就不提供了,各位可以自行搜尋。

總之整個流程大概是以下這個樣子:

  1. 讀取資料、篩選資料
  2. Load BERT Model
  3. 把資料塞成DataSet的形式
  4. 用Pytorch提供的Iterator去讀取裡面的內容
  5. 寫下MLP的內容(init定義、forward實作)
  6. 設定好MLP的參數、需要用到的optimizer & criterion(loss function)
  7. 丟下去Train、Validation然後Test

第一步就沒啥好說,隨便CSV做一做 sample取一取,有很多方法啦,如果只是想自己練習的也有一些內建的datasets可以用(EX:datasets.IMDB),最後把他存成檔案再丟出去(方便之後讀取成Dataset的形式)

第二步就很基本的call BERT Model
首先 pip install transformers
接著做以下動作:

from transformers import BertTokenizer, BertModel
tokenizer = BertTokenizer.from_pretrained(‘bert-base-uncased’)
model = BertModel.from_pretrained(“bert-base-uncased”)

接下來就是塞進去DataSets的形式,這裡首先要定義fields,每個field都是data.Field,那裡面可以對資料做一些前處理跟預先切割好,當然你不特別寫的話就會當作沒這回事。
定義好Fields之後就Call TabularDataset,告訴他資料來源、格式還有Field,詳細的參數可以參考網路上的document,更詳細

接下來就只要記得用BucketIterator再把data塞進去就好了,在這裡面有幾個很重要的東西:

  1. Batch Size
  2. Device
  3. sort

首先是Batch size,這東西其實就是告訴電腦你一次要丟幾筆資料進去裡面Train,想當然爾越大的話對RAM或者顯卡記憶體的要求就越高,我的3070 8 GB完全不夠看,分分鐘就被塞滿了
再來是Device,這個地方可以很簡單的設定成torch.device(‘cuda’),但要注意,如果顯卡的驅動沒安裝好、加上pytorch沒有裝好有cuda的版本的話是不能用的,所以記得要先用好這兩個地方,一個滿方便的不怕bug的寫法應該就是先用torch.cuda.is_available()確認cuda可不可用,可以就設成顯卡跑,不能就設成CPU跑
最後是sort,train model的時候有可能會遇到以下這個Error:

TypeError: ‘<’ not supported between instances of ‘Example’ and ‘Example’

遇到這個問題只要在Iterator讀取的時候加上sort=False就能解決了

接著是寫MLP的詳細內容,其實這邊反而沒想像中難,MLP主要就是透過多層簡化把輸入最終轉換成我們希望可以拿來對比標籤的數量

舉例來說,我身為一個人的參數有100個(身高、體重、體脂肪率、學歷......等等),但最終分類只有3個參數(性別、人種、年齡),也就是說我給了你關於一個人的各種data,希望你最後判斷他是幾歲的哪裡人跟性別,這種時候我們就是要想辦法讓100個參數最後映射到3個參數而已
那這個方法很多元,你可以首先讓100個參數映射到99個、再映射到98個......以此類推,或者你也能直接100→50→25→5→3這樣轉變,細節有很多可以自己設定去嘗試的,我就不多做解釋
forward主要就是你實際把資料送進來的時候會怎麼處理,在這裡你可以直接呼叫你在init定義好的function,或者是你也可以先對資料做一些處理,比如說讓他被送進BERT Model裡處理後再把處理後的資料拿出來開始映射

然後找一下optimizer跟criterion要用哪個(內建有很多種function,可以隨便試試看),丟下去train(這邊程式網路上都找的到很多範例,我就不多做解釋了)

簡單來說,這個作業就只是下載資料→丟進Model裡面Train→跑Test,但我花了四天才寫完,只能說自己的鍋最大,像是我一開始花了一天多自己處理資料、死都不用內建的Tabular或者Iterator,結果把自己搞死還沒成功,最後才含淚用內建function

另外各位在二分法的時候也有機會遇到train loss一直卡在.693的情況,如果去查文章的話會發現這基本上代表參數卡在一個極端值或者是很快就收斂到無限接近於0,所以無論怎麼訓練都像用猜的一樣,可行的解決辦法大致上有:

  1. 處理一下數據分布的問題
  2. 調低learning rate讓他不要一下就跑到一個local extreme的值去

但最後我也不知道我改了什麼,總之一直重跑了四個小時後,他突然就能跑了,大概是這樣

--

--

Jimmy Lin
Jimmy Lin

Written by Jimmy Lin

Hi, I'm Jimmy. I graduated from ASU and work in Amazon now.

No responses yet