Google在今年的I/O大會,發布數項TensorFlow與Keras深度學習工具更新,重點包括可讓開發者能夠簡單存取預訓練模型的模組化函式庫,還推出可用於同步分散式模型運算的擴充套件DTensor,而藉由新的JAX2TF API,開發者便能夠在TensorFlow生態系中,使用JAX數值函式庫編寫的模型。

模組化函式庫KerasCV與KerasNLP,可簡化電腦視覺和自然語言處理預訓練模型存取。這兩個新的模組化函式庫,讓開發者只要編寫幾行程式碼,就可在應用程式中整合圖像分類或是文字生成等機器學習功能。由於這兩個函式庫皆是Keras的一部分,而Keras在TensorFlow 2.0成為內建進階API,因此開發者能夠直接在TensorFlow中使用Keras,這也就代表KerasCV、KerasNLP與TensorFlow生態系可完全整合。

TensorFlow擴充套件DTensor透過組合並微調多種平行技術,以支援更大且高效能的模型訓練。以往機器學習開發人員可以透過資料平行技術擴展模型,將資料拆分之後,供水平擴展的模型實例訓練使用,不過這種擴展訓練方法有一個嚴重的限制,即是要求模型在單個硬體裝置執行。

但隨著模型越來越大,單一裝置的運算能力可能不足以處理龐大的模型,因此開發者開始需要將模型擴展到更多硬體裝置上執行。也就是說訓練龐大模型不僅需要資料平行性,還需要模型平行性,將模型分割成可以平行訓練的分片。

而DTensor不只支援資料平行性,也提供模型平行性,透過結合這兩種技術,更有效地擴展模型,同時DTensor也不受加速器類型的限制,支援TPU、GPU等各種運算裝置。

Google也釋出輕量級API JAX2TF,來加速機器學習研究生產化的速度。Google開發的Python函式庫JAX被大量用於高效數值運算上,同時JAX也支援硬體加速,能夠在GPU或TPU上高速處理大型資料集和複雜運算,但要把JAX用於生產中仍不是一件直覺簡單的事。

而JAX2TF API的出現,是要讓JAX能夠更簡單地進入TensorFlow生態系,使開發者可以將JAX模型部署到TensorFlow Serving伺服器或是TFLite裝置上,並在TensorFlow中繼續訓練JAX模型,甚至是將JAX模型和TensorFlow模型融合,以獲得更大的靈活性。

除了以上更新,開發團隊也預告,他們即將推出TensorFlow量化API,該API將會是TensorFlow 2的原生量化工具,能夠在不影響模型品質的前提下,進一步縮小模型,並且提升模型執行速度。

熱門新聞

Advertisement