-
Notifications
You must be signed in to change notification settings - Fork 0
/
similarity_search.py
75 lines (56 loc) · 1.91 KB
/
similarity_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from pymilvus import connections, Collection
import json
import numpy as np
import gensim.downloader as api
import random
# Load the Google pretrained word2vec model
model = api.load("word2vec-google-news-300")
# Connect to Milvus server
connections.connect(host='localhost', port='19530')
def calculate_average_vector(text):
tokens = text.split()
vectors = [model[token] for token in tokens if token in model]
if vectors:
return np.mean(vectors, axis=0)
else:
return np.zeros(model.vector_size)
# Function to process data and store embeddings
def process_and_store_embeddings(data):
title_vector = calculate_average_vector(data.get('title', ''))
description_vector = calculate_average_vector(data.get('description', ''))
combined_vector = np.concatenate([title_vector, description_vector])
return combined_vector
# Load the sample data from a JSON file
json_file_path = "sample.json"
with open(json_file_path, 'r', encoding='utf-8') as f:
sample_data = json.load(f)
# Process and store the embeddings for the sample data
sample_embedding = process_and_store_embeddings(sample_data)
# # prinnt the sample embedding sepated by a comma
print(", ".join([str(x) for x in sample_embedding]))
# Create a collection object
collection_name = "tender_collection"
collection = Collection(collection_name)
collection.load()
print("Collection loaded successfully")
# Search for similar embeddings from the collection
# search_params = {
# "metric_type": "L2",
# "params": {"IVF_FLAT": "flat", "nprobe": 16},
# }
results = collection.search(
data=sample_embedding,
anns_field="embedding",
limit=10,
expr=None,
param={
"metric_type": "L2",
"params": {"nprobe": 16}
},
output_fields=['idNumber'],
# output_fields=['idNumber'],
# consistency_level="Strong"
)
print("Search results:", results)
collection.release()
connections.disconnect()