package xai import ( "context" "encoding/json" "net/http" "net/http/httptest" "net/url" "strings" "testing" ) func TestBuildAuthorizeURLIncludesXAIRequiredParameters(t *testing.T) { authURL, err := BuildAuthorizeURL(AuthorizeURLParams{ AuthorizationEndpoint: "https://auth.x.ai/oauth/authorize", RedirectURI: "http://127.0.0.1:56121/callback", CodeChallenge: "challenge", State: "state-123", Nonce: "nonce-123", }) if err != nil { t.Fatalf("BuildAuthorizeURL() error = %v", err) } parsed, errParse := url.Parse(authURL) if errParse != nil { t.Fatalf("parse authorize URL: %v", errParse) } if parsed.Scheme != "https" || parsed.Host != "auth.x.ai" || parsed.Path != "/oauth/authorize" { t.Fatalf("authorize URL endpoint = %s://%s%s", parsed.Scheme, parsed.Host, parsed.Path) } query := parsed.Query() want := map[string]string{ "response_type": "code", "client_id": ClientID, "redirect_uri": "http://127.0.0.1:56121/callback", "scope": Scope, "code_challenge": "challenge", "code_challenge_method": "S256", "state": "state-123", "nonce": "nonce-123", "plan": "generic", "referrer": "cli-proxy-api", } for key, value := range want { if got := query.Get(key); got != value { t.Fatalf("%s = %q, want %q", key, got, value) } } } func TestValidateOAuthEndpointRejectsNonXAIOrigin(t *testing.T) { if _, err := ValidateOAuthEndpoint("https://auth.x.ai/oauth/token", "token_endpoint"); err != nil { t.Fatalf("ValidateOAuthEndpoint(xai) error = %v", err) } if _, err := ValidateOAuthEndpoint("http://auth.x.ai/oauth/token", "token_endpoint"); err == nil { t.Fatal("expected non-HTTPS endpoint to be rejected") } if _, err := ValidateOAuthEndpoint("https://evil.example/oauth/token", "token_endpoint"); err == nil { t.Fatal("expected non-xAI endpoint to be rejected") } } func TestRefreshTokensPostsClientIDAndRefreshToken(t *testing.T) { var gotForm url.Values server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { t.Fatalf("method = %s, want POST", r.Method) } if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "application/x-www-form-urlencoded") { t.Fatalf("Content-Type = %q, want form", got) } if err := r.ParseForm(); err != nil { t.Fatalf("ParseForm() error = %v", err) } gotForm = r.PostForm w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(map[string]any{ "access_token": "new-access", "refresh_token": "new-refresh", "token_type": "Bearer", "expires_in": 3600, }) })) defer server.Close() auth := NewXAIAuth(nil) tokenData, err := auth.RefreshTokens(context.Background(), "old-refresh", server.URL) if err != nil { t.Fatalf("RefreshTokens() error = %v", err) } if tokenData.AccessToken != "new-access" { t.Fatalf("access token = %q, want new-access", tokenData.AccessToken) } if gotForm.Get("grant_type") != "refresh_token" { t.Fatalf("grant_type = %q, want refresh_token", gotForm.Get("grant_type")) } if gotForm.Get("client_id") != ClientID { t.Fatalf("client_id = %q, want %q", gotForm.Get("client_id"), ClientID) } if gotForm.Get("refresh_token") != "old-refresh" { t.Fatalf("refresh_token = %q, want old-refresh", gotForm.Get("refresh_token")) } }