Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions api/src/org/labkey/api/mcp/AbstractAgentAction.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

import com.google.genai.errors.ClientException;
import com.google.genai.errors.ServerException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpSession;
import org.apache.commons.lang3.StringUtils;
import org.json.JSONObject;
import org.labkey.api.action.ReadOnlyApiAction;
import org.labkey.api.security.MethodsAllowed;
import org.labkey.api.util.HtmlString;
import org.labkey.api.util.HttpUtil.Method;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.validation.BindException;

Expand All @@ -15,11 +18,11 @@
import static org.apache.commons.lang3.StringUtils.isNotBlank;

/**
* "agent" it is too strong a word, but if you want to create a tools specific chat endpoint then
* start here.
* First implement getServicePrompt() to tell your "agent its mission. You can also listen in on the
* conversation to help you user get the right results.
* If you want to create a tools-specific chat endpoint, then start here.
* First implement getServicePrompt() to tell your "agent" its mission. You can also listen in on the
* conversation to help the user get the right results.
*/
@MethodsAllowed({Method.POST})
public abstract class AbstractAgentAction<F extends PromptForm> extends ReadOnlyApiAction<F>
{
protected abstract String getAgentName();
Expand All @@ -28,9 +31,11 @@ public abstract class AbstractAgentAction<F extends PromptForm> extends ReadOnly

protected ChatClient getChat(boolean create)
{
HttpSession session = getViewContext().getRequest().getSession(true);
ChatClient chatSession = McpService.get().getChat(session, getAgentName(), this::getServicePrompt, create);
return chatSession;
HttpServletRequest request = getViewContext().getRequest();
if (request == null)
throw new IllegalStateException("No request");
HttpSession session = request.getSession(true);
return McpService.get().getChat(session, getAgentName(), this::getServicePrompt, create);
}

protected String handleEscape(String prompt)
Expand All @@ -50,9 +55,9 @@ protected String handleEscape(String prompt)
}

@Override
public Object execute(PromptForm form, BindException errors) throws Exception
public Object execute(F form, BindException errors) throws Exception
{
try (var mcpPush = McpContext.withContext(getViewContext()))
try (var _ = McpContext.withContext(getViewContext()))
{
String prompt = form.getPrompt();

Expand Down Expand Up @@ -101,11 +106,10 @@ else if (isNotBlank(response.text()))
}
catch (ClientException ex)
{
var ret = new JSONObject(Map.of(
return new JSONObject(Map.of(
"text", ex.getMessage(),
"user", getViewContext().getUser().getName(),
"success", Boolean.FALSE));
return ret;
}
}
}
5 changes: 3 additions & 2 deletions api/src/org/labkey/api/mcp/McpContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.labkey.api.view.UnauthorizedException;
import org.labkey.api.writer.ContainerUser;
import org.springframework.ai.chat.model.ToolContext;

import java.util.Map;

/**
Expand Down Expand Up @@ -57,7 +58,7 @@ public User getUser()
// researched if there are other ways to pass context around to Tools registerd by McpService
//

private static final ThreadLocal<McpContext> contexts = new ThreadLocal();
private static final ThreadLocal<McpContext> contexts = new ThreadLocal<>();

public static @NotNull McpContext get()
{
Expand All @@ -67,7 +68,7 @@ public User getUser()
return ret;
}

public static AutoCloseable withContext(ContainerUser ctx)
public static AutoCloseable withContext(ContainerUser ctx)
{
return with(new McpContext(ctx));
}
Expand Down
3 changes: 2 additions & 1 deletion devtools/src/org/labkey/devtools/TestController.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.labkey.api.data.ContainerManager;
import org.labkey.api.mcp.AbstractAgentAction;
import org.labkey.api.mcp.McpService;
import org.labkey.api.mcp.PromptForm;
import org.labkey.api.security.CSRF;
import org.labkey.api.security.MethodsAllowed;
import org.labkey.api.security.RequiresLogin;
Expand Down Expand Up @@ -1308,7 +1309,7 @@ public void addNavTrail(NavTree root)


@RequiresLogin
public static class ChatEndpointAction extends AbstractAgentAction
public static class ChatEndpointAction extends AbstractAgentAction<PromptForm>
{
@Override
protected String getAgentName()
Expand Down
18 changes: 8 additions & 10 deletions devtools/src/org/labkey/devtools/view/chat.jsp
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
<%@ page import="org.labkey.api.util.DOM" %>
<%@ page import="java.util.stream.Stream" %>
<%@ page import="static org.labkey.api.util.DOM.*" %>
<%@ page import="static org.labkey.api.util.DOM.Attribute.*" %>
<%@ page extends="org.labkey.api.jsp.JspBase" %>
Expand Down Expand Up @@ -75,9 +73,8 @@ function startChatting(chatEndpoint)
scrollToBottom();
}

function handleChatResponse(event)
function handleChatResponse(req)
{
const req = event.target;
if (req.readyState === 4) {
if (req.status >= 200 && req.status < 300)
{
Expand All @@ -95,12 +92,13 @@ function startChatting(chatEndpoint)

function sendMessage(prompt)
{
var url = new URL(chatEndpoint);
url.searchParams.set('prompt', prompt);
var req = new XMLHttpRequest();
req.open('GET', url.toString(), true);
req.onreadystatechange = handleChatResponse;
req.send();
LABKEY.Ajax.request({
url: chatEndpoint,
method: 'POST',
params: {prompt: prompt},
success: handleChatResponse,
failure: handleChatResponse
});
const loadingSpinner = document.querySelector('.loading-spinner');
loadingSpinner.classList.remove('loading-spinner--hidden');
}
Expand Down
7 changes: 2 additions & 5 deletions query/src/org/labkey/query/controllers/QueryController.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.genai.Chat;
import com.google.genai.errors.ClientException;
import com.google.genai.errors.ServerException;
import jakarta.servlet.ServletException;
Expand All @@ -28,7 +27,6 @@
import jakarta.servlet.http.HttpSession;
import org.antlr.runtime.tree.Tree;
import org.apache.commons.beanutils.ConversionException;
import org.apache.commons.beanutils.ConvertUtils;
import org.apache.commons.collections4.MultiValuedMap;
import org.apache.commons.collections4.multimap.ArrayListValuedHashMap;
import org.apache.commons.collections4.multimap.HashSetValuedHashMap;
Expand Down Expand Up @@ -8879,7 +8877,7 @@ public Object execute(SqlPromptForm form, BindException errors) throws Exception
// save form here for context in getServicePrompt()
_form = form;

try (var mcpPush = McpContext.withContext(getViewContext()))
try (var _ = McpContext.withContext(getViewContext()))
{
String prompt = form.getPrompt();

Expand Down Expand Up @@ -8961,11 +8959,10 @@ public Object execute(SqlPromptForm form, BindException errors) throws Exception
}
catch (ClientException ex)
{
var ret = new JSONObject(Map.of(
return new JSONObject(Map.of(
"text", ex.getMessage(),
"user", getViewContext().getUser().getName(),
"success", Boolean.FALSE));
return ret;
}
}
}
Expand Down
37 changes: 21 additions & 16 deletions query/src/org/labkey/query/view/sourceQuery.jsp
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,18 @@
* limitations under the License.
*/
%>
<%@ page import="org.labkey.api.mcp.McpService"%>
<%@ page import="org.labkey.api.query.QueryAction"%>
<%@ page import="org.labkey.api.query.QueryDefinition"%>
<%@ page import="org.labkey.api.query.QueryDefinition" %>
<%@ page import="org.labkey.api.util.HelpTopic" %>
<%@ page import="org.labkey.api.util.JavaScriptFragment" %>
<%@ page import="org.labkey.api.view.ActionURL" %>
<%@ page import="org.labkey.api.view.HttpView" %>
<%@ page import="org.labkey.api.view.template.ClientDependencies" %>
<%@ page import="org.labkey.query.controllers.QueryController" %>
<%@ page import="org.labkey.api.mcp.McpService" %>
<%@ page import="org.labkey.api.util.JavaScriptFragment" %>
<%@ page import="java.lang.Exception" %>
<%@ page import="java.lang.Override" %>
<%@ page import="java.lang.String" %>
<%@ taglib prefix="labkey" uri="http://www.labkey.org/taglib" %>
<%@ page extends="org.labkey.api.jsp.JspBase" %>
<%!
Expand Down Expand Up @@ -363,11 +366,14 @@
if (initPrompt)
{
var url = new URL('./query-queryagent.api', window.location.href);
url.searchParams.set('schemaName', schemaName || '');
url.searchParams.set('prompt', initPrompt);
var req = new XMLHttpRequest();
req.open('GET', url.toString(), true);
req.send();
LABKEY.Ajax.request({
url: url,
method: 'POST',
params: {
prompt: initPrompt,
schemaName: schemaName || ''
}
});
}
}

Expand All @@ -392,12 +398,12 @@
// TODO waiting/thinking UI
// Build URL with same base as current document, endpoint /query-queryagent.api and prompt parameter
var url = new URL('./query-queryagent.api', window.location.href);
url.searchParams.set('prompt', prompt);
var req = new XMLHttpRequest();
req.open('GET', url.toString(), true);
req.onreadystatechange = function () {
if (req.readyState === 4) {
if (req.status >= 200 && req.status < 300) {
LABKEY.Ajax.request({
url: url,
method: 'POST',
params: {prompt: prompt},
callback: function (config, success, req) {
if (success) {
var responseJson = JSON.parse(req.responseText);
var responseText = responseJson['text'];
var responseHtml = responseJson['html'];
Expand All @@ -416,8 +422,7 @@
appendTextResponse('Request failed: ' + req.status + ' ' + (req.statusText || ''));
}
}
};
req.send();
});
ev.preventDefault();
ev.stopPropagation();
return false;
Expand Down