訓練器

tflearn.helpers.trainer.Trainer (train_ops, graph=None, clip_gradients=5.0, tensorboard_dir='/tmp/tflearn_logs/', tensorboard_verbose=0, checkpoint_path=None, best_checkpoint_path=None, max_checkpoints=None, keep_checkpoint_every_n_hours=10000.0, random_seed=None, session=None, best_val_accuracy=0.0)

泛用類,用來處理任何 TensorFlow 圖形訓練。它需要使用 TrainOp 來指定所有最佳化參數。

引數

  • train_ops: TrainOp 的清單。用來執行最佳化的網路訓練操作清單。
  • graph: tf.Graph。要使用的 TensorFlow 圖形。預設:預設 tf 圖形。
  • clip_gradients: float。剪裁梯度。預設:5.0。
  • tensorboard_dir: str。Tensorboard 日誌目錄。預設:"/tmp/tflearn_logs/".
  • tensorboard_verbose: int。詳細層級。它支援
0 - Loss, Accuracy. (Best Speed)
1 - Loss, Accuracy, Gradients.
2 - Loss, Accuracy, Gradients, Weights.
3 - Loss, Accuracy, Gradients, Weights, Activations, Sparsity.(Best Visualization)
  • checkpoint_path: str。儲存模型檢查點的路徑。如果為無,將不會儲存模型檢查點。預設:無。
  • best_checkpoint_path: str。當驗證率達到目前訓練階段最高點且大於 best_val_accuracy 時,儲存模型的路徑。預設:無。
  • max_checkpoints: int 或 None。檢查點最大數量。如果為無,則無限制。預設:無。
  • keep_checkpoint_every_n_hours: float。每個模型檢查點之間的小時數。
  • random_seed: int。隨機種子,用於測試可重製性。預設:無。
  • session: Session。執行操作的階段。如果為無,將會建立新的階段。注意:提供階段時,變數必須事先初始化,否則會產生錯誤。
  • best_val_accuracy: float 模型權重儲存到 best_checkpoint_path 之前必須達到的最小驗證準確率。這允許使用者跳過提早儲存,並在繼續訓練已重新載入的模型時設定最小儲存點。預設:0.0。

方法

fit (feed_dicts, n_epoch=10, val_feed_dicts=None, show_metric=False, snapshot_step=None, snapshot_epoch=True, shuffle_all=None, dprep_dict=None, daug_dict=None, excl_trainops=None, run_id=None, callbacks=[])

使用餵入的資料字典訓練網路。

範例
# 1 Optimizer
trainer.fit(feed_dicts={input1: X, output1: Y},val_feed_dicts={input1: X, output1: Y})
trainer.fit(feed_dicts={input1: X1, input2: X2, output1: Y},val_feed_dicts=0.1) # 10% of data used for validation

# 2 Optimizers
trainer.fit(feed_dicts=[{in1: X1, out1:Y}, {in2: X2, out2:Y2}],val_feed_dicts=[{in1: X1, out1:Y}, {in2: X2, out2:Y2}])
引數
  • feed_dicts字典字典列表。將資料輸入網路的字典。其遵循 TensorFlow 輸入字典規格:「{佔位符:資料}」。在多重最佳化器的狀況下,預期會是字典的列表,這些字典分別會輸入最佳化器。
  • n_epoch整數。執行次數。
  • val_feed_dicts字典字典列表、浮點數浮點數列表。用於驗證的資料。輸入字典遵循上述 feed_dicts 相同的規格,也可以提供 浮點數來區分訓練資料驗證(請注意,這將會將資料隨機洗牌)。
  • show_metric布林值。如果是 True,將會在每個步驟中計算並顯示準確性。可能導致較慢的訓練。
  • snapshot_step整數。如果不是 None,則網路將會在每一個提供的步驟執行快照(計算驗證損失/準確度,並儲存模型,如果已在 `Trainer` 中指定 `checkpoint_path`)。
  • snapshot_epoch布林值。如果是 True,則在每個時代結束時執行網路快照。
  • shuffle_all布林值。如果是 True,則將隨機處理所有資料批次(覆寫 TrainOp 洗牌參數行為)。
  • dprep_dict:使用 佔位符作爲鍵和 資料預處理作爲值的字典。對給定的佔位符套用即時資料預處理(訓練和測試時均套用)。
  • daug_dict:使用 佔位符作爲鍵和 資料增強作爲值的字典。對給定的佔位符套用即時資料增強(僅在訓練時套用)。
  • excl_trainopsTrainOp列表。訓練過程中將排除的訓練操作列表。
  • run_id字串。目前執行的名稱。用於 Tensorboard 顯示。如果未提供名稱,則會產生一個隨機名稱。
  • callbacks回呼列表。訓練生命週期中使用的自訂回呼函式

fit_batch (feed_dicts, dprep_dict=None, daug_dict=None)

使用單一批次訓練網路。

引數
  • feed_dicts字典字典列表。將資料輸入網路的字典。其遵循 TensorFlow 輸入字典規格:「{佔位符:資料}」。在多重最佳化器的狀況下,預期會是字典的列表,這些字典分別會輸入最佳化器。
  • dprep_dict:使用 佔位符作爲鍵和 資料預處理作爲值的字典。對給定的佔位符套用即時資料預處理(訓練和測試時均套用)。
  • daug_dict:使用 佔位符作爲鍵和 資料增強作爲值的字典。對給定的佔位符套用即時資料增強(僅在訓練時套用)。

restore (model_file, trainable_variable_only=False, variable_name_map=None, scope_for_restore=None, create_new_session=True, verbose=False)

還原 TensorFlow 模型

引數
  • model_file:要還原的 TensorFlow 模型的路徑
  • trainable_variable_only:如果是 True,就只還原可訓練變數。
  • variable_name_map:- 一個 (樣式、替換) 的元組,提供一個正規表達式樣式和替換,在從模型檔案還原之前,會將其套用至變數名稱 -- 或是一個功能 map_func,用於執行對應,呼叫方式如下:name_in_file = map_func(existing_var_op_name)。這個功能會傳回 None 表示該變數不應還原。
  • scope_for_restore:用於限制還原變量的範圍的字串。- 還原時也會從變數名稱中移除範圍名稱字首。
  • create_new_session:如果要保留目前的工作階段,請設定為 False。設定為 True (預設值) 可建立新的工作階段,並重新初始化所有變數。
  • verbose:設定為 True 可顯示使用 scope_for_restore 或 variable_name_map 時所還原的變數清單

save (model_file, global_step=None)

儲存 Tensorflow 模型

引數
  • model_filestr。Tensorflow 模型的儲存路徑
  • global_stepint。要附加到模型檔案名稱的訓練步驟 (選擇性)。

訓練操作

tflearn.helpers.trainer.TrainOp (loss, optimizer, metric=None, batch_size=64, ema=0.0, trainable_vars=None, shuffle=True, step_tensor=None, validation_monitors=None, validation_batch_size=None, name=None, graph=None)

TrainOp 代表一組用於最佳化網路的運算。

TrainOp 用於留存最佳化器的所有訓練參數。Trainer 類別隨後會全都實體化,特別考慮到所有最佳化器 (設定名稱、範圍... 設定最佳化作業...)。

引數

  • lossTensor。評估網路成本的損失運算。最佳化器會使用此成本函數來訓練網路。
  • optimizerOptimizer。Tensorflow 最佳化器。用於訓練網路的最佳化器。
  • metricTensor。評估時要使用的衡量指標巨集。
  • batch_sizeint。饋入此最佳化器的資料批次大小。預設值:64。
  • emafloat。指數移動平均。
  • trainable_varstf.Variable 的清單。可用於訓練的可訓練變數清單。預設值:所有可訓練變數。
  • shufflebool。混排資料。
  • step_tensortf.Tensor。保留訓練步驟的變數。如果未提供,將建立此變數。提早定義步驟巨集可能會對網路建立造成幫助,例如學習率衰減。
  • validation_monitorsTensor 物件的清單。驗證過程中要運算的變數清單,此清單也會用於產生要輸出到 TensorBoard 的摘要。例如,可藉此在訓練期間定期記錄混淆矩陣或 AUC 衡量指標。每個變數應為秩 1,例如形狀 [None]。
  • validation_batch_sizeint 或 None。若是 int,指定要使用於驗證資料饋入的批次大小;否則,預設值與 batch_size 相同。
  • namestr。此類別的名稱 (選擇性)。
  • graphtf.Graph。用於訓練的 Tensorflow 圖形。預設值:預設 tf 圖形。

方法

initialize_fit (feed_dict, val_feed_dict, dprep_dict, daug_dict, show_metric, summ_writer, coord)

初始化用於饋入訓練程序的資料。訓練器在開始調整資料前使用此初始化。

引數
  • feed_dictdict。要饋入的資料字典。
  • val_feed_dict: dictfloat。驗證資料字典作為輸入或驗證分割。
  • dprep_dict: dict。資料預處理字典(以佔位符為鍵,對應 DataPreprocessing 物件為值)。
  • daug_dict: dict。資料擴充字典(以佔位符為鍵,對應 DataAugmentation 物件為值)。
  • show_metric: bool。如果為 True,則在每個步驟顯示準確率。
  • summ_writer: SummaryWriter。摘要撰寫器,用於 Tensorboard 記錄。

initialize_training_ops (i, 會話, tensorboard_詳細, 裁剪梯度)

初始化所有用於訓練的 ops。由於一個網路可以有多個最佳化器,因此分配了一個身分碼「i」以符號區分。此目的是於初始化所有訓練 ops 時供 Trainer 使用。

引數
  • i: int。此最佳化器訓練程序的 ID。
  • 會話: tf.Session。用來訓練網路的會話。
  • tensorboard_詳細: int。詳細記錄。支援
0 - Loss, Accuracy.
1 - Loss, Accuracy, Gradients.
2 - Loss, Accuracy, Gradients, Weights.
3 - Loss, Accuracy, Gradients, Weights, Activations, Sparsity..
  • 裁剪梯度: float。裁剪梯度的選項。