fix(openai): add session reference counter and cache lifecycle management for websocket tools
This commit is contained in:
@@ -55,6 +55,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
passthroughSessionID := uuid.NewString()
|
passthroughSessionID := uuid.NewString()
|
||||||
downstreamSessionKey := websocketDownstreamSessionKey(c.Request)
|
downstreamSessionKey := websocketDownstreamSessionKey(c.Request)
|
||||||
|
retainResponsesWebsocketToolCaches(downstreamSessionKey)
|
||||||
clientRemoteAddr := ""
|
clientRemoteAddr := ""
|
||||||
if c != nil && c.Request != nil {
|
if c != nil && c.Request != nil {
|
||||||
clientRemoteAddr = strings.TrimSpace(c.Request.RemoteAddr)
|
clientRemoteAddr = strings.TrimSpace(c.Request.RemoteAddr)
|
||||||
@@ -63,6 +64,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
|||||||
var wsTerminateErr error
|
var wsTerminateErr error
|
||||||
var wsBodyLog strings.Builder
|
var wsBodyLog strings.Builder
|
||||||
defer func() {
|
defer func() {
|
||||||
|
releaseResponsesWebsocketToolCaches(downstreamSessionKey)
|
||||||
if wsTerminateErr != nil {
|
if wsTerminateErr != nil {
|
||||||
// log.Infof("responses websocket: session closing id=%s reason=%v", passthroughSessionID, wsTerminateErr)
|
// log.Infof("responses websocket: session closing id=%s reason=%v", passthroughSessionID, wsTerminateErr)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -16,8 +16,9 @@ const (
|
|||||||
websocketToolOutputCacheTTL = 30 * time.Minute
|
websocketToolOutputCacheTTL = 30 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
var defaultWebsocketToolOutputCache = newWebsocketToolOutputCache(websocketToolOutputCacheTTL, websocketToolOutputCacheMaxPerSession)
|
var defaultWebsocketToolOutputCache = newWebsocketToolOutputCache(0, websocketToolOutputCacheMaxPerSession)
|
||||||
var defaultWebsocketToolCallCache = newWebsocketToolOutputCache(websocketToolOutputCacheTTL, websocketToolOutputCacheMaxPerSession)
|
var defaultWebsocketToolCallCache = newWebsocketToolOutputCache(0, websocketToolOutputCacheMaxPerSession)
|
||||||
|
var defaultWebsocketToolSessionRefs = newWebsocketToolSessionRefCounter()
|
||||||
|
|
||||||
type websocketToolOutputCache struct {
|
type websocketToolOutputCache struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
@@ -33,7 +34,7 @@ type websocketToolOutputSession struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newWebsocketToolOutputCache(ttl time.Duration, maxPerSession int) *websocketToolOutputCache {
|
func newWebsocketToolOutputCache(ttl time.Duration, maxPerSession int) *websocketToolOutputCache {
|
||||||
if ttl <= 0 {
|
if ttl < 0 {
|
||||||
ttl = websocketToolOutputCacheTTL
|
ttl = websocketToolOutputCacheTTL
|
||||||
}
|
}
|
||||||
if maxPerSession <= 0 {
|
if maxPerSession <= 0 {
|
||||||
@@ -122,13 +123,22 @@ func (c *websocketToolOutputCache) cleanupLocked(now time.Time) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *websocketToolOutputCache) deleteSession(sessionKey string) {
|
||||||
|
sessionKey = strings.TrimSpace(sessionKey)
|
||||||
|
if sessionKey == "" || c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
delete(c.sessions, sessionKey)
|
||||||
|
}
|
||||||
|
|
||||||
func websocketDownstreamSessionKey(req *http.Request) string {
|
func websocketDownstreamSessionKey(req *http.Request) string {
|
||||||
if req == nil {
|
if req == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
if sessionID := strings.TrimSpace(req.Header.Get("Session_id")); sessionID != "" {
|
|
||||||
return sessionID
|
|
||||||
}
|
|
||||||
if requestID := strings.TrimSpace(req.Header.Get("X-Client-Request-Id")); requestID != "" {
|
if requestID := strings.TrimSpace(req.Header.Get("X-Client-Request-Id")); requestID != "" {
|
||||||
return requestID
|
return requestID
|
||||||
}
|
}
|
||||||
@@ -137,9 +147,74 @@ func websocketDownstreamSessionKey(req *http.Request) string {
|
|||||||
return sessionID
|
return sessionID
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if sessionID := strings.TrimSpace(req.Header.Get("Session_id")); sessionID != "" {
|
||||||
|
return sessionID
|
||||||
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type websocketToolSessionRefCounter struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
counts map[string]int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newWebsocketToolSessionRefCounter() *websocketToolSessionRefCounter {
|
||||||
|
return &websocketToolSessionRefCounter{counts: make(map[string]int)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *websocketToolSessionRefCounter) acquire(sessionKey string) {
|
||||||
|
sessionKey = strings.TrimSpace(sessionKey)
|
||||||
|
if sessionKey == "" || c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
c.counts[sessionKey]++
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *websocketToolSessionRefCounter) release(sessionKey string) bool {
|
||||||
|
sessionKey = strings.TrimSpace(sessionKey)
|
||||||
|
if sessionKey == "" || c == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
count := c.counts[sessionKey]
|
||||||
|
if count <= 1 {
|
||||||
|
delete(c.counts, sessionKey)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
c.counts[sessionKey] = count - 1
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func retainResponsesWebsocketToolCaches(sessionKey string) {
|
||||||
|
if defaultWebsocketToolSessionRefs == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defaultWebsocketToolSessionRefs.acquire(sessionKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func releaseResponsesWebsocketToolCaches(sessionKey string) {
|
||||||
|
if defaultWebsocketToolSessionRefs == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !defaultWebsocketToolSessionRefs.release(sessionKey) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if defaultWebsocketToolOutputCache != nil {
|
||||||
|
defaultWebsocketToolOutputCache.deleteSession(sessionKey)
|
||||||
|
}
|
||||||
|
if defaultWebsocketToolCallCache != nil {
|
||||||
|
defaultWebsocketToolCallCache.deleteSession(sessionKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func repairResponsesWebsocketToolCalls(sessionKey string, payload []byte) []byte {
|
func repairResponsesWebsocketToolCalls(sessionKey string, payload []byte) []byte {
|
||||||
return repairResponsesWebsocketToolCallsWithCaches(defaultWebsocketToolOutputCache, defaultWebsocketToolCallCache, sessionKey, payload)
|
return repairResponsesWebsocketToolCallsWithCaches(defaultWebsocketToolOutputCache, defaultWebsocketToolCallCache, sessionKey, payload)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user