Knowledge Graph (KG) Dataloader with Lightning¶
This guide shows how to load and use a Knowledge Graph dataset with the PyTorch Lightning LightningDataModule.
To load and simulate data from the KG model, follow the steps below:
Step 1: 📦 Import the module
In [13]:
Copied!
import sys
import os
# Go up to the root where `vpeleaderboard/` is located
sys.path.append(os.path.abspath("../../"))
import sys
import os
# Go up to the root where `vpeleaderboard/` is located
sys.path.append(os.path.abspath("../../"))
In [14]:
Copied!
from vpeleaderboard.data.src.kg.biobridge_datamodule_hetero import BioBridgeDataModule
from vpeleaderboard.data.src.kg.biobridge_datamodule_hetero import BioBridgeDataModule
Step 2: ⚙️ Initialize the KGDataModule
Specify the data directories and initialize the BioBridgeDataModule:
In [15]:
Copied!
import hydra
with hydra.initialize(config_path="../../vpeleaderboard/configs/data/kg/BioBRIDGE-PrimeKG", version_base=None):
cfg = hydra.compose(config_name="default")
dm = BioBridgeDataModule(cfg)
import hydra
with hydra.initialize(config_path="../../vpeleaderboard/configs/data/kg/BioBRIDGE-PrimeKG", version_base=None):
cfg = hydra.compose(config_name="default")
dm = BioBridgeDataModule(cfg)
Step 3: 🧹 Prepare data
Prepare the KG data by loading and caching it:
In [16]:
Copied!
# Load data, embeddings, and node/edge mappings
dm.prepare_data()
# Load data, embeddings, and node/edge mappings
dm.prepare_data()
Step 4: 🧠 Setup the data splits
Split the data into training, validation, and test sets:
In [17]:
Copied!
# Build HeteroData and apply RandomLinkSplit
dm.setup()
# Build HeteroData and apply RandomLinkSplit
dm.setup()
Step 5: 🧪 Access the dataloaders
Now, you can retrieve the standard Lightning dataloaders:
In [18]:
Copied!
print(dm.data.keys()) # Should include 'train', 'val', 'test' if setup worked
print(dm.data.keys()) # Should include 'train', 'val', 'test' if setup worked
dict_keys(['init', 'train', 'val', 'test'])
Training data
In [19]:
Copied!
train_loader = dm.train_dataloader()
train_batch = next(iter(train_loader))
print(train_batch)
train_loader = dm.train_dataloader()
train_batch = next(iter(train_loader))
print(train_batch)
HeteroDataBatch( biological_process={ num_nodes=27409, x=[27409, 768], node_name=[1], batch=[27409], ptr=[2], }, cellular_component={ num_nodes=4011, x=[4011, 768], node_name=[1], batch=[4011], ptr=[2], }, disease={ num_nodes=17054, x=[17054, 768], node_name=[1], batch=[17054], ptr=[2], }, drug={ num_nodes=6759, x=[6759, 512], node_name=[1], batch=[6759], ptr=[2], }, gene/protein={ num_nodes=18797, x=[18797, 2560], node_name=[1], batch=[18797], ptr=[2], }, molecular_function={ num_nodes=10951, x=[10951, 768], node_name=[1], batch=[10951], ptr=[2], }, (gene/protein, ppi, gene/protein)={ edge_index=[2, 440447], pos_edge_label=[440447], pos_edge_label_index=[2, 440447], neg_edge_label=[440447], neg_edge_label_index=[2, 440447], }, (drug, carrier, gene/protein)={ edge_index=[2, 571], pos_edge_label=[571], pos_edge_label_index=[2, 571], neg_edge_label=[571], neg_edge_label_index=[2, 571], }, (drug, enzyme, gene/protein)={ edge_index=[2, 3542], pos_edge_label=[3542], pos_edge_label_index=[2, 3542], neg_edge_label=[3542], neg_edge_label_index=[2, 3542], }, (drug, target, gene/protein)={ edge_index=[2, 10444], pos_edge_label=[10444], pos_edge_label_index=[2, 10444], neg_edge_label=[10444], neg_edge_label_index=[2, 10444], }, (drug, transporter, gene/protein)={ edge_index=[2, 2065], pos_edge_label=[2065], pos_edge_label_index=[2, 2065], neg_edge_label=[2065], neg_edge_label_index=[2, 2065], }, (drug, contraindication, disease)={ edge_index=[2, 21028], pos_edge_label=[21028], pos_edge_label_index=[2, 21028], neg_edge_label=[21028], neg_edge_label_index=[2, 21028], }, (drug, indication, disease)={ edge_index=[2, 6123], pos_edge_label=[6123], pos_edge_label_index=[2, 6123], neg_edge_label=[6123], neg_edge_label_index=[2, 6123], }, (drug, off-label use, disease)={ edge_index=[2, 1712], pos_edge_label=[1712], pos_edge_label_index=[2, 1712], neg_edge_label=[1712], neg_edge_label_index=[2, 1712], }, (drug, synergistic interaction, drug)={ edge_index=[2, 1563945], pos_edge_label=[1563945], pos_edge_label_index=[2, 1563945], neg_edge_label=[1563945], neg_edge_label_index=[2, 1563945], }, (gene/protein, associated with, disease)={ edge_index=[2, 54574], pos_edge_label=[54574], pos_edge_label_index=[2, 54574], neg_edge_label=[54574], neg_edge_label_index=[2, 54574], }, (disease, parent-child, disease)={ edge_index=[2, 45073], pos_edge_label=[45073], pos_edge_label_index=[2, 45073], neg_edge_label=[45073], neg_edge_label_index=[2, 45073], }, (biological_process, parent-child, biological_process)={ edge_index=[2, 69741], pos_edge_label=[69741], pos_edge_label_index=[2, 69741], neg_edge_label=[69741], neg_edge_label_index=[2, 69741], }, (molecular_function, parent-child, molecular_function)={ edge_index=[2, 18506], pos_edge_label=[18506], pos_edge_label_index=[2, 18506], neg_edge_label=[18506], neg_edge_label_index=[2, 18506], }, (cellular_component, parent-child, cellular_component)={ edge_index=[2, 6440], pos_edge_label=[6440], pos_edge_label_index=[2, 6440], neg_edge_label=[6440], neg_edge_label_index=[2, 6440], }, (gene/protein, interacts with, molecular_function)={ edge_index=[2, 46733], pos_edge_label=[46733], pos_edge_label_index=[2, 46733], neg_edge_label=[46733], neg_edge_label_index=[2, 46733], }, (gene/protein, interacts with, cellular_component)={ edge_index=[2, 52327], pos_edge_label=[52327], pos_edge_label_index=[2, 52327], neg_edge_label=[52327], neg_edge_label_index=[2, 52327], }, (gene/protein, interacts with, biological_process)={ edge_index=[2, 95425], pos_edge_label=[95425], pos_edge_label_index=[2, 95425], neg_edge_label=[95425], neg_edge_label_index=[2, 95425], }, (gene/protein, carrier, drug)={ edge_index=[2, 571], pos_edge_label=[571], pos_edge_label_index=[2, 571], neg_edge_label=[571], neg_edge_label_index=[2, 571], }, (gene/protein, enzyme, drug)={ edge_index=[2, 3542], pos_edge_label=[3542], pos_edge_label_index=[2, 3542], neg_edge_label=[3542], neg_edge_label_index=[2, 3542], }, (gene/protein, target, drug)={ edge_index=[2, 10444], pos_edge_label=[10444], pos_edge_label_index=[2, 10444], neg_edge_label=[10444], neg_edge_label_index=[2, 10444], }, (gene/protein, transporter, drug)={ edge_index=[2, 2065], pos_edge_label=[2065], pos_edge_label_index=[2, 2065], neg_edge_label=[2065], neg_edge_label_index=[2, 2065], }, (disease, contraindication, drug)={ edge_index=[2, 21028], pos_edge_label=[21028], pos_edge_label_index=[2, 21028], neg_edge_label=[21028], neg_edge_label_index=[2, 21028], }, (disease, indication, drug)={ edge_index=[2, 6123], pos_edge_label=[6123], pos_edge_label_index=[2, 6123], neg_edge_label=[6123], neg_edge_label_index=[2, 6123], }, (disease, off-label use, drug)={ edge_index=[2, 1712], pos_edge_label=[1712], pos_edge_label_index=[2, 1712], neg_edge_label=[1712], neg_edge_label_index=[2, 1712], }, (disease, associated with, gene/protein)={ edge_index=[2, 54574], pos_edge_label=[54574], pos_edge_label_index=[2, 54574], neg_edge_label=[54574], neg_edge_label_index=[2, 54574], }, (molecular_function, interacts with, gene/protein)={ edge_index=[2, 46733], pos_edge_label=[46733], pos_edge_label_index=[2, 46733], neg_edge_label=[46733], neg_edge_label_index=[2, 46733], }, (cellular_component, interacts with, gene/protein)={ edge_index=[2, 52327], pos_edge_label=[52327], pos_edge_label_index=[2, 52327], neg_edge_label=[52327], neg_edge_label_index=[2, 52327], }, (biological_process, interacts with, gene/protein)={ edge_index=[2, 95425], pos_edge_label=[95425], pos_edge_label_index=[2, 95425], neg_edge_label=[95425], neg_edge_label_index=[2, 95425], } )
Validation data
In [20]:
Copied!
val_loader = dm.val_dataloader()
val_batch = next(iter(val_loader))
print(val_batch)
val_loader = dm.val_dataloader()
val_batch = next(iter(val_loader))
print(val_batch)
HeteroDataBatch( biological_process={ num_nodes=27409, x=[27409, 768], node_name=[1], batch=[27409], ptr=[2], }, cellular_component={ num_nodes=4011, x=[4011, 768], node_name=[1], batch=[4011], ptr=[2], }, disease={ num_nodes=17054, x=[17054, 768], node_name=[1], batch=[17054], ptr=[2], }, drug={ num_nodes=6759, x=[6759, 512], node_name=[1], batch=[6759], ptr=[2], }, gene/protein={ num_nodes=18797, x=[18797, 2560], node_name=[1], batch=[18797], ptr=[2], }, molecular_function={ num_nodes=10951, x=[10951, 768], node_name=[1], batch=[10951], ptr=[2], }, (gene/protein, ppi, gene/protein)={ edge_index=[2, 440447], pos_edge_label=[62920], pos_edge_label_index=[2, 62920], neg_edge_label=[62920], neg_edge_label_index=[2, 62920], }, (drug, carrier, gene/protein)={ edge_index=[2, 571], pos_edge_label=[81], pos_edge_label_index=[2, 81], neg_edge_label=[81], neg_edge_label_index=[2, 81], }, (drug, enzyme, gene/protein)={ edge_index=[2, 3542], pos_edge_label=[506], pos_edge_label_index=[2, 506], neg_edge_label=[506], neg_edge_label_index=[2, 506], }, (drug, target, gene/protein)={ edge_index=[2, 10444], pos_edge_label=[1492], pos_edge_label_index=[2, 1492], neg_edge_label=[1492], neg_edge_label_index=[2, 1492], }, (drug, transporter, gene/protein)={ edge_index=[2, 2065], pos_edge_label=[295], pos_edge_label_index=[2, 295], neg_edge_label=[295], neg_edge_label_index=[2, 295], }, (drug, contraindication, disease)={ edge_index=[2, 21028], pos_edge_label=[3004], pos_edge_label_index=[2, 3004], neg_edge_label=[3004], neg_edge_label_index=[2, 3004], }, (drug, indication, disease)={ edge_index=[2, 6123], pos_edge_label=[874], pos_edge_label_index=[2, 874], neg_edge_label=[874], neg_edge_label_index=[2, 874], }, (drug, off-label use, disease)={ edge_index=[2, 1712], pos_edge_label=[244], pos_edge_label_index=[2, 244], neg_edge_label=[244], neg_edge_label_index=[2, 244], }, (drug, synergistic interaction, drug)={ edge_index=[2, 1563945], pos_edge_label=[223420], pos_edge_label_index=[2, 223420], neg_edge_label=[223420], neg_edge_label_index=[2, 223420], }, (gene/protein, associated with, disease)={ edge_index=[2, 54574], pos_edge_label=[7796], pos_edge_label_index=[2, 7796], neg_edge_label=[7796], neg_edge_label_index=[2, 7796], }, (disease, parent-child, disease)={ edge_index=[2, 45073], pos_edge_label=[6438], pos_edge_label_index=[2, 6438], neg_edge_label=[6438], neg_edge_label_index=[2, 6438], }, (biological_process, parent-child, biological_process)={ edge_index=[2, 69741], pos_edge_label=[9963], pos_edge_label_index=[2, 9963], neg_edge_label=[9963], neg_edge_label_index=[2, 9963], }, (molecular_function, parent-child, molecular_function)={ edge_index=[2, 18506], pos_edge_label=[2643], pos_edge_label_index=[2, 2643], neg_edge_label=[2643], neg_edge_label_index=[2, 2643], }, (cellular_component, parent-child, cellular_component)={ edge_index=[2, 6440], pos_edge_label=[920], pos_edge_label_index=[2, 920], neg_edge_label=[920], neg_edge_label_index=[2, 920], }, (gene/protein, interacts with, molecular_function)={ edge_index=[2, 46733], pos_edge_label=[6676], pos_edge_label_index=[2, 6676], neg_edge_label=[6676], neg_edge_label_index=[2, 6676], }, (gene/protein, interacts with, cellular_component)={ edge_index=[2, 52327], pos_edge_label=[7475], pos_edge_label_index=[2, 7475], neg_edge_label=[7475], neg_edge_label_index=[2, 7475], }, (gene/protein, interacts with, biological_process)={ edge_index=[2, 95425], pos_edge_label=[13632], pos_edge_label_index=[2, 13632], neg_edge_label=[13632], neg_edge_label_index=[2, 13632], }, (gene/protein, carrier, drug)={ edge_index=[2, 571], pos_edge_label=[81], pos_edge_label_index=[2, 81], neg_edge_label=[81], neg_edge_label_index=[2, 81], }, (gene/protein, enzyme, drug)={ edge_index=[2, 3542], pos_edge_label=[506], pos_edge_label_index=[2, 506], neg_edge_label=[506], neg_edge_label_index=[2, 506], }, (gene/protein, target, drug)={ edge_index=[2, 10444], pos_edge_label=[1492], pos_edge_label_index=[2, 1492], neg_edge_label=[1492], neg_edge_label_index=[2, 1492], }, (gene/protein, transporter, drug)={ edge_index=[2, 2065], pos_edge_label=[295], pos_edge_label_index=[2, 295], neg_edge_label=[295], neg_edge_label_index=[2, 295], }, (disease, contraindication, drug)={ edge_index=[2, 21028], pos_edge_label=[3004], pos_edge_label_index=[2, 3004], neg_edge_label=[3004], neg_edge_label_index=[2, 3004], }, (disease, indication, drug)={ edge_index=[2, 6123], pos_edge_label=[874], pos_edge_label_index=[2, 874], neg_edge_label=[874], neg_edge_label_index=[2, 874], }, (disease, off-label use, drug)={ edge_index=[2, 1712], pos_edge_label=[244], pos_edge_label_index=[2, 244], neg_edge_label=[244], neg_edge_label_index=[2, 244], }, (disease, associated with, gene/protein)={ edge_index=[2, 54574], pos_edge_label=[7796], pos_edge_label_index=[2, 7796], neg_edge_label=[7796], neg_edge_label_index=[2, 7796], }, (molecular_function, interacts with, gene/protein)={ edge_index=[2, 46733], pos_edge_label=[6676], pos_edge_label_index=[2, 6676], neg_edge_label=[6676], neg_edge_label_index=[2, 6676], }, (cellular_component, interacts with, gene/protein)={ edge_index=[2, 52327], pos_edge_label=[7475], pos_edge_label_index=[2, 7475], neg_edge_label=[7475], neg_edge_label_index=[2, 7475], }, (biological_process, interacts with, gene/protein)={ edge_index=[2, 95425], pos_edge_label=[13632], pos_edge_label_index=[2, 13632], neg_edge_label=[13632], neg_edge_label_index=[2, 13632], } )
Test data
In [21]:
Copied!
test_loader = dm.test_dataloader()
test_batch = next(iter(test_loader))
print(test_batch)
test_loader = dm.test_dataloader()
test_batch = next(iter(test_loader))
print(test_batch)
HeteroDataBatch( biological_process={ num_nodes=27409, x=[27409, 768], node_name=[1], batch=[27409], ptr=[2], }, cellular_component={ num_nodes=4011, x=[4011, 768], node_name=[1], batch=[4011], ptr=[2], }, disease={ num_nodes=17054, x=[17054, 768], node_name=[1], batch=[17054], ptr=[2], }, drug={ num_nodes=6759, x=[6759, 512], node_name=[1], batch=[6759], ptr=[2], }, gene/protein={ num_nodes=18797, x=[18797, 2560], node_name=[1], batch=[18797], ptr=[2], }, molecular_function={ num_nodes=10951, x=[10951, 768], node_name=[1], batch=[10951], ptr=[2], }, (gene/protein, ppi, gene/protein)={ edge_index=[2, 503367], pos_edge_label=[125841], pos_edge_label_index=[2, 125841], neg_edge_label=[125841], neg_edge_label_index=[2, 125841], }, (drug, carrier, gene/protein)={ edge_index=[2, 652], pos_edge_label=[162], pos_edge_label_index=[2, 162], neg_edge_label=[162], neg_edge_label_index=[2, 162], }, (drug, enzyme, gene/protein)={ edge_index=[2, 4048], pos_edge_label=[1012], pos_edge_label_index=[2, 1012], neg_edge_label=[1012], neg_edge_label_index=[2, 1012], }, (drug, target, gene/protein)={ edge_index=[2, 11936], pos_edge_label=[2984], pos_edge_label_index=[2, 2984], neg_edge_label=[2984], neg_edge_label_index=[2, 2984], }, (drug, transporter, gene/protein)={ edge_index=[2, 2360], pos_edge_label=[590], pos_edge_label_index=[2, 590], neg_edge_label=[590], neg_edge_label_index=[2, 590], }, (drug, contraindication, disease)={ edge_index=[2, 24032], pos_edge_label=[6008], pos_edge_label_index=[2, 6008], neg_edge_label=[6008], neg_edge_label_index=[2, 6008], }, (drug, indication, disease)={ edge_index=[2, 6997], pos_edge_label=[1749], pos_edge_label_index=[2, 1749], neg_edge_label=[1749], neg_edge_label_index=[2, 1749], }, (drug, off-label use, disease)={ edge_index=[2, 1956], pos_edge_label=[489], pos_edge_label_index=[2, 489], neg_edge_label=[489], neg_edge_label_index=[2, 489], }, (drug, synergistic interaction, drug)={ edge_index=[2, 1787365], pos_edge_label=[446841], pos_edge_label_index=[2, 446841], neg_edge_label=[446841], neg_edge_label_index=[2, 446841], }, (gene/protein, associated with, disease)={ edge_index=[2, 62370], pos_edge_label=[15592], pos_edge_label_index=[2, 15592], neg_edge_label=[15592], neg_edge_label_index=[2, 15592], }, (disease, parent-child, disease)={ edge_index=[2, 51511], pos_edge_label=[12877], pos_edge_label_index=[2, 12877], neg_edge_label=[12877], neg_edge_label_index=[2, 12877], }, (biological_process, parent-child, biological_process)={ edge_index=[2, 79704], pos_edge_label=[19926], pos_edge_label_index=[2, 19926], neg_edge_label=[19926], neg_edge_label_index=[2, 19926], }, (molecular_function, parent-child, molecular_function)={ edge_index=[2, 21149], pos_edge_label=[5287], pos_edge_label_index=[2, 5287], neg_edge_label=[5287], neg_edge_label_index=[2, 5287], }, (cellular_component, parent-child, cellular_component)={ edge_index=[2, 7360], pos_edge_label=[1840], pos_edge_label_index=[2, 1840], neg_edge_label=[1840], neg_edge_label_index=[2, 1840], }, (gene/protein, interacts with, molecular_function)={ edge_index=[2, 53409], pos_edge_label=[13352], pos_edge_label_index=[2, 13352], neg_edge_label=[13352], neg_edge_label_index=[2, 13352], }, (gene/protein, interacts with, cellular_component)={ edge_index=[2, 59802], pos_edge_label=[14950], pos_edge_label_index=[2, 14950], neg_edge_label=[14950], neg_edge_label_index=[2, 14950], }, (gene/protein, interacts with, biological_process)={ edge_index=[2, 109057], pos_edge_label=[27264], pos_edge_label_index=[2, 27264], neg_edge_label=[27264], neg_edge_label_index=[2, 27264], }, (gene/protein, carrier, drug)={ edge_index=[2, 652], pos_edge_label=[162], pos_edge_label_index=[2, 162], neg_edge_label=[162], neg_edge_label_index=[2, 162], }, (gene/protein, enzyme, drug)={ edge_index=[2, 4048], pos_edge_label=[1012], pos_edge_label_index=[2, 1012], neg_edge_label=[1012], neg_edge_label_index=[2, 1012], }, (gene/protein, target, drug)={ edge_index=[2, 11936], pos_edge_label=[2984], pos_edge_label_index=[2, 2984], neg_edge_label=[2984], neg_edge_label_index=[2, 2984], }, (gene/protein, transporter, drug)={ edge_index=[2, 2360], pos_edge_label=[590], pos_edge_label_index=[2, 590], neg_edge_label=[590], neg_edge_label_index=[2, 590], }, (disease, contraindication, drug)={ edge_index=[2, 24032], pos_edge_label=[6008], pos_edge_label_index=[2, 6008], neg_edge_label=[6008], neg_edge_label_index=[2, 6008], }, (disease, indication, drug)={ edge_index=[2, 6997], pos_edge_label=[1749], pos_edge_label_index=[2, 1749], neg_edge_label=[1749], neg_edge_label_index=[2, 1749], }, (disease, off-label use, drug)={ edge_index=[2, 1956], pos_edge_label=[489], pos_edge_label_index=[2, 489], neg_edge_label=[489], neg_edge_label_index=[2, 489], }, (disease, associated with, gene/protein)={ edge_index=[2, 62370], pos_edge_label=[15592], pos_edge_label_index=[2, 15592], neg_edge_label=[15592], neg_edge_label_index=[2, 15592], }, (molecular_function, interacts with, gene/protein)={ edge_index=[2, 53409], pos_edge_label=[13352], pos_edge_label_index=[2, 13352], neg_edge_label=[13352], neg_edge_label_index=[2, 13352], }, (cellular_component, interacts with, gene/protein)={ edge_index=[2, 59802], pos_edge_label=[14950], pos_edge_label_index=[2, 14950], neg_edge_label=[14950], neg_edge_label_index=[2, 14950], }, (biological_process, interacts with, gene/protein)={ edge_index=[2, 109057], pos_edge_label=[27264], pos_edge_label_index=[2, 27264], neg_edge_label=[27264], neg_edge_label_index=[2, 27264], } )
In [22]:
Copied!
# Train edge index
print("Train Edge Index:")
print(dm.data["train"].edge_index_dict)
# Validation edge index
print("Validation Edge Index:")
print(dm.data["val"].edge_index_dict)
# Test edge index
print("Test Edge Index:")
print(dm.data["test"].edge_index_dict)
# Train edge index
print("Train Edge Index:")
print(dm.data["train"].edge_index_dict)
# Validation edge index
print("Validation Edge Index:")
print(dm.data["val"].edge_index_dict)
# Test edge index
print("Test Edge Index:")
print(dm.data["test"].edge_index_dict)
Train Edge Index: {('gene/protein', 'ppi', 'gene/protein'): tensor([[ 7867, 2004, 13232, ..., 13920, 1095, 518], [ 4494, 12809, 8137, ..., 3937, 3186, 4023]]), ('drug', 'carrier', 'gene/protein'): tensor([[ 423, 132, 318, ..., 397, 279, 510], [4706, 4293, 2380, ..., 4293, 9441, 4706]]), ('drug', 'enzyme', 'gene/protein'): tensor([[ 171, 579, 456, ..., 1723, 908, 1639], [ 8900, 748, 1642, ..., 13047, 3936, 3854]]), ('drug', 'target', 'gene/protein'): tensor([[ 4628, 212, 795, ..., 751, 5172, 130], [ 1283, 10928, 5579, ..., 1011, 651, 375]]), ('drug', 'transporter', 'gene/protein'): tensor([[ 1560, 662, 237, ..., 179, 5533, 5558], [10907, 10907, 8392, ..., 15304, 5243, 1428]]), ('drug', 'contraindication', 'disease'): tensor([[ 165, 16, 254, ..., 29, 1648, 265], [ 8894, 10133, 5456, ..., 5990, 6872, 6356]]), ('drug', 'indication', 'disease'): tensor([[ 1300, 113, 18, ..., 1766, 142, 436], [ 5939, 6398, 294, ..., 6709, 7143, 11103]]), ('drug', 'off-label use', 'disease'): tensor([[ 5724, 918, 215, ..., 11, 986, 184], [11089, 2553, 11301, ..., 5632, 8811, 11118]]), ('drug', 'synergistic interaction', 'drug'): tensor([[4995, 2424, 1009, ..., 987, 189, 300], [1424, 6227, 1703, ..., 5498, 51, 5497]]), ('gene/protein', 'associated with', 'disease'): tensor([[ 7516, 702, 924, ..., 7836, 3193, 13385], [11206, 1621, 2484, ..., 10329, 941, 11206]]), ('disease', 'parent-child', 'disease'): tensor([[ 6194, 9545, 16592, ..., 6546, 6534, 3816], [ 6666, 9544, 8982, ..., 4638, 6746, 6601]]), ('biological_process', 'parent-child', 'biological_process'): tensor([[ 1736, 26394, 1411, ..., 18205, 1664, 5826], [ 6062, 7095, 2702, ..., 652, 2650, 17883]]), ('molecular_function', 'parent-child', 'molecular_function'): tensor([[ 9267, 218, 1506, ..., 70, 70, 4294], [ 1198, 3872, 10158, ..., 3260, 3049, 307]]), ('cellular_component', 'parent-child', 'cellular_component'): tensor([[2946, 718, 101, ..., 3209, 645, 4002], [ 412, 817, 2114, ..., 296, 108, 870]]), ('gene/protein', 'interacts with', 'molecular_function'): tensor([[ 357, 15134, 8707, ..., 9894, 5358, 8668], [ 9178, 10771, 1488, ..., 178, 5880, 99]]), ('gene/protein', 'interacts with', 'cellular_component'): tensor([[ 4620, 177, 1102, ..., 12796, 4907, 12381], [ 588, 2593, 2812, ..., 2516, 53, 501]]), ('gene/protein', 'interacts with', 'biological_process'): tensor([[ 145, 6636, 2193, ..., 4106, 10453, 7670], [ 792, 8433, 1581, ..., 20710, 5978, 8746]]), ('gene/protein', 'carrier', 'drug'): tensor([[9441, 2380, 111, ..., 4706, 4293, 4293], [ 156, 213, 65, ..., 537, 246, 432]]), ('gene/protein', 'enzyme', 'drug'): tensor([[10794, 6277, 1642, ..., 1278, 13620, 7821], [ 866, 1034, 1749, ..., 1003, 393, 35]]), ('gene/protein', 'target', 'drug'): tensor([[11138, 1918, 5238, ..., 16374, 969, 8999], [ 3400, 1124, 2314, ..., 1692, 2479, 4522]]), ('gene/protein', 'transporter', 'drug'): tensor([[ 3169, 4131, 10044, ..., 3169, 1308, 10044], [ 28, 1439, 814, ..., 1532, 5555, 36]]), ('disease', 'contraindication', 'drug'): tensor([[ 4479, 6446, 7005, ..., 11207, 7066, 11061], [ 1632, 704, 1002, ..., 696, 205, 886]]), ('disease', 'indication', 'drug'): tensor([[ 7701, 7038, 11177, ..., 1050, 2111, 5798], [ 37, 1851, 3261, ..., 3449, 479, 1637]]), ('disease', 'off-label use', 'drug'): tensor([[ 7635, 7564, 11063, ..., 7154, 5476, 4306], [ 685, 107, 1319, ..., 851, 5717, 714]]), ('disease', 'associated with', 'gene/protein'): tensor([[ 7350, 11490, 5448, ..., 8052, 1155, 346], [ 1614, 11112, 14110, ..., 12461, 3550, 10665]]), ('molecular_function', 'interacts with', 'gene/protein'): tensor([[ 178, 745, 1714, ..., 1785, 264, 7477], [4708, 4707, 2865, ..., 9079, 1717, 7181]]), ('cellular_component', 'interacts with', 'gene/protein'): tensor([[ 539, 3722, 397, ..., 876, 833, 501], [17865, 8619, 6032, ..., 1024, 9094, 7501]]), ('biological_process', 'interacts with', 'gene/protein'): tensor([[ 506, 97, 10817, ..., 8444, 19099, 20570], [ 2410, 7239, 8505, ..., 1470, 10809, 6874]])} Validation Edge Index: {('gene/protein', 'ppi', 'gene/protein'): tensor([[ 7867, 2004, 13232, ..., 13920, 1095, 518], [ 4494, 12809, 8137, ..., 3937, 3186, 4023]]), ('drug', 'carrier', 'gene/protein'): tensor([[ 423, 132, 318, ..., 397, 279, 510], [4706, 4293, 2380, ..., 4293, 9441, 4706]]), ('drug', 'enzyme', 'gene/protein'): tensor([[ 171, 579, 456, ..., 1723, 908, 1639], [ 8900, 748, 1642, ..., 13047, 3936, 3854]]), ('drug', 'target', 'gene/protein'): tensor([[ 4628, 212, 795, ..., 751, 5172, 130], [ 1283, 10928, 5579, ..., 1011, 651, 375]]), ('drug', 'transporter', 'gene/protein'): tensor([[ 1560, 662, 237, ..., 179, 5533, 5558], [10907, 10907, 8392, ..., 15304, 5243, 1428]]), ('drug', 'contraindication', 'disease'): tensor([[ 165, 16, 254, ..., 29, 1648, 265], [ 8894, 10133, 5456, ..., 5990, 6872, 6356]]), ('drug', 'indication', 'disease'): tensor([[ 1300, 113, 18, ..., 1766, 142, 436], [ 5939, 6398, 294, ..., 6709, 7143, 11103]]), ('drug', 'off-label use', 'disease'): tensor([[ 5724, 918, 215, ..., 11, 986, 184], [11089, 2553, 11301, ..., 5632, 8811, 11118]]), ('drug', 'synergistic interaction', 'drug'): tensor([[4995, 2424, 1009, ..., 987, 189, 300], [1424, 6227, 1703, ..., 5498, 51, 5497]]), ('gene/protein', 'associated with', 'disease'): tensor([[ 7516, 702, 924, ..., 7836, 3193, 13385], [11206, 1621, 2484, ..., 10329, 941, 11206]]), ('disease', 'parent-child', 'disease'): tensor([[ 6194, 9545, 16592, ..., 6546, 6534, 3816], [ 6666, 9544, 8982, ..., 4638, 6746, 6601]]), ('biological_process', 'parent-child', 'biological_process'): tensor([[ 1736, 26394, 1411, ..., 18205, 1664, 5826], [ 6062, 7095, 2702, ..., 652, 2650, 17883]]), ('molecular_function', 'parent-child', 'molecular_function'): tensor([[ 9267, 218, 1506, ..., 70, 70, 4294], [ 1198, 3872, 10158, ..., 3260, 3049, 307]]), ('cellular_component', 'parent-child', 'cellular_component'): tensor([[2946, 718, 101, ..., 3209, 645, 4002], [ 412, 817, 2114, ..., 296, 108, 870]]), ('gene/protein', 'interacts with', 'molecular_function'): tensor([[ 357, 15134, 8707, ..., 9894, 5358, 8668], [ 9178, 10771, 1488, ..., 178, 5880, 99]]), ('gene/protein', 'interacts with', 'cellular_component'): tensor([[ 4620, 177, 1102, ..., 12796, 4907, 12381], [ 588, 2593, 2812, ..., 2516, 53, 501]]), ('gene/protein', 'interacts with', 'biological_process'): tensor([[ 145, 6636, 2193, ..., 4106, 10453, 7670], [ 792, 8433, 1581, ..., 20710, 5978, 8746]]), ('gene/protein', 'carrier', 'drug'): tensor([[9441, 2380, 111, ..., 4706, 4293, 4293], [ 156, 213, 65, ..., 537, 246, 432]]), ('gene/protein', 'enzyme', 'drug'): tensor([[10794, 6277, 1642, ..., 1278, 13620, 7821], [ 866, 1034, 1749, ..., 1003, 393, 35]]), ('gene/protein', 'target', 'drug'): tensor([[11138, 1918, 5238, ..., 16374, 969, 8999], [ 3400, 1124, 2314, ..., 1692, 2479, 4522]]), ('gene/protein', 'transporter', 'drug'): tensor([[ 3169, 4131, 10044, ..., 3169, 1308, 10044], [ 28, 1439, 814, ..., 1532, 5555, 36]]), ('disease', 'contraindication', 'drug'): tensor([[ 4479, 6446, 7005, ..., 11207, 7066, 11061], [ 1632, 704, 1002, ..., 696, 205, 886]]), ('disease', 'indication', 'drug'): tensor([[ 7701, 7038, 11177, ..., 1050, 2111, 5798], [ 37, 1851, 3261, ..., 3449, 479, 1637]]), ('disease', 'off-label use', 'drug'): tensor([[ 7635, 7564, 11063, ..., 7154, 5476, 4306], [ 685, 107, 1319, ..., 851, 5717, 714]]), ('disease', 'associated with', 'gene/protein'): tensor([[ 7350, 11490, 5448, ..., 8052, 1155, 346], [ 1614, 11112, 14110, ..., 12461, 3550, 10665]]), ('molecular_function', 'interacts with', 'gene/protein'): tensor([[ 178, 745, 1714, ..., 1785, 264, 7477], [4708, 4707, 2865, ..., 9079, 1717, 7181]]), ('cellular_component', 'interacts with', 'gene/protein'): tensor([[ 539, 3722, 397, ..., 876, 833, 501], [17865, 8619, 6032, ..., 1024, 9094, 7501]]), ('biological_process', 'interacts with', 'gene/protein'): tensor([[ 506, 97, 10817, ..., 8444, 19099, 20570], [ 2410, 7239, 8505, ..., 1470, 10809, 6874]])} Test Edge Index: {('gene/protein', 'ppi', 'gene/protein'): tensor([[ 7867, 2004, 13232, ..., 7199, 38, 569], [ 4494, 12809, 8137, ..., 12107, 25, 6890]]), ('drug', 'carrier', 'gene/protein'): tensor([[ 423, 132, 318, ..., 100, 28, 143], [4706, 4293, 2380, ..., 2506, 2380, 4293]]), ('drug', 'enzyme', 'gene/protein'): tensor([[ 171, 579, 456, ..., 259, 1113, 1034], [ 8900, 748, 1642, ..., 13047, 17831, 13620]]), ('drug', 'target', 'gene/protein'): tensor([[ 4628, 212, 795, ..., 3524, 1098, 1850], [ 1283, 10928, 5579, ..., 6068, 2464, 10018]]), ('drug', 'transporter', 'gene/protein'): tensor([[ 1560, 662, 237, ..., 701, 278, 394], [10907, 10907, 8392, ..., 16445, 4131, 3882]]), ('drug', 'contraindication', 'disease'): tensor([[ 165, 16, 254, ..., 978, 1898, 199], [ 8894, 10133, 5456, ..., 8393, 2068, 7742]]), ('drug', 'indication', 'disease'): tensor([[1300, 113, 18, ..., 5899, 2490, 1383], [5939, 6398, 294, ..., 7689, 8761, 9816]]), ('drug', 'off-label use', 'disease'): tensor([[ 5724, 918, 215, ..., 625, 247, 207], [11089, 2553, 11301, ..., 4574, 11118, 5967]]), ('drug', 'synergistic interaction', 'drug'): tensor([[4995, 2424, 1009, ..., 6316, 343, 1730], [1424, 6227, 1703, ..., 1061, 1595, 5620]]), ('gene/protein', 'associated with', 'disease'): tensor([[ 7516, 702, 924, ..., 8290, 11875, 1859], [11206, 1621, 2484, ..., 3864, 3723, 10065]]), ('disease', 'parent-child', 'disease'): tensor([[ 6194, 9545, 16592, ..., 14099, 9035, 9448], [ 6666, 9544, 8982, ..., 8874, 5088, 3521]]), ('biological_process', 'parent-child', 'biological_process'): tensor([[ 1736, 26394, 1411, ..., 542, 14657, 15967], [ 6062, 7095, 2702, ..., 7604, 703, 249]]), ('molecular_function', 'parent-child', 'molecular_function'): tensor([[ 9267, 218, 1506, ..., 110, 5353, 7666], [ 1198, 3872, 10158, ..., 3506, 85, 551]]), ('cellular_component', 'parent-child', 'cellular_component'): tensor([[2946, 718, 101, ..., 786, 241, 93], [ 412, 817, 2114, ..., 236, 425, 101]]), ('gene/protein', 'interacts with', 'molecular_function'): tensor([[ 357, 15134, 8707, ..., 1724, 7704, 858], [ 9178, 10771, 1488, ..., 178, 178, 178]]), ('gene/protein', 'interacts with', 'cellular_component'): tensor([[ 4620, 177, 1102, ..., 6566, 13616, 6076], [ 588, 2593, 2812, ..., 235, 568, 506]]), ('gene/protein', 'interacts with', 'biological_process'): tensor([[ 145, 6636, 2193, ..., 15172, 6831, 12820], [ 792, 8433, 1581, ..., 1052, 10817, 5924]]), ('gene/protein', 'carrier', 'drug'): tensor([[9441, 2380, 111, ..., 4293, 4706, 4293], [ 156, 213, 65, ..., 176, 540, 233]]), ('gene/protein', 'enzyme', 'drug'): tensor([[10794, 6277, 1642, ..., 3936, 8130, 10316], [ 866, 1034, 1749, ..., 451, 807, 663]]), ('gene/protein', 'target', 'drug'): tensor([[11138, 1918, 5238, ..., 13708, 7977, 1369], [ 3400, 1124, 2314, ..., 5399, 1036, 2184]]), ('gene/protein', 'transporter', 'drug'): tensor([[ 3169, 4131, 10044, ..., 4131, 10205, 14035], [ 28, 1439, 814, ..., 1192, 546, 954]]), ('disease', 'contraindication', 'drug'): tensor([[4479, 6446, 7005, ..., 6691, 6898, 8872], [1632, 704, 1002, ..., 2128, 312, 155]]), ('disease', 'indication', 'drug'): tensor([[ 7701, 7038, 11177, ..., 6471, 5331, 7442], [ 37, 1851, 3261, ..., 8, 427, 5711]]), ('disease', 'off-label use', 'drug'): tensor([[ 7635, 7564, 11063, ..., 7214, 8260, 10936], [ 685, 107, 1319, ..., 2606, 2340, 1092]]), ('disease', 'associated with', 'gene/protein'): tensor([[ 7350, 11490, 5448, ..., 11106, 7074, 6793], [ 1614, 11112, 14110, ..., 8461, 2851, 5962]]), ('molecular_function', 'interacts with', 'gene/protein'): tensor([[ 178, 745, 1714, ..., 99, 356, 7045], [4708, 4707, 2865, ..., 394, 8937, 8589]]), ('cellular_component', 'interacts with', 'gene/protein'): tensor([[ 539, 3722, 397, ..., 652, 2606, 837], [17865, 8619, 6032, ..., 13789, 10362, 3445]]), ('biological_process', 'interacts with', 'gene/protein'): tensor([[ 506, 97, 10817, ..., 9461, 11852, 97], [ 2410, 7239, 8505, ..., 6481, 9023, 4207]])}
In [23]:
Copied!
print("📌 Available node types:")
print(dm.data["train"].node_types)
print("📌 Available edge types:")
print(dm.data["train"].edge_types)
print("📌 Available node types:")
print(dm.data["train"].node_types)
print("📌 Available edge types:")
print(dm.data["train"].edge_types)
📌 Available node types: [np.str_('biological_process'), np.str_('cellular_component'), np.str_('disease'), np.str_('drug'), np.str_('gene/protein'), np.str_('molecular_function')] 📌 Available edge types: [('gene/protein', 'ppi', 'gene/protein'), ('drug', 'carrier', 'gene/protein'), ('drug', 'enzyme', 'gene/protein'), ('drug', 'target', 'gene/protein'), ('drug', 'transporter', 'gene/protein'), ('drug', 'contraindication', 'disease'), ('drug', 'indication', 'disease'), ('drug', 'off-label use', 'disease'), ('drug', 'synergistic interaction', 'drug'), ('gene/protein', 'associated with', 'disease'), ('disease', 'parent-child', 'disease'), ('biological_process', 'parent-child', 'biological_process'), ('molecular_function', 'parent-child', 'molecular_function'), ('cellular_component', 'parent-child', 'cellular_component'), ('gene/protein', 'interacts with', 'molecular_function'), ('gene/protein', 'interacts with', 'cellular_component'), ('gene/protein', 'interacts with', 'biological_process'), ('gene/protein', 'carrier', 'drug'), ('gene/protein', 'enzyme', 'drug'), ('gene/protein', 'target', 'drug'), ('gene/protein', 'transporter', 'drug'), ('disease', 'contraindication', 'drug'), ('disease', 'indication', 'drug'), ('disease', 'off-label use', 'drug'), ('disease', 'associated with', 'gene/protein'), ('molecular_function', 'interacts with', 'gene/protein'), ('cellular_component', 'interacts with', 'gene/protein'), ('biological_process', 'interacts with', 'gene/protein')]
Step 6: 🕸️ Visualize a Subgraph of the Knowledge Graph
Plotting Local Neighborhood of a Node in a Heterogeneous Knowledge Graph (DYNC1I2 Example)
In [24]:
Copied!
import pickle
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from torch_geometric.utils import to_networkx
# Load data
with open("biobridge_cache.pkl", "rb") as f:
data = pickle.load(f)
hetero_data = data["init"]
mapper = data.get("mapper", {}) # might be empty
target_node_type = "gene/protein"
target_node_name = "DYNC1I2"
# Manual map fallback
manual_map = {"DYNC1I2": 6726}
local_idx = mapper.get(target_node_type, {}).get(target_node_name) or manual_map.get(target_node_name)
if local_idx is None:
raise ValueError(f"Node {target_node_name} not found")
# Convert to NetworkX graph
G_full = to_networkx(hetero_data)
G_full = nx.Graph(G_full) # undirected
# Relabel nodes with (type, index)
offset = 0
mapping = {}
for node_type in hetero_data.node_types:
n_nodes = hetero_data[node_type].num_nodes
for i in range(n_nodes):
mapping[offset + i] = (node_type, i)
offset += n_nodes
G_full = nx.relabel_nodes(G_full, mapping)
target_node_nx = (target_node_type, local_idx)
if target_node_nx not in G_full:
raise ValueError(f"Node {target_node_nx} not found in graph")
# Build 1-hop subgraph
neighbors = list(G_full.adj[target_node_nx])
sub_nodes = [target_node_nx] + neighbors
G_sub = G_full.subgraph(sub_nodes).copy()
print(G_sub)
# Build reverse_mapper from hetero_data.node_name (more reliable)
reverse_mapper = {}
for node_type in hetero_data.node_types:
node_names = hetero_data[node_type].node_name # numpy array of names
if node_names is None:
continue # skip if missing
for idx, name in enumerate(node_names):
# node_name might be bytes if pickled, decode if needed
if isinstance(name, bytes):
name = name.decode('utf-8')
reverse_mapper[(node_type, idx)] = name
# Add manual reverse map for known missing mappings
manual_reverse_map = {
(target_node_type, local_idx): target_node_name
}
reverse_mapper.update(manual_reverse_map)
# Create labels using node names, fallback to "type_index"
labels = {
node: reverse_mapper.get(node, f"{node[0]}_{node[1]}")
for node in G_sub.nodes
}
node_types = {node: node[0] for node in G_sub.nodes}
unique_types = sorted(set(node_types.values()))
palette = list(mcolors.TABLEAU_COLORS.values())
color_map = {ntype: palette[i % len(palette)] for i, ntype in enumerate(unique_types)}
node_colors = [color_map[node_types[node]] for node in G_sub.nodes]
# Plot
plt.figure(figsize=(10, 8))
pos = nx.spring_layout(G_sub, seed=42)
nx.draw_networkx_nodes(G_sub, pos, node_color=node_colors, node_size=500, alpha=0.85)
nx.draw_networkx_edges(G_sub, pos, edge_color='gray', alpha=0.6)
nx.draw_networkx_labels(G_sub, pos, labels, font_size=8)
# Legend
for ntype, color in color_map.items():
plt.scatter([], [], color=color, label=ntype)
plt.legend(title="Node Types", loc="upper left")
plt.title(f"1-hop Neighbors of {target_node_name} in Knowledge Graph")
plt.axis("off")
plt.tight_layout()
plt.show()
import pickle
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from torch_geometric.utils import to_networkx
# Load data
with open("biobridge_cache.pkl", "rb") as f:
data = pickle.load(f)
hetero_data = data["init"]
mapper = data.get("mapper", {}) # might be empty
target_node_type = "gene/protein"
target_node_name = "DYNC1I2"
# Manual map fallback
manual_map = {"DYNC1I2": 6726}
local_idx = mapper.get(target_node_type, {}).get(target_node_name) or manual_map.get(target_node_name)
if local_idx is None:
raise ValueError(f"Node {target_node_name} not found")
# Convert to NetworkX graph
G_full = to_networkx(hetero_data)
G_full = nx.Graph(G_full) # undirected
# Relabel nodes with (type, index)
offset = 0
mapping = {}
for node_type in hetero_data.node_types:
n_nodes = hetero_data[node_type].num_nodes
for i in range(n_nodes):
mapping[offset + i] = (node_type, i)
offset += n_nodes
G_full = nx.relabel_nodes(G_full, mapping)
target_node_nx = (target_node_type, local_idx)
if target_node_nx not in G_full:
raise ValueError(f"Node {target_node_nx} not found in graph")
# Build 1-hop subgraph
neighbors = list(G_full.adj[target_node_nx])
sub_nodes = [target_node_nx] + neighbors
G_sub = G_full.subgraph(sub_nodes).copy()
print(G_sub)
# Build reverse_mapper from hetero_data.node_name (more reliable)
reverse_mapper = {}
for node_type in hetero_data.node_types:
node_names = hetero_data[node_type].node_name # numpy array of names
if node_names is None:
continue # skip if missing
for idx, name in enumerate(node_names):
# node_name might be bytes if pickled, decode if needed
if isinstance(name, bytes):
name = name.decode('utf-8')
reverse_mapper[(node_type, idx)] = name
# Add manual reverse map for known missing mappings
manual_reverse_map = {
(target_node_type, local_idx): target_node_name
}
reverse_mapper.update(manual_reverse_map)
# Create labels using node names, fallback to "type_index"
labels = {
node: reverse_mapper.get(node, f"{node[0]}_{node[1]}")
for node in G_sub.nodes
}
node_types = {node: node[0] for node in G_sub.nodes}
unique_types = sorted(set(node_types.values()))
palette = list(mcolors.TABLEAU_COLORS.values())
color_map = {ntype: palette[i % len(palette)] for i, ntype in enumerate(unique_types)}
node_colors = [color_map[node_types[node]] for node in G_sub.nodes]
# Plot
plt.figure(figsize=(10, 8))
pos = nx.spring_layout(G_sub, seed=42)
nx.draw_networkx_nodes(G_sub, pos, node_color=node_colors, node_size=500, alpha=0.85)
nx.draw_networkx_edges(G_sub, pos, edge_color='gray', alpha=0.6)
nx.draw_networkx_labels(G_sub, pos, labels, font_size=8)
# Legend
for ntype, color in color_map.items():
plt.scatter([], [], color=color, label=ntype)
plt.legend(title="Node Types", loc="upper left")
plt.title(f"1-hop Neighbors of {target_node_name} in Knowledge Graph")
plt.axis("off")
plt.tight_layout()
plt.show()
Graph with 21 nodes and 47 edges