refactor: replace sjson.Set usage with sjson.SetBytes to optimize mutable JSON transformations

This commit is contained in:
Luis Pater
2026-03-19 17:58:54 +08:00
parent 56073ded69
commit 2bd646ad70
73 changed files with 3008 additions and 2944 deletions

View File

@@ -191,58 +191,58 @@ func convertCompletionsRequestToChatCompletions(rawJSON []byte) []byte {
}
// Create chat completions structure
out := `{"model":"","messages":[{"role":"user","content":""}]}`
out := []byte(`{"model":"","messages":[{"role":"user","content":""}]}`)
// Set model
if model := root.Get("model"); model.Exists() {
out, _ = sjson.Set(out, "model", model.String())
out, _ = sjson.SetBytes(out, "model", model.String())
}
// Set the prompt as user message content
out, _ = sjson.Set(out, "messages.0.content", prompt)
out, _ = sjson.SetBytes(out, "messages.0.content", prompt)
// Copy other parameters from completions to chat completions
if maxTokens := root.Get("max_tokens"); maxTokens.Exists() {
out, _ = sjson.Set(out, "max_tokens", maxTokens.Int())
out, _ = sjson.SetBytes(out, "max_tokens", maxTokens.Int())
}
if temperature := root.Get("temperature"); temperature.Exists() {
out, _ = sjson.Set(out, "temperature", temperature.Float())
out, _ = sjson.SetBytes(out, "temperature", temperature.Float())
}
if topP := root.Get("top_p"); topP.Exists() {
out, _ = sjson.Set(out, "top_p", topP.Float())
out, _ = sjson.SetBytes(out, "top_p", topP.Float())
}
if frequencyPenalty := root.Get("frequency_penalty"); frequencyPenalty.Exists() {
out, _ = sjson.Set(out, "frequency_penalty", frequencyPenalty.Float())
out, _ = sjson.SetBytes(out, "frequency_penalty", frequencyPenalty.Float())
}
if presencePenalty := root.Get("presence_penalty"); presencePenalty.Exists() {
out, _ = sjson.Set(out, "presence_penalty", presencePenalty.Float())
out, _ = sjson.SetBytes(out, "presence_penalty", presencePenalty.Float())
}
if stop := root.Get("stop"); stop.Exists() {
out, _ = sjson.SetRaw(out, "stop", stop.Raw)
out, _ = sjson.SetRawBytes(out, "stop", []byte(stop.Raw))
}
if stream := root.Get("stream"); stream.Exists() {
out, _ = sjson.Set(out, "stream", stream.Bool())
out, _ = sjson.SetBytes(out, "stream", stream.Bool())
}
if logprobs := root.Get("logprobs"); logprobs.Exists() {
out, _ = sjson.Set(out, "logprobs", logprobs.Bool())
out, _ = sjson.SetBytes(out, "logprobs", logprobs.Bool())
}
if topLogprobs := root.Get("top_logprobs"); topLogprobs.Exists() {
out, _ = sjson.Set(out, "top_logprobs", topLogprobs.Int())
out, _ = sjson.SetBytes(out, "top_logprobs", topLogprobs.Int())
}
if echo := root.Get("echo"); echo.Exists() {
out, _ = sjson.Set(out, "echo", echo.Bool())
out, _ = sjson.SetBytes(out, "echo", echo.Bool())
}
return []byte(out)
return out
}
// convertChatCompletionsResponseToCompletions converts chat completions API response back to completions format.
@@ -257,23 +257,23 @@ func convertChatCompletionsResponseToCompletions(rawJSON []byte) []byte {
root := gjson.ParseBytes(rawJSON)
// Base completions response structure
out := `{"id":"","object":"text_completion","created":0,"model":"","choices":[]}`
out := []byte(`{"id":"","object":"text_completion","created":0,"model":"","choices":[]}`)
// Copy basic fields
if id := root.Get("id"); id.Exists() {
out, _ = sjson.Set(out, "id", id.String())
out, _ = sjson.SetBytes(out, "id", id.String())
}
if created := root.Get("created"); created.Exists() {
out, _ = sjson.Set(out, "created", created.Int())
out, _ = sjson.SetBytes(out, "created", created.Int())
}
if model := root.Get("model"); model.Exists() {
out, _ = sjson.Set(out, "model", model.String())
out, _ = sjson.SetBytes(out, "model", model.String())
}
if usage := root.Get("usage"); usage.Exists() {
out, _ = sjson.SetRaw(out, "usage", usage.Raw)
out, _ = sjson.SetRawBytes(out, "usage", []byte(usage.Raw))
}
// Convert choices from chat completions to completions format
@@ -313,10 +313,10 @@ func convertChatCompletionsResponseToCompletions(rawJSON []byte) []byte {
if len(choices) > 0 {
choicesJSON, _ := json.Marshal(choices)
out, _ = sjson.SetRaw(out, "choices", string(choicesJSON))
out, _ = sjson.SetRawBytes(out, "choices", choicesJSON)
}
return []byte(out)
return out
}
// convertChatCompletionsStreamChunkToCompletions converts a streaming chat completions chunk to completions format.
@@ -357,19 +357,19 @@ func convertChatCompletionsStreamChunkToCompletions(chunkData []byte) []byte {
}
// Base completions stream response structure
out := `{"id":"","object":"text_completion","created":0,"model":"","choices":[]}`
out := []byte(`{"id":"","object":"text_completion","created":0,"model":"","choices":[]}`)
// Copy basic fields
if id := root.Get("id"); id.Exists() {
out, _ = sjson.Set(out, "id", id.String())
out, _ = sjson.SetBytes(out, "id", id.String())
}
if created := root.Get("created"); created.Exists() {
out, _ = sjson.Set(out, "created", created.Int())
out, _ = sjson.SetBytes(out, "created", created.Int())
}
if model := root.Get("model"); model.Exists() {
out, _ = sjson.Set(out, "model", model.String())
out, _ = sjson.SetBytes(out, "model", model.String())
}
// Convert choices from chat completions delta to completions format
@@ -408,15 +408,15 @@ func convertChatCompletionsStreamChunkToCompletions(chunkData []byte) []byte {
if len(choices) > 0 {
choicesJSON, _ := json.Marshal(choices)
out, _ = sjson.SetRaw(out, "choices", string(choicesJSON))
out, _ = sjson.SetRawBytes(out, "choices", choicesJSON)
}
// Copy usage if present
if usage := root.Get("usage"); usage.Exists() {
out, _ = sjson.SetRaw(out, "usage", usage.Raw)
out, _ = sjson.SetRawBytes(out, "usage", []byte(usage.Raw))
}
return []byte(out)
return out
}
// handleNonStreamingResponse handles non-streaming chat completion responses

View File

@@ -13,16 +13,16 @@ func HasResponseTransformerByFormatName(from, to Format) bool {
}
// TranslateStreamByFormatName converts streaming responses between schemas by their string identifiers.
func TranslateStreamByFormatName(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
func TranslateStreamByFormatName(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
return TranslateStream(ctx, from, to, model, originalRequestRawJSON, requestRawJSON, rawJSON, param)
}
// TranslateNonStreamByFormatName converts non-streaming responses between schemas by their string identifiers.
func TranslateNonStreamByFormatName(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
func TranslateNonStreamByFormatName(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte {
return TranslateNonStream(ctx, from, to, model, originalRequestRawJSON, requestRawJSON, rawJSON, param)
}
// TranslateTokenCountByFormatName converts token counts between schemas by their string identifiers.
func TranslateTokenCountByFormatName(ctx context.Context, from, to Format, count int64, rawJSON []byte) string {
func TranslateTokenCountByFormatName(ctx context.Context, from, to Format, count int64, rawJSON []byte) []byte {
return TranslateTokenCount(ctx, from, to, count, rawJSON)
}

View File

@@ -16,7 +16,7 @@ type ResponseEnvelope struct {
Model string
Stream bool
Body []byte
Chunks []string
Chunks [][]byte
}
// RequestMiddleware decorates request translation.
@@ -87,7 +87,7 @@ func (p *Pipeline) TranslateResponse(ctx context.Context, from, to Format, resp
if input.Stream {
input.Chunks = p.registry.TranslateStream(ctx, from, to, input.Model, originalReq, translatedReq, input.Body, param)
} else {
input.Body = []byte(p.registry.TranslateNonStream(ctx, from, to, input.Model, originalReq, translatedReq, input.Body, param))
input.Body = p.registry.TranslateNonStream(ctx, from, to, input.Model, originalReq, translatedReq, input.Body, param)
}
input.Format = to
return input, nil

View File

@@ -66,7 +66,7 @@ func (r *Registry) HasResponseTransformer(from, to Format) bool {
}
// TranslateStream applies the registered streaming response translator.
func (r *Registry) TranslateStream(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
func (r *Registry) TranslateStream(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
r.mu.RLock()
defer r.mu.RUnlock()
@@ -75,11 +75,11 @@ func (r *Registry) TranslateStream(ctx context.Context, from, to Format, model s
return fn.Stream(ctx, model, originalRequestRawJSON, requestRawJSON, rawJSON, param)
}
}
return []string{string(rawJSON)}
return [][]byte{rawJSON}
}
// TranslateNonStream applies the registered non-stream response translator.
func (r *Registry) TranslateNonStream(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
func (r *Registry) TranslateNonStream(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte {
r.mu.RLock()
defer r.mu.RUnlock()
@@ -88,11 +88,11 @@ func (r *Registry) TranslateNonStream(ctx context.Context, from, to Format, mode
return fn.NonStream(ctx, model, originalRequestRawJSON, requestRawJSON, rawJSON, param)
}
}
return string(rawJSON)
return rawJSON
}
// TranslateNonStream applies the registered non-stream response translator.
func (r *Registry) TranslateTokenCount(ctx context.Context, from, to Format, count int64, rawJSON []byte) string {
// TranslateTokenCount applies the registered token count response translator.
func (r *Registry) TranslateTokenCount(ctx context.Context, from, to Format, count int64, rawJSON []byte) []byte {
r.mu.RLock()
defer r.mu.RUnlock()
@@ -101,7 +101,7 @@ func (r *Registry) TranslateTokenCount(ctx context.Context, from, to Format, cou
return fn.TokenCount(ctx, count)
}
}
return string(rawJSON)
return rawJSON
}
var defaultRegistry = NewRegistry()
@@ -127,16 +127,16 @@ func HasResponseTransformer(from, to Format) bool {
}
// TranslateStream is a helper on the default registry.
func TranslateStream(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
func TranslateStream(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
return defaultRegistry.TranslateStream(ctx, from, to, model, originalRequestRawJSON, requestRawJSON, rawJSON, param)
}
// TranslateNonStream is a helper on the default registry.
func TranslateNonStream(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
func TranslateNonStream(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte {
return defaultRegistry.TranslateNonStream(ctx, from, to, model, originalRequestRawJSON, requestRawJSON, rawJSON, param)
}
// TranslateTokenCount is a helper on the default registry.
func TranslateTokenCount(ctx context.Context, from, to Format, count int64, rawJSON []byte) string {
func TranslateTokenCount(ctx context.Context, from, to Format, count int64, rawJSON []byte) []byte {
return defaultRegistry.TranslateTokenCount(ctx, from, to, count, rawJSON)
}

View File

@@ -0,0 +1,52 @@
package translator
import (
"bytes"
"context"
"testing"
)
func TestRegistryTranslateStreamReturnsByteChunks(t *testing.T) {
registry := NewRegistry()
registry.Register(FormatOpenAI, FormatGemini, nil, ResponseTransform{
Stream: func(ctx context.Context, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
return [][]byte{append([]byte(nil), rawJSON...)}
},
})
got := registry.TranslateStream(context.Background(), FormatGemini, FormatOpenAI, "model", nil, nil, []byte(`{"chunk":true}`), nil)
if len(got) != 1 {
t.Fatalf("expected 1 chunk, got %d", len(got))
}
if !bytes.Equal(got[0], []byte(`{"chunk":true}`)) {
t.Fatalf("unexpected chunk: %s", got[0])
}
}
func TestRegistryTranslateNonStreamReturnsBytes(t *testing.T) {
registry := NewRegistry()
registry.Register(FormatOpenAI, FormatGemini, nil, ResponseTransform{
NonStream: func(ctx context.Context, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte {
return append([]byte(nil), rawJSON...)
},
})
got := registry.TranslateNonStream(context.Background(), FormatGemini, FormatOpenAI, "model", nil, nil, []byte(`{"done":true}`), nil)
if !bytes.Equal(got, []byte(`{"done":true}`)) {
t.Fatalf("unexpected payload: %s", got)
}
}
func TestRegistryTranslateTokenCountReturnsBytes(t *testing.T) {
registry := NewRegistry()
registry.Register(FormatOpenAI, FormatGemini, nil, ResponseTransform{
TokenCount: func(ctx context.Context, count int64) []byte {
return []byte(`{"totalTokens":7}`)
},
})
got := registry.TranslateTokenCount(context.Background(), FormatGemini, FormatOpenAI, 7, []byte(`{"fallback":true}`))
if !bytes.Equal(got, []byte(`{"totalTokens":7}`)) {
t.Fatalf("unexpected payload: %s", got)
}
}

View File

@@ -10,17 +10,17 @@ type RequestTransform func(model string, rawJSON []byte, stream bool) []byte
// ResponseStreamTransform is a function type that converts a streaming response from a source schema to a target schema.
// It takes a context, the model name, the raw JSON of the original and converted requests, the raw JSON of the current response chunk, and an optional parameter.
// It returns a slice of strings, where each string is a chunk of the converted streaming response.
type ResponseStreamTransform func(ctx context.Context, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string
// It returns a slice of byte chunks containing the converted streaming response.
type ResponseStreamTransform func(ctx context.Context, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte
// ResponseNonStreamTransform is a function type that converts a non-streaming response from a source schema to a target schema.
// It takes a context, the model name, the raw JSON of the original and converted requests, the raw JSON of the response, and an optional parameter.
// It returns the converted response as a single string.
type ResponseNonStreamTransform func(ctx context.Context, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string
// It returns the converted response as a single byte slice.
type ResponseNonStreamTransform func(ctx context.Context, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []byte
// ResponseTokenCountTransform is a function type that transforms a token count from a source format to a target format.
// It takes a context and the token count as an int64, and returns the transformed token count as a string.
type ResponseTokenCountTransform func(ctx context.Context, count int64) string
// It takes a context and the token count as an int64, and returns the transformed token count as bytes.
type ResponseTokenCountTransform func(ctx context.Context, count int64) []byte
// ResponseTransform is a struct that groups together the functions for transforming streaming and non-streaming responses,
// as well as token counts.