Skip to content

StarkQA-PrimeKG

Class for loading StarkQAPrimeKG dataset.

StarkQAPrimeKG

Bases: Dataset

Class for loading StarkQAPrimeKG dataset. It downloads the data from the HuggingFace repo and stores it in the local directory. The data is then loaded into pandas DataFrame of QA pairs, dictionary of split indices, and node information.

Source code in aiagents4pharma/talk2knowledgegraphs/datasets/starkqa_primekg.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
class StarkQAPrimeKG(Dataset):
    """
    Class for loading StarkQAPrimeKG dataset.
    It downloads the data from the HuggingFace repo and stores it in the local directory.
    The data is then loaded into pandas DataFrame of QA pairs, dictionary of split indices,
    and node information.
    """

    def __init__(self, local_dir: str = "../../../data/starkqa_primekg/"):
        """
        Constructor for StarkQAPrimeKG class.

        Args:
            local_dir (str): The local directory to store the dataset files.
        """
        self.name: str = "starkqa_primekg"
        self.hf_repo_id: str = "snap-stanford/stark"
        self.local_dir: str = local_dir
        # Attributes to store the data
        self.starkqa: pd.DataFrame = None
        self.starkqa_split_idx: dict = None
        self.starkqa_node_info: dict = None
        self.query_emb_dict: dict = None
        self.node_emb_dict: dict = None

        # Set up the dataset
        self.setup()

    def setup(self):
        """
        A method to set up the dataset.
        """
        # Make the directory if it doesn't exist
        os.makedirs(os.path.dirname(self.local_dir), exist_ok=True)

    def _load_stark_repo(self) -> tuple[pd.DataFrame, dict, dict]:
        """
        Private method to load related files of StarkQAPrimeKG dataset.

        Returns:
            The nodes dataframe of StarkQAPrimeKG dataset.
            The split indices of StarkQAPrimeKG dataset.
            The node information of StarkQAPrimeKG dataset.
        """
        # Download the file if it does not exist in the local directory
        # Otherwise, load the data from the local directory
        local_file = os.path.join(self.local_dir, "qa/prime/stark_qa/stark_qa.csv")
        if os.path.exists(local_file):
            print(f"{local_file} already exists. Loading the data from the local directory.")
        else:
            print(f"Downloading files from {self.hf_repo_id}")

            # List all related files in the HuggingFace Hub repository
            files = list_repo_files(self.hf_repo_id, repo_type="dataset")
            files = [f for f in files if ((f.startswith("qa/prime/") or
                                           f.startswith("skb/prime/")) and f.find("raw") == -1)]

            # Download and save each file in the specified folder
            for file in tqdm(files):
                _ = hf_hub_download(self.hf_repo_id,
                                    file,
                                    repo_type="dataset",
                                    local_dir=self.local_dir)

            # Unzip the processed files
            shutil.unpack_archive(
                os.path.join(self.local_dir, "skb/prime/processed.zip"),
                os.path.join(self.local_dir, "skb/prime/")
            )

        # Load StarkQA dataframe
        starkqa = pd.read_csv(
            os.path.join(self.local_dir, "qa/prime/stark_qa/stark_qa.csv"),
            low_memory=False)

        # Read split indices
        qa_indices = sorted(starkqa['id'].tolist())
        starkqa_split_idx = {}
        for split in ['train', 'val', 'test', 'test-0.1']:
            indices_file = os.path.join(self.local_dir, "qa/prime/split", f'{split}.index')
            with open(indices_file, 'r', encoding='utf-8') as f:
                indices = f.read().strip().split('\n')
            query_ids = [int(idx) for idx in indices]
            starkqa_split_idx[split] = np.array(
                [qa_indices.index(query_id) for query_id in query_ids]
            )

        # Load the node info of PrimeKG preprocessed for StarkQA
        with open(os.path.join(self.local_dir, 'skb/prime/processed/node_info.pkl'), 'rb') as f:
            starkqa_node_info = pickle.load(f)

        return starkqa, starkqa_split_idx, starkqa_node_info

    def _load_stark_embeddings(self) -> tuple[dict, dict]:
        """
        Private method to load the embeddings of StarkQAPrimeKG dataset.

        Returns:
            The query embeddings of StarkQAPrimeKG dataset.
            The node embeddings of StarkQAPrimeKG dataset.
        """
        # Load the provided embeddings of query and nodes
        # Note that they utilized 'text-embedding-ada-002' for embeddings
        emb_model = 'text-embedding-ada-002'
        query_emb_url = 'https://drive.google.com/uc?id=1MshwJttPZsHEM2cKA5T13SIrsLeBEdyU'
        node_emb_url = 'https://drive.google.com/uc?id=16EJvCMbgkVrQ0BuIBvLBp-BYPaye-Edy'

        # Prepare respective directories to store the embeddings
        emb_dir = os.path.join(self.local_dir, emb_model)
        query_emb_dir = os.path.join(emb_dir, "query")
        node_emb_dir = os.path.join(emb_dir, "doc")
        os.makedirs(query_emb_dir, exist_ok=True)
        os.makedirs(node_emb_dir, exist_ok=True)
        query_emb_path = os.path.join(query_emb_dir, "query_emb_dict.pt")
        node_emb_path = os.path.join(node_emb_dir, "candidate_emb_dict.pt")

        # Download the embeddings if they do not exist in the local directory
        if not os.path.exists(query_emb_path) or not os.path.exists(node_emb_path):
            # Download the query embeddings
            gdown.download(query_emb_url, query_emb_path, quiet=False)

            # Download the node embeddings
            gdown.download(node_emb_url, node_emb_path, quiet=False)

        # Load the embeddings
        query_emb_dict = torch.load(query_emb_path)
        node_emb_dict = torch.load(node_emb_path)

        return query_emb_dict, node_emb_dict

    def load_data(self):
        """
        Load the StarkQAPrimeKG dataset into pandas DataFrame of QA pairs,
        dictionary of split indices, and node information.
        """
        print("Loading StarkQAPrimeKG dataset...")
        self.starkqa, self.starkqa_split_idx, self.starkqa_node_info = self._load_stark_repo()

        print("Loading StarkQAPrimeKG embeddings...")
        self.query_emb_dict, self.node_emb_dict = self._load_stark_embeddings()


    def get_starkqa(self) -> pd.DataFrame:
        """
        Get the dataframe of StarkQAPrimeKG dataset, containing the QA pairs.

        Returns:
            The nodes dataframe of PrimeKG dataset.
        """
        return self.starkqa

    def get_starkqa_split_indicies(self) -> dict:
        """
        Get the split indices of StarkQAPrimeKG dataset.

        Returns:
            The split indices of StarkQAPrimeKG dataset.
        """
        return self.starkqa_split_idx

    def get_starkqa_node_info(self) -> dict:
        """
        Get the node information of StarkQAPrimeKG dataset.

        Returns:
            The node information of StarkQAPrimeKG dataset.
        """
        return self.starkqa_node_info

    def get_query_embeddings(self) -> dict:
        """
        Get the query embeddings of StarkQAPrimeKG dataset.

        Returns:
            The query embeddings of StarkQAPrimeKG dataset.
        """
        return self.query_emb_dict

    def get_node_embeddings(self) -> dict:
        """
        Get the node embeddings of StarkQAPrimeKG dataset.

        Returns:
            The node embeddings of StarkQAPrimeKG dataset.
        """
        return self.node_emb_dict

__init__(local_dir='../../../data/starkqa_primekg/')

Constructor for StarkQAPrimeKG class.

Parameters:

Name Type Description Default
local_dir str

The local directory to store the dataset files.

'../../../data/starkqa_primekg/'
Source code in aiagents4pharma/talk2knowledgegraphs/datasets/starkqa_primekg.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def __init__(self, local_dir: str = "../../../data/starkqa_primekg/"):
    """
    Constructor for StarkQAPrimeKG class.

    Args:
        local_dir (str): The local directory to store the dataset files.
    """
    self.name: str = "starkqa_primekg"
    self.hf_repo_id: str = "snap-stanford/stark"
    self.local_dir: str = local_dir
    # Attributes to store the data
    self.starkqa: pd.DataFrame = None
    self.starkqa_split_idx: dict = None
    self.starkqa_node_info: dict = None
    self.query_emb_dict: dict = None
    self.node_emb_dict: dict = None

    # Set up the dataset
    self.setup()

_load_stark_embeddings()

Private method to load the embeddings of StarkQAPrimeKG dataset.

Returns:

Type Description
dict

The query embeddings of StarkQAPrimeKG dataset.

dict

The node embeddings of StarkQAPrimeKG dataset.

Source code in aiagents4pharma/talk2knowledgegraphs/datasets/starkqa_primekg.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def _load_stark_embeddings(self) -> tuple[dict, dict]:
    """
    Private method to load the embeddings of StarkQAPrimeKG dataset.

    Returns:
        The query embeddings of StarkQAPrimeKG dataset.
        The node embeddings of StarkQAPrimeKG dataset.
    """
    # Load the provided embeddings of query and nodes
    # Note that they utilized 'text-embedding-ada-002' for embeddings
    emb_model = 'text-embedding-ada-002'
    query_emb_url = 'https://drive.google.com/uc?id=1MshwJttPZsHEM2cKA5T13SIrsLeBEdyU'
    node_emb_url = 'https://drive.google.com/uc?id=16EJvCMbgkVrQ0BuIBvLBp-BYPaye-Edy'

    # Prepare respective directories to store the embeddings
    emb_dir = os.path.join(self.local_dir, emb_model)
    query_emb_dir = os.path.join(emb_dir, "query")
    node_emb_dir = os.path.join(emb_dir, "doc")
    os.makedirs(query_emb_dir, exist_ok=True)
    os.makedirs(node_emb_dir, exist_ok=True)
    query_emb_path = os.path.join(query_emb_dir, "query_emb_dict.pt")
    node_emb_path = os.path.join(node_emb_dir, "candidate_emb_dict.pt")

    # Download the embeddings if they do not exist in the local directory
    if not os.path.exists(query_emb_path) or not os.path.exists(node_emb_path):
        # Download the query embeddings
        gdown.download(query_emb_url, query_emb_path, quiet=False)

        # Download the node embeddings
        gdown.download(node_emb_url, node_emb_path, quiet=False)

    # Load the embeddings
    query_emb_dict = torch.load(query_emb_path)
    node_emb_dict = torch.load(node_emb_path)

    return query_emb_dict, node_emb_dict

_load_stark_repo()

Private method to load related files of StarkQAPrimeKG dataset.

Returns:

Type Description
DataFrame

The nodes dataframe of StarkQAPrimeKG dataset.

dict

The split indices of StarkQAPrimeKG dataset.

dict

The node information of StarkQAPrimeKG dataset.

Source code in aiagents4pharma/talk2knowledgegraphs/datasets/starkqa_primekg.py
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def _load_stark_repo(self) -> tuple[pd.DataFrame, dict, dict]:
    """
    Private method to load related files of StarkQAPrimeKG dataset.

    Returns:
        The nodes dataframe of StarkQAPrimeKG dataset.
        The split indices of StarkQAPrimeKG dataset.
        The node information of StarkQAPrimeKG dataset.
    """
    # Download the file if it does not exist in the local directory
    # Otherwise, load the data from the local directory
    local_file = os.path.join(self.local_dir, "qa/prime/stark_qa/stark_qa.csv")
    if os.path.exists(local_file):
        print(f"{local_file} already exists. Loading the data from the local directory.")
    else:
        print(f"Downloading files from {self.hf_repo_id}")

        # List all related files in the HuggingFace Hub repository
        files = list_repo_files(self.hf_repo_id, repo_type="dataset")
        files = [f for f in files if ((f.startswith("qa/prime/") or
                                       f.startswith("skb/prime/")) and f.find("raw") == -1)]

        # Download and save each file in the specified folder
        for file in tqdm(files):
            _ = hf_hub_download(self.hf_repo_id,
                                file,
                                repo_type="dataset",
                                local_dir=self.local_dir)

        # Unzip the processed files
        shutil.unpack_archive(
            os.path.join(self.local_dir, "skb/prime/processed.zip"),
            os.path.join(self.local_dir, "skb/prime/")
        )

    # Load StarkQA dataframe
    starkqa = pd.read_csv(
        os.path.join(self.local_dir, "qa/prime/stark_qa/stark_qa.csv"),
        low_memory=False)

    # Read split indices
    qa_indices = sorted(starkqa['id'].tolist())
    starkqa_split_idx = {}
    for split in ['train', 'val', 'test', 'test-0.1']:
        indices_file = os.path.join(self.local_dir, "qa/prime/split", f'{split}.index')
        with open(indices_file, 'r', encoding='utf-8') as f:
            indices = f.read().strip().split('\n')
        query_ids = [int(idx) for idx in indices]
        starkqa_split_idx[split] = np.array(
            [qa_indices.index(query_id) for query_id in query_ids]
        )

    # Load the node info of PrimeKG preprocessed for StarkQA
    with open(os.path.join(self.local_dir, 'skb/prime/processed/node_info.pkl'), 'rb') as f:
        starkqa_node_info = pickle.load(f)

    return starkqa, starkqa_split_idx, starkqa_node_info

get_node_embeddings()

Get the node embeddings of StarkQAPrimeKG dataset.

Returns:

Type Description
dict

The node embeddings of StarkQAPrimeKG dataset.

Source code in aiagents4pharma/talk2knowledgegraphs/datasets/starkqa_primekg.py
194
195
196
197
198
199
200
201
def get_node_embeddings(self) -> dict:
    """
    Get the node embeddings of StarkQAPrimeKG dataset.

    Returns:
        The node embeddings of StarkQAPrimeKG dataset.
    """
    return self.node_emb_dict

get_query_embeddings()

Get the query embeddings of StarkQAPrimeKG dataset.

Returns:

Type Description
dict

The query embeddings of StarkQAPrimeKG dataset.

Source code in aiagents4pharma/talk2knowledgegraphs/datasets/starkqa_primekg.py
185
186
187
188
189
190
191
192
def get_query_embeddings(self) -> dict:
    """
    Get the query embeddings of StarkQAPrimeKG dataset.

    Returns:
        The query embeddings of StarkQAPrimeKG dataset.
    """
    return self.query_emb_dict

get_starkqa()

Get the dataframe of StarkQAPrimeKG dataset, containing the QA pairs.

Returns:

Type Description
DataFrame

The nodes dataframe of PrimeKG dataset.

Source code in aiagents4pharma/talk2knowledgegraphs/datasets/starkqa_primekg.py
158
159
160
161
162
163
164
165
def get_starkqa(self) -> pd.DataFrame:
    """
    Get the dataframe of StarkQAPrimeKG dataset, containing the QA pairs.

    Returns:
        The nodes dataframe of PrimeKG dataset.
    """
    return self.starkqa

get_starkqa_node_info()

Get the node information of StarkQAPrimeKG dataset.

Returns:

Type Description
dict

The node information of StarkQAPrimeKG dataset.

Source code in aiagents4pharma/talk2knowledgegraphs/datasets/starkqa_primekg.py
176
177
178
179
180
181
182
183
def get_starkqa_node_info(self) -> dict:
    """
    Get the node information of StarkQAPrimeKG dataset.

    Returns:
        The node information of StarkQAPrimeKG dataset.
    """
    return self.starkqa_node_info

get_starkqa_split_indicies()

Get the split indices of StarkQAPrimeKG dataset.

Returns:

Type Description
dict

The split indices of StarkQAPrimeKG dataset.

Source code in aiagents4pharma/talk2knowledgegraphs/datasets/starkqa_primekg.py
167
168
169
170
171
172
173
174
def get_starkqa_split_indicies(self) -> dict:
    """
    Get the split indices of StarkQAPrimeKG dataset.

    Returns:
        The split indices of StarkQAPrimeKG dataset.
    """
    return self.starkqa_split_idx

load_data()

Load the StarkQAPrimeKG dataset into pandas DataFrame of QA pairs, dictionary of split indices, and node information.

Source code in aiagents4pharma/talk2knowledgegraphs/datasets/starkqa_primekg.py
146
147
148
149
150
151
152
153
154
155
def load_data(self):
    """
    Load the StarkQAPrimeKG dataset into pandas DataFrame of QA pairs,
    dictionary of split indices, and node information.
    """
    print("Loading StarkQAPrimeKG dataset...")
    self.starkqa, self.starkqa_split_idx, self.starkqa_node_info = self._load_stark_repo()

    print("Loading StarkQAPrimeKG embeddings...")
    self.query_emb_dict, self.node_emb_dict = self._load_stark_embeddings()

setup()

A method to set up the dataset.

Source code in aiagents4pharma/talk2knowledgegraphs/datasets/starkqa_primekg.py
44
45
46
47
48
49
def setup(self):
    """
    A method to set up the dataset.
    """
    # Make the directory if it doesn't exist
    os.makedirs(os.path.dirname(self.local_dir), exist_ok=True)