From 5e0be03d18964021ebe4e3d440b84e791fcc62dd Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Mon, 9 Mar 2026 11:44:39 +0100 Subject: [PATCH] feat: add OAuth 2.1 authentication for stdio mode Add PKCE and device flow OAuth support for stdio mode, enabling browser-based authentication as an alternative to PATs. Flow priority (security-ordered): 1. PKCE + browser auto-open (native) 2. PKCE + URL elicitation (Docker with bound port) 3. Device flow fallback (more phishable, last resort) Key changes: - internal/oauth: self-contained OAuth manager with PKCE and device flow - internal/buildinfo: build-time OAuth credential injection via ldflags - BearerAuthTransport: added TokenProvider for dynamic token resolution - OAuth middleware intercepts tools/call to trigger lazy authentication - Scope-based tool filtering using existing SupportedScopes - PAT remains optional when OAuth credentials are configured Security: - PKCE S256 prevents code interception - State parameter prevents CSRF - Callback binds to 127.0.0.1 only - URL elicitation for sensitive URLs (never exposed to LLM) - Tokens stored in memory only, never persisted to disk - ReadHeaderTimeout prevents Slowloris on callback server - html/template auto-escaping prevents XSS in callback pages Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .github/workflows/docker-publish.yml | 3 + .github/workflows/goreleaser.yml | 2 + .goreleaser.yaml | 2 +- Dockerfile | 9 +- README.md | 33 +++ cmd/github-mcp-server/main.go | 72 ++++- docs/oauth-authentication.md | 131 +++++++++ go.mod | 2 +- internal/buildinfo/buildinfo.go | 16 ++ internal/ghmcp/server.go | 107 ++++++- internal/oauth/manager.go | 364 ++++++++++++++++++++++++ internal/oauth/oauth.go | 256 +++++++++++++++++ internal/oauth/oauth_test.go | 265 +++++++++++++++++ internal/oauth/templates/error.html | 60 ++++ internal/oauth/templates/success.html | 56 ++++ pkg/http/transport/bearer.go | 12 +- third-party-licenses.darwin.md | 1 + third-party-licenses.linux.md | 1 + third-party-licenses.windows.md | 1 + third-party/golang.org/x/oauth2/LICENSE | 27 ++ 20 files changed, 1403 insertions(+), 17 deletions(-) create mode 100644 docs/oauth-authentication.md create mode 100644 internal/buildinfo/buildinfo.go create mode 100644 internal/oauth/manager.go create mode 100644 internal/oauth/oauth.go create mode 100644 internal/oauth/oauth_test.go create mode 100644 internal/oauth/templates/error.html create mode 100644 internal/oauth/templates/success.html create mode 100644 third-party/golang.org/x/oauth2/LICENSE diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml index f03d08121..1e62ea9a5 100644 --- a/.github/workflows/docker-publish.yml +++ b/.github/workflows/docker-publish.yml @@ -117,6 +117,9 @@ jobs: platforms: linux/amd64,linux/arm64 build-args: | VERSION=${{ github.ref_name }} + secrets: | + oauth_client_id=${{ secrets.OAUTH_CLIENT_ID }} + oauth_client_secret=${{ secrets.OAUTH_CLIENT_SECRET }} # Sign the resulting Docker image digest except on PRs. # This will only write to the public Rekor transparency log when the Docker diff --git a/.github/workflows/goreleaser.yml b/.github/workflows/goreleaser.yml index f8eddc076..d672c7d65 100644 --- a/.github/workflows/goreleaser.yml +++ b/.github/workflows/goreleaser.yml @@ -45,6 +45,8 @@ jobs: workdir: . env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + OAUTH_CLIENT_ID: ${{ secrets.OAUTH_CLIENT_ID }} + OAUTH_CLIENT_SECRET: ${{ secrets.OAUTH_CLIENT_SECRET }} - name: Generate signed build provenance attestations for workflow artifacts uses: actions/attest-build-provenance@v3 diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 54f6b9f40..36dfc47bc 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -9,7 +9,7 @@ builds: - env: - CGO_ENABLED=0 ldflags: - - -s -w -X main.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.Date}} + - -s -w -X main.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.Date}} -X github.com/github/github-mcp-server/internal/buildinfo.OAuthClientID={{ .Env.OAUTH_CLIENT_ID }} -X github.com/github/github-mcp-server/internal/buildinfo.OAuthClientSecret={{ .Env.OAUTH_CLIENT_SECRET }} goos: - linux - windows diff --git a/Dockerfile b/Dockerfile index 90c8b4007..9d9d86318 100644 --- a/Dockerfile +++ b/Dockerfile @@ -23,10 +23,15 @@ COPY . . # Copy built UI assets over the placeholder COPY --from=ui-build /app/pkg/github/ui_dist/* ./pkg/github/ui_dist/ -# Build the server +# Build the server. OAuth credentials are injected via build secrets to avoid +# leaking them in image history. Secrets are read at build time only. RUN --mount=type=cache,target=/go/pkg/mod \ --mount=type=cache,target=/root/.cache/go-build \ - CGO_ENABLED=0 go build -ldflags="-s -w -X main.version=${VERSION} -X main.commit=$(git rev-parse HEAD) -X main.date=$(date -u +%Y-%m-%dT%H:%M:%SZ)" \ + --mount=type=secret,id=oauth_client_id \ + --mount=type=secret,id=oauth_client_secret \ + export OAUTH_CLIENT_ID="$(cat /run/secrets/oauth_client_id 2>/dev/null || echo '')" && \ + export OAUTH_CLIENT_SECRET="$(cat /run/secrets/oauth_client_secret 2>/dev/null || echo '')" && \ + CGO_ENABLED=0 go build -ldflags="-s -w -X main.version=${VERSION} -X main.commit=$(git rev-parse HEAD) -X main.date=$(date -u +%Y-%m-%dT%H:%M:%SZ) -X github.com/github/github-mcp-server/internal/buildinfo.OAuthClientID=${OAUTH_CLIENT_ID} -X github.com/github/github-mcp-server/internal/buildinfo.OAuthClientSecret=${OAUTH_CLIENT_SECRET}" \ -o /bin/github-mcp-server ./cmd/github-mcp-server # Make a stage to run the app diff --git a/README.md b/README.md index 1b926b132..1abcf7774 100644 --- a/README.md +++ b/README.md @@ -239,6 +239,39 @@ To keep your GitHub PAT secure and reusable across different MCP hosts: +### OAuth Authentication (stdio mode) + +For stdio mode, you can use OAuth 2.1 instead of a Personal Access Token. The server triggers the OAuth flow on the first tool call: + +| Environment | Flow | Setup | +|-------------|------|-------| +| Docker with port | PKCE (URL elicitation) | Set `GITHUB_OAUTH_CLIENT_ID` + bind port | +| Docker without port | Device flow (enter code at github.com/login/device) | Set `GITHUB_OAUTH_CLIENT_ID` | +| Native binary | PKCE (browser auto-opens) | Set `GITHUB_OAUTH_CLIENT_ID` | + +**Docker example (PKCE with bound port — recommended):** +```json +{ + "mcpServers": { + "github": { + "command": "docker", + "args": ["run", "-i", "--rm", + "-e", "GITHUB_OAUTH_CLIENT_ID", + "-e", "GITHUB_OAUTH_CLIENT_SECRET", + "-e", "GITHUB_OAUTH_CALLBACK_PORT=8085", + "-p", "127.0.0.1:8085:8085", + "ghcr.io/github/github-mcp-server"], + "env": { + "GITHUB_OAUTH_CLIENT_ID": "your_client_id", + "GITHUB_OAUTH_CLIENT_SECRET": "your_client_secret" + } + } + } +} +``` + +See [docs/oauth-authentication.md](docs/oauth-authentication.md) for full setup instructions, including how to create a GitHub OAuth App. + ### GitHub Enterprise Server and Enterprise Cloud with data residency (ghe.com) The flag `--gh-host` and the environment variable `GITHUB_HOST` can be used to set diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index 05c2c6e0b..d2bfc9930 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -7,9 +7,12 @@ import ( "strings" "time" + "github.com/github/github-mcp-server/internal/buildinfo" "github.com/github/github-mcp-server/internal/ghmcp" + "github.com/github/github-mcp-server/internal/oauth" "github.com/github/github-mcp-server/pkg/github" ghhttp "github.com/github/github-mcp-server/pkg/http" + ghoauth "github.com/github/github-mcp-server/pkg/http/oauth" "github.com/spf13/cobra" "github.com/spf13/pflag" "github.com/spf13/viper" @@ -34,8 +37,12 @@ var ( Long: `Start a server that communicates via standard input/output streams using JSON-RPC messages.`, RunE: func(_ *cobra.Command, _ []string) error { token := viper.GetString("personal_access_token") - if token == "" { - return errors.New("GITHUB_PERSONAL_ACCESS_TOKEN not set") + + // Resolve OAuth credentials: explicit config > build-time > none + oauthClientID, oauthClientSecret := resolveOAuthCredentials() + + if token == "" && oauthClientID == "" { + return errors.New("GITHUB_PERSONAL_ACCESS_TOKEN not set and no OAuth credentials available") } // If you're wondering why we're not using viper.GetStringSlice("toolsets"), @@ -96,6 +103,22 @@ var ( ExcludeTools: excludeTools, RepoAccessCacheTTL: &ttl, } + + // Configure OAuth if credentials are available and no PAT is set. + // PAT takes priority — if both are configured, PAT is used directly. + if token == "" && oauthClientID != "" { + oauthScopes := getOAuthScopes() + oauthCfg := oauth.GetGitHubOAuthConfig( + oauthClientID, + oauthClientSecret, + oauthScopes, + viper.GetString("host"), + viper.GetInt("oauth-callback-port"), + ) + stdioServerConfig.OAuthManager = oauth.NewManager(oauthCfg, nil) + stdioServerConfig.OAuthScopes = oauthScopes + } + return ghmcp.RunStdioServer(stdioServerConfig) }, } @@ -154,6 +177,12 @@ func init() { httpCmd.Flags().String("base-path", "", "Externally visible base path for the HTTP server (for OAuth resource metadata)") httpCmd.Flags().Bool("scope-challenge", false, "Enable OAuth scope challenge responses") + // OAuth flags (stdio only) + stdioCmd.Flags().String("oauth-client-id", "", "OAuth client ID for browser-based authentication") + stdioCmd.Flags().String("oauth-client-secret", "", "OAuth client secret") + stdioCmd.Flags().StringSlice("oauth-scopes", nil, "Explicit OAuth scopes to request (overrides automatic computation)") + stdioCmd.Flags().Int("oauth-callback-port", 0, "Fixed port for OAuth callback server (0 for random, required for Docker with -p)") + // Bind flag to viper _ = viper.BindPFlag("toolsets", rootCmd.PersistentFlags().Lookup("toolsets")) _ = viper.BindPFlag("tools", rootCmd.PersistentFlags().Lookup("tools")) @@ -173,6 +202,10 @@ func init() { _ = viper.BindPFlag("base-url", httpCmd.Flags().Lookup("base-url")) _ = viper.BindPFlag("base-path", httpCmd.Flags().Lookup("base-path")) _ = viper.BindPFlag("scope-challenge", httpCmd.Flags().Lookup("scope-challenge")) + _ = viper.BindPFlag("oauth-client-id", stdioCmd.Flags().Lookup("oauth-client-id")) + _ = viper.BindPFlag("oauth-client-secret", stdioCmd.Flags().Lookup("oauth-client-secret")) + _ = viper.BindPFlag("oauth-scopes", stdioCmd.Flags().Lookup("oauth-scopes")) + _ = viper.BindPFlag("oauth-callback-port", stdioCmd.Flags().Lookup("oauth-callback-port")) // Add subcommands rootCmd.AddCommand(stdioCmd) rootCmd.AddCommand(httpCmd) @@ -200,3 +233,38 @@ func wordSepNormalizeFunc(_ *pflag.FlagSet, name string) pflag.NormalizedName { } return pflag.NormalizedName(name) } + +// resolveOAuthCredentials returns OAuth client credentials from the best +// available source. Priority: explicit config > build-time baked > none. +func resolveOAuthCredentials() (clientID, clientSecret string) { + clientID = viper.GetString("oauth-client-id") + clientSecret = viper.GetString("oauth-client-secret") + if clientID != "" { + return clientID, clientSecret + } + + if buildinfo.OAuthClientID != "" { + return buildinfo.OAuthClientID, buildinfo.OAuthClientSecret + } + + return "", "" +} + +// getOAuthScopes returns the OAuth scopes to request. Uses explicit override +// if provided, otherwise falls back to the canonical SupportedScopes list +// which covers all tools the server may expose. +func getOAuthScopes() []string { + + if viper.IsSet("oauth-scopes") { + var scopes []string + if err := viper.UnmarshalKey("oauth-scopes", &scopes); err == nil && len(scopes) > 0 { + return scopes + } + } + + // Use the canonical list maintained alongside the HTTP OAuth metadata. + // This requests all scopes any tool might need. The consent screen shows + // the user exactly what is being requested, and scope-based tool filtering + // hides tools the granted token cannot satisfy. + return ghoauth.SupportedScopes +} diff --git a/docs/oauth-authentication.md b/docs/oauth-authentication.md new file mode 100644 index 000000000..e97860b66 --- /dev/null +++ b/docs/oauth-authentication.md @@ -0,0 +1,131 @@ +# OAuth Authentication (stdio mode) + +The GitHub MCP Server supports OAuth 2.1 authentication for stdio mode, allowing users to authenticate via their browser instead of manually creating Personal Access Tokens. + +## How It Works + +When no `GITHUB_PERSONAL_ACCESS_TOKEN` is configured and OAuth credentials are available, the server starts without a token. On the first tool call, it triggers the OAuth flow: + +1. **PKCE flow** (primary): A local callback server starts, your browser opens to GitHub's authorization page, and the token is received via callback. If the browser cannot open (e.g., Docker), the authorization URL is shown via [MCP URL elicitation](https://modelcontextprotocol.io/specification/2025-11-25/client/elicitation). + +2. **Device flow** (fallback): If the callback server cannot start (e.g., Docker without port binding), the server falls back to GitHub's [device flow](https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/authorizing-oauth-apps#device-flow). A code is displayed that you enter at [github.com/login/device](https://github.com/login/device). + +### Authentication Priority + +| Priority | Source | Notes | +|----------|--------|-------| +| 1 (highest) | `GITHUB_PERSONAL_ACCESS_TOKEN` | PAT is used directly, OAuth is skipped | +| 2 | `GITHUB_OAUTH_CLIENT_ID` (env/flag) | Explicit OAuth credentials | +| 3 | Built-in credentials | Baked into official releases via build flags | + +## Docker Setup (Recommended) + +Docker is the standard distribution method. The recommended setup uses PKCE with a bound port: + +```json +{ + "mcpServers": { + "github": { + "command": "docker", + "args": [ + "run", "-i", "--rm", + "-e", "GITHUB_OAUTH_CLIENT_ID", + "-e", "GITHUB_OAUTH_CLIENT_SECRET", + "-e", "GITHUB_OAUTH_CALLBACK_PORT=8085", + "-p", "127.0.0.1:8085:8085", + "ghcr.io/github/github-mcp-server" + ], + "env": { + "GITHUB_OAUTH_CLIENT_ID": "your_client_id", + "GITHUB_OAUTH_CLIENT_SECRET": "your_client_secret" + } + } + } +} +``` + +> **Security**: Always bind to `127.0.0.1` (not `0.0.0.0`) to restrict the callback to localhost. + +### Docker Without Port Binding (Device Flow) + +If you cannot bind a port, the server falls back to device flow: + +```json +{ + "mcpServers": { + "github": { + "command": "docker", + "args": [ + "run", "-i", "--rm", + "-e", "GITHUB_OAUTH_CLIENT_ID", + "-e", "GITHUB_OAUTH_CLIENT_SECRET", + "ghcr.io/github/github-mcp-server" + ], + "env": { + "GITHUB_OAUTH_CLIENT_ID": "your_client_id", + "GITHUB_OAUTH_CLIENT_SECRET": "your_client_secret" + } + } + } +} +``` + +## Native Binary Setup + +For native binaries, PKCE works automatically with a random port: + +```bash +export GITHUB_OAUTH_CLIENT_ID="your_client_id" +export GITHUB_OAUTH_CLIENT_SECRET="your_client_secret" +./github-mcp-server stdio +``` + +The browser opens automatically. No port configuration needed. + +## Creating a GitHub OAuth App + +1. Go to **GitHub Settings** → **Developer settings** → **OAuth Apps** +2. Click **New OAuth App** +3. Fill in: + - **Application name**: e.g., "GitHub MCP Server" + - **Homepage URL**: `https://github.com/github/github-mcp-server` + - **Authorization callback URL**: `http://localhost:8085/callback` (match your `--oauth-callback-port`) +4. Click **Register application** +5. Copy the **Client ID** and generate a **Client Secret** + +> **Note**: The callback URL must be registered even for device flow, though it won't be used. + +## Configuration Reference + +| Environment Variable | Flag | Description | +|---------------------|------|-------------| +| `GITHUB_OAUTH_CLIENT_ID` | `--oauth-client-id` | OAuth client ID | +| `GITHUB_OAUTH_CLIENT_SECRET` | `--oauth-client-secret` | OAuth client secret | +| `GITHUB_OAUTH_CALLBACK_PORT` | `--oauth-callback-port` | Fixed callback port (0 = random) | +| `GITHUB_OAUTH_SCOPES` | `--oauth-scopes` | Override automatic scope selection | + +## Security Design + +### PKCE (Proof Key for Code Exchange) +All authorization code flows use PKCE with S256 challenge, preventing authorization code interception even if an attacker can observe the callback. + +### Fixed Port Considerations +Docker requires a fixed callback port for port mapping. This is acceptable because: +- **PKCE verifier** is generated per-flow and never leaves the process — an attacker who intercepts the callback cannot exchange the code +- **State parameter** prevents CSRF — the callback validates state match +- **Callback server binds to 127.0.0.1** — not accessible from outside the host +- **Short-lived** — the server shuts down immediately after receiving the callback + +### Token Handling +- Tokens are stored **in memory only** — never written to disk +- OAuth token takes precedence over PAT if both become available +- The server requests only the scopes needed by the configured tools + +### URL Elicitation Security +When the browser cannot auto-open, the authorization URL is shown via MCP URL-mode elicitation. This is secure because: +- URL elicitation presents the URL to the user without exposing it to the LLM context +- The MCP client shows the full URL for user inspection before navigation +- Credentials flow directly between the user's browser and GitHub — never through the MCP channel + +### Device Flow as Fallback +Device flow is more susceptible to social engineering than PKCE (the device code could theoretically be phished), which is why PKCE is always attempted first. Device flow is only used when a callback server cannot be started. diff --git a/go.mod b/go.mod index 2bacfe759..f8a6a3dd8 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/spf13/viper v1.21.0 github.com/stretchr/testify v1.11.1 github.com/yosida95/uritemplate/v3 v3.0.2 + golang.org/x/oauth2 v0.34.0 ) require ( @@ -45,7 +46,6 @@ require ( go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect golang.org/x/net v0.38.0 // indirect - golang.org/x/oauth2 v0.34.0 // indirect golang.org/x/sys v0.40.0 // indirect golang.org/x/text v0.28.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/internal/buildinfo/buildinfo.go b/internal/buildinfo/buildinfo.go new file mode 100644 index 000000000..152e843d3 --- /dev/null +++ b/internal/buildinfo/buildinfo.go @@ -0,0 +1,16 @@ +// Package buildinfo contains variables that are set at build time via ldflags. +// These allow official releases to include default OAuth credentials without +// requiring end-user configuration. +// +// Example ldflags usage: +// +// go build -ldflags="-X github.com/github/github-mcp-server/internal/buildinfo.OAuthClientID=xxx" +package buildinfo + +// OAuthClientID is the default OAuth client ID, set at build time. +var OAuthClientID string + +// OAuthClientSecret is the default OAuth client secret, set at build time. +// Note: For public OAuth clients (native apps), the client secret is not +// truly secret per OAuth 2.1 — security relies on PKCE, not the secret. +var OAuthClientSecret string diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index 5c4e7f6f1..a8e729d26 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -12,6 +12,7 @@ import ( "syscall" "time" + "github.com/github/github-mcp-server/internal/oauth" "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/github" "github.com/github/github-mcp-server/pkg/http/transport" @@ -37,7 +38,8 @@ type githubClients struct { } // createGitHubClients creates all the GitHub API clients needed by the server. -func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolver) (*githubClients, error) { +// If tokenProvider is set, it is used for dynamic token resolution (OAuth). +func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolver, tokenProvider func() string) (*githubClients, error) { restURL, err := apiHost.BaseRESTURL(context.Background()) if err != nil { return nil, fmt.Errorf("failed to get base REST URL: %w", err) @@ -58,20 +60,34 @@ func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolv return nil, fmt.Errorf("failed to get Raw URL: %w", err) } - // Construct REST client - restClient := gogithub.NewClient(nil).WithAuthToken(cfg.Token) + // Construct REST client. + // When a tokenProvider is configured (OAuth), we skip WithAuthToken and + // use BearerAuthTransport exclusively. This avoids double-wrapping: go-github's + // WithAuthToken installs its own round tripper that would overwrite the + // dynamic token with the static one. + var restClient *gogithub.Client + if tokenProvider != nil { + restClient = gogithub.NewClient(&http.Client{ + Transport: &transport.BearerAuthTransport{ + Transport: http.DefaultTransport, + TokenProvider: tokenProvider, + }, + }) + } else { + restClient = gogithub.NewClient(nil).WithAuthToken(cfg.Token) + } restClient.UserAgent = fmt.Sprintf("github-mcp-server/%s", cfg.Version) restClient.BaseURL = restURL restClient.UploadURL = uploadURL // Construct GraphQL client - // We use NewEnterpriseClient unconditionally since we already parsed the API host gqlHTTPClient := &http.Client{ Transport: &transport.BearerAuthTransport{ Transport: &transport.GraphQLFeaturesTransport{ Transport: http.DefaultTransport, }, - Token: cfg.Token, + Token: cfg.Token, + TokenProvider: tokenProvider, }, } @@ -101,13 +117,15 @@ func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolv }, nil } -func NewStdioMCPServer(ctx context.Context, cfg github.MCPServerConfig) (*mcp.Server, error) { +// NewStdioMCPServer creates an MCP server for stdio mode. +// tokenProvider, if non-nil, enables dynamic token resolution (for OAuth). +func NewStdioMCPServer(ctx context.Context, cfg github.MCPServerConfig, tokenProvider func() string) (*mcp.Server, error) { apiHost, err := utils.NewAPIHost(cfg.Host) if err != nil { return nil, fmt.Errorf("failed to parse API host: %w", err) } - clients, err := createGitHubClients(cfg, apiHost) + clients, err := createGitHubClients(cfg, apiHost, tokenProvider) if err != nil { return nil, fmt.Errorf("failed to create GitHub clients: %w", err) } @@ -222,6 +240,15 @@ type StdioServerConfig struct { // RepoAccessCacheTTL overrides the default TTL for repository access cache entries. RepoAccessCacheTTL *time.Duration + + // OAuthManager, if set, enables OAuth 2.1 authentication for stdio mode. + // When configured, the server starts without a token and triggers the OAuth + // flow lazily on the first tool call that requires authentication. + OAuthManager *oauth.Manager + + // OAuthScopes are the OAuth scopes that were requested. Used for + // scope-based tool filtering (hiding tools the token can't satisfy). + OAuthScopes []string } // RunStdioServer is not concurrent safe. @@ -252,7 +279,8 @@ func RunStdioServer(cfg StdioServerConfig) error { // Only classic PATs (ghp_ prefix) return OAuth scopes via X-OAuth-Scopes header. // Fine-grained PATs and other token types don't support this, so we skip filtering. var tokenScopes []string - if strings.HasPrefix(cfg.Token, "ghp_") { + switch { + case strings.HasPrefix(cfg.Token, "ghp_"): fetchedScopes, err := fetchTokenScopesForHost(ctx, cfg.Token, cfg.Host) if err != nil { logger.Warn("failed to fetch token scopes, continuing without scope filtering", "error", err) @@ -260,10 +288,30 @@ func RunStdioServer(cfg StdioServerConfig) error { tokenScopes = fetchedScopes logger.Info("token scopes fetched for filtering", "scopes", tokenScopes) } - } else { + case cfg.OAuthManager != nil: + // For OAuth, use the requested scopes for tool filtering. This hides + // tools requiring scopes the OAuth token won't have, avoiding dead tools + // that waste context and can never succeed. STDIO does not support + // scope challenge / step-up auth, so filtering is the only option. + tokenScopes = cfg.OAuthScopes + logger.Info("using OAuth scopes for tool filtering", "scopes", tokenScopes) + default: logger.Debug("skipping scope filtering for non-PAT token") } + // Build the token provider. For OAuth, the token is obtained lazily + // after server startup. The provider returns the current token from + // whichever source is available (PAT or OAuth). + var tokenProvider func() string + if cfg.OAuthManager != nil { + tokenProvider = func() string { + if t := cfg.OAuthManager.GetAccessToken(); t != "" { + return t + } + return cfg.Token + } + } + ghServer, err := NewStdioMCPServer(ctx, github.MCPServerConfig{ Version: cfg.Version, Host: cfg.Host, @@ -281,11 +329,18 @@ func RunStdioServer(cfg StdioServerConfig) error { Logger: logger, RepoAccessTTL: cfg.RepoAccessCacheTTL, TokenScopes: tokenScopes, - }) + }, tokenProvider) if err != nil { return fmt.Errorf("failed to create MCP server: %w", err) } + // Add OAuth middleware: intercepts tool calls and triggers the OAuth + // flow if no token is available yet. The middleware blocks the tool + // call until authentication completes, then retries transparently. + if cfg.OAuthManager != nil { + ghServer.AddReceivingMiddleware(createOAuthMiddleware(cfg.OAuthManager, logger)) + } + if cfg.ExportTranslations { // Once server is initialized, all translations are loaded dumpTranslations() @@ -376,6 +431,38 @@ func addUserAgentsMiddleware(cfg github.MCPServerConfig, restClient *gogithub.Cl } } +// createOAuthMiddleware returns middleware that triggers OAuth authentication +// on tool calls when no token is available. It accesses the MCP session from +// the request to use elicitation for the OAuth flow. +func createOAuthMiddleware(oauthMgr *oauth.Manager, logger *slog.Logger) func(next mcp.MethodHandler) mcp.MethodHandler { + return func(next mcp.MethodHandler) mcp.MethodHandler { + return func(ctx context.Context, method string, request mcp.Request) (mcp.Result, error) { + if method != "tools/call" { + return next(ctx, method, request) + } + + if oauthMgr.HasToken() { + return next(ctx, method, request) + } + + // Extract session from the request for elicitation + callReq, ok := request.(*mcp.CallToolRequest) + if !ok { + return next(ctx, method, request) + } + + logger.Info("no token available, triggering OAuth flow") + if err := oauthMgr.RequestAuthentication(ctx, callReq.Session); err != nil { + logger.Error("OAuth authentication failed", "error", err) + return nil, fmt.Errorf("authentication required: %w", err) + } + + logger.Info("OAuth authentication successful") + return next(ctx, method, request) + } + } +} + // fetchTokenScopesForHost fetches the OAuth scopes for a token from the GitHub API. // It constructs the appropriate API host URL based on the configured host. func fetchTokenScopesForHost(ctx context.Context, token, host string) ([]string, error) { diff --git a/internal/oauth/manager.go b/internal/oauth/manager.go new file mode 100644 index 000000000..4877cac7f --- /dev/null +++ b/internal/oauth/manager.go @@ -0,0 +1,364 @@ +package oauth + +import ( + "context" + "fmt" + "log/slog" + "os" + "sync" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "golang.org/x/oauth2" +) + +// Manager handles OAuth authentication state and flow orchestration. +// +// Flow priority (security-ordered): +// 1. PKCE + browser auto-open (native binary — no elicitation needed) +// 2. PKCE + URL elicitation (Docker with bound port, or native when browser fails) +// 3. Device flow (fallback — more phishable, used only when PKCE is unavailable) +type Manager struct { + config Config + logger *slog.Logger + mu sync.RWMutex + token *Result + authInProgress bool + authDone chan struct{} +} + +// NewManager creates a new OAuth manager with the given configuration. +func NewManager(cfg Config, logger *slog.Logger) *Manager { + if logger == nil { + logger = slog.New(slog.NewTextHandler(os.Stderr, nil)) + } + return &Manager{ + config: cfg, + logger: logger, + } +} + +// HasToken returns true if a valid token is available. +func (m *Manager) HasToken() bool { + m.mu.RLock() + defer m.mu.RUnlock() + return m.token != nil && m.token.AccessToken != "" +} + +// GetAccessToken returns the current access token, or empty string if none. +func (m *Manager) GetAccessToken() string { + m.mu.RLock() + defer m.mu.RUnlock() + if m.token == nil { + return "" + } + return m.token.AccessToken +} + +// RequestAuthentication triggers the OAuth flow. +// If authentication is already in progress from another goroutine, this waits +// for it to complete rather than starting a duplicate flow. +func (m *Manager) RequestAuthentication(ctx context.Context, session *mcp.ServerSession) error { + m.mu.Lock() + if m.authInProgress { + authDone := m.authDone + m.mu.Unlock() + + select { + case <-authDone: + if m.HasToken() { + return nil + } + return fmt.Errorf("authentication failed") + case <-ctx.Done(): + return ctx.Err() + } + } + + m.authInProgress = true + m.authDone = make(chan struct{}) + m.mu.Unlock() + + defer func() { + m.mu.Lock() + m.authInProgress = false + close(m.authDone) + m.mu.Unlock() + }() + + // Always attempt PKCE first — it's more secure than device flow. + // Skip PKCE only when it cannot work: random port inside Docker + // (random ports can't be mapped, and the browser can't auto-open). + if m.config.CallbackPort == 0 && IsRunningInDocker() { + m.logger.Info("Docker detected with no callback port configured, using device flow") + return m.startDeviceFlow(ctx, session) + } + + err := m.startPKCEFlow(ctx, session) + if err == nil { + return nil + } + + m.logger.Info("PKCE flow unavailable, falling back to device flow", "reason", err) + + // Device flow fallback — used when PKCE callback server cannot start + // (e.g., Docker without port binding). Device flow is more phishable + // than PKCE, so it's only a fallback. + return m.startDeviceFlow(ctx, session) +} + +// startPKCEFlow runs the PKCE authorization code flow. +// +// Steps: +// 1. Start local callback server (127.0.0.1 only) +// 2. Try to open the auth URL in the user's browser +// 3. If browser fails, use URL elicitation to show the URL securely +// 4. Wait for the callback with the authorization code +// 5. Exchange the code for a token using the PKCE verifier +func (m *Manager) startPKCEFlow(ctx context.Context, session *mcp.ServerSession) error { + verifier, err := generatePKCEVerifier() + if err != nil { + return fmt.Errorf("PKCE setup failed: %w", err) + } + + state, err := generateRandomToken() + if err != nil { + return fmt.Errorf("state generation failed: %w", err) + } + + listener, port, err := startLocalServer(m.config.CallbackPort) + if err != nil { + return fmt.Errorf("callback server failed: %w", err) + } + + oauth2Cfg := &oauth2.Config{ + ClientID: m.config.ClientID, + ClientSecret: m.config.ClientSecret, + RedirectURL: fmt.Sprintf("http://localhost:%d/callback", port), + Scopes: m.config.Scopes, + Endpoint: oauth2.Endpoint{ + AuthURL: m.config.AuthURL, + TokenURL: m.config.TokenURL, + }, + } + + authURL := oauth2Cfg.AuthCodeURL(state, oauth2.S256ChallengeOption(verifier)) + + codeChan := make(chan string, 1) + errChan := make(chan error, 1) + server := createCallbackServer(state, codeChan, errChan, listener) + + cleanup := func() { + shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _ = server.Shutdown(shutdownCtx) + _ = listener.Close() + } + + // Try browser auto-open first (works on native, fails in Docker) + browserErr := openBrowser(authURL) + if browserErr != nil { + m.logger.Debug("browser auto-open failed, trying URL elicitation", "error", browserErr) + } + + // If browser didn't open, use URL elicitation to show the auth URL. + // URL mode elicitation is secure: the MCP client shows the URL to the + // user without exposing it to the LLM context. + elicitCancelChan := make(chan struct{}, 1) + elicitCtx, cancelElicit := context.WithCancel(ctx) + defer cancelElicit() + + if browserErr != nil { + if !m.tryURLElicitation(elicitCtx, session, authURL, elicitCancelChan) { + // No browser, no URL elicitation — PKCE cannot proceed. + // Caller will fall back to device flow. + cleanup() + return fmt.Errorf("no browser available and client does not support URL elicitation") + } + } + + select { + case code := <-codeChan: + cancelElicit() + token, exchangeErr := oauth2Cfg.Exchange(ctx, code, oauth2.VerifierOption(verifier)) + cleanup() + if exchangeErr != nil { + return fmt.Errorf("failed to exchange code for token: %w", exchangeErr) + } + + m.setToken(&Result{ + AccessToken: token.AccessToken, + RefreshToken: token.RefreshToken, + TokenType: token.TokenType, + Expiry: token.Expiry, + }) + return nil + + case err := <-errChan: + cleanup() + return fmt.Errorf("OAuth callback error: %w", err) + + case <-elicitCancelChan: + cleanup() + return fmt.Errorf("OAuth authorization was cancelled by user") + + case <-ctx.Done(): + cleanup() + return ctx.Err() + + case <-time.After(DefaultAuthTimeout): + cleanup() + return fmt.Errorf("OAuth timeout after %v — please try again", DefaultAuthTimeout) + } +} + +// tryURLElicitation attempts to show the auth URL via MCP URL-mode elicitation. +// Returns true if elicitation was started, false if unavailable. +func (m *Manager) tryURLElicitation(ctx context.Context, session *mcp.ServerSession, authURL string, cancelChan chan<- struct{}) bool { + if session == nil { + return false + } + + // Check if client supports URL elicitation + params := session.InitializeParams() + if params == nil || params.Capabilities == nil || + params.Capabilities.Elicitation == nil || + params.Capabilities.Elicitation.URL == nil { + return false + } + + go func() { + elicitID, _ := generateRandomToken() + result, err := session.Elicit(ctx, &mcp.ElicitParams{ + Mode: "url", + URL: authURL, + ElicitationID: elicitID, + Message: "Please visit the URL to authorize GitHub MCP Server.", + }) + if err != nil || result == nil || result.Action == "cancel" || result.Action == "decline" { + select { + case cancelChan <- struct{}{}: + default: + } + } + }() + + return true +} + +// startDeviceFlow runs the device authorization flow. +// This is the fallback when PKCE is unavailable (no port binding). +// Device flow is inherently more phishable than PKCE because the device +// code could be socially engineered — it should only be used as a fallback. +func (m *Manager) startDeviceFlow(ctx context.Context, session *mcp.ServerSession) error { + oauth2Cfg := &oauth2.Config{ + ClientID: m.config.ClientID, + ClientSecret: m.config.ClientSecret, + Scopes: m.config.Scopes, + Endpoint: oauth2.Endpoint{ + AuthURL: m.config.AuthURL, + TokenURL: m.config.TokenURL, + DeviceAuthURL: m.config.DeviceAuthURL, + }, + } + + deviceAuth, err := oauth2Cfg.DeviceAuth(ctx) + if err != nil { + return fmt.Errorf("failed to get device authorization: %w", err) + } + + pollCtx, cancelPoll := context.WithCancel(ctx) + defer cancelPoll() + + m.showDeviceCode(pollCtx, session, deviceAuth, cancelPoll) + + token, err := oauth2Cfg.DeviceAccessToken(pollCtx, deviceAuth) + if err != nil { + if pollCtx.Err() != nil { + return fmt.Errorf("OAuth authorization was cancelled by user") + } + return fmt.Errorf("failed to get device access token: %w", err) + } + + m.setToken(&Result{ + AccessToken: token.AccessToken, + RefreshToken: token.RefreshToken, + TokenType: token.TokenType, + Expiry: token.Expiry, + }) + + return nil +} + +// showDeviceCode displays the device code to the user via the best available channel. +// Priority: URL elicitation → form elicitation → stderr. +func (m *Manager) showDeviceCode(ctx context.Context, session *mcp.ServerSession, deviceAuth *oauth2.DeviceAuthResponse, cancelPoll context.CancelFunc) { + message := fmt.Sprintf("Visit %s and enter code: %s", deviceAuth.VerificationURI, deviceAuth.UserCode) + + if session == nil { + m.logger.Info(message) + fmt.Fprintf(os.Stderr, "\n%s\n\n", message) + return + } + + // Try URL elicitation first (most secure display) + params := session.InitializeParams() + supportsURL := params != nil && params.Capabilities != nil && + params.Capabilities.Elicitation != nil && + params.Capabilities.Elicitation.URL != nil + + if supportsURL { + go func() { + elicitID, _ := generateRandomToken() + result, err := session.Elicit(ctx, &mcp.ElicitParams{ + Mode: "url", + URL: deviceAuth.VerificationURI, + ElicitationID: elicitID, + Message: fmt.Sprintf("Enter the code: %s", deviceAuth.UserCode), + }) + if err != nil || result == nil || result.Action == "cancel" || result.Action == "decline" { + cancelPoll() + } + }() + return + } + + // Try form elicitation — device codes are safe to display via form mode + // (they are short-lived, require user action, and are designed for display) + supportsForm := params != nil && params.Capabilities != nil && + params.Capabilities.Elicitation != nil + + if supportsForm { + go func() { + result, err := session.Elicit(ctx, &mcp.ElicitParams{ + Mode: "form", + Message: message, + RequestedSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "acknowledged": map[string]any{ + "type": "boolean", + "title": "I have entered the code", + "description": message, + "default": false, + }, + }, + }, + }) + if err != nil || result == nil || result.Action == "cancel" || result.Action == "decline" { + cancelPoll() + } + }() + return + } + + // Last resort: stderr (no elicitation available) + m.logger.Info(message) + fmt.Fprintf(os.Stderr, "\n%s\n\n", message) +} + +func (m *Manager) setToken(token *Result) { + m.mu.Lock() + defer m.mu.Unlock() + m.token = token +} diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go new file mode 100644 index 000000000..014da3b01 --- /dev/null +++ b/internal/oauth/oauth.go @@ -0,0 +1,256 @@ +package oauth + +import ( + "crypto/rand" + "embed" + "encoding/base64" + "fmt" + "html/template" + "io" + "net" + "net/http" + "os" + "os/exec" + "runtime" + "strings" + "time" +) + +//go:embed templates/*.html +var templateFS embed.FS + +var ( + errorTemplate *template.Template + successTemplate *template.Template +) + +func init() { + var err error + errorTemplate, err = template.ParseFS(templateFS, "templates/error.html") + if err != nil { + panic(fmt.Sprintf("failed to parse error template: %v", err)) + } + successTemplate, err = template.ParseFS(templateFS, "templates/success.html") + if err != nil { + panic(fmt.Sprintf("failed to parse success template: %v", err)) + } +} + +// DefaultAuthTimeout is the timeout for the OAuth authorization flow. +const DefaultAuthTimeout = 5 * time.Minute + +// Config holds the OAuth configuration. +type Config struct { + ClientID string + ClientSecret string + Scopes []string + AuthURL string + TokenURL string + Host string // GitHub host for constructing OAuth URLs + DeviceAuthURL string + CallbackPort int // Fixed callback port (0 for random) +} + +// Result contains the OAuth flow result. +// +// GitHub OAuth App tokens do not expire, but GitHub App tokens do. +// Callers should handle re-authentication when API calls fail with auth errors. +type Result struct { + AccessToken string + RefreshToken string + TokenType string + Expiry time.Time +} + +// generatePKCEVerifier generates a PKCE code verifier (43 base64url chars from 32 random bytes). +func generatePKCEVerifier() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", fmt.Errorf("failed to generate PKCE verifier: %w", err) + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// generateRandomToken generates a cryptographically random URL-safe token. +func generateRandomToken() (string, error) { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// IsRunningInDocker detects if the process is running inside a Docker container. +// On non-Linux systems this always returns false since detection relies on +// Linux-specific filesystem paths. +func IsRunningInDocker() bool { + if runtime.GOOS != "linux" { + return false + } + + if _, err := os.Stat("/.dockerenv"); err == nil { + return true + } + + data, err := os.ReadFile("/proc/1/cgroup") + if err == nil && (strings.Contains(string(data), "docker") || strings.Contains(string(data), "containerd")) { + return true + } + + return false +} + +// startLocalServer starts a local HTTP callback server. +// When port is 0 (random), binds to 127.0.0.1 only (native binary, secure). +// When port is explicitly set, binds to 0.0.0.0 so Docker port mapping +// (iptables DNAT to the container's eth0) can reach it. +func startLocalServer(port int) (net.Listener, int, error) { + host := "127.0.0.1" + if port > 0 { + host = "0.0.0.0" + } + addr := fmt.Sprintf("%s:%d", host, port) + listener, err := net.Listen("tcp", addr) + if err != nil { + return nil, 0, fmt.Errorf("failed to start listener on %s: %w", addr, err) + } + + actualPort := listener.Addr().(*net.TCPAddr).Port + return listener, actualPort, nil +} + +// createCallbackHandler creates an HTTP handler for the OAuth callback. +// It validates the state parameter for CSRF protection and captures the authorization code. +func createCallbackHandler(expectedState string, codeChan chan<- string, errChan chan<- error) http.Handler { + mux := http.NewServeMux() + + mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) { + if errMsg := r.URL.Query().Get("error"); errMsg != "" { + errDesc := r.URL.Query().Get("error_description") + if errDesc != "" { + errMsg = fmt.Sprintf("%s: %s", errMsg, errDesc) + } + errChan <- fmt.Errorf("authorization failed: %s", errMsg) + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + // html/template auto-escapes ErrorMessage to prevent XSS + if err := errorTemplate.Execute(w, struct{ ErrorMessage string }{ErrorMessage: errMsg}); err != nil { + http.Error(w, "Internal error", http.StatusInternalServerError) + } + return + } + + if state := r.URL.Query().Get("state"); state != expectedState { + errChan <- fmt.Errorf("state mismatch (possible CSRF attack)") + http.Error(w, "State mismatch", http.StatusBadRequest) + return + } + + code := r.URL.Query().Get("code") + if code == "" { + errChan <- fmt.Errorf("no authorization code received") + http.Error(w, "No code received", http.StatusBadRequest) + return + } + + codeChan <- code + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + if err := successTemplate.Execute(w, nil); err != nil { + http.Error(w, "Internal error", http.StatusInternalServerError) + } + }) + + return mux +} + +// createCallbackServer creates and starts an HTTP server for the OAuth callback. +func createCallbackServer(expectedState string, codeChan chan<- string, errChan chan<- error, listener net.Listener) *http.Server { + handler := createCallbackHandler(expectedState, codeChan, errChan) + server := &http.Server{ + Handler: handler, + ReadHeaderTimeout: 10 * time.Second, // Prevent Slowloris attacks + } + + go func() { + if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { + errChan <- fmt.Errorf("callback server error: %w", err) + } + }() + + return server +} + +// openBrowser tries to open the URL in the default browser. +func openBrowser(url string) error { + var cmd *exec.Cmd + + switch runtime.GOOS { + case "linux": + cmd = exec.Command("xdg-open", url) + case "darwin": + cmd = exec.Command("open", url) + case "windows": + cmd = exec.Command("cmd", "/c", "start", url) + default: + return fmt.Errorf("unsupported platform: %s", runtime.GOOS) + } + + cmd.Stdout = io.Discard + cmd.Stderr = io.Discard + return cmd.Start() +} + +// GetGitHubOAuthConfig returns a Config for the specified GitHub host. +func GetGitHubOAuthConfig(clientID, clientSecret string, scopes []string, host string, callbackPort int) Config { + authURL, tokenURL, deviceAuthURL := getOAuthEndpoints(host) + + return Config{ + ClientID: clientID, + ClientSecret: clientSecret, + Scopes: scopes, + AuthURL: authURL, + TokenURL: tokenURL, + DeviceAuthURL: deviceAuthURL, + Host: host, + CallbackPort: callbackPort, + } +} + +// getOAuthEndpoints returns the appropriate OAuth endpoints based on the host. +func getOAuthEndpoints(host string) (authURL, tokenURL, deviceAuthURL string) { + if host == "" { + return "https://github.com/login/oauth/authorize", + "https://github.com/login/oauth/access_token", + "https://github.com/login/device/code" + } + + hostURL := host + if !strings.HasPrefix(host, "http://") && !strings.HasPrefix(host, "https://") { + hostURL = "https://" + host + } + + var scheme, hostname string + if strings.HasPrefix(hostURL, "https://") { + scheme = "https" + hostname = strings.TrimPrefix(hostURL, "https://") + } else if strings.HasPrefix(hostURL, "http://") { + scheme = "http" + hostname = strings.TrimPrefix(hostURL, "http://") + } + + if idx := strings.Index(hostname, "/"); idx > 0 { + hostname = hostname[:idx] + } + + // Strip api. subdomain for github.com (api.github.com → github.com) + if hostname == "api.github.com" { + hostname = "github.com" + } + + authURL = fmt.Sprintf("%s://%s/login/oauth/authorize", scheme, hostname) + tokenURL = fmt.Sprintf("%s://%s/login/oauth/access_token", scheme, hostname) + deviceAuthURL = fmt.Sprintf("%s://%s/login/device/code", scheme, hostname) + + return authURL, tokenURL, deviceAuthURL +} diff --git a/internal/oauth/oauth_test.go b/internal/oauth/oauth_test.go new file mode 100644 index 000000000..f915c1276 --- /dev/null +++ b/internal/oauth/oauth_test.go @@ -0,0 +1,265 @@ +package oauth + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGeneratePKCEVerifier(t *testing.T) { + verifier, err := generatePKCEVerifier() + require.NoError(t, err) + + // Base64URL encoding of 32 bytes = 43 characters + assert.GreaterOrEqual(t, len(verifier), 43) + + verifier2, err := generatePKCEVerifier() + require.NoError(t, err) + assert.NotEqual(t, verifier, verifier2) +} + +func TestGenerateRandomToken(t *testing.T) { + token1, err := generateRandomToken() + require.NoError(t, err) + assert.GreaterOrEqual(t, len(token1), 20) + + token2, err := generateRandomToken() + require.NoError(t, err) + assert.NotEqual(t, token1, token2) +} + +func TestGetGitHubOAuthConfig(t *testing.T) { + tests := []struct { + name string + host string + callbackPort int + wantAuthURL string + wantTokenURL string + wantDeviceURL string + }{ + { + name: "default github.com", + host: "", + wantAuthURL: "https://github.com/login/oauth/authorize", + wantTokenURL: "https://github.com/login/oauth/access_token", + wantDeviceURL: "https://github.com/login/device/code", + }, + { + name: "GHES host with scheme", + host: "https://github.enterprise.com", + callbackPort: 8085, + wantAuthURL: "https://github.enterprise.com/login/oauth/authorize", + wantTokenURL: "https://github.enterprise.com/login/oauth/access_token", + wantDeviceURL: "https://github.enterprise.com/login/device/code", + }, + { + name: "GHEC host (ghe.com)", + host: "https://mycompany.ghe.com", + wantAuthURL: "https://mycompany.ghe.com/login/oauth/authorize", + wantTokenURL: "https://mycompany.ghe.com/login/oauth/access_token", + wantDeviceURL: "https://mycompany.ghe.com/login/device/code", + }, + { + name: "host without scheme defaults to https", + host: "github.enterprise.com", + wantAuthURL: "https://github.enterprise.com/login/oauth/authorize", + wantTokenURL: "https://github.enterprise.com/login/oauth/access_token", + wantDeviceURL: "https://github.enterprise.com/login/device/code", + }, + { + name: "api.github.com strips api subdomain", + host: "api.github.com", + wantAuthURL: "https://github.com/login/oauth/authorize", + wantTokenURL: "https://github.com/login/oauth/access_token", + wantDeviceURL: "https://github.com/login/device/code", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := GetGitHubOAuthConfig("cid", "csecret", []string{"repo"}, tt.host, tt.callbackPort) + + assert.Equal(t, "cid", cfg.ClientID) + assert.Equal(t, "csecret", cfg.ClientSecret) + assert.Equal(t, []string{"repo"}, cfg.Scopes) + assert.Equal(t, tt.wantAuthURL, cfg.AuthURL) + assert.Equal(t, tt.wantTokenURL, cfg.TokenURL) + assert.Equal(t, tt.wantDeviceURL, cfg.DeviceAuthURL) + assert.Equal(t, tt.callbackPort, cfg.CallbackPort) + }) + } +} + +func TestStartLocalServer(t *testing.T) { + t.Run("random port binds to localhost", func(t *testing.T) { + listener, port, err := startLocalServer(0) + require.NoError(t, err) + defer listener.Close() + + assert.Greater(t, port, 0) + // Random port binds to 127.0.0.1 (secure, native only) + assert.Contains(t, listener.Addr().String(), "127.0.0.1:") + }) + + t.Run("fixed port binds to all interfaces", func(t *testing.T) { + fixedPort := 54321 + listener, port, err := startLocalServer(fixedPort) + require.NoError(t, err) + defer listener.Close() + + assert.Equal(t, fixedPort, port) + // Fixed port binds to all interfaces (0.0.0.0 or [::]) for Docker port mapping + addr := listener.Addr().String() + assert.True(t, strings.Contains(addr, "0.0.0.0:") || strings.Contains(addr, "[::]:"), + "expected all-interface bind, got %s", addr) + }) +} + +func TestCallbackHandler(t *testing.T) { + expectedState := "test-state-12345" + + t.Run("successful callback", func(t *testing.T) { + codeChan := make(chan string, 1) + errChan := make(chan error, 1) + handler := createCallbackHandler(expectedState, codeChan, errChan) + + req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state-12345", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Contains(t, w.Header().Get("Content-Type"), "text/html") + assert.Contains(t, w.Body.String(), "Authorization Successful") + + select { + case code := <-codeChan: + assert.Equal(t, "test-code", code) + default: + t.Fatal("expected code on channel") + } + }) + + t.Run("state mismatch", func(t *testing.T) { + codeChan := make(chan string, 1) + errChan := make(chan error, 1) + handler := createCallbackHandler(expectedState, codeChan, errChan) + + req := httptest.NewRequest("GET", "/callback?code=test-code&state=wrong-state", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusBadRequest, w.Code) + + select { + case err := <-errChan: + assert.Contains(t, err.Error(), "state mismatch") + default: + t.Fatal("expected error on channel") + } + }) + + t.Run("missing code", func(t *testing.T) { + codeChan := make(chan string, 1) + errChan := make(chan error, 1) + handler := createCallbackHandler(expectedState, codeChan, errChan) + + req := httptest.NewRequest("GET", "/callback?state=test-state-12345", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusBadRequest, w.Code) + + select { + case err := <-errChan: + assert.Contains(t, err.Error(), "no authorization code") + default: + t.Fatal("expected error on channel") + } + }) + + t.Run("OAuth error response", func(t *testing.T) { + codeChan := make(chan string, 1) + errChan := make(chan error, 1) + handler := createCallbackHandler(expectedState, codeChan, errChan) + + req := httptest.NewRequest("GET", "/callback?error=access_denied&error_description=User+denied+access", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) // Error template renders with 200 + assert.Contains(t, w.Body.String(), "Authorization Failed") + + select { + case err := <-errChan: + assert.Contains(t, err.Error(), "access_denied") + assert.Contains(t, err.Error(), "User denied access") + default: + t.Fatal("expected error on channel") + } + }) + + t.Run("XSS prevention in error messages", func(t *testing.T) { + codeChan := make(chan string, 1) + errChan := make(chan error, 1) + handler := createCallbackHandler(expectedState, codeChan, errChan) + + // Attempt XSS via error parameter — html/template auto-escapes + req := httptest.NewRequest("GET", `/callback?error=`, nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + body := w.Body.String() + assert.NotContains(t, body, "