-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathkernel_browser_rollout_processor.py
More file actions
289 lines (245 loc) · 10.4 KB
/
kernel_browser_rollout_processor.py
File metadata and controls
289 lines (245 loc) · 10.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
"""
Kernel Browser Rollout Processor for Eval Protocol.
Runs multi-step browser episodes using Kernel browser pools.
This enables evaluating VLM agents on computer use tasks within the EP framework.
Usage:
from kernel_browser_rollout_processor import KernelBrowserRolloutProcessor
@evaluation_test(
rollout_processor=KernelBrowserRolloutProcessor(pool_name="eval-browser-pool"),
...
)
def test_my_eval(row: EvaluationRow) -> EvaluationRow:
...
"""
import asyncio
import base64
import io
import logging
import os
from PIL import Image
from eval_protocol.models import EvaluationRow, Message, Status
from eval_protocol.pytest.rollout_processor import RolloutProcessor
from eval_protocol.pytest.types import RolloutProcessorConfig
# From kernel-tinker-rl (vendored core package)
from kernel import Kernel
from core.agent import AgentConfig, DEFAULT_MODEL, QwenAgent
from core.agent_loop import run_agent_loop
from core.browser import KernelBrowserAdapter
logger = logging.getLogger(__name__)
def encode_screenshots(images: list[Image.Image]) -> list[str]:
"""Encode PIL Images to base64 strings for JSON storage."""
encoded = []
for img in images:
buffer = io.BytesIO()
img.save(buffer, format="PNG")
encoded.append(base64.b64encode(buffer.getvalue()).decode("utf-8"))
return encoded
def decode_screenshots(encoded: list[str]) -> list[Image.Image]:
"""Decode base64 strings back to PIL Images."""
images = []
for b64 in encoded:
buffer = io.BytesIO(base64.b64decode(b64))
images.append(Image.open(buffer))
return images
class KernelBrowserRolloutProcessor(RolloutProcessor):
"""
Rollout processor that runs multi-step browser episodes using Kernel.
Each rollout:
1. Acquires browser from Kernel pool
2. Navigates to initial URL
3. Runs VLM agent loop (screenshot → predict → execute → repeat)
4. Captures trajectory (screenshots, action_history)
5. Releases browser
6. Returns EvaluationRow with trajectory data in execution_metadata.extra
The evaluation function can then use WebJudge to score the trajectory.
Args:
pool_name: Name of the Kernel browser pool to use
max_steps: Maximum number of agent steps per episode
image_max_size: Max dimension for screenshot resizing
acquire_timeout_seconds: Timeout for acquiring browser from pool
system_prompt: Custom system prompt (defaults to agent_auth prompt)
extra_actions: Additional action types to support
base_url: API base URL for model inference
"""
def __init__(
self,
pool_name: str = "eval-browser-pool",
max_steps: int = 10,
image_max_size: int = 512,
acquire_timeout_seconds: int = 120,
system_prompt: str | None = None,
extra_actions: list | None = None,
base_url: str = "https://api.fireworks.ai/inference/v1",
):
self.pool_name = pool_name
self.max_steps = max_steps
self.image_max_size = image_max_size
self.acquire_timeout_seconds = acquire_timeout_seconds
self.system_prompt = system_prompt
self.extra_actions = extra_actions or []
self.base_url = base_url
self._kernel: Kernel | None = None
@property
def api_key(self) -> str | None:
"""Get API key at runtime (not init time) to support remote environments."""
return os.environ.get("FIREWORKS_API_KEY")
def setup(self) -> None:
"""Initialize Kernel client."""
self._kernel = Kernel()
logger.info(f"Kernel client initialized, using pool: {self.pool_name}")
def __call__(
self,
rows: list[EvaluationRow],
config: RolloutProcessorConfig,
) -> list[asyncio.Task[EvaluationRow]]:
"""Process evaluation rows and return async tasks."""
semaphore = config.semaphore
async def _sem_wrapper(row: EvaluationRow) -> EvaluationRow:
async with semaphore:
return await self._process_row(row, config)
return [asyncio.create_task(_sem_wrapper(row)) for row in rows]
async def _process_row(
self,
row: EvaluationRow,
config: RolloutProcessorConfig,
) -> EvaluationRow:
"""
Run a single browser episode and return trajectory.
The trajectory data is stored in row.execution_metadata.extra for
the evaluation function to use with WebJudge.
"""
if self._kernel is None:
self._kernel = Kernel()
# Extract task info from row
dataset_info = row.input_metadata.dataset_info or {}
initial_url = dataset_info.get("initial_url", "")
task = dataset_info.get("task", "")
task_id = row.input_metadata.row_id or "unknown"
# Get model from completion_params
completion_params = config.completion_params or {}
model = completion_params.get("model", DEFAULT_MODEL)
# Strip the fireworks_ai/ prefix that the eval protocol adds for LiteLLM
# routing — the direct OpenAI client expects raw Fireworks model names
# like "accounts/fireworks/models/..." not "fireworks_ai/accounts/...".
if isinstance(model, str) and model.startswith("fireworks_ai/"):
model = model[len("fireworks_ai/"):]
temperature = completion_params.get("temperature", 0.0)
max_tokens = completion_params.get("max_tokens", 512)
# Episode state
adapter: KernelBrowserAdapter | None = None
error: str | None = None
final_url: str | None = None
try:
# Acquire browser from pool
browser = self._kernel.browser_pools.acquire(
self.pool_name,
acquire_timeout_seconds=self.acquire_timeout_seconds,
)
adapter = KernelBrowserAdapter(self._kernel, browser, reset_on_init=True)
# Navigate to initial URL
initial_screenshot = adapter.navigate(initial_url)
# Create agent
agent_config = AgentConfig(
model=model,
base_url=self.base_url,
api_key=self.api_key,
system_prompt=self.system_prompt,
extra_actions=self.extra_actions,
temperature=temperature,
max_tokens=max_tokens,
)
agent = QwenAgent(config=agent_config)
# Run the multi-turn agent loop
# This runs in a thread pool since it has blocking VLM calls
loop_result = await asyncio.to_thread(
run_agent_loop,
agent=agent,
adapter=adapter,
task=task,
initial_screenshot=initial_screenshot,
max_steps=self.max_steps,
image_max_size=self.image_max_size,
)
# Get final URL
final_url = adapter.get_current_url()
# Get preserved message history from agent (includes tool_calls)
agent_messages = agent.get_messages()
row.messages = [
Message(
role=m["role"],
content=m.get("content"),
tool_calls=m.get("tool_calls"),
)
for m in agent_messages
]
# Store trajectory data in extra (JSON-serializable)
if row.execution_metadata.extra is None:
row.execution_metadata.extra = {}
row.execution_metadata.extra.update(
{
# For WebJudge evaluation (base64 encoded for JSON serialization)
"screenshots_b64": encode_screenshots(loop_result.screenshots),
"action_history": loop_result.action_history,
# Task info (needed for Trajectory construction)
"task": task,
"task_id": task_id,
"initial_url": initial_url,
"final_url": final_url,
# Episode metadata
"termination_reason": loop_result.termination_reason,
"terminal_action": loop_result.terminal_action,
"steps_completed": loop_result.steps_completed,
# Step-by-step details (for debugging)
"step_results": [
{
"step": sr.step,
"action_desc": sr.action_desc,
"predict_time": sr.predict_time,
"exec_time": sr.exec_time,
"total_time": sr.total_time,
"is_terminal": sr.is_terminal,
"error": sr.error,
}
for sr in loop_result.step_results
],
}
)
# Set rollout status
if loop_result.error:
logger.error(f"Rollout error for {task_id}: {loop_result.error}")
row.rollout_status = Status.error(loop_result.error)
else:
row.rollout_status = Status.rollout_finished()
except Exception as e:
error = str(e)
logger.error(f"Rollout failed for {task_id}: {error}")
if row.execution_metadata.extra is None:
row.execution_metadata.extra = {}
row.execution_metadata.extra.update(
{
"task": task,
"task_id": task_id,
"initial_url": initial_url,
"error": error,
"screenshots_b64": [],
"action_history": [],
"model_responses": [],
"termination_reason": "error",
"terminal_action": None,
"steps_completed": 0,
}
)
row.rollout_status = Status.error(error)
finally:
# Release browser back to pool
if adapter is not None:
try:
reuse = error is None and not adapter._should_not_reuse
self._kernel.browser_pools.release(
self.pool_name,
session_id=adapter.session_id,
reuse=reuse,
)
except Exception as e:
logger.warning(f"Failed to release browser: {e}")
return row