[Python] 利用Keras數字辨識_辨識多種電腦字體的數字part2(建模、預測)


終於來到Part2!!

擷取一堆0~9的電腦字體數字做訓練,接著再拿另一群做測試

我們來看看用Keras的效能如何



前面已經先說過圖片裁切資料前處理Part1

接著這篇我們來看如何建模以及進行預測!


1. 建立模型


#建立一個線性堆疊的模型
model = Sequential()
#建立卷積層1
model.add(Conv2D(filters=16,kernel_size=(5,5),padding='same',input_shape=(28,28,1),activation='relu'))
#建立池化層1。
model.add(MaxPooling2D(pool_size=(2,2)))
#建立卷積層2
model.add(Conv2D(filters=36,kernel_size=(5,5),padding='same',activation='relu'))
#建立池化層2
model.add(MaxPooling2D(pool_size=(2,2)))
#避免overfitting
model.add(Dropout(0.25))
#建立平坦層
model.add(Flatten())
#建立隱藏層,且避免overfitting
model.add(Dense(128,activation='relu'))
model.add(Dropout(0.25))
#建立輸出層
model.add(Dense(10,activation='softmax'))

在這裡說明一下裡面的一些參數:

activation:設定激勵函數,最常用的就是relu
input_shape(a,b,c):a,b為輸入影像的大小,c為單色或彩色
padding='same':使影像經過卷積計算後大小不變
kernel_size(a,b):濾鏡的大小為axb
filters:濾鏡的層數




2. 定義訓練方式 & 進行訓練


#設定損失函數,最佳化方法,以及評估模型等。
model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

###開始訓練
train_history=model.fit(x=x_train4_nor,y=y_train_one,validation_split=0.1,epochs=25,batch_size=300,verbose=2)

在這裡也說明一下裡面的一些參數:

x:影像特徵值,也就是訓練資料
y:影像實際值,也就是訓練資料的正確數字
validation_split:設定訓練及驗證資料比例 (1/10)
epochs:訓練周期
batch_size:每一批次多少筆資料

verbose:顯示訓練過程
- acc:模型訓練精度
- val_acc:模型在驗證集上的精度
- loss訓練的的損失值。

會不斷顯示訓練周期(epochs)25次的結果(這裡epochs=25)


如果跑越多次,精確度會更高一些,甚至可以接近100%正確喔!

(資料若夠多的話也可以提升精確度)



3. 預測測試資料並檢視結果



#輸入影像,並儲存結果。
prediction=model.predict_classes(x_test_nor)

#交叉表
pd.crosstab(y_test,prediction,rownames=['label'],colnames=['predict'])






4. 圖型化呈現(可略,但有這步感覺厲害多了)



###畫出圖形
def plot_images_labels(images,labels,prediction,idx,num=10):
    fig=plt.gcf()
    fig.set_size_inches(12,14)
    if num>25:
        num=15
    for i in range(0,num):
        ax=plt.subplot(5,5,1+i)
        ax.imshow(np.reshape(images[idx],(28,28)), cmap='binary')
        title="label=" +str(labels[idx])
        if len(prediction)>0:
            title+=",predict="+str(prediction[idx])
        ax.set_title(title,fontsize=10)
        ax.set_xticks([])
        ax.set_yticks([])
        idx=idx+1
    plt.show()

# 印出10張實際圖形,並顯示預測值
plot_images_labels(x_test,y_test,prediction,idx=1)  




完成電腦數字的測試後,整個成就感上升阿!!(後面還有實際發票辨識測試要做...)

這樣就完整結束了!


**發票辨識也可以像辨識電腦數字一樣精確喔!