StarkQA-PrimeKG Evaluation¶
In this tutorial, we will perform a question-and-answering task on the StarkQA-PrimeKG dataset by utilizing textual embeddings over the queries and nodes.
The following are important publication and repository links related to this work.
# Import necessary libraries
import os
import ast
from typing import Any, Union, List, Dict, Optional
from tqdm import tqdm
import pandas as pd
import numpy as np
import torch
from torchmetrics.functional.retrieval.hit_rate import retrieval_hit_rate
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank
from torchmetrics.functional.retrieval.recall import retrieval_recall
from torchmetrics.functional.retrieval.precision import retrieval_precision
from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision
from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg
from torchmetrics.functional.retrieval.r_precision import retrieval_r_precision
import sys
sys.path.append('../../..')
from aiagents4pharma.talk2knowledgegraphs.datasets.starkqa_primekg import StarkQAPrimeKG
Load StarkQA-PrimeKG¶
The StarkQAPrimeKG
allows to load the data from the HuggingFace Hub if the data is not available locally.
Otherwise, the data is loaded from the local directory as defined in the local_dir
.
# Define starkqa primekg data by providing a local directory where the data is stored
starkqa_data = StarkQAPrimeKG(local_dir="../../../../data/starkqa_primekg/")
To load the dataframes of StarkQA and its split, we just need a method as follows.
# Invoke a method to load the data
starkqa_data.load_data()
# Get the StarkQAPrimeKG data, which are the QA pairs, split indices, and the node information
starkqa_split_indices = starkqa_data.get_starkqa_split_indicies()
Loading StarkQAPrimeKG dataset... ../../../../data/starkqa_primekg/qa/prime/stark_qa/stark_qa.csv already exists. Loading the data from the local directory. Loading StarkQAPrimeKG embeddings...
/home/awmulyadi/Repositories/office/AIAgents4Pharma/docs/notebooks/talk2knowledgegraphs/../../../aiagents4pharma/talk2knowledgegraphs/datasets/starkqa_primekg.py:141: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. query_emb_dict = torch.load(query_emb_path) /home/awmulyadi/Repositories/office/AIAgents4Pharma/docs/notebooks/talk2knowledgegraphs/../../../aiagents4pharma/talk2knowledgegraphs/datasets/starkqa_primekg.py:142: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. node_emb_dict = torch.load(node_emb_path)
Load Textual Embedding¶
In tutorial_starkqa_primekg_textual_embeddings.ipynb
notebook, we have shown how to obtain textual embeddings over the query, node and edge information.
Therefore, we just need to load the pre-processed embeddings.
We can retrieve query embedding as follows:
# # Load the enriched nodes dataframe from parquet file
starkqa_df = pd.read_parquet(os.path.join(starkqa_data.local_dir, 'starkqaprimekg_queries_embedded.parquet'), engine='pyarrow')
# Basic conversion of the answer_ids from string to list
starkqa_df['answer_ids'] = starkqa_df.apply(lambda x: ast.literal_eval(x['answer_ids']), axis=1)
# Check the dataframe of query embeddings
starkqa_df.head()
id | query | answer_ids | query_embedded | |
---|---|---|---|---|
0 | 0 | Could you identify any skin diseases associate... | [95886] | [0.050286733, 0.0050845086, 0.06326583, 0.0360... |
1 | 1 | What drugs target the CYP3A4 enzyme and are us... | [15450] | [0.009708624, 0.01434415, -0.07435164, -0.0736... |
2 | 2 | What is the name of the condition characterize... | [98851, 98853] | [-0.058651656, -0.0031773308, 0.015822958, -0.... |
3 | 3 | What drugs are used to treat epithelioid sarco... | [15698] | [-0.035772394, 0.064148985, -0.018727051, -0.0... |
4 | 4 | Can you supply a compilation of genes and prot... | [7161, 22045] | [-0.072102964, -0.008873461, -0.007186646, 0.0... |
Next, we can retrieve node embedding as follows:
# # Load the enriched nodes dataframe from parquet file
primekg_nodes = pd.read_parquet(os.path.join(starkqa_data.local_dir, 'starkqaprimekg_nodes_embedded.parquet'), engine='pyarrow')
# Check the dataframe of node embeddings
primekg_nodes.head()
node_id | node_name | node_type | enriched_node | x | |
---|---|---|---|---|---|
0 | 0 | PHYHIP | gene/protein | PHYHIP belongs to gene/protein category. Enabl... | [-0.06876933, 0.00096770556, -0.0630331, -0.04... |
1 | 1 | GPANK1 | gene/protein | GPANK1 belongs to gene/protein category. This ... | [-0.08932163, 0.031602174, -0.102335155, -0.03... |
2 | 2 | ZRSR2 | gene/protein | ZRSR2 belongs to gene/protein category. This g... | [-0.10059608, -0.020288778, 0.008750704, 0.003... |
3 | 3 | NRF1 | gene/protein | NRF1 belongs to gene/protein category. This ge... | [-0.09837414, -0.02768978, -0.061966445, 0.026... |
4 | 4 | PI4KA | gene/protein | PI4KA belongs to gene/protein category. This g... | [-0.03965294, -0.0017360917, -0.12756099, -5.2... |
Prepare Metrics¶
In order to measure the performance of the model, we need to prepare metrics for evaluation purposes.
# Metrics
eval_metrics = [
"mrr",
"map",
"rprecision",
"recall@5",
"recall@10",
"recall@20",
"recall@50",
"recall@100",
"hit@1",
"hit@3",
"hit@5",
"hit@10",
"hit@20",
"hit@50",
]
eval_csv = pd.DataFrame(columns=["idx", "query_id", "pred_rank"] + eval_metrics)
eval_csv
idx | query_id | pred_rank | mrr | map | rprecision | recall@5 | recall@10 | recall@20 | recall@50 | recall@100 | hit@1 | hit@3 | hit@5 | hit@10 | hit@20 | hit@50 |
---|
# Define the evaluation function based on StarkQA evaluation
# https://github.com/snap-stanford/stark/blob/main/stark_qa/evaluator.py
def evaluate(candidate_ids: List[int],
pred_ids: List[int],
pred: torch.Tensor,
answer_ids : Union[torch.LongTensor, List[int]],
metrics: List[str] = ['mrr', 'hit@3', 'recall@20'],
device: str = 'cpu') -> Dict[str, float]:
"""
Evaluate the model predictions.
Args:
candidate_ids: List of candidate node ids.
pred_ids: List of predicted node ids.
pred: List of predicted node names.
answer_ids: List of correct node ids.
metrics: List of metrics to compute.
device: Device to use.
"""
all_pred = torch.ones((max(candidate_ids) + 1, pred.shape[1]), dtype=torch.float) * (pred.min() - 1)
all_pred[pred_ids, :] = pred
all_pred = all_pred[candidate_ids].t().to(device)
bool_gd = torch.zeros((max(candidate_ids) + 1, pred.shape[1]), dtype=torch.bool)
bool_gd[torch.concat(answer_ids), torch.repeat_interleave(torch.arange(len(answer_ids)), torch.tensor(list(map(len, answer_ids))))] = True
bool_gd = bool_gd[candidate_ids].t().to(device)
results = []
for i in range(len(answer_ids)):
eval_metrics = {}
for metric in metrics:
k = int(metric.split('@')[-1]) if '@' in metric else None
if metric == 'mrr':
result = retrieval_reciprocal_rank(all_pred[i], bool_gd[i])
elif metric == 'rprecision':
result = retrieval_r_precision(all_pred[i], bool_gd[i])
elif 'hit' in metric:
result = retrieval_hit_rate(all_pred[i], bool_gd[i], top_k=k)
elif 'recall' in metric:
result = retrieval_recall(all_pred[i], bool_gd[i], top_k=k)
elif 'precision' in metric:
result = retrieval_precision(all_pred[i], bool_gd[i], top_k=k)
elif 'map' in metric:
result = retrieval_average_precision(all_pred[i], bool_gd[i], top_k=k)
elif 'ndcg' in metric:
result = retrieval_normalized_dcg(all_pred[i], bool_gd[i], top_k=k)
eval_metrics[metric] = float(result)
results.append(eval_metrics)
return results
Vector Similarity Search (VSS)¶
A particular model that we will evaluate is a simple vector similarity model, called vector similarity search (VSS).
It measures the similarity between the embeddings of the query against the nodes of StarkQA-PrimeKG to retrieve the answer candidates.
Please refer to the paper and the following code:
# Parameters
split = 'test-0.1' # For simplicity, we use the small set of test data
batch_size = 256
model = 'vss'
save_topk = 100 # Top-K predictions to be considered
# Use testing split indices
indices = starkqa_split_indices[split].tolist()
# Check device availability
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device
'cuda:0'
# Prepare variables for evaluation
candidate_ids = torch.LongTensor(primekg_nodes.node_id.tolist())
# Loop through the test data
for batch_idx in tqdm(range(0, len(indices), batch_size or len(indices))):
if batch_idx == 0:
batch_indices = [idx for idx in indices[batch_idx : min(batch_idx + batch_size, len(indices))]]
if len(batch_indices) == 0:
continue
# Get the query ids, queries, queries_embedded, and answer ids from dataframe
query_ids, queries, queries_embedded, answer_ids = zip(
*[starkqa_df[['id', 'query', 'query_embedded', 'answer_ids']].iloc[idx] for idx in batch_indices]
)
# Using VSS, we calculate similarities between query and candidate embeddings
similarity = torch.matmul(torch.tensor(np.array(queries_embedded)).to(device),
torch.tensor(np.array(primekg_nodes.x.values.tolist())).T.to(device)).cpu()
# Measure performance
pred_ids = candidate_ids
pred = similarity.t()
answer_ids = [torch.LongTensor(answer_id) for answer_id in answer_ids]
results = evaluate(candidate_ids, pred_ids, pred, answer_ids, metrics=eval_metrics)
for i, result in enumerate(results):
result["idx"], result["query_id"] = batch_indices[i], query_ids[i]
result["pred_rank"] = pred_ids[torch.argsort(pred[:,i], descending=True)[save_topk]].tolist()
eval_csv = pd.concat([eval_csv, pd.DataFrame([result]).astype(eval_csv.dtypes)], ignore_index=True)
100%|██████████| 2/2 [00:30<00:00, 15.35s/it]
We can further check the evaluation results within the eval_csv
dataframe to observe the performance of the model.
# Check the evaluation results
eval_csv
idx | query_id | pred_rank | mrr | map | rprecision | recall@5 | recall@10 | recall@20 | recall@50 | recall@100 | hit@1 | hit@3 | hit@5 | hit@10 | hit@20 | hit@50 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 9 | 9 | 3163 | 0.090909 | 0.038068 | 0.0 | 0.0 | 0.0 | 0.333333 | 0.333333 | 0.333333 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 1.0 |
1 | 26 | 26 | 31190 | 0.333333 | 0.333333 | 0.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 0.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 |
2 | 88 | 88 | 99533 | 0.076923 | 0.076923 | 0.0 | 0.0 | 0.0 | 1.0 | 1.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 1.0 |
3 | 195 | 195 | 36650 | 0.008696 | 0.008696 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
4 | 231 | 231 | 22840 | 0.010638 | 0.010638 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
251 | 9977 | 9977 | 71495 | 0.000312 | 0.000312 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
252 | 9992 | 9992 | 72393 | 0.012658 | 0.006716 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.5 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
253 | 9996 | 9996 | 103422 | 0.000014 | 0.000014 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
254 | 10139 | 10139 | 127757 | 0.034483 | 0.034483 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 |
255 | 10191 | 10191 | 22104 | 0.000015 | 0.000015 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
256 rows × 17 columns
Finally, we can take the average of the evaluation metrics over the test samples.
# Taking the mean of the evaluation metrics
eval_csv[eval_metrics].mean()
mrr 0.127865 map 0.085517 rprecision 0.049698 recall@5 0.113364 recall@10 0.171124 recall@20 0.240542 recall@50 0.316177 recall@100 0.365272 hit@1 0.070312 hit@3 0.140625 hit@5 0.183594 hit@10 0.25 hit@20 0.324219 hit@50 0.414062 dtype: object