feat(executor): enhance Qwen system message handling with strict injection and merging rules
Closes: #2537
This commit is contained in:
@@ -172,32 +172,101 @@ func timeUntilNextDay() time.Duration {
|
|||||||
return tomorrow.Sub(now)
|
return tomorrow.Sub(now)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ensureQwenSystemMessage prepends a default system message if none exists in "messages".
|
// ensureQwenSystemMessage ensures the request has a single system message at the beginning.
|
||||||
|
// It always injects the default system prompt and merges any user-provided system messages
|
||||||
|
// into the injected system message content to satisfy Qwen's strict message ordering rules.
|
||||||
func ensureQwenSystemMessage(payload []byte) ([]byte, error) {
|
func ensureQwenSystemMessage(payload []byte) ([]byte, error) {
|
||||||
messages := gjson.GetBytes(payload, "messages")
|
isInjectedSystemPart := func(part gjson.Result) bool {
|
||||||
if messages.Exists() && messages.IsArray() {
|
if !part.Exists() || !part.IsObject() {
|
||||||
var buf bytes.Buffer
|
return false
|
||||||
buf.WriteByte('[')
|
|
||||||
buf.Write(qwenDefaultSystemMessage)
|
|
||||||
for _, msg := range messages.Array() {
|
|
||||||
buf.WriteByte(',')
|
|
||||||
buf.WriteString(msg.Raw)
|
|
||||||
}
|
}
|
||||||
buf.WriteByte(']')
|
if !strings.EqualFold(part.Get("type").String(), "text") {
|
||||||
updated, errSet := sjson.SetRawBytes(payload, "messages", buf.Bytes())
|
return false
|
||||||
if errSet != nil {
|
|
||||||
return nil, fmt.Errorf("qwen executor: set default system message failed: %w", errSet)
|
|
||||||
}
|
}
|
||||||
return updated, nil
|
if !strings.EqualFold(part.Get("cache_control.type").String(), "ephemeral") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
text := part.Get("text").String()
|
||||||
|
return text == "" || text == "You are Qwen Code."
|
||||||
}
|
}
|
||||||
|
|
||||||
var buf bytes.Buffer
|
defaultParts := gjson.ParseBytes(qwenDefaultSystemMessage).Get("content")
|
||||||
buf.WriteByte('[')
|
var systemParts []any
|
||||||
buf.Write(qwenDefaultSystemMessage)
|
if defaultParts.Exists() && defaultParts.IsArray() {
|
||||||
buf.WriteByte(']')
|
for _, part := range defaultParts.Array() {
|
||||||
updated, errSet := sjson.SetRawBytes(payload, "messages", buf.Bytes())
|
systemParts = append(systemParts, part.Value())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(systemParts) == 0 {
|
||||||
|
systemParts = append(systemParts, map[string]any{
|
||||||
|
"type": "text",
|
||||||
|
"text": "You are Qwen Code.",
|
||||||
|
"cache_control": map[string]any{
|
||||||
|
"type": "ephemeral",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
appendSystemContent := func(content gjson.Result) {
|
||||||
|
makeTextPart := func(text string) map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"type": "text",
|
||||||
|
"text": text,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !content.Exists() || content.Type == gjson.Null {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if content.IsArray() {
|
||||||
|
for _, part := range content.Array() {
|
||||||
|
if part.Type == gjson.String {
|
||||||
|
systemParts = append(systemParts, makeTextPart(part.String()))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if isInjectedSystemPart(part) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
systemParts = append(systemParts, part.Value())
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if content.Type == gjson.String {
|
||||||
|
systemParts = append(systemParts, makeTextPart(content.String()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if content.IsObject() {
|
||||||
|
if isInjectedSystemPart(content) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
systemParts = append(systemParts, content.Value())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
systemParts = append(systemParts, makeTextPart(content.String()))
|
||||||
|
}
|
||||||
|
|
||||||
|
messages := gjson.GetBytes(payload, "messages")
|
||||||
|
var nonSystemMessages []any
|
||||||
|
if messages.Exists() && messages.IsArray() {
|
||||||
|
for _, msg := range messages.Array() {
|
||||||
|
if strings.EqualFold(msg.Get("role").String(), "system") {
|
||||||
|
appendSystemContent(msg.Get("content"))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
nonSystemMessages = append(nonSystemMessages, msg.Value())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
newMessages := make([]any, 0, 1+len(nonSystemMessages))
|
||||||
|
newMessages = append(newMessages, map[string]any{
|
||||||
|
"role": "system",
|
||||||
|
"content": systemParts,
|
||||||
|
})
|
||||||
|
newMessages = append(newMessages, nonSystemMessages...)
|
||||||
|
|
||||||
|
updated, errSet := sjson.SetBytes(payload, "messages", newMessages)
|
||||||
if errSet != nil {
|
if errSet != nil {
|
||||||
return nil, fmt.Errorf("qwen executor: set default system message failed: %w", errSet)
|
return nil, fmt.Errorf("qwen executor: set system message failed: %w", errSet)
|
||||||
}
|
}
|
||||||
return updated, nil
|
return updated, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestQwenExecutorParseSuffix(t *testing.T) {
|
func TestQwenExecutorParseSuffix(t *testing.T) {
|
||||||
@@ -28,3 +29,123 @@ func TestQwenExecutorParseSuffix(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEnsureQwenSystemMessage_MergeStringSystem(t *testing.T) {
|
||||||
|
payload := []byte(`{
|
||||||
|
"model": "qwen3.6-plus",
|
||||||
|
"stream": true,
|
||||||
|
"messages": [
|
||||||
|
{ "role": "system", "content": "ABCDEFG" },
|
||||||
|
{ "role": "user", "content": [ { "type": "text", "text": "你好" } ] }
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
out, err := ensureQwenSystemMessage(payload)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ensureQwenSystemMessage() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msgs := gjson.GetBytes(out, "messages").Array()
|
||||||
|
if len(msgs) != 2 {
|
||||||
|
t.Fatalf("messages length = %d, want 2", len(msgs))
|
||||||
|
}
|
||||||
|
if msgs[0].Get("role").String() != "system" {
|
||||||
|
t.Fatalf("messages[0].role = %q, want %q", msgs[0].Get("role").String(), "system")
|
||||||
|
}
|
||||||
|
parts := msgs[0].Get("content").Array()
|
||||||
|
if len(parts) != 2 {
|
||||||
|
t.Fatalf("messages[0].content length = %d, want 2", len(parts))
|
||||||
|
}
|
||||||
|
if parts[0].Get("text").String() != "You are Qwen Code." || parts[0].Get("cache_control.type").String() != "ephemeral" {
|
||||||
|
t.Fatalf("messages[0].content[0] = %s, want injected system part", parts[0].Raw)
|
||||||
|
}
|
||||||
|
if parts[1].Get("type").String() != "text" || parts[1].Get("text").String() != "ABCDEFG" {
|
||||||
|
t.Fatalf("messages[0].content[1] = %s, want text part with ABCDEFG", parts[1].Raw)
|
||||||
|
}
|
||||||
|
if msgs[1].Get("role").String() != "user" {
|
||||||
|
t.Fatalf("messages[1].role = %q, want %q", msgs[1].Get("role").String(), "user")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureQwenSystemMessage_MergeObjectSystem(t *testing.T) {
|
||||||
|
payload := []byte(`{
|
||||||
|
"messages": [
|
||||||
|
{ "role": "system", "content": { "type": "text", "text": "ABCDEFG" } },
|
||||||
|
{ "role": "user", "content": [ { "type": "text", "text": "你好" } ] }
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
out, err := ensureQwenSystemMessage(payload)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ensureQwenSystemMessage() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msgs := gjson.GetBytes(out, "messages").Array()
|
||||||
|
if len(msgs) != 2 {
|
||||||
|
t.Fatalf("messages length = %d, want 2", len(msgs))
|
||||||
|
}
|
||||||
|
parts := msgs[0].Get("content").Array()
|
||||||
|
if len(parts) != 2 {
|
||||||
|
t.Fatalf("messages[0].content length = %d, want 2", len(parts))
|
||||||
|
}
|
||||||
|
if parts[1].Get("text").String() != "ABCDEFG" {
|
||||||
|
t.Fatalf("messages[0].content[1].text = %q, want %q", parts[1].Get("text").String(), "ABCDEFG")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureQwenSystemMessage_PrependsWhenMissing(t *testing.T) {
|
||||||
|
payload := []byte(`{
|
||||||
|
"messages": [
|
||||||
|
{ "role": "user", "content": [ { "type": "text", "text": "你好" } ] }
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
out, err := ensureQwenSystemMessage(payload)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ensureQwenSystemMessage() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msgs := gjson.GetBytes(out, "messages").Array()
|
||||||
|
if len(msgs) != 2 {
|
||||||
|
t.Fatalf("messages length = %d, want 2", len(msgs))
|
||||||
|
}
|
||||||
|
if msgs[0].Get("role").String() != "system" {
|
||||||
|
t.Fatalf("messages[0].role = %q, want %q", msgs[0].Get("role").String(), "system")
|
||||||
|
}
|
||||||
|
if !msgs[0].Get("content").IsArray() || len(msgs[0].Get("content").Array()) == 0 {
|
||||||
|
t.Fatalf("messages[0].content = %s, want non-empty array", msgs[0].Get("content").Raw)
|
||||||
|
}
|
||||||
|
if msgs[1].Get("role").String() != "user" {
|
||||||
|
t.Fatalf("messages[1].role = %q, want %q", msgs[1].Get("role").String(), "user")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureQwenSystemMessage_MergesMultipleSystemMessages(t *testing.T) {
|
||||||
|
payload := []byte(`{
|
||||||
|
"messages": [
|
||||||
|
{ "role": "system", "content": "A" },
|
||||||
|
{ "role": "user", "content": [ { "type": "text", "text": "hi" } ] },
|
||||||
|
{ "role": "system", "content": "B" }
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
out, err := ensureQwenSystemMessage(payload)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ensureQwenSystemMessage() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msgs := gjson.GetBytes(out, "messages").Array()
|
||||||
|
if len(msgs) != 2 {
|
||||||
|
t.Fatalf("messages length = %d, want 2", len(msgs))
|
||||||
|
}
|
||||||
|
parts := msgs[0].Get("content").Array()
|
||||||
|
if len(parts) != 3 {
|
||||||
|
t.Fatalf("messages[0].content length = %d, want 3", len(parts))
|
||||||
|
}
|
||||||
|
if parts[1].Get("text").String() != "A" {
|
||||||
|
t.Fatalf("messages[0].content[1].text = %q, want %q", parts[1].Get("text").String(), "A")
|
||||||
|
}
|
||||||
|
if parts[2].Get("text").String() != "B" {
|
||||||
|
t.Fatalf("messages[0].content[2].text = %q, want %q", parts[2].Get("text").String(), "B")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user