-
Notifications
You must be signed in to change notification settings - Fork 41
Open
Description
I wrote the HybridSearchAsync feature for this REPO
Create a new file ”Milvus.Collection.Entity.HybridSearch.py“
using System.Buffers;
using System.Buffers.Binary;
using System.Globalization;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text.Json;
using Google.Protobuf.Collections;
using KeyValuePair = Milvus.Client.Grpc.KeyValuePair;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Milvus.Client.Grpc;
using Google.Protobuf;
using System.Net.NetworkInformation;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Google.Protobuf;
using Milvus.Client.Grpc;
using Milvus.Client;
using static System.Net.Mime.MediaTypeNames;
namespace Milvus.Client;
public partial class MilvusCollection
{
///// <summary>
/////
///// </summary>
///// <param name="databaseName"></param>
///// <param name="requests"></param>
///// <param name="cancellationToken"></param>
///// <returns></returns>
//public async Task<SearchResults> HybridSearchEmbedAsyncX(
// string databaseName,
// IReadOnlyList<SearchRequest> requests,
// CancellationToken cancellationToken = default)
//{
// // 构建混合搜索请求(HybridSearchRequest)
// var hybridRequest = new Grpc.HybridSearchRequest
// {
// Base = new MsgBase { MsgType = MsgType.Search }, // 必选Base字段
// CollectionName = Name,
// DbName = databaseName, // "technical_economy",
// OutputFields = { "*" },
// ConsistencyLevel = Grpc.ConsistencyLevel.Eventually,
// UseDefaultConsistency = false,
// RankParams = // RRF排序参数
// {
// new KeyValuePair { Key = "ranker", Value = "RRFRanker" },
// new KeyValuePair { Key = "k", Value = "60" },
// new KeyValuePair { Key = "limit", Value = "60" }
// },
// };
// hybridRequest.Requests.AddRange(requests); // 正确添加子请求(RepeatedField.AddRange)
// // 5. 执行混合搜索(修正InvokeAsync调用)
// Grpc.SearchResults response = await _client.InvokeAsync(
// _client.GrpcClient.HybridSearchAsync, hybridRequest,
// static response => response.Status, // 状态提取器
// cancellationToken)
// .ConfigureAwait(false);
// //return response;
// // 解析结果
// //return ParseResults(response, topK, threshold);
// List<FieldData> fieldData = ProcessReturnedFieldData(response.Results.FieldsData);
// return new SearchResults
// {
// CollectionName = response.CollectionName,
// FieldsData = fieldData,
// Ids = response.Results.Ids is null ? default : MilvusIds.FromGrpc(response.Results.Ids),
// NumQueries = response.Results.NumQueries,
// Scores = response.Results.Scores,
// Limit = response.Results.TopK,
// Limits = response.Results.Topks,
// };
//}
/// <summary>
///
/// </summary>
/// <param name="databaseName"></param>
/// <param name="denseReqs"></param>
/// <param name="sparseReqs"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public async Task<SearchResults> HybridSearchAsync(
string databaseName,
IReadOnlyList<(string annsField, float[] queryVector, string filterExpr)> denseReqs,
IReadOnlyList<(string annsField, string queryText, string filterExpr)> sparseReqs,
CancellationToken cancellationToken = default)
{
var requests = new List<SearchRequest>();
{
Verify.NotNull(denseReqs);
Verify.NotNull(sparseReqs);
denseReqs = denseReqs
.Where(t =>
!string.IsNullOrWhiteSpace(t.annsField) &&
t.queryVector != null &&
t.queryVector.Length > 0)
.ToList();
sparseReqs = sparseReqs
.Where(t =>
!string.IsNullOrWhiteSpace(t.annsField) &&
!string.IsNullOrWhiteSpace(t.queryText))
.ToList();
//====================================
foreach (var denseReq in denseReqs)
{
var annsField = denseReq.annsField;
var queryVector = denseReq.queryVector;
var filterExpr = denseReq.filterExpr;
var placeholders = new PlaceholderGroup();
byte[] vectorBytes = new byte[queryVector.Length * 4];
Buffer.BlockCopy(queryVector, srcOffset: 0, vectorBytes, dstOffset: 0, vectorBytes.Length);
placeholders.Placeholders.Add(new PlaceholderValue
{
Tag = "$0",
Type = PlaceholderType.FloatVector,
Values = { ByteString.CopyFrom(vectorBytes) }
});
var searchRequest = new SearchRequest
{
Dsl = !string.IsNullOrWhiteSpace(filterExpr) ? filterExpr : "",
DslType = DslType.BoolExprV1,
PlaceholderGroup = MessageExtensions.ToByteString(placeholders),
Nq = 1,
SearchParams =
{
new KeyValuePair { Key = "anns_field", Value = annsField }, // 添加这一行
new KeyValuePair { Key = "metric_type", Value = "IP" },
new KeyValuePair { Key = "params", Value = "{\"nprobe\":32}" },
}
};
requests.Add(searchRequest);
}
//====================================
foreach (var sparseReq in sparseReqs)
{
var annsField = sparseReq.annsField;
var queryText = sparseReq.queryText;
var filterExpr = sparseReq.filterExpr;
//byte[] sparseBytes = System.Text.Encoding.UTF8.GetBytes(queryText);
var placeholders = new PlaceholderGroup();
placeholders.Placeholders.Add(new PlaceholderValue
{
Tag = "$1",
// 改为字符串类型
Type = PlaceholderType.VarChar,
// 使用StrValues存储字符串,而非字节数组
Values = { ByteString.CopyFromUtf8(queryText) }
});
var searchRequest = new SearchRequest
{
//Dsl = $"""anns_field {annsField} where {(!string.IsNullOrWhiteSpace(filterExpr) ? filterExpr : "1=1")}""",
// 另一种正确写法 - 使用Milvus支持的全匹配表达式
Dsl = !string.IsNullOrWhiteSpace(filterExpr) ? filterExpr : "",
DslType = DslType.BoolExprV1,
PlaceholderGroup = MessageExtensions.ToByteString(placeholders), // 修正序列化
Nq = 1,
SearchParams =
{
new KeyValuePair { Key = "metric_type", Value = "BM25" },
new KeyValuePair { Key = "params", Value = "{\"drop_ratio_build\":0.2}" },
new KeyValuePair { Key = "anns_field", Value = annsField }
}
};
requests.Add(searchRequest);
}
}
// 构建混合搜索请求(HybridSearchRequest)
var hybridRequest = new Grpc.HybridSearchRequest
{
Base = new MsgBase { MsgType = MsgType.Search }, // 必选Base字段
CollectionName = Name,
DbName = databaseName, // "technical_economy",
OutputFields = { "*" },
ConsistencyLevel = Grpc.ConsistencyLevel.Eventually,
UseDefaultConsistency = false,
RankParams = // RRF排序参数
{
new KeyValuePair { Key = "ranker", Value = "RRFRanker" },
new KeyValuePair { Key = "k", Value = "60" },
new KeyValuePair { Key = "limit", Value = "60" }
},
};
hybridRequest.Requests.AddRange(requests); // 正确添加子请求(RepeatedField.AddRange)
// 5. 执行混合搜索(修正InvokeAsync调用)
Grpc.SearchResults response = await _client.InvokeAsync(
_client.GrpcClient.HybridSearchAsync, hybridRequest,
static response => response.Status, // 状态提取器
cancellationToken)
.ConfigureAwait(false);
//return response;
// 解析结果
//return ParseResults(response, topK, threshold);
List<FieldData> fieldData = ProcessReturnedFieldData(response.Results.FieldsData);
return new SearchResults
{
CollectionName = response.CollectionName,
FieldsData = fieldData,
Ids = response.Results.Ids is null ? default : MilvusIds.FromGrpc(response.Results.Ids),
NumQueries = response.Results.NumQueries,
Scores = response.Results.Scores,
Limit = response.Results.TopK,
Limits = response.Results.Topks,
};
//// 5. 执行混合搜索(修正InvokeAsync调用)
//SearchResults response = await _client.InvokeAsync(
// _client.GrpcClient.HybridSearchAsync, // gRPC客户端方法
// hybridRequest, // 请求对象
// static response => response.Status, // 状态提取器
// cancellationToken
// ).ConfigureAwait(false);
//return response;
// 解析结果
//return ParseResults(response, topK, threshold);
}
private static List<SearchRequest> CreateRequests(
IReadOnlyList<(string annsField, float[] queryVector, string filterExpr)> denseReqs,
IReadOnlyList<(string annsField, string queryText, string filterExpr)> sparseReqs
)
{
Verify.NotNull(denseReqs);
Verify.NotNull(sparseReqs);
denseReqs = denseReqs
.Where(t =>
!string.IsNullOrWhiteSpace(t.annsField) &&
t.queryVector != null &&
t.queryVector.Length > 0)
.ToList();
sparseReqs = sparseReqs
.Where(t =>
!string.IsNullOrWhiteSpace(t.annsField) &&
!string.IsNullOrWhiteSpace(t.queryText))
.ToList();
var requests = new List<SearchRequest>();
foreach (var denseReq in denseReqs)
{
var queryVector = denseReq.queryVector;
var filterExpr = denseReq.filterExpr;
var annsField = denseReq.annsField;
var placeholders = new PlaceholderGroup();
byte[] vectorBytes = new byte[queryVector.Length * 4];
Buffer.BlockCopy(queryVector, srcOffset: 0, vectorBytes, dstOffset: 0, vectorBytes.Length);
placeholders.Placeholders.Add(new PlaceholderValue
{
//Tag = "$0",
Type = PlaceholderType.FloatVector,
Values = { ByteString.CopyFrom(vectorBytes) }
});
var searchRequest = new SearchRequest
{
Dsl = !string.IsNullOrWhiteSpace(filterExpr) ? filterExpr : "",
DslType = DslType.BoolExprV1,
PlaceholderGroup = MessageExtensions.ToByteString(placeholders),
Nq = 1,
SearchParams =
{
new KeyValuePair { Key = "anns_field", Value = annsField }, // 添加这一行
new KeyValuePair { Key = "metric_type", Value = "IP" },
new KeyValuePair { Key = "params", Value = "{\"nprobe\":32}" },
}
};
requests.Add(searchRequest);
}
foreach (var sparseReq in sparseReqs)
{
var annsField = sparseReq.annsField;
var queryText = sparseReq.queryText;
var filterExpr = sparseReq.filterExpr;
//byte[] sparseBytes = System.Text.Encoding.UTF8.GetBytes(queryText);
var placeholders = new PlaceholderGroup();
placeholders.Placeholders.Add(new PlaceholderValue
{
Tag = "$1",
//Type = PlaceholderType.SparseFloatVector,
//Values = { ByteString.CopyFrom(sparseBytes) }
Type = PlaceholderType.VarChar, // 改为字符串类型
Values = { ByteString.CopyFromUtf8(queryText) } // 使用StrValues存储字符串,而非字节数组
});
var searchRequest = new SearchRequest
{
//Dsl = $"""anns_field {annsField} where {(!string.IsNullOrWhiteSpace(filterExpr) ? filterExpr : "1=1")}""",
// 另一种正确写法 - 使用Milvus支持的全匹配表达式
Dsl = !string.IsNullOrWhiteSpace(filterExpr) ? filterExpr : "",
DslType = DslType.BoolExprV1,
PlaceholderGroup = MessageExtensions.ToByteString(placeholders), // 修正序列化
Nq = 1,
SearchParams =
{
new KeyValuePair { Key = "metric_type", Value = "BM25" },
new KeyValuePair { Key = "params", Value = "{\"drop_ratio_build\":0.2}" },
new KeyValuePair { Key = "anns_field", Value = annsField }
}
};
requests.Add(searchRequest);
}
return requests;
}
private static SearchRequest CreateDenseRequest(
string annsField,
float[] queryVector,
string filterExpr,
int nq = 1)
{
var placeholders = new PlaceholderGroup();
byte[] vectorBytes = new byte[queryVector.Length * 4];
Buffer.BlockCopy(queryVector, srcOffset: 0, vectorBytes, dstOffset: 0, vectorBytes.Length);
placeholders.Placeholders.Add(new PlaceholderValue
{
Tag = "$0",
Type = PlaceholderType.FloatVector,
Values = { ByteString.CopyFrom(vectorBytes) }
});
return new SearchRequest
{
//Dsl = $"""anns_field {annsField} where {(!string.IsNullOrWhiteSpace(filterExpr) ? filterExpr : "1=1")}""",
// 另一种正确写法 - 使用Milvus支持的全匹配表达式
Dsl = !string.IsNullOrWhiteSpace(filterExpr) ? filterExpr : "",
DslType = DslType.BoolExprV1,
// 修正序列化:使用MessageExtensions.ToByteString(因PlaceholderGroup实现IMessage)
PlaceholderGroup = MessageExtensions.ToByteString(placeholders),
Nq = nq,
SearchParams =
{
new KeyValuePair { Key = "anns_field", Value = annsField }, // 添加这一行
new KeyValuePair { Key = "metric_type", Value = "IP" },
new KeyValuePair { Key = "params", Value = "{\"nprobe\":32}" },
}
};
}
private static SearchRequest CreateSparseRequest(
string annsField,
string queryText,
string filterExpr,
int nq = 1)
{
var placeholders = new PlaceholderGroup();
// 实际使用时需通过Milvus稀疏编码器转换文本为稀疏向量
//byte[] sparseBytes = System.Text.Encoding.UTF8.GetBytes(queryText);
placeholders.Placeholders.Add(new PlaceholderValue
{
//Tag = "$1",
//Type = PlaceholderType.SparseFloatVector,
//Values = { ByteString.CopyFrom(sparseBytes) },
Type = PlaceholderType.VarChar, // 改为字符串类型
Values = { ByteString.CopyFromUtf8(queryText) } // 使用StrValues存储字符串,而非字节数组
});
return new SearchRequest
{
//Dsl = $"""anns_field {annsField} where {(!string.IsNullOrWhiteSpace(filterExpr) ? filterExpr : "1=1")}""",
// 另一种正确写法 - 使用Milvus支持的全匹配表达式
Dsl = !string.IsNullOrWhiteSpace(filterExpr) ? filterExpr : "",
DslType = DslType.BoolExprV1,
PlaceholderGroup = MessageExtensions.ToByteString(placeholders), // 修正序列化
Nq = nq,
SearchParams =
{
new KeyValuePair { Key = "anns_field", Value = annsField },
new KeyValuePair { Key = "metric_type", Value = "BM25" },
new KeyValuePair { Key = "params", Value = "{\"drop_ratio_build\":0.2}" },
}
};
}
}
/// <summary>
/// 解析搜索结果
/// </summary>
private static List<Dictionary<string, object>> ParseResults(
Grpc.SearchResults? response,
int topK,
float threshold)
{
if (response?.Status?.Code != 0)
{
throw new InvalidOperationException($"搜索失败: {response?.Status?.Reason ?? "未知错误"}");
}
var allEntities = new List<Dictionary<string, object>>();
// 遍历SearchResults中的SearchResultData(每个查询对应一个SearchResultData)
if (response?.Results != null)
{
// 调用上述方法解析单个SearchResultData
var entities = ParseSearchResultData(response.Results);
allEntities.AddRange(entities);
}
// 过滤低于阈值的结果并取topK
return allEntities
.Where(e => !(e.TryGetValue("distance", out var d) && d is float dist && dist < threshold))
.OrderByDescending(e => e.TryGetValue("distance", out var d) ? (float)d : 0)
.Take(topK)
.ToList();
}
/// <summary>
/// 解析SearchResultData,遍历所有实体
/// </summary>
private static List<Dictionary<string, object>> ParseSearchResultData(SearchResultData resultData)
{
var entities = new List<Dictionary<string, object>>();
// 1. 获取总查询数和每个查询的结果数
int numQueries = (int)resultData.NumQueries;
int topK = (int)resultData.TopK;
// 2. 遍历每个查询的结果(按查询索引)
for (int queryIndex = 0; queryIndex < numQueries; queryIndex++)
{
// 3. 遍历当前查询的TopK结果(按结果索引)
for (int resultIndex = 0; resultIndex < topK; resultIndex++)
{
// 计算全局索引(所有查询的结果按顺序存储)
int globalIndex = queryIndex * topK + resultIndex;
// 4. 构建实体字典
var entity = new Dictionary<string, object>();
// 4.1 添加分数/距离(如果有)
if (resultData.Scores.Count > globalIndex)
entity["score"] = resultData.Scores[globalIndex];
if (resultData.Distances.Count > globalIndex)
entity["distance"] = resultData.Distances[globalIndex];
// 4.2 添加主键ID(如果有)
if (resultData.Ids?.IntId?.Data?.Count > globalIndex)
entity["id"] = resultData.Ids.IntId.Data[globalIndex];
else if (resultData.Ids?.StrId?.Data?.Count > globalIndex)
entity["id"] = resultData.Ids.StrId.Data[globalIndex];
// 4.3 添加其他字段数据(如属性字段)
foreach (var fieldData in resultData.FieldsData)
{
// 根据字段类型提取对应索引的值
object? value = fieldData.Scalars switch
{
// 整数类型字段
{ LongData: { Data: { Count: > 0 } data } } => data[globalIndex],
// 浮点数类型字段
{ FloatData: { Data: { Count: > 0 } data } } => data[globalIndex],
// 字符串类型字段
{ StringData: { Data: { Count: > 0 } data } } => data[globalIndex],
// 其他类型(如布尔、日期等)可按需扩展
_ => null
};
if (value != null)
entity[fieldData.FieldName] = value;
}
entities.Add(entity);
}
}
return entities;
}
how to use ~ :
var denseReqs = new List<(string, float[], string)>() {
new ("describ_dense", searchVectors[0].ToArray(), "" ) };
var sparsReqs = new List<(string, string, string)>() {
new ("name_property_bm25_sparse", "test_text_",""),
new ("describ_bm25_sparse", "test_text_1",""),
};
var response = await collection!.HybridSearchAsync(
databaseName: "test",
denseReqs,
sparsReqs);
var results = ParseResults(response, topK, threshold);
jlong23
Metadata
Metadata
Assignees
Labels
No labels