訓練器
tflearn.helpers.trainer.Trainer (train_ops, graph=None, clip_gradients=5.0, tensorboard_dir='/tmp/tflearn_logs/', tensorboard_verbose=0, checkpoint_path=None, best_checkpoint_path=None, max_checkpoints=None, keep_checkpoint_every_n_hours=10000.0, random_seed=None, session=None, best_val_accuracy=0.0)
泛用類,用來處理任何 TensorFlow 圖形訓練。它需要使用 TrainOp
來指定所有最佳化參數。
引數
- train_ops:
TrainOp
的清單。用來執行最佳化的網路訓練操作清單。 - graph:
tf.Graph
。要使用的 TensorFlow 圖形。預設:預設 tf 圖形。 - clip_gradients:
float
。剪裁梯度。預設:5.0。 - tensorboard_dir:
str
。Tensorboard 日誌目錄。預設:"/tmp/tflearn_logs/". - tensorboard_verbose:
int
。詳細層級。它支援
0 - Loss, Accuracy. (Best Speed)
1 - Loss, Accuracy, Gradients.
2 - Loss, Accuracy, Gradients, Weights.
3 - Loss, Accuracy, Gradients, Weights, Activations, Sparsity.(Best Visualization)
- checkpoint_path:
str
。儲存模型檢查點的路徑。如果為無,將不會儲存模型檢查點。預設:無。 - best_checkpoint_path:
str
。當驗證率達到目前訓練階段最高點且大於 best_val_accuracy 時,儲存模型的路徑。預設:無。 - max_checkpoints:
int
或 None。檢查點最大數量。如果為無,則無限制。預設:無。 - keep_checkpoint_every_n_hours:
float
。每個模型檢查點之間的小時數。 - random_seed:
int
。隨機種子,用於測試可重製性。預設:無。 - session:
Session
。執行操作的階段。如果為無,將會建立新的階段。注意:提供階段時,變數必須事先初始化,否則會產生錯誤。 - best_val_accuracy:
float
模型權重儲存到 best_checkpoint_path 之前必須達到的最小驗證準確率。這允許使用者跳過提早儲存,並在繼續訓練已重新載入的模型時設定最小儲存點。預設:0.0。
方法
fit (feed_dicts, n_epoch=10, val_feed_dicts=None, show_metric=False, snapshot_step=None, snapshot_epoch=True, shuffle_all=None, dprep_dict=None, daug_dict=None, excl_trainops=None, run_id=None, callbacks=[])
使用餵入的資料字典訓練網路。
範例
# 1 Optimizer
trainer.fit(feed_dicts={input1: X, output1: Y},val_feed_dicts={input1: X, output1: Y})
trainer.fit(feed_dicts={input1: X1, input2: X2, output1: Y},val_feed_dicts=0.1) # 10% of data used for validation
# 2 Optimizers
trainer.fit(feed_dicts=[{in1: X1, out1:Y}, {in2: X2, out2:Y2}],val_feed_dicts=[{in1: X1, out1:Y}, {in2: X2, out2:Y2}])
引數
- feed_dicts:
字典
或字典
列表。將資料輸入網路的字典。其遵循 TensorFlow 輸入字典規格:「{佔位符:資料}」。在多重最佳化器的狀況下,預期會是字典的列表,這些字典分別會輸入最佳化器。 - n_epoch:
整數
。執行次數。 - val_feed_dicts:
字典
、字典
列表、浮點數
或浮點數
列表。用於驗證的資料。輸入字典遵循上述feed_dicts
相同的規格,也可以提供浮點數
來區分訓練資料驗證(請注意,這將會將資料隨機洗牌)。 - show_metric:
布林值
。如果是 True,將會在每個步驟中計算並顯示準確性。可能導致較慢的訓練。 - snapshot_step:
整數
。如果不是 None,則網路將會在每一個提供的步驟執行快照(計算驗證損失/準確度,並儲存模型,如果已在 `Trainer` 中指定 `checkpoint_path`)。 - snapshot_epoch:
布林值
。如果是 True,則在每個時代結束時執行網路快照。 - shuffle_all:
布林值
。如果是 True,則將隨機處理所有資料批次(覆寫TrainOp
洗牌參數行為)。 - dprep_dict:使用
佔位符
作爲鍵和資料預處理
作爲值的字典
。對給定的佔位符套用即時資料預處理(訓練和測試時均套用)。 - daug_dict:使用
佔位符
作爲鍵和資料增強
作爲值的字典
。對給定的佔位符套用即時資料增強(僅在訓練時套用)。 - excl_trainops:
TrainOp
的列表
。訓練過程中將排除的訓練操作列表。 - run_id:
字串
。目前執行的名稱。用於 Tensorboard 顯示。如果未提供名稱,則會產生一個隨機名稱。 - callbacks:
回呼
或列表
。訓練生命週期中使用的自訂回呼函式
fit_batch (feed_dicts, dprep_dict=None, daug_dict=None)
使用單一批次訓練網路。
引數
- feed_dicts:
字典
或字典
列表。將資料輸入網路的字典。其遵循 TensorFlow 輸入字典規格:「{佔位符:資料}」。在多重最佳化器的狀況下,預期會是字典的列表,這些字典分別會輸入最佳化器。 - dprep_dict:使用
佔位符
作爲鍵和資料預處理
作爲值的字典
。對給定的佔位符套用即時資料預處理(訓練和測試時均套用)。 - daug_dict:使用
佔位符
作爲鍵和資料增強
作爲值的字典
。對給定的佔位符套用即時資料增強(僅在訓練時套用)。
restore (model_file, trainable_variable_only=False, variable_name_map=None, scope_for_restore=None, create_new_session=True, verbose=False)
還原 TensorFlow 模型
引數
- model_file:要還原的 TensorFlow 模型的路徑
- trainable_variable_only:如果是 True,就只還原可訓練變數。
- variable_name_map:- 一個 (樣式、替換) 的元組,提供一個正規表達式樣式和替換,在從模型檔案還原之前,會將其套用至變數名稱 -- 或是一個功能 map_func,用於執行對應,呼叫方式如下:name_in_file = map_func(existing_var_op_name)。這個功能會傳回 None 表示該變數不應還原。
- scope_for_restore:用於限制還原變量的範圍的字串。- 還原時也會從變數名稱中移除範圍名稱字首。
- create_new_session:如果要保留目前的工作階段,請設定為 False。設定為 True (預設值) 可建立新的工作階段,並重新初始化所有變數。
- verbose:設定為 True 可顯示使用 scope_for_restore 或 variable_name_map 時所還原的變數清單
save (model_file, global_step=None)
儲存 Tensorflow 模型
引數
- model_file:
str
。Tensorflow 模型的儲存路徑 - global_step:
int
。要附加到模型檔案名稱的訓練步驟 (選擇性)。
訓練操作
tflearn.helpers.trainer.TrainOp (loss, optimizer, metric=None, batch_size=64, ema=0.0, trainable_vars=None, shuffle=True, step_tensor=None, validation_monitors=None, validation_batch_size=None, name=None, graph=None)
TrainOp 代表一組用於最佳化網路的運算。
TrainOp 用於留存最佳化器的所有訓練參數。Trainer
類別隨後會全都實體化,特別考慮到所有最佳化器 (設定名稱、範圍... 設定最佳化作業...)。
引數
- loss:
Tensor
。評估網路成本的損失運算。最佳化器會使用此成本函數來訓練網路。 - optimizer:
Optimizer
。Tensorflow 最佳化器。用於訓練網路的最佳化器。 - metric:
Tensor
。評估時要使用的衡量指標巨集。 - batch_size:
int
。饋入此最佳化器的資料批次大小。預設值:64。 - ema:
float
。指數移動平均。 - trainable_vars:
tf.Variable
的清單。可用於訓練的可訓練變數清單。預設值:所有可訓練變數。 - shuffle:
bool
。混排資料。 - step_tensor:
tf.Tensor
。保留訓練步驟的變數。如果未提供,將建立此變數。提早定義步驟巨集可能會對網路建立造成幫助,例如學習率衰減。 - validation_monitors:
Tensor
物件的清單。驗證過程中要運算的變數清單,此清單也會用於產生要輸出到 TensorBoard 的摘要。例如,可藉此在訓練期間定期記錄混淆矩陣或 AUC 衡量指標。每個變數應為秩 1,例如形狀 [None]。 - validation_batch_size:
int
或 None。若是int
,指定要使用於驗證資料饋入的批次大小;否則,預設值與batch_size
相同。 - name:
str
。此類別的名稱 (選擇性)。 - graph:
tf.Graph
。用於訓練的 Tensorflow 圖形。預設值:預設 tf 圖形。
方法
initialize_fit (feed_dict, val_feed_dict, dprep_dict, daug_dict, show_metric, summ_writer, coord)
初始化用於饋入訓練程序的資料。訓練器在開始調整資料前使用此初始化。
引數
- feed_dict:
dict
。要饋入的資料字典。 - val_feed_dict:
dict
或float
。驗證資料字典作為輸入或驗證分割。 - dprep_dict:
dict
。資料預處理字典(以佔位符為鍵,對應DataPreprocessing
物件為值)。 - daug_dict:
dict
。資料擴充字典(以佔位符為鍵,對應DataAugmentation
物件為值)。 - show_metric:
bool
。如果為 True,則在每個步驟顯示準確率。 - summ_writer:
SummaryWriter
。摘要撰寫器,用於 Tensorboard 記錄。
initialize_training_ops (i, 會話, tensorboard_詳細, 裁剪梯度)
初始化所有用於訓練的 ops。由於一個網路可以有多個最佳化器,因此分配了一個身分碼「i」以符號區分。此目的是於初始化所有訓練 ops 時供 Trainer
使用。
引數
- i:
int
。此最佳化器訓練程序的 ID。 - 會話:
tf.Session
。用來訓練網路的會話。 - tensorboard_詳細:
int
。詳細記錄。支援
0 - Loss, Accuracy.
1 - Loss, Accuracy, Gradients.
2 - Loss, Accuracy, Gradients, Weights.
3 - Loss, Accuracy, Gradients, Weights, Activations, Sparsity..
- 裁剪梯度:
float
。裁剪梯度的選項。