diff --git a/api/auth_middleware.go b/api/auth_middleware.go index 06811e3c..c408d668 100644 --- a/api/auth_middleware.go +++ b/api/auth_middleware.go @@ -270,6 +270,7 @@ func (app *ApiServer) authMiddleware(c *fiber.Ctx) error { if entry, ok := app.lookupOAuthAccessToken(c, bearerToken); ok { if myId == 0 || entry.UserID == myId { wallet = strings.ToLower(entry.ClientID) + c.Locals("oauthScope", entry.Scope) if myId == 0 { myId = entry.UserID c.Locals("myId", int(entry.UserID)) @@ -339,6 +340,17 @@ func (app *ApiServer) requireAuthMiddleware(c *fiber.Ctx) error { return c.Next() } +// Middleware that asserts the request carries write scope when authenticated via +// an OAuth PKCE token. Non-OAuth auth methods (signature, api_access_key) are +// always allowed through. Must be placed after authMiddleware. +func (app *ApiServer) requireWriteScope(c *fiber.Ctx) error { + if scope, ok := c.Locals("oauthScope").(string); ok && scope != "" && scope != "write" { + return fiber.NewError(fiber.StatusForbidden, "OAuth token scope insufficient: write scope required") + } + + return c.Next() +} + // Get a user from their wallet address. // // Note: Do NOT use this with `getAuthedWallet()` to infer the current user. diff --git a/api/auth_middleware_test.go b/api/auth_middleware_test.go index 98d8d651..d7edebe7 100644 --- a/api/auth_middleware_test.go +++ b/api/auth_middleware_test.go @@ -315,6 +315,50 @@ func TestGetApiSignerWithApiAccessKey(t *testing.T) { "body %s should contain address %s", string(body), parentApiKey) } +func TestRequireWriteScope(t *testing.T) { + // requireWriteScope only reads c.Locals("oauthScope"), so no DB is needed. + app := &ApiServer{} + + // Create a dummy write route that chains requireWriteScope after a scope-setting middleware + testApp := fiber.New() + testApp.Post("/write", func(c *fiber.Ctx) error { + // Simulate what authMiddleware does: set oauthScope if a PKCE token was used + scope := c.Get("X-Test-Oauth-Scope") + if scope != "" { + c.Locals("oauthScope", scope) + } + return c.Next() + }, app.requireWriteScope, func(c *fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + // PKCE token with scope=read should be rejected (403) + t.Run("read scope rejected", func(t *testing.T) { + req := httptest.NewRequest("POST", "/write", nil) + req.Header.Set("X-Test-Oauth-Scope", "read") + res, err := testApp.Test(req, -1) + assert.NoError(t, err) + assert.Equal(t, fiber.StatusForbidden, res.StatusCode) + }) + + // PKCE token with scope=write should be allowed (200) + t.Run("write scope allowed", func(t *testing.T) { + req := httptest.NewRequest("POST", "/write", nil) + req.Header.Set("X-Test-Oauth-Scope", "write") + res, err := testApp.Test(req, -1) + assert.NoError(t, err) + assert.Equal(t, fiber.StatusOK, res.StatusCode) + }) + + // Non-OAuth auth (no oauthScope set) should pass through (200) + t.Run("non-oauth auth passes through", func(t *testing.T) { + req := httptest.NewRequest("POST", "/write", nil) + res, err := testApp.Test(req, -1) + assert.NoError(t, err) + assert.Equal(t, fiber.StatusOK, res.StatusCode) + }) +} + // ensureApiKeysTables creates api_keys and api_access_keys if they do not exist. func ensureApiKeysTables(t *testing.T, app *ApiServer, ctx context.Context) { t.Helper() diff --git a/api/server.go b/api/server.go index 103e2ed8..5029e2d4 100644 --- a/api/server.go +++ b/api/server.go @@ -362,7 +362,7 @@ func NewApiServer(config config.Config) *ApiServer { for _, g := range []fiber.Router{v1, v1Full} { // Users g.Get("/users", app.v1Users) - g.Post("/users", app.requireAuthMiddleware, app.postV1Users) + g.Post("/users", app.requireAuthMiddleware, app.requireWriteScope, app.postV1Users) g.Get("/users/address", app.v1UserIdsByAddresses) g.Get("/users/search", app.v1UsersSearch) g.Get("/users/unclaimed_id", app.v1UsersUnclaimedId) @@ -395,11 +395,11 @@ func NewApiServer(config config.Config) *ApiServer { g.Get("/users/:userId/balance/history", app.v1UsersBalanceHistory) g.Get("/users/:userId/managers", app.v1UsersManagers) g.Get("/users/:userId/managed_users", app.v1UsersManagedUsers) - g.Post("/users/:userId/grants", app.requireAuthMiddleware, app.postV1UsersGrant) - g.Delete("/users/:userId/grants/:address", app.requireAuthMiddleware, app.deleteV1UsersGrant) - g.Post("/users/:userId/managers", app.requireAuthMiddleware, app.postV1UsersManager) - g.Delete("/users/:userId/managers/:managerUserId", app.requireAuthMiddleware, app.deleteV1UsersManager) - g.Post("/users/:userId/grants/approve", app.requireAuthMiddleware, app.postV1UsersApproveGrant) + g.Post("/users/:userId/grants", app.requireAuthMiddleware, app.requireWriteScope, app.postV1UsersGrant) + g.Delete("/users/:userId/grants/:address", app.requireAuthMiddleware, app.requireWriteScope, app.deleteV1UsersGrant) + g.Post("/users/:userId/managers", app.requireAuthMiddleware, app.requireWriteScope, app.postV1UsersManager) + g.Delete("/users/:userId/managers/:managerUserId", app.requireAuthMiddleware, app.requireWriteScope, app.deleteV1UsersManager) + g.Post("/users/:userId/grants/approve", app.requireAuthMiddleware, app.requireWriteScope, app.postV1UsersApproveGrant) g.Get("/users/:userId/mutuals", app.v1UsersMutuals) g.Get("/users/:userId/reposts", app.v1UsersReposts) g.Get("/users/:userId/related", app.v1UsersRelated) @@ -447,17 +447,17 @@ func NewApiServer(config config.Config) *ApiServer { g.Get("/users/:userId/developer-apps", app.v1UsersDeveloperApps) g.Get("/users/:userId/withdrawals/download", app.requireAuthForUserId, app.v1UsersWithdrawalsDownloadCsv) g.Get("/users/:userId/withdrawals/download/json", app.requireAuthForUserId, app.v1UsersWithdrawalsDownloadJson) - g.Post("/users/:userId/follow", app.requireAuthMiddleware, app.postV1UserFollow) - g.Delete("/users/:userId/follow", app.requireAuthMiddleware, app.deleteV1UserFollow) - g.Post("/users/:userId/subscribe", app.requireAuthMiddleware, app.postV1UserSubscribe) - g.Delete("/users/:userId/subscribe", app.requireAuthMiddleware, app.deleteV1UserSubscribe) - g.Post("/users/:userId/mute", app.requireAuthMiddleware, app.postV1UserMute) - g.Delete("/users/:userId/mute", app.requireAuthMiddleware, app.deleteV1UserMute) - g.Put("/users/:userId", app.requireAuthMiddleware, app.putV1User) + g.Post("/users/:userId/follow", app.requireAuthMiddleware, app.requireWriteScope, app.postV1UserFollow) + g.Delete("/users/:userId/follow", app.requireAuthMiddleware, app.requireWriteScope, app.deleteV1UserFollow) + g.Post("/users/:userId/subscribe", app.requireAuthMiddleware, app.requireWriteScope, app.postV1UserSubscribe) + g.Delete("/users/:userId/subscribe", app.requireAuthMiddleware, app.requireWriteScope, app.deleteV1UserSubscribe) + g.Post("/users/:userId/mute", app.requireAuthMiddleware, app.requireWriteScope, app.postV1UserMute) + g.Delete("/users/:userId/mute", app.requireAuthMiddleware, app.requireWriteScope, app.deleteV1UserMute) + g.Put("/users/:userId", app.requireAuthMiddleware, app.requireWriteScope, app.putV1User) // Tracks g.Get("/tracks", app.v1Tracks) - g.Post("/tracks", app.requireAuthMiddleware, app.postV1Tracks) + g.Post("/tracks", app.requireAuthMiddleware, app.requireWriteScope, app.postV1Tracks) g.Get("/tracks/search", app.v1TracksSearch) g.Get("/tracks/unclaimed_id", app.v1TracksUnclaimedId) @@ -482,16 +482,16 @@ func NewApiServer(config config.Config) *ApiServer { g.Get("/tracks/:trackId/access-info", app.v1TrackAccessInfo) g.Get("/tracks/:trackId/remixes", app.v1TrackRemixes) g.Get("/tracks/:trackId/reposts", app.v1TrackReposts) - g.Post("/tracks/:trackId/reposts", app.requireAuthMiddleware, app.postV1TrackRepost) - g.Delete("/tracks/:trackId/reposts", app.requireAuthMiddleware, app.deleteV1TrackRepost) + g.Post("/tracks/:trackId/reposts", app.requireAuthMiddleware, app.requireWriteScope, app.postV1TrackRepost) + g.Delete("/tracks/:trackId/reposts", app.requireAuthMiddleware, app.requireWriteScope, app.deleteV1TrackRepost) g.Get("/tracks/:trackId/stems", app.v1TrackStems) g.Get("/tracks/:trackId/favorites", app.v1TrackFavorites) - g.Post("/tracks/:trackId/favorites", app.requireAuthMiddleware, app.postV1TrackFavorite) - g.Delete("/tracks/:trackId/favorites", app.requireAuthMiddleware, app.deleteV1TrackFavorite) - g.Post("/tracks/:trackId/shares", app.requireAuthMiddleware, app.postV1TrackShare) - g.Post("/tracks/:trackId/downloads", app.requireAuthMiddleware, app.postV1TrackDownload) - g.Put("/tracks/:trackId", app.requireAuthMiddleware, app.putV1Track) - g.Delete("/tracks/:trackId", app.requireAuthMiddleware, app.deleteV1Track) + g.Post("/tracks/:trackId/favorites", app.requireAuthMiddleware, app.requireWriteScope, app.postV1TrackFavorite) + g.Delete("/tracks/:trackId/favorites", app.requireAuthMiddleware, app.requireWriteScope, app.deleteV1TrackFavorite) + g.Post("/tracks/:trackId/shares", app.requireAuthMiddleware, app.requireWriteScope, app.postV1TrackShare) + g.Post("/tracks/:trackId/downloads", app.requireAuthMiddleware, app.requireWriteScope, app.postV1TrackDownload) + g.Put("/tracks/:trackId", app.requireAuthMiddleware, app.requireWriteScope, app.putV1Track) + g.Delete("/tracks/:trackId", app.requireAuthMiddleware, app.requireWriteScope, app.deleteV1Track) g.Get("/tracks/:trackId/comments", app.v1TrackComments) g.Get("/tracks/:trackId/comment_count", app.v1TrackCommentCount) g.Get("/tracks/:trackId/comment-count", app.v1TrackCommentCount) @@ -504,7 +504,7 @@ func NewApiServer(config config.Config) *ApiServer { // Playlists g.Get("/playlists", app.v1Playlists) - g.Post("/playlists", app.requireAuthMiddleware, app.postV1Playlists) + g.Post("/playlists", app.requireAuthMiddleware, app.requireWriteScope, app.postV1Playlists) g.Get("/playlists/search", app.v1PlaylistsSearch) g.Get("/playlists/unclaimed_id", app.v1PlaylistsUnclaimedId) g.Get("/playlists/unclaimed-id", app.v1PlaylistsUnclaimedId) @@ -517,14 +517,14 @@ func NewApiServer(config config.Config) *ApiServer { g.Get("/playlists/:playlistId", app.v1Playlist) g.Get("/playlists/:playlistId/stream", app.v1PlaylistStream) g.Get("/playlists/:playlistId/reposts", app.v1PlaylistReposts) - g.Post("/playlists/:playlistId/reposts", app.requireAuthMiddleware, app.postV1PlaylistRepost) - g.Delete("/playlists/:playlistId/reposts", app.requireAuthMiddleware, app.deleteV1PlaylistRepost) + g.Post("/playlists/:playlistId/reposts", app.requireAuthMiddleware, app.requireWriteScope, app.postV1PlaylistRepost) + g.Delete("/playlists/:playlistId/reposts", app.requireAuthMiddleware, app.requireWriteScope, app.deleteV1PlaylistRepost) g.Get("/playlists/:playlistId/favorites", app.v1PlaylistFavorites) - g.Post("/playlists/:playlistId/favorites", app.requireAuthMiddleware, app.postV1PlaylistFavorite) - g.Delete("/playlists/:playlistId/favorites", app.requireAuthMiddleware, app.deleteV1PlaylistFavorite) - g.Post("/playlists/:playlistId/shares", app.requireAuthMiddleware, app.postV1PlaylistShare) - g.Put("/playlists/:playlistId", app.requireAuthMiddleware, app.putV1Playlist) - g.Delete("/playlists/:playlistId", app.requireAuthMiddleware, app.deleteV1Playlist) + g.Post("/playlists/:playlistId/favorites", app.requireAuthMiddleware, app.requireWriteScope, app.postV1PlaylistFavorite) + g.Delete("/playlists/:playlistId/favorites", app.requireAuthMiddleware, app.requireWriteScope, app.deleteV1PlaylistFavorite) + g.Post("/playlists/:playlistId/shares", app.requireAuthMiddleware, app.requireWriteScope, app.postV1PlaylistShare) + g.Put("/playlists/:playlistId", app.requireAuthMiddleware, app.requireWriteScope, app.putV1Playlist) + g.Delete("/playlists/:playlistId", app.requireAuthMiddleware, app.requireWriteScope, app.deleteV1Playlist) g.Get("/playlists/:playlistId/tracks", app.v1PlaylistTracks) // Explore @@ -538,18 +538,18 @@ func NewApiServer(config config.Config) *ApiServer { // Developer Apps g.Get("/developer_apps/:address", app.v1DeveloperApps) g.Get("/developer-apps/:address", app.v1DeveloperApps) - g.Post("/developer_apps", app.postV1UsersDeveloperApp) - g.Post("/developer-apps", app.postV1UsersDeveloperApp) - g.Put("/developer_apps/:address", app.putV1UsersDeveloperApp) - g.Put("/developer-apps/:address", app.putV1UsersDeveloperApp) - g.Delete("/developer_apps/:address", app.deleteV1UsersDeveloperApp) - g.Delete("/developer-apps/:address", app.deleteV1UsersDeveloperApp) - g.Post("/developer_apps/:address/access-keys/deactivate", app.postV1UsersDeveloperAppAccessKeyDeactivate) - g.Post("/developer-apps/:address/access-keys/deactivate", app.postV1UsersDeveloperAppAccessKeyDeactivate) - g.Post("/developer_apps/:address/register-api-key", app.requireAuthMiddleware, app.postV1UsersDeveloperAppRegisterApiKey) - g.Post("/developer-apps/:address/register-api-key", app.requireAuthMiddleware, app.postV1UsersDeveloperAppRegisterApiKey) - g.Post("/developer_apps/:address/access-keys", app.postV1UsersDeveloperAppAccessKey) - g.Post("/developer-apps/:address/access-keys", app.postV1UsersDeveloperAppAccessKey) + g.Post("/developer_apps", app.requireAuthMiddleware, app.requireWriteScope, app.postV1UsersDeveloperApp) + g.Post("/developer-apps", app.requireAuthMiddleware, app.requireWriteScope, app.postV1UsersDeveloperApp) + g.Put("/developer_apps/:address", app.requireAuthMiddleware, app.requireWriteScope, app.putV1UsersDeveloperApp) + g.Put("/developer-apps/:address", app.requireAuthMiddleware, app.requireWriteScope, app.putV1UsersDeveloperApp) + g.Delete("/developer_apps/:address", app.requireAuthMiddleware, app.requireWriteScope, app.deleteV1UsersDeveloperApp) + g.Delete("/developer-apps/:address", app.requireAuthMiddleware, app.requireWriteScope, app.deleteV1UsersDeveloperApp) + g.Post("/developer_apps/:address/access-keys/deactivate", app.requireAuthMiddleware, app.requireWriteScope, app.postV1UsersDeveloperAppAccessKeyDeactivate) + g.Post("/developer-apps/:address/access-keys/deactivate", app.requireAuthMiddleware, app.requireWriteScope, app.postV1UsersDeveloperAppAccessKeyDeactivate) + g.Post("/developer_apps/:address/register-api-key", app.requireAuthMiddleware, app.requireWriteScope, app.postV1UsersDeveloperAppRegisterApiKey) + g.Post("/developer-apps/:address/register-api-key", app.requireAuthMiddleware, app.requireWriteScope, app.postV1UsersDeveloperAppRegisterApiKey) + g.Post("/developer_apps/:address/access-keys", app.requireAuthMiddleware, app.requireWriteScope, app.postV1UsersDeveloperAppAccessKey) + g.Post("/developer-apps/:address/access-keys", app.requireAuthMiddleware, app.requireWriteScope, app.postV1UsersDeveloperAppAccessKey) // OAuth2 PKCE g.Get("/oauth/authorize", app.v1OAuthAuthorizeRedirect) @@ -573,15 +573,15 @@ func NewApiServer(config config.Config) *ApiServer { // Comments g.Get("/comments/unclaimed_id", app.v1CommentsUnclaimedId) g.Get("/comments/unclaimed-id", app.v1CommentsUnclaimedId) - g.Post("/comments", app.requireAuthMiddleware, app.postV1Comment) + g.Post("/comments", app.requireAuthMiddleware, app.requireWriteScope, app.postV1Comment) g.Get("/comments/:commentId", app.v1Comment) - g.Put("/comments/:commentId", app.requireAuthMiddleware, app.putV1Comment) - g.Delete("/comments/:commentId", app.requireAuthMiddleware, app.deleteV1Comment) - g.Post("/comments/:commentId/react", app.requireAuthMiddleware, app.postV1CommentReact) - g.Delete("/comments/:commentId/react", app.requireAuthMiddleware, app.deleteV1CommentReact) - g.Post("/comments/:commentId/pin", app.requireAuthMiddleware, app.postV1CommentPin) - g.Delete("/comments/:commentId/pin", app.requireAuthMiddleware, app.deleteV1CommentPin) - g.Post("/comments/:commentId/report", app.requireAuthMiddleware, app.postV1CommentReport) + g.Put("/comments/:commentId", app.requireAuthMiddleware, app.requireWriteScope, app.putV1Comment) + g.Delete("/comments/:commentId", app.requireAuthMiddleware, app.requireWriteScope, app.deleteV1Comment) + g.Post("/comments/:commentId/react", app.requireAuthMiddleware, app.requireWriteScope, app.postV1CommentReact) + g.Delete("/comments/:commentId/react", app.requireAuthMiddleware, app.requireWriteScope, app.deleteV1CommentReact) + g.Post("/comments/:commentId/pin", app.requireAuthMiddleware, app.requireWriteScope, app.postV1CommentPin) + g.Delete("/comments/:commentId/pin", app.requireAuthMiddleware, app.requireWriteScope, app.deleteV1CommentPin) + g.Post("/comments/:commentId/report", app.requireAuthMiddleware, app.requireWriteScope, app.postV1CommentReport) // Tips g.Get("/tips", app.v1Tips)