Skip to content

I wrote the HybridSearchAsync feature for this REPO #103

@sky92archangel

Description

@sky92archangel

I wrote the HybridSearchAsync feature for this REPO

Create a new file ”Milvus.Collection.Entity.HybridSearch.py“

using System.Buffers;
using System.Buffers.Binary;
using System.Globalization;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text.Json;
using Google.Protobuf.Collections;
using KeyValuePair = Milvus.Client.Grpc.KeyValuePair;

using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Milvus.Client.Grpc;
using Google.Protobuf;
using System.Net.NetworkInformation;

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Google.Protobuf;
using Milvus.Client.Grpc;
using Milvus.Client;
using static System.Net.Mime.MediaTypeNames;

namespace Milvus.Client;
  
public partial class MilvusCollection
{
    ///// <summary>
    ///// 
    ///// </summary>
    ///// <param name="databaseName"></param>
    ///// <param name="requests"></param>
    ///// <param name="cancellationToken"></param>
    ///// <returns></returns>
    //public async Task<SearchResults> HybridSearchEmbedAsyncX(
    //        string databaseName,
    //        IReadOnlyList<SearchRequest> requests,
    //        CancellationToken cancellationToken = default)
    //{
    //    //  构建混合搜索请求(HybridSearchRequest)
    //    var hybridRequest = new Grpc.HybridSearchRequest
    //    {
    //        Base = new MsgBase { MsgType = MsgType.Search }, // 必选Base字段
    //        CollectionName = Name,
    //        DbName = databaseName, // "technical_economy",
    //        OutputFields = { "*" },
    //        ConsistencyLevel = Grpc.ConsistencyLevel.Eventually,
    //        UseDefaultConsistency = false,
    //        RankParams = // RRF排序参数
    //        {
    //            new KeyValuePair { Key = "ranker", Value = "RRFRanker" },
    //            new KeyValuePair { Key = "k", Value = "60" },
    //            new KeyValuePair { Key = "limit", Value = "60" }
    //        },
    //    };

    //    hybridRequest.Requests.AddRange(requests); // 正确添加子请求(RepeatedField.AddRange)

    //    // 5. 执行混合搜索(修正InvokeAsync调用)
    //    Grpc.SearchResults response = await _client.InvokeAsync(
    //        _client.GrpcClient.HybridSearchAsync, hybridRequest,
    //        static response => response.Status, // 状态提取器
    //        cancellationToken)
    //      .ConfigureAwait(false);
    //    //return response;
    //    // 解析结果
    //    //return ParseResults(response, topK, threshold);
    //    List<FieldData> fieldData = ProcessReturnedFieldData(response.Results.FieldsData);

    //    return new SearchResults
    //    {
    //        CollectionName = response.CollectionName,
    //        FieldsData = fieldData,
    //        Ids = response.Results.Ids is null ? default : MilvusIds.FromGrpc(response.Results.Ids),
    //        NumQueries = response.Results.NumQueries,
    //        Scores = response.Results.Scores,
    //        Limit = response.Results.TopK,
    //        Limits = response.Results.Topks,
    //    };

    //}


    /// <summary>
    /// 
    /// </summary>
    /// <param name="databaseName"></param>
    /// <param name="denseReqs"></param>
    /// <param name="sparseReqs"></param>
    /// <param name="cancellationToken"></param>
    /// <returns></returns>
    public async Task<SearchResults> HybridSearchAsync(
            string databaseName,
            IReadOnlyList<(string annsField, float[] queryVector, string filterExpr)> denseReqs,
            IReadOnlyList<(string annsField, string queryText, string filterExpr)> sparseReqs,
            CancellationToken cancellationToken = default)
    {
        var requests = new List<SearchRequest>();
        {
            Verify.NotNull(denseReqs);
            Verify.NotNull(sparseReqs);
            denseReqs = denseReqs
              .Where(t =>
              !string.IsNullOrWhiteSpace(t.annsField) &&
              t.queryVector != null &&
              t.queryVector.Length > 0)
              .ToList();
            sparseReqs = sparseReqs
              .Where(t =>
              !string.IsNullOrWhiteSpace(t.annsField) &&
              !string.IsNullOrWhiteSpace(t.queryText))
              .ToList();

            //====================================
            foreach (var denseReq in denseReqs)
            {
                var annsField = denseReq.annsField;
                var queryVector = denseReq.queryVector;
                var filterExpr = denseReq.filterExpr;

                var placeholders = new PlaceholderGroup();
                byte[] vectorBytes = new byte[queryVector.Length * 4];
                Buffer.BlockCopy(queryVector, srcOffset: 0, vectorBytes, dstOffset: 0, vectorBytes.Length);
                placeholders.Placeholders.Add(new PlaceholderValue
                {
                    Tag = "$0",
                    Type = PlaceholderType.FloatVector,
                    Values = { ByteString.CopyFrom(vectorBytes) }
                });
                var searchRequest = new SearchRequest
                {
                    Dsl = !string.IsNullOrWhiteSpace(filterExpr) ? filterExpr : "",
                    DslType = DslType.BoolExprV1,
                    PlaceholderGroup = MessageExtensions.ToByteString(placeholders),
                    Nq = 1,
                    SearchParams =
                {
                    new KeyValuePair { Key = "anns_field", Value = annsField }, // 添加这一行
                    new KeyValuePair { Key = "metric_type", Value = "IP" },
                    new KeyValuePair { Key = "params", Value = "{\"nprobe\":32}" },
                }
                };

                requests.Add(searchRequest);
            }

            //====================================

            foreach (var sparseReq in sparseReqs)
            {
                var annsField = sparseReq.annsField;
                var queryText = sparseReq.queryText;
                var filterExpr = sparseReq.filterExpr; 
                //byte[] sparseBytes = System.Text.Encoding.UTF8.GetBytes(queryText); 
                var placeholders = new PlaceholderGroup();
                placeholders.Placeholders.Add(new PlaceholderValue
                {
                    Tag = "$1", 
                    // 改为字符串类型
                    Type = PlaceholderType.VarChar,
                    // 使用StrValues存储字符串,而非字节数组
                    Values = { ByteString.CopyFromUtf8(queryText) }
                });

                var searchRequest = new SearchRequest
                {
                    //Dsl = $"""anns_field {annsField} where {(!string.IsNullOrWhiteSpace(filterExpr) ? filterExpr : "1=1")}""",
                    // 另一种正确写法 - 使用Milvus支持的全匹配表达式
                    Dsl = !string.IsNullOrWhiteSpace(filterExpr) ? filterExpr : "",
                    DslType = DslType.BoolExprV1,
                    PlaceholderGroup = MessageExtensions.ToByteString(placeholders), // 修正序列化
                    Nq = 1,
                    SearchParams =
                {
                    new KeyValuePair { Key = "metric_type", Value = "BM25" },
                    new KeyValuePair { Key = "params", Value = "{\"drop_ratio_build\":0.2}" },
                    new KeyValuePair { Key = "anns_field", Value = annsField }
                }
                };
                requests.Add(searchRequest);
            } 
        }



        //  构建混合搜索请求(HybridSearchRequest)
        var hybridRequest = new Grpc.HybridSearchRequest
        {
            Base = new MsgBase { MsgType = MsgType.Search }, // 必选Base字段
            CollectionName = Name,
            DbName = databaseName, // "technical_economy",
            OutputFields = { "*" },
            ConsistencyLevel = Grpc.ConsistencyLevel.Eventually,
            UseDefaultConsistency = false,
            RankParams = // RRF排序参数
            {
                new KeyValuePair { Key = "ranker", Value = "RRFRanker" },
                new KeyValuePair { Key = "k", Value = "60" },
                new KeyValuePair { Key = "limit", Value = "60" }
            },
        };

        hybridRequest.Requests.AddRange(requests); // 正确添加子请求(RepeatedField.AddRange)

        // 5. 执行混合搜索(修正InvokeAsync调用)
        Grpc.SearchResults response = await _client.InvokeAsync(
            _client.GrpcClient.HybridSearchAsync, hybridRequest,
            static response => response.Status, // 状态提取器
            cancellationToken)
          .ConfigureAwait(false);
        //return response;
        // 解析结果
        //return ParseResults(response, topK, threshold);
        List<FieldData> fieldData = ProcessReturnedFieldData(response.Results.FieldsData);

        return new SearchResults
        {
            CollectionName = response.CollectionName,
            FieldsData = fieldData,
            Ids = response.Results.Ids is null ? default : MilvusIds.FromGrpc(response.Results.Ids),
            NumQueries = response.Results.NumQueries,
            Scores = response.Results.Scores,
            Limit = response.Results.TopK,
            Limits = response.Results.Topks,
        };

        //// 5. 执行混合搜索(修正InvokeAsync调用)
        //SearchResults response = await _client.InvokeAsync(
        //   _client.GrpcClient.HybridSearchAsync, // gRPC客户端方法
        //     hybridRequest, // 请求对象
        //     static response => response.Status, // 状态提取器
        //     cancellationToken
        // ).ConfigureAwait(false);
        //return response;
        // 解析结果
        //return ParseResults(response, topK, threshold);
    }



    private static List<SearchRequest> CreateRequests(
        IReadOnlyList<(string annsField, float[] queryVector, string filterExpr)> denseReqs,
        IReadOnlyList<(string annsField, string queryText, string filterExpr)> sparseReqs
        )
    {
        Verify.NotNull(denseReqs);
        Verify.NotNull(sparseReqs);
        denseReqs = denseReqs
          .Where(t =>
          !string.IsNullOrWhiteSpace(t.annsField) &&
          t.queryVector != null &&
          t.queryVector.Length > 0)
          .ToList();
        sparseReqs = sparseReqs
          .Where(t =>
          !string.IsNullOrWhiteSpace(t.annsField) &&
          !string.IsNullOrWhiteSpace(t.queryText))
          .ToList();

        var requests = new List<SearchRequest>();

        foreach (var denseReq in denseReqs)
        {
            var queryVector = denseReq.queryVector;
            var filterExpr = denseReq.filterExpr;
            var annsField = denseReq.annsField;

            var placeholders = new PlaceholderGroup();
            byte[] vectorBytes = new byte[queryVector.Length * 4];
            Buffer.BlockCopy(queryVector, srcOffset: 0, vectorBytes, dstOffset: 0, vectorBytes.Length);
            placeholders.Placeholders.Add(new PlaceholderValue
            {
                //Tag = "$0",
                Type = PlaceholderType.FloatVector,
                Values = { ByteString.CopyFrom(vectorBytes) }
            });
            var searchRequest = new SearchRequest
            {
                Dsl = !string.IsNullOrWhiteSpace(filterExpr) ? filterExpr : "",
                DslType = DslType.BoolExprV1,
                PlaceholderGroup = MessageExtensions.ToByteString(placeholders),
                Nq = 1,
                SearchParams =
                {
                    new KeyValuePair { Key = "anns_field", Value = annsField }, // 添加这一行
                    new KeyValuePair { Key = "metric_type", Value = "IP" },
                    new KeyValuePair { Key = "params", Value = "{\"nprobe\":32}" },
                }
            };

            requests.Add(searchRequest);
        }


        foreach (var sparseReq in sparseReqs)
        {
            var annsField = sparseReq.annsField;
            var queryText = sparseReq.queryText;
            var filterExpr = sparseReq.filterExpr;

            //byte[] sparseBytes = System.Text.Encoding.UTF8.GetBytes(queryText);

            var placeholders = new PlaceholderGroup();
            placeholders.Placeholders.Add(new PlaceholderValue
            {
                Tag = "$1",
                //Type = PlaceholderType.SparseFloatVector,
                //Values = { ByteString.CopyFrom(sparseBytes) }
                Type = PlaceholderType.VarChar,  // 改为字符串类型
                Values = { ByteString.CopyFromUtf8(queryText) }  // 使用StrValues存储字符串,而非字节数组
            });

            var searchRequest = new SearchRequest
            {
                //Dsl = $"""anns_field {annsField} where {(!string.IsNullOrWhiteSpace(filterExpr) ? filterExpr : "1=1")}""",
                // 另一种正确写法 - 使用Milvus支持的全匹配表达式
                Dsl = !string.IsNullOrWhiteSpace(filterExpr) ? filterExpr : "",
                DslType = DslType.BoolExprV1,
                PlaceholderGroup = MessageExtensions.ToByteString(placeholders), // 修正序列化
                Nq = 1,
                SearchParams =
                {
                    new KeyValuePair { Key = "metric_type", Value = "BM25" },
                    new KeyValuePair { Key = "params", Value = "{\"drop_ratio_build\":0.2}" },
                    new KeyValuePair { Key = "anns_field", Value = annsField }

                }
            };
            requests.Add(searchRequest);
        }

        return requests;

    }


    private static SearchRequest CreateDenseRequest(
       string annsField,
       float[] queryVector,
       string filterExpr,
       int nq = 1)
    {
        var placeholders = new PlaceholderGroup();

        byte[] vectorBytes = new byte[queryVector.Length * 4];
        Buffer.BlockCopy(queryVector, srcOffset: 0, vectorBytes, dstOffset: 0, vectorBytes.Length);

        placeholders.Placeholders.Add(new PlaceholderValue
        {
            Tag = "$0",
            Type = PlaceholderType.FloatVector,
            Values = { ByteString.CopyFrom(vectorBytes) }
        });

        return new SearchRequest
        {
            //Dsl = $"""anns_field {annsField} where {(!string.IsNullOrWhiteSpace(filterExpr) ? filterExpr : "1=1")}""",
            // 另一种正确写法 - 使用Milvus支持的全匹配表达式
            Dsl = !string.IsNullOrWhiteSpace(filterExpr) ? filterExpr : "",
            DslType = DslType.BoolExprV1,
            // 修正序列化:使用MessageExtensions.ToByteString(因PlaceholderGroup实现IMessage)
            PlaceholderGroup = MessageExtensions.ToByteString(placeholders),
            Nq = nq,
            SearchParams =
            {
                new KeyValuePair { Key = "anns_field", Value = annsField }, // 添加这一行
                new KeyValuePair { Key = "metric_type", Value = "IP" },
                new KeyValuePair { Key = "params", Value = "{\"nprobe\":32}" },
            }
        };
    }

    private static SearchRequest CreateSparseRequest(
      string annsField,
      string queryText,
      string filterExpr,
      int nq = 1)
    {
        var placeholders = new PlaceholderGroup();

        // 实际使用时需通过Milvus稀疏编码器转换文本为稀疏向量
        //byte[] sparseBytes = System.Text.Encoding.UTF8.GetBytes(queryText);

        placeholders.Placeholders.Add(new PlaceholderValue
        {
            //Tag = "$1",
            //Type = PlaceholderType.SparseFloatVector,
            //Values = { ByteString.CopyFrom(sparseBytes) },
            Type = PlaceholderType.VarChar,  // 改为字符串类型
            Values = { ByteString.CopyFromUtf8(queryText) }  // 使用StrValues存储字符串,而非字节数组
        });

        return new SearchRequest
        {
            //Dsl = $"""anns_field {annsField} where {(!string.IsNullOrWhiteSpace(filterExpr) ? filterExpr : "1=1")}""",
            // 另一种正确写法 - 使用Milvus支持的全匹配表达式
            Dsl = !string.IsNullOrWhiteSpace(filterExpr) ? filterExpr : "",
            DslType = DslType.BoolExprV1,
            PlaceholderGroup = MessageExtensions.ToByteString(placeholders), // 修正序列化
            Nq = nq,
            SearchParams =
            {
                new KeyValuePair { Key = "anns_field", Value = annsField },
                new KeyValuePair { Key = "metric_type", Value = "BM25" },
                new KeyValuePair { Key = "params", Value = "{\"drop_ratio_build\":0.2}" },

            }
        };
    }
     
}

  /// <summary>
  /// 解析搜索结果
  /// </summary> 
  private static List<Dictionary<string, object>> ParseResults(
  Grpc.SearchResults? response,
  int topK,
  float threshold)
  {
      if (response?.Status?.Code != 0)
      {
          throw new InvalidOperationException($"搜索失败: {response?.Status?.Reason ?? "未知错误"}");
      }

      var allEntities = new List<Dictionary<string, object>>();

      // 遍历SearchResults中的SearchResultData(每个查询对应一个SearchResultData)
      if (response?.Results != null)
      {
          // 调用上述方法解析单个SearchResultData
          var entities = ParseSearchResultData(response.Results);
          allEntities.AddRange(entities);
      }

      // 过滤低于阈值的结果并取topK
      return allEntities
          .Where(e => !(e.TryGetValue("distance", out var d) && d is float dist && dist < threshold))
          .OrderByDescending(e => e.TryGetValue("distance", out var d) ? (float)d : 0)
          .Take(topK)
          .ToList();
  }


  /// <summary>
  /// 解析SearchResultData,遍历所有实体
  /// </summary>
  private static List<Dictionary<string, object>> ParseSearchResultData(SearchResultData resultData)
  {
      var entities = new List<Dictionary<string, object>>();

      // 1. 获取总查询数和每个查询的结果数
      int numQueries = (int)resultData.NumQueries;
      int topK = (int)resultData.TopK;

      // 2. 遍历每个查询的结果(按查询索引)
      for (int queryIndex = 0; queryIndex < numQueries; queryIndex++)
      {
          // 3. 遍历当前查询的TopK结果(按结果索引)
          for (int resultIndex = 0; resultIndex < topK; resultIndex++)
          {
              // 计算全局索引(所有查询的结果按顺序存储)
              int globalIndex = queryIndex * topK + resultIndex;

              // 4. 构建实体字典
              var entity = new Dictionary<string, object>();

              // 4.1 添加分数/距离(如果有)
              if (resultData.Scores.Count > globalIndex)
                  entity["score"] = resultData.Scores[globalIndex];
              if (resultData.Distances.Count > globalIndex)
                  entity["distance"] = resultData.Distances[globalIndex];

              // 4.2 添加主键ID(如果有)
              if (resultData.Ids?.IntId?.Data?.Count > globalIndex)
                  entity["id"] = resultData.Ids.IntId.Data[globalIndex];
              else if (resultData.Ids?.StrId?.Data?.Count > globalIndex)
                  entity["id"] = resultData.Ids.StrId.Data[globalIndex];

              // 4.3 添加其他字段数据(如属性字段)
              foreach (var fieldData in resultData.FieldsData)
              {
                  // 根据字段类型提取对应索引的值
                  object? value = fieldData.Scalars switch
                  {
                      // 整数类型字段
                      { LongData: { Data: { Count: > 0 } data } } => data[globalIndex],
                      // 浮点数类型字段
                      { FloatData: { Data: { Count: > 0 } data } } => data[globalIndex],
                      // 字符串类型字段
                      { StringData: { Data: { Count: > 0 } data } } => data[globalIndex],
                      // 其他类型(如布尔、日期等)可按需扩展
                      _ => null
                  };

                  if (value != null)
                      entity[fieldData.FieldName] = value;
              }

              entities.Add(entity);
          }
      }

      return entities;
  }

how to use ~ :

   var denseReqs = new List<(string, float[], string)>() {
       new ("describ_dense", searchVectors[0].ToArray(), "" ) };
   var sparsReqs = new List<(string, string, string)>() {
       new ("name_property_bm25_sparse", "test_text_",""),
       new ("describ_bm25_sparse", "test_text_1",""),
   };

   var response = await collection!.HybridSearchAsync(
       databaseName: "test",
       denseReqs,
       sparsReqs);

  var results = ParseResults(response, topK, threshold);

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions