diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e738723f..b29136716 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ - `certificates`: [v1.2.0](services/certificates/CHANGELOG.md#v120) - **Feature:** Switch from `v2beta` API version to `v2` version. - **Breaking change:** Rename `CreateCertificateResponse` to `GetCertificateResponse` +- `core`: + - [v0.21.0](core/CHANGELOG.md#v0210) + - **Deprecation:** KeyFlow `SetToken` and `GetToken` will be removed after 2026-07-01. Use GetAccessToken instead and rely on client refresh. + - **Feature:** Support Workload Identity Federation flow - `sfs`: - [v0.2.0](services/sfs/CHANGELOG.md) - **Breaking change:** Remove region configuration in `APIClient` diff --git a/README.md b/README.md index 69d23ae86..9ca8dcace 100644 --- a/README.md +++ b/README.md @@ -234,4 +234,4 @@ See the [release documentation](./RELEASE.md) for further information. ## License -Apache 2.0 +Apache 2.0 \ No newline at end of file diff --git a/core/CHANGELOG.md b/core/CHANGELOG.md index 8b1d2fb86..d7d8ecf5c 100644 --- a/core/CHANGELOG.md +++ b/core/CHANGELOG.md @@ -1,3 +1,7 @@ +## v0.21.0 +- **Deprecation:** KeyFlow `SetToken` and `GetToken` will be removed after 2026-07-01. Use GetAccessToken instead and rely on client refresh. +- **Feature:** Support Workload Identity Federation flow + ## v0.20.1 - **Improvement:** Improve error message when passing a PEM encoded file to as service account key diff --git a/core/VERSION b/core/VERSION index 2c80271d5..759e855fb 100644 --- a/core/VERSION +++ b/core/VERSION @@ -1 +1 @@ -v0.20.1 +v0.21.0 diff --git a/core/auth/auth.go b/core/auth/auth.go index 568847aea..b393afbb7 100644 --- a/core/auth/auth.go +++ b/core/auth/auth.go @@ -51,6 +51,12 @@ func SetupAuth(cfg *config.Configuration) (rt http.RoundTripper, err error) { return nil, fmt.Errorf("configuring no auth client: %w", err) } return noAuthRoundTripper, nil + } else if cfg.WorkloadIdentityFederation { + wifRoundTripper, err := WorkloadIdentityFederationAuth(cfg) + if err != nil { + return nil, fmt.Errorf("configuring no auth client: %w", err) + } + return wifRoundTripper, nil } else if cfg.ServiceAccountKey != "" || cfg.ServiceAccountKeyPath != "" { keyRoundTripper, err := KeyAuth(cfg) if err != nil { @@ -84,14 +90,18 @@ func DefaultAuth(cfg *config.Configuration) (rt http.RoundTripper, err error) { cfg = &config.Configuration{} } - // Key flow - rt, err = KeyAuth(cfg) + // WIF flow + rt, err = WorkloadIdentityFederationAuth(cfg) if err != nil { - keyFlowErr := err - // Token flow - rt, err = TokenAuth(cfg) + // Key flow + rt, err = KeyAuth(cfg) if err != nil { - return nil, fmt.Errorf("no valid credentials were found: trying key flow: %s, trying token flow: %w", keyFlowErr.Error(), err) + keyFlowErr := err + // Token flow + rt, err = TokenAuth(cfg) + if err != nil { + return nil, fmt.Errorf("no valid credentials were found: trying key flow: %s, trying token flow: %w", keyFlowErr.Error(), err) + } } } return rt, nil @@ -221,6 +231,29 @@ func KeyAuth(cfg *config.Configuration) (http.RoundTripper, error) { return client, nil } +// WorkloadIdentityFederationAuth configures the wif flow and returns an http.RoundTripper +// that can be used to make authenticated requests using an access token +func WorkloadIdentityFederationAuth(cfg *config.Configuration) (http.RoundTripper, error) { + wifConfig := clients.WorkloadIdentityFederationFlowConfig{ + TokenUrl: cfg.TokenCustomUrl, + BackgroundTokenRefreshContext: cfg.BackgroundTokenRefreshContext, + ClientID: cfg.ServiceAccountEmail, + TokenExpiration: cfg.ServiceAccountFederatedTokenExpiration, + FederatedTokenFunction: cfg.ServiceAccountFederatedTokenFunc, + } + + if cfg.HTTPClient != nil && cfg.HTTPClient.Transport != nil { + wifConfig.HTTPTransport = cfg.HTTPClient.Transport + } + + client := &clients.WorkloadIdentityFederationFlow{} + if err := client.Init(&wifConfig); err != nil { + return nil, fmt.Errorf("error initializing client: %w", err) + } + + return client, nil +} + // readCredentialsFile reads the credentials file from the specified path and returns Credentials func readCredentialsFile(path string) (*Credentials, error) { if path == "" { diff --git a/core/auth/auth_test.go b/core/auth/auth_test.go index a7c776946..b861bf581 100644 --- a/core/auth/auth_test.go +++ b/core/auth/auth_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/stackitcloud/stackit-sdk-go/core/clients" "github.com/stackitcloud/stackit-sdk-go/core/config" @@ -121,6 +122,32 @@ func TestSetupAuth(t *testing.T) { } }() + // create a wif assertion file + wifAssertionFile, errs := os.CreateTemp("", "temp-*.txt") + if errs != nil { + t.Fatalf("Creating temporary file: %s", err) + } + defer func() { + _ = wifAssertionFile.Close() + err := os.Remove(wifAssertionFile.Name()) + if err != nil { + t.Fatalf("Removing temporary file: %s", err) + } + }() + + token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)), + Subject: "sub", + }).SignedString([]byte("test")) + if err != nil { + t.Fatalf("Removing temporary file: %s", err) + } + + _, errs = wifAssertionFile.WriteString(string(token)) + if errs != nil { + t.Fatalf("Writing wif assertion to temporary file: %s", err) + } + // create a credentials file with saKey and private key credentialsKeyFile, errs := os.CreateTemp("", "temp-*.txt") if errs != nil { @@ -147,12 +174,19 @@ func TestSetupAuth(t *testing.T) { desc string config *config.Configuration setToken bool + setWorkloadIdentity bool setKeys bool setKeyPaths bool setCredentialsFilePathToken bool setCredentialsFilePathKey bool isValid bool }{ + { + desc: "wif_config", + config: nil, + setWorkloadIdentity: true, + isValid: true, + }, { desc: "token_config", config: nil, @@ -241,6 +275,12 @@ func TestSetupAuth(t *testing.T) { t.Setenv("STACKIT_CREDENTIALS_PATH", "") } + if test.setWorkloadIdentity { + t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", wifAssertionFile.Name()) + } else { + t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", "") + } + t.Setenv("STACKIT_SERVICE_ACCOUNT_EMAIL", "test-email") authRoundTripper, err := SetupAuth(test.config) @@ -253,7 +293,7 @@ func TestSetupAuth(t *testing.T) { t.Fatalf("Test didn't return error on invalid test case") } - if test.isValid && authRoundTripper == nil { + if authRoundTripper == nil && test.isValid { t.Fatalf("Roundtripper returned is nil for valid test case") } }) @@ -381,6 +421,32 @@ func TestDefaultAuth(t *testing.T) { t.Fatalf("Writing private key to temporary file: %s", err) } + // create a wif assertion file + wifAssertionFile, errs := os.CreateTemp("", "temp-*.txt") + if errs != nil { + t.Fatalf("Creating temporary file: %s", err) + } + defer func() { + _ = wifAssertionFile.Close() + err := os.Remove(wifAssertionFile.Name()) + if err != nil { + t.Fatalf("Removing temporary file: %s", err) + } + }() + + token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)), + Subject: "sub", + }).SignedString([]byte("test")) + if err != nil { + t.Fatalf("Removing temporary file: %s", err) + } + + _, errs = wifAssertionFile.WriteString(string(token)) + if errs != nil { + t.Fatalf("Writing wif assertion to temporary file: %s", err) + } + // create a credentials file with saKey and private key credentialsKeyFile, errs := os.CreateTemp("", "temp-*.txt") if errs != nil { @@ -409,6 +475,7 @@ func TestDefaultAuth(t *testing.T) { setKeyPaths bool setKeys bool setCredentialsFilePathKey bool + setWorkloadIdentity bool isValid bool expectedFlow string }{ @@ -418,6 +485,14 @@ func TestDefaultAuth(t *testing.T) { isValid: true, expectedFlow: "token", }, + { + desc: "wif_precedes_key_precedes_token", + setToken: true, + setKeyPaths: true, + setWorkloadIdentity: true, + isValid: true, + expectedFlow: "wif", + }, { desc: "key_precedes_token", setToken: true, @@ -475,6 +550,13 @@ func TestDefaultAuth(t *testing.T) { } else { t.Setenv("STACKIT_SERVICE_ACCOUNT_TOKEN", "") } + + if test.setWorkloadIdentity { + t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", wifAssertionFile.Name()) + } else { + t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", "") + } + t.Setenv("STACKIT_SERVICE_ACCOUNT_EMAIL", "test-email") // Get the default authentication client and ensure that it's not nil @@ -501,6 +583,10 @@ func TestDefaultAuth(t *testing.T) { if _, ok := authClient.(*clients.KeyFlow); !ok { t.Fatalf("Expected key flow, got %s", reflect.TypeOf(authClient)) } + case "wif": + if _, ok := authClient.(*clients.WorkloadIdentityFederationFlow); !ok { + t.Fatalf("Expected key flow, got %s", reflect.TypeOf(authClient)) + } } } }) diff --git a/core/clients/auth_flow.go b/core/clients/auth_flow.go new file mode 100644 index 000000000..5a0c18960 --- /dev/null +++ b/core/clients/auth_flow.go @@ -0,0 +1,88 @@ +package clients + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/stackitcloud/stackit-sdk-go/core/oapierror" +) + +const ( + defaultTokenExpirationLeeway = time.Second * 5 +) + +type AuthFlow interface { + RoundTrip(req *http.Request) (*http.Response, error) + GetAccessToken() (string, error) + getBackgroundTokenRefreshContext() context.Context + refreshAccessToken() error +} + +// TokenResponseBody is the API response +// when requesting a new token +type TokenResponseBody struct { + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + // Deprecated: RefreshToken is no longer used and the SDK will not attempt to refresh tokens using it but will instead use the AuthFlow implementation to get new tokens. + // This will be removed after 2026-07-01. + RefreshToken string `json:"refresh_token"` + Scope string `json:"scope"` + TokenType string `json:"token_type"` +} + +func parseTokenResponse(res *http.Response) (*TokenResponseBody, error) { + if res == nil { + return nil, fmt.Errorf("received bad response from API") + } + if res.StatusCode != http.StatusOK { + body, err := io.ReadAll(res.Body) + if err != nil { + // Fail silently, omit body from error + // We're trying to show error details, so it's unnecessary to fail because of this err + body = []byte{} + } + return nil, &oapierror.GenericOpenAPIError{ + StatusCode: res.StatusCode, + Body: body, + } + } + body, err := io.ReadAll(res.Body) + if err != nil { + return nil, err + } + + token := &TokenResponseBody{} + err = json.Unmarshal(body, token) + if err != nil { + return nil, fmt.Errorf("unmarshal token response: %w", err) + } + return token, nil +} + +func tokenExpired(token string, tokenExpirationLeeway time.Duration) (bool, error) { + if token == "" { + return true, nil + } + + // We can safely use ParseUnverified because we are not authenticating the user at this point. + // We're just checking the expiration time + tokenParsed, _, err := jwt.NewParser().ParseUnverified(token, &jwt.RegisteredClaims{}) + if err != nil { + return false, fmt.Errorf("parse token: %w", err) + } + + expirationTimestampNumeric, err := tokenParsed.Claims.GetExpirationTime() + if err != nil { + return false, fmt.Errorf("get expiration timestamp: %w", err) + } + + // Pretend to be `tokenExpirationLeeway` into the future to avoid token expiring + // between retrieving the token and upstream systems validating it. + now := time.Now().Add(tokenExpirationLeeway) + return now.After(expirationTimestampNumeric.Time), nil +} diff --git a/core/clients/key_flow_continuous_refresh.go b/core/clients/continuous_refresh.go similarity index 67% rename from core/clients/key_flow_continuous_refresh.go rename to core/clients/continuous_refresh.go index f5129aa02..effb5e668 100644 --- a/core/clients/key_flow_continuous_refresh.go +++ b/core/clients/continuous_refresh.go @@ -17,12 +17,12 @@ var ( defaultTimeBetweenTries = 5 * time.Minute ) -// Continuously refreshes the token of a key flow, retrying if the token API returns 5xx errrors. Writes to stderr when it terminates. +// Continuously refreshes the token of an auth flow, retrying if the token API returns 5xx errrors. Writes to stderr when it terminates. // -// To terminate this routine, close the context in keyFlow.config.BackgroundTokenRefreshContext. -func continuousRefreshToken(keyflow *KeyFlow) { +// To terminate this routine, close the context in flow.getBackgroundTokenRefreshContext(). +func continuousRefreshToken(flow AuthFlow) { refresher := &continuousTokenRefresher{ - keyFlow: keyflow, + flow: flow, timeStartBeforeTokenExpiration: defaultTimeStartBeforeTokenExpiration, timeBetweenContextCheck: defaultTimeBetweenContextCheck, timeBetweenTries: defaultTimeBetweenTries, @@ -32,36 +32,26 @@ func continuousRefreshToken(keyflow *KeyFlow) { } type continuousTokenRefresher struct { - keyFlow *KeyFlow + flow AuthFlow // Token refresh tries start at [Access token expiration timestamp] - [This duration] timeStartBeforeTokenExpiration time.Duration timeBetweenContextCheck time.Duration timeBetweenTries time.Duration } -// Continuously refreshes the token of a key flow, retrying if the token API returns 5xx errrors. Always returns with a non-nil error. +// Continuously refreshes the token of an auth flow, retrying if the token API returns 5xx errrors. Always returns with a non-nil error. // -// To terminate this routine, close the context in refresher.keyFlow.config.BackgroundTokenRefreshContext. +// To terminate this routine, close the context in refresher.flow.getBackgroundTokenRefreshContext(). func (refresher *continuousTokenRefresher) continuousRefreshToken() error { // Compute timestamp where we'll refresh token // Access token may be empty at this point, we have to check it var startRefreshTimestamp time.Time - var accessToken string - refresher.keyFlow.tokenMutex.RLock() - if refresher.keyFlow.token != nil { - accessToken = refresher.keyFlow.token.AccessToken - } - refresher.keyFlow.tokenMutex.RUnlock() - if accessToken == "" { - startRefreshTimestamp = time.Now() - } else { - expirationTimestamp, err := refresher.getAccessTokenExpirationTimestamp() - if err != nil { - return fmt.Errorf("get access token expiration timestamp: %w", err) - } - startRefreshTimestamp = expirationTimestamp.Add(-refresher.timeStartBeforeTokenExpiration) + expirationTimestamp, err := refresher.getAccessTokenExpirationTimestamp() + if err != nil { + return fmt.Errorf("get access token expiration timestamp: %w", err) } + startRefreshTimestamp = expirationTimestamp.Add(-refresher.timeStartBeforeTokenExpiration) for { err := refresher.waitUntilTimestamp(startRefreshTimestamp) @@ -69,7 +59,7 @@ func (refresher *continuousTokenRefresher) continuousRefreshToken() error { return err } - err = refresher.keyFlow.config.BackgroundTokenRefreshContext.Err() + err = refresher.flow.getBackgroundTokenRefreshContext().Err() if err != nil { return fmt.Errorf("check context: %w", err) } @@ -92,13 +82,14 @@ func (refresher *continuousTokenRefresher) continuousRefreshToken() error { } func (refresher *continuousTokenRefresher) getAccessTokenExpirationTimestamp() (*time.Time, error) { - refresher.keyFlow.tokenMutex.RLock() - token := refresher.keyFlow.token.AccessToken - refresher.keyFlow.tokenMutex.RUnlock() + accessToken, err := refresher.flow.GetAccessToken() + if err != nil { + return nil, err + } // We can safely use ParseUnverified because we are not doing authentication of any kind // We're just checking the expiration time - tokenParsed, _, err := jwt.NewParser().ParseUnverified(token, &jwt.RegisteredClaims{}) + tokenParsed, _, err := jwt.NewParser().ParseUnverified(accessToken, &jwt.RegisteredClaims{}) if err != nil { return nil, fmt.Errorf("parse token: %w", err) } @@ -111,7 +102,7 @@ func (refresher *continuousTokenRefresher) getAccessTokenExpirationTimestamp() ( func (refresher *continuousTokenRefresher) waitUntilTimestamp(timestamp time.Time) error { for time.Now().Before(timestamp) { - err := refresher.keyFlow.config.BackgroundTokenRefreshContext.Err() + err := refresher.flow.getBackgroundTokenRefreshContext().Err() if err != nil { return fmt.Errorf("check context: %w", err) } @@ -125,7 +116,7 @@ func (refresher *continuousTokenRefresher) waitUntilTimestamp(timestamp time.Tim // - (false, nil) if not successful but should be retried. // - (_, err) if not successful and shouldn't be retried. func (refresher *continuousTokenRefresher) refreshToken() (bool, error) { - err := refresher.keyFlow.recreateAccessToken() + err := refresher.flow.refreshAccessToken() if err == nil { return true, nil } diff --git a/core/clients/key_flow_continuous_refresh_test.go b/core/clients/continuous_refresh_test.go similarity index 74% rename from core/clients/key_flow_continuous_refresh_test.go rename to core/clients/continuous_refresh_test.go index 7c7ee9565..2ad3f20d4 100644 --- a/core/clients/key_flow_continuous_refresh_test.go +++ b/core/clients/continuous_refresh_test.go @@ -12,7 +12,6 @@ import ( "time" "github.com/golang-jwt/jwt/v5" - "github.com/stackitcloud/stackit-sdk-go/core/oapierror" ) @@ -22,9 +21,9 @@ func TestContinuousRefreshToken(t *testing.T) { jwt.TimePrecision = time.Millisecond // Refresher settings - timeStartBeforeTokenExpiration := 500 * time.Millisecond - timeBetweenContextCheck := 10 * time.Millisecond - timeBetweenTries := 100 * time.Millisecond + timeStartBeforeTokenExpiration := 0 * time.Second + timeBetweenContextCheck := 50 * time.Millisecond + timeBetweenTries := 500 * time.Millisecond // All generated acess tokens will have this time to live accessTokensTimeToLive := 1 * time.Second @@ -34,16 +33,20 @@ func TestContinuousRefreshToken(t *testing.T) { contextClosesIn time.Duration doError error expectedNumberDoCalls int - expectedCallRange []int // Optional: for tests that can have variable call counts }{ + { + desc: "update access token never", + contextClosesIn: 900 * time.Millisecond, // Should allow no refresh + expectedNumberDoCalls: 0, + }, { desc: "update access token once", - contextClosesIn: 700 * time.Millisecond, // Should allow one refresh + contextClosesIn: 1900 * time.Millisecond, // Should allow one refresh expectedNumberDoCalls: 1, }, { desc: "update access token twice", - contextClosesIn: 1300 * time.Millisecond, // Should allow two refreshes + contextClosesIn: 2900 * time.Millisecond, // Should allow two refreshes expectedNumberDoCalls: 2, }, { @@ -62,14 +65,14 @@ func TestContinuousRefreshToken(t *testing.T) { expectedNumberDoCalls: 0, }, { - desc: "refresh token fails - non-API error", - contextClosesIn: 700 * time.Millisecond, + desc: "refresh token fails - error", + contextClosesIn: 1900 * time.Millisecond, doError: fmt.Errorf("something went wrong"), expectedNumberDoCalls: 1, }, { desc: "refresh token fails - API non-5xx error", - contextClosesIn: 700 * time.Millisecond, + contextClosesIn: 1900 * time.Millisecond, doError: &oapierror.GenericOpenAPIError{ StatusCode: http.StatusBadRequest, }, @@ -77,92 +80,35 @@ func TestContinuousRefreshToken(t *testing.T) { }, { desc: "refresh token fails - API 5xx error", - contextClosesIn: 800 * time.Millisecond, + contextClosesIn: 2900 * time.Millisecond, doError: &oapierror.GenericOpenAPIError{ StatusCode: http.StatusInternalServerError, }, - expectedNumberDoCalls: 3, - expectedCallRange: []int{3, 4}, // Allow 3 or 4 calls due to timing race condition + expectedNumberDoCalls: 4, }, } for _, tt := range tests { + tt := tt t.Run(tt.desc, func(t *testing.T) { - accessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokensTimeToLive)), - }).SignedString([]byte("test")) - if err != nil { - t.Fatalf("failed to create access token: %v", err) - } - - refreshToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), - }).SignedString([]byte("test")) + t.Parallel() + accessToken, err := signToken(accessTokensTimeToLive) if err != nil { - t.Fatalf("failed to create refresh token: %v", err) + t.Fatalf("failed to sign access token: %v", err) } - - numberDoCalls := 0 - mockDo := func(_ *http.Request) (resp *http.Response, err error) { - numberDoCalls++ // count refresh attempts - if tt.doError != nil { - return nil, tt.doError - } - newAccessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokensTimeToLive)), - }).SignedString([]byte("test")) - if err != nil { - t.Fatalf("Do call: failed to create access token: %v", err) - } - responseBodyStruct := TokenResponseBody{ - AccessToken: newAccessToken, - RefreshToken: refreshToken, - } - responseBody, err := json.Marshal(responseBodyStruct) - if err != nil { - t.Fatalf("Do call: failed to marshal response: %v", err) - } - response := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader(responseBody)), - } - return response, nil - } - ctx := context.Background() ctx, cancel := context.WithTimeout(ctx, tt.contextClosesIn) defer cancel() - keyFlow := &KeyFlow{} - privateKeyBytes, err := generatePrivateKey() - if err != nil { - t.Fatalf("Error generating private key: %s", err) - } - keyFlowConfig := &KeyFlowConfig{ - ServiceAccountKey: fixtureServiceAccountKey(), - PrivateKey: string(privateKeyBytes), - AuthHTTPClient: &http.Client{ - Transport: mockTransportFn{mockDo}, - }, - HTTPTransport: mockTransportFn{mockDo}, - BackgroundTokenRefreshContext: nil, - } - err = keyFlow.Init(keyFlowConfig) - if err != nil { - t.Fatalf("failed to initialize key flow: %v", err) - } - - // Set the token after initialization - err = keyFlow.SetToken(accessToken, refreshToken) - if err != nil { - t.Fatalf("failed to set token: %v", err) + authFlow := &fakeAuthFlow{ + backgroundTokenRefreshContext: ctx, + doError: tt.doError, + accessTokensTimeToLive: accessTokensTimeToLive, + accessToken: accessToken, } - // Set the context for continuous refresh - keyFlow.config.BackgroundTokenRefreshContext = ctx - refresher := &continuousTokenRefresher{ - keyFlow: keyFlow, + flow: authFlow, timeStartBeforeTokenExpiration: timeStartBeforeTokenExpiration, timeBetweenContextCheck: timeBetweenContextCheck, timeBetweenTries: timeBetweenTries, @@ -172,13 +118,8 @@ func TestContinuousRefreshToken(t *testing.T) { if err == nil { t.Fatalf("routine finished with non-nil error") } - - // Check if we have a range of expected calls (for timing-sensitive tests) - if tt.expectedCallRange != nil { - if !contains(tt.expectedCallRange, numberDoCalls) { - t.Fatalf("expected %v calls to API to refresh token, got %d", tt.expectedCallRange, numberDoCalls) - } - } else if numberDoCalls != tt.expectedNumberDoCalls { + numberDoCalls := authFlow.getTokenCalls() + if numberDoCalls != tt.expectedNumberDoCalls { t.Fatalf("expected %d calls to API to refresh token, got %d", tt.expectedNumberDoCalls, numberDoCalls) } }) @@ -214,18 +155,14 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) { chanUnblockContinuousRefreshToken := make(chan bool) // The access token at the start - accessTokenFirst, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(10 * time.Second)), - }).SignedString([]byte("token-first")) + accessTokenFirst, err := signToken(10 * time.Second) if err != nil { t.Fatalf("failed to create first access token: %v", err) } // The access token that will replace accessTokenFirst // Has a much longer expiration timestamp - accessTokenSecond, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), - }).SignedString([]byte("token-second")) + accessTokenSecond, err := signToken(time.Hour) if err != nil { t.Fatalf("failed to create second access token: %v", err) } @@ -235,9 +172,7 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) { } // The refresh token used to update the access token - refreshToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), - }).SignedString([]byte("test")) + refreshToken, err := signToken(time.Hour) if err != nil { t.Fatalf("failed to create refresh token: %v", err) } @@ -264,9 +199,7 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) { // This handles the continuous nature of the refresh routine if currentTestPhase > 1 { // Return a valid response for any additional auth requests - newAccessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)), - }).SignedString([]byte("additional-token")) + newAccessToken, err := signToken(time.Hour) if err != nil { t.Fatalf("Do call: failed to create additional access token: %v", err) } @@ -419,7 +352,7 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) { // Create a custom refresher with shorter timing for the test refresher := &continuousTokenRefresher{ - keyFlow: keyFlow, + flow: keyFlow, timeStartBeforeTokenExpiration: 9 * time.Second, // Start 9 seconds before expiration timeBetweenContextCheck: 5 * time.Millisecond, timeBetweenTries: 40 * time.Millisecond, @@ -476,11 +409,54 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) { } } -func contains(arr []int, val int) bool { - for _, v := range arr { - if v == val { - return true - } +func signToken(expiration time.Duration) (string, error) { + return jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiration)), + }).SignedString([]byte("test")) +} + +var _ AuthFlow = &fakeAuthFlow{} + +type fakeAuthFlow struct { + backgroundTokenRefreshContext context.Context + tokenCounter int + doError error + accessTokensTimeToLive time.Duration + accessToken string +} + +func (f *fakeAuthFlow) RoundTrip(_ *http.Request) (*http.Response, error) { + return nil, nil +} +func (f *fakeAuthFlow) GetAccessToken() (string, error) { + expired, err := tokenExpired(f.accessToken, 0) + if err != nil { + return "", err + } + if !expired { + return f.accessToken, nil + } + f.tokenCounter++ + if f.doError != nil { + return "", f.doError + } + accessToken, err := signToken(f.accessTokensTimeToLive) + if err != nil { + return "", f.doError } - return false + f.accessToken = accessToken + return accessToken, nil +} + +func (f *fakeAuthFlow) refreshAccessToken() error { + _, err := f.GetAccessToken() + return err +} + +func (f *fakeAuthFlow) getBackgroundTokenRefreshContext() context.Context { + return f.backgroundTokenRefreshContext +} + +func (f *fakeAuthFlow) getTokenCalls() int { + return f.tokenCounter } diff --git a/core/clients/key_flow.go b/core/clients/key_flow.go index 589774314..bac4022cd 100644 --- a/core/clients/key_flow.go +++ b/core/clients/key_flow.go @@ -4,11 +4,9 @@ import ( "context" "crypto/rsa" "crypto/x509" - "encoding/json" "encoding/pem" "errors" "fmt" - "io" "net/http" "net/url" "regexp" @@ -32,10 +30,10 @@ const ( tokenAPI = "https://service-account.api.stackit.cloud/token" //nolint:gosec // linter false positive defaultTokenType = "Bearer" defaultScope = "" - - defaultTokenExpirationLeeway = time.Second * 5 ) +var _ AuthFlow = &KeyFlow{} + // KeyFlow handles auth with SA key type KeyFlow struct { rt http.RoundTripper @@ -65,16 +63,6 @@ type KeyFlowConfig struct { AuthHTTPClient *http.Client } -// TokenResponseBody is the API response -// when requesting a new token -type TokenResponseBody struct { - AccessToken string `json:"access_token"` - ExpiresIn int `json:"expires_in"` - RefreshToken string `json:"refresh_token"` - Scope string `json:"scope"` - TokenType string `json:"token_type"` -} - // ServiceAccountKeyResponse is the API response // when creating a new SA key type ServiceAccountKeyResponse struct { @@ -114,6 +102,9 @@ func (c *KeyFlow) GetServiceAccountEmail() string { } // GetToken returns the token field +// +// Deprecated: Use GetAccessToken instead. +// This will be removed after 2026-07-01. func (c *KeyFlow) GetToken() TokenResponseBody { c.tokenMutex.RLock() defer c.tokenMutex.RUnlock() @@ -160,6 +151,9 @@ func (c *KeyFlow) Init(cfg *KeyFlowConfig) error { // SetToken can be used to set an access and refresh token manually in the client. // The other fields in the token field are determined by inspecting the token or setting default values. +// +// Deprecated: This method will be removed in future versions. Access tokens are going to be automatically managed by the client. +// This will be removed after 2026-07-01. func (c *KeyFlow) SetToken(accessToken, refreshToken string) error { // We can safely use ParseUnverified because we are not authenticating the user, // We are parsing the token just to get the expiration time claim @@ -176,8 +170,8 @@ func (c *KeyFlow) SetToken(accessToken, refreshToken string) error { c.token = &TokenResponseBody{ AccessToken: accessToken, ExpiresIn: int(exp.Time.Unix()), - Scope: defaultScope, RefreshToken: refreshToken, + Scope: defaultScope, TokenType: defaultTokenType, } c.tokenMutex.Unlock() @@ -203,7 +197,6 @@ func (c *KeyFlow) GetAccessToken() (string, error) { if c.rt == nil { return "", fmt.Errorf("nil http round tripper, please run Init()") } - var accessToken string c.tokenMutex.RLock() @@ -237,6 +230,14 @@ func (c *KeyFlow) GetAccessToken() (string, error) { return accessToken, nil } +func (c *KeyFlow) refreshAccessToken() error { + return c.recreateAccessToken() +} + +func (c *KeyFlow) getBackgroundTokenRefreshContext() context.Context { + return c.config.BackgroundTokenRefreshContext +} + // validate the client is configured well func (c *KeyFlow) validate() error { if c.config.ServiceAccountKey == nil { @@ -307,11 +308,20 @@ func (c *KeyFlow) createAccessToken() (err error) { err = fmt.Errorf("close request access token response: %w", tempErr) } }() - return c.parseTokenResponse(res) + token, err := parseTokenResponse(res) + if err != nil { + return err + } + c.tokenMutex.Lock() + c.token = token + c.tokenMutex.Unlock() + return nil } // createAccessTokenWithRefreshToken creates an access token using // an existing pre-validated refresh token +// Deprecated: This method will be removed in future versions. Access tokens are going to be refreshed without refresh token. +// This will be removed after 2026-07-01. func (c *KeyFlow) createAccessTokenWithRefreshToken() (err error) { c.tokenMutex.RLock() refreshToken := c.token.RefreshToken @@ -327,7 +337,14 @@ func (c *KeyFlow) createAccessTokenWithRefreshToken() (err error) { err = fmt.Errorf("close request access token with refresh token response: %w", tempErr) } }() - return c.parseTokenResponse(res) + token, err := parseTokenResponse(res) + if err != nil { + return err + } + c.tokenMutex.Lock() + c.token = token + c.tokenMutex.Unlock() + return nil } // generateSelfSignedJWT generates JWT token @@ -338,7 +355,7 @@ func (c *KeyFlow) generateSelfSignedJWT() (string, error) { "jti": uuid.New(), "aud": c.key.Credentials.Aud, "iat": jwt.NewNumericDate(time.Now()), - "exp": jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), + "exp": jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), } token := jwt.NewWithClaims(jwt.SigningMethodRS512, claims) token.Header["kid"] = c.key.Credentials.Kid @@ -358,6 +375,7 @@ func (c *KeyFlow) requestToken(grant, assertion string) (*http.Response, error) } else { body.Set("assertion", assertion) } + payload := strings.NewReader(body.Encode()) req, err := http.NewRequest(http.MethodPost, c.config.TokenUrl, payload) if err != nil { @@ -367,60 +385,3 @@ func (c *KeyFlow) requestToken(grant, assertion string) (*http.Response, error) return c.authClient.Do(req) } - -// parseTokenResponse parses the response from the server -func (c *KeyFlow) parseTokenResponse(res *http.Response) error { - if res == nil { - return fmt.Errorf("received bad response from API") - } - if res.StatusCode != http.StatusOK { - body, err := io.ReadAll(res.Body) - if err != nil { - // Fail silently, omit body from error - // We're trying to show error details, so it's unnecessary to fail because of this err - body = []byte{} - } - return &oapierror.GenericOpenAPIError{ - StatusCode: res.StatusCode, - Body: body, - } - } - body, err := io.ReadAll(res.Body) - if err != nil { - return err - } - - c.tokenMutex.Lock() - c.token = &TokenResponseBody{} - err = json.Unmarshal(body, c.token) - c.tokenMutex.Unlock() - if err != nil { - return fmt.Errorf("unmarshal token response: %w", err) - } - - return nil -} - -func tokenExpired(token string, tokenExpirationLeeway time.Duration) (bool, error) { - if token == "" { - return true, nil - } - - // We can safely use ParseUnverified because we are not authenticating the user at this point. - // We're just checking the expiration time - tokenParsed, _, err := jwt.NewParser().ParseUnverified(token, &jwt.RegisteredClaims{}) - if err != nil { - return false, fmt.Errorf("parse token: %w", err) - } - - expirationTimestampNumeric, err := tokenParsed.Claims.GetExpirationTime() - if err != nil { - return false, fmt.Errorf("get expiration timestamp: %w", err) - } - - // Pretend to be `tokenExpirationLeeway` into the future to avoid token expiring - // between retrieving the token and upstream systems validating it. - now := time.Now().Add(tokenExpirationLeeway) - - return now.After(expirationTimestampNumeric.Time), nil -} diff --git a/core/clients/key_flow_test.go b/core/clients/key_flow_test.go index 9803f24ee..045bd28d2 100644 --- a/core/clients/key_flow_test.go +++ b/core/clients/key_flow_test.go @@ -178,8 +178,8 @@ func TestSetToken(t *testing.T) { AccessToken: accessToken, ExpiresIn: int(timestamp.Unix()), RefreshToken: tt.refreshToken, - Scope: defaultScope, - TokenType: defaultTokenType, + Scope: "", + TokenType: "Bearer", } if !cmp.Equal(expectedKeyFlowToken, keyFlow.token) { t.Errorf("The returned result is wrong. Expected %+v, got %+v", expectedKeyFlowToken, keyFlow.token) @@ -194,25 +194,25 @@ func TestTokenExpired(t *testing.T) { tests := []struct { desc string tokenInvalid bool - tokenExpiresAt time.Time + tokenDuration time.Duration expectedErr bool expectedIsExpired bool }{ { desc: "token valid", - tokenExpiresAt: time.Now().Add(time.Hour), + tokenDuration: time.Hour, expectedErr: false, expectedIsExpired: false, }, { desc: "token expired", - tokenExpiresAt: time.Now().Add(-time.Hour), + tokenDuration: -time.Hour, expectedErr: false, expectedIsExpired: true, }, { desc: "token almost expired", - tokenExpiresAt: time.Now().Add(tokenExpirationLeeway), + tokenDuration: tokenExpirationLeeway, expectedErr: false, expectedIsExpired: true, }, @@ -228,9 +228,7 @@ func TestTokenExpired(t *testing.T) { var err error token := "foo" if !tt.tokenInvalid { - token, err = jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(tt.tokenExpiresAt), - }).SignedString([]byte("test")) + token, err = signToken(tt.tokenDuration) if err != nil { t.Fatalf("failed to create token: %v", err) } @@ -442,10 +440,9 @@ func TestKeyFlow_Do(t *testing.T) { res.Header().Set("Content-Type", "application/json") token := &TokenResponseBody{ - AccessToken: testBearerToken, - ExpiresIn: 2147483647, - RefreshToken: testBearerToken, - TokenType: "Bearer", + AccessToken: testBearerToken, + ExpiresIn: 2147483647, + TokenType: "Bearer", } if err := json.NewEncoder(res.Body).Encode(token); err != nil { diff --git a/core/clients/workload_identity_flow.go b/core/clients/workload_identity_flow.go new file mode 100644 index 000000000..73a1ed272 --- /dev/null +++ b/core/clients/workload_identity_flow.go @@ -0,0 +1,223 @@ +package clients + +import ( + "context" + "fmt" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/stackitcloud/stackit-sdk-go/core/oidcadapters" + "github.com/stackitcloud/stackit-sdk-go/core/utils" +) + +const ( + clientIDEnv = "STACKIT_SERVICE_ACCOUNT_EMAIL" + FederatedTokenFileEnv = "STACKIT_FEDERATED_TOKEN_FILE" //nolint:gosec // This is not a secret, just the env variable name + wifTokenEndpointEnv = "STACKIT_IDP_TOKEN_ENDPOINT" //nolint:gosec // This is not a secret, just the env variable name + wifTokenExpirationEnv = "STACKIT_IDP_TOKEN_EXPIRATION_SECONDS" //nolint:gosec // This is not a secret, just the env variable name + + wifClientAssertionType = "urn:schwarz:params:oauth:client-assertion-type:workload-jwt" + wifGrantType = "client_credentials" + defaultWifTokenEndpoint = "https://accounts.stackit.cloud/oauth/v2/token" //nolint:gosec // This is not a secret, just the public endpoint for default value + defaultFederatedTokenPath = "/var/run/secrets/stackit.cloud/serviceaccount/token" //nolint:gosec // This is not a secret, just the default path for workload identity token + defaultWifExpirationToken = "1h" +) + +var ( + _ = utils.GetEnvOrDefault(wifTokenExpirationEnv, defaultWifExpirationToken) // Not used yet +) + +var _ AuthFlow = &WorkloadIdentityFederationFlow{} + +// WorkloadIdentityFlow handles auth with Workload Identity Federation +type WorkloadIdentityFederationFlow struct { + rt http.RoundTripper + authClient *http.Client + config *WorkloadIdentityFederationFlowConfig + + tokenMutex sync.RWMutex + token *TokenResponseBody + + // If the current access token would expire in less than TokenExpirationLeeway, + // the client will refresh it early to prevent clock skew or other timing issues. + tokenExpirationLeeway time.Duration +} + +// KeyFlowConfig is the flow config +type WorkloadIdentityFederationFlowConfig struct { + TokenUrl string + ClientID string + TokenExpiration string // Not supported yet + BackgroundTokenRefreshContext context.Context // Functionality is enabled if this isn't nil + HTTPTransport http.RoundTripper + AuthHTTPClient *http.Client + FederatedTokenFunction oidcadapters.OIDCTokenFunc // Function to get the federated token +} + +// GetConfig returns the flow configuration +func (c *WorkloadIdentityFederationFlow) GetConfig() WorkloadIdentityFederationFlowConfig { + if c.config == nil { + return WorkloadIdentityFederationFlowConfig{} + } + return *c.config +} + +// GetAccessToken implements AuthFlow. +func (c *WorkloadIdentityFederationFlow) GetAccessToken() (string, error) { + if c.rt == nil { + return "", fmt.Errorf("nil http round tripper, please run Init()") + } + var accessToken string + + c.tokenMutex.RLock() + if c.token != nil { + accessToken = c.token.AccessToken + } + c.tokenMutex.RUnlock() + + accessTokenExpired, err := tokenExpired(accessToken, c.tokenExpirationLeeway) + if err != nil { + return "", fmt.Errorf("check access token is expired: %w", err) + } + if !accessTokenExpired { + return accessToken, nil + } + if err = c.createAccessToken(); err != nil { + return "", fmt.Errorf("get new access token: %w", err) + } + + c.tokenMutex.RLock() + accessToken = c.token.AccessToken + c.tokenMutex.RUnlock() + + return accessToken, nil +} + +func (c *WorkloadIdentityFederationFlow) refreshAccessToken() error { + return c.createAccessToken() +} + +// RoundTrip implements the http.RoundTripper interface. +// It gets a token, adds it to the request's authorization header, and performs the request. +func (c *WorkloadIdentityFederationFlow) RoundTrip(req *http.Request) (*http.Response, error) { + if c.rt == nil { + return nil, fmt.Errorf("please run Init()") + } + + accessToken, err := c.GetAccessToken() + if err != nil { + return nil, err + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) + return c.rt.RoundTrip(req) +} + +// getBackgroundTokenRefreshContext implements AuthFlow. +func (c *WorkloadIdentityFederationFlow) getBackgroundTokenRefreshContext() context.Context { + return c.config.BackgroundTokenRefreshContext +} + +func (c *WorkloadIdentityFederationFlow) Init(cfg *WorkloadIdentityFederationFlowConfig) error { + // No concurrency at this point, so no mutex check needed + c.token = &TokenResponseBody{} + c.config = cfg + + if c.config.TokenUrl == "" { + c.config.TokenUrl = utils.GetEnvOrDefault(wifTokenEndpointEnv, defaultWifTokenEndpoint) + } + + if c.config.ClientID == "" { + c.config.ClientID = utils.GetEnvOrDefault(clientIDEnv, "") + } + + if c.config.FederatedTokenFunction == nil { + c.config.FederatedTokenFunction = oidcadapters.ReadJWTFromFileSystem(utils.GetEnvOrDefault(FederatedTokenFileEnv, defaultFederatedTokenPath)) + } + + c.tokenExpirationLeeway = defaultTokenExpirationLeeway + + if c.rt = cfg.HTTPTransport; c.rt == nil { + c.rt = http.DefaultTransport + } + + if c.authClient = cfg.AuthHTTPClient; cfg.AuthHTTPClient == nil { + c.authClient = &http.Client{ + Transport: c.rt, + Timeout: DefaultClientTimeout, + } + } + + err := c.validate() + if err != nil { + return err + } + + if c.config.BackgroundTokenRefreshContext != nil { + go continuousRefreshToken(c) + } + return nil +} + +// validate the client is configured well +func (c *WorkloadIdentityFederationFlow) validate() error { + if c.config.ClientID == "" { + return fmt.Errorf("client ID cannot be empty") + } + if c.config.TokenUrl == "" { + return fmt.Errorf("token URL cannot be empty") + } + if _, err := c.config.FederatedTokenFunction(context.Background()); err != nil { + return fmt.Errorf("error reading federated token file - %w", err) + } + if c.tokenExpirationLeeway < 0 { + return fmt.Errorf("token expiration leeway cannot be negative") + } + + return nil +} + +// createAccessToken creates an access token using self signed JWT +func (c *WorkloadIdentityFederationFlow) createAccessToken() error { + clientAssertion, err := c.config.FederatedTokenFunction(context.Background()) + if err != nil { + return err + } + res, err := c.requestToken(c.config.ClientID, clientAssertion) + if err != nil { + return err + } + defer func() { + tempErr := res.Body.Close() + if tempErr != nil && err == nil { + err = fmt.Errorf("close request access token response: %w", tempErr) + } + }() + token, err := parseTokenResponse(res) + if err != nil { + return err + } + c.tokenMutex.Lock() + c.token = token + c.tokenMutex.Unlock() + return nil +} + +func (c *WorkloadIdentityFederationFlow) requestToken(clientID, assertion string) (*http.Response, error) { + body := url.Values{} + body.Set("grant_type", wifGrantType) + body.Set("client_assertion_type", wifClientAssertionType) + body.Set("client_assertion", assertion) + body.Set("client_id", clientID) + + payload := strings.NewReader(body.Encode()) + req, err := http.NewRequest(http.MethodPost, c.config.TokenUrl, payload) + if err != nil { + return nil, err + } + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + return c.authClient.Do(req) +} diff --git a/core/clients/workload_identity_flow_test.go b/core/clients/workload_identity_flow_test.go new file mode 100644 index 000000000..7d59593f4 --- /dev/null +++ b/core/clients/workload_identity_flow_test.go @@ -0,0 +1,310 @@ +package clients + +import ( + "context" + "encoding/json" + "log" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" + + "github.com/stackitcloud/stackit-sdk-go/core/oidcadapters" + + "github.com/golang-jwt/jwt/v5" +) + +func TestWorkloadIdentityFlowInit(t *testing.T) { + tests := []struct { + name string + clientID string + clientIDAsEnv bool + customTokenUrl string + customTokenUrlEnv bool + tokenExpiration string + validAssertion bool + tokenFilePathAsEnv bool + missingTokenFilePath bool + wantErr bool + }{ + { + name: "ok setting all", + clientID: "test@stackit.cloud", + validAssertion: true, + wantErr: false, + }, + { + name: "ok using defaults from envs", + clientID: "test@stackit.cloud", + clientIDAsEnv: true, + tokenFilePathAsEnv: true, + customTokenUrlEnv: true, + validAssertion: true, + wantErr: false, + }, + { + name: "ok using defaults from envs", + clientID: "test@stackit.cloud", + clientIDAsEnv: true, + tokenFilePathAsEnv: true, + customTokenUrlEnv: true, + validAssertion: true, + wantErr: false, + }, + { + name: "missing client id", + validAssertion: true, + wantErr: true, + }, + { + name: "invalid assertion", + clientID: "test@stackit.cloud", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + flow := &WorkloadIdentityFederationFlow{} + flowConfig := &WorkloadIdentityFederationFlowConfig{} + if tt.customTokenUrl != "" { + if tt.customTokenUrlEnv { + t.Setenv("STACKIT_IDP_TOKEN_ENDPOINT", tt.customTokenUrl) + } else { + flowConfig.TokenUrl = tt.customTokenUrl + } + } + + if tt.clientID != "" { + if tt.clientIDAsEnv { + t.Setenv("STACKIT_SERVICE_ACCOUNT_EMAIL", tt.clientID) + } else { + flowConfig.ClientID = tt.clientID + } + } + if tt.tokenExpiration != "" { + flowConfig.TokenExpiration = tt.tokenExpiration + } + + if !tt.missingTokenFilePath { + file, err := os.CreateTemp("", "*.token") + if err != nil { + log.Fatal(err) + } + defer func() { + err := os.Remove(file.Name()) + if err != nil { + t.Fatalf("Removing temporary file: %s", err) + } + }() + if tt.validAssertion { + token, err := signTokenWithSubject("subject", time.Minute) + if err != nil { + t.Fatalf("failed to create token: %v", err) + } + err = os.WriteFile(file.Name(), []byte(token), os.ModeAppend) + if err != nil { + t.Fatalf("writing temporary file: %s", err) + } + } + if tt.tokenFilePathAsEnv { + t.Setenv("STACKIT_FEDERATED_TOKEN_FILE", file.Name()) + } else { + flowConfig.FederatedTokenFunction = oidcadapters.ReadJWTFromFileSystem(file.Name()) + } + } + + if err := flow.Init(flowConfig); (err != nil) != tt.wantErr { + t.Errorf("KeyFlow.Init() error = %v, wantErr %v", err, tt.wantErr) + } + if flow.config == nil { + t.Error("config is nil") + } + + if flow.config.ClientID != tt.clientID { + t.Errorf("clientID mismatch, want %s, got %s", tt.clientID, flow.config.ClientID) + } + + if tt.customTokenUrl != "" && flow.config.TokenUrl != tt.customTokenUrl { + t.Errorf("tokenUrl mismatch, want %s, got %s", tt.customTokenUrl, flow.config.TokenUrl) + } + + if tt.customTokenUrl == "" && flow.config.TokenUrl != "https://accounts.stackit.cloud/oauth/v2/token" { + t.Errorf("tokenUrl mismatch, want %s, got %s", "https://accounts.stackit.cloud/oauth/v2/token", flow.config.TokenUrl) + } + + if tt.tokenExpiration != "" && flow.config.TokenExpiration != tt.tokenExpiration { + t.Errorf("tokenExpiration mismatch, want %s, got %s", tt.tokenExpiration, flow.config.TokenExpiration) + } + }) + } +} + +func signTokenWithSubject(sub string, expiration time.Duration) (string, error) { + return jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiration)), + Subject: sub, + }).SignedString([]byte("test")) +} + +func TestWorkloadIdentityFlowRoundTrip(t *testing.T) { + validSub := "valid-sub" + serviceAccountSub := "sa-sub" + tests := []struct { + name string + clientID string + validAssertion bool + injectToken bool + wantErr bool + }{ + { + name: "ok setting all", + clientID: "test@stackit.cloud", + validAssertion: true, + wantErr: false, + }, + { + name: "injected token ok", + clientID: "test@stackit.cloud", + validAssertion: true, + injectToken: true, + wantErr: false, + }, + { + name: "invalid assertion", + clientID: "test@stackit.cloud", + validAssertion: false, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + err := r.ParseForm() + if err != nil { + t.Fatalf("failed to parse form: %v", err) + } + assertionType := r.PostForm.Get("client_assertion_type") + if assertionType != "urn:schwarz:params:oauth:client-assertion-type:workload-jwt" { + t.Fatalf("invalid assertion type: %s", assertionType) + } + grantType := r.PostForm.Get("grant_type") + if grantType != "client_credentials" { + t.Fatalf("invalid grant type: %s", assertionType) + } + context, _, err := jwt.NewParser().ParseUnverified(r.PostForm.Get("client_assertion"), jwt.MapClaims{}) + if err != nil { + t.Fatalf("failed to validate token: %v", err) + } + sub, err := context.Claims.GetSubject() + if err != nil { + t.Fatalf("failed to validate token sub: %v", err) + } + if sub != validSub { + w.WriteHeader(http.StatusUnauthorized) + return + } + + token, err := signTokenWithSubject(serviceAccountSub, time.Minute) + if err != nil { + t.Fatalf("failed to create token: %v", err) + } + + tokenResponse := &TokenResponseBody{ + AccessToken: token, + ExpiresIn: 60, + TokenType: "Bearer", + } + + payload, err := json.Marshal(tokenResponse) + if err != nil { + t.Fatalf("failed to create token payload: %v", err) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, err = w.Write(payload) + if err != nil { + t.Fatalf("writing response: %s", err) + } + })) + t.Cleanup(authServer.Close) + + protectedResource := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + context, _, err := jwt.NewParser().ParseUnverified(strings.Fields(r.Header.Get("Authorization"))[1], jwt.MapClaims{}) + if err != nil { + t.Fatalf("failed to validate token: %v", err) + } + + sub, err := context.Claims.GetSubject() + if err != nil { + t.Fatalf("failed to validate token sub: %v", err) + } + if sub != serviceAccountSub { + t.Fatalf("invalid token on protected resource: %v", err) + } + + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(protectedResource.Close) + + flow := &WorkloadIdentityFederationFlow{} + flowConfig := &WorkloadIdentityFederationFlowConfig{} + flowConfig.TokenUrl = authServer.URL + + flowConfig.ClientID = tt.clientID + + subject := "wrong" + if tt.validAssertion { + subject = validSub + } + token, err := signTokenWithSubject(subject, time.Minute) + if err != nil { + t.Fatalf("failed to create token: %v", err) + } + + if tt.injectToken { + flowConfig.FederatedTokenFunction = func(context.Context) (string, error) { + return token, nil + } + } else { + file, err := os.CreateTemp("", "*.token") + if err != nil { + log.Fatal(err) + } + defer func() { + err := os.Remove(file.Name()) + if err != nil { + t.Fatalf("Removing temporary file: %s", err) + } + }() + flowConfig.FederatedTokenFunction = oidcadapters.ReadJWTFromFileSystem(file.Name()) + err = os.WriteFile(file.Name(), []byte(token), os.ModeAppend) + if err != nil { + t.Fatalf("writing temporary file: %s", err) + } + } + + if err := flow.Init(flowConfig); err != nil { + t.Errorf("KeyFlow.Init() error = %v", err) + } + if flow.config == nil { + t.Error("config is nil") + } + + client := http.Client{ + Transport: flow, + } + resp, err := client.Get(protectedResource.URL) + if (err != nil || resp.StatusCode != http.StatusOK) && !tt.wantErr { + t.Fatalf("failed request to protected resource: %v", err) + } + if resp != nil && resp.Body != nil { + if err := resp.Body.Close(); err != nil { + t.Errorf("resp.Body.Close() error = %v", err) + } + } + }) + } +} diff --git a/core/config/config.go b/core/config/config.go index 93002c02a..ec3bae10e 100644 --- a/core/config/config.go +++ b/core/config/config.go @@ -11,6 +11,7 @@ import ( "time" "github.com/stackitcloud/stackit-sdk-go/core/clients" + "github.com/stackitcloud/stackit-sdk-go/core/oidcadapters" ) const ( @@ -75,26 +76,29 @@ type Middleware func(http.RoundTripper) http.RoundTripper // Configuration stores the configuration of the API client type Configuration struct { - Host string `json:"host,omitempty"` - Scheme string `json:"scheme,omitempty"` - DefaultHeader map[string]string `json:"defaultHeader,omitempty"` - UserAgent string `json:"userAgent,omitempty"` - Debug bool `json:"debug,omitempty"` - NoAuth bool `json:"noAuth,omitempty"` - ServiceAccountEmail string `json:"serviceAccountEmail,omitempty"` // Deprecated: ServiceAccountEmail is not required and will be removed after 12th June 2025. - Token string `json:"token,omitempty"` - ServiceAccountKey string `json:"serviceAccountKey,omitempty"` - PrivateKey string `json:"privateKey,omitempty"` - ServiceAccountKeyPath string `json:"serviceAccountKeyPath,omitempty"` - PrivateKeyPath string `json:"privateKeyPath,omitempty"` - CredentialsFilePath string `json:"credentialsFilePath,omitempty"` - TokenCustomUrl string `json:"tokenCustomUrl,omitempty"` - Region string `json:"region,omitempty"` - CustomAuth http.RoundTripper - Servers ServerConfigurations - OperationServers map[string]ServerConfigurations - HTTPClient *http.Client - Middleware []Middleware + Host string `json:"host,omitempty"` + Scheme string `json:"scheme,omitempty"` + DefaultHeader map[string]string `json:"defaultHeader,omitempty"` + UserAgent string `json:"userAgent,omitempty"` + Debug bool `json:"debug,omitempty"` + NoAuth bool `json:"noAuth,omitempty"` + WorkloadIdentityFederation bool `json:"workloadIdentityFederation,omitempty"` + ServiceAccountFederatedTokenExpiration string `json:"serviceAccountFederatedTokenExpiration,omitempty"` + ServiceAccountFederatedTokenFunc oidcadapters.OIDCTokenFunc `json:"serviceAccountFederatedTokenFunc,omitempty"` + ServiceAccountEmail string `json:"serviceAccountEmail,omitempty"` + Token string `json:"token,omitempty"` + ServiceAccountKey string `json:"serviceAccountKey,omitempty"` + PrivateKey string `json:"privateKey,omitempty"` + ServiceAccountKeyPath string `json:"serviceAccountKeyPath,omitempty"` + PrivateKeyPath string `json:"privateKeyPath,omitempty"` + CredentialsFilePath string `json:"credentialsFilePath,omitempty"` + TokenCustomUrl string `json:"tokenCustomUrl,omitempty"` + Region string `json:"region,omitempty"` + CustomAuth http.RoundTripper + Servers ServerConfigurations + OperationServers map[string]ServerConfigurations + HTTPClient *http.Client + Middleware []Middleware // If != nil, a goroutine will be launched that will refresh the service account's access token when it's close to being expired. // The goroutine is killed whenever this context is canceled. @@ -176,8 +180,6 @@ func WithTokenEndpoint(url string) ConfigurationOption { } // WithServiceAccountEmail returns a ConfigurationOption that sets the service account email -// -// Deprecated: WithServiceAccountEmail is not required and will be removed after 12th June 2025. func WithServiceAccountEmail(serviceAccountEmail string) ConfigurationOption { return func(config *Configuration) error { config.ServiceAccountEmail = serviceAccountEmail @@ -237,6 +239,48 @@ func WithToken(token string) ConfigurationOption { } } +// WithWorkloadIdentityFederationAuth returns a ConfigurationOption that sets workload identity flow to be used for authentication in API calls +func WithWorkloadIdentityFederationAuth() ConfigurationOption { + return func(config *Configuration) error { + config.WorkloadIdentityFederation = true + return nil + } +} + +// WithWorkloadIdentityFederationFunc returns a ConfigurationOption that sets the function to get the federated token for workload identity federation flow +func WithWorkloadIdentityFederationFunc(function oidcadapters.OIDCTokenFunc) ConfigurationOption { + return func(config *Configuration) error { + config.ServiceAccountFederatedTokenFunc = function + return nil + } +} + +// WithWorkloadIdentityFederationPath returns a ConfigurationOption that sets the custom path to the federated token file for workload identity federation flow +func WithWorkloadIdentityFederationPath(path string) ConfigurationOption { + return func(config *Configuration) error { + config.ServiceAccountFederatedTokenFunc = oidcadapters.ReadJWTFromFileSystem(path) + return nil + } +} + +// WithWorkloadIdentityFederationFunc returns a ConfigurationOption that sets the id token for workload identity federation flow +func WithWorkloadIdentityFederationToken(token string) ConfigurationOption { + return func(config *Configuration) error { + config.ServiceAccountFederatedTokenFunc = func(context.Context) (string, error) { + return token, nil + } + return nil + } +} + +// WithWorkloadIdentityFederationTokenExpiration returns a ConfigurationOption that sets the token expiration for workload identity federation flow +func WithWorkloadIdentityFederationTokenExpiration(expiration string) ConfigurationOption { + return func(config *Configuration) error { + config.ServiceAccountFederatedTokenExpiration = expiration + return nil + } +} + // Deprecated: retry options were removed to reduce complexity of the client. If this functionality is needed, you can provide your own custom HTTP client. This option has no effect, and will be removed in a later update func WithMaxRetries(_ int) ConfigurationOption { return func(_ *Configuration) error { diff --git a/core/oidcadapters/filesystem.go b/core/oidcadapters/filesystem.go new file mode 100644 index 000000000..9947d4088 --- /dev/null +++ b/core/oidcadapters/filesystem.go @@ -0,0 +1,27 @@ +package oidcadapters + +import ( + "context" + "os" + + "github.com/golang-jwt/jwt/v5" +) + +var ( + parser *jwt.Parser = jwt.NewParser() +) + +func ReadJWTFromFileSystem(tokenFilePath string) OIDCTokenFunc { + return func(context.Context) (string, error) { + token, err := os.ReadFile(tokenFilePath) + if err != nil { + return "", err + } + tokenStr := string(token) + _, _, err = parser.ParseUnverified(tokenStr, jwt.MapClaims{}) + if err != nil { + return "", err + } + return tokenStr, nil + } +} diff --git a/core/oidcadapters/filesystem_test.go b/core/oidcadapters/filesystem_test.go new file mode 100644 index 000000000..998f0765c --- /dev/null +++ b/core/oidcadapters/filesystem_test.go @@ -0,0 +1,62 @@ +package oidcadapters + +import ( + "context" + "log" + "os" + "testing" +) + +func TestReadJWTFromFileSystem(t *testing.T) { + file, err := os.CreateTemp("", "*.token") + if err != nil { + log.Fatal(err) + } + defer func() { + err := os.Remove(file.Name()) + if err != nil { + t.Fatalf("Removing temporary file: %s", err) + } + }() + + token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.KMUFsIDTnFmyG3nMiGM6H9FNFUROf3wh7SmqJp-QV30" // nolint:gosec // This is a fake token for testing purposes only + err = os.WriteFile(file.Name(), []byte(token), os.ModeAppend) + if err != nil { + t.Fatalf("Writing temporary file: %s", err) + } + _, err = ReadJWTFromFileSystem(file.Name())(context.Background()) + if err != nil { + t.Fatalf("Reading JWT from file system: %s", err) + } +} + +func TestReadRandomContentFromFileSystem(t *testing.T) { + file, err := os.CreateTemp("", "*.token") + if err != nil { + log.Fatal(err) + } + defer func() { + err := os.Remove(file.Name()) + if err != nil { + t.Fatalf("Removing temporary file: %s", err) + } + }() + + token := "invalid random content" + err = os.WriteFile(file.Name(), []byte(token), os.ModeAppend) + if err != nil { + t.Fatalf("Writing temporary file: %s", err) + } + + _, err = ReadJWTFromFileSystem(file.Name())(context.Background()) + if err == nil { + t.Fatalf("Reading JWT from file system must fail") + } +} + +func TestReadMissingFileFromFileSystem(t *testing.T) { + _, err := ReadJWTFromFileSystem("/path/to/nonexistent/file.token")(context.Background()) + if err == nil { + t.Fatalf("Reading JWT from file system must fail") + } +} diff --git a/core/oidcadapters/githubactions.go b/core/oidcadapters/githubactions.go new file mode 100644 index 000000000..589c74ca6 --- /dev/null +++ b/core/oidcadapters/githubactions.go @@ -0,0 +1,59 @@ +package oidcadapters + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" +) + +func RequestGHOIDCToken(oidcRequestUrl, oidcRequestToken string) OIDCTokenFunc { + return func(ctx context.Context) (string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, oidcRequestUrl, http.NoBody) + if err != nil { + return "", fmt.Errorf("githubAssertion: failed to build request: %w", err) + } + + query, err := url.ParseQuery(req.URL.RawQuery) + if err != nil { + return "", fmt.Errorf("githubAssertion: cannot parse URL query") + } + + if query.Get("audience") == "" { + query.Set("audience", "sts.accounts.stackit.cloud") + req.URL.RawQuery = query.Encode() + } + + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oidcRequestToken)) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", fmt.Errorf("githubAssertion: cannot request token: %w", err) + } + + defer func() { + _ = resp.Body.Close() + }() + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return "", fmt.Errorf("githubAssertion: cannot parse response: %w", err) + } + + if c := resp.StatusCode; c < 200 || c > 299 { + return "", fmt.Errorf("githubAssertion: received HTTP status %d with response: %s", resp.StatusCode, body) + } + + var tokenRes struct { + Value string `json:"value"` + } + if err := json.Unmarshal(body, &tokenRes); err != nil { + return "", fmt.Errorf("githubAssertion: cannot unmarshal response: %w", err) + } + + return tokenRes.Value, nil + } +} diff --git a/core/oidcadapters/types.go b/core/oidcadapters/types.go new file mode 100644 index 000000000..c1d92ef42 --- /dev/null +++ b/core/oidcadapters/types.go @@ -0,0 +1,5 @@ +package oidcadapters + +import "context" + +type OIDCTokenFunc func(context.Context) (string, error) diff --git a/core/utils/utils.go b/core/utils/utils.go index e36612c0e..9d5c8818e 100644 --- a/core/utils/utils.go +++ b/core/utils/utils.go @@ -1,5 +1,7 @@ package utils +import "os" + // Ptr Returns the pointer to any type T func Ptr[T any](v T) *T { return &v @@ -27,3 +29,10 @@ func EnumSliceToStringSlice[T ~string](inputSlice []T) []string { return result } + +func GetEnvOrDefault(envVar, defaultValue string) string { + if value := os.Getenv(envVar); value != "" { + return value + } + return defaultValue +}