diff --git a/go/internal/e2e/multi_client_test.go b/go/internal/e2e/multi_client_test.go index 3c7dc34c3..406f118ce 100644 --- a/go/internal/e2e/multi_client_test.go +++ b/go/internal/e2e/multi_client_test.go @@ -16,11 +16,8 @@ import ( func TestMultiClient(t *testing.T) { // Use TCP mode so a second client can connect to the same CLI process ctx := testharness.NewTestContext(t) - client1 := copilot.NewClient(&copilot.ClientOptions{ - CLIPath: ctx.CLIPath, - Cwd: ctx.WorkDir, - Env: ctx.Env(), - UseStdio: copilot.Bool(false), + client1 := ctx.NewClient(func(opts *copilot.ClientOptions) { + opts.UseStdio = copilot.Bool(false) }) t.Cleanup(func() { client1.ForceStop() }) diff --git a/go/internal/e2e/testharness/context.go b/go/internal/e2e/testharness/context.go index b9edab1e5..1ec68d77e 100644 --- a/go/internal/e2e/testharness/context.go +++ b/go/internal/e2e/testharness/context.go @@ -158,7 +158,8 @@ func (c *TestContext) Env() []string { } // NewClient creates a CopilotClient configured for this test context. -func (c *TestContext) NewClient() *copilot.Client { +// Optional overrides can be applied to the default ClientOptions via the opts function. +func (c *TestContext) NewClient(opts ...func(*copilot.ClientOptions)) *copilot.Client { options := &copilot.ClientOptions{ CLIPath: c.CLIPath, Cwd: c.WorkDir, @@ -170,6 +171,10 @@ func (c *TestContext) NewClient() *copilot.Client { options.GitHubToken = "fake-token-for-e2e-tests" } + for _, opt := range opts { + opt(options) + } + return copilot.NewClient(options) } diff --git a/go/internal/jsonrpc2/jsonrpc2.go b/go/internal/jsonrpc2/jsonrpc2.go index fbc5b931c..8cf01e35a 100644 --- a/go/internal/jsonrpc2/jsonrpc2.go +++ b/go/internal/jsonrpc2/jsonrpc2.go @@ -214,9 +214,15 @@ func (c *Client) Request(method string, params any) (json.RawMessage, error) { } } - paramsData, err := json.Marshal(params) - if err != nil { - return nil, fmt.Errorf("failed to marshal params: %w", err) + var paramsData json.RawMessage + if params == nil { + paramsData = json.RawMessage("{}") + } else { + var err error + paramsData, err = json.Marshal(params) + if err != nil { + return nil, fmt.Errorf("failed to marshal params: %w", err) + } } // Send request @@ -224,7 +230,7 @@ func (c *Client) Request(method string, params any) (json.RawMessage, error) { JSONRPC: "2.0", ID: json.RawMessage(`"` + requestID + `"`), Method: method, - Params: json.RawMessage(paramsData), + Params: paramsData, } if err := c.sendMessage(request); err != nil { @@ -261,15 +267,19 @@ func (c *Client) Request(method string, params any) (json.RawMessage, error) { // Notify sends a JSON-RPC notification (no response expected) func (c *Client) Notify(method string, params any) error { - paramsData, err := json.Marshal(params) - if err != nil { - return fmt.Errorf("failed to marshal params: %w", err) + var paramsData json.RawMessage + if params != nil { + var err error + paramsData, err = json.Marshal(params) + if err != nil { + return fmt.Errorf("failed to marshal params: %w", err) + } } notification := Request{ JSONRPC: "2.0", Method: method, - Params: json.RawMessage(paramsData), + Params: paramsData, } return c.sendMessage(notification) } diff --git a/go/rpc/generated_rpc.go b/go/rpc/generated_rpc.go index b9ba408b5..f6232399c 100644 --- a/go/rpc/generated_rpc.go +++ b/go/rpc/generated_rpc.go @@ -102,7 +102,7 @@ type Tool struct { // tools) NamespacedName *string `json:"namespacedName,omitempty"` // JSON Schema for the tool's input parameters - Parameters map[string]interface{} `json:"parameters,omitempty"` + Parameters map[string]any `json:"parameters,omitempty"` } type ToolsListParams struct { @@ -453,10 +453,10 @@ type SessionToolsHandlePendingToolCallParams struct { } type ResultResult struct { - Error *string `json:"error,omitempty"` - ResultType *string `json:"resultType,omitempty"` - TextResultForLlm string `json:"textResultForLlm"` - ToolTelemetry map[string]interface{} `json:"toolTelemetry,omitempty"` + Error *string `json:"error,omitempty"` + ResultType *string `json:"resultType,omitempty"` + TextResultForLlm string `json:"textResultForLlm"` + ToolTelemetry map[string]any `json:"toolTelemetry,omitempty"` } type SessionCommandsHandlePendingCommandResult struct { @@ -539,11 +539,11 @@ type SessionPermissionsHandlePendingPermissionRequestParams struct { } type SessionPermissionsHandlePendingPermissionRequestParamsResult struct { - Kind Kind `json:"kind"` - Rules []interface{} `json:"rules,omitempty"` - Feedback *string `json:"feedback,omitempty"` - Message *string `json:"message,omitempty"` - Path *string `json:"path,omitempty"` + Kind Kind `json:"kind"` + Rules []any `json:"rules,omitempty"` + Feedback *string `json:"feedback,omitempty"` + Message *string `json:"message,omitempty"` + Path *string `json:"path,omitempty"` } type SessionLogResult struct { @@ -712,12 +712,14 @@ type Content struct { StringArray []string } -type ServerModelsRpcApi struct { +type serverApi struct { client *jsonrpc2.Client } -func (a *ServerModelsRpcApi) List(ctx context.Context) (*ModelsListResult, error) { - raw, err := a.client.Request("models.list", map[string]interface{}{}) +type ServerModelsApi serverApi + +func (a *ServerModelsApi) List(ctx context.Context) (*ModelsListResult, error) { + raw, err := a.client.Request("models.list", nil) if err != nil { return nil, err } @@ -728,11 +730,9 @@ func (a *ServerModelsRpcApi) List(ctx context.Context) (*ModelsListResult, error return &result, nil } -type ServerToolsRpcApi struct { - client *jsonrpc2.Client -} +type ServerToolsApi serverApi -func (a *ServerToolsRpcApi) List(ctx context.Context, params *ToolsListParams) (*ToolsListResult, error) { +func (a *ServerToolsApi) List(ctx context.Context, params *ToolsListParams) (*ToolsListResult, error) { raw, err := a.client.Request("tools.list", params) if err != nil { return nil, err @@ -744,12 +744,10 @@ func (a *ServerToolsRpcApi) List(ctx context.Context, params *ToolsListParams) ( return &result, nil } -type ServerAccountRpcApi struct { - client *jsonrpc2.Client -} +type ServerAccountApi serverApi -func (a *ServerAccountRpcApi) GetQuota(ctx context.Context) (*AccountGetQuotaResult, error) { - raw, err := a.client.Request("account.getQuota", map[string]interface{}{}) +func (a *ServerAccountApi) GetQuota(ctx context.Context) (*AccountGetQuotaResult, error) { + raw, err := a.client.Request("account.getQuota", nil) if err != nil { return nil, err } @@ -762,14 +760,15 @@ func (a *ServerAccountRpcApi) GetQuota(ctx context.Context) (*AccountGetQuotaRes // ServerRpc provides typed server-scoped RPC methods. type ServerRpc struct { - client *jsonrpc2.Client - Models *ServerModelsRpcApi - Tools *ServerToolsRpcApi - Account *ServerAccountRpcApi + common serverApi // Reuse a single struct instead of allocating one for each service on the heap. + + Models *ServerModelsApi + Tools *ServerToolsApi + Account *ServerAccountApi } func (a *ServerRpc) Ping(ctx context.Context, params *PingParams) (*PingResult, error) { - raw, err := a.client.Request("ping", params) + raw, err := a.common.client.Request("ping", params) if err != nil { return nil, err } @@ -781,20 +780,23 @@ func (a *ServerRpc) Ping(ctx context.Context, params *PingParams) (*PingResult, } func NewServerRpc(client *jsonrpc2.Client) *ServerRpc { - return &ServerRpc{client: client, - Models: &ServerModelsRpcApi{client: client}, - Tools: &ServerToolsRpcApi{client: client}, - Account: &ServerAccountRpcApi{client: client}, - } + r := &ServerRpc{} + r.common = serverApi{client: client} + r.Models = (*ServerModelsApi)(&r.common) + r.Tools = (*ServerToolsApi)(&r.common) + r.Account = (*ServerAccountApi)(&r.common) + return r } -type ModelRpcApi struct { +type sessionApi struct { client *jsonrpc2.Client sessionID string } -func (a *ModelRpcApi) GetCurrent(ctx context.Context) (*SessionModelGetCurrentResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +type ModelApi sessionApi + +func (a *ModelApi) GetCurrent(ctx context.Context) (*SessionModelGetCurrentResult, error) { + req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.model.getCurrent", req) if err != nil { return nil, err @@ -806,8 +808,8 @@ func (a *ModelRpcApi) GetCurrent(ctx context.Context) (*SessionModelGetCurrentRe return &result, nil } -func (a *ModelRpcApi) SwitchTo(ctx context.Context, params *SessionModelSwitchToParams) (*SessionModelSwitchToResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *ModelApi) SwitchTo(ctx context.Context, params *SessionModelSwitchToParams) (*SessionModelSwitchToResult, error) { + req := map[string]any{"sessionId": a.sessionID} if params != nil { req["modelId"] = params.ModelID if params.ReasoningEffort != nil { @@ -825,13 +827,10 @@ func (a *ModelRpcApi) SwitchTo(ctx context.Context, params *SessionModelSwitchTo return &result, nil } -type ModeRpcApi struct { - client *jsonrpc2.Client - sessionID string -} +type ModeApi sessionApi -func (a *ModeRpcApi) Get(ctx context.Context) (*SessionModeGetResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *ModeApi) Get(ctx context.Context) (*SessionModeGetResult, error) { + req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.mode.get", req) if err != nil { return nil, err @@ -843,8 +842,8 @@ func (a *ModeRpcApi) Get(ctx context.Context) (*SessionModeGetResult, error) { return &result, nil } -func (a *ModeRpcApi) Set(ctx context.Context, params *SessionModeSetParams) (*SessionModeSetResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *ModeApi) Set(ctx context.Context, params *SessionModeSetParams) (*SessionModeSetResult, error) { + req := map[string]any{"sessionId": a.sessionID} if params != nil { req["mode"] = params.Mode } @@ -859,13 +858,10 @@ func (a *ModeRpcApi) Set(ctx context.Context, params *SessionModeSetParams) (*Se return &result, nil } -type PlanRpcApi struct { - client *jsonrpc2.Client - sessionID string -} +type PlanApi sessionApi -func (a *PlanRpcApi) Read(ctx context.Context) (*SessionPlanReadResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *PlanApi) Read(ctx context.Context) (*SessionPlanReadResult, error) { + req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.plan.read", req) if err != nil { return nil, err @@ -877,8 +873,8 @@ func (a *PlanRpcApi) Read(ctx context.Context) (*SessionPlanReadResult, error) { return &result, nil } -func (a *PlanRpcApi) Update(ctx context.Context, params *SessionPlanUpdateParams) (*SessionPlanUpdateResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *PlanApi) Update(ctx context.Context, params *SessionPlanUpdateParams) (*SessionPlanUpdateResult, error) { + req := map[string]any{"sessionId": a.sessionID} if params != nil { req["content"] = params.Content } @@ -893,8 +889,8 @@ func (a *PlanRpcApi) Update(ctx context.Context, params *SessionPlanUpdateParams return &result, nil } -func (a *PlanRpcApi) Delete(ctx context.Context) (*SessionPlanDeleteResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *PlanApi) Delete(ctx context.Context) (*SessionPlanDeleteResult, error) { + req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.plan.delete", req) if err != nil { return nil, err @@ -906,13 +902,10 @@ func (a *PlanRpcApi) Delete(ctx context.Context) (*SessionPlanDeleteResult, erro return &result, nil } -type WorkspaceRpcApi struct { - client *jsonrpc2.Client - sessionID string -} +type WorkspaceApi sessionApi -func (a *WorkspaceRpcApi) ListFiles(ctx context.Context) (*SessionWorkspaceListFilesResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *WorkspaceApi) ListFiles(ctx context.Context) (*SessionWorkspaceListFilesResult, error) { + req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.workspace.listFiles", req) if err != nil { return nil, err @@ -924,8 +917,8 @@ func (a *WorkspaceRpcApi) ListFiles(ctx context.Context) (*SessionWorkspaceListF return &result, nil } -func (a *WorkspaceRpcApi) ReadFile(ctx context.Context, params *SessionWorkspaceReadFileParams) (*SessionWorkspaceReadFileResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *WorkspaceApi) ReadFile(ctx context.Context, params *SessionWorkspaceReadFileParams) (*SessionWorkspaceReadFileResult, error) { + req := map[string]any{"sessionId": a.sessionID} if params != nil { req["path"] = params.Path } @@ -940,8 +933,8 @@ func (a *WorkspaceRpcApi) ReadFile(ctx context.Context, params *SessionWorkspace return &result, nil } -func (a *WorkspaceRpcApi) CreateFile(ctx context.Context, params *SessionWorkspaceCreateFileParams) (*SessionWorkspaceCreateFileResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *WorkspaceApi) CreateFile(ctx context.Context, params *SessionWorkspaceCreateFileParams) (*SessionWorkspaceCreateFileResult, error) { + req := map[string]any{"sessionId": a.sessionID} if params != nil { req["path"] = params.Path req["content"] = params.Content @@ -957,14 +950,11 @@ func (a *WorkspaceRpcApi) CreateFile(ctx context.Context, params *SessionWorkspa return &result, nil } -// Experimental: FleetRpcApi contains experimental APIs that may change or be removed. -type FleetRpcApi struct { - client *jsonrpc2.Client - sessionID string -} +// Experimental: FleetApi contains experimental APIs that may change or be removed. +type FleetApi sessionApi -func (a *FleetRpcApi) Start(ctx context.Context, params *SessionFleetStartParams) (*SessionFleetStartResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *FleetApi) Start(ctx context.Context, params *SessionFleetStartParams) (*SessionFleetStartResult, error) { + req := map[string]any{"sessionId": a.sessionID} if params != nil { if params.Prompt != nil { req["prompt"] = *params.Prompt @@ -981,14 +971,11 @@ func (a *FleetRpcApi) Start(ctx context.Context, params *SessionFleetStartParams return &result, nil } -// Experimental: AgentRpcApi contains experimental APIs that may change or be removed. -type AgentRpcApi struct { - client *jsonrpc2.Client - sessionID string -} +// Experimental: AgentApi contains experimental APIs that may change or be removed. +type AgentApi sessionApi -func (a *AgentRpcApi) List(ctx context.Context) (*SessionAgentListResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *AgentApi) List(ctx context.Context) (*SessionAgentListResult, error) { + req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.agent.list", req) if err != nil { return nil, err @@ -1000,8 +987,8 @@ func (a *AgentRpcApi) List(ctx context.Context) (*SessionAgentListResult, error) return &result, nil } -func (a *AgentRpcApi) GetCurrent(ctx context.Context) (*SessionAgentGetCurrentResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *AgentApi) GetCurrent(ctx context.Context) (*SessionAgentGetCurrentResult, error) { + req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.agent.getCurrent", req) if err != nil { return nil, err @@ -1013,8 +1000,8 @@ func (a *AgentRpcApi) GetCurrent(ctx context.Context) (*SessionAgentGetCurrentRe return &result, nil } -func (a *AgentRpcApi) Select(ctx context.Context, params *SessionAgentSelectParams) (*SessionAgentSelectResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *AgentApi) Select(ctx context.Context, params *SessionAgentSelectParams) (*SessionAgentSelectResult, error) { + req := map[string]any{"sessionId": a.sessionID} if params != nil { req["name"] = params.Name } @@ -1029,8 +1016,8 @@ func (a *AgentRpcApi) Select(ctx context.Context, params *SessionAgentSelectPara return &result, nil } -func (a *AgentRpcApi) Deselect(ctx context.Context) (*SessionAgentDeselectResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *AgentApi) Deselect(ctx context.Context) (*SessionAgentDeselectResult, error) { + req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.agent.deselect", req) if err != nil { return nil, err @@ -1042,8 +1029,8 @@ func (a *AgentRpcApi) Deselect(ctx context.Context) (*SessionAgentDeselectResult return &result, nil } -func (a *AgentRpcApi) Reload(ctx context.Context) (*SessionAgentReloadResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *AgentApi) Reload(ctx context.Context) (*SessionAgentReloadResult, error) { + req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.agent.reload", req) if err != nil { return nil, err @@ -1055,14 +1042,11 @@ func (a *AgentRpcApi) Reload(ctx context.Context) (*SessionAgentReloadResult, er return &result, nil } -// Experimental: SkillsRpcApi contains experimental APIs that may change or be removed. -type SkillsRpcApi struct { - client *jsonrpc2.Client - sessionID string -} +// Experimental: SkillsApi contains experimental APIs that may change or be removed. +type SkillsApi sessionApi -func (a *SkillsRpcApi) List(ctx context.Context) (*SessionSkillsListResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *SkillsApi) List(ctx context.Context) (*SessionSkillsListResult, error) { + req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.skills.list", req) if err != nil { return nil, err @@ -1074,8 +1058,8 @@ func (a *SkillsRpcApi) List(ctx context.Context) (*SessionSkillsListResult, erro return &result, nil } -func (a *SkillsRpcApi) Enable(ctx context.Context, params *SessionSkillsEnableParams) (*SessionSkillsEnableResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *SkillsApi) Enable(ctx context.Context, params *SessionSkillsEnableParams) (*SessionSkillsEnableResult, error) { + req := map[string]any{"sessionId": a.sessionID} if params != nil { req["name"] = params.Name } @@ -1090,8 +1074,8 @@ func (a *SkillsRpcApi) Enable(ctx context.Context, params *SessionSkillsEnablePa return &result, nil } -func (a *SkillsRpcApi) Disable(ctx context.Context, params *SessionSkillsDisableParams) (*SessionSkillsDisableResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *SkillsApi) Disable(ctx context.Context, params *SessionSkillsDisableParams) (*SessionSkillsDisableResult, error) { + req := map[string]any{"sessionId": a.sessionID} if params != nil { req["name"] = params.Name } @@ -1106,8 +1090,8 @@ func (a *SkillsRpcApi) Disable(ctx context.Context, params *SessionSkillsDisable return &result, nil } -func (a *SkillsRpcApi) Reload(ctx context.Context) (*SessionSkillsReloadResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *SkillsApi) Reload(ctx context.Context) (*SessionSkillsReloadResult, error) { + req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.skills.reload", req) if err != nil { return nil, err @@ -1119,14 +1103,11 @@ func (a *SkillsRpcApi) Reload(ctx context.Context) (*SessionSkillsReloadResult, return &result, nil } -// Experimental: McpRpcApi contains experimental APIs that may change or be removed. -type McpRpcApi struct { - client *jsonrpc2.Client - sessionID string -} +// Experimental: McpApi contains experimental APIs that may change or be removed. +type McpApi sessionApi -func (a *McpRpcApi) List(ctx context.Context) (*SessionMCPListResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *McpApi) List(ctx context.Context) (*SessionMCPListResult, error) { + req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.mcp.list", req) if err != nil { return nil, err @@ -1138,8 +1119,8 @@ func (a *McpRpcApi) List(ctx context.Context) (*SessionMCPListResult, error) { return &result, nil } -func (a *McpRpcApi) Enable(ctx context.Context, params *SessionMCPEnableParams) (*SessionMCPEnableResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *McpApi) Enable(ctx context.Context, params *SessionMCPEnableParams) (*SessionMCPEnableResult, error) { + req := map[string]any{"sessionId": a.sessionID} if params != nil { req["serverName"] = params.ServerName } @@ -1154,8 +1135,8 @@ func (a *McpRpcApi) Enable(ctx context.Context, params *SessionMCPEnableParams) return &result, nil } -func (a *McpRpcApi) Disable(ctx context.Context, params *SessionMCPDisableParams) (*SessionMCPDisableResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *McpApi) Disable(ctx context.Context, params *SessionMCPDisableParams) (*SessionMCPDisableResult, error) { + req := map[string]any{"sessionId": a.sessionID} if params != nil { req["serverName"] = params.ServerName } @@ -1170,8 +1151,8 @@ func (a *McpRpcApi) Disable(ctx context.Context, params *SessionMCPDisableParams return &result, nil } -func (a *McpRpcApi) Reload(ctx context.Context) (*SessionMCPReloadResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *McpApi) Reload(ctx context.Context) (*SessionMCPReloadResult, error) { + req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.mcp.reload", req) if err != nil { return nil, err @@ -1183,14 +1164,11 @@ func (a *McpRpcApi) Reload(ctx context.Context) (*SessionMCPReloadResult, error) return &result, nil } -// Experimental: PluginsRpcApi contains experimental APIs that may change or be removed. -type PluginsRpcApi struct { - client *jsonrpc2.Client - sessionID string -} +// Experimental: PluginsApi contains experimental APIs that may change or be removed. +type PluginsApi sessionApi -func (a *PluginsRpcApi) List(ctx context.Context) (*SessionPluginsListResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *PluginsApi) List(ctx context.Context) (*SessionPluginsListResult, error) { + req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.plugins.list", req) if err != nil { return nil, err @@ -1202,14 +1180,11 @@ func (a *PluginsRpcApi) List(ctx context.Context) (*SessionPluginsListResult, er return &result, nil } -// Experimental: ExtensionsRpcApi contains experimental APIs that may change or be removed. -type ExtensionsRpcApi struct { - client *jsonrpc2.Client - sessionID string -} +// Experimental: ExtensionsApi contains experimental APIs that may change or be removed. +type ExtensionsApi sessionApi -func (a *ExtensionsRpcApi) List(ctx context.Context) (*SessionExtensionsListResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *ExtensionsApi) List(ctx context.Context) (*SessionExtensionsListResult, error) { + req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.extensions.list", req) if err != nil { return nil, err @@ -1221,8 +1196,8 @@ func (a *ExtensionsRpcApi) List(ctx context.Context) (*SessionExtensionsListResu return &result, nil } -func (a *ExtensionsRpcApi) Enable(ctx context.Context, params *SessionExtensionsEnableParams) (*SessionExtensionsEnableResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *ExtensionsApi) Enable(ctx context.Context, params *SessionExtensionsEnableParams) (*SessionExtensionsEnableResult, error) { + req := map[string]any{"sessionId": a.sessionID} if params != nil { req["id"] = params.ID } @@ -1237,8 +1212,8 @@ func (a *ExtensionsRpcApi) Enable(ctx context.Context, params *SessionExtensions return &result, nil } -func (a *ExtensionsRpcApi) Disable(ctx context.Context, params *SessionExtensionsDisableParams) (*SessionExtensionsDisableResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *ExtensionsApi) Disable(ctx context.Context, params *SessionExtensionsDisableParams) (*SessionExtensionsDisableResult, error) { + req := map[string]any{"sessionId": a.sessionID} if params != nil { req["id"] = params.ID } @@ -1253,8 +1228,8 @@ func (a *ExtensionsRpcApi) Disable(ctx context.Context, params *SessionExtension return &result, nil } -func (a *ExtensionsRpcApi) Reload(ctx context.Context) (*SessionExtensionsReloadResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *ExtensionsApi) Reload(ctx context.Context) (*SessionExtensionsReloadResult, error) { + req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.extensions.reload", req) if err != nil { return nil, err @@ -1266,14 +1241,11 @@ func (a *ExtensionsRpcApi) Reload(ctx context.Context) (*SessionExtensionsReload return &result, nil } -// Experimental: CompactionRpcApi contains experimental APIs that may change or be removed. -type CompactionRpcApi struct { - client *jsonrpc2.Client - sessionID string -} +// Experimental: CompactionApi contains experimental APIs that may change or be removed. +type CompactionApi sessionApi -func (a *CompactionRpcApi) Compact(ctx context.Context) (*SessionCompactionCompactResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *CompactionApi) Compact(ctx context.Context) (*SessionCompactionCompactResult, error) { + req := map[string]any{"sessionId": a.sessionID} raw, err := a.client.Request("session.compaction.compact", req) if err != nil { return nil, err @@ -1285,13 +1257,10 @@ func (a *CompactionRpcApi) Compact(ctx context.Context) (*SessionCompactionCompa return &result, nil } -type ToolsRpcApi struct { - client *jsonrpc2.Client - sessionID string -} +type ToolsApi sessionApi -func (a *ToolsRpcApi) HandlePendingToolCall(ctx context.Context, params *SessionToolsHandlePendingToolCallParams) (*SessionToolsHandlePendingToolCallResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *ToolsApi) HandlePendingToolCall(ctx context.Context, params *SessionToolsHandlePendingToolCallParams) (*SessionToolsHandlePendingToolCallResult, error) { + req := map[string]any{"sessionId": a.sessionID} if params != nil { req["requestId"] = params.RequestID if params.Result != nil { @@ -1312,13 +1281,10 @@ func (a *ToolsRpcApi) HandlePendingToolCall(ctx context.Context, params *Session return &result, nil } -type CommandsRpcApi struct { - client *jsonrpc2.Client - sessionID string -} +type CommandsApi sessionApi -func (a *CommandsRpcApi) HandlePendingCommand(ctx context.Context, params *SessionCommandsHandlePendingCommandParams) (*SessionCommandsHandlePendingCommandResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *CommandsApi) HandlePendingCommand(ctx context.Context, params *SessionCommandsHandlePendingCommandParams) (*SessionCommandsHandlePendingCommandResult, error) { + req := map[string]any{"sessionId": a.sessionID} if params != nil { req["requestId"] = params.RequestID if params.Error != nil { @@ -1336,13 +1302,10 @@ func (a *CommandsRpcApi) HandlePendingCommand(ctx context.Context, params *Sessi return &result, nil } -type UiRpcApi struct { - client *jsonrpc2.Client - sessionID string -} +type UiApi sessionApi -func (a *UiRpcApi) Elicitation(ctx context.Context, params *SessionUIElicitationParams) (*SessionUIElicitationResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *UiApi) Elicitation(ctx context.Context, params *SessionUIElicitationParams) (*SessionUIElicitationResult, error) { + req := map[string]any{"sessionId": a.sessionID} if params != nil { req["message"] = params.Message req["requestedSchema"] = params.RequestedSchema @@ -1358,13 +1321,10 @@ func (a *UiRpcApi) Elicitation(ctx context.Context, params *SessionUIElicitation return &result, nil } -type PermissionsRpcApi struct { - client *jsonrpc2.Client - sessionID string -} +type PermissionsApi sessionApi -func (a *PermissionsRpcApi) HandlePendingPermissionRequest(ctx context.Context, params *SessionPermissionsHandlePendingPermissionRequestParams) (*SessionPermissionsHandlePendingPermissionRequestResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *PermissionsApi) HandlePendingPermissionRequest(ctx context.Context, params *SessionPermissionsHandlePendingPermissionRequestParams) (*SessionPermissionsHandlePendingPermissionRequestResult, error) { + req := map[string]any{"sessionId": a.sessionID} if params != nil { req["requestId"] = params.RequestID req["result"] = params.Result @@ -1380,13 +1340,10 @@ func (a *PermissionsRpcApi) HandlePendingPermissionRequest(ctx context.Context, return &result, nil } -type ShellRpcApi struct { - client *jsonrpc2.Client - sessionID string -} +type ShellApi sessionApi -func (a *ShellRpcApi) Exec(ctx context.Context, params *SessionShellExecParams) (*SessionShellExecResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *ShellApi) Exec(ctx context.Context, params *SessionShellExecParams) (*SessionShellExecResult, error) { + req := map[string]any{"sessionId": a.sessionID} if params != nil { req["command"] = params.Command if params.Cwd != nil { @@ -1407,8 +1364,8 @@ func (a *ShellRpcApi) Exec(ctx context.Context, params *SessionShellExecParams) return &result, nil } -func (a *ShellRpcApi) Kill(ctx context.Context, params *SessionShellKillParams) (*SessionShellKillResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} +func (a *ShellApi) Kill(ctx context.Context, params *SessionShellKillParams) (*SessionShellKillResult, error) { + req := map[string]any{"sessionId": a.sessionID} if params != nil { req["processId"] = params.ProcessID if params.Signal != nil { @@ -1428,28 +1385,28 @@ func (a *ShellRpcApi) Kill(ctx context.Context, params *SessionShellKillParams) // SessionRpc provides typed session-scoped RPC methods. type SessionRpc struct { - client *jsonrpc2.Client - sessionID string - Model *ModelRpcApi - Mode *ModeRpcApi - Plan *PlanRpcApi - Workspace *WorkspaceRpcApi - Fleet *FleetRpcApi - Agent *AgentRpcApi - Skills *SkillsRpcApi - Mcp *McpRpcApi - Plugins *PluginsRpcApi - Extensions *ExtensionsRpcApi - Compaction *CompactionRpcApi - Tools *ToolsRpcApi - Commands *CommandsRpcApi - Ui *UiRpcApi - Permissions *PermissionsRpcApi - Shell *ShellRpcApi + common sessionApi // Reuse a single struct instead of allocating one for each service on the heap. + + Model *ModelApi + Mode *ModeApi + Plan *PlanApi + Workspace *WorkspaceApi + Fleet *FleetApi + Agent *AgentApi + Skills *SkillsApi + Mcp *McpApi + Plugins *PluginsApi + Extensions *ExtensionsApi + Compaction *CompactionApi + Tools *ToolsApi + Commands *CommandsApi + Ui *UiApi + Permissions *PermissionsApi + Shell *ShellApi } func (a *SessionRpc) Log(ctx context.Context, params *SessionLogParams) (*SessionLogResult, error) { - req := map[string]interface{}{"sessionId": a.sessionID} + req := map[string]any{"sessionId": a.common.sessionID} if params != nil { req["message"] = params.Message if params.Level != nil { @@ -1462,7 +1419,7 @@ func (a *SessionRpc) Log(ctx context.Context, params *SessionLogParams) (*Sessio req["url"] = *params.URL } } - raw, err := a.client.Request("session.log", req) + raw, err := a.common.client.Request("session.log", req) if err != nil { return nil, err } @@ -1474,22 +1431,23 @@ func (a *SessionRpc) Log(ctx context.Context, params *SessionLogParams) (*Sessio } func NewSessionRpc(client *jsonrpc2.Client, sessionID string) *SessionRpc { - return &SessionRpc{client: client, sessionID: sessionID, - Model: &ModelRpcApi{client: client, sessionID: sessionID}, - Mode: &ModeRpcApi{client: client, sessionID: sessionID}, - Plan: &PlanRpcApi{client: client, sessionID: sessionID}, - Workspace: &WorkspaceRpcApi{client: client, sessionID: sessionID}, - Fleet: &FleetRpcApi{client: client, sessionID: sessionID}, - Agent: &AgentRpcApi{client: client, sessionID: sessionID}, - Skills: &SkillsRpcApi{client: client, sessionID: sessionID}, - Mcp: &McpRpcApi{client: client, sessionID: sessionID}, - Plugins: &PluginsRpcApi{client: client, sessionID: sessionID}, - Extensions: &ExtensionsRpcApi{client: client, sessionID: sessionID}, - Compaction: &CompactionRpcApi{client: client, sessionID: sessionID}, - Tools: &ToolsRpcApi{client: client, sessionID: sessionID}, - Commands: &CommandsRpcApi{client: client, sessionID: sessionID}, - Ui: &UiRpcApi{client: client, sessionID: sessionID}, - Permissions: &PermissionsRpcApi{client: client, sessionID: sessionID}, - Shell: &ShellRpcApi{client: client, sessionID: sessionID}, - } + r := &SessionRpc{} + r.common = sessionApi{client: client, sessionID: sessionID} + r.Model = (*ModelApi)(&r.common) + r.Mode = (*ModeApi)(&r.common) + r.Plan = (*PlanApi)(&r.common) + r.Workspace = (*WorkspaceApi)(&r.common) + r.Fleet = (*FleetApi)(&r.common) + r.Agent = (*AgentApi)(&r.common) + r.Skills = (*SkillsApi)(&r.common) + r.Mcp = (*McpApi)(&r.common) + r.Plugins = (*PluginsApi)(&r.common) + r.Extensions = (*ExtensionsApi)(&r.common) + r.Compaction = (*CompactionApi)(&r.common) + r.Tools = (*ToolsApi)(&r.common) + r.Commands = (*CommandsApi)(&r.common) + r.Ui = (*UiApi)(&r.common) + r.Permissions = (*PermissionsApi)(&r.common) + r.Shell = (*ShellApi)(&r.common) + return r } diff --git a/scripts/codegen/go.ts b/scripts/codegen/go.ts index 59abee298..5c6a71b23 100644 --- a/scripts/codegen/go.ts +++ b/scripts/codegen/go.ts @@ -8,16 +8,16 @@ import { execFile } from "child_process"; import fs from "fs/promises"; -import { promisify } from "util"; import type { JSONSchema7 } from "json-schema"; import { FetchingJSONSchemaStore, InputData, JSONSchemaInput, quicktype } from "quicktype-core"; +import { promisify } from "util"; import { - getSessionEventsSchemaPath, getApiSchemaPath, + getSessionEventsSchemaPath, + isNodeFullyExperimental, + isRpcMethod, postProcessSchema, writeGeneratedFile, - isRpcMethod, - isNodeFullyExperimental, type ApiSchema, type RpcMethod, } from "./utils.js"; @@ -261,6 +261,8 @@ async function generateRpc(schemaPath?: string): Promise { } // Remove trailing blank lines from quicktype output before appending qtCode = qtCode.replace(/\n+$/, ""); + // Replace interface{} with any (quicktype emits the pre-1.18 form) + qtCode = qtCode.replace(/\binterface\{\}/g, "any"); // Build method wrappers const lines: string[] = []; @@ -301,9 +303,17 @@ function emitRpcWrapper(lines: string[], node: Record, isSessio const topLevelMethods = Object.entries(node).filter(([, v]) => isRpcMethod(v)); const wrapperName = isSession ? "SessionRpc" : "ServerRpc"; - const apiSuffix = "RpcApi"; + const apiSuffix = "Api"; + const serviceName = isSession ? "sessionApi" : "serverApi"; + + // Emit the common service struct (unexported, shared by all API groups via type cast) + lines.push(`type ${serviceName} struct {`); + lines.push(`\tclient *jsonrpc2.Client`); + if (isSession) lines.push(`\tsessionID string`); + lines.push(`}`); + lines.push(``); - // Emit API structs for groups + // Emit API types for groups for (const [groupName, groupNode] of groups) { const prefix = isSession ? "" : "Server"; const apiName = prefix + toPascalCase(groupName) + apiSuffix; @@ -311,14 +321,7 @@ function emitRpcWrapper(lines: string[], node: Record, isSessio if (groupExperimental) { lines.push(`// Experimental: ${apiName} contains experimental APIs that may change or be removed.`); } - lines.push(`type ${apiName} struct {`); - if (isSession) { - lines.push(`\tclient *jsonrpc2.Client`); - lines.push(`\tsessionID string`); - } else { - lines.push(`\tclient *jsonrpc2.Client`); - } - lines.push(`}`); + lines.push(`type ${apiName} ${serviceName}`); lines.push(``); for (const [key, value] of Object.entries(groupNode as Record)) { if (!isRpcMethod(value)) continue; @@ -328,15 +331,15 @@ function emitRpcWrapper(lines: string[], node: Record, isSessio // Compute field name lengths for gofmt-compatible column alignment const groupPascalNames = groups.map(([g]) => toPascalCase(g)); - const allFieldNames = isSession ? ["client", "sessionID", ...groupPascalNames] : ["client", ...groupPascalNames]; + const allFieldNames = isSession ? ["common", ...groupPascalNames] : ["common", ...groupPascalNames]; const maxFieldLen = Math.max(...allFieldNames.map((n) => n.length)); const pad = (name: string) => name.padEnd(maxFieldLen); // Emit wrapper struct lines.push(`// ${wrapperName} provides typed ${isSession ? "session" : "server"}-scoped RPC methods.`); lines.push(`type ${wrapperName} struct {`); - lines.push(`\t${pad("client")} *jsonrpc2.Client`); - if (isSession) lines.push(`\t${pad("sessionID")} string`); + lines.push(`\t${pad("common")} ${serviceName} // Reuse a single struct instead of allocating one for each service on the heap.`); + lines.push(``); for (const [groupName] of groups) { const prefix = isSession ? "" : "Server"; lines.push(`\t${pad(toPascalCase(groupName))} *${prefix}${toPascalCase(groupName)}${apiSuffix}`); @@ -344,34 +347,31 @@ function emitRpcWrapper(lines: string[], node: Record, isSessio lines.push(`}`); lines.push(``); - // Top-level methods (server only) + // Top-level methods on the wrapper use the common service fields for (const [key, value] of topLevelMethods) { if (!isRpcMethod(value)) continue; - emitMethod(lines, wrapperName, key, value, isSession, resolveType, fieldNames, false); + emitMethod(lines, wrapperName, key, value, isSession, resolveType, fieldNames, false, true); } - // Compute key alignment for constructor composite literal (gofmt aligns key: value) - const maxKeyLen = Math.max(...groupPascalNames.map((n) => n.length + 1)); // +1 for colon - const padKey = (name: string) => (name + ":").padEnd(maxKeyLen + 1); // +1 for min trailing space - // Constructor const ctorParams = isSession ? "client *jsonrpc2.Client, sessionID string" : "client *jsonrpc2.Client"; - const ctorFields = isSession ? "client: client, sessionID: sessionID," : "client: client,"; lines.push(`func New${wrapperName}(${ctorParams}) *${wrapperName} {`); - lines.push(`\treturn &${wrapperName}{${ctorFields}`); + lines.push(`\tr := &${wrapperName}{}`); + if (isSession) { + lines.push(`\tr.common = ${serviceName}{client: client, sessionID: sessionID}`); + } else { + lines.push(`\tr.common = ${serviceName}{client: client}`); + } for (const [groupName] of groups) { const prefix = isSession ? "" : "Server"; - const apiInit = isSession - ? `&${toPascalCase(groupName)}${apiSuffix}{client: client, sessionID: sessionID}` - : `&${prefix}${toPascalCase(groupName)}${apiSuffix}{client: client}`; - lines.push(`\t\t${padKey(toPascalCase(groupName))}${apiInit},`); + lines.push(`\tr.${toPascalCase(groupName)} = (*${prefix}${toPascalCase(groupName)}${apiSuffix})(&r.common)`); } - lines.push(`\t}`); + lines.push(`\treturn r`); lines.push(`}`); lines.push(``); } -function emitMethod(lines: string[], receiver: string, name: string, method: RpcMethod, isSession: boolean, resolveType: (name: string) => string, fieldNames: Map>, groupExperimental = false): void { +function emitMethod(lines: string[], receiver: string, name: string, method: RpcMethod, isSession: boolean, resolveType: (name: string) => string, fieldNames: Map>, groupExperimental = false, isWrapper = false): void { const methodName = toPascalCase(name); const resultType = resolveType(toPascalCase(method.rpcMethod) + "Result"); @@ -381,6 +381,10 @@ function emitMethod(lines: string[], receiver: string, name: string, method: Rpc const hasParams = isSession ? nonSessionParams.length > 0 : Object.keys(paramProps).length > 0; const paramsType = hasParams ? resolveType(toPascalCase(method.rpcMethod) + "Params") : ""; + // For wrapper-level methods, access fields through a.common; for service type aliases, use a directly + const clientRef = isWrapper ? "a.common.client" : "a.client"; + const sessionIDRef = isWrapper ? "a.common.sessionID" : "a.sessionID"; + if (method.stability === "experimental" && !groupExperimental) { lines.push(`// Experimental: ${methodName} is an experimental API and may change or be removed in future versions.`); } @@ -391,7 +395,7 @@ function emitMethod(lines: string[], receiver: string, name: string, method: Rpc lines.push(sig + ` {`); if (isSession) { - lines.push(`\treq := map[string]interface{}{"sessionId": a.sessionID}`); + lines.push(`\treq := map[string]any{"sessionId": ${sessionIDRef}}`); if (hasParams) { lines.push(`\tif params != nil {`); for (const pName of nonSessionParams) { @@ -408,10 +412,10 @@ function emitMethod(lines: string[], receiver: string, name: string, method: Rpc } lines.push(`\t}`); } - lines.push(`\traw, err := a.client.Request("${method.rpcMethod}", req)`); + lines.push(`\traw, err := ${clientRef}.Request("${method.rpcMethod}", req)`); } else { - const arg = hasParams ? "params" : "map[string]interface{}{}"; - lines.push(`\traw, err := a.client.Request("${method.rpcMethod}", ${arg})`); + const arg = hasParams ? "params" : "nil"; + lines.push(`\traw, err := ${clientRef}.Request("${method.rpcMethod}", ${arg})`); } lines.push(`\tif err != nil {`);