Skip to content

MilvusConnectionManager

Milvus Connection Manager for Talk2KnowledgeGraphs.

This module provides centralized connection management for Milvus database, removing the dependency on frontend session state and enabling proper separation of concerns between frontend and backend.

MilvusConnectionManager

Centralized Milvus connection manager for backend tools with singleton pattern.

This class handles: - Connection establishment and management - Database switching - Connection health checks - Graceful error handling - Thread-safe singleton pattern

Parameters:

Name Type Description Default
cfg dict[str, Any]

Configuration object containing Milvus connection parameters

required
Source code in aiagents4pharma/talk2knowledgegraphs/utils/database/milvus_connection_manager.py
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
class MilvusConnectionManager:
    """
    Centralized Milvus connection manager for backend tools with singleton pattern.

    This class handles:
    - Connection establishment and management
    - Database switching
    - Connection health checks
    - Graceful error handling
    - Thread-safe singleton pattern

    Args:
        cfg: Configuration object containing Milvus connection parameters
    """

    _instances = {}
    _lock = threading.Lock()

    def __new__(cls, cfg: dict[str, Any]):
        """
        Create singleton instance based on database configuration.

        Args:
            cfg: Configuration dictionary containing Milvus DB settings

        Returns:
            MilvusConnectionManager: Singleton instance for the given config
        """
        # Create a unique key based on connection parameters
        config_key = (
            cfg.milvus_db.host,
            int(cfg.milvus_db.port),
            cfg.milvus_db.user,
            cfg.milvus_db.database_name,
            cfg.milvus_db.alias,
        )

        if config_key not in cls._instances:
            with cls._lock:
                # Double-check locking pattern
                if config_key not in cls._instances:
                    instance = super().__new__(cls)
                    cls._instances[config_key] = instance
                    logger.info(
                        "Created new MilvusConnectionManager singleton for database: %s",
                        cfg.milvus_db.database_name,
                    )
        else:
            logger.debug(
                "Reusing existing MilvusConnectionManager singleton for database: %s",
                cfg.milvus_db.database_name,
            )

        return cls._instances[config_key]

    def __init__(self, cfg: dict[str, Any]):
        """
        Initialize the Milvus connection manager.

        Args:
            cfg: Configuration dictionary containing Milvus DB settings
        """
        # Prevent re-initialization of singleton instance
        if hasattr(self, "_initialized"):
            return

        self.cfg = cfg
        self.alias = cfg.milvus_db.alias
        self.host = cfg.milvus_db.host
        self.port = int(cfg.milvus_db.port)  # Ensure port is integer
        self.user = cfg.milvus_db.user
        self.password = cfg.milvus_db.password
        self.database_name = cfg.milvus_db.database_name

        # Thread lock for connection operations
        self._connection_lock = threading.Lock()

        # Initialize both sync and async clients
        self._sync_client = None
        self._async_client = None

        # Mark as initialized
        self._initialized = True

        logger.info("MilvusConnectionManager initialized for database: %s", self.database_name)

    def get_sync_client(self) -> MilvusClient:
        """
        Get or create a synchronous MilvusClient.

        Returns:
            MilvusClient: Configured synchronous client
        """
        if self._sync_client is None:
            self._sync_client = MilvusClient(
                uri=f"http://{self.host}:{self.port}",
                token=f"{self.user}:{self.password}",
                db_name=self.database_name,
            )
            logger.info("Created synchronous MilvusClient for database: %s", self.database_name)
        return self._sync_client

    def get_async_client(self) -> AsyncMilvusClient:
        """
        Get or create an asynchronous AsyncMilvusClient.

        Returns:
            AsyncMilvusClient: Configured asynchronous client
        """
        if self._async_client is None:
            try:
                self._async_client = AsyncMilvusClient(
                    uri=f"http://{self.host}:{self.port}",
                    token=f"{self.user}:{self.password}",
                    db_name=self.database_name,
                )
                logger.info(
                    "Created asynchronous AsyncMilvusClient for database: %s",
                    self.database_name,
                )
            except (MilvusException, RuntimeError, ConnectionError, OSError) as e:
                logger.error("Failed to create async client: %s", str(e))
                # Don't raise here, let the calling method handle the fallback
                return None
        return self._async_client

    def ensure_connection(self) -> bool:
        """
        Ensure Milvus connection exists, create if not.

        This method checks if a connection with the specified alias exists,
        and creates one if it doesn't. It also switches to the correct database.
        Thread-safe implementation with connection locking.

        Returns:
            bool: True if connection is established, False otherwise

        Raises:
            MilvusException: If connection cannot be established
        """
        with self._connection_lock:
            try:
                # Check if connection already exists
                if not connections.has_connection(self.alias):
                    logger.info("Creating new Milvus connection with alias: %s", self.alias)
                    connections.connect(
                        alias=self.alias,
                        host=self.host,
                        port=self.port,
                        user=self.user,
                        password=self.password,
                    )
                    logger.info(
                        "Successfully connected to Milvus at %s:%s",
                        self.host,
                        self.port,
                    )
                else:
                    logger.debug("Milvus connection already exists with alias: %s", self.alias)

                # Switch to the correct database
                db.using_database(self.database_name)
                logger.debug("Using Milvus database: %s", self.database_name)

                return True

            except MilvusException as e:
                logger.error("Failed to establish Milvus connection: %s", str(e))
                raise
            except Exception as e:
                logger.error("Unexpected error during Milvus connection: %s", str(e))
                raise MilvusException(f"Connection failed: {str(e)}") from e

    def get_connection_info(self) -> dict[str, Any]:
        """
        Get current connection information.

        Returns:
            Dict containing connection details
        """
        try:
            if connections.has_connection(self.alias):
                conn_addr = connections.get_connection_addr(self.alias)
                return {
                    "alias": self.alias,
                    "host": self.host,
                    "port": self.port,
                    "database": self.database_name,
                    "connected": True,
                    "connection_address": conn_addr,
                }
            return {
                "alias": self.alias,
                "host": self.host,
                "port": self.port,
                "database": self.database_name,
                "connected": False,
                "connection_address": None,
            }
        except (MilvusException, RuntimeError, ConnectionError, OSError) as e:
            logger.error("Error getting connection info: %s", str(e))
            return {"alias": self.alias, "connected": False, "error": str(e)}

    def test_connection(self) -> bool:
        """
        Test the connection by attempting to list collections.

        Returns:
            bool: True if connection is healthy, False otherwise
        """
        try:
            self.ensure_connection()

            # Try to get a collection to test the connection
            test_collection_name = f"{self.database_name}_nodes"
            Collection(name=test_collection_name)

            logger.debug("Connection test successful")
            return True

        except (MilvusException, RuntimeError, ConnectionError, OSError) as e:
            logger.error("Connection test failed: %s", str(e))
            return False

    def disconnect(self) -> bool:
        """
        Disconnect from Milvus (both sync and async clients).

        Returns:
            bool: True if disconnected successfully, False otherwise
        """
        try:
            success = True

            # Disconnect sync client
            if connections.has_connection(self.alias):
                connections.disconnect(self.alias)
                logger.info("Disconnected sync connection with alias: %s", self.alias)

            # Close async client if it exists
            if self._async_client is not None:
                try:
                    # Check if we can close the async client properly
                    try:
                        loop = asyncio.get_running_loop()
                        # If there's a running loop, create a task
                        loop.create_task(self._async_client.close())
                    except RuntimeError:
                        # No running loop, use asyncio.run in a thread
                        with concurrent.futures.ThreadPoolExecutor() as executor:
                            executor.submit(lambda: asyncio.run(self._async_client.close())).result(
                                timeout=5
                            )

                    self._async_client = None
                    logger.info("Closed async client for database: %s", self.database_name)
                except (TimeoutError, RuntimeError) as e:
                    logger.warning("Error closing async client: %s", str(e))
                    # Still clear the reference even if close failed
                    self._async_client = None
                    success = False

            # Clear sync client reference
            if self._sync_client is not None:
                self._sync_client = None
                logger.info("Cleared sync client reference")

            return success

        except (MilvusException, RuntimeError, ConnectionError, OSError) as e:
            logger.error("Error disconnecting from Milvus: %s", str(e))
            return False

    def get_collection(self, collection_name: str) -> Collection:
        """
        Get a Milvus collection, ensuring connection is established.
        Thread-safe implementation.

        Args:
            collection_name: Name of the collection to retrieve

        Returns:
            Collection: The requested Milvus collection

        Raises:
            MilvusException: If collection cannot be retrieved
        """
        try:
            self.ensure_connection()
            collection = Collection(name=collection_name)
            collection.load()  # Load collection data
            logger.debug("Successfully loaded collection: %s", collection_name)
            return collection

        except Exception as e:
            logger.error("Failed to get collection %s: %s", collection_name, str(e))
            raise MilvusException(f"Failed to get collection {collection_name}: {str(e)}") from e

    async def async_search(self, params: SearchParams) -> list:
        """
        Perform asynchronous vector search.

        Args:
            params: SearchParams object containing all search parameters

        Returns:
            List of search results
        """
        try:
            async_client = self.get_async_client()
            if async_client is None:
                raise MilvusException("Failed to create async client")

            # Ensure collection is loaded before searching
            await async_client.load_collection(collection_name=params.collection_name)

            results = await async_client.search(
                collection_name=params.collection_name,
                data=params.data,
                anns_field=params.anns_field,
                search_params=params.search_params,
                limit=params.limit,
                output_fields=params.output_fields or [],
            )
            logger.debug("Async search completed for collection: %s", params.collection_name)
            return results
        except MilvusException as e:
            logger.warning(
                "Async search failed for collection %s: %s, falling back to sync",
                params.collection_name,
                str(e),
            )
            # Fallback to sync operation
            return await asyncio.to_thread(self._sync_search, params)

    def _sync_search(self, params: SearchParams) -> list:
        """Sync fallback for search operations."""
        try:
            collection = Collection(name=params.collection_name)
            collection.load()
            results = collection.search(
                data=params.data,
                anns_field=params.anns_field,
                param=params.search_params,
                limit=params.limit,
                output_fields=params.output_fields or [],
            )
            logger.debug(
                "Sync fallback search completed for collection: %s",
                params.collection_name,
            )
            return results
        except Exception as e:
            logger.error(
                "Sync fallback search failed for collection %s: %s",
                params.collection_name,
                str(e),
            )
            raise MilvusException(f"Search failed (sync fallback): {str(e)}") from e

    async def async_query(self, params: QueryParams) -> list:
        """
        Perform asynchronous query with sync fallback.

        Args:
            params: QueryParams object containing all query parameters

        Returns:
            List of query results
        """
        try:
            async_client = self.get_async_client()
            if async_client is None:
                raise MilvusException("Failed to create async client")

            # Ensure collection is loaded before querying
            await async_client.load_collection(collection_name=params.collection_name)

            results = await async_client.query(
                collection_name=params.collection_name,
                filter=params.expr,
                output_fields=params.output_fields or [],
                limit=params.limit,
            )
            logger.debug("Async query completed for collection: %s", params.collection_name)
            return results
        except MilvusException as e:
            logger.warning(
                "Async query failed for collection %s: %s, falling back to sync",
                params.collection_name,
                str(e),
            )
            # Fallback to sync operation
            return await asyncio.to_thread(self._sync_query, params)

    def _sync_query(self, params: QueryParams) -> list:
        """Sync fallback for query operations."""
        try:
            collection = Collection(name=params.collection_name)
            collection.load()
            results = collection.query(
                expr=params.expr,
                output_fields=params.output_fields or [],
                limit=params.limit,
            )
            logger.debug(
                "Sync fallback query completed for collection: %s",
                params.collection_name,
            )
            return results
        except Exception as e:
            logger.error(
                "Sync fallback query failed for collection %s: %s",
                params.collection_name,
                str(e),
            )
            raise MilvusException(f"Query failed (sync fallback): {str(e)}") from e

    async def async_load_collection(self, collection_name: str) -> bool:
        """
        Asynchronously load a collection.

        Args:
            collection_name: Name of the collection to load

        Returns:
            bool: True if loaded successfully
        """
        try:
            async_client = self.get_async_client()
            await async_client.load_collection(collection_name=collection_name)
            logger.debug("Async load completed for collection: %s", collection_name)
            return True
        except Exception as e:
            logger.error("Async load failed for collection %s: %s", collection_name, str(e))
            raise MilvusException(f"Async load failed: {str(e)}") from e

    async def async_get_collection_stats(self, collection_name: str) -> dict:
        """
        Get collection statistics asynchronously.

        Args:
            collection_name: Name of the collection

        Returns:
            dict: Collection statistics
        """
        try:
            # Note: Using sync client methods through asyncio.to_thread as fallback
            # since AsyncMilvusClient might not have all stat methods
            stats = await asyncio.to_thread(lambda: Collection(name=collection_name).num_entities)
            return {"num_entities": stats}
        except Exception as e:
            logger.error(
                "Failed to get async collection stats for %s: %s",
                collection_name,
                str(e),
            )
            raise MilvusException(f"Failed to get collection stats: {str(e)}") from e

    @classmethod
    def get_instance(cls, cfg: dict[str, Any]) -> "MilvusConnectionManager":
        """
        Get singleton instance for the given configuration.

        Args:
            cfg: Configuration dictionary containing Milvus DB settings

        Returns:
            MilvusConnectionManager: Singleton instance for the given config
        """
        return cls(cfg)

    @classmethod
    def clear_instances(cls):
        """
        Clear all singleton instances. Useful for testing or cleanup.
        """
        with cls._lock:
            # Disconnect all existing connections before clearing
            for instance in cls._instances.values():
                instance.disconnect()
            cls._instances.clear()
            logger.info("Cleared all MilvusConnectionManager singleton instances")

    @classmethod
    def from_config(cls, cfg: dict[str, Any]) -> "MilvusConnectionManager":
        """
        Create a MilvusConnectionManager from configuration.

        Args:
            cfg: Configuration object or dictionary

        Returns:
            MilvusConnectionManager: Configured connection manager instance
        """
        return cls(cfg)

    @classmethod
    def from_hydra_config(
        cls,
        config_path: str = "../configs",
        config_name: str = "config",
        overrides: list | None = None,
    ) -> "MilvusConnectionManager":
        """
        Create a MilvusConnectionManager from Hydra configuration.

        This method loads the Milvus database configuration using Hydra,
        providing complete backend separation from frontend configs.

        Args:
            config_path: Path to the configs directory
            config_name: Name of the main config file
            overrides: List of config overrides

        Returns:
            MilvusConnectionManager: Configured connection manager instance

        Example:
            # Load with default database config
            conn_manager = MilvusConnectionManager.from_hydra_config()

            # Load with specific overrides
            conn_manager = MilvusConnectionManager.from_hydra_config(
                overrides=["utils/database/milvus=default"]
            )
        """
        if overrides is None:
            overrides = ["utils/database/milvus=default"]

        try:
            with hydra.initialize(version_base=None, config_path=config_path):
                cfg_all = hydra.compose(config_name=config_name, overrides=overrides)
                cfg = cfg_all.utils.database.milvus  # Extract utils.database.milvus section
                logger.info("Loaded Milvus config from Hydra with overrides: %s", overrides)
                return cls(cfg)
        except Exception as e:
            logger.error("Failed to load Hydra configuration: %s", str(e))
            raise MilvusException(f"Configuration loading failed: {str(e)}") from e

__init__(cfg)

Initialize the Milvus connection manager.

Parameters:

Name Type Description Default
cfg dict[str, Any]

Configuration dictionary containing Milvus DB settings

required
Source code in aiagents4pharma/talk2knowledgegraphs/utils/database/milvus_connection_manager.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def __init__(self, cfg: dict[str, Any]):
    """
    Initialize the Milvus connection manager.

    Args:
        cfg: Configuration dictionary containing Milvus DB settings
    """
    # Prevent re-initialization of singleton instance
    if hasattr(self, "_initialized"):
        return

    self.cfg = cfg
    self.alias = cfg.milvus_db.alias
    self.host = cfg.milvus_db.host
    self.port = int(cfg.milvus_db.port)  # Ensure port is integer
    self.user = cfg.milvus_db.user
    self.password = cfg.milvus_db.password
    self.database_name = cfg.milvus_db.database_name

    # Thread lock for connection operations
    self._connection_lock = threading.Lock()

    # Initialize both sync and async clients
    self._sync_client = None
    self._async_client = None

    # Mark as initialized
    self._initialized = True

    logger.info("MilvusConnectionManager initialized for database: %s", self.database_name)

__new__(cfg)

Create singleton instance based on database configuration.

Parameters:

Name Type Description Default
cfg dict[str, Any]

Configuration dictionary containing Milvus DB settings

required

Returns:

Name Type Description
MilvusConnectionManager

Singleton instance for the given config

Source code in aiagents4pharma/talk2knowledgegraphs/utils/database/milvus_connection_manager.py
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def __new__(cls, cfg: dict[str, Any]):
    """
    Create singleton instance based on database configuration.

    Args:
        cfg: Configuration dictionary containing Milvus DB settings

    Returns:
        MilvusConnectionManager: Singleton instance for the given config
    """
    # Create a unique key based on connection parameters
    config_key = (
        cfg.milvus_db.host,
        int(cfg.milvus_db.port),
        cfg.milvus_db.user,
        cfg.milvus_db.database_name,
        cfg.milvus_db.alias,
    )

    if config_key not in cls._instances:
        with cls._lock:
            # Double-check locking pattern
            if config_key not in cls._instances:
                instance = super().__new__(cls)
                cls._instances[config_key] = instance
                logger.info(
                    "Created new MilvusConnectionManager singleton for database: %s",
                    cfg.milvus_db.database_name,
                )
    else:
        logger.debug(
            "Reusing existing MilvusConnectionManager singleton for database: %s",
            cfg.milvus_db.database_name,
        )

    return cls._instances[config_key]

_sync_query(params)

Sync fallback for query operations.

Source code in aiagents4pharma/talk2knowledgegraphs/utils/database/milvus_connection_manager.py
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
def _sync_query(self, params: QueryParams) -> list:
    """Sync fallback for query operations."""
    try:
        collection = Collection(name=params.collection_name)
        collection.load()
        results = collection.query(
            expr=params.expr,
            output_fields=params.output_fields or [],
            limit=params.limit,
        )
        logger.debug(
            "Sync fallback query completed for collection: %s",
            params.collection_name,
        )
        return results
    except Exception as e:
        logger.error(
            "Sync fallback query failed for collection %s: %s",
            params.collection_name,
            str(e),
        )
        raise MilvusException(f"Query failed (sync fallback): {str(e)}") from e

Sync fallback for search operations.

Source code in aiagents4pharma/talk2knowledgegraphs/utils/database/milvus_connection_manager.py
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
def _sync_search(self, params: SearchParams) -> list:
    """Sync fallback for search operations."""
    try:
        collection = Collection(name=params.collection_name)
        collection.load()
        results = collection.search(
            data=params.data,
            anns_field=params.anns_field,
            param=params.search_params,
            limit=params.limit,
            output_fields=params.output_fields or [],
        )
        logger.debug(
            "Sync fallback search completed for collection: %s",
            params.collection_name,
        )
        return results
    except Exception as e:
        logger.error(
            "Sync fallback search failed for collection %s: %s",
            params.collection_name,
            str(e),
        )
        raise MilvusException(f"Search failed (sync fallback): {str(e)}") from e

async_get_collection_stats(collection_name) async

Get collection statistics asynchronously.

Parameters:

Name Type Description Default
collection_name str

Name of the collection

required

Returns:

Name Type Description
dict dict

Collection statistics

Source code in aiagents4pharma/talk2knowledgegraphs/utils/database/milvus_connection_manager.py
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
async def async_get_collection_stats(self, collection_name: str) -> dict:
    """
    Get collection statistics asynchronously.

    Args:
        collection_name: Name of the collection

    Returns:
        dict: Collection statistics
    """
    try:
        # Note: Using sync client methods through asyncio.to_thread as fallback
        # since AsyncMilvusClient might not have all stat methods
        stats = await asyncio.to_thread(lambda: Collection(name=collection_name).num_entities)
        return {"num_entities": stats}
    except Exception as e:
        logger.error(
            "Failed to get async collection stats for %s: %s",
            collection_name,
            str(e),
        )
        raise MilvusException(f"Failed to get collection stats: {str(e)}") from e

async_load_collection(collection_name) async

Asynchronously load a collection.

Parameters:

Name Type Description Default
collection_name str

Name of the collection to load

required

Returns:

Name Type Description
bool bool

True if loaded successfully

Source code in aiagents4pharma/talk2knowledgegraphs/utils/database/milvus_connection_manager.py
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
async def async_load_collection(self, collection_name: str) -> bool:
    """
    Asynchronously load a collection.

    Args:
        collection_name: Name of the collection to load

    Returns:
        bool: True if loaded successfully
    """
    try:
        async_client = self.get_async_client()
        await async_client.load_collection(collection_name=collection_name)
        logger.debug("Async load completed for collection: %s", collection_name)
        return True
    except Exception as e:
        logger.error("Async load failed for collection %s: %s", collection_name, str(e))
        raise MilvusException(f"Async load failed: {str(e)}") from e

async_query(params) async

Perform asynchronous query with sync fallback.

Parameters:

Name Type Description Default
params QueryParams

QueryParams object containing all query parameters

required

Returns:

Type Description
list

List of query results

Source code in aiagents4pharma/talk2knowledgegraphs/utils/database/milvus_connection_manager.py
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
async def async_query(self, params: QueryParams) -> list:
    """
    Perform asynchronous query with sync fallback.

    Args:
        params: QueryParams object containing all query parameters

    Returns:
        List of query results
    """
    try:
        async_client = self.get_async_client()
        if async_client is None:
            raise MilvusException("Failed to create async client")

        # Ensure collection is loaded before querying
        await async_client.load_collection(collection_name=params.collection_name)

        results = await async_client.query(
            collection_name=params.collection_name,
            filter=params.expr,
            output_fields=params.output_fields or [],
            limit=params.limit,
        )
        logger.debug("Async query completed for collection: %s", params.collection_name)
        return results
    except MilvusException as e:
        logger.warning(
            "Async query failed for collection %s: %s, falling back to sync",
            params.collection_name,
            str(e),
        )
        # Fallback to sync operation
        return await asyncio.to_thread(self._sync_query, params)

Perform asynchronous vector search.

Parameters:

Name Type Description Default
params SearchParams

SearchParams object containing all search parameters

required

Returns:

Type Description
list

List of search results

Source code in aiagents4pharma/talk2knowledgegraphs/utils/database/milvus_connection_manager.py
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
async def async_search(self, params: SearchParams) -> list:
    """
    Perform asynchronous vector search.

    Args:
        params: SearchParams object containing all search parameters

    Returns:
        List of search results
    """
    try:
        async_client = self.get_async_client()
        if async_client is None:
            raise MilvusException("Failed to create async client")

        # Ensure collection is loaded before searching
        await async_client.load_collection(collection_name=params.collection_name)

        results = await async_client.search(
            collection_name=params.collection_name,
            data=params.data,
            anns_field=params.anns_field,
            search_params=params.search_params,
            limit=params.limit,
            output_fields=params.output_fields or [],
        )
        logger.debug("Async search completed for collection: %s", params.collection_name)
        return results
    except MilvusException as e:
        logger.warning(
            "Async search failed for collection %s: %s, falling back to sync",
            params.collection_name,
            str(e),
        )
        # Fallback to sync operation
        return await asyncio.to_thread(self._sync_search, params)

clear_instances() classmethod

Clear all singleton instances. Useful for testing or cleanup.

Source code in aiagents4pharma/talk2knowledgegraphs/utils/database/milvus_connection_manager.py
520
521
522
523
524
525
526
527
528
529
530
@classmethod
def clear_instances(cls):
    """
    Clear all singleton instances. Useful for testing or cleanup.
    """
    with cls._lock:
        # Disconnect all existing connections before clearing
        for instance in cls._instances.values():
            instance.disconnect()
        cls._instances.clear()
        logger.info("Cleared all MilvusConnectionManager singleton instances")

disconnect()

Disconnect from Milvus (both sync and async clients).

Returns:

Name Type Description
bool bool

True if disconnected successfully, False otherwise

Source code in aiagents4pharma/talk2knowledgegraphs/utils/database/milvus_connection_manager.py
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
def disconnect(self) -> bool:
    """
    Disconnect from Milvus (both sync and async clients).

    Returns:
        bool: True if disconnected successfully, False otherwise
    """
    try:
        success = True

        # Disconnect sync client
        if connections.has_connection(self.alias):
            connections.disconnect(self.alias)
            logger.info("Disconnected sync connection with alias: %s", self.alias)

        # Close async client if it exists
        if self._async_client is not None:
            try:
                # Check if we can close the async client properly
                try:
                    loop = asyncio.get_running_loop()
                    # If there's a running loop, create a task
                    loop.create_task(self._async_client.close())
                except RuntimeError:
                    # No running loop, use asyncio.run in a thread
                    with concurrent.futures.ThreadPoolExecutor() as executor:
                        executor.submit(lambda: asyncio.run(self._async_client.close())).result(
                            timeout=5
                        )

                self._async_client = None
                logger.info("Closed async client for database: %s", self.database_name)
            except (TimeoutError, RuntimeError) as e:
                logger.warning("Error closing async client: %s", str(e))
                # Still clear the reference even if close failed
                self._async_client = None
                success = False

        # Clear sync client reference
        if self._sync_client is not None:
            self._sync_client = None
            logger.info("Cleared sync client reference")

        return success

    except (MilvusException, RuntimeError, ConnectionError, OSError) as e:
        logger.error("Error disconnecting from Milvus: %s", str(e))
        return False

ensure_connection()

Ensure Milvus connection exists, create if not.

This method checks if a connection with the specified alias exists, and creates one if it doesn't. It also switches to the correct database. Thread-safe implementation with connection locking.

Returns:

Name Type Description
bool bool

True if connection is established, False otherwise

Raises:

Type Description
MilvusException

If connection cannot be established

Source code in aiagents4pharma/talk2knowledgegraphs/utils/database/milvus_connection_manager.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
def ensure_connection(self) -> bool:
    """
    Ensure Milvus connection exists, create if not.

    This method checks if a connection with the specified alias exists,
    and creates one if it doesn't. It also switches to the correct database.
    Thread-safe implementation with connection locking.

    Returns:
        bool: True if connection is established, False otherwise

    Raises:
        MilvusException: If connection cannot be established
    """
    with self._connection_lock:
        try:
            # Check if connection already exists
            if not connections.has_connection(self.alias):
                logger.info("Creating new Milvus connection with alias: %s", self.alias)
                connections.connect(
                    alias=self.alias,
                    host=self.host,
                    port=self.port,
                    user=self.user,
                    password=self.password,
                )
                logger.info(
                    "Successfully connected to Milvus at %s:%s",
                    self.host,
                    self.port,
                )
            else:
                logger.debug("Milvus connection already exists with alias: %s", self.alias)

            # Switch to the correct database
            db.using_database(self.database_name)
            logger.debug("Using Milvus database: %s", self.database_name)

            return True

        except MilvusException as e:
            logger.error("Failed to establish Milvus connection: %s", str(e))
            raise
        except Exception as e:
            logger.error("Unexpected error during Milvus connection: %s", str(e))
            raise MilvusException(f"Connection failed: {str(e)}") from e

from_config(cfg) classmethod

Create a MilvusConnectionManager from configuration.

Parameters:

Name Type Description Default
cfg dict[str, Any]

Configuration object or dictionary

required

Returns:

Name Type Description
MilvusConnectionManager MilvusConnectionManager

Configured connection manager instance

Source code in aiagents4pharma/talk2knowledgegraphs/utils/database/milvus_connection_manager.py
532
533
534
535
536
537
538
539
540
541
542
543
@classmethod
def from_config(cls, cfg: dict[str, Any]) -> "MilvusConnectionManager":
    """
    Create a MilvusConnectionManager from configuration.

    Args:
        cfg: Configuration object or dictionary

    Returns:
        MilvusConnectionManager: Configured connection manager instance
    """
    return cls(cfg)

from_hydra_config(config_path='../configs', config_name='config', overrides=None) classmethod

Create a MilvusConnectionManager from Hydra configuration.

This method loads the Milvus database configuration using Hydra, providing complete backend separation from frontend configs.

Parameters:

Name Type Description Default
config_path str

Path to the configs directory

'../configs'
config_name str

Name of the main config file

'config'
overrides list | None

List of config overrides

None

Returns:

Name Type Description
MilvusConnectionManager MilvusConnectionManager

Configured connection manager instance

Example

Load with default database config

conn_manager = MilvusConnectionManager.from_hydra_config()

Load with specific overrides

conn_manager = MilvusConnectionManager.from_hydra_config( overrides=["utils/database/milvus=default"] )

Source code in aiagents4pharma/talk2knowledgegraphs/utils/database/milvus_connection_manager.py
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
@classmethod
def from_hydra_config(
    cls,
    config_path: str = "../configs",
    config_name: str = "config",
    overrides: list | None = None,
) -> "MilvusConnectionManager":
    """
    Create a MilvusConnectionManager from Hydra configuration.

    This method loads the Milvus database configuration using Hydra,
    providing complete backend separation from frontend configs.

    Args:
        config_path: Path to the configs directory
        config_name: Name of the main config file
        overrides: List of config overrides

    Returns:
        MilvusConnectionManager: Configured connection manager instance

    Example:
        # Load with default database config
        conn_manager = MilvusConnectionManager.from_hydra_config()

        # Load with specific overrides
        conn_manager = MilvusConnectionManager.from_hydra_config(
            overrides=["utils/database/milvus=default"]
        )
    """
    if overrides is None:
        overrides = ["utils/database/milvus=default"]

    try:
        with hydra.initialize(version_base=None, config_path=config_path):
            cfg_all = hydra.compose(config_name=config_name, overrides=overrides)
            cfg = cfg_all.utils.database.milvus  # Extract utils.database.milvus section
            logger.info("Loaded Milvus config from Hydra with overrides: %s", overrides)
            return cls(cfg)
    except Exception as e:
        logger.error("Failed to load Hydra configuration: %s", str(e))
        raise MilvusException(f"Configuration loading failed: {str(e)}") from e

get_async_client()

Get or create an asynchronous AsyncMilvusClient.

Returns:

Name Type Description
AsyncMilvusClient AsyncMilvusClient

Configured asynchronous client

Source code in aiagents4pharma/talk2knowledgegraphs/utils/database/milvus_connection_manager.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def get_async_client(self) -> AsyncMilvusClient:
    """
    Get or create an asynchronous AsyncMilvusClient.

    Returns:
        AsyncMilvusClient: Configured asynchronous client
    """
    if self._async_client is None:
        try:
            self._async_client = AsyncMilvusClient(
                uri=f"http://{self.host}:{self.port}",
                token=f"{self.user}:{self.password}",
                db_name=self.database_name,
            )
            logger.info(
                "Created asynchronous AsyncMilvusClient for database: %s",
                self.database_name,
            )
        except (MilvusException, RuntimeError, ConnectionError, OSError) as e:
            logger.error("Failed to create async client: %s", str(e))
            # Don't raise here, let the calling method handle the fallback
            return None
    return self._async_client

get_collection(collection_name)

Get a Milvus collection, ensuring connection is established. Thread-safe implementation.

Parameters:

Name Type Description Default
collection_name str

Name of the collection to retrieve

required

Returns:

Name Type Description
Collection Collection

The requested Milvus collection

Raises:

Type Description
MilvusException

If collection cannot be retrieved

Source code in aiagents4pharma/talk2knowledgegraphs/utils/database/milvus_connection_manager.py
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
def get_collection(self, collection_name: str) -> Collection:
    """
    Get a Milvus collection, ensuring connection is established.
    Thread-safe implementation.

    Args:
        collection_name: Name of the collection to retrieve

    Returns:
        Collection: The requested Milvus collection

    Raises:
        MilvusException: If collection cannot be retrieved
    """
    try:
        self.ensure_connection()
        collection = Collection(name=collection_name)
        collection.load()  # Load collection data
        logger.debug("Successfully loaded collection: %s", collection_name)
        return collection

    except Exception as e:
        logger.error("Failed to get collection %s: %s", collection_name, str(e))
        raise MilvusException(f"Failed to get collection {collection_name}: {str(e)}") from e

get_connection_info()

Get current connection information.

Returns:

Type Description
dict[str, Any]

Dict containing connection details

Source code in aiagents4pharma/talk2knowledgegraphs/utils/database/milvus_connection_manager.py
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
def get_connection_info(self) -> dict[str, Any]:
    """
    Get current connection information.

    Returns:
        Dict containing connection details
    """
    try:
        if connections.has_connection(self.alias):
            conn_addr = connections.get_connection_addr(self.alias)
            return {
                "alias": self.alias,
                "host": self.host,
                "port": self.port,
                "database": self.database_name,
                "connected": True,
                "connection_address": conn_addr,
            }
        return {
            "alias": self.alias,
            "host": self.host,
            "port": self.port,
            "database": self.database_name,
            "connected": False,
            "connection_address": None,
        }
    except (MilvusException, RuntimeError, ConnectionError, OSError) as e:
        logger.error("Error getting connection info: %s", str(e))
        return {"alias": self.alias, "connected": False, "error": str(e)}

get_instance(cfg) classmethod

Get singleton instance for the given configuration.

Parameters:

Name Type Description Default
cfg dict[str, Any]

Configuration dictionary containing Milvus DB settings

required

Returns:

Name Type Description
MilvusConnectionManager MilvusConnectionManager

Singleton instance for the given config

Source code in aiagents4pharma/talk2knowledgegraphs/utils/database/milvus_connection_manager.py
507
508
509
510
511
512
513
514
515
516
517
518
@classmethod
def get_instance(cls, cfg: dict[str, Any]) -> "MilvusConnectionManager":
    """
    Get singleton instance for the given configuration.

    Args:
        cfg: Configuration dictionary containing Milvus DB settings

    Returns:
        MilvusConnectionManager: Singleton instance for the given config
    """
    return cls(cfg)

get_sync_client()

Get or create a synchronous MilvusClient.

Returns:

Name Type Description
MilvusClient MilvusClient

Configured synchronous client

Source code in aiagents4pharma/talk2knowledgegraphs/utils/database/milvus_connection_manager.py
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def get_sync_client(self) -> MilvusClient:
    """
    Get or create a synchronous MilvusClient.

    Returns:
        MilvusClient: Configured synchronous client
    """
    if self._sync_client is None:
        self._sync_client = MilvusClient(
            uri=f"http://{self.host}:{self.port}",
            token=f"{self.user}:{self.password}",
            db_name=self.database_name,
        )
        logger.info("Created synchronous MilvusClient for database: %s", self.database_name)
    return self._sync_client

test_connection()

Test the connection by attempting to list collections.

Returns:

Name Type Description
bool bool

True if connection is healthy, False otherwise

Source code in aiagents4pharma/talk2knowledgegraphs/utils/database/milvus_connection_manager.py
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
def test_connection(self) -> bool:
    """
    Test the connection by attempting to list collections.

    Returns:
        bool: True if connection is healthy, False otherwise
    """
    try:
        self.ensure_connection()

        # Try to get a collection to test the connection
        test_collection_name = f"{self.database_name}_nodes"
        Collection(name=test_collection_name)

        logger.debug("Connection test successful")
        return True

    except (MilvusException, RuntimeError, ConnectionError, OSError) as e:
        logger.error("Connection test failed: %s", str(e))
        return False

QueryParams dataclass

Parameters for query operations.

Source code in aiagents4pharma/talk2knowledgegraphs/utils/database/milvus_connection_manager.py
37
38
39
40
41
42
43
44
@dataclass
class QueryParams:
    """Parameters for query operations."""

    collection_name: str
    expr: str
    output_fields: list | None = None
    limit: int | None = None

SearchParams dataclass

Parameters for search operations.

Source code in aiagents4pharma/talk2knowledgegraphs/utils/database/milvus_connection_manager.py
25
26
27
28
29
30
31
32
33
34
@dataclass
class SearchParams:
    """Parameters for search operations."""

    collection_name: str
    data: list
    anns_field: str
    search_params: dict
    limit: int
    output_fields: list | None = None