118 lines
3.6 KiB
Python
118 lines
3.6 KiB
Python
from typing import List, Optional, Dict, Tuple, Union
|
|
|
|
from chromadb import (
|
|
EphemeralClient,
|
|
PersistentClient,
|
|
DEFAULT_TENANT,
|
|
DEFAULT_DATABASE,
|
|
CloudClient,
|
|
AsyncHttpClient,
|
|
)
|
|
from chromadb.api.models.AsyncCollection import AsyncCollection
|
|
from chromadb.api.types import IncludeEnum, QueryResult
|
|
|
|
from autorag.utils.util import apply_recursive
|
|
from autorag.vectordb.base import BaseVectorStore
|
|
|
|
|
|
class Chroma(BaseVectorStore):
|
|
def __init__(
|
|
self,
|
|
embedding_model: Union[str, List[dict]],
|
|
collection_name: str,
|
|
embedding_batch: int = 100,
|
|
client_type: str = "persistent",
|
|
similarity_metric: str = "cosine",
|
|
path: str = None,
|
|
host: str = "localhost",
|
|
port: int = 8000,
|
|
ssl: bool = False,
|
|
headers: Optional[Dict[str, str]] = None,
|
|
api_key: Optional[str] = None,
|
|
tenant: str = DEFAULT_TENANT,
|
|
database: str = DEFAULT_DATABASE,
|
|
):
|
|
super().__init__(embedding_model, similarity_metric, embedding_batch)
|
|
if client_type == "ephemeral":
|
|
self.client = EphemeralClient(tenant=tenant, database=database)
|
|
elif client_type == "persistent":
|
|
assert path is not None, "path must be provided for persistent client"
|
|
self.client = PersistentClient(path=path, tenant=tenant, database=database)
|
|
elif client_type == "http":
|
|
self.client = AsyncHttpClient(
|
|
host=host,
|
|
port=port,
|
|
ssl=ssl,
|
|
headers=headers,
|
|
tenant=tenant,
|
|
database=database,
|
|
)
|
|
elif client_type == "cloud":
|
|
self.client = CloudClient(
|
|
tenant=tenant,
|
|
database=database,
|
|
api_key=api_key,
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"client_type {client_type} is not supported\n"
|
|
"supported client types are: ephemeral, persistent, http, cloud"
|
|
)
|
|
|
|
self.collection = self.client.get_or_create_collection(
|
|
name=collection_name,
|
|
metadata={"hnsw:space": similarity_metric},
|
|
)
|
|
|
|
async def add(self, ids: List[str], texts: List[str]):
|
|
texts = self.truncated_inputs(texts)
|
|
text_embeddings = await self.embedding.aget_text_embedding_batch(texts)
|
|
if isinstance(self.collection, AsyncCollection):
|
|
await self.collection.add(ids=ids, embeddings=text_embeddings)
|
|
else:
|
|
self.collection.add(ids=ids, embeddings=text_embeddings)
|
|
|
|
async def fetch(self, ids: List[str]) -> List[List[float]]:
|
|
if isinstance(self.collection, AsyncCollection):
|
|
fetch_result = await self.collection.get(
|
|
ids, include=[IncludeEnum.embeddings]
|
|
)
|
|
else:
|
|
fetch_result = self.collection.get(ids, include=[IncludeEnum.embeddings])
|
|
fetch_embeddings = fetch_result["embeddings"]
|
|
return fetch_embeddings
|
|
|
|
async def is_exist(self, ids: List[str]) -> List[bool]:
|
|
if isinstance(self.collection, AsyncCollection):
|
|
fetched_result = await self.collection.get(ids, include=[])
|
|
else:
|
|
fetched_result = self.collection.get(ids, include=[])
|
|
existed_ids = fetched_result["ids"]
|
|
return list(map(lambda x: x in existed_ids, ids))
|
|
|
|
async def query(
|
|
self, queries: List[str], top_k: int, **kwargs
|
|
) -> Tuple[List[List[str]], List[List[float]]]:
|
|
queries = self.truncated_inputs(queries)
|
|
query_embeddings: List[
|
|
List[float]
|
|
] = await self.embedding.aget_text_embedding_batch(queries)
|
|
if isinstance(self.collection, AsyncCollection):
|
|
query_result: QueryResult = await self.collection.query(
|
|
query_embeddings=query_embeddings, n_results=top_k
|
|
)
|
|
else:
|
|
query_result: QueryResult = self.collection.query(
|
|
query_embeddings=query_embeddings, n_results=top_k
|
|
)
|
|
ids = query_result["ids"]
|
|
scores = query_result["distances"]
|
|
scores = apply_recursive(lambda x: 1 - x, scores)
|
|
return ids, scores
|
|
|
|
async def delete(self, ids: List[str]):
|
|
if isinstance(self.collection, AsyncCollection):
|
|
await self.collection.delete(ids)
|
|
else:
|
|
self.collection.delete(ids)
|