Knowledge Graph (KG) Dataloader with Lightning¶
This guide shows how to load and use a Knowledge Graph dataset with the PyTorch Lightning LightningDataModule.
To load and simulate data from the KG model, follow the steps below:
Step 1: 📦 Import the module
In [13]:
Copied!
import sys
import os
# Go up to the root where `vpeleaderboard/` is located
sys.path.append(os.path.abspath("../../"))
import sys
import os
# Go up to the root where `vpeleaderboard/` is located
sys.path.append(os.path.abspath("../../"))
In [14]:
Copied!
from vpeleaderboard.data.src.kg.biobridge_datamodule_hetero import BioBridgeDataModule
from vpeleaderboard.data.src.kg.biobridge_datamodule_hetero import BioBridgeDataModule
Step 2: ⚙️ Initialize the KGDataModule
Specify the data directories and initialize the BioBridgeDataModule:
In [15]:
Copied!
import hydra
with hydra.initialize(config_path="../../vpeleaderboard/configs/data/kg/BioBRIDGE-PrimeKG", version_base=None):
cfg = hydra.compose(config_name="default")
dm = BioBridgeDataModule(cfg)
import hydra
with hydra.initialize(config_path="../../vpeleaderboard/configs/data/kg/BioBRIDGE-PrimeKG", version_base=None):
cfg = hydra.compose(config_name="default")
dm = BioBridgeDataModule(cfg)
Step 3: 🧹 Prepare data
Prepare the KG data by loading and caching it:
In [16]:
Copied!
# Load data, embeddings, and node/edge mappings
dm.prepare_data()
# Load data, embeddings, and node/edge mappings
dm.prepare_data()
Step 4: 🧠 Setup the data splits
Split the data into training, validation, and test sets:
In [17]:
Copied!
# Build HeteroData and apply RandomLinkSplit
dm.setup()
# Build HeteroData and apply RandomLinkSplit
dm.setup()
Step 5: 🧪 Access the dataloaders
Now, you can retrieve the standard Lightning dataloaders:
In [18]:
Copied!
print(dm.data.keys()) # Should include 'train', 'val', 'test' if setup worked
print(dm.data.keys()) # Should include 'train', 'val', 'test' if setup worked
dict_keys(['init', 'train', 'val', 'test'])
Training data
In [19]:
Copied!
train_loader = dm.train_dataloader()
train_batch = next(iter(train_loader))
print(train_batch)
train_loader = dm.train_dataloader()
train_batch = next(iter(train_loader))
print(train_batch)
HeteroDataBatch(
biological_process={
num_nodes=27409,
x=[27409, 768],
node_name=[1],
batch=[27409],
ptr=[2],
},
cellular_component={
num_nodes=4011,
x=[4011, 768],
node_name=[1],
batch=[4011],
ptr=[2],
},
disease={
num_nodes=17054,
x=[17054, 768],
node_name=[1],
batch=[17054],
ptr=[2],
},
drug={
num_nodes=6759,
x=[6759, 512],
node_name=[1],
batch=[6759],
ptr=[2],
},
gene/protein={
num_nodes=18797,
x=[18797, 2560],
node_name=[1],
batch=[18797],
ptr=[2],
},
molecular_function={
num_nodes=10951,
x=[10951, 768],
node_name=[1],
batch=[10951],
ptr=[2],
},
(gene/protein, ppi, gene/protein)={
edge_index=[2, 440447],
pos_edge_label=[440447],
pos_edge_label_index=[2, 440447],
neg_edge_label=[440447],
neg_edge_label_index=[2, 440447],
},
(drug, carrier, gene/protein)={
edge_index=[2, 571],
pos_edge_label=[571],
pos_edge_label_index=[2, 571],
neg_edge_label=[571],
neg_edge_label_index=[2, 571],
},
(drug, enzyme, gene/protein)={
edge_index=[2, 3542],
pos_edge_label=[3542],
pos_edge_label_index=[2, 3542],
neg_edge_label=[3542],
neg_edge_label_index=[2, 3542],
},
(drug, target, gene/protein)={
edge_index=[2, 10444],
pos_edge_label=[10444],
pos_edge_label_index=[2, 10444],
neg_edge_label=[10444],
neg_edge_label_index=[2, 10444],
},
(drug, transporter, gene/protein)={
edge_index=[2, 2065],
pos_edge_label=[2065],
pos_edge_label_index=[2, 2065],
neg_edge_label=[2065],
neg_edge_label_index=[2, 2065],
},
(drug, contraindication, disease)={
edge_index=[2, 21028],
pos_edge_label=[21028],
pos_edge_label_index=[2, 21028],
neg_edge_label=[21028],
neg_edge_label_index=[2, 21028],
},
(drug, indication, disease)={
edge_index=[2, 6123],
pos_edge_label=[6123],
pos_edge_label_index=[2, 6123],
neg_edge_label=[6123],
neg_edge_label_index=[2, 6123],
},
(drug, off-label use, disease)={
edge_index=[2, 1712],
pos_edge_label=[1712],
pos_edge_label_index=[2, 1712],
neg_edge_label=[1712],
neg_edge_label_index=[2, 1712],
},
(drug, synergistic interaction, drug)={
edge_index=[2, 1563945],
pos_edge_label=[1563945],
pos_edge_label_index=[2, 1563945],
neg_edge_label=[1563945],
neg_edge_label_index=[2, 1563945],
},
(gene/protein, associated with, disease)={
edge_index=[2, 54574],
pos_edge_label=[54574],
pos_edge_label_index=[2, 54574],
neg_edge_label=[54574],
neg_edge_label_index=[2, 54574],
},
(disease, parent-child, disease)={
edge_index=[2, 45073],
pos_edge_label=[45073],
pos_edge_label_index=[2, 45073],
neg_edge_label=[45073],
neg_edge_label_index=[2, 45073],
},
(biological_process, parent-child, biological_process)={
edge_index=[2, 69741],
pos_edge_label=[69741],
pos_edge_label_index=[2, 69741],
neg_edge_label=[69741],
neg_edge_label_index=[2, 69741],
},
(molecular_function, parent-child, molecular_function)={
edge_index=[2, 18506],
pos_edge_label=[18506],
pos_edge_label_index=[2, 18506],
neg_edge_label=[18506],
neg_edge_label_index=[2, 18506],
},
(cellular_component, parent-child, cellular_component)={
edge_index=[2, 6440],
pos_edge_label=[6440],
pos_edge_label_index=[2, 6440],
neg_edge_label=[6440],
neg_edge_label_index=[2, 6440],
},
(gene/protein, interacts with, molecular_function)={
edge_index=[2, 46733],
pos_edge_label=[46733],
pos_edge_label_index=[2, 46733],
neg_edge_label=[46733],
neg_edge_label_index=[2, 46733],
},
(gene/protein, interacts with, cellular_component)={
edge_index=[2, 52327],
pos_edge_label=[52327],
pos_edge_label_index=[2, 52327],
neg_edge_label=[52327],
neg_edge_label_index=[2, 52327],
},
(gene/protein, interacts with, biological_process)={
edge_index=[2, 95425],
pos_edge_label=[95425],
pos_edge_label_index=[2, 95425],
neg_edge_label=[95425],
neg_edge_label_index=[2, 95425],
},
(gene/protein, carrier, drug)={
edge_index=[2, 571],
pos_edge_label=[571],
pos_edge_label_index=[2, 571],
neg_edge_label=[571],
neg_edge_label_index=[2, 571],
},
(gene/protein, enzyme, drug)={
edge_index=[2, 3542],
pos_edge_label=[3542],
pos_edge_label_index=[2, 3542],
neg_edge_label=[3542],
neg_edge_label_index=[2, 3542],
},
(gene/protein, target, drug)={
edge_index=[2, 10444],
pos_edge_label=[10444],
pos_edge_label_index=[2, 10444],
neg_edge_label=[10444],
neg_edge_label_index=[2, 10444],
},
(gene/protein, transporter, drug)={
edge_index=[2, 2065],
pos_edge_label=[2065],
pos_edge_label_index=[2, 2065],
neg_edge_label=[2065],
neg_edge_label_index=[2, 2065],
},
(disease, contraindication, drug)={
edge_index=[2, 21028],
pos_edge_label=[21028],
pos_edge_label_index=[2, 21028],
neg_edge_label=[21028],
neg_edge_label_index=[2, 21028],
},
(disease, indication, drug)={
edge_index=[2, 6123],
pos_edge_label=[6123],
pos_edge_label_index=[2, 6123],
neg_edge_label=[6123],
neg_edge_label_index=[2, 6123],
},
(disease, off-label use, drug)={
edge_index=[2, 1712],
pos_edge_label=[1712],
pos_edge_label_index=[2, 1712],
neg_edge_label=[1712],
neg_edge_label_index=[2, 1712],
},
(disease, associated with, gene/protein)={
edge_index=[2, 54574],
pos_edge_label=[54574],
pos_edge_label_index=[2, 54574],
neg_edge_label=[54574],
neg_edge_label_index=[2, 54574],
},
(molecular_function, interacts with, gene/protein)={
edge_index=[2, 46733],
pos_edge_label=[46733],
pos_edge_label_index=[2, 46733],
neg_edge_label=[46733],
neg_edge_label_index=[2, 46733],
},
(cellular_component, interacts with, gene/protein)={
edge_index=[2, 52327],
pos_edge_label=[52327],
pos_edge_label_index=[2, 52327],
neg_edge_label=[52327],
neg_edge_label_index=[2, 52327],
},
(biological_process, interacts with, gene/protein)={
edge_index=[2, 95425],
pos_edge_label=[95425],
pos_edge_label_index=[2, 95425],
neg_edge_label=[95425],
neg_edge_label_index=[2, 95425],
}
)
Validation data
In [20]:
Copied!
val_loader = dm.val_dataloader()
val_batch = next(iter(val_loader))
print(val_batch)
val_loader = dm.val_dataloader()
val_batch = next(iter(val_loader))
print(val_batch)
HeteroDataBatch(
biological_process={
num_nodes=27409,
x=[27409, 768],
node_name=[1],
batch=[27409],
ptr=[2],
},
cellular_component={
num_nodes=4011,
x=[4011, 768],
node_name=[1],
batch=[4011],
ptr=[2],
},
disease={
num_nodes=17054,
x=[17054, 768],
node_name=[1],
batch=[17054],
ptr=[2],
},
drug={
num_nodes=6759,
x=[6759, 512],
node_name=[1],
batch=[6759],
ptr=[2],
},
gene/protein={
num_nodes=18797,
x=[18797, 2560],
node_name=[1],
batch=[18797],
ptr=[2],
},
molecular_function={
num_nodes=10951,
x=[10951, 768],
node_name=[1],
batch=[10951],
ptr=[2],
},
(gene/protein, ppi, gene/protein)={
edge_index=[2, 440447],
pos_edge_label=[62920],
pos_edge_label_index=[2, 62920],
neg_edge_label=[62920],
neg_edge_label_index=[2, 62920],
},
(drug, carrier, gene/protein)={
edge_index=[2, 571],
pos_edge_label=[81],
pos_edge_label_index=[2, 81],
neg_edge_label=[81],
neg_edge_label_index=[2, 81],
},
(drug, enzyme, gene/protein)={
edge_index=[2, 3542],
pos_edge_label=[506],
pos_edge_label_index=[2, 506],
neg_edge_label=[506],
neg_edge_label_index=[2, 506],
},
(drug, target, gene/protein)={
edge_index=[2, 10444],
pos_edge_label=[1492],
pos_edge_label_index=[2, 1492],
neg_edge_label=[1492],
neg_edge_label_index=[2, 1492],
},
(drug, transporter, gene/protein)={
edge_index=[2, 2065],
pos_edge_label=[295],
pos_edge_label_index=[2, 295],
neg_edge_label=[295],
neg_edge_label_index=[2, 295],
},
(drug, contraindication, disease)={
edge_index=[2, 21028],
pos_edge_label=[3004],
pos_edge_label_index=[2, 3004],
neg_edge_label=[3004],
neg_edge_label_index=[2, 3004],
},
(drug, indication, disease)={
edge_index=[2, 6123],
pos_edge_label=[874],
pos_edge_label_index=[2, 874],
neg_edge_label=[874],
neg_edge_label_index=[2, 874],
},
(drug, off-label use, disease)={
edge_index=[2, 1712],
pos_edge_label=[244],
pos_edge_label_index=[2, 244],
neg_edge_label=[244],
neg_edge_label_index=[2, 244],
},
(drug, synergistic interaction, drug)={
edge_index=[2, 1563945],
pos_edge_label=[223420],
pos_edge_label_index=[2, 223420],
neg_edge_label=[223420],
neg_edge_label_index=[2, 223420],
},
(gene/protein, associated with, disease)={
edge_index=[2, 54574],
pos_edge_label=[7796],
pos_edge_label_index=[2, 7796],
neg_edge_label=[7796],
neg_edge_label_index=[2, 7796],
},
(disease, parent-child, disease)={
edge_index=[2, 45073],
pos_edge_label=[6438],
pos_edge_label_index=[2, 6438],
neg_edge_label=[6438],
neg_edge_label_index=[2, 6438],
},
(biological_process, parent-child, biological_process)={
edge_index=[2, 69741],
pos_edge_label=[9963],
pos_edge_label_index=[2, 9963],
neg_edge_label=[9963],
neg_edge_label_index=[2, 9963],
},
(molecular_function, parent-child, molecular_function)={
edge_index=[2, 18506],
pos_edge_label=[2643],
pos_edge_label_index=[2, 2643],
neg_edge_label=[2643],
neg_edge_label_index=[2, 2643],
},
(cellular_component, parent-child, cellular_component)={
edge_index=[2, 6440],
pos_edge_label=[920],
pos_edge_label_index=[2, 920],
neg_edge_label=[920],
neg_edge_label_index=[2, 920],
},
(gene/protein, interacts with, molecular_function)={
edge_index=[2, 46733],
pos_edge_label=[6676],
pos_edge_label_index=[2, 6676],
neg_edge_label=[6676],
neg_edge_label_index=[2, 6676],
},
(gene/protein, interacts with, cellular_component)={
edge_index=[2, 52327],
pos_edge_label=[7475],
pos_edge_label_index=[2, 7475],
neg_edge_label=[7475],
neg_edge_label_index=[2, 7475],
},
(gene/protein, interacts with, biological_process)={
edge_index=[2, 95425],
pos_edge_label=[13632],
pos_edge_label_index=[2, 13632],
neg_edge_label=[13632],
neg_edge_label_index=[2, 13632],
},
(gene/protein, carrier, drug)={
edge_index=[2, 571],
pos_edge_label=[81],
pos_edge_label_index=[2, 81],
neg_edge_label=[81],
neg_edge_label_index=[2, 81],
},
(gene/protein, enzyme, drug)={
edge_index=[2, 3542],
pos_edge_label=[506],
pos_edge_label_index=[2, 506],
neg_edge_label=[506],
neg_edge_label_index=[2, 506],
},
(gene/protein, target, drug)={
edge_index=[2, 10444],
pos_edge_label=[1492],
pos_edge_label_index=[2, 1492],
neg_edge_label=[1492],
neg_edge_label_index=[2, 1492],
},
(gene/protein, transporter, drug)={
edge_index=[2, 2065],
pos_edge_label=[295],
pos_edge_label_index=[2, 295],
neg_edge_label=[295],
neg_edge_label_index=[2, 295],
},
(disease, contraindication, drug)={
edge_index=[2, 21028],
pos_edge_label=[3004],
pos_edge_label_index=[2, 3004],
neg_edge_label=[3004],
neg_edge_label_index=[2, 3004],
},
(disease, indication, drug)={
edge_index=[2, 6123],
pos_edge_label=[874],
pos_edge_label_index=[2, 874],
neg_edge_label=[874],
neg_edge_label_index=[2, 874],
},
(disease, off-label use, drug)={
edge_index=[2, 1712],
pos_edge_label=[244],
pos_edge_label_index=[2, 244],
neg_edge_label=[244],
neg_edge_label_index=[2, 244],
},
(disease, associated with, gene/protein)={
edge_index=[2, 54574],
pos_edge_label=[7796],
pos_edge_label_index=[2, 7796],
neg_edge_label=[7796],
neg_edge_label_index=[2, 7796],
},
(molecular_function, interacts with, gene/protein)={
edge_index=[2, 46733],
pos_edge_label=[6676],
pos_edge_label_index=[2, 6676],
neg_edge_label=[6676],
neg_edge_label_index=[2, 6676],
},
(cellular_component, interacts with, gene/protein)={
edge_index=[2, 52327],
pos_edge_label=[7475],
pos_edge_label_index=[2, 7475],
neg_edge_label=[7475],
neg_edge_label_index=[2, 7475],
},
(biological_process, interacts with, gene/protein)={
edge_index=[2, 95425],
pos_edge_label=[13632],
pos_edge_label_index=[2, 13632],
neg_edge_label=[13632],
neg_edge_label_index=[2, 13632],
}
)
Test data
In [21]:
Copied!
test_loader = dm.test_dataloader()
test_batch = next(iter(test_loader))
print(test_batch)
test_loader = dm.test_dataloader()
test_batch = next(iter(test_loader))
print(test_batch)
HeteroDataBatch(
biological_process={
num_nodes=27409,
x=[27409, 768],
node_name=[1],
batch=[27409],
ptr=[2],
},
cellular_component={
num_nodes=4011,
x=[4011, 768],
node_name=[1],
batch=[4011],
ptr=[2],
},
disease={
num_nodes=17054,
x=[17054, 768],
node_name=[1],
batch=[17054],
ptr=[2],
},
drug={
num_nodes=6759,
x=[6759, 512],
node_name=[1],
batch=[6759],
ptr=[2],
},
gene/protein={
num_nodes=18797,
x=[18797, 2560],
node_name=[1],
batch=[18797],
ptr=[2],
},
molecular_function={
num_nodes=10951,
x=[10951, 768],
node_name=[1],
batch=[10951],
ptr=[2],
},
(gene/protein, ppi, gene/protein)={
edge_index=[2, 503367],
pos_edge_label=[125841],
pos_edge_label_index=[2, 125841],
neg_edge_label=[125841],
neg_edge_label_index=[2, 125841],
},
(drug, carrier, gene/protein)={
edge_index=[2, 652],
pos_edge_label=[162],
pos_edge_label_index=[2, 162],
neg_edge_label=[162],
neg_edge_label_index=[2, 162],
},
(drug, enzyme, gene/protein)={
edge_index=[2, 4048],
pos_edge_label=[1012],
pos_edge_label_index=[2, 1012],
neg_edge_label=[1012],
neg_edge_label_index=[2, 1012],
},
(drug, target, gene/protein)={
edge_index=[2, 11936],
pos_edge_label=[2984],
pos_edge_label_index=[2, 2984],
neg_edge_label=[2984],
neg_edge_label_index=[2, 2984],
},
(drug, transporter, gene/protein)={
edge_index=[2, 2360],
pos_edge_label=[590],
pos_edge_label_index=[2, 590],
neg_edge_label=[590],
neg_edge_label_index=[2, 590],
},
(drug, contraindication, disease)={
edge_index=[2, 24032],
pos_edge_label=[6008],
pos_edge_label_index=[2, 6008],
neg_edge_label=[6008],
neg_edge_label_index=[2, 6008],
},
(drug, indication, disease)={
edge_index=[2, 6997],
pos_edge_label=[1749],
pos_edge_label_index=[2, 1749],
neg_edge_label=[1749],
neg_edge_label_index=[2, 1749],
},
(drug, off-label use, disease)={
edge_index=[2, 1956],
pos_edge_label=[489],
pos_edge_label_index=[2, 489],
neg_edge_label=[489],
neg_edge_label_index=[2, 489],
},
(drug, synergistic interaction, drug)={
edge_index=[2, 1787365],
pos_edge_label=[446841],
pos_edge_label_index=[2, 446841],
neg_edge_label=[446841],
neg_edge_label_index=[2, 446841],
},
(gene/protein, associated with, disease)={
edge_index=[2, 62370],
pos_edge_label=[15592],
pos_edge_label_index=[2, 15592],
neg_edge_label=[15592],
neg_edge_label_index=[2, 15592],
},
(disease, parent-child, disease)={
edge_index=[2, 51511],
pos_edge_label=[12877],
pos_edge_label_index=[2, 12877],
neg_edge_label=[12877],
neg_edge_label_index=[2, 12877],
},
(biological_process, parent-child, biological_process)={
edge_index=[2, 79704],
pos_edge_label=[19926],
pos_edge_label_index=[2, 19926],
neg_edge_label=[19926],
neg_edge_label_index=[2, 19926],
},
(molecular_function, parent-child, molecular_function)={
edge_index=[2, 21149],
pos_edge_label=[5287],
pos_edge_label_index=[2, 5287],
neg_edge_label=[5287],
neg_edge_label_index=[2, 5287],
},
(cellular_component, parent-child, cellular_component)={
edge_index=[2, 7360],
pos_edge_label=[1840],
pos_edge_label_index=[2, 1840],
neg_edge_label=[1840],
neg_edge_label_index=[2, 1840],
},
(gene/protein, interacts with, molecular_function)={
edge_index=[2, 53409],
pos_edge_label=[13352],
pos_edge_label_index=[2, 13352],
neg_edge_label=[13352],
neg_edge_label_index=[2, 13352],
},
(gene/protein, interacts with, cellular_component)={
edge_index=[2, 59802],
pos_edge_label=[14950],
pos_edge_label_index=[2, 14950],
neg_edge_label=[14950],
neg_edge_label_index=[2, 14950],
},
(gene/protein, interacts with, biological_process)={
edge_index=[2, 109057],
pos_edge_label=[27264],
pos_edge_label_index=[2, 27264],
neg_edge_label=[27264],
neg_edge_label_index=[2, 27264],
},
(gene/protein, carrier, drug)={
edge_index=[2, 652],
pos_edge_label=[162],
pos_edge_label_index=[2, 162],
neg_edge_label=[162],
neg_edge_label_index=[2, 162],
},
(gene/protein, enzyme, drug)={
edge_index=[2, 4048],
pos_edge_label=[1012],
pos_edge_label_index=[2, 1012],
neg_edge_label=[1012],
neg_edge_label_index=[2, 1012],
},
(gene/protein, target, drug)={
edge_index=[2, 11936],
pos_edge_label=[2984],
pos_edge_label_index=[2, 2984],
neg_edge_label=[2984],
neg_edge_label_index=[2, 2984],
},
(gene/protein, transporter, drug)={
edge_index=[2, 2360],
pos_edge_label=[590],
pos_edge_label_index=[2, 590],
neg_edge_label=[590],
neg_edge_label_index=[2, 590],
},
(disease, contraindication, drug)={
edge_index=[2, 24032],
pos_edge_label=[6008],
pos_edge_label_index=[2, 6008],
neg_edge_label=[6008],
neg_edge_label_index=[2, 6008],
},
(disease, indication, drug)={
edge_index=[2, 6997],
pos_edge_label=[1749],
pos_edge_label_index=[2, 1749],
neg_edge_label=[1749],
neg_edge_label_index=[2, 1749],
},
(disease, off-label use, drug)={
edge_index=[2, 1956],
pos_edge_label=[489],
pos_edge_label_index=[2, 489],
neg_edge_label=[489],
neg_edge_label_index=[2, 489],
},
(disease, associated with, gene/protein)={
edge_index=[2, 62370],
pos_edge_label=[15592],
pos_edge_label_index=[2, 15592],
neg_edge_label=[15592],
neg_edge_label_index=[2, 15592],
},
(molecular_function, interacts with, gene/protein)={
edge_index=[2, 53409],
pos_edge_label=[13352],
pos_edge_label_index=[2, 13352],
neg_edge_label=[13352],
neg_edge_label_index=[2, 13352],
},
(cellular_component, interacts with, gene/protein)={
edge_index=[2, 59802],
pos_edge_label=[14950],
pos_edge_label_index=[2, 14950],
neg_edge_label=[14950],
neg_edge_label_index=[2, 14950],
},
(biological_process, interacts with, gene/protein)={
edge_index=[2, 109057],
pos_edge_label=[27264],
pos_edge_label_index=[2, 27264],
neg_edge_label=[27264],
neg_edge_label_index=[2, 27264],
}
)
In [22]:
Copied!
# Train edge index
print("Train Edge Index:")
print(dm.data["train"].edge_index_dict)
# Validation edge index
print("Validation Edge Index:")
print(dm.data["val"].edge_index_dict)
# Test edge index
print("Test Edge Index:")
print(dm.data["test"].edge_index_dict)
# Train edge index
print("Train Edge Index:")
print(dm.data["train"].edge_index_dict)
# Validation edge index
print("Validation Edge Index:")
print(dm.data["val"].edge_index_dict)
# Test edge index
print("Test Edge Index:")
print(dm.data["test"].edge_index_dict)
Train Edge Index:
{('gene/protein', 'ppi', 'gene/protein'): tensor([[ 7867, 2004, 13232, ..., 13920, 1095, 518],
[ 4494, 12809, 8137, ..., 3937, 3186, 4023]]), ('drug', 'carrier', 'gene/protein'): tensor([[ 423, 132, 318, ..., 397, 279, 510],
[4706, 4293, 2380, ..., 4293, 9441, 4706]]), ('drug', 'enzyme', 'gene/protein'): tensor([[ 171, 579, 456, ..., 1723, 908, 1639],
[ 8900, 748, 1642, ..., 13047, 3936, 3854]]), ('drug', 'target', 'gene/protein'): tensor([[ 4628, 212, 795, ..., 751, 5172, 130],
[ 1283, 10928, 5579, ..., 1011, 651, 375]]), ('drug', 'transporter', 'gene/protein'): tensor([[ 1560, 662, 237, ..., 179, 5533, 5558],
[10907, 10907, 8392, ..., 15304, 5243, 1428]]), ('drug', 'contraindication', 'disease'): tensor([[ 165, 16, 254, ..., 29, 1648, 265],
[ 8894, 10133, 5456, ..., 5990, 6872, 6356]]), ('drug', 'indication', 'disease'): tensor([[ 1300, 113, 18, ..., 1766, 142, 436],
[ 5939, 6398, 294, ..., 6709, 7143, 11103]]), ('drug', 'off-label use', 'disease'): tensor([[ 5724, 918, 215, ..., 11, 986, 184],
[11089, 2553, 11301, ..., 5632, 8811, 11118]]), ('drug', 'synergistic interaction', 'drug'): tensor([[4995, 2424, 1009, ..., 987, 189, 300],
[1424, 6227, 1703, ..., 5498, 51, 5497]]), ('gene/protein', 'associated with', 'disease'): tensor([[ 7516, 702, 924, ..., 7836, 3193, 13385],
[11206, 1621, 2484, ..., 10329, 941, 11206]]), ('disease', 'parent-child', 'disease'): tensor([[ 6194, 9545, 16592, ..., 6546, 6534, 3816],
[ 6666, 9544, 8982, ..., 4638, 6746, 6601]]), ('biological_process', 'parent-child', 'biological_process'): tensor([[ 1736, 26394, 1411, ..., 18205, 1664, 5826],
[ 6062, 7095, 2702, ..., 652, 2650, 17883]]), ('molecular_function', 'parent-child', 'molecular_function'): tensor([[ 9267, 218, 1506, ..., 70, 70, 4294],
[ 1198, 3872, 10158, ..., 3260, 3049, 307]]), ('cellular_component', 'parent-child', 'cellular_component'): tensor([[2946, 718, 101, ..., 3209, 645, 4002],
[ 412, 817, 2114, ..., 296, 108, 870]]), ('gene/protein', 'interacts with', 'molecular_function'): tensor([[ 357, 15134, 8707, ..., 9894, 5358, 8668],
[ 9178, 10771, 1488, ..., 178, 5880, 99]]), ('gene/protein', 'interacts with', 'cellular_component'): tensor([[ 4620, 177, 1102, ..., 12796, 4907, 12381],
[ 588, 2593, 2812, ..., 2516, 53, 501]]), ('gene/protein', 'interacts with', 'biological_process'): tensor([[ 145, 6636, 2193, ..., 4106, 10453, 7670],
[ 792, 8433, 1581, ..., 20710, 5978, 8746]]), ('gene/protein', 'carrier', 'drug'): tensor([[9441, 2380, 111, ..., 4706, 4293, 4293],
[ 156, 213, 65, ..., 537, 246, 432]]), ('gene/protein', 'enzyme', 'drug'): tensor([[10794, 6277, 1642, ..., 1278, 13620, 7821],
[ 866, 1034, 1749, ..., 1003, 393, 35]]), ('gene/protein', 'target', 'drug'): tensor([[11138, 1918, 5238, ..., 16374, 969, 8999],
[ 3400, 1124, 2314, ..., 1692, 2479, 4522]]), ('gene/protein', 'transporter', 'drug'): tensor([[ 3169, 4131, 10044, ..., 3169, 1308, 10044],
[ 28, 1439, 814, ..., 1532, 5555, 36]]), ('disease', 'contraindication', 'drug'): tensor([[ 4479, 6446, 7005, ..., 11207, 7066, 11061],
[ 1632, 704, 1002, ..., 696, 205, 886]]), ('disease', 'indication', 'drug'): tensor([[ 7701, 7038, 11177, ..., 1050, 2111, 5798],
[ 37, 1851, 3261, ..., 3449, 479, 1637]]), ('disease', 'off-label use', 'drug'): tensor([[ 7635, 7564, 11063, ..., 7154, 5476, 4306],
[ 685, 107, 1319, ..., 851, 5717, 714]]), ('disease', 'associated with', 'gene/protein'): tensor([[ 7350, 11490, 5448, ..., 8052, 1155, 346],
[ 1614, 11112, 14110, ..., 12461, 3550, 10665]]), ('molecular_function', 'interacts with', 'gene/protein'): tensor([[ 178, 745, 1714, ..., 1785, 264, 7477],
[4708, 4707, 2865, ..., 9079, 1717, 7181]]), ('cellular_component', 'interacts with', 'gene/protein'): tensor([[ 539, 3722, 397, ..., 876, 833, 501],
[17865, 8619, 6032, ..., 1024, 9094, 7501]]), ('biological_process', 'interacts with', 'gene/protein'): tensor([[ 506, 97, 10817, ..., 8444, 19099, 20570],
[ 2410, 7239, 8505, ..., 1470, 10809, 6874]])}
Validation Edge Index:
{('gene/protein', 'ppi', 'gene/protein'): tensor([[ 7867, 2004, 13232, ..., 13920, 1095, 518],
[ 4494, 12809, 8137, ..., 3937, 3186, 4023]]), ('drug', 'carrier', 'gene/protein'): tensor([[ 423, 132, 318, ..., 397, 279, 510],
[4706, 4293, 2380, ..., 4293, 9441, 4706]]), ('drug', 'enzyme', 'gene/protein'): tensor([[ 171, 579, 456, ..., 1723, 908, 1639],
[ 8900, 748, 1642, ..., 13047, 3936, 3854]]), ('drug', 'target', 'gene/protein'): tensor([[ 4628, 212, 795, ..., 751, 5172, 130],
[ 1283, 10928, 5579, ..., 1011, 651, 375]]), ('drug', 'transporter', 'gene/protein'): tensor([[ 1560, 662, 237, ..., 179, 5533, 5558],
[10907, 10907, 8392, ..., 15304, 5243, 1428]]), ('drug', 'contraindication', 'disease'): tensor([[ 165, 16, 254, ..., 29, 1648, 265],
[ 8894, 10133, 5456, ..., 5990, 6872, 6356]]), ('drug', 'indication', 'disease'): tensor([[ 1300, 113, 18, ..., 1766, 142, 436],
[ 5939, 6398, 294, ..., 6709, 7143, 11103]]), ('drug', 'off-label use', 'disease'): tensor([[ 5724, 918, 215, ..., 11, 986, 184],
[11089, 2553, 11301, ..., 5632, 8811, 11118]]), ('drug', 'synergistic interaction', 'drug'): tensor([[4995, 2424, 1009, ..., 987, 189, 300],
[1424, 6227, 1703, ..., 5498, 51, 5497]]), ('gene/protein', 'associated with', 'disease'): tensor([[ 7516, 702, 924, ..., 7836, 3193, 13385],
[11206, 1621, 2484, ..., 10329, 941, 11206]]), ('disease', 'parent-child', 'disease'): tensor([[ 6194, 9545, 16592, ..., 6546, 6534, 3816],
[ 6666, 9544, 8982, ..., 4638, 6746, 6601]]), ('biological_process', 'parent-child', 'biological_process'): tensor([[ 1736, 26394, 1411, ..., 18205, 1664, 5826],
[ 6062, 7095, 2702, ..., 652, 2650, 17883]]), ('molecular_function', 'parent-child', 'molecular_function'): tensor([[ 9267, 218, 1506, ..., 70, 70, 4294],
[ 1198, 3872, 10158, ..., 3260, 3049, 307]]), ('cellular_component', 'parent-child', 'cellular_component'): tensor([[2946, 718, 101, ..., 3209, 645, 4002],
[ 412, 817, 2114, ..., 296, 108, 870]]), ('gene/protein', 'interacts with', 'molecular_function'): tensor([[ 357, 15134, 8707, ..., 9894, 5358, 8668],
[ 9178, 10771, 1488, ..., 178, 5880, 99]]), ('gene/protein', 'interacts with', 'cellular_component'): tensor([[ 4620, 177, 1102, ..., 12796, 4907, 12381],
[ 588, 2593, 2812, ..., 2516, 53, 501]]), ('gene/protein', 'interacts with', 'biological_process'): tensor([[ 145, 6636, 2193, ..., 4106, 10453, 7670],
[ 792, 8433, 1581, ..., 20710, 5978, 8746]]), ('gene/protein', 'carrier', 'drug'): tensor([[9441, 2380, 111, ..., 4706, 4293, 4293],
[ 156, 213, 65, ..., 537, 246, 432]]), ('gene/protein', 'enzyme', 'drug'): tensor([[10794, 6277, 1642, ..., 1278, 13620, 7821],
[ 866, 1034, 1749, ..., 1003, 393, 35]]), ('gene/protein', 'target', 'drug'): tensor([[11138, 1918, 5238, ..., 16374, 969, 8999],
[ 3400, 1124, 2314, ..., 1692, 2479, 4522]]), ('gene/protein', 'transporter', 'drug'): tensor([[ 3169, 4131, 10044, ..., 3169, 1308, 10044],
[ 28, 1439, 814, ..., 1532, 5555, 36]]), ('disease', 'contraindication', 'drug'): tensor([[ 4479, 6446, 7005, ..., 11207, 7066, 11061],
[ 1632, 704, 1002, ..., 696, 205, 886]]), ('disease', 'indication', 'drug'): tensor([[ 7701, 7038, 11177, ..., 1050, 2111, 5798],
[ 37, 1851, 3261, ..., 3449, 479, 1637]]), ('disease', 'off-label use', 'drug'): tensor([[ 7635, 7564, 11063, ..., 7154, 5476, 4306],
[ 685, 107, 1319, ..., 851, 5717, 714]]), ('disease', 'associated with', 'gene/protein'): tensor([[ 7350, 11490, 5448, ..., 8052, 1155, 346],
[ 1614, 11112, 14110, ..., 12461, 3550, 10665]]), ('molecular_function', 'interacts with', 'gene/protein'): tensor([[ 178, 745, 1714, ..., 1785, 264, 7477],
[4708, 4707, 2865, ..., 9079, 1717, 7181]]), ('cellular_component', 'interacts with', 'gene/protein'): tensor([[ 539, 3722, 397, ..., 876, 833, 501],
[17865, 8619, 6032, ..., 1024, 9094, 7501]]), ('biological_process', 'interacts with', 'gene/protein'): tensor([[ 506, 97, 10817, ..., 8444, 19099, 20570],
[ 2410, 7239, 8505, ..., 1470, 10809, 6874]])}
Test Edge Index:
{('gene/protein', 'ppi', 'gene/protein'): tensor([[ 7867, 2004, 13232, ..., 7199, 38, 569],
[ 4494, 12809, 8137, ..., 12107, 25, 6890]]), ('drug', 'carrier', 'gene/protein'): tensor([[ 423, 132, 318, ..., 100, 28, 143],
[4706, 4293, 2380, ..., 2506, 2380, 4293]]), ('drug', 'enzyme', 'gene/protein'): tensor([[ 171, 579, 456, ..., 259, 1113, 1034],
[ 8900, 748, 1642, ..., 13047, 17831, 13620]]), ('drug', 'target', 'gene/protein'): tensor([[ 4628, 212, 795, ..., 3524, 1098, 1850],
[ 1283, 10928, 5579, ..., 6068, 2464, 10018]]), ('drug', 'transporter', 'gene/protein'): tensor([[ 1560, 662, 237, ..., 701, 278, 394],
[10907, 10907, 8392, ..., 16445, 4131, 3882]]), ('drug', 'contraindication', 'disease'): tensor([[ 165, 16, 254, ..., 978, 1898, 199],
[ 8894, 10133, 5456, ..., 8393, 2068, 7742]]), ('drug', 'indication', 'disease'): tensor([[1300, 113, 18, ..., 5899, 2490, 1383],
[5939, 6398, 294, ..., 7689, 8761, 9816]]), ('drug', 'off-label use', 'disease'): tensor([[ 5724, 918, 215, ..., 625, 247, 207],
[11089, 2553, 11301, ..., 4574, 11118, 5967]]), ('drug', 'synergistic interaction', 'drug'): tensor([[4995, 2424, 1009, ..., 6316, 343, 1730],
[1424, 6227, 1703, ..., 1061, 1595, 5620]]), ('gene/protein', 'associated with', 'disease'): tensor([[ 7516, 702, 924, ..., 8290, 11875, 1859],
[11206, 1621, 2484, ..., 3864, 3723, 10065]]), ('disease', 'parent-child', 'disease'): tensor([[ 6194, 9545, 16592, ..., 14099, 9035, 9448],
[ 6666, 9544, 8982, ..., 8874, 5088, 3521]]), ('biological_process', 'parent-child', 'biological_process'): tensor([[ 1736, 26394, 1411, ..., 542, 14657, 15967],
[ 6062, 7095, 2702, ..., 7604, 703, 249]]), ('molecular_function', 'parent-child', 'molecular_function'): tensor([[ 9267, 218, 1506, ..., 110, 5353, 7666],
[ 1198, 3872, 10158, ..., 3506, 85, 551]]), ('cellular_component', 'parent-child', 'cellular_component'): tensor([[2946, 718, 101, ..., 786, 241, 93],
[ 412, 817, 2114, ..., 236, 425, 101]]), ('gene/protein', 'interacts with', 'molecular_function'): tensor([[ 357, 15134, 8707, ..., 1724, 7704, 858],
[ 9178, 10771, 1488, ..., 178, 178, 178]]), ('gene/protein', 'interacts with', 'cellular_component'): tensor([[ 4620, 177, 1102, ..., 6566, 13616, 6076],
[ 588, 2593, 2812, ..., 235, 568, 506]]), ('gene/protein', 'interacts with', 'biological_process'): tensor([[ 145, 6636, 2193, ..., 15172, 6831, 12820],
[ 792, 8433, 1581, ..., 1052, 10817, 5924]]), ('gene/protein', 'carrier', 'drug'): tensor([[9441, 2380, 111, ..., 4293, 4706, 4293],
[ 156, 213, 65, ..., 176, 540, 233]]), ('gene/protein', 'enzyme', 'drug'): tensor([[10794, 6277, 1642, ..., 3936, 8130, 10316],
[ 866, 1034, 1749, ..., 451, 807, 663]]), ('gene/protein', 'target', 'drug'): tensor([[11138, 1918, 5238, ..., 13708, 7977, 1369],
[ 3400, 1124, 2314, ..., 5399, 1036, 2184]]), ('gene/protein', 'transporter', 'drug'): tensor([[ 3169, 4131, 10044, ..., 4131, 10205, 14035],
[ 28, 1439, 814, ..., 1192, 546, 954]]), ('disease', 'contraindication', 'drug'): tensor([[4479, 6446, 7005, ..., 6691, 6898, 8872],
[1632, 704, 1002, ..., 2128, 312, 155]]), ('disease', 'indication', 'drug'): tensor([[ 7701, 7038, 11177, ..., 6471, 5331, 7442],
[ 37, 1851, 3261, ..., 8, 427, 5711]]), ('disease', 'off-label use', 'drug'): tensor([[ 7635, 7564, 11063, ..., 7214, 8260, 10936],
[ 685, 107, 1319, ..., 2606, 2340, 1092]]), ('disease', 'associated with', 'gene/protein'): tensor([[ 7350, 11490, 5448, ..., 11106, 7074, 6793],
[ 1614, 11112, 14110, ..., 8461, 2851, 5962]]), ('molecular_function', 'interacts with', 'gene/protein'): tensor([[ 178, 745, 1714, ..., 99, 356, 7045],
[4708, 4707, 2865, ..., 394, 8937, 8589]]), ('cellular_component', 'interacts with', 'gene/protein'): tensor([[ 539, 3722, 397, ..., 652, 2606, 837],
[17865, 8619, 6032, ..., 13789, 10362, 3445]]), ('biological_process', 'interacts with', 'gene/protein'): tensor([[ 506, 97, 10817, ..., 9461, 11852, 97],
[ 2410, 7239, 8505, ..., 6481, 9023, 4207]])}
In [23]:
Copied!
print("📌 Available node types:")
print(dm.data["train"].node_types)
print("📌 Available edge types:")
print(dm.data["train"].edge_types)
print("📌 Available node types:")
print(dm.data["train"].node_types)
print("📌 Available edge types:")
print(dm.data["train"].edge_types)
📌 Available node types:
[np.str_('biological_process'), np.str_('cellular_component'), np.str_('disease'), np.str_('drug'), np.str_('gene/protein'), np.str_('molecular_function')]
📌 Available edge types:
[('gene/protein', 'ppi', 'gene/protein'), ('drug', 'carrier', 'gene/protein'), ('drug', 'enzyme', 'gene/protein'), ('drug', 'target', 'gene/protein'), ('drug', 'transporter', 'gene/protein'), ('drug', 'contraindication', 'disease'), ('drug', 'indication', 'disease'), ('drug', 'off-label use', 'disease'), ('drug', 'synergistic interaction', 'drug'), ('gene/protein', 'associated with', 'disease'), ('disease', 'parent-child', 'disease'), ('biological_process', 'parent-child', 'biological_process'), ('molecular_function', 'parent-child', 'molecular_function'), ('cellular_component', 'parent-child', 'cellular_component'), ('gene/protein', 'interacts with', 'molecular_function'), ('gene/protein', 'interacts with', 'cellular_component'), ('gene/protein', 'interacts with', 'biological_process'), ('gene/protein', 'carrier', 'drug'), ('gene/protein', 'enzyme', 'drug'), ('gene/protein', 'target', 'drug'), ('gene/protein', 'transporter', 'drug'), ('disease', 'contraindication', 'drug'), ('disease', 'indication', 'drug'), ('disease', 'off-label use', 'drug'), ('disease', 'associated with', 'gene/protein'), ('molecular_function', 'interacts with', 'gene/protein'), ('cellular_component', 'interacts with', 'gene/protein'), ('biological_process', 'interacts with', 'gene/protein')]
Step 6: 🕸️ Visualize a Subgraph of the Knowledge Graph
Plotting Local Neighborhood of a Node in a Heterogeneous Knowledge Graph (DYNC1I2 Example)
In [24]:
Copied!
import pickle
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from torch_geometric.utils import to_networkx
# Load data
with open("biobridge_cache.pkl", "rb") as f:
data = pickle.load(f)
hetero_data = data["init"]
mapper = data.get("mapper", {}) # might be empty
target_node_type = "gene/protein"
target_node_name = "DYNC1I2"
# Manual map fallback
manual_map = {"DYNC1I2": 6726}
local_idx = mapper.get(target_node_type, {}).get(target_node_name) or manual_map.get(target_node_name)
if local_idx is None:
raise ValueError(f"Node {target_node_name} not found")
# Convert to NetworkX graph
G_full = to_networkx(hetero_data)
G_full = nx.Graph(G_full) # undirected
# Relabel nodes with (type, index)
offset = 0
mapping = {}
for node_type in hetero_data.node_types:
n_nodes = hetero_data[node_type].num_nodes
for i in range(n_nodes):
mapping[offset + i] = (node_type, i)
offset += n_nodes
G_full = nx.relabel_nodes(G_full, mapping)
target_node_nx = (target_node_type, local_idx)
if target_node_nx not in G_full:
raise ValueError(f"Node {target_node_nx} not found in graph")
# Build 1-hop subgraph
neighbors = list(G_full.adj[target_node_nx])
sub_nodes = [target_node_nx] + neighbors
G_sub = G_full.subgraph(sub_nodes).copy()
print(G_sub)
# Build reverse_mapper from hetero_data.node_name (more reliable)
reverse_mapper = {}
for node_type in hetero_data.node_types:
node_names = hetero_data[node_type].node_name # numpy array of names
if node_names is None:
continue # skip if missing
for idx, name in enumerate(node_names):
# node_name might be bytes if pickled, decode if needed
if isinstance(name, bytes):
name = name.decode('utf-8')
reverse_mapper[(node_type, idx)] = name
# Add manual reverse map for known missing mappings
manual_reverse_map = {
(target_node_type, local_idx): target_node_name
}
reverse_mapper.update(manual_reverse_map)
# Create labels using node names, fallback to "type_index"
labels = {
node: reverse_mapper.get(node, f"{node[0]}_{node[1]}")
for node in G_sub.nodes
}
node_types = {node: node[0] for node in G_sub.nodes}
unique_types = sorted(set(node_types.values()))
palette = list(mcolors.TABLEAU_COLORS.values())
color_map = {ntype: palette[i % len(palette)] for i, ntype in enumerate(unique_types)}
node_colors = [color_map[node_types[node]] for node in G_sub.nodes]
# Plot
plt.figure(figsize=(10, 8))
pos = nx.spring_layout(G_sub, seed=42)
nx.draw_networkx_nodes(G_sub, pos, node_color=node_colors, node_size=500, alpha=0.85)
nx.draw_networkx_edges(G_sub, pos, edge_color='gray', alpha=0.6)
nx.draw_networkx_labels(G_sub, pos, labels, font_size=8)
# Legend
for ntype, color in color_map.items():
plt.scatter([], [], color=color, label=ntype)
plt.legend(title="Node Types", loc="upper left")
plt.title(f"1-hop Neighbors of {target_node_name} in Knowledge Graph")
plt.axis("off")
plt.tight_layout()
plt.show()
import pickle
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from torch_geometric.utils import to_networkx
# Load data
with open("biobridge_cache.pkl", "rb") as f:
data = pickle.load(f)
hetero_data = data["init"]
mapper = data.get("mapper", {}) # might be empty
target_node_type = "gene/protein"
target_node_name = "DYNC1I2"
# Manual map fallback
manual_map = {"DYNC1I2": 6726}
local_idx = mapper.get(target_node_type, {}).get(target_node_name) or manual_map.get(target_node_name)
if local_idx is None:
raise ValueError(f"Node {target_node_name} not found")
# Convert to NetworkX graph
G_full = to_networkx(hetero_data)
G_full = nx.Graph(G_full) # undirected
# Relabel nodes with (type, index)
offset = 0
mapping = {}
for node_type in hetero_data.node_types:
n_nodes = hetero_data[node_type].num_nodes
for i in range(n_nodes):
mapping[offset + i] = (node_type, i)
offset += n_nodes
G_full = nx.relabel_nodes(G_full, mapping)
target_node_nx = (target_node_type, local_idx)
if target_node_nx not in G_full:
raise ValueError(f"Node {target_node_nx} not found in graph")
# Build 1-hop subgraph
neighbors = list(G_full.adj[target_node_nx])
sub_nodes = [target_node_nx] + neighbors
G_sub = G_full.subgraph(sub_nodes).copy()
print(G_sub)
# Build reverse_mapper from hetero_data.node_name (more reliable)
reverse_mapper = {}
for node_type in hetero_data.node_types:
node_names = hetero_data[node_type].node_name # numpy array of names
if node_names is None:
continue # skip if missing
for idx, name in enumerate(node_names):
# node_name might be bytes if pickled, decode if needed
if isinstance(name, bytes):
name = name.decode('utf-8')
reverse_mapper[(node_type, idx)] = name
# Add manual reverse map for known missing mappings
manual_reverse_map = {
(target_node_type, local_idx): target_node_name
}
reverse_mapper.update(manual_reverse_map)
# Create labels using node names, fallback to "type_index"
labels = {
node: reverse_mapper.get(node, f"{node[0]}_{node[1]}")
for node in G_sub.nodes
}
node_types = {node: node[0] for node in G_sub.nodes}
unique_types = sorted(set(node_types.values()))
palette = list(mcolors.TABLEAU_COLORS.values())
color_map = {ntype: palette[i % len(palette)] for i, ntype in enumerate(unique_types)}
node_colors = [color_map[node_types[node]] for node in G_sub.nodes]
# Plot
plt.figure(figsize=(10, 8))
pos = nx.spring_layout(G_sub, seed=42)
nx.draw_networkx_nodes(G_sub, pos, node_color=node_colors, node_size=500, alpha=0.85)
nx.draw_networkx_edges(G_sub, pos, edge_color='gray', alpha=0.6)
nx.draw_networkx_labels(G_sub, pos, labels, font_size=8)
# Legend
for ntype, color in color_map.items():
plt.scatter([], [], color=color, label=ntype)
plt.legend(title="Node Types", loc="upper left")
plt.title(f"1-hop Neighbors of {target_node_name} in Knowledge Graph")
plt.axis("off")
plt.tight_layout()
plt.show()
Graph with 21 nodes and 47 edges