From cab21861d1bdbed04473afab055746a238d0a1dd Mon Sep 17 00:00:00 2001 From: David Gageot Date: Sat, 21 Mar 2026 20:12:01 +0100 Subject: [PATCH 1/4] Refactor RAG from agent-level config to standard toolset type MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the special-cased RAG configuration (AgentConfig.RAG []string referencing top-level Config.RAG map) with a standard toolset type (type: rag) that follows the same patterns as MCP and other toolsets. Key changes: - Add 'type: rag' to toolset registry with ref support for shared definitions - Introduce RAGToolset wrapper (mirrors MCPToolset) for top-level rag section - Add RAGConfig field to Toolset for inline/resolved RAG configuration - Add resolveRAGDefinitions() mirroring resolveMCPDefinitions() - Extract rag.NewManager() for per-toolset manager creation - Implement tools.Startable on RAGTool for lazy init and file watching - Remove RAG special-casing from Team, LocalRuntime, and teamloader - Add v6→v7 migration for old rag agent field to toolset entries - Update schema, docs, and all example YAML files Assisted-By: docker-agent --- agent-schema.json | 70 +++++++++++++--- docs/features/rag/index.md | 4 +- examples/rag.yaml | 8 +- examples/rag/bm25.yaml | 6 +- examples/rag/hybrid.yaml | 5 +- examples/rag/reranking.yaml | 5 +- examples/rag/semantic_embeddings.yaml | 8 +- pkg/config/config.go | 4 + pkg/config/latest/parse.go | 29 +++++++ pkg/config/latest/types.go | 115 +++++++++++++++++++++++--- pkg/config/latest/validate.go | 12 ++- pkg/config/overrides.go | 7 +- pkg/config/rags.go | 54 ++++++++++++ pkg/rag/builder.go | 96 ++++++++++++--------- pkg/runtime/rag.go | 108 ++---------------------- pkg/runtime/runtime.go | 2 - pkg/runtime/runtime_test.go | 100 ---------------------- pkg/team/team.go | 19 ----- pkg/teamloader/registry.go | 23 ++++++ pkg/teamloader/teamloader.go | 59 +------------ pkg/tools/builtin/rag.go | 27 ++++++ 21 files changed, 407 insertions(+), 354 deletions(-) create mode 100644 pkg/config/rags.go diff --git a/agent-schema.json b/agent-schema.json index fd2bbc76b..7622d8b86 100644 --- a/agent-schema.json +++ b/agent-schema.json @@ -59,9 +59,9 @@ }, "rag": { "type": "object", - "description": "Map of RAG (Retrieval-Augmented Generation) configurations", + "description": "Map of reusable RAG source definitions. Define RAG sources here and reference them by name from agent toolsets to avoid duplication.", "additionalProperties": { - "$ref": "#/definitions/RAGConfig" + "$ref": "#/definitions/RAGToolset" } }, "metadata": { @@ -299,13 +299,6 @@ ], "additionalProperties": false }, - "rag": { - "type": "array", - "description": "List of RAG sources to use for this agent", - "items": { - "type": "string" - } - }, "add_description_parameter": { "type": "boolean", "description": "Whether to add a 'description' parameter to tool calls, allowing the LLM to provide context about why it is calling a tool" @@ -807,6 +800,51 @@ ], "additionalProperties": false }, + "RAGToolset": { + "type": "object", + "description": "Reusable RAG source definition. Define once at the top level and reference by name from agent toolsets. RAG config fields (tool, docs, strategies, results, respect_vcs) are specified directly alongside toolset fields.", + "allOf": [ + { + "$ref": "#/definitions/RAGConfig" + }, + { + "type": "object", + "properties": { + "instruction": { + "type": "string", + "description": "Custom instruction for this RAG source" + }, + "tools": { + "type": "array", + "description": "Optional list of tools to expose", + "items": { + "type": "string" + } + }, + "name": { + "type": "string", + "description": "Optional display name override for the RAG tool" + }, + "defer": { + "description": "Deferred loading configuration", + "oneOf": [ + { + "type": "boolean", + "description": "Set to true to defer all tools" + }, + { + "type": "array", + "description": "Array of tool names to defer", + "items": { + "type": "string" + } + } + ] + } + } + } + ] + }, "Toolset": { "type": "object", "description": "Tool configuration", @@ -830,7 +868,8 @@ "user_prompt", "openapi", "model_picker", - "background_agents" + "background_agents", + "rag" ] }, "instruction": { @@ -910,6 +949,10 @@ "$ref": "#/definitions/ApiConfig", "description": "API tool configuration" }, + "rag_config": { + "$ref": "#/definitions/RAGConfig", + "description": "RAG configuration for type: rag toolsets" + }, "ignore_vcs": { "type": "boolean", "description": "Whether to ignore VCS files (.git directories and .gitignore patterns) in filesystem operations. Default: true", @@ -1119,6 +1162,13 @@ "const": "background_agents" } } + }, + { + "properties": { + "type": { + "const": "rag" + } + } } ] }, diff --git a/docs/features/rag/index.md b/docs/features/rag/index.md index 0172dd71f..011b5f180 100644 --- a/docs/features/rag/index.md +++ b/docs/features/rag/index.md @@ -36,7 +36,9 @@ agents: model: openai/gpt-4o instruction: | You have access to a knowledge base. Use it to answer questions. - rag: [my_docs] + toolsets: + - type: rag + ref: my_docs ``` ## Retrieval Strategies diff --git a/examples/rag.yaml b/examples/rag.yaml index 5b1b423df..52955f10f 100644 --- a/examples/rag.yaml +++ b/examples/rag.yaml @@ -1,3 +1,4 @@ + agents: root: model: gpt-5-minimal @@ -7,8 +8,9 @@ agents: can use when it makes sense to do so, based on the user's question. If you receive sources from the knowledge base, always include them as a markdown list of links to local files at the very end of your response. - rag: - - blork_knowledge_base + toolsets: + - type: rag + ref: blork_knowledge_base models: gpt-5-minimal: @@ -27,4 +29,4 @@ rag: - type: chunked-embeddings embedding_model: openai/text-embedding-3-small database: ./rag/chunked_embeddings.db - vector_dimensions: 1536 \ No newline at end of file + vector_dimensions: 1536 diff --git a/examples/rag/bm25.yaml b/examples/rag/bm25.yaml index a79200cc4..441dd1ab7 100644 --- a/examples/rag/bm25.yaml +++ b/examples/rag/bm25.yaml @@ -1,3 +1,4 @@ + agents: root: model: openai/gpt-5-mini @@ -5,8 +6,9 @@ agents: instruction: | You are a helpful assistant that uses BM25 keyword-based search to find relevant information in documents. - rag: - - blork_knowledge_base + toolsets: + - type: rag + ref: blork_knowledge_base rag: blork_knowledge_base: diff --git a/examples/rag/hybrid.yaml b/examples/rag/hybrid.yaml index 2cde29acb..16e163914 100644 --- a/examples/rag/hybrid.yaml +++ b/examples/rag/hybrid.yaml @@ -13,8 +13,9 @@ agents: instruction: | You are a helpful assistant with access to hybrid retrieval combining semantic and keyword search for comprehensive results. - rag: - - knowledge_base + toolsets: + - type: rag + ref: knowledge_base rag: knowledge_base: diff --git a/examples/rag/reranking.yaml b/examples/rag/reranking.yaml index 5ee88d675..4503dec56 100644 --- a/examples/rag/reranking.yaml +++ b/examples/rag/reranking.yaml @@ -18,8 +18,9 @@ agents: instruction: | You are a helpful assistant with access to hybrid retrieval combining semantic and keyword search for comprehensive results. - rag: - - knowledge_base + toolsets: + - type: rag + ref: knowledge_base rag: knowledge_base: diff --git a/examples/rag/semantic_embeddings.yaml b/examples/rag/semantic_embeddings.yaml index 2f2a8ea6b..e754046e8 100644 --- a/examples/rag/semantic_embeddings.yaml +++ b/examples/rag/semantic_embeddings.yaml @@ -17,8 +17,9 @@ agents: instruction: | You are a helpful coding assistant with access to semantic code search. Use the search tool to find relevant code based on meaning, not just keywords. - rag: - - codebase + toolsets: + - type: rag + ref: codebase rag: codebase: @@ -78,7 +79,7 @@ rag: chunking: size: 1000 respect_word_boundaries: true - code_aware: true # Use tree-sitter for AST-aware chunking + code_aware: true # Use tree-sitter for AST-based chunking results: # Optional: rerank results using an LLM for better relevance @@ -94,4 +95,3 @@ rag: deduplicate: true return_full_content: false # return full document content instead of just the matched chunks limit: 5 - diff --git a/pkg/config/config.go b/pkg/config/config.go index a618fc7ec..ea5d15a22 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -124,6 +124,10 @@ func validateConfig(cfg *latest.Config) error { return err } + if err := resolveRAGDefinitions(cfg); err != nil { + return err + } + allNames := map[string]bool{} for _, agent := range cfg.Agents { allNames[agent.Name] = true diff --git a/pkg/config/latest/parse.go b/pkg/config/latest/parse.go index 595422201..4b3af7edb 100644 --- a/pkg/config/latest/parse.go +++ b/pkg/config/latest/parse.go @@ -26,5 +26,34 @@ func upgradeIfNeeded(c any, _ []byte) (any, error) { var config Config types.CloneThroughJSON(old, &config) + + // Migrate AgentConfig.RAG []string → toolsets with type: rag + ref + for i, agent := range old.Agents { + if len(agent.RAG) == 0 { + continue + } + for _, ragName := range agent.RAG { + config.Agents[i].Toolsets = append(config.Agents[i].Toolsets, Toolset{ + Type: "rag", + Ref: ragName, + }) + } + } + + // Migrate top-level RAG map from RAGConfig to RAGToolset + if len(old.RAG) > 0 && config.RAG == nil { + config.RAG = make(map[string]RAGToolset) + } + for name, oldRAG := range old.RAG { + var ragCfg RAGConfig + types.CloneThroughJSON(oldRAG, &ragCfg) + config.RAG[name] = RAGToolset{ + Toolset: Toolset{ + Type: "rag", + RAGConfig: &ragCfg, + }, + } + } + return config, nil } diff --git a/pkg/config/latest/types.go b/pkg/config/latest/types.go index 53eca84b5..2396939a1 100644 --- a/pkg/config/latest/types.go +++ b/pkg/config/latest/types.go @@ -25,7 +25,7 @@ type Config struct { Providers map[string]ProviderConfig `json:"providers,omitempty"` Models map[string]ModelConfig `json:"models,omitempty"` MCPs map[string]MCPToolset `json:"mcps,omitempty"` - RAG map[string]RAGConfig `json:"rag,omitempty"` + RAG map[string]RAGToolset `json:"rag,omitempty"` Metadata Metadata `json:"metadata"` Permissions *PermissionsConfig `json:"permissions,omitempty"` } @@ -52,6 +52,96 @@ func (m *MCPToolset) UnmarshalYAML(unmarshal func(any) error) error { return m.validate() } +// RAGToolset is a reusable RAG source definition stored in the top-level +// "rag" section. It is identical to a Toolset but skips the normal +// Toolset.validate() call during YAML unmarshaling because the "type" +// field is implicit (always "rag") and the RAG config is validated +// during config resolution. +type RAGToolset struct { + Toolset `json:",inline" yaml:",inline"` +} + +func (r RAGToolset) MarshalYAML() (any, error) { + // Flatten RAGConfig fields alongside toolset fields into a single map. + result := make(map[string]any) + + if r.Instruction != "" { + result["instruction"] = r.Instruction + } + if len(r.Tools) > 0 { + result["tools"] = r.Tools + } + if r.Name != "" { + result["name"] = r.Name + } + if !r.Defer.IsEmpty() { + result["defer"] = r.Defer + } + + if r.RAGConfig != nil { + cfg := r.RAGConfig + result["tool"] = cfg.Tool + if len(cfg.Docs) > 0 { + result["docs"] = cfg.Docs + } + if cfg.RespectVCS != nil { + result["respect_vcs"] = *cfg.RespectVCS + } + if len(cfg.Strategies) > 0 { + result["strategies"] = cfg.Strategies + } + result["results"] = cfg.Results + } + + return result, nil +} + +func (r *RAGToolset) UnmarshalYAML(unmarshal func(any) error) error { + // RAGToolset flattens RAGConfig fields directly at the top level, + // so users write tool/docs/strategies alongside toolset fields + // (instruction, tools, name, defer) without a rag_config wrapper. + // + // We unmarshal into a raw map first to avoid strict-mode errors + // from fields that belong to RAGConfig but not Toolset. + var raw map[string]any + if err := unmarshal(&raw); err != nil { + return err + } + + // Extract toolset-level fields + var tf Toolset + tf.Type = "rag" + if v, ok := raw["instruction"].(string); ok { + tf.Instruction = v + } + if v, ok := raw["name"].(string); ok { + tf.Name = v + } + if v, ok := raw["tools"]; ok { + if arr, ok := v.([]any); ok { + for _, item := range arr { + if s, ok := item.(string); ok { + tf.Tools = append(tf.Tools, s) + } + } + } + } + if v, ok := raw["defer"]; ok { + data, _ := yaml.Marshal(v) + _ = yaml.Unmarshal(data, &tf.Defer) + } + + // Unmarshal RAGConfig from the same map (it has its own UnmarshalYAML) + var ragCfg RAGConfig + if err := unmarshal(&ragCfg); err != nil { + return err + } + + tf.RAGConfig = &ragCfg + r.Toolset = tf + return nil +} + type Agents []AgentConfig func (c *Agents) UnmarshalYAML(unmarshal func(any) error) error { @@ -236,16 +326,16 @@ func (d Duration) MarshalJSON() ([]byte, error) { // AgentConfig represents a single agent configuration type AgentConfig struct { - Name string - Model string `json:"model,omitempty"` - Fallback *FallbackConfig `json:"fallback,omitempty"` - Description string `json:"description,omitempty"` - WelcomeMessage string `json:"welcome_message,omitempty"` - Toolsets []Toolset `json:"toolsets,omitempty"` - Instruction string `json:"instruction,omitempty"` - SubAgents []string `json:"sub_agents,omitempty"` - Handoffs []string `json:"handoffs,omitempty"` - RAG []string `json:"rag,omitempty"` + Name string + Model string `json:"model,omitempty"` + Fallback *FallbackConfig `json:"fallback,omitempty"` + Description string `json:"description,omitempty"` + WelcomeMessage string `json:"welcome_message,omitempty"` + Toolsets []Toolset `json:"toolsets,omitempty"` + Instruction string `json:"instruction,omitempty"` + SubAgents []string `json:"sub_agents,omitempty"` + Handoffs []string `json:"handoffs,omitempty"` + AddDate bool `json:"add_date,omitempty"` AddEnvironmentInfo bool `json:"add_environment_info,omitempty"` CodeModeTools bool `json:"code_mode_tools,omitempty"` @@ -607,6 +697,9 @@ type Toolset struct { // For the `fetch` tool Timeout int `json:"timeout,omitempty"` + // For the `rag` tool + RAGConfig *RAGConfig `json:"rag_config,omitempty" yaml:"rag_config,omitempty"` + // For the `model_picker` tool Models []string `json:"models,omitempty"` } diff --git a/pkg/config/latest/validate.go b/pkg/config/latest/validate.go index 87f4fecc5..6d2131418 100644 --- a/pkg/config/latest/validate.go +++ b/pkg/config/latest/validate.go @@ -90,8 +90,8 @@ func (t *Toolset) validate() error { if len(t.Args) > 0 && t.Type != "mcp" && t.Type != "lsp" { return errors.New("args can only be used with type 'mcp' or 'lsp'") } - if t.Ref != "" && t.Type != "mcp" { - return errors.New("ref can only be used with type 'mcp'") + if t.Ref != "" && t.Type != "mcp" && t.Type != "rag" { + return errors.New("ref can only be used with type 'mcp' or 'rag'") } if (t.Remote.URL != "" || t.Remote.TransportType != "") && t.Type != "mcp" { return errors.New("remote can only be used with type 'mcp'") @@ -111,6 +111,9 @@ func (t *Toolset) validate() error { if t.Name != "" && (t.Type != "mcp" && t.Type != "a2a") { return errors.New("name can only be used with type 'mcp' or 'a2a'") } + if t.RAGConfig != nil && t.Type != "rag" { + return errors.New("rag_config can only be used with type 'rag'") + } switch t.Type { case "shell": @@ -152,6 +155,11 @@ func (t *Toolset) validate() error { if len(t.Models) == 0 { return errors.New("model_picker toolset requires at least one model in the 'models' list") } + case "rag": + // rag toolset requires either a ref or inline rag_config + if t.Ref == "" && t.RAGConfig == nil { + return errors.New("rag toolset requires either ref or rag_config") + } case "background_agents": // no additional validation needed } diff --git a/pkg/config/overrides.go b/pkg/config/overrides.go index 3bb0afad6..842027ccd 100644 --- a/pkg/config/overrides.go +++ b/pkg/config/overrides.go @@ -110,8 +110,11 @@ func ensureModelsExist(cfg *latest.Config) error { } // Ensure models referenced by RAG strategies exist - for ragName, ragCfg := range cfg.RAG { - for _, stratCfg := range ragCfg.Strategies { + for ragName, ragToolset := range cfg.RAG { + if ragToolset.RAGConfig == nil { + continue + } + for _, stratCfg := range ragToolset.RAGConfig.Strategies { rawModel, ok := stratCfg.Params["model"] if !ok { continue diff --git a/pkg/config/rags.go b/pkg/config/rags.go new file mode 100644 index 000000000..4cc6a3c18 --- /dev/null +++ b/pkg/config/rags.go @@ -0,0 +1,54 @@ +package config + +import ( + "fmt" + + "github.com/docker/docker-agent/pkg/config/latest" +) + +// resolveRAGDefinitions resolves RAG definition references in agent toolsets. +// When an agent toolset of type "rag" has a ref that matches a key in the +// top-level rag section, the toolset is expanded with the definition's properties. +// Any properties set directly on the toolset override the definition properties. +func resolveRAGDefinitions(cfg *latest.Config) error { + for i := range cfg.Agents { + agent := &cfg.Agents[i] + for j := range agent.Toolsets { + ts := &agent.Toolsets[j] + if ts.Type != "rag" || ts.Ref == "" { + continue + } + + def, ok := cfg.RAG[ts.Ref] + if !ok { + return fmt.Errorf("agent '%s' references non-existent RAG definition '%s'", agent.Name, ts.Ref) + } + + applyRAGDefaults(ts, &def.Toolset) + } + } + + return nil +} + +// applyRAGDefaults fills empty fields in ts from def. Toolset values win. +func applyRAGDefaults(ts, def *latest.Toolset) { + // Clear the ref since it's been resolved + ts.Ref = "" + + if ts.RAGConfig == nil { + ts.RAGConfig = def.RAGConfig + } + if ts.Instruction == "" { + ts.Instruction = def.Instruction + } + if len(ts.Tools) == 0 { + ts.Tools = def.Tools + } + if ts.Defer.IsEmpty() { + ts.Defer = def.Defer + } + if ts.Name == "" { + ts.Name = def.Name + } +} diff --git a/pkg/rag/builder.go b/pkg/rag/builder.go index dc265670d..cfd9a0b07 100644 --- a/pkg/rag/builder.go +++ b/pkg/rag/builder.go @@ -25,6 +25,8 @@ type ManagersBuildConfig struct { } // NewManagers constructs all RAG managers defined in the config. +// +// Deprecated: Use NewManager for per-toolset creation instead. func NewManagers(ctx context.Context, cfg *latest.Config, buildCfg ManagersBuildConfig) ([]*Manager, error) { if len(cfg.RAG) == 0 { return nil, nil @@ -32,52 +34,72 @@ func NewManagers(ctx context.Context, cfg *latest.Config, buildCfg ManagersBuild var managers []*Manager - for ragName, ragCfg := range cfg.RAG { - // Validate that we have at least one strategy - if len(ragCfg.Strategies) == 0 { - return nil, fmt.Errorf("no strategies configured for RAG %q", ragName) + for ragName, ragToolset := range cfg.RAG { + if ragToolset.RAGConfig == nil { + continue } - - // Build context for strategy builders - strategyBuildCtx := strategy.BuildContext{ - RAGName: ragName, - ParentDir: buildCfg.ParentDir, - SharedDocs: GetAbsolutePaths(buildCfg.ParentDir, ragCfg.Docs), - Models: buildCfg.Models, - Env: buildCfg.Env, - ModelsGateway: buildCfg.ModelsGateway, - RespectVCS: ragCfg.GetRespectVCS(), - } - - strategyConfigs, strategyEvents, err := buildStrategyConfigs(ctx, ragCfg, strategyBuildCtx, ragName) + mgr, err := NewManager(ctx, ragName, ragToolset.RAGConfig, buildCfg) if err != nil { - return nil, fmt.Errorf("failed to build strategy configs for RAG %q: %w", ragName, err) + return nil, err } + managers = append(managers, mgr) + } - managerCfg, err := buildManagerConfig(ctx, ragCfg, buildCfg, strategyConfigs) - if err != nil { - return nil, fmt.Errorf("failed to build manager config for RAG %q: %w", ragName, err) - } + return managers, nil +} - // The strategyEvents channel is so the manager can convert strategy events to RAG events. - manager, err := New(ctx, ragName, managerCfg, strategyEvents) - if err != nil { - return nil, fmt.Errorf("failed to create RAG manager %q: %w", ragName, err) - } +// NewManager constructs a single RAG manager from a RAGConfig. +func NewManager( + ctx context.Context, + ragName string, + ragCfg *latest.RAGConfig, + buildCfg ManagersBuildConfig, +) (*Manager, error) { + if ragCfg == nil { + return nil, fmt.Errorf("nil RAG config for %q", ragName) + } - managers = append(managers, manager) + // Validate that we have at least one strategy + if len(ragCfg.Strategies) == 0 { + return nil, fmt.Errorf("no strategies configured for RAG %q", ragName) + } - strategyNames := make([]string, len(strategyConfigs)) - for i, sc := range strategyConfigs { - strategyNames[i] = sc.Name - } - slog.Debug("Created RAG manager", - "name", ragName, - "strategies", strategyNames, - "docs", len(managerCfg.Docs)) + // Build context for strategy builders + strategyBuildCtx := strategy.BuildContext{ + RAGName: ragName, + ParentDir: buildCfg.ParentDir, + SharedDocs: GetAbsolutePaths(buildCfg.ParentDir, ragCfg.Docs), + Models: buildCfg.Models, + Env: buildCfg.Env, + ModelsGateway: buildCfg.ModelsGateway, + RespectVCS: ragCfg.GetRespectVCS(), } - return managers, nil + strategyConfigs, strategyEvents, err := buildStrategyConfigs(ctx, *ragCfg, strategyBuildCtx, ragName) + if err != nil { + return nil, fmt.Errorf("failed to build strategy configs for RAG %q: %w", ragName, err) + } + + managerCfg, err := buildManagerConfig(ctx, *ragCfg, buildCfg, strategyConfigs) + if err != nil { + return nil, fmt.Errorf("failed to build manager config for RAG %q: %w", ragName, err) + } + + manager, err := New(ctx, ragName, managerCfg, strategyEvents) + if err != nil { + return nil, fmt.Errorf("failed to create RAG manager %q: %w", ragName, err) + } + + strategyNames := make([]string, len(strategyConfigs)) + for i, sc := range strategyConfigs { + strategyNames[i] = sc.Name + } + slog.Debug("Created RAG manager", + "name", ragName, + "strategies", strategyNames, + "docs", len(managerCfg.Docs)) + + return manager, nil } // buildManagerConfig constructs a rag.Manager Config from the configuration and strategies. diff --git a/pkg/runtime/rag.go b/pkg/runtime/rag.go index d1e87bd04..813d1219b 100644 --- a/pkg/runtime/rag.go +++ b/pkg/runtime/rag.go @@ -1,105 +1,15 @@ package runtime -import ( - "context" - "fmt" - "log/slog" +import "context" - "github.com/docker/docker-agent/pkg/rag" - "github.com/docker/docker-agent/pkg/rag/types" -) - -// StartBackgroundRAGInit initializes RAG in background and forwards events -// Should be called early (e.g., by App) to start indexing before RunStream -func (r *LocalRuntime) StartBackgroundRAGInit(ctx context.Context, sendEvent func(Event)) { - if r.ragInitialized.Swap(true) { - return - } - - ragManagers := r.team.RAGManagers() - if len(ragManagers) == 0 { - return - } - - // Set up event forwarding BEFORE starting initialization - r.forwardRAGEvents(ctx, ragManagers, sendEvent) - initializeRAG(ctx, ragManagers) - startRAGFileWatchers(ctx, ragManagers) -} - -// forwardRAGEvents forwards RAG manager events to the given callback -// Consolidates duplicated event forwarding logic -func (r *LocalRuntime) forwardRAGEvents(ctx context.Context, ragManagers []*rag.Manager, sendEvent func(Event)) { - for _, mgr := range ragManagers { - go func() { - ragName := mgr.Name() - slog.Debug("Starting RAG event forwarder goroutine", "rag", ragName) - for { - select { - case <-ctx.Done(): - slog.Debug("RAG event forwarder stopped", "rag", ragName) - return - case ragEvent, ok := <-mgr.Events(): - if !ok { - slog.Debug("RAG events channel closed", "rag", ragName) - return - } - - agentName := r.CurrentAgentName() - slog.Debug("Forwarding RAG event", "type", ragEvent.Type, "rag", ragName, "agent", agentName) - - switch ragEvent.Type { - case types.EventTypeIndexingStarted: - sendEvent(RAGIndexingStarted(ragName, ragEvent.StrategyName)) - case types.EventTypeIndexingProgress: - if ragEvent.Progress != nil { - sendEvent(RAGIndexingProgress(ragName, ragEvent.StrategyName, ragEvent.Progress.Current, ragEvent.Progress.Total, agentName)) - } - case types.EventTypeIndexingComplete: - sendEvent(RAGIndexingCompleted(ragName, ragEvent.StrategyName)) - case types.EventTypeUsage: - // Convert RAG usage to TokenUsageEvent so TUI displays it - sendEvent(NewTokenUsageEvent("", agentName, &Usage{ - InputTokens: ragEvent.TotalTokens, - ContextLength: ragEvent.TotalTokens, - Cost: ragEvent.Cost, - })) - case types.EventTypeError: - if ragEvent.Error != nil { - sendEvent(Error(fmt.Sprintf("RAG %s error: %v", ragName, ragEvent.Error))) - } - default: - // Log unhandled events for debugging - slog.Debug("Unhandled RAG event type", "type", ragEvent.Type, "rag", ragName) - } - } - } - }() - } -} - -// InitializeRAG initializes all RAG managers in the background -func initializeRAG(ctx context.Context, ragManagers []*rag.Manager) { - for _, mgr := range ragManagers { - go func() { - slog.Debug("Starting RAG manager initialization goroutine", "rag", mgr.Name()) - if err := mgr.Initialize(ctx); err != nil { - slog.Error("Failed to initialize RAG manager", "rag", mgr.Name(), "error", err) - } else { - slog.Info("RAG manager initialized successfully", "rag", mgr.Name()) - } - }() - } +// StartBackgroundRAGInit is a no-op. RAG initialization is now handled +// per-toolset via the tools.Startable interface. +func (r *LocalRuntime) StartBackgroundRAGInit(_ context.Context, _ func(Event)) { + // RAG toolsets are initialized lazily when first used. } -// StartRAGFileWatchers starts file watchers for all RAG managers -func startRAGFileWatchers(ctx context.Context, ragManagers []*rag.Manager) { - for _, mgr := range ragManagers { - go func() { - slog.Debug("Starting RAG file watcher goroutine", "rag", mgr.Name()) - if err := mgr.StartFileWatcher(ctx); err != nil { - slog.Error("Failed to start RAG file watcher", "rag", mgr.Name(), "error", err) - } - }() - } +// InitializeRAG is a no-op. RAG initialization is now handled +// per-toolset via the tools.Startable interface. +func (r *LocalRuntime) InitializeRAG(_ context.Context, _ chan Event) { + // RAG toolsets are initialized lazily when first used. } diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index a0e95b422..03feeb013 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -8,7 +8,6 @@ import ( "maps" "strings" "sync" - "sync/atomic" "time" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -184,7 +183,6 @@ type LocalRuntime struct { elicitationRequestCh chan ElicitationResult // Channel for receiving elicitation responses elicitationEventsChannel chan Event // Current events channel for sending elicitation requests elicitationEventsChannelMux sync.RWMutex // Protects elicitationEventsChannel - ragInitialized atomic.Bool sessionStore session.Store workingDir string // Working directory for hooks execution env []string // Environment variables for hooks execution diff --git a/pkg/runtime/runtime_test.go b/pkg/runtime/runtime_test.go index cd9d17a66..4bb92cb9e 100644 --- a/pkg/runtime/runtime_test.go +++ b/pkg/runtime/runtime_test.go @@ -7,7 +7,6 @@ import ( "reflect" "sync" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -19,10 +18,6 @@ import ( "github.com/docker/docker-agent/pkg/model/provider/base" "github.com/docker/docker-agent/pkg/modelsdev" "github.com/docker/docker-agent/pkg/permissions" - "github.com/docker/docker-agent/pkg/rag" - "github.com/docker/docker-agent/pkg/rag/database" - "github.com/docker/docker-agent/pkg/rag/strategy" - ragtypes "github.com/docker/docker-agent/pkg/rag/types" "github.com/docker/docker-agent/pkg/session" "github.com/docker/docker-agent/pkg/team" "github.com/docker/docker-agent/pkg/tools" @@ -501,101 +496,6 @@ func TestContextCancellation(t *testing.T) { require.IsType(t, &StreamStoppedEvent{}, events[len(events)-1]) } -// stubRAGStrategy is a minimal implementation of strategy.Strategy for testing RAG initialization. -type stubRAGStrategy struct{} - -func (s *stubRAGStrategy) Initialize(_ context.Context, _ []string, _ strategy.ChunkingConfig) error { - return nil -} - -func (s *stubRAGStrategy) Query(_ context.Context, _ string, _ int, _ float64) ([]database.SearchResult, error) { - return nil, nil -} - -func (s *stubRAGStrategy) CheckAndReindexChangedFiles(_ context.Context, _ []string, _ strategy.ChunkingConfig) error { - return nil -} - -func (s *stubRAGStrategy) StartFileWatcher(_ context.Context, _ []string, _ strategy.ChunkingConfig) error { - return nil -} - -func (s *stubRAGStrategy) Close() error { return nil } - -func TestStartBackgroundRAGInit_StopsForwardingAfterContextCancel(t *testing.T) { - t.Parallel() - - baseCtx := t.Context() - ctx, cancel := context.WithCancel(baseCtx) - defer cancel() - - // Build a RAG manager with a stub strategy and a controllable event channel. - strategyEvents := make(chan ragtypes.Event, 10) - mgr, err := rag.New( - ctx, - "test-rag", - rag.Config{ - StrategyConfigs: []strategy.Config{ - { - Name: "stub", - Strategy: &stubRAGStrategy{}, - Docs: nil, - }, - }, - }, - strategyEvents, - ) - require.NoError(t, err) - defer func() { - _ = mgr.Close() - }() - - rt := &LocalRuntime{ - team: team.New(team.WithRAGManagers([]*rag.Manager{mgr})), - currentAgent: "root", - } - - eventsCh := make(chan Event, 10) - - // Start background RAG init with event forwarding. - rt.StartBackgroundRAGInit(ctx, func(ev Event) { - eventsCh <- ev - }) - - // Emit an "indexing_completed" event and ensure it is forwarded. - strategyEvents <- ragtypes.Event{ - Type: ragtypes.EventTypeIndexingComplete, - StrategyName: "stub", - } - - select { - case <-eventsCh: - // ok: at least one event forwarded - case <-time.After(100 * time.Millisecond): - t.Fatalf("expected RAG event to be forwarded before cancellation") - } - - // Cancel the context and ensure no further events are forwarded. - cancel() - - // Brief yield to allow the forwarder goroutine to observe cancellation. - // This is a timing-based negative test: we verify no event is forwarded. - time.Sleep(10 * time.Millisecond) - - // Emit another event; it should NOT be forwarded. - strategyEvents <- ragtypes.Event{ - Type: ragtypes.EventTypeIndexingComplete, - StrategyName: "stub", - } - - select { - case ev := <-eventsCh: - t.Fatalf("expected no events after cancellation, got %T", ev) - case <-time.After(20 * time.Millisecond): - // success: no events forwarded - } -} - func TestToolCallVariations(t *testing.T) { tests := []struct { name string diff --git a/pkg/team/team.go b/pkg/team/team.go index 7e4e34a5e..df68ec8bc 100644 --- a/pkg/team/team.go +++ b/pkg/team/team.go @@ -4,18 +4,15 @@ import ( "context" "errors" "fmt" - "log/slog" "strings" "github.com/docker/docker-agent/pkg/agent" "github.com/docker/docker-agent/pkg/config/types" "github.com/docker/docker-agent/pkg/permissions" - "github.com/docker/docker-agent/pkg/rag" ) type Team struct { agents []*agent.Agent - ragManagers []*rag.Manager permissions *permissions.Checker } @@ -27,12 +24,6 @@ func WithAgents(agents ...*agent.Agent) Opt { } } -func WithRAGManagers(managers []*rag.Manager) Opt { - return func(t *Team) { - t.ragManagers = managers - } -} - func WithPermissions(checker *permissions.Checker) Opt { return func(t *Team) { t.permissions = checker @@ -127,20 +118,10 @@ func (t *Team) StopToolSets(ctx context.Context) error { return fmt.Errorf("failed to stop tool sets: %w", err) } } - for name, mgr := range t.ragManagers { - if err := mgr.Close(); err != nil { - slog.Error("Failed to close RAG manager", "name", name, "error", err) - } - } return nil } -// RAGManagers returns the RAG managers for this team -func (t *Team) RAGManagers() []*rag.Manager { - return t.ragManagers -} - // Permissions returns the permission checker for this team. // Returns nil if no permissions are configured. func (t *Team) Permissions() *permissions.Checker { diff --git a/pkg/teamloader/registry.go b/pkg/teamloader/registry.go index 59222c587..3cf729053 100644 --- a/pkg/teamloader/registry.go +++ b/pkg/teamloader/registry.go @@ -1,6 +1,7 @@ package teamloader import ( + "cmp" "context" "errors" "fmt" @@ -16,6 +17,7 @@ import ( "github.com/docker/docker-agent/pkg/memory/database/sqlite" "github.com/docker/docker-agent/pkg/path" "github.com/docker/docker-agent/pkg/paths" + "github.com/docker/docker-agent/pkg/rag" "github.com/docker/docker-agent/pkg/toolinstall" "github.com/docker/docker-agent/pkg/tools" "github.com/docker/docker-agent/pkg/tools/a2a" @@ -79,6 +81,7 @@ func NewDefaultToolsetRegistry() *ToolsetRegistry { r.Register("openapi", createOpenAPITool) r.Register("model_picker", createModelPickerTool) r.Register("background_agents", createBackgroundAgentsTool) + r.Register("rag", createRAGTool) return r } @@ -353,3 +356,23 @@ func createModelPickerTool(_ context.Context, toolset latest.Toolset, _ string, func createBackgroundAgentsTool(_ context.Context, _ latest.Toolset, _ string, _ *config.RuntimeConfig, _ string) (tools.ToolSet, error) { return agenttool.NewToolSet(), nil } + +func createRAGTool(ctx context.Context, toolset latest.Toolset, parentDir string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) { + if toolset.RAGConfig == nil { + return nil, errors.New("rag toolset requires rag_config (should have been resolved from ref)") + } + + ragName := cmp.Or(toolset.Name, "rag") + + mgr, err := rag.NewManager(ctx, ragName, toolset.RAGConfig, rag.ManagersBuildConfig{ + ParentDir: parentDir, + ModelsGateway: runConfig.ModelsGateway, + Env: runConfig.EnvProvider(), + }) + if err != nil { + return nil, fmt.Errorf("failed to create RAG manager: %w", err) + } + + toolName := cmp.Or(mgr.ToolName(), ragName) + return builtin.NewRAGTool(mgr, toolName), nil +} diff --git a/pkg/teamloader/teamloader.go b/pkg/teamloader/teamloader.go index 55071373f..459c04064 100644 --- a/pkg/teamloader/teamloader.go +++ b/pkg/teamloader/teamloader.go @@ -22,7 +22,6 @@ import ( "github.com/docker/docker-agent/pkg/model/provider/options" "github.com/docker/docker-agent/pkg/modelsdev" "github.com/docker/docker-agent/pkg/permissions" - "github.com/docker/docker-agent/pkg/rag" "github.com/docker/docker-agent/pkg/skills" "github.com/docker/docker-agent/pkg/team" "github.com/docker/docker-agent/pkg/tools" @@ -122,20 +121,9 @@ func LoadWithConfig(ctx context.Context, agentSource config.Source, runConfig *c return nil, err } - // Create RAG managers + // Load agents parentDir := cmp.Or(agentSource.ParentDir(), runConfig.WorkingDir) configName := configNameFromSource(agentSource.Name()) - ragManagers, err := rag.NewManagers(ctx, cfg, rag.ManagersBuildConfig{ - ParentDir: parentDir, - ModelsGateway: runConfig.ModelsGateway, - Env: env, - Models: cfg.Models, - }) - if err != nil { - return nil, fmt.Errorf("failed to create RAG managers: %w", err) - } - - // Load agents var agents []*agent.Agent agentsByName := make(map[string]*agent.Agent) @@ -211,12 +199,6 @@ func LoadWithConfig(ctx context.Context, agentSource config.Source, runConfig *c opts = append(opts, agent.WithLoadTimeWarnings(warnings)) } - // Add RAG tools if agent has RAG sources - if len(agentConfig.RAG) > 0 { - ragTools := createRAGToolsForAgent(&agentConfig, ragManagers) - agentTools = append(agentTools, ragTools...) - } - // Add skills toolset if skills are enabled if agentConfig.Skills.Enabled() { loadedSkills := skills.Load(agentConfig.Skills.Sources) @@ -275,7 +257,6 @@ func LoadWithConfig(ctx context.Context, agentSource config.Source, runConfig *c return &LoadResult{ Team: team.New( team.WithAgents(agents...), - team.WithRAGManagers(ragManagers), team.WithPermissions(permChecker), ), Models: cfg.Models, @@ -602,41 +583,3 @@ func externalDepthFromContext(ctx context.Context) int { func contextWithExternalDepth(ctx context.Context, depth int) context.Context { return context.WithValue(ctx, externalDepthKey, depth) } - -// createRAGToolsForAgent creates RAG tools for an agent, one for each referenced RAG source -func createRAGToolsForAgent(agentConfig *latest.AgentConfig, ragManagers []*rag.Manager) []tools.ToolSet { - if len(agentConfig.RAG) == 0 { - return nil - } - - var ragTools []tools.ToolSet - - for _, ragName := range agentConfig.RAG { - idx := slices.IndexFunc(ragManagers, func(m *rag.Manager) bool { - return m.Name() == ragName - }) - if idx == -1 { - slog.Error("RAG source not found", "rag_source", ragName) - continue - } - - mgr := ragManagers[idx] - - // Use custom tool name if configured, otherwise use the RAG source name - toolName := cmp.Or(mgr.ToolName(), ragName) - - // Create a separate tool for this RAG source - ragTool := builtin.NewRAGTool(mgr, toolName) - - ragTools = append(ragTools, ragTool) - - slog.Debug("Created RAG tool for agent", - "rag_source", ragName, - "tool_name", toolName, - "manager_name", mgr.Name(), - "description", mgr.Description(), - "instruction", mgr.ToolInstruction()) - } - - return ragTools -} diff --git a/pkg/tools/builtin/rag.go b/pkg/tools/builtin/rag.go index d35878d2f..1e0b9d745 100644 --- a/pkg/tools/builtin/rag.go +++ b/pkg/tools/builtin/rag.go @@ -23,6 +23,7 @@ type RAGTool struct { var ( _ tools.ToolSet = (*RAGTool)(nil) _ tools.Instructable = (*RAGTool)(nil) + _ tools.Startable = (*RAGTool)(nil) ) // NewRAGTool creates a new RAG tool for a single RAG manager @@ -45,6 +46,32 @@ type QueryResult struct { ChunkIndex int `json:"chunk_index" jsonschema:"Index of the chunk within the source document"` } +// Start initializes the RAG manager (indexes documents). +func (t *RAGTool) Start(ctx context.Context) error { + if t.manager == nil { + return nil + } + slog.Debug("Starting RAG tool initialization", "tool", t.toolName) + if err := t.manager.Initialize(ctx); err != nil { + return fmt.Errorf("failed to initialize RAG manager %q: %w", t.toolName, err) + } + // Start file watcher in background + go func() { + if err := t.manager.StartFileWatcher(ctx); err != nil { + slog.Error("Failed to start RAG file watcher", "tool", t.toolName, "error", err) + } + }() + return nil +} + +// Stop closes the RAG manager and releases resources. +func (t *RAGTool) Stop(_ context.Context) error { + if t.manager == nil { + return nil + } + return t.manager.Close() +} + func (t *RAGTool) Instructions() string { if t.manager != nil { instruction := t.manager.ToolInstruction() From 9ff467941e3b0efaf38a1dd686b7318d899abfb6 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Mon, 23 Mar 2026 15:04:54 +0100 Subject: [PATCH 2/4] Restore RAG indexing event forwarding to TUI after toolset refactor The PR #2210 moved RAG from agent-level config to standard toolset type (tools.Startable) but removed the event forwarding that showed indexing progress in the TUI sidebar. This restores event forwarding by: - Adding an EventCallback to RAGTool that forwards rag.Manager events during Start() initialization - Having StartBackgroundRAGInit discover RAG tools from agent toolsets and wire up the event callback before initialization happens - Converting RAG manager events (indexing started/progress/completed, usage, errors) back to runtime events for the TUI Assisted-By: docker-agent --- pkg/runtime/rag.go | 66 ++++++++++++++++++++++++++++++++++------ pkg/tools/builtin/rag.go | 41 +++++++++++++++++++++++-- 2 files changed, 96 insertions(+), 11 deletions(-) diff --git a/pkg/runtime/rag.go b/pkg/runtime/rag.go index 813d1219b..80f5bf927 100644 --- a/pkg/runtime/rag.go +++ b/pkg/runtime/rag.go @@ -1,15 +1,63 @@ package runtime -import "context" +import ( + "context" + "fmt" + "log/slog" -// StartBackgroundRAGInit is a no-op. RAG initialization is now handled -// per-toolset via the tools.Startable interface. -func (r *LocalRuntime) StartBackgroundRAGInit(_ context.Context, _ func(Event)) { - // RAG toolsets are initialized lazily when first used. + ragtypes "github.com/docker/docker-agent/pkg/rag/types" + "github.com/docker/docker-agent/pkg/tools" + "github.com/docker/docker-agent/pkg/tools/builtin" +) + +// StartBackgroundRAGInit discovers RAG toolsets from agents and wires up event +// forwarding so the TUI can display indexing progress. Actual initialization +// happens lazily when the tool is first used (via tools.Startable). +func (r *LocalRuntime) StartBackgroundRAGInit(ctx context.Context, sendEvent func(Event)) { + for _, name := range r.team.AgentNames() { + a, err := r.team.Agent(name) + if err != nil { + continue + } + for _, ts := range a.ToolSets() { + ragTool, ok := tools.As[*builtin.RAGTool](ts) + if !ok { + continue + } + ragTool.SetEventCallback(ragEventForwarder(ctx, ragTool.Name(), r, sendEvent)) + } + } } -// InitializeRAG is a no-op. RAG initialization is now handled -// per-toolset via the tools.Startable interface. -func (r *LocalRuntime) InitializeRAG(_ context.Context, _ chan Event) { - // RAG toolsets are initialized lazily when first used. +// ragEventForwarder returns a callback that converts RAG manager events to runtime events. +func ragEventForwarder(ctx context.Context, ragName string, r *LocalRuntime, sendEvent func(Event)) builtin.RAGEventCallback { + return func(ragEvent ragtypes.Event) { + agentName := r.CurrentAgentName() + slog.Debug("Forwarding RAG event", "type", ragEvent.Type, "rag", ragName, "agent", agentName) + + switch ragEvent.Type { + case ragtypes.EventTypeIndexingStarted: + sendEvent(RAGIndexingStarted(ragName, ragEvent.StrategyName)) + case ragtypes.EventTypeIndexingProgress: + if ragEvent.Progress != nil { + sendEvent(RAGIndexingProgress(ragName, ragEvent.StrategyName, ragEvent.Progress.Current, ragEvent.Progress.Total, agentName)) + } + case ragtypes.EventTypeIndexingComplete: + sendEvent(RAGIndexingCompleted(ragName, ragEvent.StrategyName)) + case ragtypes.EventTypeUsage: + sendEvent(NewTokenUsageEvent("", agentName, &Usage{ + InputTokens: ragEvent.TotalTokens, + ContextLength: ragEvent.TotalTokens, + Cost: ragEvent.Cost, + })) + case ragtypes.EventTypeError: + if ragEvent.Error != nil { + sendEvent(Error(fmt.Sprintf("RAG %s error: %v", ragName, ragEvent.Error))) + } + default: + slog.Debug("Unhandled RAG event type", "type", ragEvent.Type, "rag", ragName) + } + + _ = ctx // available for future use + } } diff --git a/pkg/tools/builtin/rag.go b/pkg/tools/builtin/rag.go index 1e0b9d745..465f1a0a8 100644 --- a/pkg/tools/builtin/rag.go +++ b/pkg/tools/builtin/rag.go @@ -10,13 +10,18 @@ import ( "slices" "github.com/docker/docker-agent/pkg/rag" + ragtypes "github.com/docker/docker-agent/pkg/rag/types" "github.com/docker/docker-agent/pkg/tools" ) +// RAGEventCallback is called to forward RAG manager events during initialization. +type RAGEventCallback func(event ragtypes.Event) + // RAGTool provides document querying capabilities for a single RAG source type RAGTool struct { - manager *rag.Manager - toolName string + manager *rag.Manager + toolName string + eventCallback RAGEventCallback } // Verify interface compliance @@ -35,6 +40,11 @@ func NewRAGTool(manager *rag.Manager, toolName string) *RAGTool { } } +// Name returns the tool name for this RAG source. +func (t *RAGTool) Name() string { + return t.toolName +} + type QueryRAGArgs struct { Query string `json:"query" jsonschema:"Search query"` } @@ -46,12 +56,24 @@ type QueryResult struct { ChunkIndex int `json:"chunk_index" jsonschema:"Index of the chunk within the source document"` } +// SetEventCallback sets a callback to receive RAG manager events during initialization. +// This must be called before Start() to receive indexing progress events. +func (t *RAGTool) SetEventCallback(cb RAGEventCallback) { + t.eventCallback = cb +} + // Start initializes the RAG manager (indexes documents). func (t *RAGTool) Start(ctx context.Context) error { if t.manager == nil { return nil } slog.Debug("Starting RAG tool initialization", "tool", t.toolName) + + // Forward RAG manager events if a callback is set + if t.eventCallback != nil { + go t.forwardEvents(ctx) + } + if err := t.manager.Initialize(ctx); err != nil { return fmt.Errorf("failed to initialize RAG manager %q: %w", t.toolName, err) } @@ -64,6 +86,21 @@ func (t *RAGTool) Start(ctx context.Context) error { return nil } +// forwardEvents reads events from the RAG manager and forwards them via the callback. +func (t *RAGTool) forwardEvents(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case event, ok := <-t.manager.Events(): + if !ok { + return + } + t.eventCallback(event) + } + } +} + // Stop closes the RAG manager and releases resources. func (t *RAGTool) Stop(_ context.Context) error { if t.manager == nil { From 10dcb3c2e9548c0fbac505088d4919ed3faa6b3b Mon Sep 17 00:00:00 2001 From: David Gageot Date: Mon, 23 Mar 2026 18:40:50 +0100 Subject: [PATCH 3/4] Add a welcome message Signed-off-by: David Gageot --- examples/rag.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/rag.yaml b/examples/rag.yaml index 52955f10f..409a5b0ce 100644 --- a/examples/rag.yaml +++ b/examples/rag.yaml @@ -8,6 +8,7 @@ agents: can use when it makes sense to do so, based on the user's question. If you receive sources from the knowledge base, always include them as a markdown list of links to local files at the very end of your response. + welcome_message: Ask me anything about Blorks. toolsets: - type: rag ref: blork_knowledge_base From fdf60288c52459f66f1c8af51271a2e2598ac96e Mon Sep 17 00:00:00 2001 From: David Gageot Date: Mon, 23 Mar 2026 18:48:44 +0100 Subject: [PATCH 4/4] Simplify RAG event forwarding and clean up RAGTool - Remove RAGInitializer interface and StartBackgroundRAGInit indirection. RAG event callbacks are now wired in configureToolsetHandlers alongside other handler setup, using the same pattern as Elicitable/OAuthCapable. - Remove deprecated NewManagers wrapper (no callers after toolset refactor). - Clean up RAGTool: unexport internal types (QueryRAGArgs, QueryResult), inline sortResults, remove verbose debug logging from Tools(), simplify handleQueryRAG. Assisted-By: docker-agent --- pkg/app/app.go | 18 ----- pkg/rag/builder.go | 24 ------- pkg/runtime/loop.go | 10 +-- pkg/runtime/rag.go | 25 +------ pkg/tools/builtin/rag.go | 130 +++++++++++++--------------------- pkg/tools/builtin/rag_test.go | 19 ++--- 6 files changed, 61 insertions(+), 165 deletions(-) diff --git a/pkg/app/app.go b/pkg/app/app.go index 08b5132e5..b8c70acaf 100644 --- a/pkg/app/app.go +++ b/pkg/app/app.go @@ -30,12 +30,6 @@ import ( "github.com/docker/docker-agent/pkg/tui/messages" ) -// RAGInitializer is implemented by runtimes that support background RAG initialization. -// Local runtimes use this to start indexing early; remote runtimes typically do not. -type RAGInitializer interface { - StartBackgroundRAGInit(ctx context.Context, sendEvent func(runtime.Event)) -} - type App struct { runtime runtime.Runtime session *session.Session @@ -122,18 +116,6 @@ func New(ctx context.Context, rt runtime.Runtime, sess *session.Session, opts .. } }() - // If the runtime supports background RAG initialization, start it - // and forward events to the TUI. Remote runtimes typically handle RAG server-side - // and won't implement this optional interface. - if ragRuntime, ok := rt.(RAGInitializer); ok { - go ragRuntime.StartBackgroundRAGInit(ctx, func(event runtime.Event) { - select { - case app.events <- event: - case <-ctx.Done(): - } - }) - } - // Subscribe to tool list changes so the sidebar updates immediately // when an MCP server adds or removes tools (outside of a RunStream). if tcs, ok := rt.(runtime.ToolsChangeSubscriber); ok { diff --git a/pkg/rag/builder.go b/pkg/rag/builder.go index cfd9a0b07..a243513ec 100644 --- a/pkg/rag/builder.go +++ b/pkg/rag/builder.go @@ -24,30 +24,6 @@ type ManagersBuildConfig struct { Models map[string]latest.ModelConfig // Model configurations from config } -// NewManagers constructs all RAG managers defined in the config. -// -// Deprecated: Use NewManager for per-toolset creation instead. -func NewManagers(ctx context.Context, cfg *latest.Config, buildCfg ManagersBuildConfig) ([]*Manager, error) { - if len(cfg.RAG) == 0 { - return nil, nil - } - - var managers []*Manager - - for ragName, ragToolset := range cfg.RAG { - if ragToolset.RAGConfig == nil { - continue - } - mgr, err := NewManager(ctx, ragName, ragToolset.RAGConfig, buildCfg) - if err != nil { - return nil, err - } - managers = append(managers, mgr) - } - - return managers, nil -} - // NewManager constructs a single RAG manager from a RAGConfig. func NewManager( ctx context.Context, diff --git a/pkg/runtime/loop.go b/pkg/runtime/loop.go index 69caa7eb6..e6a643d8b 100644 --- a/pkg/runtime/loop.go +++ b/pkg/runtime/loop.go @@ -96,11 +96,6 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c // Emit team information events <- TeamInfo(r.agentDetailsFromTeam(), a.Name()) - // Initialize RAG and forward events - r.StartBackgroundRAGInit(ctx, func(event Event) { - events <- event - }) - r.emitAgentWarnings(a, chanSend(events)) r.configureToolsetHandlers(a, events) @@ -534,6 +529,11 @@ func (r *LocalRuntime) configureToolsetHandlers(a *agent.Agent, events chan Even func() { events <- Authorization(tools.ElicitationActionAccept, a.Name()) }, r.managedOAuth, ) + + // Wire RAG event forwarding so the TUI shows indexing progress. + if ragTool, ok := tools.As[*builtin.RAGTool](toolset); ok { + ragTool.SetEventCallback(ragEventForwarder(ragTool.Name(), r, chanSend(events))) + } } } diff --git a/pkg/runtime/rag.go b/pkg/runtime/rag.go index 80f5bf927..ba875717d 100644 --- a/pkg/runtime/rag.go +++ b/pkg/runtime/rag.go @@ -1,36 +1,15 @@ package runtime import ( - "context" "fmt" "log/slog" ragtypes "github.com/docker/docker-agent/pkg/rag/types" - "github.com/docker/docker-agent/pkg/tools" "github.com/docker/docker-agent/pkg/tools/builtin" ) -// StartBackgroundRAGInit discovers RAG toolsets from agents and wires up event -// forwarding so the TUI can display indexing progress. Actual initialization -// happens lazily when the tool is first used (via tools.Startable). -func (r *LocalRuntime) StartBackgroundRAGInit(ctx context.Context, sendEvent func(Event)) { - for _, name := range r.team.AgentNames() { - a, err := r.team.Agent(name) - if err != nil { - continue - } - for _, ts := range a.ToolSets() { - ragTool, ok := tools.As[*builtin.RAGTool](ts) - if !ok { - continue - } - ragTool.SetEventCallback(ragEventForwarder(ctx, ragTool.Name(), r, sendEvent)) - } - } -} - // ragEventForwarder returns a callback that converts RAG manager events to runtime events. -func ragEventForwarder(ctx context.Context, ragName string, r *LocalRuntime, sendEvent func(Event)) builtin.RAGEventCallback { +func ragEventForwarder(ragName string, r *LocalRuntime, sendEvent func(Event)) builtin.RAGEventCallback { return func(ragEvent ragtypes.Event) { agentName := r.CurrentAgentName() slog.Debug("Forwarding RAG event", "type", ragEvent.Type, "rag", ragName, "agent", agentName) @@ -57,7 +36,5 @@ func ragEventForwarder(ctx context.Context, ragName string, r *LocalRuntime, sen default: slog.Debug("Unhandled RAG event type", "type", ragEvent.Type, "rag", ragName) } - - _ = ctx // available for future use } } diff --git a/pkg/tools/builtin/rag.go b/pkg/tools/builtin/rag.go index 465f1a0a8..f37bf4539 100644 --- a/pkg/tools/builtin/rag.go +++ b/pkg/tools/builtin/rag.go @@ -17,22 +17,21 @@ import ( // RAGEventCallback is called to forward RAG manager events during initialization. type RAGEventCallback func(event ragtypes.Event) -// RAGTool provides document querying capabilities for a single RAG source +// RAGTool provides document querying capabilities for a single RAG source. type RAGTool struct { manager *rag.Manager toolName string eventCallback RAGEventCallback } -// Verify interface compliance +// Verify interface compliance. var ( _ tools.ToolSet = (*RAGTool)(nil) _ tools.Instructable = (*RAGTool)(nil) _ tools.Startable = (*RAGTool)(nil) ) -// NewRAGTool creates a new RAG tool for a single RAG manager -// toolName is the name to use for the tool (typically from config or manager name) +// NewRAGTool creates a new RAG tool for a single RAG manager. func NewRAGTool(manager *rag.Manager, toolName string) *RAGTool { return &RAGTool{ manager: manager, @@ -45,31 +44,20 @@ func (t *RAGTool) Name() string { return t.toolName } -type QueryRAGArgs struct { - Query string `json:"query" jsonschema:"Search query"` -} - -type QueryResult struct { - SourcePath string `json:"source_path" jsonschema:"Path to the source document"` - Content string `json:"content" jsonschema:"Relevant document chunk content"` - Similarity float64 `json:"similarity" jsonschema:"Similarity score (0-1)"` - ChunkIndex int `json:"chunk_index" jsonschema:"Index of the chunk within the source document"` -} - -// SetEventCallback sets a callback to receive RAG manager events during initialization. -// This must be called before Start() to receive indexing progress events. +// SetEventCallback sets a callback to receive RAG manager events during +// initialization. Must be called before Start(). func (t *RAGTool) SetEventCallback(cb RAGEventCallback) { t.eventCallback = cb } -// Start initializes the RAG manager (indexes documents). +// Start initializes the RAG manager (indexes documents) and starts a +// file watcher for incremental updates. func (t *RAGTool) Start(ctx context.Context) error { if t.manager == nil { return nil } - slog.Debug("Starting RAG tool initialization", "tool", t.toolName) - // Forward RAG manager events if a callback is set + // Forward RAG manager events if a callback is set. if t.eventCallback != nil { go t.forwardEvents(ctx) } @@ -77,7 +65,7 @@ func (t *RAGTool) Start(ctx context.Context) error { if err := t.manager.Initialize(ctx); err != nil { return fmt.Errorf("failed to initialize RAG manager %q: %w", t.toolName, err) } - // Start file watcher in background + go func() { if err := t.manager.StartFileWatcher(ctx); err != nil { slog.Error("Failed to start RAG file watcher", "tool", t.toolName, "error", err) @@ -86,6 +74,14 @@ func (t *RAGTool) Start(ctx context.Context) error { return nil } +// Stop closes the RAG manager and releases resources. +func (t *RAGTool) Stop(_ context.Context) error { + if t.manager == nil { + return nil + } + return t.manager.Close() +} + // forwardEvents reads events from the RAG manager and forwards them via the callback. func (t *RAGTool) forwardEvents(ctx context.Context) { for { @@ -101,27 +97,27 @@ func (t *RAGTool) forwardEvents(ctx context.Context) { } } -// Stop closes the RAG manager and releases resources. -func (t *RAGTool) Stop(_ context.Context) error { - if t.manager == nil { - return nil - } - return t.manager.Close() -} - func (t *RAGTool) Instructions() string { if t.manager != nil { - instruction := t.manager.ToolInstruction() - if instruction != "" { + if instruction := t.manager.ToolInstruction(); instruction != "" { return instruction } } - - // Default instruction if none provided return fmt.Sprintf("Search documents in %s to find relevant code or documentation. "+ "Provide a clear search query describing what you need.", t.toolName) } +type queryRAGArgs struct { + Query string `json:"query" jsonschema:"Search query"` +} + +type queryResult struct { + SourcePath string `json:"source_path" jsonschema:"Path to the source document"` + Content string `json:"content" jsonschema:"Relevant document chunk content"` + Similarity float64 `json:"similarity" jsonschema:"Similarity score (0-1)"` + ChunkIndex int `json:"chunk_index" jsonschema:"Index of the chunk within the source document"` +} + func (t *RAGTool) Tools(context.Context) ([]tools.Tool, error) { var description string if t.manager != nil { @@ -131,83 +127,53 @@ func (t *RAGTool) Tools(context.Context) ([]tools.Tool, error) { "Provide a natural language query describing what you need. "+ "Returns the most relevant document chunks with file paths.", t.toolName)) - paramsSchema := tools.MustSchemaFor[QueryRAGArgs]() - outputSchema := tools.MustSchemaFor[[]QueryResult]() - - // Log schemas for debugging - if paramsJSON, err := json.Marshal(paramsSchema); err == nil { - slog.Debug("RAG tool parameters schema", - "tool_name", t.toolName, - "schema", string(paramsJSON)) - } - if outputJSON, err := json.Marshal(outputSchema); err == nil { - slog.Debug("RAG tool output schema", - "tool_name", t.toolName, - "schema", string(outputJSON)) - } - - tool := tools.Tool{ + return []tools.Tool{{ Name: t.toolName, Category: "knowledge", Description: description, - Parameters: paramsSchema, - OutputSchema: outputSchema, + Parameters: tools.MustSchemaFor[queryRAGArgs](), + OutputSchema: tools.MustSchemaFor[[]queryResult](), Handler: tools.NewHandler(t.handleQueryRAG), Annotations: tools.ToolAnnotations{ ReadOnlyHint: true, Title: "Query " + t.toolName, }, - } - - slog.Debug("RAG tool registered", - "tool_name", tool.Name, - "category", tool.Category, - "description", description, - "title", tool.Annotations.Title, - "read_only", tool.Annotations.ReadOnlyHint) - - return []tools.Tool{tool}, nil + }}, nil } -func (t *RAGTool) handleQueryRAG(ctx context.Context, args QueryRAGArgs) (*tools.ToolCallResult, error) { +func (t *RAGTool) handleQueryRAG(ctx context.Context, args queryRAGArgs) (*tools.ToolCallResult, error) { if args.Query == "" { return nil, errors.New("query cannot be empty") } results, err := t.manager.Query(ctx, args.Query) if err != nil { - slog.Error("RAG query failed", "rag", t.manager.Name(), "error", err) return nil, fmt.Errorf("RAG query failed: %w", err) } - allResults := make([]QueryResult, 0, len(results)) - for _, result := range results { - allResults = append(allResults, QueryResult{ - SourcePath: result.Document.SourcePath, - Content: result.Document.Content, - Similarity: result.Similarity, - ChunkIndex: result.Document.ChunkIndex, + out := make([]queryResult, 0, len(results)) + for _, r := range results { + out = append(out, queryResult{ + SourcePath: r.Document.SourcePath, + Content: r.Document.Content, + Similarity: r.Similarity, + ChunkIndex: r.Document.ChunkIndex, }) } - sortResults(allResults) + slices.SortFunc(out, func(a, b queryResult) int { + return cmp.Compare(b.Similarity, a.Similarity) + }) - maxResults := 10 - if len(allResults) > maxResults { - allResults = allResults[:maxResults] + const maxResults = 10 + if len(out) > maxResults { + out = out[:maxResults] } - resultJSON, err := json.Marshal(allResults) + resultJSON, err := json.Marshal(out) if err != nil { return nil, fmt.Errorf("failed to marshal results: %w", err) } return tools.ResultSuccess(string(resultJSON)), nil } - -// sortResults sorts query results by similarity in descending order -func sortResults(results []QueryResult) { - slices.SortFunc(results, func(a, b QueryResult) int { - return cmp.Compare(b.Similarity, a.Similarity) // Descending order - }) -} diff --git a/pkg/tools/builtin/rag_test.go b/pkg/tools/builtin/rag_test.go index 9ba3de0b1..031037088 100644 --- a/pkg/tools/builtin/rag_test.go +++ b/pkg/tools/builtin/rag_test.go @@ -1,6 +1,8 @@ package builtin import ( + "cmp" + "slices" "testing" "github.com/stretchr/testify/assert" @@ -29,7 +31,7 @@ func TestRAGTool_ToolName(t *testing.T) { t.Run(tt.name, func(t *testing.T) { tool := &RAGTool{ toolName: tt.toolName, - manager: nil, // We don't need a real manager for name tests + manager: nil, } tools, err := tool.Tools(t.Context()) @@ -50,28 +52,21 @@ func TestRAGTool_DefaultDescription(t *testing.T) { tools, err := tool.Tools(t.Context()) require.NoError(t, err) require.Len(t, tools, 1) - - // Should contain the tool name in the description assert.Contains(t, tools[0].Description, "test_docs") } func TestRAGTool_SortResults(t *testing.T) { - results := []QueryResult{ + results := []queryResult{ {SourcePath: "a.txt", Similarity: 0.5}, {SourcePath: "b.txt", Similarity: 0.9}, {SourcePath: "c.txt", Similarity: 0.3}, {SourcePath: "d.txt", Similarity: 0.7}, } - sortResults(results) - - // Should be sorted by similarity in descending order - assert.InDelta(t, 0.9, results[0].Similarity, 0.001) - assert.InDelta(t, 0.7, results[1].Similarity, 0.001) - assert.InDelta(t, 0.5, results[2].Similarity, 0.001) - assert.InDelta(t, 0.3, results[3].Similarity, 0.001) + slices.SortFunc(results, func(a, b queryResult) int { + return cmp.Compare(b.Similarity, a.Similarity) + }) - // Verify the source paths match assert.Equal(t, "b.txt", results[0].SourcePath) assert.Equal(t, "d.txt", results[1].SourcePath) assert.Equal(t, "a.txt", results[2].SourcePath)