Rumah > Peranti teknologi > AI > Pautkan contoh kod ramalan menggunakan Pytorch Geometric

Pautkan contoh kod ramalan menggunakan Pytorch Geometric

王林
Lepaskan: 2023-10-20 19:33:08
ke hadapan
1086 orang telah melayarinya

PyTorch Geometric (PyG) ialah alat utama untuk membina model rangkaian saraf graf dan bereksperimen dengan pelbagai lilitan graf. Dalam artikel ini kami akan memperkenalkannya melalui ramalan pautan.

使用Pytorch Geometric 进行链接预测代码示例

Ramalan pautan menjawab soalan: Dua nod yang manakah harus dipautkan antara satu sama lain Kami akan menyediakan data untuk pemodelan dengan melakukan "pemisahan transformasi". Sediakan pemuat data graf khusus untuk pemprosesan kelompok. Bina model dalam Torch Geometric, latihnya menggunakan PyTorch Lightning dan semak prestasi model.

Persediaan perpustakaan

  • Obor Ini tidak memerlukan pengenalan lanjut
  • Obor ialah perpustakaan utama rangkaian saraf graf Geometrik dan fokus artikel ini
  • PyTorch Lightning digunakan untuk melatih, menala dan mengesahkan model. Ia memudahkan operasi latihan
  • Sklearn Metrics dan Torchmetrics digunakan untuk menyemak prestasi model.
  • PyTorch Geometric mempunyai beberapa kebergantungan khusus, jika anda menghadapi masalah memasangnya, sila rujuk dokumentasi rasminya.

Penyediaan Data

Kami akan menggunakan set data petikan Cora ML. Set data boleh diakses melalui Torch Geometric.

 data = tg.datasets.CitationFull(root="data", name="Cora_ML")
Salin selepas log masuk

Secara lalai, set data Geometri Obor boleh mengembalikan berbilang graf. Mari lihat rupa graf tunggal

data[0] > Data(x=[2995, 2879], edge_index=[2, 16316], y=[2995])
Salin selepas log masuk

di mana X ialah ciri nod. edge_index ialah matriks 2 x (n tepi) (dimensi pertama = 2, ditafsirkan sebagai: baris 0 - nod sumber/"pengirim", baris 1 - nod sasaran/"penerima").

Pemisahan Pautan

Kami akan mulakan dengan membahagikan pautan dalam set data. Gunakan 20% pautan graf sebagai set pengesahan dan 10% sebagai set ujian. Sampel negatif tidak akan ditambahkan pada set data latihan di sini kerana pautan negatif tersebut akan dibuat dengan segera oleh pemuat data kelompok.

Secara umum, pensampelan negatif menghasilkan sampel "palsu" (dalam kes kami pautan antara nod), jadi model belajar cara membezakan antara pautan sebenar dan palsu. Persampelan negatif adalah berdasarkan teori dan matematik persampelan dan mempunyai beberapa sifat statistik yang bagus.

Pertama: mari buat objek split pautan.

 link_splitter = tg.transforms.RandomLinkSplit(num_val=0.2, num_test=0.1, add_negative_train_samples=False,disjoint_train_ratio=0.8)
Salin selepas log masuk

disjoint_train_ratio melaraskan bilangan tepi yang akan digunakan sebagai maklumat latihan dalam fasa "penyeliaan". Bahagian tepi yang tinggal akan digunakan untuk penghantaran mesej (fasa pemindahan maklumat dalam rangkaian).

Terdapat sekurang-kurangnya dua kaedah pembahagian tepi dalam rangkaian saraf graf: pembahagian induktif dan pembahagian konduktif. Kaedah transformasi menganggap GNN perlu mempelajari corak struktur daripada struktur graf. Dalam tetapan induktif, label nod/tepi boleh digunakan untuk pembelajaran. Terdapat dua kertas kerja pada penghujung kertas ini yang membincangkan konsep-konsep ini secara terperinci dan menyediakan pemformalan tambahan: ([1], [3]).

 train_g, val_g, test_g = link_splitter(data[0])  > Data(x=[2995, 2879], edge_index=[2, 2285], y=[2995], edge_label=[9137], edge_label_index=[2, 9137])
Salin selepas log masuk

Selepas operasi ini, kami mempunyai beberapa atribut baharu:

label_tepi: menerangkan sama ada tepi itu benar/salah. Inilah yang ingin kita ramalkan.

edge_label_index ialah matriks 2 x NUM EDGES yang digunakan untuk menyimpan pautan nod.

Mari kita lihat pengedaran sampel

th.unique(train_g.edge_label, return_counts=True) > (tensor([1.]), tensor([9137]))  th.unique(val_g.edge_label, return_counts=True) > (tensor([0., 1.]), tensor([3263, 3263]))  th.unique(val_g.edge_label, return_counts=True) > (tensor([0., 1.]), tensor([3263, 3263]))
Salin selepas log masuk

Untuk data latihan tidak ada sisi negatif (kami akan menciptanya semasa latihan), untuk set val/ujian - sudah ada beberapa pautan "palsu" dalam 50: nisbah 50.

Model

Kini kita boleh membina model menggunakan GNN dengan membina

kelas GNN(nn.Module):

def __init__(self, dim_in: int, conv_sizes: Tuple[int, ...], act_f: nn.Module = th.relu, dropout: float = 0.1,*args, **kwargs):super().__init__()self.dim_in = dim_inself.dim_out = conv_sizes[-1]self.dropout = dropoutself.act_f = act_flast_in = dim_inlayers = [] # Here we build subsequent graph convolutions.for conv_sz in conv_sizes:# Single graph convolution layerconv = tgnn.SAGEConv(in_channels=last_in, out_channels=conv_sz, *args, **kwargs)last_in = conv_szlayers.append(conv)self.layers = nn.ModuleList(layers) def forward(self, x: th.Tensor, edge_index: th.Tensor) -> th.Tensor:h = x# For every graph convolution in the network...for conv in self.layers:# ... perform node embedding via message passingh = conv(h, edge_index)h = self.act_f(h)if self.dropout:h = nn.functional.dropout(h, p=self.dropout, training=self.training)return h
Salin selepas log masuk

Bahagian penting dalam model ini ialah set lilitan graf - dalam kes kami Ia SAGEConv. Takrif formal lilitan SAGE ialah:

使用Pytorch Geometric 进行链接预测代码示例å¾ç

v ialah nod semasa, N(v) jiran nod v. Untuk mengetahui lebih lanjut tentang jenis lilitan ini, lihat kertas asal daripada GraphSAGE[1]

Mari semak sama ada model boleh membuat ramalan menggunakan data yang disediakan. Input kepada model PyG di sini ialah matriks ciri nod X dan pautan yang menentukan edge_index.

gnn = GNN(train_g.x.size()[1], conv_sizes=[512, 256, 128]) with th.no_grad():out = gnn(train_g.x, train_g.edge_index)  out   > tensor([[0.0000, 0.0000, 0.0051, ..., 0.0997, 0.0000, 0.0000],[0.0107, 0.0000, 0.0576, ..., 0.0651, 0.0000, 0.0000],[0.0000, 0.0000, 0.0102, ..., 0.0973, 0.0000, 0.0000],...,[0.0000, 0.0000, 0.0549, ..., 0.0671, 0.0000, 0.0000],[0.0000, 0.0000, 0.0166, ..., 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0034, ..., 0.1111, 0.0000, 0.0000]])
Salin selepas log masuk

Keluaran model kami ialah matriks benam nod dengan dimensi: N nod x saiz benam.

PyTorch Lightning

PyTorch Lightning digunakan terutamanya untuk latihan, tetapi di sini kami menambah lapisan Linear selepas output GNN sebagai kepala output untuk meramalkan sama ada untuk dipautkan.

class LinkPredModel(pl.LightningModule):

def __init__(self,dim_in: int,conv_sizes: Tuple[int, ...], act_f: nn.Module = th.relu, dropout: float = 0.1,lr: float = 0.01,*args, **kwargs):super().__init__() # Our inner GNN modelself.gnn = GNN(dim_in, conv_sizes=conv_sizes, act_f=act_f, dropout=dropout) # Final prediction model on links.self.lin_pred = nn.Linear(self.gnn.dim_out, 1)self.lr = lr def forward(self, x: th.Tensor, edge_index: th.Tensor) -> th.Tensor:# Step 1: make node embeddings using GNN.h = self.gnn(x, edge_index) # Take source nodes embeddings- sendersh_src = h[edge_index[0, :]]# Take target node embeddings - receiversh_dst = h[edge_index[1, :]] # Calculate the product between themsrc_dst_mult = h_src * h_dst# Apply non-linearityout = self.lin_pred(src_dst_mult)return out def _step(self, batch: th.Tensor, phase: str='train') -> th.Tensor:yhat_edge = self(batch.x, batch.edge_label_index).squeeze()y = batch.edge_labelloss = nn.functional.binary_cross_entropy_with_logits(input=yhat_edge, target=y)f1 = tm.functional.f1_score(preds=yhat_edge, target=y, task='binary')prec = tm.functional.precision(preds=yhat_edge, target=y, task='binary')recall = tm.functional.recall(preds=yhat_edge, target=y, task='binary') # Watch for logging here - we need to provide batch_size, as (at the time of this implementation)# PL cannot understand the batch size.self.log(f"{phase}_f1", f1, batch_size=batch.edge_label_index.shape[1])self.log(f"{phase}_loss", loss, batch_size=batch.edge_label_index.shape[1])self.log(f"{phase}_precision", prec, batch_size=batch.edge_label_index.shape[1])self.log(f"{phase}_recall", recall, batch_size=batch.edge_label_index.shape[1])return loss def training_step(self, batch, batch_idx):return self._step(batch) def validation_step(self, batch, batch_idx):return self._step(batch, "val") def test_step(self, batch, batch_idx):return self._step(batch, "test") def predict_step(self, batch):x, edge_index = batchreturn self(x, edge_index) def configure_optimizers(self):return th.optim.Adam(self.parameters(), lr=self.lr)
Salin selepas log masuk

Peranan PyTorch Lightning adalah untuk membantu kami memudahkan langkah latihan Kami hanya perlu mengkonfigurasi beberapa fungsi Kami boleh menggunakan arahan berikut untuk menguji sama ada model itu tersedia

Latihan

Untuk langkah latihan, pemuat data memerlukan pemprosesan khas.

Data graf memerlukan pemprosesan khas - terutamanya ramalan pautan. PyG mempunyai beberapa kelas pemuat data khusus yang bertanggungjawab untuk menjana kelompok dengan betul. Kami akan menggunakan: tg.loader.LinkNeighborLoader, yang menerima input berikut:

Data untuk dimuatkan secara pukal (imej). num_neighbors Bilangan maksimum jiran untuk dimuatkan setiap nod semasa satu "hop". Senarai yang menyatakan bilangan jiran 1 - 2 - 3 -…-K. Terutamanya berguna untuk grafik yang sangat besar.

edge_label_index atribut yang sudah menunjukkan pautan benar/salah.

neg_sampling_ratio - nisbah sampel negatif kepada sampel sebenar.

 model = LinkPredModel(val_g.x.size()[1], conv_sizes=[512, 256, 128]) with th.no_grad():out = model.predict_step((val_g.x, val_g.edge_label_index))
Salin selepas log masuk
Berikut ialah model latihan

 train_loader = tg.loader.LinkNeighborLoader(train_g,num_neighbors=[-1, 10, 5],batch_size=128,edge_label_index=train_g.edge_label_index, # "on the fly" negative sampling creation for batchneg_sampling_ratio=0.5 )  val_loader = tg.loader.LinkNeighborLoader(val_g,num_neighbors=[-1, 10, 5],batch_size=128,edge_label_index=val_g.edge_label_index,edge_label=val_g.edge_label, # negative samples for val set are done already as ground-truthneg_sampling_ratio=0.0 )  test_loader = tg.loader.LinkNeighborLoader(test_g,num_neighbors=[-1, 10, 5],batch_size=128,edge_label_index=test_g.edge_label_index,edge_label=test_g.edge_label, # negative samples for test set are done already as ground-truthneg_sampling_ratio=0.0 )
Salin selepas log masuk
Uji pengesahan data, lihat laporan klasifikasi dan keluk ROC.

model = LinkPredModel(val_g.x.size()[1], conv_sizes=[512, 256, 128]) trainer = pl.Trainer(max_epochs=20, log_every_n_steps=5)  # Validate before training - we will see results of untrained model. trainer.validate(model, val_loader)  # Train the model trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)
Salin selepas log masuk
Hasilnya nampak cantik:

precision recall f1-score support0.0 0.68 0.70 0.69 16311.0 0.69 0.66 0.68 1631accuracy 0.68 3262macro avg 0.68 0.68 0.68 3262
Salin selepas log masuk

ROC曲线也不错

使用Pytorch Geometric 进行链接预测代码示例

我们训练的模型并不特别复杂,也没有经过精心调整,但它完成了工作。当然这只是一个为了演示使用的小型数据集。

总结

图神经网络尽管看起来很复杂,但是PyTorch Geometric为我们提供了一个很好的解决方案。我们可以直接使用其中内置的模型实现,这方便了我们使用和简化了入门的门槛。

本文代码:https://github.com/maddataanalyst/blogposts_code/blob/main/graph_nns_series/pyg_pyl_perfect_match/pytorch-geometric-lightning-perfect-match.ipynb

Atas ialah kandungan terperinci Pautkan contoh kod ramalan menggunakan Pytorch Geometric. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!

Label berkaitan:
sumber:51cto.com
Kenyataan Laman Web ini
Kandungan artikel ini disumbangkan secara sukarela oleh netizen, dan hak cipta adalah milik pengarang asal. Laman web ini tidak memikul tanggungjawab undang-undang yang sepadan. Jika anda menemui sebarang kandungan yang disyaki plagiarisme atau pelanggaran, sila hubungi admin@php.cn
Tutorial Popular
Lagi>
Muat turun terkini
Lagi>
kesan web
Kod sumber laman web
Bahan laman web
Templat hujung hadapan