Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ai-cache] Implement a WASM plugin for LLM result retrieval based on vector similarity #1290

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
4f7bfbd
fix bugs
johnlanni Jul 31, 2024
0f9e816
fix bugs
Suchun-sv Aug 1, 2024
ff1bce6
fix bugs
Suchun-sv Aug 12, 2024
1e9d42e
init
EnableAsync Aug 15, 2024
f2a9ff6
fix conflict
Suchun-sv Aug 23, 2024
5cbae03
Merge branch 'alibaba:main' into main
Suchun-sv Aug 23, 2024
27b2f71
alter some errors
Suchun-sv Aug 24, 2024
130f2ee
fix: embedding error
EnableAsync Aug 24, 2024
56314d7
fix bugs && update interface design
Suchun-sv Aug 24, 2024
85549d0
fix bugs && refine the variable names
Suchun-sv Aug 25, 2024
8444f5e
update design for cache to support extension
Suchun-sv Aug 25, 2024
a655bc4
Merge branch 'alibaba:main' into main
Suchun-sv Sep 5, 2024
d68fa88
Refined the code; README.md content needs to be updated.
Suchun-sv Sep 5, 2024
5179392
fix bugs, README.md to be updated
Suchun-sv Sep 6, 2024
ece7e2f
fix bugs, refine variable name, update README.md
Suchun-sv Sep 6, 2024
e868a1a
Merge branch 'alibaba:main' into main
Suchun-sv Sep 6, 2024
138a526
delete folder
Suchun-sv Sep 6, 2024
e8ad550
fix typos
Suchun-sv Sep 6, 2024
c83f5c4
fix typos
Suchun-sv Sep 6, 2024
f3d3292
change append to appendMsg
Suchun-sv Sep 6, 2024
b0cf29d
fix bugs and refine code
Suchun-sv Sep 11, 2024
4a18f96
Merge branch 'main' into main
Suchun-sv Sep 11, 2024
21c9a79
fix bugs and update the SetEx function
Suchun-sv Sep 12, 2024
1767896
Merge branch 'main' into main
Suchun-sv Sep 12, 2024
71b9530
Optimize query flow logic (not fully tested)
Suchun-sv Sep 17, 2024
51b9ccc
Fix bugs and verify removal of cache setting
Suchun-sv Sep 21, 2024
3583bc9
fix bugs and update logic as requested
Suchun-sv Sep 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions plugins/wasm-go/extensions/ai-cache/cache/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ type ProviderConfig struct {
cacheKeyPrefix string
}

func (c *ProviderConfig) GetProviderType() string {
return c.typ
}

func (c *ProviderConfig) FromJson(json gjson.Result) {
c.typ = json.Get("type").String()
c.serviceName = json.Get("serviceName").String()
Expand Down
43 changes: 29 additions & 14 deletions plugins/wasm-go/extensions/ai-cache/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ type PluginConfig struct {
}

func (c *PluginConfig) FromJson(json gjson.Result) {
c.embeddingProviderConfig.FromJson(json.Get("embedding"))
c.vectorProviderConfig.FromJson(json.Get("vector"))
c.embeddingProviderConfig.FromJson(json.Get("embedding"))
c.cacheProviderConfig.FromJson(json.Get("cache"))

c.CacheKeyFrom = json.Get("cacheKeyFrom").String()
Expand All @@ -54,20 +54,25 @@ func (c *PluginConfig) FromJson(json gjson.Result) {

c.StreamResponseTemplate = json.Get("streamResponseTemplate").String()
if c.StreamResponseTemplate == "" {
c.StreamResponseTemplate = `data:{"id":"ai-cache.hit","choices":[{"index":0,"delta":{"role":"assistant","content":%s},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n"
c.StreamResponseTemplate = `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n"
}
c.ResponseTemplate = json.Get("responseTemplate").String()
if c.ResponseTemplate == "" {
c.ResponseTemplate = `{"id":"ai-cache.hit","choices":[{"index":0,"message":{"role":"assistant","content":%s},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
c.ResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
}
}

func (c *PluginConfig) Validate() error {
if err := c.cacheProviderConfig.Validate(); err != nil {
return err
// if cache provider is configured, validate it
if c.cacheProviderConfig.GetProviderType() != "" {
if err := c.cacheProviderConfig.Validate(); err != nil {
return err
}
}
if err := c.embeddingProviderConfig.Validate(); err != nil {
return err
if c.embeddingProviderConfig.GetProviderType() != "" {
if err := c.embeddingProviderConfig.Validate(); err != nil {
return err
}
}
if err := c.vectorProviderConfig.Validate(); err != nil {
return err
Expand All @@ -77,15 +82,25 @@ func (c *PluginConfig) Validate() error {

func (c *PluginConfig) Complete(log wrapper.Log) error {
var err error
c.embeddingProvider, err = embedding.CreateProvider(c.embeddingProviderConfig)
if err != nil {
return err
if c.embeddingProviderConfig.GetProviderType() != "" {
c.embeddingProvider, err = embedding.CreateProvider(c.embeddingProviderConfig)
if err != nil {
return err
}
} else {
log.Info("embedding provider is not configured")
c.embeddingProvider = nil
}
c.vectorProvider, err = vector.CreateProvider(c.vectorProviderConfig)
if err != nil {
return err
if c.cacheProviderConfig.GetProviderType() != "" {
c.cacheProvider, err = cache.CreateProvider(c.cacheProviderConfig)
if err != nil {
return err
}
} else {
log.Info("cache provider is not configured")
c.cacheProvider = nil
}
c.cacheProvider, err = cache.CreateProvider(c.cacheProviderConfig)
c.vectorProvider, err = vector.CreateProvider(c.vectorProviderConfig)
if err != nil {
return err
}
Expand Down
44 changes: 25 additions & 19 deletions plugins/wasm-go/extensions/ai-cache/core.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
package main

import (
"encoding/json"
"errors"
"fmt"

"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/config"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-cache/vector"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/go-errors/errors"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/tidwall/resp"
)
Expand Down Expand Up @@ -50,29 +49,32 @@ func handleCacheResponse(key string, response resp.Value, ctx wrapper.HttpContex
if useSimilaritySearch {
if err := performSimilaritySearch(key, ctx, config, log, key, stream); err != nil {
log.Errorf("Failed to perform similarity search for key: %s, error: %v", key, err)
proxywasm.ResumeHttpRequest()
}
} else {
proxywasm.ResumeHttpRequest()
}
proxywasm.ResumeHttpRequest()
}

// processCacheHit handles a successful cache hit.
func processCacheHit(key string, response string, stream bool, ctx wrapper.HttpContext, config config.PluginConfig, log wrapper.Log) {
escapedResponse, err := json.Marshal(response)
log.Debugf("Cached response for key %s: %s", key, escapedResponse)

if err != nil {
handleInternalError(err, "Failed to marshal cached response", log)
return
if stream {
log.Debug("streaming response is not supported for cache hit yet")
stream = false
}
// escapedResponse, err := json.Marshal(response)
// log.Debugf("Cached response for key %s: %s", key, escapedResponse)

ctx.SetContext(CACHE_KEY_CONTEXT_KEY, nil)
// if err != nil {
// handleInternalError(err, "Failed to marshal cached response", log)
// return
// }
log.Debugf("Cached response for key %s: %s", key, response)

contentType := "application/json; charset=utf-8"
if stream {
contentType = "text/event-stream; charset=utf-8"
}
ctx.SetContext(CACHE_KEY_CONTEXT_KEY, nil)

proxywasm.SendHttpResponse(200, [][2]string{{"content-type", contentType}}, []byte(fmt.Sprintf(config.ResponseTemplate, escapedResponse)), -1)
// proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(config.ResponseTemplate, escapedResponse)), -1)
proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(config.ResponseTemplate, response)), -1)
}

// performSimilaritySearch determines the appropriate similarity search method to use.
Expand Down Expand Up @@ -149,7 +151,7 @@ func handleQueryResults(key string, results []vector.QueryResult, ctx wrapper.Ht
}

mostSimilarData := results[0]
log.Debugf("Most similar key found: %s with score: %f", mostSimilarData.Text, mostSimilarData.Score)
log.Debugf("For key: %s, the most similar key found: %s with score: %f", key, mostSimilarData.Text, mostSimilarData.Score)

simThresholdProvider, ok := config.GetVectorProvider().(vector.SimilarityThresholdProvider)
if !ok {
Expand All @@ -165,7 +167,11 @@ func handleQueryResults(key string, results []vector.QueryResult, ctx wrapper.Ht
processCacheHit(key, mostSimilarData.Answer, stream, ctx, config, log)
} else {
// otherwise, continue to check cache for the most similar key
CheckCacheForKey(mostSimilarData.Text, ctx, config, log, stream, false)
err = CheckCacheForKey(mostSimilarData.Text, ctx, config, log, stream, false)
if err != nil {
log.Errorf("check cache for key: %s failed, error: %v", mostSimilarData.Text, err)
proxywasm.ResumeHttpRequest()
}
}
} else {
log.Infof("Score too high for key: %s with score: %f above threshold", mostSimilarData.Text, mostSimilarData.Score)
Expand Down Expand Up @@ -220,9 +226,9 @@ func uploadEmbeddingAndAnswer(ctx wrapper.HttpContext, config config.PluginConfi
}

// Attempt to upload answer embedding first
if ansEmbUploader, ok := activeVectorProvider.(vector.AnswerEmbeddingUploader); ok {
if ansEmbUploader, ok := activeVectorProvider.(vector.AnswerAndEmbeddingUploader); ok {
log.Infof("[onHttpResponseBody] uploading answer embedding for key: %s", key)
err := ansEmbUploader.UploadAnswerEmbedding(key, emb, value, ctx, log, nil)
err := ansEmbUploader.UploadAnswerAndEmbedding(key, emb, value, ctx, log, nil)
if err != nil {
log.Warnf("[onHttpResponseBody] failed to upload answer embedding for key: %s, error: %v", key, err)
} else {
Expand Down
19 changes: 10 additions & 9 deletions plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@ package embedding
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"

"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
)

const (
DOMAIN = "dashscope.aliyuncs.com"
PORT = 443
DEFAULT_MODEL_NAME = "text-embedding-v1"
ENDPOINT = "/api/v1/services/embeddings/text-embedding/text-embedding"
DASHSCOPE_DOMAIN = "dashscope.aliyuncs.com"
DASHSCOPE_PORT = 443
DASHSCOPE_DEFAULT_MODEL_NAME = "text-embedding-v1"
DASHSCOPE_ENDPOINT = "/api/v1/services/embeddings/text-embedding/text-embedding"
)

type dashScopeProviderInitializer struct {
Expand All @@ -28,10 +29,10 @@ func (d *dashScopeProviderInitializer) ValidateConfig(config ProviderConfig) err

func (d *dashScopeProviderInitializer) CreateProvider(c ProviderConfig) (Provider, error) {
if c.servicePort == 0 {
c.servicePort = PORT
c.servicePort = DASHSCOPE_PORT
}
if c.serviceDomain == "" {
c.serviceDomain = DOMAIN
c.serviceDomain = DASHSCOPE_DOMAIN
}
return &DSProvider{
config: c,
Expand Down Expand Up @@ -95,7 +96,7 @@ func (d *DSProvider) constructParameters(texts []string, log wrapper.Log) (strin
model := d.config.model

if model == "" {
model = DEFAULT_MODEL_NAME
model = DASHSCOPE_DEFAULT_MODEL_NAME
}
data := EmbeddingRequest{
Model: model,
Expand Down Expand Up @@ -124,7 +125,7 @@ func (d *DSProvider) constructParameters(texts []string, log wrapper.Log) (strin
{"Content-Type", "application/json"},
}

return ENDPOINT, headers, requestBody, err
return DASHSCOPE_ENDPOINT, headers, requestBody, err
}

type Result struct {
Expand Down Expand Up @@ -168,7 +169,7 @@ func (d *DSProvider) GetEmbedding(

resp, err = d.parseTextEmbedding(responseBody)
if err != nil {
err = errors.New("failed to parse response: " + err.Error())
err = fmt.Errorf("failed to parse response: %v", err)
callback(nil, err)
return
}
Expand Down
28 changes: 20 additions & 8 deletions plugins/wasm-go/extensions/ai-cache/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,21 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config config.PluginConfig,
}

func onHttpResponseBody(ctx wrapper.HttpContext, config config.PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte {
log.Debugf("[onHttpResponseBody] chunk: %s", string(chunk))
log.Debugf("[onHttpResponseBody] escaped chunk: %q", string(chunk))
log.Debugf("[onHttpResponseBody] isLastChunk: %v", isLastChunk)

// if strings.HasSuffix(string(chunk), "[DONE] \n\n") {
// isLastChunk = true
// }

// If the context contains TOOL_CALLS_CONTEXT_KEY, bypass caching
if ctx.GetContext(TOOL_CALLS_CONTEXT_KEY) != nil {
return chunk
}

keyI := ctx.GetContext(CACHE_KEY_CONTEXT_KEY)
if keyI == nil {
key := ctx.GetContext(CACHE_KEY_CONTEXT_KEY)
if key == nil {
log.Debug("[onHttpResponseBody] key is nil, bypass caching")
return chunk
}

Expand All @@ -137,17 +142,24 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config config.PluginConfig, chu
}

// Handle last chunk
value, err := processFinalChunk(ctx, config, chunk, log)
var value string
var err error

if len(chunk) > 0 {
value, err = processNonEmptyChunk(ctx, config, chunk, log)
} else {
value, err = processEmptyChunk(ctx, config, chunk, log)
}

if err != nil {
log.Warnf("[onHttpResponseBody] failed to process final chunk: %v", err)
log.Warnf("[onHttpResponseBody] failed to process chunk: %v", err)
return chunk
}

// Cache the final value
cacheResponse(ctx, config, keyI.(string), value, log)
cacheResponse(ctx, config, key.(string), value, log)

// Handle embedding upload if available
uploadEmbeddingAndAnswer(ctx, config, keyI.(string), value, log)
uploadEmbeddingAndAnswer(ctx, config, key.(string), value, log)

return chunk
}
11 changes: 10 additions & 1 deletion plugins/wasm-go/extensions/ai-cache/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func savePartialMessage(ctx wrapper.HttpContext, partialMessage []byte, messages
}

// Processes the final chunk and returns the parsed value or an error
func processFinalChunk(ctx wrapper.HttpContext, config config.PluginConfig, chunk []byte, log wrapper.Log) (string, error) {
func processNonEmptyChunk(ctx wrapper.HttpContext, config config.PluginConfig, chunk []byte, log wrapper.Log) (string, error) {
stream := ctx.GetContext(STREAM_CONTEXT_KEY)
var value string

Expand All @@ -122,6 +122,15 @@ func processFinalChunk(ctx wrapper.HttpContext, config config.PluginConfig, chun
return value, nil
}

func processEmptyChunk(ctx wrapper.HttpContext, config config.PluginConfig, chunk []byte, log wrapper.Log) (string, error) {
tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY)
if tempContentI == nil {
return string(chunk), nil
}
value := tempContentI.(string)
return value, nil
}

// Appends the final body chunk to the existing body content
func appendFinalBody(ctx wrapper.HttpContext, chunk []byte) []byte {
tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY)
Expand Down
16 changes: 14 additions & 2 deletions plugins/wasm-go/extensions/ai-cache/vector/dashvector.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,13 @@ func (d *DvProvider) QueryEmbedding(
return err
}

func checkField(fields map[string]interface{}, key string) string {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getField 是不是好一点?

if val, ok := fields[key]; ok {
return val.(string)
}
return ""
}

func (d *DvProvider) ParseQueryResponse(responseBody []byte, ctx wrapper.HttpContext, log wrapper.Log) ([]QueryResult, error) {
resp, err := d.parseQueryResponse(responseBody)
if err != nil {
Expand All @@ -171,9 +178,10 @@ func (d *DvProvider) ParseQueryResponse(responseBody []byte, ctx wrapper.HttpCon

for _, output := range resp.Output {
result := QueryResult{
Text: output.Fields["query"].(string),
Text: checkField(output.Fields, "query"),
Embedding: output.Vector,
Score: output.Score,
Answer: checkField(output.Fields, "answer"),
}
results = append(results, result)
}
Expand Down Expand Up @@ -234,7 +242,11 @@ func (d *DvProvider) UploadEmbedding(queryString string, queryEmb []float64, ctx
return err
}

func (d *DvProvider) UploadAnswerEmbedding(queryString string, queryEmb []float64, queryAnswer string, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
func (d *DvProvider) GetSimilarityThreshold() float64 {
return threshold
}

func (d *DvProvider) UploadAnswerAndEmbedding(queryString string, queryEmb []float64, queryAnswer string, ctx wrapper.HttpContext, log wrapper.Log, callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
url, body, headers, err := d.constructUploadParameters(queryEmb, queryString, queryAnswer)
if err != nil {
return err
Expand Down
8 changes: 6 additions & 2 deletions plugins/wasm-go/extensions/ai-cache/vector/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ type EmbeddingUploader interface {
callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error
}

type AnswerEmbeddingUploader interface {
UploadAnswerEmbedding(
type AnswerAndEmbeddingUploader interface {
UploadAnswerAndEmbedding(
queryString string,
queryEmb []float64,
answer string,
Expand Down Expand Up @@ -118,6 +118,10 @@ type ProviderConfig struct {
// ChromaTimeout uint32 `require:"false" yaml:"ChromaTimeout" json:"ChromaTimeout"`
}

func (c *ProviderConfig) GetProviderType() string {
return c.typ
}

func (c *ProviderConfig) FromJson(json gjson.Result) {
c.typ = json.Get("type").String()
// DashVector
Expand Down