圖卷積和消息傳遞理論的可視化詳解
來源:Deephub Imba
假設現(xiàn)在需要設計治療某些疾病的****物。有一個其中包含成功治療疾病的****物和不起作用的****物數(shù)據(jù)集,現(xiàn)在需要設計一種新****,并且想知道它是否可以治療這種疾病。如果可以創(chuàng)建一個有意義的****物表示,就可以訓練一個分類器來預測它是否對疾病治療有用。我們的****物是分子式,可以用圖表表示。該圖的節(jié)點是原子。也可以用特征向量 x 來描述原子(它可以由原子屬性組成,如質(zhì)量、電子數(shù)或其他)。為了對分子進行分類,我們希望利用有關(guān)其空間結(jié)構(gòu)和原子特征的知識來獲得一些有意義的表示。
以圖形表示的分子示例。原子有它們的特征向量 X。特征向量中的索引表示節(jié)點索引。
最直接的方法是聚合特征向量,例如,簡單地取它們的平均值:
這是一個有效的解決方案,但它忽略了重要的分子空間結(jié)構(gòu)。
我們可以提出另一種想法:用鄰接矩陣表示分子圖,并用特征向量“擴展”其深度。我們得到了一個偽圖像 [8, 8, N],其中 N 是節(jié)點特征向量 x 的維數(shù)?,F(xiàn)在可以使用常規(guī)卷積神經(jīng)網(wǎng)絡并提取分子嵌入。
圖結(jié)構(gòu)可以表示為鄰接矩陣。節(jié)點特征可以表示為圖像中的通道(括號代表連接)。
這種方法利用了圖結(jié)構(gòu),但有一個巨大的缺點:如果改變節(jié)點的順序會得到不同的表示。所以這樣的表示不是置換不變量。但是鄰接矩陣中的節(jié)點順序是任意的, 例如,可以將列順序從 [0, 1, 2, 3, 4, 5, 6, 7] 更改為 [0, 2, 1, 3, 5, 4, 7, 6],它仍然是 圖的有效鄰接矩陣。所以可以創(chuàng)建所有可能的排列并將它們堆疊在一起,這會使我們有 1625702400 個可能的鄰接矩陣(8!* 8?。?shù)據(jù)量太大了,所以應該找到更好的解決方案。
但是問題是,我們?nèi)绾握峡臻g信息并有效地做到這一點?上面的例子可以讓我們想到卷積的概念,但它應該在圖上完成。
所以圖卷積就出現(xiàn)了
當對圖像應用常規(guī)卷積時會發(fā)生什么?相鄰像素的值乘以過濾器權(quán)重并相加。我們可以在圖表上做類似的事情嗎?是的,可以在矩陣 X 中堆疊節(jié)點特征向量并將它們乘以鄰接矩陣 A,然后得到了更新的特征 X`,它結(jié)合了有關(guān)節(jié)點最近鄰居的信息。為簡單起見,讓我們考慮一個具有標量節(jié)點特征的示例:
標量值節(jié)點特征的示例。僅針對節(jié)點 0 說明了 1 跳距離,但對于所有其他節(jié)點也是一樣的。
每個節(jié)點都會獲得有關(guān)其最近鄰居的信息(也稱為 1 跳距離)。鄰接矩陣上的乘法將特征從一個節(jié)點傳播到另一個節(jié)點。
在圖像域中可以通過增加濾波器大小來擴展感受野。在圖中則可以考慮更遠的鄰居。如果將 A^2 乘以 X——關(guān)于 2 跳距離節(jié)點的信息會傳播到節(jié)點:
節(jié)點 0 現(xiàn)在具有關(guān)于節(jié)點 2 的信息,該信息位于 2 跳距離內(nèi)。該圖僅針對節(jié)點 0 說明了躍點,但對于所有其他節(jié)點也是如此。
矩陣 A 的更高冪的行為方式相同:乘以 A^n 會導致特征從 n 跳距離節(jié)點傳播,所以可以通過將乘法添加到鄰接矩陣的更高次方來擴展“感受野”。為了概括這一操作,可以將節(jié)點更新的函數(shù)定義為具有某些權(quán)重 w 的此類乘法之和:
多項式圖卷積濾波器。A——圖鄰接矩陣,w——標量權(quán)重,x——初始節(jié)點特征,x'——更新節(jié)點特征。
新特征 x' 是來自 n 跳距離的節(jié)點的某種混合,相應距離的影響由權(quán)重 w 控制。這樣的操作可以被認為是一個圖卷積,濾波器 P 由權(quán)重 w 參數(shù)化。與圖像上的卷積類似,圖卷積濾波器也可以具有不同的感受野并聚合有關(guān)節(jié)點鄰居的信息,但鄰居的結(jié)構(gòu)不像圖像中的卷積核那樣規(guī)則。
這樣的多項式與一般卷積一樣是置換等變性的??梢允褂脠D拉普拉斯算子而不是鄰接矩陣來傳遞特征差異而不是節(jié)點之間的特征值(也可以使用標準化的鄰接矩陣)。
將圖卷積表示為多項式的能力可以從一般的譜圖卷積( spectral graph convolutions)中推導出來。例如,利用帶有圖拉普拉斯算子的切比雪夫多項式的濾波器提供了直接譜圖卷積的近似值 [1]。
并且可以輕松地將其推廣到具有相同方程的節(jié)點特征的任何維度上。但在更高維度的情況下,處理的是節(jié)點特征矩陣 X 而不是節(jié)點特征向量。例如,對于 N 個節(jié)點和節(jié)點中的 1 或 M 個特征,我們得到:
x——節(jié)點特征向量,X——堆疊節(jié)點特征,M——節(jié)點特征向量的維度,N——節(jié)點數(shù)量。
可以將特征向量的“深度”維度視為圖像卷積中的“通道”。
現(xiàn)在用另外一種不同的方式看看上面的討論。繼續(xù)采用上面討論的一個簡單的多項式卷積,只有兩個第一項,讓 w 等于 1:
現(xiàn)在如果將圖特征矩陣 X 乘以 (I + A) 可以得到以下結(jié)果:
對于每個節(jié)點,都添加了相鄰節(jié)點的總和。因此該操作可以表示如下:
N(i) 表示節(jié)點 i 的一跳距離鄰居。
在這個例子中,“update”和“aggregate”只是簡單的求和函數(shù)。
這種關(guān)于節(jié)點特征更新被稱為消息傳遞機制。這樣的消息傳遞的單次迭代等效于帶有過濾器 P= I + A 的圖卷積。那么如果想從更遠的節(jié)點傳播信息,我們可以再次重復這樣的操作幾次,從而用更多的多項式項逼近圖卷積。
但是需要注意的是:如果重復多次圖卷積,可能會導致圖過度平滑,其中每個節(jié)點嵌入對于所有連接的節(jié)點都變成相同的平均向量。
那么如何增強消息傳遞的表達能力?可以嘗試聚合和更新函數(shù),并額外轉(zhuǎn)換節(jié)點特征:
W1——更新節(jié)點特征的權(quán)重矩陣,W2——更新相鄰節(jié)點特征的權(quán)重矩陣。
可以使用任何排列不變函數(shù)進行聚合,例如 sum、max、mean 或更復雜的函數(shù),例如 DeepSets。
例如,評估消息傳遞的基本方法之一是 GCN 層:
第一眼看到這個公式可能并不熟悉,但讓我們使用“更新”和“聚合”函數(shù)來看看它:
使用單個矩陣 W 代替兩個權(quán)重矩陣 W1 和 W2。更新函數(shù)是求和,聚合函數(shù)是歸一化節(jié)點特征的總和,包括節(jié)點特征 i。d——表示節(jié)點度。
這樣就使用一個權(quán)重矩陣 W 而不是兩個,并使用 Kipf 和 Welling 歸一化求和作為聚合,還有一個求和作為更新函數(shù)。聚合操作評估鄰居和節(jié)點 i 本身,這相當于將自循環(huán)( self-loops)添加到圖中。
所以具有消息傳遞機制的 GNN 可以表示為多次重復的聚合和更新函數(shù)。消息傳遞的每次迭代都可以被視為一個新的 GNN 層。節(jié)點更新的所有操作都是可微的,并且可以使用可以學習的權(quán)重矩陣進行參數(shù)化?,F(xiàn)在我們可以構(gòu)建一個圖卷積網(wǎng)絡并探索它是如何執(zhí)行的。
使用上面提到的 GCN 層構(gòu)建和訓練圖神經(jīng)網(wǎng)絡。對于這個例子,我將使用 PyG 庫和 [2] 中提供的 AIDS 圖數(shù)據(jù)集。它由 2000 個代表分子化合物的圖表組成:其中 1600 個被認為對 HIV 無活性,其中 400 個對 HIV 有活性。每個節(jié)點都有一個包含 38 個特征的特征向量。以下是數(shù)據(jù)集中分子圖表示的示例:
使用 networkx 庫可視化來自 AIDS 數(shù)據(jù)集的樣本。
為簡單起見,我們將構(gòu)建一個只有 3 個 GCN 層的模型。嵌入空間可視化的最終嵌入維度將是 2-d。為了獲得圖嵌入,將使用均值聚合。為了對分子進行分類,將在圖嵌入之后使用一個簡單的線性分類器。
具有三個 GCN 層、平均池化和線性分類器的圖神經(jīng)網(wǎng)絡。
對于第一次消息傳遞的迭代(第 1 層),初始特征向量被投影到 256 維空間。在第二個消息傳遞期間(第 2 層),特征向量在同一維度上更新。在第三次消息傳遞(第 3 層)期間,特征被投影到二維空間,然后對所有節(jié)點特征進行平均以獲得最終的圖嵌入。最后,這些嵌入被輸送到線性分類器。選擇二維維度只是為了可視化,更高的維度肯定會更好。這樣的模型可以使用 PyG 庫來實現(xiàn):
from torch import nn
from torch.nn import functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool
class GCNModel(nn.Module):
def __init__(self, feature_node_dim=38, num_classes=2, hidden_dim=256, out_dim=2):
super(GCNModel, self).__init__()
torch.manual_seed(123)
self.conv1 = GCNConv(feature_node_dim, hidden_dim)
self.conv2 = GCNConv(hidden_dim, hidden_dim)
self.conv3 = GCNConv(hidden_dim, out_dim)
self.linear = nn.Linear(out_dim, num_classes)
def forward(self, x, edge_index, batch):
# Graph convolutions with nonlinearity:
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
x = F.relu(x)
x = self.conv3(x, edge_index)
# Graph embedding:
x_embed = global_mean_pool(x, batch)
# Linear classifier:
x = self.linear(x_embed)
return x, x_embed
在其訓練期間,可以可視化圖嵌入和分類器決策邊界??梢钥吹较鬟f操作如何使僅使用 3 個圖卷積層的生成有意義的圖嵌入的。這里使用隨機初始化的模型嵌入并沒有線性可分分布:
上圖是對隨機初始化的模型進行正向傳播得到的分子嵌入
但在訓練過程中,分子嵌入很快變成線性可分:
即使是 3 個圖卷積層也可以生成有意義的二維分子嵌入,這些嵌入可以使用線性模型進行分類,在驗證集上具有約 82% 的準確度。
在本文中介紹了圖卷積如何表示為多項式,以及如何使用消息傳遞機制來近似它。這種具有附加特征變換的方法具有強大的表示能力。本文中僅僅觸及了圖卷積和圖神經(jīng)網(wǎng)絡的皮毛。圖卷積層和聚合函數(shù)有十幾種不同的體系結(jié)構(gòu)。并且在圖上能夠完成的任務任務也很多,如節(jié)點分類、邊緣重建等。所以如果想深入挖掘,PyG教程是一個很好的開始。
*博客內(nèi)容為網(wǎng)友個人發(fā)布,僅代表博主個人觀點,如有侵權(quán)請聯(lián)系工作人員刪除。
電氣符號相關(guān)文章:電氣符號大全
萬用表相關(guān)文章:萬用表怎么用
可控硅相關(guān)文章:可控硅工作原理
手機電池相關(guān)文章:手機電池修復
pa相關(guān)文章:pa是什么
晶體管相關(guān)文章:晶體管工作原理
電荷放大器相關(guān)文章:電荷放大器原理 晶體管相關(guān)文章:晶體管原理 調(diào)速器相關(guān)文章:調(diào)速器原理