refactor: replace sjson.Set usage with sjson.SetBytes to optimize mutable JSON transformations
This commit is contained in:
@@ -191,58 +191,58 @@ func convertCompletionsRequestToChatCompletions(rawJSON []byte) []byte {
|
||||
}
|
||||
|
||||
// Create chat completions structure
|
||||
out := `{"model":"","messages":[{"role":"user","content":""}]}`
|
||||
out := []byte(`{"model":"","messages":[{"role":"user","content":""}]}`)
|
||||
|
||||
// Set model
|
||||
if model := root.Get("model"); model.Exists() {
|
||||
out, _ = sjson.Set(out, "model", model.String())
|
||||
out, _ = sjson.SetBytes(out, "model", model.String())
|
||||
}
|
||||
|
||||
// Set the prompt as user message content
|
||||
out, _ = sjson.Set(out, "messages.0.content", prompt)
|
||||
out, _ = sjson.SetBytes(out, "messages.0.content", prompt)
|
||||
|
||||
// Copy other parameters from completions to chat completions
|
||||
if maxTokens := root.Get("max_tokens"); maxTokens.Exists() {
|
||||
out, _ = sjson.Set(out, "max_tokens", maxTokens.Int())
|
||||
out, _ = sjson.SetBytes(out, "max_tokens", maxTokens.Int())
|
||||
}
|
||||
|
||||
if temperature := root.Get("temperature"); temperature.Exists() {
|
||||
out, _ = sjson.Set(out, "temperature", temperature.Float())
|
||||
out, _ = sjson.SetBytes(out, "temperature", temperature.Float())
|
||||
}
|
||||
|
||||
if topP := root.Get("top_p"); topP.Exists() {
|
||||
out, _ = sjson.Set(out, "top_p", topP.Float())
|
||||
out, _ = sjson.SetBytes(out, "top_p", topP.Float())
|
||||
}
|
||||
|
||||
if frequencyPenalty := root.Get("frequency_penalty"); frequencyPenalty.Exists() {
|
||||
out, _ = sjson.Set(out, "frequency_penalty", frequencyPenalty.Float())
|
||||
out, _ = sjson.SetBytes(out, "frequency_penalty", frequencyPenalty.Float())
|
||||
}
|
||||
|
||||
if presencePenalty := root.Get("presence_penalty"); presencePenalty.Exists() {
|
||||
out, _ = sjson.Set(out, "presence_penalty", presencePenalty.Float())
|
||||
out, _ = sjson.SetBytes(out, "presence_penalty", presencePenalty.Float())
|
||||
}
|
||||
|
||||
if stop := root.Get("stop"); stop.Exists() {
|
||||
out, _ = sjson.SetRaw(out, "stop", stop.Raw)
|
||||
out, _ = sjson.SetRawBytes(out, "stop", []byte(stop.Raw))
|
||||
}
|
||||
|
||||
if stream := root.Get("stream"); stream.Exists() {
|
||||
out, _ = sjson.Set(out, "stream", stream.Bool())
|
||||
out, _ = sjson.SetBytes(out, "stream", stream.Bool())
|
||||
}
|
||||
|
||||
if logprobs := root.Get("logprobs"); logprobs.Exists() {
|
||||
out, _ = sjson.Set(out, "logprobs", logprobs.Bool())
|
||||
out, _ = sjson.SetBytes(out, "logprobs", logprobs.Bool())
|
||||
}
|
||||
|
||||
if topLogprobs := root.Get("top_logprobs"); topLogprobs.Exists() {
|
||||
out, _ = sjson.Set(out, "top_logprobs", topLogprobs.Int())
|
||||
out, _ = sjson.SetBytes(out, "top_logprobs", topLogprobs.Int())
|
||||
}
|
||||
|
||||
if echo := root.Get("echo"); echo.Exists() {
|
||||
out, _ = sjson.Set(out, "echo", echo.Bool())
|
||||
out, _ = sjson.SetBytes(out, "echo", echo.Bool())
|
||||
}
|
||||
|
||||
return []byte(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// convertChatCompletionsResponseToCompletions converts chat completions API response back to completions format.
|
||||
@@ -257,23 +257,23 @@ func convertChatCompletionsResponseToCompletions(rawJSON []byte) []byte {
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
|
||||
// Base completions response structure
|
||||
out := `{"id":"","object":"text_completion","created":0,"model":"","choices":[]}`
|
||||
out := []byte(`{"id":"","object":"text_completion","created":0,"model":"","choices":[]}`)
|
||||
|
||||
// Copy basic fields
|
||||
if id := root.Get("id"); id.Exists() {
|
||||
out, _ = sjson.Set(out, "id", id.String())
|
||||
out, _ = sjson.SetBytes(out, "id", id.String())
|
||||
}
|
||||
|
||||
if created := root.Get("created"); created.Exists() {
|
||||
out, _ = sjson.Set(out, "created", created.Int())
|
||||
out, _ = sjson.SetBytes(out, "created", created.Int())
|
||||
}
|
||||
|
||||
if model := root.Get("model"); model.Exists() {
|
||||
out, _ = sjson.Set(out, "model", model.String())
|
||||
out, _ = sjson.SetBytes(out, "model", model.String())
|
||||
}
|
||||
|
||||
if usage := root.Get("usage"); usage.Exists() {
|
||||
out, _ = sjson.SetRaw(out, "usage", usage.Raw)
|
||||
out, _ = sjson.SetRawBytes(out, "usage", []byte(usage.Raw))
|
||||
}
|
||||
|
||||
// Convert choices from chat completions to completions format
|
||||
@@ -313,10 +313,10 @@ func convertChatCompletionsResponseToCompletions(rawJSON []byte) []byte {
|
||||
|
||||
if len(choices) > 0 {
|
||||
choicesJSON, _ := json.Marshal(choices)
|
||||
out, _ = sjson.SetRaw(out, "choices", string(choicesJSON))
|
||||
out, _ = sjson.SetRawBytes(out, "choices", choicesJSON)
|
||||
}
|
||||
|
||||
return []byte(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// convertChatCompletionsStreamChunkToCompletions converts a streaming chat completions chunk to completions format.
|
||||
@@ -357,19 +357,19 @@ func convertChatCompletionsStreamChunkToCompletions(chunkData []byte) []byte {
|
||||
}
|
||||
|
||||
// Base completions stream response structure
|
||||
out := `{"id":"","object":"text_completion","created":0,"model":"","choices":[]}`
|
||||
out := []byte(`{"id":"","object":"text_completion","created":0,"model":"","choices":[]}`)
|
||||
|
||||
// Copy basic fields
|
||||
if id := root.Get("id"); id.Exists() {
|
||||
out, _ = sjson.Set(out, "id", id.String())
|
||||
out, _ = sjson.SetBytes(out, "id", id.String())
|
||||
}
|
||||
|
||||
if created := root.Get("created"); created.Exists() {
|
||||
out, _ = sjson.Set(out, "created", created.Int())
|
||||
out, _ = sjson.SetBytes(out, "created", created.Int())
|
||||
}
|
||||
|
||||
if model := root.Get("model"); model.Exists() {
|
||||
out, _ = sjson.Set(out, "model", model.String())
|
||||
out, _ = sjson.SetBytes(out, "model", model.String())
|
||||
}
|
||||
|
||||
// Convert choices from chat completions delta to completions format
|
||||
@@ -408,15 +408,15 @@ func convertChatCompletionsStreamChunkToCompletions(chunkData []byte) []byte {
|
||||
|
||||
if len(choices) > 0 {
|
||||
choicesJSON, _ := json.Marshal(choices)
|
||||
out, _ = sjson.SetRaw(out, "choices", string(choicesJSON))
|
||||
out, _ = sjson.SetRawBytes(out, "choices", choicesJSON)
|
||||
}
|
||||
|
||||
// Copy usage if present
|
||||
if usage := root.Get("usage"); usage.Exists() {
|
||||
out, _ = sjson.SetRaw(out, "usage", usage.Raw)
|
||||
out, _ = sjson.SetRawBytes(out, "usage", []byte(usage.Raw))
|
||||
}
|
||||
|
||||
return []byte(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// handleNonStreamingResponse handles non-streaming chat completion responses
|
||||
|
||||
@@ -13,16 +13,16 @@ func HasResponseTransformerByFormatName(from, to Format) bool {
|
||||
}
|
||||
|
||||
// TranslateStreamByFormatName converts streaming responses between schemas by their string identifiers.
|
||||
func TranslateStreamByFormatName(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
func TranslateStreamByFormatName(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
return TranslateStream(ctx, from, to, model, originalRequestRawJSON, requestRawJSON, rawJSON, param)
|
||||
}
|
||||
|
||||
// TranslateNonStreamByFormatName converts non-streaming responses between schemas by their string identifiers.
|
||||
func TranslateNonStreamByFormatName(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
|
||||
func TranslateNonStreamByFormatName(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte {
|
||||
return TranslateNonStream(ctx, from, to, model, originalRequestRawJSON, requestRawJSON, rawJSON, param)
|
||||
}
|
||||
|
||||
// TranslateTokenCountByFormatName converts token counts between schemas by their string identifiers.
|
||||
func TranslateTokenCountByFormatName(ctx context.Context, from, to Format, count int64, rawJSON []byte) string {
|
||||
func TranslateTokenCountByFormatName(ctx context.Context, from, to Format, count int64, rawJSON []byte) []byte {
|
||||
return TranslateTokenCount(ctx, from, to, count, rawJSON)
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ type ResponseEnvelope struct {
|
||||
Model string
|
||||
Stream bool
|
||||
Body []byte
|
||||
Chunks []string
|
||||
Chunks [][]byte
|
||||
}
|
||||
|
||||
// RequestMiddleware decorates request translation.
|
||||
@@ -87,7 +87,7 @@ func (p *Pipeline) TranslateResponse(ctx context.Context, from, to Format, resp
|
||||
if input.Stream {
|
||||
input.Chunks = p.registry.TranslateStream(ctx, from, to, input.Model, originalReq, translatedReq, input.Body, param)
|
||||
} else {
|
||||
input.Body = []byte(p.registry.TranslateNonStream(ctx, from, to, input.Model, originalReq, translatedReq, input.Body, param))
|
||||
input.Body = p.registry.TranslateNonStream(ctx, from, to, input.Model, originalReq, translatedReq, input.Body, param)
|
||||
}
|
||||
input.Format = to
|
||||
return input, nil
|
||||
|
||||
@@ -66,7 +66,7 @@ func (r *Registry) HasResponseTransformer(from, to Format) bool {
|
||||
}
|
||||
|
||||
// TranslateStream applies the registered streaming response translator.
|
||||
func (r *Registry) TranslateStream(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
func (r *Registry) TranslateStream(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
@@ -75,11 +75,11 @@ func (r *Registry) TranslateStream(ctx context.Context, from, to Format, model s
|
||||
return fn.Stream(ctx, model, originalRequestRawJSON, requestRawJSON, rawJSON, param)
|
||||
}
|
||||
}
|
||||
return []string{string(rawJSON)}
|
||||
return [][]byte{rawJSON}
|
||||
}
|
||||
|
||||
// TranslateNonStream applies the registered non-stream response translator.
|
||||
func (r *Registry) TranslateNonStream(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
|
||||
func (r *Registry) TranslateNonStream(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
@@ -88,11 +88,11 @@ func (r *Registry) TranslateNonStream(ctx context.Context, from, to Format, mode
|
||||
return fn.NonStream(ctx, model, originalRequestRawJSON, requestRawJSON, rawJSON, param)
|
||||
}
|
||||
}
|
||||
return string(rawJSON)
|
||||
return rawJSON
|
||||
}
|
||||
|
||||
// TranslateNonStream applies the registered non-stream response translator.
|
||||
func (r *Registry) TranslateTokenCount(ctx context.Context, from, to Format, count int64, rawJSON []byte) string {
|
||||
// TranslateTokenCount applies the registered token count response translator.
|
||||
func (r *Registry) TranslateTokenCount(ctx context.Context, from, to Format, count int64, rawJSON []byte) []byte {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
@@ -101,7 +101,7 @@ func (r *Registry) TranslateTokenCount(ctx context.Context, from, to Format, cou
|
||||
return fn.TokenCount(ctx, count)
|
||||
}
|
||||
}
|
||||
return string(rawJSON)
|
||||
return rawJSON
|
||||
}
|
||||
|
||||
var defaultRegistry = NewRegistry()
|
||||
@@ -127,16 +127,16 @@ func HasResponseTransformer(from, to Format) bool {
|
||||
}
|
||||
|
||||
// TranslateStream is a helper on the default registry.
|
||||
func TranslateStream(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
func TranslateStream(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
return defaultRegistry.TranslateStream(ctx, from, to, model, originalRequestRawJSON, requestRawJSON, rawJSON, param)
|
||||
}
|
||||
|
||||
// TranslateNonStream is a helper on the default registry.
|
||||
func TranslateNonStream(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
|
||||
func TranslateNonStream(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte {
|
||||
return defaultRegistry.TranslateNonStream(ctx, from, to, model, originalRequestRawJSON, requestRawJSON, rawJSON, param)
|
||||
}
|
||||
|
||||
// TranslateTokenCount is a helper on the default registry.
|
||||
func TranslateTokenCount(ctx context.Context, from, to Format, count int64, rawJSON []byte) string {
|
||||
func TranslateTokenCount(ctx context.Context, from, to Format, count int64, rawJSON []byte) []byte {
|
||||
return defaultRegistry.TranslateTokenCount(ctx, from, to, count, rawJSON)
|
||||
}
|
||||
|
||||
52
sdk/translator/registry_bytes_test.go
Normal file
52
sdk/translator/registry_bytes_test.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package translator
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRegistryTranslateStreamReturnsByteChunks(t *testing.T) {
|
||||
registry := NewRegistry()
|
||||
registry.Register(FormatOpenAI, FormatGemini, nil, ResponseTransform{
|
||||
Stream: func(ctx context.Context, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
return [][]byte{append([]byte(nil), rawJSON...)}
|
||||
},
|
||||
})
|
||||
|
||||
got := registry.TranslateStream(context.Background(), FormatGemini, FormatOpenAI, "model", nil, nil, []byte(`{"chunk":true}`), nil)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected 1 chunk, got %d", len(got))
|
||||
}
|
||||
if !bytes.Equal(got[0], []byte(`{"chunk":true}`)) {
|
||||
t.Fatalf("unexpected chunk: %s", got[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryTranslateNonStreamReturnsBytes(t *testing.T) {
|
||||
registry := NewRegistry()
|
||||
registry.Register(FormatOpenAI, FormatGemini, nil, ResponseTransform{
|
||||
NonStream: func(ctx context.Context, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte {
|
||||
return append([]byte(nil), rawJSON...)
|
||||
},
|
||||
})
|
||||
|
||||
got := registry.TranslateNonStream(context.Background(), FormatGemini, FormatOpenAI, "model", nil, nil, []byte(`{"done":true}`), nil)
|
||||
if !bytes.Equal(got, []byte(`{"done":true}`)) {
|
||||
t.Fatalf("unexpected payload: %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryTranslateTokenCountReturnsBytes(t *testing.T) {
|
||||
registry := NewRegistry()
|
||||
registry.Register(FormatOpenAI, FormatGemini, nil, ResponseTransform{
|
||||
TokenCount: func(ctx context.Context, count int64) []byte {
|
||||
return []byte(`{"totalTokens":7}`)
|
||||
},
|
||||
})
|
||||
|
||||
got := registry.TranslateTokenCount(context.Background(), FormatGemini, FormatOpenAI, 7, []byte(`{"fallback":true}`))
|
||||
if !bytes.Equal(got, []byte(`{"totalTokens":7}`)) {
|
||||
t.Fatalf("unexpected payload: %s", got)
|
||||
}
|
||||
}
|
||||
@@ -10,17 +10,17 @@ type RequestTransform func(model string, rawJSON []byte, stream bool) []byte
|
||||
|
||||
// ResponseStreamTransform is a function type that converts a streaming response from a source schema to a target schema.
|
||||
// It takes a context, the model name, the raw JSON of the original and converted requests, the raw JSON of the current response chunk, and an optional parameter.
|
||||
// It returns a slice of strings, where each string is a chunk of the converted streaming response.
|
||||
type ResponseStreamTransform func(ctx context.Context, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string
|
||||
// It returns a slice of byte chunks containing the converted streaming response.
|
||||
type ResponseStreamTransform func(ctx context.Context, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte
|
||||
|
||||
// ResponseNonStreamTransform is a function type that converts a non-streaming response from a source schema to a target schema.
|
||||
// It takes a context, the model name, the raw JSON of the original and converted requests, the raw JSON of the response, and an optional parameter.
|
||||
// It returns the converted response as a single string.
|
||||
type ResponseNonStreamTransform func(ctx context.Context, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string
|
||||
// It returns the converted response as a single byte slice.
|
||||
type ResponseNonStreamTransform func(ctx context.Context, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte
|
||||
|
||||
// ResponseTokenCountTransform is a function type that transforms a token count from a source format to a target format.
|
||||
// It takes a context and the token count as an int64, and returns the transformed token count as a string.
|
||||
type ResponseTokenCountTransform func(ctx context.Context, count int64) string
|
||||
// It takes a context and the token count as an int64, and returns the transformed token count as bytes.
|
||||
type ResponseTokenCountTransform func(ctx context.Context, count int64) []byte
|
||||
|
||||
// ResponseTransform is a struct that groups together the functions for transforming streaming and non-streaming responses,
|
||||
// as well as token counts.
|
||||
|
||||
Reference in New Issue
Block a user