Skip to content

Biobridge Dataloader

Loads the BioBridge dataset and prepares it for training and evaluation using LightningDataModule from PyTorch Lightning, with optional caching.

BioBridgeDataModule

Bases: LightningDataModule

LightningDataModule for the BioBridge dataset.

Source code in vpeleaderboard/data/src/kg/biobridge_datamodule_hetero.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
class BioBridgeDataModule(LightningDataModule):
    """
    LightningDataModule for the BioBridge dataset.
    """
    def __init__(self, cfg: DictConfig) -> None:
        """
        Initializes the BioBridgeDataModule.

        Args:
            cfg (DictConfig): Configuration object with dataset parameters.
        """
        super().__init__()
        self.save_hyperparameters(logger=False)
        self.cfg = cfg
        self.primekg_dir = cfg.data.primekg_dir
        self.biobridge_dir = cfg.data.biobridge_dir
        self.batch_size = cfg.data.batch_size
        self.cache_path = cfg.data.cache_path
        self.biobridge = None
        self.mapper = {}
        self.data = {}

    def prepare_data(self) -> None:
        """
        Loads and processes the data, optionally using cached data.
        If cache is invalid or not found, processes data freshly.

        Returns:
            None
        """

        if os.path.exists(self.cache_path):
            logger.info("Cache file found. Loading cached data...")

            with open(self.cache_path, "rb") as f:
                cached_data = pickle.load(f)

            # Check if the cache is a dictionary
            if not isinstance(cached_data, dict):
                raise ValueError(f"Cached data at {self.cache_path} is not a dictionary.")

            if "init" not in cached_data:
                raise KeyError(f"Missing 'init' key in cached data from {self.cache_path}.")

            # Check that 'init' is a HeteroData object
            if not isinstance(cached_data["init"], HeteroData):
                raise TypeError("init in cached data is not of type HeteroData.")
            self.data = cached_data
            logger.info("DEBUG (prepare_data): "
                  "Cached data looks like a valid HeteroData dictionary.")
            return  # Exit if cache is valid

        # Cache not found or failed to load. Processing raw data for BioBridge
        biobridge_processor = Biobridgepreparedata(self.cfg)
        biobridge_processor.pre_process() # This method builds biobridge_processor.data["init"]

        # Ensure self.data is a dict before trying to assign a key.
        if not isinstance(self.data, dict):
            self.data = {} # Ensure it's a dict
        self.data["init"] = biobridge_processor.data["init"]

        with open(self.cache_path, "wb") as f:
            pickle.dump(self.data, f)
        logger.info("✅ Cached raw data to %s", self.cache_path)

    def setup(self, stage: Optional[str] = None) -> None:
        """
        Sets up training, validation, and test splits using RandomLinkSplit.

        Args:
            stage (Optional[str]): Optional stage indicator.

        Returns:
            None
        """
        if "train" in self.data:
            return
        with hydra.initialize(version_base=None,
                              config_path="../../../configs/data/kg/BioBRIDGE-PrimeKG"):
            # Load the configuration for RandomLinkSplit
            cfg: DictConfig = hydra.compose(config_name="default")

        transform = RandomLinkSplit(
            num_val=cfg.random_link_split.num_val,
            num_test=cfg.random_link_split.num_test,
            is_undirected=cfg.random_link_split.is_undirected,
            add_negative_train_samples=cfg.random_link_split.add_negative_train_samples,
            neg_sampling_ratio=cfg.random_link_split.neg_sampling_ratio,
            split_labels=cfg.random_link_split.split_labels,
            edge_types=self.data["init"].edge_types,
        )

        self.data["train"], self.data["val"], self.data["test"] = transform(self.data["init"])

        with open(self.cache_path, "wb") as f:
            pickle.dump(self.data, f)
        logger.info("✅ Cached train/val/test splits to %s", self.cache_path)

    def train_dataloader(self) -> GeoDataLoader:
        """
        Returns the training dataloader.

        Returns:
            GeoDataLoader: DataLoader for training set.
        """
        if "train" not in self.data:
            raise RuntimeError("Please run `setup()` before calling train_dataloader().")
        return GeoDataLoader([self.data["train"]], batch_size=1, shuffle=False)

    def val_dataloader(self) -> GeoDataLoader:
        """
        Returns the validation dataloader.

        Returns:
            GeoDataLoader: DataLoader for validation set.
        """
        if "val" not in self.data:
            raise RuntimeError("Please run `setup()` before calling val_dataloader().")
        return GeoDataLoader([self.data["val"]], batch_size=1, shuffle=False)

    def test_dataloader(self) -> GeoDataLoader:
        """
        Returns the test dataloader.

        Returns:
            GeoDataLoader: DataLoader for test set.
        """
        if "test" not in self.data:
            raise RuntimeError("Please run `setup()` before calling test_dataloader().")
        return GeoDataLoader([self.data["test"]], batch_size=1, shuffle=False)

    def teardown(self, stage: Optional[str] = None) -> None:
        pass

    def state_dict(self) -> Dict[Any, Any]:
        """
        Returns the internal state of the data module.

        Returns:
            dict: Empty dictionary (no state to save).
        """
        return {}

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        pass

__init__(cfg)

Initializes the BioBridgeDataModule.

Parameters:

Name Type Description Default
cfg DictConfig

Configuration object with dataset parameters.

required
Source code in vpeleaderboard/data/src/kg/biobridge_datamodule_hetero.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def __init__(self, cfg: DictConfig) -> None:
    """
    Initializes the BioBridgeDataModule.

    Args:
        cfg (DictConfig): Configuration object with dataset parameters.
    """
    super().__init__()
    self.save_hyperparameters(logger=False)
    self.cfg = cfg
    self.primekg_dir = cfg.data.primekg_dir
    self.biobridge_dir = cfg.data.biobridge_dir
    self.batch_size = cfg.data.batch_size
    self.cache_path = cfg.data.cache_path
    self.biobridge = None
    self.mapper = {}
    self.data = {}

prepare_data()

Loads and processes the data, optionally using cached data. If cache is invalid or not found, processes data freshly.

Returns:

Type Description
None

None

Source code in vpeleaderboard/data/src/kg/biobridge_datamodule_hetero.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def prepare_data(self) -> None:
    """
    Loads and processes the data, optionally using cached data.
    If cache is invalid or not found, processes data freshly.

    Returns:
        None
    """

    if os.path.exists(self.cache_path):
        logger.info("Cache file found. Loading cached data...")

        with open(self.cache_path, "rb") as f:
            cached_data = pickle.load(f)

        # Check if the cache is a dictionary
        if not isinstance(cached_data, dict):
            raise ValueError(f"Cached data at {self.cache_path} is not a dictionary.")

        if "init" not in cached_data:
            raise KeyError(f"Missing 'init' key in cached data from {self.cache_path}.")

        # Check that 'init' is a HeteroData object
        if not isinstance(cached_data["init"], HeteroData):
            raise TypeError("init in cached data is not of type HeteroData.")
        self.data = cached_data
        logger.info("DEBUG (prepare_data): "
              "Cached data looks like a valid HeteroData dictionary.")
        return  # Exit if cache is valid

    # Cache not found or failed to load. Processing raw data for BioBridge
    biobridge_processor = Biobridgepreparedata(self.cfg)
    biobridge_processor.pre_process() # This method builds biobridge_processor.data["init"]

    # Ensure self.data is a dict before trying to assign a key.
    if not isinstance(self.data, dict):
        self.data = {} # Ensure it's a dict
    self.data["init"] = biobridge_processor.data["init"]

    with open(self.cache_path, "wb") as f:
        pickle.dump(self.data, f)
    logger.info("✅ Cached raw data to %s", self.cache_path)

setup(stage=None)

Sets up training, validation, and test splits using RandomLinkSplit.

Parameters:

Name Type Description Default
stage Optional[str]

Optional stage indicator.

None

Returns:

Type Description
None

None

Source code in vpeleaderboard/data/src/kg/biobridge_datamodule_hetero.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def setup(self, stage: Optional[str] = None) -> None:
    """
    Sets up training, validation, and test splits using RandomLinkSplit.

    Args:
        stage (Optional[str]): Optional stage indicator.

    Returns:
        None
    """
    if "train" in self.data:
        return
    with hydra.initialize(version_base=None,
                          config_path="../../../configs/data/kg/BioBRIDGE-PrimeKG"):
        # Load the configuration for RandomLinkSplit
        cfg: DictConfig = hydra.compose(config_name="default")

    transform = RandomLinkSplit(
        num_val=cfg.random_link_split.num_val,
        num_test=cfg.random_link_split.num_test,
        is_undirected=cfg.random_link_split.is_undirected,
        add_negative_train_samples=cfg.random_link_split.add_negative_train_samples,
        neg_sampling_ratio=cfg.random_link_split.neg_sampling_ratio,
        split_labels=cfg.random_link_split.split_labels,
        edge_types=self.data["init"].edge_types,
    )

    self.data["train"], self.data["val"], self.data["test"] = transform(self.data["init"])

    with open(self.cache_path, "wb") as f:
        pickle.dump(self.data, f)
    logger.info("✅ Cached train/val/test splits to %s", self.cache_path)

state_dict()

Returns the internal state of the data module.

Returns:

Name Type Description
dict Dict[Any, Any]

Empty dictionary (no state to save).

Source code in vpeleaderboard/data/src/kg/biobridge_datamodule_hetero.py
154
155
156
157
158
159
160
161
def state_dict(self) -> Dict[Any, Any]:
    """
    Returns the internal state of the data module.

    Returns:
        dict: Empty dictionary (no state to save).
    """
    return {}

test_dataloader()

Returns the test dataloader.

Returns:

Name Type Description
GeoDataLoader DataLoader

DataLoader for test set.

Source code in vpeleaderboard/data/src/kg/biobridge_datamodule_hetero.py
140
141
142
143
144
145
146
147
148
149
def test_dataloader(self) -> GeoDataLoader:
    """
    Returns the test dataloader.

    Returns:
        GeoDataLoader: DataLoader for test set.
    """
    if "test" not in self.data:
        raise RuntimeError("Please run `setup()` before calling test_dataloader().")
    return GeoDataLoader([self.data["test"]], batch_size=1, shuffle=False)

train_dataloader()

Returns the training dataloader.

Returns:

Name Type Description
GeoDataLoader DataLoader

DataLoader for training set.

Source code in vpeleaderboard/data/src/kg/biobridge_datamodule_hetero.py
118
119
120
121
122
123
124
125
126
127
def train_dataloader(self) -> GeoDataLoader:
    """
    Returns the training dataloader.

    Returns:
        GeoDataLoader: DataLoader for training set.
    """
    if "train" not in self.data:
        raise RuntimeError("Please run `setup()` before calling train_dataloader().")
    return GeoDataLoader([self.data["train"]], batch_size=1, shuffle=False)

val_dataloader()

Returns the validation dataloader.

Returns:

Name Type Description
GeoDataLoader DataLoader

DataLoader for validation set.

Source code in vpeleaderboard/data/src/kg/biobridge_datamodule_hetero.py
129
130
131
132
133
134
135
136
137
138
def val_dataloader(self) -> GeoDataLoader:
    """
    Returns the validation dataloader.

    Returns:
        GeoDataLoader: DataLoader for validation set.
    """
    if "val" not in self.data:
        raise RuntimeError("Please run `setup()` before calling val_dataloader().")
    return GeoDataLoader([self.data["val"]], batch_size=1, shuffle=False)