diff --git a/build.gradle b/build.gradle index aa002ab263..2bad3671ea 100644 --- a/build.gradle +++ b/build.gradle @@ -20,6 +20,18 @@ project.configurations { project.dependencies { + // Source: https://mvnrepository.com/artifact/ai.djl/api + implementation("ai.djl:api:0.36.0") + + // Source: https://mvnrepository.com/artifact/ai.djl/model-zoo + implementation("ai.djl:model-zoo:0.36.0") + + // Source: https://mvnrepository.com/artifact/ai.djl.huggingface/tokenizers + implementation("ai.djl.huggingface:tokenizers:0.36.0") + + // Source: https://mvnrepository.com/artifact/ai.djl.pytorch/pytorch-engine + implementation("ai.djl.pytorch:pytorch-engine:0.36.0") + implementation "org.xerial:sqlite-jdbc:${sqliteJdbcVersion}" // declaring SQLite here to be used in TargetedMS test api("junit:junit:${junitVersion}") api("org.seleniumhq.selenium:selenium-api:${seleniumVersion}") diff --git a/src/org/labkey/test/pages/TestChatPage.java b/src/org/labkey/test/pages/TestChatPage.java new file mode 100644 index 0000000000..8653798c68 --- /dev/null +++ b/src/org/labkey/test/pages/TestChatPage.java @@ -0,0 +1,138 @@ +package org.labkey.test.pages; + +import org.labkey.test.BootstrapLocators; +import org.labkey.test.Locator; +import org.labkey.test.WebDriverWrapper; +import org.labkey.test.WebTestHelper; +import org.openqa.selenium.Keys; +import org.openqa.selenium.NoSuchElementException; +import org.openqa.selenium.StaleElementReferenceException; +import org.openqa.selenium.TimeoutException; +import org.openqa.selenium.WebDriver; +import org.openqa.selenium.WebElement; +import org.openqa.selenium.interactions.Actions; + +import java.util.List; +import java.util.stream.Collectors; + +import static org.labkey.test.util.selenium.WebElementUtils.getTextContent; + +public class TestChatPage extends LabKeyPage +{ + + private int _numOfResponses = 0; + + public TestChatPage(WebDriver driver) + { + super(driver); + } + + public static TestChatPage beginAt(WebDriverWrapper driver) + { + driver.beginAt(WebTestHelper.buildURL("test", "chat")); + return new TestChatPage(driver.getDriver()); + } + + @Override + protected void waitForPage() + { + waitFor(() -> { + try + { + return !BootstrapLocators.loadingSpinner.areAnyVisible(getDriver()) + && Locator.tagWithId("textarea", "chatPrompt") + .refindWhenNeeded(getDriver()).isDisplayed() + && Locator.tagWithClass("div", "genaiResponse") + .findElements(getDriver()).size() == 1; + } + catch (NoSuchElementException | StaleElementReferenceException | TimeoutException retry) + { + return false; + } + }, "There is a problem loading the chat page.", 30_000); + + } + + public void enterPrompt(String prompt) + { + + _numOfResponses = Locator.tagWithClass("div", "genaiResponse") + .findElements(getDriver()).size(); + log("enterPrompt: Num of responses: " + _numOfResponses); + + elementCache().chatPrompt.click(); + + Actions actions = new Actions(getDriver()); + actions.sendKeys(prompt) + .keyDown(Keys.SHIFT) + .keyDown(Keys.ENTER) + .keyUp(Keys.ENTER) + .keyUp(Keys.SHIFT) + .build() + .perform(); + + sleep(500); + + log("enterPrompt: Num of responses: " + _numOfResponses); + + } + + public String getMostRecentResponse() + { + log("getResponse: Current num of responses: " + + Locator.tagWithClass("div", "genaiResponse").findElements(getDriver()).size()); + + waitFor(() -> { + try + { + return !BootstrapLocators.loadingSpinner.areAnyVisible(getDriver()) + && elementCache().chatPrompt.isDisplayed() + && Locator.tagWithClass("div", "genaiResponse") + .findElements(getDriver()).size() > _numOfResponses; + } + catch (NoSuchElementException | StaleElementReferenceException | TimeoutException retry) + { + return false; + } + }, "I haven't seen a new response.", 120_000); + + _numOfResponses = Locator.tagWithClass("div", "genaiResponse") + .findElements(getDriver()).size(); + + log("getResponse: Num of responses: " + _numOfResponses); + + return Locator.tagWithClass("div", "genaiResponse") + .findElements(getDriver()).getLast().getText(); + } + + public List getAllResponses() + { + waitFor(() -> { + try + { + return !BootstrapLocators.loadingSpinner.areAnyVisible(getDriver()) + && elementCache().chatPrompt.isDisplayed(); + } + catch (NoSuchElementException | StaleElementReferenceException | TimeoutException retry) + { + return false; + } + }, "Timed out waiting for the current process to stop.", 120_000); + + List responses = Locator.tagWithClass("div", "genaiResponse").findElements(getDriver()); + return responses.stream().map(el -> getTextContent(el).trim()).collect(Collectors.toList()); + + } + + @Override + protected ElementCache newElementCache() + { + return new ElementCache(); + } + + protected class ElementCache extends LabKeyPage.ElementCache + { + WebElement chatPrompt = Locator.tagWithId("textarea", "chatPrompt") + .refindWhenNeeded(this); + } +} diff --git a/src/org/labkey/test/tests/testChatTest.java b/src/org/labkey/test/tests/testChatTest.java new file mode 100644 index 0000000000..02930e9cc4 --- /dev/null +++ b/src/org/labkey/test/tests/testChatTest.java @@ -0,0 +1,205 @@ +package org.labkey.test.tests; + + +import ai.djl.MalformedModelException; +import ai.djl.inference.Predictor; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.translate.TranslateException; +import org.junit.Before; +import org.junit.Test; +import org.labkey.test.BaseWebDriverTest; +import org.labkey.test.pages.TestChatPage; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +public class testChatTest extends BaseWebDriverTest +{ + static final String PROJ_NAME = "SomeSillyProject"; + static final String LOG_FORMAT_STRING = "\n\nTest %s\nString 1: %s\nString 2: %s\nCosine Similarity (1.0-identical, 0.0-unrelated, negative-opposite): %e\nCosine Distance / Deviation (1.0-orthogonal, 0.0-no deviation): %e\n"; + static double _cosine_diff = 0.0; + + Criteria _criteria = null; + // Loading the model is expensive. Could / should pool it. + ZooModel _model = null; + // Predictor is not thread safe. + Predictor _predictor = null; + + @Override + public BrowserType bestBrowser() + { + return BrowserType.CHROME; + } + + @Override + public List getAssociatedModules() + { + return Arrays.asList("experiment", "issues"); + } + @Override + protected String getProjectName() + { + return PROJ_NAME; + } + + @Before + public void buildStuffIfNeeded() + { + + if (null == _criteria) + { + _criteria = Criteria.builder() + .setTypes(String.class, float[].class) + // Force the PyTorch engine and specify the Hugging Face path + .optEngine("PyTorch") + //Load the model: sentence-transformers/all-MiniLM-L6-v2 + //This is the MiniLM embedding model: + // 384-dimensional output vectors + // Optimized for semantic similarity + //all-MiniLM-L6-v2: + // all -> The model was trained on a massive, diverse dataset. + // L6 -> Depth of the neural network. This is a 6-layer model (faster than L12). + // v2 -> Second version of the model. + .optModelUrls("djl://ai.djl.huggingface.pytorch/sentence-transformers/all-MiniLM-L6-v2") + // This translator is often required to bridge the gap between String and the Model's Tensor input. + // Sentence transformer models require tokenization before inference. + .optArgument("tokenizer", "sentence-transformers/all-MiniLM-L6-v2") + .build(); + + try + { + _model = _criteria.loadModel(); + _predictor = _model.newPredictor(); + } + catch (IOException | ModelNotFoundException | MalformedModelException e) + { + throw new RuntimeException(e); + } + + } + + } + + @Test + public void testBaseLine() + { + log("Check the math."); + String string01 = "ABC"; + + double derivation = calculateDeviation(string01, string01); + log(String.format(LOG_FORMAT_STRING, + "Same Response String", string01, string01, _cosine_diff, derivation)); + + String string02 = "123456789"; + + derivation = calculateDeviation(string01, string02); + log(String.format(LOG_FORMAT_STRING, + "Different Check", string01, string02, _cosine_diff, derivation)); + + string01 = "Patient is Healthy"; + string02 = "Patient is Dead"; + + derivation = calculateDeviation(string01, string02); + log(String.format(LOG_FORMAT_STRING, + "Two Different Meanings", string01, string02, _cosine_diff, derivation)); + + string01 = "The quick brown fox jumped over the lazy dog."; + string02 = ".dog lazy the over jumped fox brown quick The"; + + derivation = calculateDeviation(string01, string02); + log(String.format(LOG_FORMAT_STRING, + "Same Words Different Order", string01, string02, _cosine_diff, derivation)); + + } + + @Test + public void testChat() + { + + log("Test the chat app."); + + TestChatPage testChatPage = TestChatPage.beginAt(this); + testChatPage.enterPrompt("Tell me about SampleManager."); + String response1 = testChatPage.getMostRecentResponse(); + + log("Ask the same question again."); + testChatPage.enterPrompt("Tell me about SampleManager."); + String response2 = testChatPage.getMostRecentResponse(); + double derivation = calculateDeviation(response1, response2); + + log(String.format(LOG_FORMAT_STRING, + "Ask The Question Again", response1, response2, _cosine_diff, derivation)); + + log("Now sign out and sign back in to try and change the response."); + signOut(); + signIn(); + + testChatPage = TestChatPage.beginAt(this); + testChatPage.enterPrompt("Tell me about SampleManager."); + response2 = testChatPage.getMostRecentResponse(); + derivation = calculateDeviation(response1, response2); + + log(String.format(LOG_FORMAT_STRING, + "Log Out and Back In", response1, response2, _cosine_diff, derivation)); + + } + + private double calculateDeviation(String str01, String str02) + { + double deviation; + + try { + float[] vector1 = _predictor.predict(str01); + float[] vector2 = _predictor.predict(str02); + deviation = calculateCosineDistance(vector1, vector2); + + } + catch (TranslateException e) + { + throw new RuntimeException(e); + } + + return deviation; + + } + + // Method using linear algebra that calculates the Cosine Similarity (how relevant they are to each other) and then + // converts it to Cosine Distance (the "deviation" or how far apart they are). + public static double calculateCosineDistance(float[] vectorA, float[] vectorB) { + + if (vectorA.length != vectorB.length) { + throw new IllegalArgumentException("Vectors must have the same dimension."); + } + + double dotProduct = 0.0; + double normA = 0.0; + double normB = 0.0; + + for (int i = 0; i < vectorA.length; i++) { + dotProduct += vectorA[i] * vectorB[i]; + normA += vectorA[i] * vectorA[i]; + normB += vectorB[i] * vectorB[i]; + } + + double magnitude = (Math.sqrt(normA) * Math.sqrt(normB)); + if (magnitude == 0.0) + { + return 1.0; // Maximum distance, completely dismilar. + } + // Cosine Similarity Formula. + // Measures the cosine of the angle between two vectors (the text responses converted into numbers). + // dot(A, B) / (||A|| * ||B||) + // Dot Product of A & B divided by the magnitude, Euclidean Norms (lengths) of the vectors multiplied together. + // Range from -1 to 1. 1.0 means identical, 0.0 unrelated, negative is opposite. + double similarity = dotProduct / magnitude; + _cosine_diff = similarity; + + // Return Cosine Distance (Deviation) + // 0.0 means identical, 1.0 means orthogonal (completely different) + return 1.0 - similarity; + } + +}