Skip to content

Commit

Permalink
addressed comments
Browse files Browse the repository at this point in the history
  • Loading branch information
eavanvalkenburg committed Nov 15, 2024
1 parent 6adc585 commit a1a4032
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from collections.abc import Sequence
from typing import Any, ClassVar, Generic, TypeVar

from semantic_kernel.data.vector_search.vector_search_result import VectorSearchResult

if sys.version_info >= (3, 12):
from typing import override # pragma: no cover
else:
Expand All @@ -33,6 +31,7 @@
VectorSearchOptions,
)
from semantic_kernel.data.vector_search.vector_search import VectorSearchBase
from semantic_kernel.data.vector_search.vector_search_result import VectorSearchResult
from semantic_kernel.data.vector_search.vector_text_search import VectorTextSearchMixin
from semantic_kernel.data.vector_search.vectorized_search import VectorizedSearchMixin
from semantic_kernel.exceptions import MemoryConnectorException, MemoryConnectorInitializationError
Expand Down
3 changes: 2 additions & 1 deletion python/semantic_kernel/connectors/memory/weaviate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,8 @@ def extract_vectors_from_weaviate_object_based_on_data_model_definition(
# region VectorSearch helpers


def _create_filter_from_vector_search_filters(filters: VectorSearchFilter | None) -> "_Filters | None":
def create_filter_from_vector_search_filters(filters: VectorSearchFilter | None) -> "_Filters | None":
"""Create a Weaviate filter from a vector search filter."""
if not filters:
return None
weaviate_filters: list["_Filters"] = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
from weaviate.classes.query import Filter, MetadataQuery
from weaviate.collections.classes.data import DataObject
from weaviate.collections.collection import CollectionAsync
from weaviate.exceptions import WeaviateClosedClientError

from semantic_kernel.connectors.memory.weaviate.utils import (
_create_filter_from_vector_search_filters,
create_filter_from_vector_search_filters,
data_model_definition_to_weaviate_named_vectors,
data_model_definition_to_weaviate_properties,
extract_key_from_dict_record_based_on_data_model_definition,
Expand All @@ -43,6 +44,7 @@
MemoryConnectorException,
MemoryConnectorInitializationError,
)
from semantic_kernel.exceptions.memory_connector_exceptions import VectorStoreModelValidationError
from semantic_kernel.kernel_types import OneOrMany
from semantic_kernel.utils.experimental_decorator import experimental_class

Expand Down Expand Up @@ -172,6 +174,10 @@ async def _inner_upsert(
try:
collection: CollectionAsync = self.async_client.collections.get(self.collection_name)
response = await collection.data.insert_many(records)
except WeaviateClosedClientError as ex:
raise MemoryConnectorException(
"Client is closed, please use the context manager or self.async_client.connect."
) from ex
except Exception as ex:
raise MemoryConnectorException(f"Failed to upsert records: {ex}")

Expand All @@ -187,6 +193,10 @@ async def _inner_get(self, keys: Sequence[TKey], **kwargs: Any) -> OneOrMany[Any
)

return result.objects
except WeaviateClosedClientError as ex:
raise MemoryConnectorException(
"Client is closed, please use the context manager or self.async_client.connect."
) from ex
except Exception as ex:
raise MemoryConnectorException(f"Failed to get records: {ex}")

Expand All @@ -195,6 +205,10 @@ async def _inner_delete(self, keys: Sequence[TKey], **kwargs: Any) -> None:
try:
collection: CollectionAsync = self.async_client.collections.get(self.collection_name)
await collection.data.delete_many(where=Filter.any_of([Filter.by_id().equal(key) for key in keys]))
except WeaviateClosedClientError as ex:
raise MemoryConnectorException(
"Client is closed, please use the context manager or self.async_client.connect."
) from ex
except Exception as ex:
raise MemoryConnectorException(f"Failed to delete records: {ex}")

Expand All @@ -210,7 +224,7 @@ async def _inner_search(
vector_field = self.data_model_definition.try_get_vector_field(options.vector_field_name)
collection: CollectionAsync = self.async_client.collections.get(self.collection_name)
args = {
"filters": _create_filter_from_vector_search_filters(options.filter),
"filters": create_filter_from_vector_search_filters(options.filter),
"include_vector": options.include_vectors,
"limit": options.top,
"offset": options.skip,
Expand Down Expand Up @@ -287,6 +301,10 @@ async def _inner_vectorized_search(
return_metadata=MetadataQuery(distance=True),
**args,
)
except WeaviateClosedClientError as ex:
raise MemoryConnectorException(
"Client is closed, please use the context manager or self.async_client.connect."
) from ex
except Exception as ex:
raise MemoryConnectorException(f"Failed searching using a vector: {ex}") from ex

Expand Down Expand Up @@ -349,6 +367,10 @@ async def create_collection(self, **kwargs) -> None:
if kwargs:
try:
await self.async_client.collections.create(**kwargs)
except WeaviateClosedClientError as ex:
raise MemoryConnectorException(
"Client is closed, please use the context manager or self.async_client.connect."
) from ex
except Exception as ex:
raise MemoryConnectorException(f"Failed to create collection: {ex}") from ex
try:
Expand All @@ -364,6 +386,10 @@ async def create_collection(self, **kwargs) -> None:
if self.named_vectors
else None,
)
except WeaviateClosedClientError as ex:
raise MemoryConnectorException(
"Client is closed, please use the context manager or self.async_client.connect."
) from ex
except Exception as ex:
raise MemoryConnectorException(f"Failed to create collection: {ex}") from ex

Expand All @@ -379,6 +405,10 @@ async def does_collection_exist(self, **kwargs) -> bool:
"""
try:
return await self.async_client.collections.exists(self.collection_name)
except WeaviateClosedClientError as ex:
raise MemoryConnectorException(
"Client is closed, please use the context manager or self.async_client.connect."
) from ex
except Exception as ex:
raise MemoryConnectorException(f"Failed to check if collection exists: {ex}") from ex

Expand All @@ -391,18 +421,28 @@ async def delete_collection(self, **kwargs) -> None:
"""
try:
await self.async_client.collections.delete(self.collection_name)
except WeaviateClosedClientError as ex:
raise MemoryConnectorException(
"Client is closed, please use the context manager or self.async_client.connect."
) from ex
except Exception as ex:
raise MemoryConnectorException(f"Failed to delete collection: {ex}") from ex

@override
async def __aenter__(self) -> "WeaviateCollection":
"""Enter the context manager."""
if not await self.async_client.is_ready():
await self.async_client.connect()
await self.async_client.connect()
return self

@override
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
"""Exit the context manager."""
if self.managed_client:
await self.async_client.close()

def _validate_data_model(self):
super()._validate_data_model()
if self.named_vectors and len(self.data_model_definition.vector_field_names) > 1:
raise VectorStoreModelValidationError(
"Named vectors must be enabled if there are more then 1 vector fields in the data model definition."
)

0 comments on commit a1a4032

Please sign in to comment.