Merge remote-tracking branch 'origin/pr/3239'
This commit is contained in:
@@ -57,11 +57,72 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu
|
|||||||
|
|
||||||
// Convert input array to messages
|
// Convert input array to messages
|
||||||
if input := root.Get("input"); input.Exists() && input.IsArray() {
|
if input := root.Get("input"); input.Exists() && input.IsArray() {
|
||||||
input.ForEach(func(_, item gjson.Result) bool {
|
inputItems := input.Array()
|
||||||
|
outputCallIDs := make(map[string]struct{})
|
||||||
|
for _, item := range inputItems {
|
||||||
|
if item.Get("type").String() != "function_call_output" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
callID := strings.TrimSpace(item.Get("call_id").String())
|
||||||
|
if callID == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
outputCallIDs[callID] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
pendingToolCalls := make([]interface{}, 0)
|
||||||
|
pendingToolCallIDs := make([]string, 0)
|
||||||
|
awaitingToolOutputs := make(map[string]struct{})
|
||||||
|
deferredMessages := make([][]byte, 0)
|
||||||
|
|
||||||
|
flushPendingToolCalls := func() {
|
||||||
|
if len(pendingToolCalls) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
assistantMessage := []byte(`{"role":"assistant","tool_calls":[]}`)
|
||||||
|
assistantMessage, _ = sjson.SetBytes(assistantMessage, "tool_calls", pendingToolCalls)
|
||||||
|
out, _ = sjson.SetRawBytes(out, "messages.-1", assistantMessage)
|
||||||
|
for _, id := range pendingToolCallIDs {
|
||||||
|
if strings.TrimSpace(id) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
awaitingToolOutputs[id] = struct{}{}
|
||||||
|
}
|
||||||
|
pendingToolCalls = pendingToolCalls[:0]
|
||||||
|
pendingToolCallIDs = pendingToolCallIDs[:0]
|
||||||
|
}
|
||||||
|
flushDeferredMessages := func() {
|
||||||
|
for _, message := range deferredMessages {
|
||||||
|
out, _ = sjson.SetRawBytes(out, "messages.-1", message)
|
||||||
|
}
|
||||||
|
deferredMessages = deferredMessages[:0]
|
||||||
|
}
|
||||||
|
hasAwaitingToolOutput := func() bool {
|
||||||
|
for id := range awaitingToolOutputs {
|
||||||
|
if _, ok := outputCallIDs[id]; ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
appendRegularMessage := func(message []byte) {
|
||||||
|
// Keep tool-call adjacency strict for providers that require
|
||||||
|
// assistant(tool_calls) -> tool(tool_call_id) with no message in between.
|
||||||
|
if hasAwaitingToolOutput() {
|
||||||
|
deferredMessages = append(deferredMessages, message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
out, _ = sjson.SetRawBytes(out, "messages.-1", message)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, item := range inputItems {
|
||||||
itemType := item.Get("type").String()
|
itemType := item.Get("type").String()
|
||||||
if itemType == "" && item.Get("role").String() != "" {
|
if itemType == "" && item.Get("role").String() != "" {
|
||||||
itemType = "message"
|
itemType = "message"
|
||||||
}
|
}
|
||||||
|
if itemType != "function_call" {
|
||||||
|
flushPendingToolCalls()
|
||||||
|
}
|
||||||
|
|
||||||
switch itemType {
|
switch itemType {
|
||||||
case "message", "":
|
case "message", "":
|
||||||
@@ -109,12 +170,10 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu
|
|||||||
message, _ = sjson.SetBytes(message, "content", content.String())
|
message, _ = sjson.SetBytes(message, "content", content.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
out, _ = sjson.SetRawBytes(out, "messages.-1", message)
|
appendRegularMessage(message)
|
||||||
|
|
||||||
case "function_call":
|
case "function_call":
|
||||||
// Handle function call conversion to assistant message with tool_calls
|
// Buffer consecutive function calls and emit them as one assistant message.
|
||||||
assistantMessage := []byte(`{"role":"assistant","tool_calls":[]}`)
|
|
||||||
|
|
||||||
toolCall := []byte(`{"id":"","type":"function","function":{"name":"","arguments":""}}`)
|
toolCall := []byte(`{"id":"","type":"function","function":{"name":"","arguments":""}}`)
|
||||||
|
|
||||||
if callId := item.Get("call_id"); callId.Exists() {
|
if callId := item.Get("call_id"); callId.Exists() {
|
||||||
@@ -128,16 +187,19 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu
|
|||||||
if arguments := item.Get("arguments"); arguments.Exists() {
|
if arguments := item.Get("arguments"); arguments.Exists() {
|
||||||
toolCall, _ = sjson.SetBytes(toolCall, "function.arguments", arguments.String())
|
toolCall, _ = sjson.SetBytes(toolCall, "function.arguments", arguments.String())
|
||||||
}
|
}
|
||||||
|
pendingToolCalls = append(pendingToolCalls, gjson.ParseBytes(toolCall).Value())
|
||||||
assistantMessage, _ = sjson.SetRawBytes(assistantMessage, "tool_calls.0", toolCall)
|
if callID := strings.TrimSpace(item.Get("call_id").String()); callID != "" {
|
||||||
out, _ = sjson.SetRawBytes(out, "messages.-1", assistantMessage)
|
pendingToolCallIDs = append(pendingToolCallIDs, callID)
|
||||||
|
}
|
||||||
|
|
||||||
case "function_call_output":
|
case "function_call_output":
|
||||||
// Handle function call output conversion to tool message
|
// Handle function call output conversion to tool message
|
||||||
toolMessage := []byte(`{"role":"tool","tool_call_id":"","content":""}`)
|
toolMessage := []byte(`{"role":"tool","tool_call_id":"","content":""}`)
|
||||||
|
callID := ""
|
||||||
|
|
||||||
if callId := item.Get("call_id"); callId.Exists() {
|
if callId := item.Get("call_id"); callId.Exists() {
|
||||||
toolMessage, _ = sjson.SetBytes(toolMessage, "tool_call_id", callId.String())
|
callID = strings.TrimSpace(callId.String())
|
||||||
|
toolMessage, _ = sjson.SetBytes(toolMessage, "tool_call_id", callID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if output := item.Get("output"); output.Exists() {
|
if output := item.Get("output"); output.Exists() {
|
||||||
@@ -145,10 +207,17 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu
|
|||||||
}
|
}
|
||||||
|
|
||||||
out, _ = sjson.SetRawBytes(out, "messages.-1", toolMessage)
|
out, _ = sjson.SetRawBytes(out, "messages.-1", toolMessage)
|
||||||
|
if callID != "" {
|
||||||
|
delete(awaitingToolOutputs, callID)
|
||||||
|
}
|
||||||
|
if len(awaitingToolOutputs) == 0 && len(deferredMessages) > 0 {
|
||||||
|
flushDeferredMessages()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
}
|
||||||
})
|
flushPendingToolCalls()
|
||||||
|
flushDeferredMessages()
|
||||||
} else if input.Type == gjson.String {
|
} else if input.Type == gjson.String {
|
||||||
msg := []byte(`{}`)
|
msg := []byte(`{}`)
|
||||||
msg, _ = sjson.SetBytes(msg, "role", "user")
|
msg, _ = sjson.SetBytes(msg, "role", "user")
|
||||||
|
|||||||
@@ -0,0 +1,124 @@
|
|||||||
|
package responses
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func prettyJSONForTest(raw []byte) string {
|
||||||
|
if !gjson.ValidBytes(raw) {
|
||||||
|
return string(raw)
|
||||||
|
}
|
||||||
|
var out bytes.Buffer
|
||||||
|
if err := json.Indent(&out, raw, "", " "); err != nil {
|
||||||
|
return string(raw)
|
||||||
|
}
|
||||||
|
return out.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertOpenAIResponsesRequestToOpenAIChatCompletions_MergeConsecutiveFunctionCalls(t *testing.T) {
|
||||||
|
raw := []byte(`{
|
||||||
|
"input": [
|
||||||
|
{"type":"function_call","call_id":"exec_command:0","name":"exec_command","arguments":"{\"cmd\":\"ls\"}"},
|
||||||
|
{"type":"function_call","call_id":"exec_command:1","name":"exec_command","arguments":"{\"cmd\":\"pwd\"}"},
|
||||||
|
{"type":"function_call_output","call_id":"exec_command:0","output":"ok0"},
|
||||||
|
{"type":"function_call_output","call_id":"exec_command:1","output":"ok1"}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
t.Logf("input json:\n%s", prettyJSONForTest(raw))
|
||||||
|
|
||||||
|
out := ConvertOpenAIResponsesRequestToOpenAIChatCompletions("kimi-k2.6", raw, true)
|
||||||
|
t.Logf("output json:\n%s", prettyJSONForTest(out))
|
||||||
|
|
||||||
|
msgs := gjson.GetBytes(out, "messages")
|
||||||
|
if !msgs.Exists() || !msgs.IsArray() {
|
||||||
|
t.Fatalf("messages should be an array")
|
||||||
|
}
|
||||||
|
if got := len(msgs.Array()); got != 3 {
|
||||||
|
t.Fatalf("messages count = %d, want %d", got, 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := gjson.GetBytes(out, "messages.0.role").String(); got != "assistant" {
|
||||||
|
t.Fatalf("messages.0.role = %q, want %q", got, "assistant")
|
||||||
|
}
|
||||||
|
if got := len(gjson.GetBytes(out, "messages.0.tool_calls").Array()); got != 2 {
|
||||||
|
t.Fatalf("messages.0.tool_calls length = %d, want %d", got, 2)
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "messages.0.tool_calls.0.id").String(); got != "exec_command:0" {
|
||||||
|
t.Fatalf("messages.0.tool_calls.0.id = %q, want %q", got, "exec_command:0")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "messages.0.tool_calls.1.id").String(); got != "exec_command:1" {
|
||||||
|
t.Fatalf("messages.0.tool_calls.1.id = %q, want %q", got, "exec_command:1")
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := gjson.GetBytes(out, "messages.1.tool_call_id").String(); got != "exec_command:0" {
|
||||||
|
t.Fatalf("messages.1.tool_call_id = %q, want %q", got, "exec_command:0")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "messages.2.tool_call_id").String(); got != "exec_command:1" {
|
||||||
|
t.Fatalf("messages.2.tool_call_id = %q, want %q", got, "exec_command:1")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertOpenAIResponsesRequestToOpenAIChatCompletions_SplitFunctionCallsWhenInterrupted(t *testing.T) {
|
||||||
|
raw := []byte(`{
|
||||||
|
"input": [
|
||||||
|
{"type":"function_call","call_id":"call_a","name":"tool_a","arguments":"{}"},
|
||||||
|
{"type":"message","role":"user","content":"next"},
|
||||||
|
{"type":"function_call","call_id":"call_b","name":"tool_b","arguments":"{}"}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
t.Logf("input json:\n%s", prettyJSONForTest(raw))
|
||||||
|
|
||||||
|
out := ConvertOpenAIResponsesRequestToOpenAIChatCompletions("kimi-k2.6", raw, false)
|
||||||
|
t.Logf("output json:\n%s", prettyJSONForTest(out))
|
||||||
|
|
||||||
|
if got := len(gjson.GetBytes(out, "messages").Array()); got != 3 {
|
||||||
|
t.Fatalf("messages count = %d, want %d", got, 3)
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "messages.0.tool_calls.0.id").String(); got != "call_a" {
|
||||||
|
t.Fatalf("messages.0.tool_calls.0.id = %q, want %q", got, "call_a")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "messages.2.tool_calls.0.id").String(); got != "call_b" {
|
||||||
|
t.Fatalf("messages.2.tool_calls.0.id = %q, want %q", got, "call_b")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertOpenAIResponsesRequestToOpenAIChatCompletions_DefersMessageUntilToolOutput(t *testing.T) {
|
||||||
|
raw := []byte(`{
|
||||||
|
"input": [
|
||||||
|
{"type":"function_call","call_id":"call_x","name":"exec_command","arguments":"{\"cmd\":\"echo hi\"}"},
|
||||||
|
{"type":"message","role":"user","content":"Approved command prefix saved"},
|
||||||
|
{"type":"function_call_output","call_id":"call_x","output":"ok"},
|
||||||
|
{"type":"message","role":"user","content":"next"}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
t.Logf("input json:\n%s", prettyJSONForTest(raw))
|
||||||
|
|
||||||
|
out := ConvertOpenAIResponsesRequestToOpenAIChatCompletions("kimi-k2.6", raw, true)
|
||||||
|
t.Logf("output json:\n%s", prettyJSONForTest(out))
|
||||||
|
|
||||||
|
if got := len(gjson.GetBytes(out, "messages").Array()); got != 4 {
|
||||||
|
t.Fatalf("messages count = %d, want %d", got, 4)
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "messages.0.role").String(); got != "assistant" {
|
||||||
|
t.Fatalf("messages.0.role = %q, want %q", got, "assistant")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "messages.1.role").String(); got != "tool" {
|
||||||
|
t.Fatalf("messages.1.role = %q, want %q", got, "tool")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "messages.1.tool_call_id").String(); got != "call_x" {
|
||||||
|
t.Fatalf("messages.1.tool_call_id = %q, want %q", got, "call_x")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "messages.2.role").String(); got != "user" {
|
||||||
|
t.Fatalf("messages.2.role = %q, want %q", got, "user")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "messages.2.content").String(); got != "Approved command prefix saved" {
|
||||||
|
t.Fatalf("messages.2.content = %q, want %q", got, "Approved command prefix saved")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "messages.3.content").String(); got != "next" {
|
||||||
|
t.Fatalf("messages.3.content = %q, want %q", got, "next")
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user