diff --git a/internal/slack/slack.go b/internal/slack/slack.go index 62bd79c..86fbbb6 100644 --- a/internal/slack/slack.go +++ b/internal/slack/slack.go @@ -3,6 +3,7 @@ package slack import ( "errors" "fmt" + "time" "github.com/elliotchance/pie/v2" "github.com/rs/zerolog/log" @@ -14,7 +15,14 @@ type IService interface { } type service struct { - client iclient + client iclient + maxAttempts int + initialBackoff time.Duration +} + +type conversationsResult struct { + Channels []slack.Channel + NextCursor string } // New creates a new Slack service @@ -24,19 +32,26 @@ func New(token string, debug bool) (IService, error) { return nil, errors.New("failed to create slack client") } - s := service{&client{client: slackClient}} + s := service{ + client: &client{client: slackClient}, + maxAttempts: 5, + initialBackoff: 2 * time.Second, + } return &s, nil } // PostMessage posts a message to the given slack channel func (s *service) PostMessage(channelName string, options ...slack.MsgOption) (ts string, err error) { - channel, err := s.findSlackChannel(channelName) + channel, err := runWithRetries(func() (*slack.Channel, error) { return s.findSlackChannel(channelName) }, s.maxAttempts, s.initialBackoff) if err != nil { return } - _, ts, err = s.client.PostMessage(channel.ID, options...) + ts, err = runWithRetries(func() (string, error) { + _, msgTs, err := s.client.PostMessage(channel.ID, options...) + return msgTs, err + }, s.maxAttempts, s.initialBackoff) if err != nil { return ts, errors.Join(errors.New("failed to post slack message"), err) } @@ -54,15 +69,25 @@ func (s *service) findSlackChannel(channelName string) (channel *slack.Channel, var channelTypes = []string{"private_channel", "public_channel"} for { - if channels, nextCursor, err = s.client.GetConversations(&slack.GetConversationsParameters{ - ExcludeArchived: true, - Cursor: nextCursor, - Types: channelTypes, - Limit: 1000, - }); err != nil { - return nil, errors.Join(errors.New("failed to get slack channel list"), err) + result, opErr := runWithRetries(func() (conversationsResult, error) { + convChannels, convCursor, convErr := s.client.GetConversations(&slack.GetConversationsParameters{ + ExcludeArchived: true, + Cursor: nextCursor, + Types: channelTypes, + Limit: 1000, + }) + if convErr != nil { + return conversationsResult{}, convErr + } + return conversationsResult{Channels: convChannels, NextCursor: convCursor}, nil + }, s.maxAttempts, s.initialBackoff) + if opErr != nil { + return nil, errors.Join(errors.New("failed to get slack channel list"), opErr) } + channels = result.Channels + nextCursor = result.NextCursor + idx := pie.FindFirstUsing(channels, func(c slack.Channel) bool { return c.Name == channelName }) if idx > -1 { log.Info().Str("channel", channelName).Msg("Found slack channel") @@ -75,3 +100,46 @@ func (s *service) findSlackChannel(channelName string) (channel *slack.Channel, log.Debug().Str("channel", channelName).Str("nextPage", nextCursor).Msg("Channel not found in current page, fetching next page") } } + +func runWithRetries[T any](operation func() (T, error), maxAttempts int, backoff time.Duration) (result T, err error) { + if maxAttempts <= 0 { + maxAttempts = 1 + } + + for attempt := 1; attempt <= maxAttempts; attempt++ { + result, err = operation() + if err == nil { + return result, nil + } + + if attempt == maxAttempts { + break + } + + var sleepDuration time.Duration + var rateLimitErr *slack.RateLimitedError + + if errors.As(err, &rateLimitErr) { + // Override the standard backoff with Slack's requested wait time + if rateLimitErr.RetryAfter > 0 { + sleepDuration = rateLimitErr.RetryAfter + } else { + // Use exponential backoff: backoff * 2^(attempt-1) + sleepDuration = backoff * time.Duration(1<<(attempt-1)) + } + + log.Warn(). + Err(err). + Int("attempt", attempt). + Dur("retry_after", sleepDuration). + Msg("Hit Slack rate limit, backing off dynamically") + } else { + sleepDuration = backoff * time.Duration(1<<(attempt-1)) + log.Warn().Err(err).Int("attempt", attempt).Dur("backoff", sleepDuration).Msg("Operation failed, retrying with exponential backoff") + } + + time.Sleep(sleepDuration) + } + + return result, fmt.Errorf("operation failed after %d attempts: %w", maxAttempts, err) +} diff --git a/internal/slack/slack_test.go b/internal/slack/slack_test.go index c4a1cf5..2bc96c7 100644 --- a/internal/slack/slack_test.go +++ b/internal/slack/slack_test.go @@ -1,7 +1,9 @@ package slack import ( + "errors" "testing" + "time" "github.com/slack-go/slack" "github.com/stretchr/testify/assert" @@ -35,7 +37,11 @@ func TestPostMessage(t *testing.T) { ) mockClient.On("PostMessage", channelID, mock.Anything).Return("", "", nil) - svc := service{&mockClient} + svc := service{ + client: &mockClient, + maxAttempts: 3, + initialBackoff: 2 * time.Second, + } _, err := svc.PostMessage(channelName, message) @@ -66,7 +72,11 @@ func TestFindSlackChannel(t *testing.T) { nil, ) - svc := service{&mockClient} + svc := service{ + client: &mockClient, + maxAttempts: 3, + initialBackoff: 2 * time.Second, + } channel, err := svc.findSlackChannel(channelName) @@ -89,3 +99,89 @@ func (c *mockClient) GetConversations(params *slack.GetConversationsParameters) args := c.Called(params) return args.Get(0).([]slack.Channel), args.String(1), args.Error(2) } + +// TestPostMessageWithRateLimitRetry verifies retry happens on rate limit errors +func TestPostMessageWithRateLimitRetry(t *testing.T) { + channelID := "test-channel" + + mockClient := mockClient{} + mockClient.On("GetConversations", mock.Anything).Return( + []slack.Channel{ + { + GroupConversation: slack.GroupConversation{ + Conversation: slack.Conversation{ID: channelID}, + Name: "test-channel", + }, + }, + }, + "", + nil, + ) + // First call fails with rate limit + mockClient.On("PostMessage", channelID, mock.Anything).Return("", "", errors.New("error: rate limit")).Once() + // Second call succeeds + mockClient.On("PostMessage", channelID, mock.Anything).Return("", "ts123", nil).Once() + + // Create service with minimal backoff for testing (1ms instead of 2s) + svc := service{ + client: &mockClient, + maxAttempts: 3, + initialBackoff: 1 * time.Microsecond, + } + + start := time.Now() + ts, err := svc.PostMessage(channelID, slack.MsgOptionText("test", false)) + + assert.Nil(t, err) + assert.Equal(t, "ts123", ts) + mockClient.AssertExpectations(t) + + // Verify it waited (at least the backoff time, which is now 1ms) + elapsed := time.Since(start) + assert.GreaterOrEqual(t, elapsed, 1*time.Microsecond, "should have waited for backoff") +} +func TestPostMessageWithDynamicRateLimitRetry(t *testing.T) { + channelID := "test-channel" + mockClient := mockClient{} + + // Setup mock channel resolution + mockClient.On("GetConversations", mock.Anything).Return( + []slack.Channel{ + { + GroupConversation: slack.GroupConversation{ + Conversation: slack.Conversation{ID: channelID}, + Name: "test-channel", + }, + }, + }, + "", + nil, + ) + + expectedWait := 50 * time.Millisecond + rateLimitErr := &slack.RateLimitedError{ + RetryAfter: expectedWait, + } + + mockClient.On("PostMessage", channelID, mock.Anything). + Return("", "", rateLimitErr).Once() + + mockClient.On("PostMessage", channelID, mock.Anything). + Return("", "ts123", nil).Once() + + svc := service{ + client: &mockClient, + maxAttempts: 3, + initialBackoff: 1 * time.Millisecond, + } + + start := time.Now() + ts, err := svc.PostMessage(channelID, slack.MsgOptionText("test", false)) + + assert.Nil(t, err) + assert.Equal(t, "ts123", ts) + mockClient.AssertExpectations(t) + + elapsed := time.Since(start) + assert.GreaterOrEqual(t, elapsed, expectedWait, "should have used Slack's dynamic RetryAfter backoff") +} diff --git a/lgtm.toml b/lgtm.toml index a71f55f..352360d 100644 --- a/lgtm.toml +++ b/lgtm.toml @@ -1,7 +1,7 @@ technologies = ["Golang"] categories = ["Correctness", "Quality", "Testing", "Security"] exclude = ["go.mod", "go.sum"] -model = "gemini-2.5-flash-preview-*" +model = "gemini-2.5-pro" silent = false publish = true ai_retries = 2