原創(chuàng) | 谷歌JAX 助力科學(xué)計(jì)算
谷歌最新推出的JAX,官方定義為CPU、GPU和TPU上的NumPy。它具有出色的自動(dòng)微分(differentiation)功能,是可用于高性能機(jī)器學(xué)習(xí)研究的python庫。Numpy在科學(xué)計(jì)算領(lǐng)域十分普及,但是在深度學(xué)習(xí)領(lǐng)域,由于它不支持自動(dòng)微分和GPU加速,所以更多的是使用Tensorflow或Pytorch這樣的深度學(xué)習(xí)框架。然而谷歌之前推出的Tensorflow API有一些比較混亂的情況,在1.x的迭代中,就存在如原子op、layers等不同層次的API。面對(duì)不同類型的用戶,使用粒度不同的多層API本身并不是什么問題。但同層次的API也有多種競(jìng)品,如slim和layers等實(shí)則提高了學(xué)習(xí)成本和遷移成本。而JAX使用 XLA 在諸如GPU和TPU的加速器上編譯和運(yùn)行NumPy。它與 NumPy API 非常相似, numpy 完成的事情幾乎都可以用 jax.numpy 完成,從而避免了直接定義API這件事。
下面簡要介紹JAX的幾個(gè)特性,并同時(shí)給出一些示例讓讀者能夠快速入門上手。最后我們將結(jié)合科學(xué)計(jì)算的實(shí)例,展現(xiàn)google JAX在科學(xué)計(jì)算方面的巨大威力。
1.JAX特性
1)自動(dòng)微分:
在深度學(xué)習(xí)領(lǐng)域,網(wǎng)絡(luò)參數(shù)的優(yōu)化是通過基于梯度的反向傳播算法實(shí)現(xiàn)的。因此能夠?qū)崿F(xiàn)任意數(shù)值函數(shù)的微分對(duì)于機(jī)器學(xué)習(xí)有著十分重要的意義。下面結(jié)合官方文檔的例子簡要介紹這一特性。
首先介紹最簡單的grad求一階微分:可以直接通過grad函數(shù)求某一函數(shù)在某位置的梯度值
import jax.numpy as jnpfrom jax import grad, jit, vmapgrad_tanh = grad(jnp.tanh)print(grad_tanh(2.0))[OUT]:0.070650816
當(dāng)然如果想對(duì)雙切正弦函數(shù)繼續(xù)求二階,三階導(dǎo)數(shù),也可以這樣做:
print(grad(grad(jnp.tanh))(2.0))print(grad(grad(grad(jnp.tanh)))(2.0))[OUT]:-0.136218680.25265405
除此之外,還可以利用hessian、jacfwd 和 jacrev 等方法實(shí)現(xiàn)函數(shù)轉(zhuǎn)換,它們的功能分別是求解海森矩陣,以及利用前向或反向模式求解雅克比矩陣。Jacfwd和jacrev可以得到一樣的結(jié)果,但是在不同的情形下求解效率不同,這是因?yàn)閮烧弑澈髮?duì)應(yīng)的微分幾何中的push forward和pull back方法。而前面提到的grad則是基于反向模式。
在一些擬牛頓法的優(yōu)化算法中,常常需要利用二階的海森矩陣。為了實(shí)現(xiàn)海森矩陣的求解。為了實(shí)現(xiàn)這一目標(biāo),我們可以使用jacfwd(jacrev(f))或者jacrev(jacfwd(f))。但是前者的效率更高,因?yàn)閮?nèi)層的雅克比矩陣計(jì)算是通過類似于一個(gè)1維損失函數(shù)對(duì)n維向量的求導(dǎo),明顯使用反向模式更為合適。外層則通常是n維函數(shù)對(duì)n維向量的求導(dǎo),正向模式更有優(yōu)勢(shì)。
2)向量化
無論是科學(xué)計(jì)算或者機(jī)器學(xué)習(xí)的研究中,我們都會(huì)將定義的優(yōu)化目標(biāo)函數(shù)應(yīng)用到大量數(shù)據(jù)中,例如在神經(jīng)網(wǎng)絡(luò)中我們?nèi)ビ?jì)算每一個(gè)批次的損失函數(shù)值。JAX 通過 vmap 轉(zhuǎn)換實(shí)現(xiàn)自動(dòng)向量化,簡化了這種形式的編程。
下面結(jié)合幾個(gè)例子,說明這一用法:
vmap有3個(gè)最重要的參數(shù):
fun: 代表需要進(jìn)行向量化操作的具體函數(shù);
in_axes:輸入格式為元組,代表fun中每個(gè)輸入?yún)?shù)中,使用哪一個(gè)維度進(jìn)行向量化;
out_axes: 經(jīng)過fun計(jì)算后,每組輸出在哪個(gè)維度輸出。
我們先來看二維情況下的一些例子:
import jax.numpy as jnpimport numpy as npimport jax
(1)先定義a,b兩個(gè)二維數(shù)組(array)
a = np.array(([1,3],[23, 5]))print(a)[out]: [[ 1 3][23 5]]b = np.array(([11,7],[19,13]))print(b)[OUT]: [[11 7][19 13]]
(2)正常的兩個(gè)矩陣element-wise的相加
print(jnp.add(a,b))#[[1+11, 3+7]]# [[23+19, 5+13]][OUT]: [[12 10][42 18]]
(3)矩陣a的行 + 矩陣b的行,然后根據(jù)out_axes=0輸出,0表示行輸出
print(jax.vmap(jnp.add, in_axes=(0,0), out_axes=0)(a,b))#[[1+11, 3+7]]#[[23+19, 5+13]][OUT]: [[12 10][42 18]]
(4)矩陣a的行 + 矩陣b的行,然后根據(jù)out_axes=1輸出,1表示列輸出
print(jax.vmap(jnp.add, in_axes=(0,0), out_axes=1)(a,b))# [[1+11, 3+7]]#[[23+19, 5+13]] 再以列轉(zhuǎn)置輸出[OUT]: [[12 42][10 18]]
理解了上面的例子之后,現(xiàn)在開始增加難度,換成三維的例子:
from jax.numpy import jnpA, B, C, D = 2, 3, 4, 5def foo(tree_arg):x, (y, z) = tree_argreturn jnp.dot(x, jnp.dot(y, z))from jax import vmapK = 6 # batch sizex = jnp.ones((K, A, B)) # batch axis in different locationsy = jnp.ones((B, K, C))z = jnp.ones((C, D, K))tree = (x, (y, z))vfoo = vmap(foo, in_axes=((0, (1, 2)),))print(vfoo(tree).shape)
你能夠計(jì)算最后的輸出嗎?
讓我們一起來分析一下。在這段代碼中分別定義了三個(gè)全1矩陣x,y,z,他們的維度分別是6*2*3,3*6*4,4*5*6。而tree則控制了foo函數(shù)中矩陣連續(xù)點(diǎn)積的順序。根據(jù)in_axes可知,y和z的點(diǎn)積最后結(jié)果為6個(gè)3*5的子矩陣,這是由于y和z此時(shí)相當(dāng)于6個(gè)y的子矩陣(3*4維)和6個(gè)z的子矩陣(4*5維)點(diǎn)積。再與x點(diǎn)積,得到的最終結(jié)果為(6,2,5)。
3)JIT編譯
XLA是TensorFlow底層做JIT編譯優(yōu)化的工具,XLA可以對(duì)計(jì)算圖做算子Fusion,將多個(gè)GPU Kernel合并成少量的GPU Kernel,用以減少調(diào)用次數(shù),可以大量節(jié)省GPU Memory IO時(shí)間。Jax本身并沒有重新做執(zhí)行引擎層面的東西,而是直接復(fù)用TensorFlow中的XLA Backend進(jìn)行靜態(tài)編譯,以此實(shí)現(xiàn)加速。
jit的基本使用方法非常簡單,直接調(diào)用jax.jit()或使用@jax.jit裝飾函數(shù)即可:
import jax.numpy as jnpfrom jax import jitdef slow_f(x):# Element-wise ops see a large benefit from fusionreturn x * x + x * 2.0x = jnp.ones((5000, 5000))fast_f = jax.jit(slow_f) # 靜態(tài)編譯slow_f;%timeit -n10 -r3 fast_f(x)%timeit -n10 -r3 slow_f(x)10 loops, best of 3: 24.2 ms per loop10 loops, best of 3: 82.8 ms per loop
運(yùn)行時(shí)間結(jié)果:fast_f(x)是slow_f(x) 在CPU上運(yùn)行速度的3.5倍!靜態(tài)編譯大大加速了程序的運(yùn)行速度。如圖1 所示。
圖 1 tensorflow和JAX中的XLA backend
2.JAX在科學(xué)計(jì)算中的應(yīng)用
分子動(dòng)力學(xué)是現(xiàn)代計(jì)算凝聚態(tài)物理的重要力量。它經(jīng)常用于模擬材料。下面的實(shí)例將展現(xiàn)JAX在以分子動(dòng)力學(xué)為代表的科學(xué)計(jì)算領(lǐng)域的巨大潛力。
首先簡單介紹一下分子動(dòng)力學(xué)。分子動(dòng)力學(xué)的基本任務(wù)就是獲得研究對(duì)象在不同時(shí)刻的位置和速度,然后基于統(tǒng)計(jì)力學(xué)的知識(shí)獲取想得到的物理量,解釋對(duì)象的行為和性質(zhì)。
它的主要步驟包括:
第一步,設(shè)置研究對(duì)象組成粒子的初始位置和速度;第二步,基于粒子的位置計(jì)算每個(gè)粒子的合力,并基于牛頓第二定計(jì)算粒子的加速度。(這里可能有小伙伴會(huì)問,如何計(jì)算?我們下文的勢(shì)函數(shù)將為大家解釋);第三步,基于加速度算下一時(shí)刻粒子速度,根據(jù)速度計(jì)算下一時(shí)刻位置。
不斷循環(huán)2-3步,得到粒子的運(yùn)動(dòng)軌跡。
如需要獲得所有粒子的軌跡,根據(jù)牛頓運(yùn)動(dòng)方程,需要知道粒子的初始位置和速度,質(zhì)量以及受力。粒子的受力是勢(shì)能函數(shù)的負(fù)梯度,所以在分子動(dòng)力學(xué)模擬中,必須確定所有原子之間的勢(shì)能函數(shù),即勢(shì)能關(guān)于兩個(gè)原子之間相對(duì)位置的函數(shù),這個(gè)勢(shì)函數(shù)我們也稱之為力場(chǎng)。
在分子動(dòng)力學(xué)中,復(fù)雜力場(chǎng)的優(yōu)化是一類重要的問題。ReaxFF就是其中的代表。相比于傳統(tǒng)力場(chǎng)基于靜態(tài)化學(xué)鍵以及不隨化學(xué)環(huán)境改變的靜態(tài)電荷假設(shè),ReaxFF引入鍵級(jí)勢(shì)的概念,這允許鍵在整個(gè)模擬過程里形成和斷開,并動(dòng)態(tài)地為原子分配電荷。也正是由于這些特性的存在,反應(yīng)力場(chǎng)的形式明顯比經(jīng)典力場(chǎng)更為復(fù)雜。這使得我們將其計(jì)算的能量等值與密度泛函或者實(shí)驗(yàn)值對(duì)比得到的損失函數(shù)進(jìn)行反饋優(yōu)化時(shí)更為困難,如圖2 所示。
圖2 反應(yīng)力場(chǎng)的參數(shù)構(gòu)成
各種全局優(yōu)化方法,例如遺傳算法,模擬退火算法,進(jìn)化算法以及粒子群優(yōu)化算法等等往往沒有利用任何梯度信息,這使得這些搜索成本可能會(huì)非常昂貴。而JAX的出現(xiàn)為這一問題的解決帶來了可能。
JAX-REAXFF:
1)流程
圖3 Jax-ReaxFF流程
圖3是Jax-ReaxFF的任務(wù)流概述,可以將其大致分為兩個(gè)階段:聚類和主優(yōu)化循環(huán)。而主優(yōu)化循環(huán)則分別包括利用梯度信息的能量最小化和力場(chǎng)參數(shù)優(yōu)化。
聚類只要是根據(jù)相互作用列表進(jìn)行聚類,在內(nèi)存中正確對(duì)齊,以確保有效的單指令多數(shù)據(jù)(SIMD)并行化提高效率。
而主優(yōu)化循環(huán)中能量最小化的過程是尋找能量最低最穩(wěn)定幾何構(gòu)型的過程。它的具體做法是利用JAX求體系勢(shì)能對(duì)原子坐標(biāo)的梯度,進(jìn)行優(yōu)化。力場(chǎng)參數(shù)的優(yōu)化在原文中則分別使用了兩種擬牛頓優(yōu)化方法——L-BFGS和SLSQP。這通scipy.optimize.minimize函數(shù)實(shí)現(xiàn),其中向該函數(shù)直接傳入JAX求解梯度的方法以提高效率。能量最小化和力場(chǎng)參數(shù)優(yōu)化迭代循環(huán)。
圖4 JAX-ReaxFF主循環(huán)優(yōu)化
Github地址:https://github.com/cagrikymk/JAX-ReaxFF
2)效果
作者在多個(gè)數(shù)據(jù)集上分別實(shí)現(xiàn)了參數(shù)的優(yōu)化,可以看到相比于其他算法,利用JAX梯度信息的優(yōu)化具有明顯的速度優(yōu)勢(shì)。
圖5 金屬鈷數(shù)據(jù)集結(jié)果
參考文獻(xiàn):https://pubs.acs.org/doi/pdf/10.1021/acs.jctc.2c00363https://jax.readthedocs.io/en/latest/faq.htmlhttps://zhuanlan.zhihu.com/p/474724292https://arxiv.org/abs/2010.09063https://mp.weixin.qq.com/s/AoygUZK886RClDBnp1v3jw
*博客內(nèi)容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀點(diǎn),如有侵權(quán)請(qǐng)聯(lián)系工作人員刪除。