探索 Tensorflow 2.0
前言
近期 Tensorflow 發佈了 2.0 的穩定版本。
這個版本重點在於易用性的改進,加強與開源神經網路函式庫 Keras 的整合,並且簡化 API 降低功能重複。
現在 TensorFlow 交換格式都與 SavedModel 統一,SavedModel 格式內容具有完整的 TensorFlow 程式,包含權重和計算,不需要原始建置模型的程式碼就能執行,這對於模型共享或是部署非常有用。
開發團隊為 TensorFlow 2.0 的 API 進行了調整,許多 API 符號經重新命名或是刪除,參數名稱也被更改,整體來說,調整後能讓 API 的使用經驗更加一致清楚。
本文旨在將 初探 Tensorflow 機器學習、Tf.js 實作 Chrome 擴充功能 的內容進行更新補充。
欲升級 Tensorflow 至 2.0 版本可用以下指令:
pip install tensorflow==2.0
影像辨識-訓練
建構神經網路
TensorFlow 2.0 提供的 API 更加人性化了。
建立神經網路層只需呼叫Sequential
,並將所需要的層疊入即可。
此處用Dense
全連接層建立了隱藏層(20 個神經元、激勵函數 relu)、輸出層(10 個神經元)。
且第一層需指定input_shape
參數,(266,)
表示輸入的shape
為[*,266]
。
接著使用compile
指定最佳化器(optimizer)、誤差計算方法(loss)、訓練成效指標(metrics)。
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(20, activation='relu', input_shape=(266,)), # 隱藏層
tf.keras.layers.Dense(10, activation='softmax') # 輸出層
])
model.compile(optimizer='adam', # 最佳化器
loss='sparse_categorical_crossentropy', # 誤差計算
metrics=['accuracy']) # 成效指標
訓練並儲存結果
TensorFlow 2.0 呼叫fit
丟入x
、y
就可以訓練了,訓練 100 步。
最後儲存訓練後的 saved_model 模型以及 tfjs 的模型。
導入tensorflowjs
使用save_keras_model
就可以直接導出 tfjs 模型了,不需再額外轉檔。
import os
import tensorflow as tf
import tensorflowjs as tfjs
import numpy as np
from PIL import Image
x, y = read_pic()
x = z_score(x)
model.fit(x, y, epochs=100) # 訓練 100 步
tf.saved_model.save(model, "../model_v2") # 儲存 saved_model 模型
tfjs.converters.save_keras_model(model, "../web_model_v2") # 儲存 tfjs 模型
影像辨識-推理
推理並輸出預測值
透過saved_model.save
儲存的檔案已是模型,故神經網路不需重建。
TensorFlow 2.0 讀取已存在的模型只需要呼叫saved_model.load
即可輕鬆達成。
輸入的dtype
必須符合規定,一般是float32
,shape
也需與input_shape
參數相符。
將輸入直接丟進方才讀出來的模型裡就可以進行推理了:model(x)
。
最後,再將其轉換成 array 類型,並用argmax(1)
取最大值的索引即可。
import os, requests
import tensorflow as tf
import numpy as np
from PIL import Image
from io import BytesIO
x = load_pic()
x = z_score(x)
x = np.array(x, dtype=np.float32) # 轉成 float32
model = tf.saved_model.load("model_v2") # 讀取模型
print("=====")
print(np.array(model(x)).argmax(1)) # 進行推理
print("=====")
核心 JS 撰寫
load_model()
透過save_keras_model
導出的 tfjs 模型,於 TensorFlow.js 中要以loadLayersModel
方式載入。
故將loadGraphModel
直接改為loadLayersModel
即可相容,但也需留意輸入的dtype
與shape
是否正確。
class Sess
{
async load_model()
{
console.log("[*]模型下載中...")
const startTime = performance.now()
try
{
// this.model = await tf.loadGraphModel(MODEL_URL)
this.model = await tf.loadLayersModel(MODEL_URL)
tf.tidy(() => { this.model.predict(tf.zeros([4, 266])) })
const totalTime = Math.floor(performance.now() - startTime)
console.log(`[*]模型初始化完成,共耗時 ${totalTime} ms`)
}
catch { console.error(`[!]無法從下列網址載入模型: ${MODEL_URL}`) }
const predict = this.predict(this.z_score(this.load_pic()))
const string = predict.join("")
console.log(`[*]預測驗證碼為:${string}`)
document.querySelector("#baseContent_cph_confirm_txt").value = string
console.log("[*]已填入 inputText")
console.log("[*]執行完畢")
}
}