偷偷摘套内射激情视频,久久精品99国产国产精,中文字幕无线乱码人妻,中文在线中文a,性爽19p

PyTorch Geometric框架下圖神經(jīng)網(wǎng)絡(luò)的可解釋性機(jī)制:原理、實(shí)現(xiàn)與評(píng)估

人工智能 機(jī)器學(xué)習(xí)
在機(jī)器學(xué)習(xí)領(lǐng)域存在一個(gè)普遍的認(rèn)知誤區(qū),即可解釋性與準(zhǔn)確性存在對(duì)立關(guān)系。這種觀點(diǎn)認(rèn)為可解釋模型在復(fù)雜度上存在固有限制,因此無(wú)法達(dá)到最優(yōu)性能水平,神經(jīng)網(wǎng)絡(luò)之所以能夠在各個(gè)領(lǐng)域占據(jù)主導(dǎo)地位,正是因?yàn)槠涑搅巳祟?lèi)可理解的范疇。

在機(jī)器學(xué)習(xí)領(lǐng)域存在一個(gè)普遍的認(rèn)知誤區(qū),即可解釋性與準(zhǔn)確性存在對(duì)立關(guān)系。這種觀點(diǎn)認(rèn)為可解釋模型在復(fù)雜度上存在固有限制,因此無(wú)法達(dá)到最優(yōu)性能水平,神經(jīng)網(wǎng)絡(luò)之所以能夠在各個(gè)領(lǐng)域占據(jù)主導(dǎo)地位,正是因?yàn)槠涑搅巳祟?lèi)可理解的范疇。

其實(shí)這種觀點(diǎn)存在根本性的謬誤。研究表明,黑盒模型在高風(fēng)險(xiǎn)決策場(chǎng)景中往往表現(xiàn)出準(zhǔn)確性不足的問(wèn)題[1],[2],[3]。因此模型的不可解釋性應(yīng)被視為一個(gè)需要克服的缺陷,而非獲得高準(zhǔn)確性的必要條件。這種缺陷既非必然,也非不可避免,在構(gòu)建可靠的決策系統(tǒng)時(shí)必須得到妥善解決。

解決此問(wèn)題的關(guān)鍵在于可解釋性??山忉屝允侵改P途邆湎蛉祟?lèi)展示其決策過(guò)程的能力[4]。模型需要能夠清晰地展示哪些輸入數(shù)據(jù)、特征或參數(shù)對(duì)其預(yù)測(cè)結(jié)果產(chǎn)生了影響,從而實(shí)現(xiàn)決策過(guò)程的透明化。

PyTorch Geometric的可解釋性模塊為圖機(jī)器學(xué)習(xí)模型提供了一套完整的可解釋性工具[5]。該模塊具有以下核心功能:

  1. 關(guān)鍵圖特性識(shí)別 — 能夠識(shí)別并突出顯示對(duì)模型預(yù)測(cè)具有重要影響的節(jié)點(diǎn)、邊和特征。
  2. 圖結(jié)構(gòu)定制與隔離 — 通過(guò)特定圖組件的掩碼操作或關(guān)注區(qū)域的界定,實(shí)現(xiàn)針對(duì)性的解釋生成。
  3. 圖特性可視化 — 提供多種可視化方法,包括帶有邊權(quán)重透明度的子圖展示和top-k特征重要性條形圖等。
  4. 評(píng)估指標(biāo)體系 — 提供多維度的定量評(píng)估方法,用于衡量解釋的質(zhì)量。

可解釋性模塊的系統(tǒng)架構(gòu)圖:

我們下面使用Reddit數(shù)據(jù)集來(lái)進(jìn)行詳細(xì)的描述。

數(shù)據(jù)集

我們選用Reddit數(shù)據(jù)集作為實(shí)驗(yàn)數(shù)據(jù)。該數(shù)據(jù)集是一個(gè)包含不同社區(qū)Reddit帖子的標(biāo)準(zhǔn)基準(zhǔn)數(shù)據(jù)集,可通過(guò)PyTorch Geometric提供的公開(kāi)數(shù)據(jù)集倉(cāng)庫(kù)直接訪(fǎng)問(wèn)。

Reddit數(shù)據(jù)集的規(guī)模較大,包含232,965個(gè)節(jié)點(diǎn)、114,615,892條邊,每個(gè)節(jié)點(diǎn)具有602維特征,共涉及41個(gè)分類(lèi)類(lèi)別??紤]到數(shù)據(jù)集規(guī)模,我們采用NeighborLoader類(lèi)實(shí)現(xiàn)小批量處理。該類(lèi)提供了一種高效的采樣機(jī)制,可以對(duì)大規(guī)模圖數(shù)據(jù)集中的節(jié)點(diǎn)及其k-跳鄰域進(jìn)行小批量采樣。所以設(shè)置了三個(gè)NeighborLoader實(shí)例,分別用于訓(xùn)練、測(cè)試和可解釋性分析。num_neighbors和batch_size參數(shù)可根據(jù)系統(tǒng)資源情況進(jìn)行調(diào)整。

# 數(shù)據(jù)集加載與預(yù)處理
 dataset = Reddit(root="/tmp/Reddit")  
 data = dataset[0]  
   
 train_loader = NeighborLoader(  
         data,  
         input_nodes=data.train_mask,  
         # a=第一層鄰居采樣數(shù)量
         # b=第二層鄰居采樣數(shù)量
         num_neighbors=[a, b]  
         batch_size=batch_size,  
         shuffle=True  
    )  
   
 test_loader = NeighborLoader(  
         data,  
         input_nodes=data.test_mask,  
         num_neighbors=num_neighbors,  
         batch_size=batch_size,  
         shuffle=False  # 測(cè)試階段保持順序以確??芍貜?fù)性
    )  
   
 explain_loader = NeighborLoader(  
     data,  
     batch_size=batch_size,  
     num_neighbors=num_neighbors,  
     shuffle=True  
 )

GraphSAGE

我們采用GraphSAGE作為基礎(chǔ)模型架構(gòu)。GraphSAGE是一個(gè)專(zhuān)為歸納學(xué)習(xí)設(shè)計(jì)的圖神經(jīng)網(wǎng)絡(luò)框架,其特點(diǎn)是能夠?qū)㈩A(yù)測(cè)能力泛化到未見(jiàn)過(guò)的節(jié)點(diǎn)。模型的高效鄰居采樣機(jī)制使其特別適合處理Reddit這樣的大規(guī)模圖數(shù)據(jù)集。以下代碼展示了模型的核心結(jié)構(gòu)及其訓(xùn)練、測(cè)試方法的實(shí)現(xiàn)。

# GNN模型定義
 class SAGE(torch.nn.Module):  
     def __init__(self, in_channels, hidden_channels, out_channels):  
         super().__init__()  
         self.convs = torch.nn.ModuleList()  
         # 構(gòu)建雙層網(wǎng)絡(luò)結(jié)構(gòu)
         self.convs.append(SAGEConv(in_channels, hidden_channels))  
         self.convs.append(SAGEConv(hidden_channels, out_channels))  
   
     def forward(self, x, edge_index):  
         for i, conv in enumerate(self.convs):  
             x = conv(x, edge_index)  
             if i < len(self.convs) - 1:  
                 x = F.relu(x)  
                 x = F.dropout(x, p=0.5, training=self.training)  
         return x

模型訓(xùn)練實(shí)現(xiàn)

# 訓(xùn)練過(guò)程實(shí)現(xiàn)
 def train(model, loader, optimizer, device, num_train_nodes):  
     model.train()  
     total_loss = 0  
     total_correct = 0  
   
     for batch in tqdm(loader, desc="Training"):  
         # 數(shù)據(jù)遷移至指定計(jì)算設(shè)備
         batch = batch.to(device)  
   
         # 前向傳播計(jì)算
         optimizer.zero_grad()  
         out = model(batch.x, batch.edge_index)  
   
         # 損失計(jì)算與反向傳播
         loss = F.cross_entropy(out[batch.train_mask], batch.y[batch.train_mask])  
         loss.backward()  
         optimizer.step()  
   
         # 計(jì)算當(dāng)前批次訓(xùn)練節(jié)點(diǎn)的預(yù)測(cè)準(zhǔn)確率
         pred = out[batch.train_mask].argmax(dim=-1)  
         total_correct += int((pred == batch.y[batch.train_mask]).sum())  
         total_loss += loss.item()  
   
     return total_loss / len(loader), total_correct / num_train_nodes

模型評(píng)估實(shí)現(xiàn)

# 測(cè)試過(guò)程實(shí)現(xiàn)
 def test(model, loader, device):  
     model.eval()  
     total_correct = 0  
     total_test_nodes = 0  
   
     for batch in tqdm(loader, desc="Testing"):  
         batch = batch.to(device)  
   
         # 預(yù)測(cè)計(jì)算
         with torch.no_grad():  
             out = model(batch.x, batch.edge_index)  
             pred = out.argmax(dim=-1)  
   
         # 評(píng)估測(cè)試節(jié)點(diǎn)的預(yù)測(cè)準(zhǔn)確率
         mask = batch.test_mask  
         total_correct += int((pred[mask] == batch.y[mask]).sum())  
         total_test_nodes += mask.sum().item()  
   
     # 計(jì)算整體測(cè)試準(zhǔn)確率
     accuracy = total_correct / total_test_nodes  
     return accuracy

Explainer模塊配置

要啟用可解釋性分析功能,首先需要完成Explainer的初始化配置。以下是相關(guān)參數(shù)的詳細(xì)說(shuō)明:

model: torch.nn.Module,  
 algorithm: ExplainerAlgorithm,  
 explanation_type: Union[ExplanationType, str],  
 node_mask_type: Optional[Union[MaskType, str]] = None,  
 edge_mask_type: Optional[Union[MaskType, str]] = None,  
 model_config: Union[ModelConfig, Dict[str, Any]],  
 threshold_config: Optional[ThresholdConfig] = None

下面對(duì)各參數(shù)進(jìn)行詳細(xì)說(shuō)明:

**model: torch.nn.Module** — 指定需要進(jìn)行可解釋性分析的PyG模型實(shí)例。

**algorithm: ExplainerAlgorithm** — 可選的解釋器算法:

這里主要要使用_GNNExplainer

  • DummyExplainer: 用于生成隨機(jī)解釋的基準(zhǔn)測(cè)試器
  • GNNExplainer: 基于"GNNExplainer: Generating Explanations for Graph Neural Networks"論文實(shí)現(xiàn)[6]
  • CaptumExplainer: 集成Captum開(kāi)源庫(kù)的解釋器[7]
  • PGExplainer: 基于"Parameterized Explainer for Graph Neural Network"論文實(shí)現(xiàn)[8]
  • AttentionExplainer: 基于注意力機(jī)制的解釋器[9]
  • GraphMaskExplainer: 基于Interpreting Graph Neural Networks for NLP With Differentiable Edge Masking論文實(shí)現(xiàn)[10]

**explanation_type: Union[ExplanationType, str]** — 解釋類(lèi)型配置,包含兩種選項(xiàng):

"model": 針對(duì)模型預(yù)測(cè)機(jī)制的解釋

調(diào)用Explainer時(shí)可通過(guò)index參數(shù)指定待解釋的節(jié)點(diǎn)、邊或圖的索引,實(shí)現(xiàn)精確定位分析。

"phenomenon": 針對(duì)數(shù)據(jù)內(nèi)在特征的解釋

調(diào)用時(shí)需要通過(guò)target參數(shù)指定包含所有節(jié)點(diǎn)真實(shí)標(biāo)簽的張量。這使得Explainer能夠比對(duì)模型預(yù)測(cè)與真實(shí)標(biāo)簽,從而識(shí)別圖中對(duì)模型決策過(guò)程最具影響力的組件(節(jié)點(diǎn)、邊或特征),并評(píng)估其與真實(shí)數(shù)據(jù)分布的一致性。

mask_type參數(shù)配置

**node_mask_type: Optional[Union[MaskType, str]] = None**

**edge_mask_type: Optional[Union[MaskType, str]] = None**

提供四種掩碼策略:

  1. None: 不進(jìn)行掩碼處理
  2. "object": 整體掩碼策略,每次掩碼一個(gè)完整的節(jié)點(diǎn)/邊
  3. "common_attributes": 全局特征掩碼,對(duì)所有節(jié)點(diǎn)/邊的指定特征進(jìn)行掩碼
  4. "attributes": 局部特征掩碼,僅對(duì)指定節(jié)點(diǎn)/邊的特定特征進(jìn)行掩碼

**model_config: Union[ModelConfig, Dict[str, Any]]** — 模型配置參數(shù)集

主要包括:

1.mode: 預(yù)測(cè)任務(wù)類(lèi)型配置,可選值包括:'binary_classification'、'multiclass_classification'或'regression'

2.task_level: 預(yù)測(cè)任務(wù)級(jí)別,可選值包括:'node'、'edge'或'graph'

3.return_type: 模型輸出格式配置,可選值包括:'probs'、'log_probs'或'raw'

**threshold_config: Optional[ThresholdConfig]** — 閾值控制參數(shù),用于精確控制掩碼應(yīng)用的范圍和方式。

1.threshold_type: 閾值類(lèi)型配置,包含以下選項(xiàng):

  • None: 保持原始狀態(tài),保留所有重要性分?jǐn)?shù)
  • "hard": 采用固定閾值截?cái)嗖呗?,將低于指定值的重要性分?jǐn)?shù)置零
  • "topk": 保留重要性分?jǐn)?shù)最高的k個(gè)元素(節(jié)點(diǎn)、邊或特征),其余置零
  • "topk_hard": 類(lèi)似于"topk",但將保留元素的重要性分?jǐn)?shù)統(tǒng)一設(shè)為1,實(shí)現(xiàn)二值化表示

2.value: 閾值參數(shù)設(shè)置

  • 對(duì)于threshold_type = "hard",value取值范圍為[0,1]
  • 對(duì)于threshold_type = "topk"或"topk_hard",value表示保留的元素?cái)?shù)量k

閾值參數(shù)配置的關(guān)鍵考慮:

  • k值過(guò)小可能導(dǎo)致重要信息丟失
  • k值過(guò)大可能引入噪聲信息
  • 存在性能指標(biāo)發(fā)生突變的臨界閾
  • 最優(yōu)閾值的確定通常需要針對(duì)具體應(yīng)用場(chǎng)景進(jìn)行實(shí)驗(yàn)驗(yàn)證

Explainer調(diào)用實(shí)現(xiàn)

Explainer的調(diào)用需要配置以下參數(shù):

x: Union[Tensor, Dict[str, Tensor]],  
 edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]],  
 target: Optional[Tensor] = None,  
 index: Optional[Union[int, Tensor]] = None

各參數(shù)說(shuō)明:

  • x: 節(jié)點(diǎn)特征矩陣(對(duì)應(yīng)data.x或batch.x)
  • edge_index: 邊索引張量(對(duì)應(yīng)data.edge_index或batch.edge_index)
  • target: 真實(shí)標(biāo)簽張量(對(duì)應(yīng)data.y或batch.y)
  • index: 指定待解釋的節(jié)點(diǎn)、邊或圖的索引,可以是單個(gè)整數(shù)、整數(shù)張量或None(表示解釋所有輸出)

實(shí)例分析

假設(shè)模型將索引為x=10的帖子分類(lèi)到某個(gè)特定subreddit,我們可以分析這一預(yù)測(cè)的依據(jù),確定哪些特征對(duì)該預(yù)測(cè)結(jié)果產(chǎn)生了關(guān)鍵影響。下面展示如何初始化和調(diào)用Explainer來(lái)實(shí)現(xiàn)這一分析:

index = 143  
   
 model_explainer = Explainer(  
     model=model,  
     algorithm=GNNExplainer(epochs=50),  
     explanation_type='model',  
     node_mask_type='attributes',  
     model_config=dict(  
         mode='multiclass_classification',  
         task_level='node',  
         return_type='log_probs',  
    )  
     threshold_config=dict(threshold_type='topk', value=20)  
 )

說(shuō)明:

  • 選擇explanation_type='model'用于分析模型的預(yù)測(cè)機(jī)制
  • 設(shè)置node_mask_type='attributes'以研究特征重要性,同時(shí)保持node_edge_type=None以專(zhuān)注于節(jié)點(diǎn)分析
  • model_config配置反映了數(shù)據(jù)集特點(diǎn):41個(gè)類(lèi)別的多分類(lèi)問(wèn)題(mode = 'multiclass_classification'),節(jié)點(diǎn)級(jí)預(yù)測(cè)任務(wù)(task_level = 'node'),使用對(duì)數(shù)概率輸出(return_type = 'log_probs')
  • threshold_config設(shè)置為保留最重要的20個(gè)節(jié)點(diǎn)(threshold_type='topk', value=20)

執(zhí)行分析:

model_explanation = model_explainer(  
     batch.x,  
     batch.edge_index,  
     index=index  
 )

由于設(shè)置了explanation_type = 'model',此處無(wú)需指定target參數(shù),執(zhí)行完成后返回Explanation對(duì)象,包含完整的解釋結(jié)果

Explanation類(lèi)封裝了可解釋性模塊產(chǎn)生的關(guān)鍵分析信息[11]。其結(jié)構(gòu)設(shè)計(jì)如下:

x: Optional[Tensor] = None,  
 edge_index: Optional[Tensor] = None,  
 edge_attr: Optional[Tensor] = None,  
 y: Optional[Union[Tensor, int, float]] = None,  
 pos: Optional[Tensor] = None,  
 time: Optional[Tensor] = None

核心屬性說(shuō)明:

  • x: 節(jié)點(diǎn)特征矩陣,維度為[num_nodes, num_features]
  • edge_index: 邊索引矩陣,維度為[2, num_edges]
  • edge_attr: 邊特征矩陣,維度為[num_edges, num_edge_features]
  • y: 真實(shí)標(biāo)簽,可以是回歸問(wèn)題的目標(biāo)值或分類(lèi)問(wèn)題的類(lèi)別標(biāo)簽
  • pos: 節(jié)點(diǎn)空間坐標(biāo)矩陣,維度為[num_nodes, num_dimension]
  • time: 時(shí)序信息張量,格式根據(jù)具體時(shí)間特征定義(如,time = [2022, 2023, 2024]表示節(jié)點(diǎn)0-2的時(shí)間戳)

解釋結(jié)果分析方法

預(yù)測(cè)行為分析

以下代碼用于獲取模型的初始預(yù)測(cè)結(jié)果:

model.eval()  
 with torch.no_grad():  
     predictions = model_explainer.get_prediction(batch.x, batch.edge_index)

要分析特定圖屬性掩碼對(duì)預(yù)測(cè)的影響,可使用get_masked_prediction方法。例如,分析掩碼節(jié)點(diǎn)5對(duì)預(yù)測(cè)的影響:

# 構(gòu)建掩碼矩陣
 node_mask = torch.ones_like(batch.x)  
 node_mask[5] = 0  # 對(duì)節(jié)點(diǎn)5進(jìn)行掩碼處理
   
 with torch.no_grad():  
     masked_predictions = model_explainer.get_masked_prediction(batch.x, batch.edge_index, node_mask=node_mask)

進(jìn)行預(yù)測(cè)差異分析:

difference = predictions - masked_predictions  
 mean_difference = difference.mean(dim=0).cpu().numpy()  
   
 plt.figure(figsize=(10, 6))  
 plt.plot(mean_difference, color="olive", label="Mean Difference")  
 plt.title('原始預(yù)測(cè)與掩碼預(yù)測(cè)的差異分析')  
 plt.xlabel('類(lèi)別')  
 plt.ylabel('Logits差異均值')  
 plt.legend()  
 plt.show()

該圖展示了節(jié)點(diǎn)5掩碼對(duì)各類(lèi)別預(yù)測(cè)logits的平均影響。正值表示掩碼導(dǎo)致該類(lèi)別的預(yù)測(cè)概率增加,負(fù)值則表示減少。這種可視化有助于理解特定節(jié)點(diǎn)對(duì)模型決策的影響程度和方向。

除了均值分析,還可以采用其他評(píng)估指標(biāo),如:

  • 絕對(duì)差異
  • 相對(duì)差異
  • 均方誤差(MSE)
  • 自定義評(píng)估指標(biāo)

關(guān)鍵子圖提取

為了深入分析圖結(jié)構(gòu)中的重要組件,可以使用以下方法:

get_explanation_subgraph():提取對(duì)解釋具有非零重要性的節(jié)點(diǎn)和邊,返回一個(gè)新的Explanation對(duì)象。這有助于隔離對(duì)預(yù)測(cè)最具影響力的圖結(jié)構(gòu)組件。

get_complement_subgraph():提取重要性為零的節(jié)點(diǎn)和邊,返回一個(gè)新的Explanation對(duì)象。這有助于理解模型認(rèn)為不重要的圖結(jié)構(gòu)部分。

這些方法的主要價(jià)值在于能夠分離和聚焦于感興趣的圖結(jié)構(gòu)組件,尤其是get_explanation_subgraph()可以有效降低來(lái)自無(wú)關(guān)節(jié)點(diǎn)和邊的干擾。

關(guān)鍵特征提取

下代碼展示了如何提取影響節(jié)點(diǎn)預(yù)測(cè)的關(guān)鍵特征。這段代碼改編自visualize_feature_importance方法

node_mask = model_explanation.get('node_mask')  
 if node_mask is None:  
     raise ValueError(f"The attribute 'node_mask' is not available "  
                       f"in '{model_explanation.__class__.__name__}' "  
                       f"(got {model_explanation.available_explanations})")  
 if node_mask.dim() != 2 or node_mask.size(1) <= 1:  
     raise ValueError(f"Cannot compute feature importance for "  
                       f"object-level 'node_mask' "  
                       f"(got shape {node_mask.size()})")  
   
 score = node_mask.sum(dim=0)  
 non_zero_indices = torch.nonzero(score, as_tuple=True)[0]  
 non_zero_scores = score[non_zero_indices]  
   
 # 特征重要性排序
 sorted_indices = non_zero_indices[torch.argsort(non_zero_scores, descending=True)]  
 print(sorted_indices)

輸出示例:

tensor([555, 474,  43, 210, 446, 158, 516, 273, 417, 531], device='cuda:0')

該實(shí)現(xiàn)的關(guān)鍵步驟:

  1. 計(jì)算每個(gè)特征在所有節(jié)點(diǎn)上的累積重要性
  2. 篩選出具有非零重要性的特征
  3. 特征列表的長(zhǎng)度由Explainer初始化時(shí)的ThresholdConfig決定(示例中為10,因?yàn)樵O(shè)置了threshold_config = dict(threshold_type='topk', value=10)

解釋結(jié)果可視化

圖結(jié)構(gòu)可視化

visualize_graph方法用于直觀展示對(duì)模型預(yù)測(cè)有影響的節(jié)點(diǎn)和邊。該方法的一個(gè)重要特性是通過(guò)邊的不透明度表示其重要性(不透明度越高表示重要性越大)。需要注意的是,使用此方法時(shí)Explainer不能設(shè)置edge_mask_type=None

方法定義:

visualize_graph(path: Optional[str] = None,  
                 backend: Optional[str] = None,  
                 node_labels: Optional[List[str]] = None)

參數(shù)說(shuō)明:

  • path: 可視化結(jié)果保存路徑
  • backend: 可視化后端選擇,支持graphviz或networkx
  • node_label: 節(jié)點(diǎn)標(biāo)識(shí)符列表

下面通過(guò)兩個(gè)示例展示不同配置下的可視化效果:

示例1:基礎(chǔ)特征屬性分析

配置:node_mask_type='attributes',不設(shè)置閾值

visual_explainer_1 = Explainer(  
     model=model,  
     algorithm=GNNExplainer(epochs=50),  
     explanation_type='model',  
     node_mask_type='attributes',  
     edge_mask_type='object',  
     model_config=dict(  
         mode='multiclass_classification',  
         task_level='node',  
         return_type='log_probs',  
    )  
 )  
   
 index = 143  
   
 visual_explanation_1 = visual_explainer_1(  
     batch.x,  
     batch.edge_index,  
     index=index  
 )

生成可視化結(jié)果:

visual_explanation_1.visualize_graph('visual_graph_1.png', backend="graphviz")

可視化結(jié)果展示了與節(jié)點(diǎn)143相連的所有節(jié)點(diǎn),這些節(jié)點(diǎn)的特征都對(duì)節(jié)點(diǎn)143的預(yù)測(cè)產(chǎn)生了影響。圖中邊的不透明度差異反映了不同連接對(duì)預(yù)測(cè)結(jié)果的影響程度。由于未設(shè)置閾值,可視化結(jié)果包含了較多的節(jié)點(diǎn)和邊,這有助于全面理解模型的決策過(guò)程,但可能不夠聚焦。

示例2:重要性篩選分析

配置:node_mask_type='attributes',threshold_cnotallow=dict(threshold_type='topk', value=10),edge_mask_type=None

本示例通過(guò)設(shè)置閾值來(lái)篩選最重要的節(jié)點(diǎn),提供更聚焦的分析視圖:

visual_explainer_2 = Explainer(  
     model=model,  
     algorithm=GNNExplainer(epochs=50),  
     explanation_type='model',  
     node_mask_type='attributes',  
     model_config=dict(  
         mode='multiclass_classification',  
         task_level='node',  
         return_type='log_probs',  
    ),  
     threshold_config=dict(threshold_type='topk', value=10)  
 )  
   
 index = 143  
   
 visual_explanation_2 = visual_explainer_2(  
     batch.x,  
     batch.edge_index,  
     index=index  
 )
 
 # 生成可視化結(jié)果
 visual_explanation_2.visualize_graph('visual_graph_2.png', backend="graphviz")

第二種可視化方法通過(guò)限制顯示最重要的10個(gè)節(jié)點(diǎn),提供了更加精煉的分析視圖。邊的不透明度變化不太明顯,這說(shuō)明這些保留下來(lái)的邊對(duì)預(yù)測(cè)結(jié)果具有相近的影響程度。這種篩選后的可視化更適合用于識(shí)別和分析關(guān)鍵影響因素。

特征重要性可視化

visualize_feature_importance方法提供了另一種可視化視角,用于展示影響節(jié)點(diǎn)預(yù)測(cè)的top-k重要特征。使用此方法時(shí),Explainer的初始化配置中不能設(shè)置node_mask_type=None,詳細(xì)實(shí)現(xiàn)可參考方法的源代碼。

方法定義:

visualize_feature_importance(path: Optional[str] = None,  
                              feat_labels: Optional[List[str]] = None,  
                              top_k: Optional[int] = None)[source])

參數(shù)說(shuō)明:

  • path: 可視化結(jié)果保存路徑
  • feat_labels: 特征標(biāo)簽列表,用于增強(qiáng)可讀性
  • top_k: 顯示的重要特征數(shù)量示例調(diào)用:
model_explanation.visualize_feature_importance(top_k=10)

該圖顯示了對(duì)節(jié)點(diǎn)143預(yù)測(cè)結(jié)果影響最大的前10個(gè)特征。這些特征與我們之前通過(guò)分析得到的影響特征列表完全一致,提供了直觀的重要性排序視圖。

解釋質(zhì)量評(píng)估

為了區(qū)分高質(zhì)量解釋和低質(zhì)量解釋?zhuān)枰⒁惶紫到y(tǒng)的評(píng)估機(jī)制。這一評(píng)估機(jī)制對(duì)于判斷不同解釋器(如DummyExplainer與專(zhuān)業(yè)解釋器)的性能差異尤為重要。系統(tǒng)提供了五種評(píng)估指標(biāo)[12]:

基于真實(shí)標(biāo)簽的評(píng)估

groundtruth_metrics用于評(píng)估生成的解釋掩碼與真實(shí)解釋掩碼之間的一致性。這個(gè)指標(biāo)有助于判斷模型識(shí)別的重要特征是否與實(shí)際數(shù)據(jù)中的關(guān)鍵特征相符。

  • 評(píng)估模型解釋與數(shù)據(jù)真實(shí)重要性特征的匹配程度
  • 驗(yàn)證模型的解釋能力是否符合領(lǐng)域知識(shí)
  • 識(shí)別潛在的誤解釋情況

準(zhǔn)確性評(píng)估

fidelity指標(biāo)通過(guò)比較兩種場(chǎng)景下的預(yù)測(cè)差異來(lái)評(píng)估解釋的質(zhì)量:

Fid+(保留重要特征):

  • 僅保留解釋認(rèn)定的重要部分
  • 評(píng)估這些部分是否足以重現(xiàn)原始預(yù)測(cè)

Fid-(移除重要特征):

  • 移除解釋認(rèn)定的重要部分
  • 評(píng)估這些部分的缺失是否會(huì)顯著改變預(yù)測(cè)結(jié)果

評(píng)估標(biāo)準(zhǔn):

  • 高質(zhì)量解釋?xiě)?yīng)具有高Fid+值,表明保留的重要特征能夠很好地支持原始預(yù)測(cè)
  • 同時(shí)應(yīng)具有低Fid-值,表明移除這些特征會(huì)導(dǎo)致預(yù)測(cè)結(jié)果發(fā)生顯著變化

綜合特征化評(píng)分

characterization_score將Fid+和Fid-兩個(gè)指標(biāo)整合為單一評(píng)分,提供更全面的評(píng)估視角:

  • Fid+:評(píng)估保留重要特征的效果(目標(biāo)值接近1)
  • Fid-:評(píng)估移除重要特征的影響(目標(biāo)值接近0)
  • 權(quán)重配置:默認(rèn)兩者權(quán)重相等(各0.5),可根據(jù)具體應(yīng)用場(chǎng)景調(diào)整

準(zhǔn)確性曲線(xiàn)分析

fidelity_curve_auc提供了一個(gè)更加動(dòng)態(tài)的評(píng)估視角,通過(guò)測(cè)量不同閾值下解釋質(zhì)量的變化來(lái)生成完整的性能曲線(xiàn):

評(píng)估機(jī)制:

  • 通過(guò)調(diào)整重要特征的閾值進(jìn)行多次準(zhǔn)確性測(cè)量
  • 計(jì)算測(cè)量結(jié)果的曲線(xiàn)下面積(AUC)
  • 分析解釋質(zhì)量隨特征數(shù)量變化的穩(wěn)定性

結(jié)果解讀:

  • AUC = 1:解釋在所有閾值下均保持高準(zhǔn)確性
  • AUC = 0:解釋在所有閾值下均表現(xiàn)不佳
  • AUC值越高表明解釋的穩(wěn)健性越好

相比特征化評(píng)分,曲線(xiàn)分析的優(yōu)勢(shì)在于能夠提供全范圍閾值下的性能表現(xiàn),而不是僅關(guān)注特定點(diǎn)的表現(xiàn)。

示例:

from torch_geometric.explain.metric import (  
    fidelity,  
    characterization_score,  
    fidelity_curve_auc,  
    unfaithfulness  
 )  
   
 # 驗(yàn)證解釋結(jié)果
 is_valid = model_explanation.validate()  
   
 # 計(jì)算準(zhǔn)確性指標(biāo)
 fid_pos, fid_neg = fidelity(  
    explainer=metric_explainer,  
    explanation=metric_explanation  
 )  
   
 # 計(jì)算特征化評(píng)分
 char_score = characterization_score(  
     fid_pos,  
     fid_neg,  
     pos_weight=0.7,    # 提高正向影響的權(quán)重  
     neg_weight=0.3     # 降低負(fù)向影響的權(quán)重          
 )  
 
 # 準(zhǔn)確性曲線(xiàn)AUC計(jì)算
 pos_fidelity = torch.tensor([0.9, 0.8, 0.7, 0.6, 0.5])  
 neg_fidelity = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5])  
   
 # 定義評(píng)估閾值點(diǎn)
 x = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5])  
   
 # 計(jì)算AUC
 auc = fidelity_curve_auc(pos_fidelity, neg_fidelity, x)  
 
 # 輸出評(píng)估結(jié)果
 print(f"準(zhǔn)確性指標(biāo): {fid_pos}, {fid_neg}")  
 print(f"特征化評(píng)分: {char_score}")  
 print("準(zhǔn)確性曲線(xiàn)AUC:", auc.item())

總結(jié)

圖神經(jīng)網(wǎng)絡(luò)的可解釋性研究對(duì)于提升模型的可信度和實(shí)用價(jià)值具有重要意義。通過(guò)PyTorch Geometric的可解釋性模塊,我們實(shí)現(xiàn)了對(duì)復(fù)雜模型決策過(guò)程的系統(tǒng)分析和理解。

責(zé)任編輯:華軒 來(lái)源: DeepHub IMBA
相關(guān)推薦

2019-11-08 10:17:41

人工智能機(jī)器學(xué)習(xí)技術(shù)

2023-03-09 12:12:38

算法準(zhǔn)則

2020-05-14 08:40:57

神經(jīng)網(wǎng)絡(luò)決策樹(shù)AI

2022-06-07 11:14:23

神經(jīng)網(wǎng)絡(luò)AI中科院

2019-08-29 18:07:51

機(jī)器學(xué)習(xí)人工智能

2023-03-07 16:48:54

算法可解釋性

2024-08-23 13:40:00

AI模型

2022-07-28 09:00:00

深度學(xué)習(xí)網(wǎng)絡(luò)類(lèi)型架構(gòu)

2025-01-13 08:13:18

2024-09-18 05:25:00

可解釋性人工智能AI

2024-05-28 08:00:00

人工智能機(jī)器學(xué)習(xí)

2025-03-10 08:34:39

2025-06-16 08:51:00

2021-01-08 10:47:07

機(jī)器學(xué)習(xí)模型算法

2024-04-30 14:54:10

2022-06-14 14:48:09

AI圖像GAN

2023-05-04 07:23:04

因果推斷貝葉斯因果網(wǎng)絡(luò)

2023-08-15 10:04:40

2022-05-25 14:21:01

神經(jīng)網(wǎng)絡(luò)框架技術(shù)

2025-07-08 08:38:09

推理錨點(diǎn)LLM大模型
點(diǎn)贊
收藏

51CTO技術(shù)棧公眾號(hào)