package com.springai;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import io.qdrant.client.grpc.Collections;
import io.qdrant.client.QdrantClient;
import io.qdrant.client.QdrantGrpcClient;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import org.checkerframework.checker.units.qual.A;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.util.CollectionUtils;

import java.util.List;

import static dev.langchain4j.store.embedding.filter.MetadataFilterBuilder.metadataKey;


@SpringBootTest
public class AppTest {
    @Autowired
    QdrantClient qdrantGrpcClient;
    @Autowired
    EmbeddingStore embeddingStore;
    @Autowired
    EmbeddingModel embeddingModel;


    //创建qdrant索引
    @Test
    public void testApp() {
        var vectorParams = Collections.VectorParams.newBuilder()
                .setDistance(Collections.Distance.Cosine)
                .setSize(1024)
                .build();
        qdrantGrpcClient.createCollectionAsync("testv", vectorParams);
        System.out.println("创建成功");

    }

    //存储向量数据
    @Test
    public void testApp2() {
//        TextSegment from1 = TextSegment.from("客服的电话是400-3464563");
//        TextSegment from2 = TextSegment.from("客服工作时间是周一到周五");
//        TextSegment from3 = TextSegment.from("客服的投诉电话是400-123456");
//        TextSegment from4 = TextSegment.from("客服的人数是245");
//        // 转换向量
//        Embedding content1 = embeddingModel.embed(from1).content();
//        Embedding content2 = embeddingModel.embed(from2).content();
//        Embedding content3 = embeddingModel.embed(from3).content();
//        Embedding content4 = embeddingModel.embed(from4).content();
//        // 存储入向量数据库
//        embeddingStore.add(content1, from1);
//        embeddingStore.add(content2, from2);
//        embeddingStore.add(content3, from3);
//        embeddingStore.add(content4, from4);
        TextSegment from1 = TextSegment.from("客服女生人数是55人");
        from1.metadata().put("author", "lisi");
        Embedding content1 = embeddingModel.embed(from1).content();
        embeddingStore.add(content1, from1);

    }

    //查询向量数据
    @Test
    public void testApp3() {
        //问题
        String msg = "你们的客服人数多少";
        //问题向量
        Embedding embedding = embeddingModel.embed(msg).content();
        EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
                .maxResults(1)
                .queryEmbedding(embedding)
                .filter(metadataKey("author").isEqualTo("lisi"))
                .build();
        // 2. 在向量库中搜索相似内容(这里过滤只找作者为"lisi"的内容)

        EmbeddingSearchResult searchRequest = embeddingStore.search(request);
        System.out.println(searchRequest.matches().get(0));

        if (!CollectionUtils.isEmpty(searchRequest.matches())) {
            // 3. 获取所有匹配结果
            List matches = searchRequest.matches();
            // 4. 遍历处理每个匹配项
            for (Object match : matches) {
                System.out.println(match);
            }
        }

    }

}

常用过滤器:

Filter名称 功能 使用示例
And 同时满足多个条件 Filter.and(condition1, condition2)
Or 满足其中任意一个条件 Filter.or(condition1, condition2)
Not 不满足条件 Filter.not(condition)
IsEqualTo 等于 new IsEqualTo("field", "value")
IsGreaterThan 大于 new IsGreaterThan("field", value)
IsLessThan 小于 new IsLessThan("field", value)
IsIn 在列表内 new IsIn("field", listOfValues)
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐