Skip to content

PCST (Milvus Multimodal)

Exctraction of multimodal subgraph using Prize-Collecting Steiner Tree (PCST) algorithm.

MultimodalPCSTPruning

Bases: NamedTuple

Prize-Collecting Steiner Tree (PCST) pruning algorithm implementation inspired by G-Retriever (He et al., 'G-Retriever: Retrieval-Augmented Generation for Textual Graph Understanding and Question Answering', NeurIPS 2024) paper. https://arxiv.org/abs/2402.07630 https://github.com/XiaoxinHe/G-Retriever/blob/main/src/dataset/utils/retrieval.py

Parameters:

Name Type Description Default
topk

The number of top nodes to consider.

required
topk_e

The number of top edges to consider.

required
cost_e

The cost of the edges.

required
c_const

The constant value for the cost of the edges computation.

required
root

The root node of the subgraph, -1 for unrooted.

required
num_clusters

The number of clusters.

required
pruning

The pruning strategy to use.

required
verbosity_level

The verbosity level.

required
Source code in aiagents4pharma/talk2knowledgegraphs/utils/extractions/milvus_multimodal_pcst.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
class MultimodalPCSTPruning(NamedTuple):
    """
    Prize-Collecting Steiner Tree (PCST) pruning algorithm implementation inspired by G-Retriever
    (He et al., 'G-Retriever: Retrieval-Augmented Generation for Textual Graph Understanding and
    Question Answering', NeurIPS 2024) paper.
    https://arxiv.org/abs/2402.07630
    https://github.com/XiaoxinHe/G-Retriever/blob/main/src/dataset/utils/retrieval.py

    Args:
        topk: The number of top nodes to consider.
        topk_e: The number of top edges to consider.
        cost_e: The cost of the edges.
        c_const: The constant value for the cost of the edges computation.
        root: The root node of the subgraph, -1 for unrooted.
        num_clusters: The number of clusters.
        pruning: The pruning strategy to use.
        verbosity_level: The verbosity level.
    """
    topk: int = 3
    topk_e: int = 3
    cost_e: float = 0.5
    c_const: float = 0.01
    root: int = -1
    num_clusters: int = 1
    pruning: str = "gw"
    verbosity_level: int = 0
    use_description: bool = False
    metric_type: str = "IP"  # Inner Product

    def prepare_collections(self, cfg: dict, modality: str) -> dict:
        """
        Prepare the collections for nodes, node-type specific nodes, and edges in Milvus.

        Args:
            cfg: The configuration dictionary containing the Milvus setup.
            modality: The modality to use for the subgraph extraction.

        Returns:
            A dictionary containing the collections of nodes, node-type specific nodes, and edges.
        """
        # Initialize the collections dictionary
        colls = {}

        # Load the collection for nodes
        colls["nodes"] = Collection(name=f"{cfg.milvus_db.database_name}_nodes")

        if modality != "prompt":
            # Load the collection for the specific node type
            colls["nodes_type"] = Collection(
                f"{cfg.milvus_db.database_name}_nodes_{modality.replace('/', '_')}"
            )

        # Load the collection for edges
        colls["edges"] = Collection(name=f"{cfg.milvus_db.database_name}_edges")

        # Load the collections
        for coll in colls.values():
            coll.load()

        return colls

    def _compute_node_prizes(self,
                             query_emb: list,
                             colls: dict) -> dict:
        """
        Compute the node prizes based on the cosine similarity between the query and nodes.

        Args:
            query_emb: The query embedding. This can be an embedding of
                a prompt, sequence, or any other feature to be used for the subgraph extraction.
            colls: The collections of nodes, node-type specific nodes, and edges in Milvus.

        Returns:
            The prizes of the nodes.
        """
        # Intialize several variables
        topk = min(self.topk, colls["nodes"].num_entities)
        n_prizes = py.zeros(colls["nodes"].num_entities, dtype=py.float32)

        # Calculate cosine similarity for text features and update the score
        if self.use_description:
            # Search the collection with the text embedding
            res = colls["nodes"].search(
                data=[query_emb],
                anns_field="desc_emb",
                param={"metric_type": self.metric_type},
                limit=topk,
                output_fields=["node_id"])
        else:
            # Search the collection with the query embedding
            res = colls["nodes_type"].search(
                data=[query_emb],
                anns_field="feat_emb",
                param={"metric_type": self.metric_type},
                limit=topk,
                output_fields=["node_id"])

        # Update the prizes based on the search results
        n_prizes[[r.id for r in res[0]]] = py.arange(topk, 0, -1).astype(py.float32)

        return n_prizes

    def _compute_edge_prizes(self,
                             text_emb: list,
                             colls: dict) -> py.ndarray:
        """
        Compute the node prizes based on the cosine similarity between the query and nodes.

        Args:
            text_emb: The textual description embedding.
            colls: The collections of nodes, node-type specific nodes, and edges in Milvus.

        Returns:
            The prizes of the nodes.
        """
        # Intialize several variables
        topk_e = min(self.topk_e, colls["edges"].num_entities)
        e_prizes = py.zeros(colls["edges"].num_entities, dtype=py.float32)

        # Search the collection with the query embedding
        res = colls["edges"].search(
            data=[text_emb],
            anns_field="feat_emb",
            param={"metric_type": self.metric_type},
            limit=topk_e, # Only retrieve the top-k edges
            # limit=colls["edges"].num_entities,
            output_fields=["head_id", "tail_id"])

        # Update the prizes based on the search results
        e_prizes[[r.id for r in res[0]]] = [r.score for r in res[0]]

        # Further process the edge_prizes
        unique_prizes, inverse_indices = py.unique(e_prizes, return_inverse=True)
        topk_e_values = unique_prizes[py.argsort(-unique_prizes)[:topk_e]]
        # e_prizes[e_prizes < topk_e_values[-1]] = 0.0
        last_topk_e_value = topk_e
        for k in range(topk_e):
            indices = inverse_indices == (unique_prizes == topk_e_values[k]).nonzero()[0]
            value = min((topk_e - k) / indices.sum().item(), last_topk_e_value)
            e_prizes[indices] = value
            last_topk_e_value = value * (1 - self.c_const)

        return e_prizes

    def compute_prizes(self,
                       text_emb: list,
                       query_emb: list,
                       colls: dict) -> dict:
        """
        Compute the node prizes based on the cosine similarity between the query and nodes,
        as well as the edge prizes based on the cosine similarity between the query and edges.
        Note that the node and edge embeddings shall use the same embedding model and dimensions
        with the query.

        Args:
            text_emb: The textual description embedding.
            query_emb: The query embedding. This can be an embedding of
                a prompt, sequence, or any other feature to be used for the subgraph extraction.
            colls: The collections of nodes, node-type specific nodes, and edges in Milvus.

        Returns:
            The prizes of the nodes and edges.
        """
        # Compute prizes for nodes
        logger.log(logging.INFO, "_compute_node_prizes")
        n_prizes = self._compute_node_prizes(query_emb, colls)

        # Compute prizes for edges
        logger.log(logging.INFO, "_compute_edge_prizes")
        e_prizes = self._compute_edge_prizes(text_emb, colls)

        return {"nodes": n_prizes, "edges": e_prizes}

    def compute_subgraph_costs(self,
                               edge_index: py.ndarray,
                               num_nodes: int,
                               prizes: dict) -> Tuple[py.ndarray, py.ndarray, py.ndarray]:
        """
        Compute the costs in constructing the subgraph proposed by G-Retriever paper.

        Args:
            edge_index: The edge index of the graph, consisting of source and destination nodes.
            num_nodes: The number of nodes in the graph.
            prizes: The prizes of the nodes and the edges.

        Returns:
            edges: The edges of the subgraph, consisting of edges and number of edges without
                virtual edges.
            prizes: The prizes of the subgraph.
            costs: The costs of the subgraph.
        """
        # Initialize several variables
        real_ = {}
        virt_ = {}

        # Update edge cost threshold
        updated_cost_e = min(
            self.cost_e,
            py.max(prizes["edges"]).item() * (1 - self.c_const / 2),
        )

        # Masks for real and virtual edges
        logger.log(logging.INFO, "Creating masks for real and virtual edges")
        real_["mask"] = prizes["edges"] <= updated_cost_e
        virt_["mask"] = ~real_["mask"]

        # Real edge indices
        logger.log(logging.INFO, "Computing real edges")
        real_["indices"] = py.nonzero(real_["mask"])[0]
        real_["src"] = edge_index[0][real_["indices"]]
        real_["dst"] = edge_index[1][real_["indices"]]
        real_["edges"] = py.stack([real_["src"], real_["dst"]], axis=1)
        real_["costs"] = updated_cost_e - prizes["edges"][real_["indices"]]

        # Edge index mapping: local real edge idx -> original global index
        logger.log(logging.INFO, "Creating mapping for real edges")
        mapping_edges = dict(zip(range(len(real_["indices"])), real_["indices"].tolist()))

        # Virtual edge handling
        logger.log(logging.INFO, "Computing virtual edges")
        virt_["indices"] = py.nonzero(virt_["mask"])[0]
        virt_["src"] = edge_index[0][virt_["indices"]]
        virt_["dst"] = edge_index[1][virt_["indices"]]
        virt_["prizes"] = prizes["edges"][virt_["indices"]] - updated_cost_e

        # Generate virtual node IDs
        logger.log(logging.INFO, "Generating virtual node IDs")
        virt_["num"] = virt_["indices"].shape[0]
        virt_["node_ids"] = py.arange(num_nodes, num_nodes + virt_["num"])

        # Virtual edges: (src → virtual), (virtual → dst)
        logger.log(logging.INFO, "Creating virtual edges")
        virt_["edges_1"] = py.stack([virt_["src"], virt_["node_ids"]], axis=1)
        virt_["edges_2"] = py.stack([virt_["node_ids"], virt_["dst"]], axis=1)
        virt_["edges"] = py.concatenate([virt_["edges_1"],
                                         virt_["edges_2"]], axis=0)
        virt_["costs"] = py.zeros((virt_["edges"].shape[0],), dtype=real_["costs"].dtype)

        # Combine real and virtual edges/costs
        logger.log(logging.INFO, "Combining real and virtual edges/costs")
        all_edges = py.concatenate([real_["edges"], virt_["edges"]], axis=0)
        all_costs = py.concatenate([real_["costs"], virt_["costs"]], axis=0)

        # Final prizes
        logger.log(logging.INFO, "Getting final prizes")
        final_prizes = py.concatenate([prizes["nodes"], virt_["prizes"]], axis=0)

        # Mapping virtual node ID -> edge index in original graph
        logger.log(logging.INFO, "Creating mapping for virtual nodes")
        mapping_nodes = dict(zip(virt_["node_ids"].tolist(), virt_["indices"].tolist()))

        # Build return values
        logger.log(logging.INFO, "Building return values")
        edges_dict = {
            "edges": all_edges,
            "num_prior_edges": real_["edges"].shape[0],
        }
        mapping = {
            "edges": mapping_edges,
            "nodes": mapping_nodes,
        }

        return edges_dict, final_prizes, all_costs, mapping

    def get_subgraph_nodes_edges(self,
                                 num_nodes: int,
                                 vertices: py.ndarray,
                                 edges_dict: dict,
                                 mapping: dict) -> dict:
        """
        Get the selected nodes and edges of the subgraph based on the vertices and edges computed
        by the PCST algorithm.

        Args:
            num_nodes: The number of nodes in the graph.
            vertices: The vertices selected by the PCST algorithm.
            edges_dict: A dictionary containing the edges and the number of prior edges.
            mapping: A dictionary containing the mapping of nodes and edges.

        Returns:
            The selected nodes and edges of the extracted subgraph.
        """
        # Get edges information
        edges = edges_dict["edges"]
        num_prior_edges = edges_dict["num_prior_edges"]
        # Get edges information
        edges = edges_dict["edges"]
        num_prior_edges = edges_dict["num_prior_edges"]
        # Retrieve the selected nodes and edges based on the given vertices and edges
        subgraph_nodes = vertices[vertices < num_nodes]
        subgraph_edges = [mapping["edges"][e.item()] for e in edges if e < num_prior_edges]
        virtual_vertices = vertices[vertices >= num_nodes]
        if len(virtual_vertices) > 0:
            virtual_vertices = vertices[vertices >= num_nodes]
            virtual_edges = [mapping["nodes"][i.item()] for i in virtual_vertices]
            subgraph_edges = py.array(subgraph_edges + virtual_edges)
        edge_index = edges_dict["edge_index"][:, subgraph_edges]
        subgraph_nodes = py.unique(
            py.concatenate(
                [subgraph_nodes, edge_index[0], edge_index[1]]
            )
        )

        return {"nodes": subgraph_nodes, "edges": subgraph_edges}

    def extract_subgraph(self,
                         text_emb: list,
                         query_emb: list,
                         modality: str,
                         cfg: dict) -> dict:
        """
        Perform the Prize-Collecting Steiner Tree (PCST) algorithm to extract the subgraph.

        Args:
            text_emb: The textual description embedding.
            query_emb: The query embedding. This can be an embedding of
                a prompt, sequence, or any other feature to be used for the subgraph extraction.
            modality: The modality to use for the subgraph extraction
                (e.g., "text", "sequence", "smiles").
            cfg: The configuration dictionary containing the Milvus setup.

        Returns:
            The selected nodes and edges of the subgraph.
        """
        # Load the collections for nodes
        logger.log(logging.INFO, "Preparing collections")
        colls = self.prepare_collections(cfg, modality)

        # Load cache edge index
        logger.log(logging.INFO, "Loading cache edge index")
        with open(cfg.milvus_db.cache_edge_index_path, "rb") as f:
            edge_index = pickle.load(f)
            edge_index = py.array(edge_index)

        # Assert the topk and topk_e values for subgraph retrieval
        assert self.topk > 0, "topk must be greater than or equal to 0"
        assert self.topk_e > 0, "topk_e must be greater than or equal to 0"

        # Retrieve the top-k nodes and edges based on the query embedding
        logger.log(logging.INFO, "compute_prizes")
        prizes = self.compute_prizes(text_emb, query_emb, colls)

        # Compute costs in constructing the subgraph
        logger.log(logging.INFO, "compute_subgraph_costs")
        edges_dict, prizes, costs, mapping = self.compute_subgraph_costs(
            edge_index, colls["nodes"].num_entities, prizes)

        # Retrieve the subgraph using the PCST algorithm
        logger.log(logging.INFO, "Running PCST algorithm")
        result_vertices, result_edges = pcst_fast.pcst_fast(
            edges_dict["edges"].tolist(),
            prizes.tolist(),
            costs.tolist(),
            self.root,
            self.num_clusters,
            self.pruning,
            self.verbosity_level,
        )

        # Get subgraph nodes and edges based on the result of the PCST algorithm
        logger.log(logging.INFO, "Getting subgraph nodes and edges")
        subgraph = self.get_subgraph_nodes_edges(
            colls["nodes"].num_entities,
            py.asarray(result_vertices),
            {"edges": py.asarray(result_edges),
             "num_prior_edges": edges_dict["num_prior_edges"],
             "edge_index": edge_index},
            mapping)
        print(subgraph)

        return subgraph

_compute_edge_prizes(text_emb, colls)

Compute the node prizes based on the cosine similarity between the query and nodes.

Parameters:

Name Type Description Default
text_emb list

The textual description embedding.

required
colls dict

The collections of nodes, node-type specific nodes, and edges in Milvus.

required

Returns:

Type Description
ndarray

The prizes of the nodes.

Source code in aiagents4pharma/talk2knowledgegraphs/utils/extractions/milvus_multimodal_pcst.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def _compute_edge_prizes(self,
                         text_emb: list,
                         colls: dict) -> py.ndarray:
    """
    Compute the node prizes based on the cosine similarity between the query and nodes.

    Args:
        text_emb: The textual description embedding.
        colls: The collections of nodes, node-type specific nodes, and edges in Milvus.

    Returns:
        The prizes of the nodes.
    """
    # Intialize several variables
    topk_e = min(self.topk_e, colls["edges"].num_entities)
    e_prizes = py.zeros(colls["edges"].num_entities, dtype=py.float32)

    # Search the collection with the query embedding
    res = colls["edges"].search(
        data=[text_emb],
        anns_field="feat_emb",
        param={"metric_type": self.metric_type},
        limit=topk_e, # Only retrieve the top-k edges
        # limit=colls["edges"].num_entities,
        output_fields=["head_id", "tail_id"])

    # Update the prizes based on the search results
    e_prizes[[r.id for r in res[0]]] = [r.score for r in res[0]]

    # Further process the edge_prizes
    unique_prizes, inverse_indices = py.unique(e_prizes, return_inverse=True)
    topk_e_values = unique_prizes[py.argsort(-unique_prizes)[:topk_e]]
    # e_prizes[e_prizes < topk_e_values[-1]] = 0.0
    last_topk_e_value = topk_e
    for k in range(topk_e):
        indices = inverse_indices == (unique_prizes == topk_e_values[k]).nonzero()[0]
        value = min((topk_e - k) / indices.sum().item(), last_topk_e_value)
        e_prizes[indices] = value
        last_topk_e_value = value * (1 - self.c_const)

    return e_prizes

_compute_node_prizes(query_emb, colls)

Compute the node prizes based on the cosine similarity between the query and nodes.

Parameters:

Name Type Description Default
query_emb list

The query embedding. This can be an embedding of a prompt, sequence, or any other feature to be used for the subgraph extraction.

required
colls dict

The collections of nodes, node-type specific nodes, and edges in Milvus.

required

Returns:

Type Description
dict

The prizes of the nodes.

Source code in aiagents4pharma/talk2knowledgegraphs/utils/extractions/milvus_multimodal_pcst.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def _compute_node_prizes(self,
                         query_emb: list,
                         colls: dict) -> dict:
    """
    Compute the node prizes based on the cosine similarity between the query and nodes.

    Args:
        query_emb: The query embedding. This can be an embedding of
            a prompt, sequence, or any other feature to be used for the subgraph extraction.
        colls: The collections of nodes, node-type specific nodes, and edges in Milvus.

    Returns:
        The prizes of the nodes.
    """
    # Intialize several variables
    topk = min(self.topk, colls["nodes"].num_entities)
    n_prizes = py.zeros(colls["nodes"].num_entities, dtype=py.float32)

    # Calculate cosine similarity for text features and update the score
    if self.use_description:
        # Search the collection with the text embedding
        res = colls["nodes"].search(
            data=[query_emb],
            anns_field="desc_emb",
            param={"metric_type": self.metric_type},
            limit=topk,
            output_fields=["node_id"])
    else:
        # Search the collection with the query embedding
        res = colls["nodes_type"].search(
            data=[query_emb],
            anns_field="feat_emb",
            param={"metric_type": self.metric_type},
            limit=topk,
            output_fields=["node_id"])

    # Update the prizes based on the search results
    n_prizes[[r.id for r in res[0]]] = py.arange(topk, 0, -1).astype(py.float32)

    return n_prizes

compute_prizes(text_emb, query_emb, colls)

Compute the node prizes based on the cosine similarity between the query and nodes, as well as the edge prizes based on the cosine similarity between the query and edges. Note that the node and edge embeddings shall use the same embedding model and dimensions with the query.

Parameters:

Name Type Description Default
text_emb list

The textual description embedding.

required
query_emb list

The query embedding. This can be an embedding of a prompt, sequence, or any other feature to be used for the subgraph extraction.

required
colls dict

The collections of nodes, node-type specific nodes, and edges in Milvus.

required

Returns:

Type Description
dict

The prizes of the nodes and edges.

Source code in aiagents4pharma/talk2knowledgegraphs/utils/extractions/milvus_multimodal_pcst.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
def compute_prizes(self,
                   text_emb: list,
                   query_emb: list,
                   colls: dict) -> dict:
    """
    Compute the node prizes based on the cosine similarity between the query and nodes,
    as well as the edge prizes based on the cosine similarity between the query and edges.
    Note that the node and edge embeddings shall use the same embedding model and dimensions
    with the query.

    Args:
        text_emb: The textual description embedding.
        query_emb: The query embedding. This can be an embedding of
            a prompt, sequence, or any other feature to be used for the subgraph extraction.
        colls: The collections of nodes, node-type specific nodes, and edges in Milvus.

    Returns:
        The prizes of the nodes and edges.
    """
    # Compute prizes for nodes
    logger.log(logging.INFO, "_compute_node_prizes")
    n_prizes = self._compute_node_prizes(query_emb, colls)

    # Compute prizes for edges
    logger.log(logging.INFO, "_compute_edge_prizes")
    e_prizes = self._compute_edge_prizes(text_emb, colls)

    return {"nodes": n_prizes, "edges": e_prizes}

compute_subgraph_costs(edge_index, num_nodes, prizes)

Compute the costs in constructing the subgraph proposed by G-Retriever paper.

Parameters:

Name Type Description Default
edge_index ndarray

The edge index of the graph, consisting of source and destination nodes.

required
num_nodes int

The number of nodes in the graph.

required
prizes dict

The prizes of the nodes and the edges.

required

Returns:

Name Type Description
edges ndarray

The edges of the subgraph, consisting of edges and number of edges without virtual edges.

prizes ndarray

The prizes of the subgraph.

costs ndarray

The costs of the subgraph.

Source code in aiagents4pharma/talk2knowledgegraphs/utils/extractions/milvus_multimodal_pcst.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
def compute_subgraph_costs(self,
                           edge_index: py.ndarray,
                           num_nodes: int,
                           prizes: dict) -> Tuple[py.ndarray, py.ndarray, py.ndarray]:
    """
    Compute the costs in constructing the subgraph proposed by G-Retriever paper.

    Args:
        edge_index: The edge index of the graph, consisting of source and destination nodes.
        num_nodes: The number of nodes in the graph.
        prizes: The prizes of the nodes and the edges.

    Returns:
        edges: The edges of the subgraph, consisting of edges and number of edges without
            virtual edges.
        prizes: The prizes of the subgraph.
        costs: The costs of the subgraph.
    """
    # Initialize several variables
    real_ = {}
    virt_ = {}

    # Update edge cost threshold
    updated_cost_e = min(
        self.cost_e,
        py.max(prizes["edges"]).item() * (1 - self.c_const / 2),
    )

    # Masks for real and virtual edges
    logger.log(logging.INFO, "Creating masks for real and virtual edges")
    real_["mask"] = prizes["edges"] <= updated_cost_e
    virt_["mask"] = ~real_["mask"]

    # Real edge indices
    logger.log(logging.INFO, "Computing real edges")
    real_["indices"] = py.nonzero(real_["mask"])[0]
    real_["src"] = edge_index[0][real_["indices"]]
    real_["dst"] = edge_index[1][real_["indices"]]
    real_["edges"] = py.stack([real_["src"], real_["dst"]], axis=1)
    real_["costs"] = updated_cost_e - prizes["edges"][real_["indices"]]

    # Edge index mapping: local real edge idx -> original global index
    logger.log(logging.INFO, "Creating mapping for real edges")
    mapping_edges = dict(zip(range(len(real_["indices"])), real_["indices"].tolist()))

    # Virtual edge handling
    logger.log(logging.INFO, "Computing virtual edges")
    virt_["indices"] = py.nonzero(virt_["mask"])[0]
    virt_["src"] = edge_index[0][virt_["indices"]]
    virt_["dst"] = edge_index[1][virt_["indices"]]
    virt_["prizes"] = prizes["edges"][virt_["indices"]] - updated_cost_e

    # Generate virtual node IDs
    logger.log(logging.INFO, "Generating virtual node IDs")
    virt_["num"] = virt_["indices"].shape[0]
    virt_["node_ids"] = py.arange(num_nodes, num_nodes + virt_["num"])

    # Virtual edges: (src → virtual), (virtual → dst)
    logger.log(logging.INFO, "Creating virtual edges")
    virt_["edges_1"] = py.stack([virt_["src"], virt_["node_ids"]], axis=1)
    virt_["edges_2"] = py.stack([virt_["node_ids"], virt_["dst"]], axis=1)
    virt_["edges"] = py.concatenate([virt_["edges_1"],
                                     virt_["edges_2"]], axis=0)
    virt_["costs"] = py.zeros((virt_["edges"].shape[0],), dtype=real_["costs"].dtype)

    # Combine real and virtual edges/costs
    logger.log(logging.INFO, "Combining real and virtual edges/costs")
    all_edges = py.concatenate([real_["edges"], virt_["edges"]], axis=0)
    all_costs = py.concatenate([real_["costs"], virt_["costs"]], axis=0)

    # Final prizes
    logger.log(logging.INFO, "Getting final prizes")
    final_prizes = py.concatenate([prizes["nodes"], virt_["prizes"]], axis=0)

    # Mapping virtual node ID -> edge index in original graph
    logger.log(logging.INFO, "Creating mapping for virtual nodes")
    mapping_nodes = dict(zip(virt_["node_ids"].tolist(), virt_["indices"].tolist()))

    # Build return values
    logger.log(logging.INFO, "Building return values")
    edges_dict = {
        "edges": all_edges,
        "num_prior_edges": real_["edges"].shape[0],
    }
    mapping = {
        "edges": mapping_edges,
        "nodes": mapping_nodes,
    }

    return edges_dict, final_prizes, all_costs, mapping

extract_subgraph(text_emb, query_emb, modality, cfg)

Perform the Prize-Collecting Steiner Tree (PCST) algorithm to extract the subgraph.

Parameters:

Name Type Description Default
text_emb list

The textual description embedding.

required
query_emb list

The query embedding. This can be an embedding of a prompt, sequence, or any other feature to be used for the subgraph extraction.

required
modality str

The modality to use for the subgraph extraction (e.g., "text", "sequence", "smiles").

required
cfg dict

The configuration dictionary containing the Milvus setup.

required

Returns:

Type Description
dict

The selected nodes and edges of the subgraph.

Source code in aiagents4pharma/talk2knowledgegraphs/utils/extractions/milvus_multimodal_pcst.py
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
def extract_subgraph(self,
                     text_emb: list,
                     query_emb: list,
                     modality: str,
                     cfg: dict) -> dict:
    """
    Perform the Prize-Collecting Steiner Tree (PCST) algorithm to extract the subgraph.

    Args:
        text_emb: The textual description embedding.
        query_emb: The query embedding. This can be an embedding of
            a prompt, sequence, or any other feature to be used for the subgraph extraction.
        modality: The modality to use for the subgraph extraction
            (e.g., "text", "sequence", "smiles").
        cfg: The configuration dictionary containing the Milvus setup.

    Returns:
        The selected nodes and edges of the subgraph.
    """
    # Load the collections for nodes
    logger.log(logging.INFO, "Preparing collections")
    colls = self.prepare_collections(cfg, modality)

    # Load cache edge index
    logger.log(logging.INFO, "Loading cache edge index")
    with open(cfg.milvus_db.cache_edge_index_path, "rb") as f:
        edge_index = pickle.load(f)
        edge_index = py.array(edge_index)

    # Assert the topk and topk_e values for subgraph retrieval
    assert self.topk > 0, "topk must be greater than or equal to 0"
    assert self.topk_e > 0, "topk_e must be greater than or equal to 0"

    # Retrieve the top-k nodes and edges based on the query embedding
    logger.log(logging.INFO, "compute_prizes")
    prizes = self.compute_prizes(text_emb, query_emb, colls)

    # Compute costs in constructing the subgraph
    logger.log(logging.INFO, "compute_subgraph_costs")
    edges_dict, prizes, costs, mapping = self.compute_subgraph_costs(
        edge_index, colls["nodes"].num_entities, prizes)

    # Retrieve the subgraph using the PCST algorithm
    logger.log(logging.INFO, "Running PCST algorithm")
    result_vertices, result_edges = pcst_fast.pcst_fast(
        edges_dict["edges"].tolist(),
        prizes.tolist(),
        costs.tolist(),
        self.root,
        self.num_clusters,
        self.pruning,
        self.verbosity_level,
    )

    # Get subgraph nodes and edges based on the result of the PCST algorithm
    logger.log(logging.INFO, "Getting subgraph nodes and edges")
    subgraph = self.get_subgraph_nodes_edges(
        colls["nodes"].num_entities,
        py.asarray(result_vertices),
        {"edges": py.asarray(result_edges),
         "num_prior_edges": edges_dict["num_prior_edges"],
         "edge_index": edge_index},
        mapping)
    print(subgraph)

    return subgraph

get_subgraph_nodes_edges(num_nodes, vertices, edges_dict, mapping)

Get the selected nodes and edges of the subgraph based on the vertices and edges computed by the PCST algorithm.

Parameters:

Name Type Description Default
num_nodes int

The number of nodes in the graph.

required
vertices ndarray

The vertices selected by the PCST algorithm.

required
edges_dict dict

A dictionary containing the edges and the number of prior edges.

required
mapping dict

A dictionary containing the mapping of nodes and edges.

required

Returns:

Type Description
dict

The selected nodes and edges of the extracted subgraph.

Source code in aiagents4pharma/talk2knowledgegraphs/utils/extractions/milvus_multimodal_pcst.py
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
def get_subgraph_nodes_edges(self,
                             num_nodes: int,
                             vertices: py.ndarray,
                             edges_dict: dict,
                             mapping: dict) -> dict:
    """
    Get the selected nodes and edges of the subgraph based on the vertices and edges computed
    by the PCST algorithm.

    Args:
        num_nodes: The number of nodes in the graph.
        vertices: The vertices selected by the PCST algorithm.
        edges_dict: A dictionary containing the edges and the number of prior edges.
        mapping: A dictionary containing the mapping of nodes and edges.

    Returns:
        The selected nodes and edges of the extracted subgraph.
    """
    # Get edges information
    edges = edges_dict["edges"]
    num_prior_edges = edges_dict["num_prior_edges"]
    # Get edges information
    edges = edges_dict["edges"]
    num_prior_edges = edges_dict["num_prior_edges"]
    # Retrieve the selected nodes and edges based on the given vertices and edges
    subgraph_nodes = vertices[vertices < num_nodes]
    subgraph_edges = [mapping["edges"][e.item()] for e in edges if e < num_prior_edges]
    virtual_vertices = vertices[vertices >= num_nodes]
    if len(virtual_vertices) > 0:
        virtual_vertices = vertices[vertices >= num_nodes]
        virtual_edges = [mapping["nodes"][i.item()] for i in virtual_vertices]
        subgraph_edges = py.array(subgraph_edges + virtual_edges)
    edge_index = edges_dict["edge_index"][:, subgraph_edges]
    subgraph_nodes = py.unique(
        py.concatenate(
            [subgraph_nodes, edge_index[0], edge_index[1]]
        )
    )

    return {"nodes": subgraph_nodes, "edges": subgraph_edges}

prepare_collections(cfg, modality)

Prepare the collections for nodes, node-type specific nodes, and edges in Milvus.

Parameters:

Name Type Description Default
cfg dict

The configuration dictionary containing the Milvus setup.

required
modality str

The modality to use for the subgraph extraction.

required

Returns:

Type Description
dict

A dictionary containing the collections of nodes, node-type specific nodes, and edges.

Source code in aiagents4pharma/talk2knowledgegraphs/utils/extractions/milvus_multimodal_pcst.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def prepare_collections(self, cfg: dict, modality: str) -> dict:
    """
    Prepare the collections for nodes, node-type specific nodes, and edges in Milvus.

    Args:
        cfg: The configuration dictionary containing the Milvus setup.
        modality: The modality to use for the subgraph extraction.

    Returns:
        A dictionary containing the collections of nodes, node-type specific nodes, and edges.
    """
    # Initialize the collections dictionary
    colls = {}

    # Load the collection for nodes
    colls["nodes"] = Collection(name=f"{cfg.milvus_db.database_name}_nodes")

    if modality != "prompt":
        # Load the collection for the specific node type
        colls["nodes_type"] = Collection(
            f"{cfg.milvus_db.database_name}_nodes_{modality.replace('/', '_')}"
        )

    # Load the collection for edges
    colls["edges"] = Collection(name=f"{cfg.milvus_db.database_name}_edges")

    # Load the collections
    for coll in colls.values():
        coll.load()

    return colls