From eb29b3b9370ba9a322d9b0edcdbdb2c53ad406c2 Mon Sep 17 00:00:00 2001 From: labkey-danield Date: Thu, 12 Feb 2026 05:57:18 -0800 Subject: [PATCH 1/3] Experiment using DJL as a way to measure the deviation of answers from AI chat. --- build.gradle | 12 ++ src/org/labkey/test/pages/TestChatPage.java | 138 +++++++++++++++++++ src/org/labkey/test/tests/testChatTest.java | 145 ++++++++++++++++++++ 3 files changed, 295 insertions(+) create mode 100644 src/org/labkey/test/pages/TestChatPage.java create mode 100644 src/org/labkey/test/tests/testChatTest.java diff --git a/build.gradle b/build.gradle index aa002ab263..2bad3671ea 100644 --- a/build.gradle +++ b/build.gradle @@ -20,6 +20,18 @@ project.configurations { project.dependencies { + // Source: https://mvnrepository.com/artifact/ai.djl/api + implementation("ai.djl:api:0.36.0") + + // Source: https://mvnrepository.com/artifact/ai.djl/model-zoo + implementation("ai.djl:model-zoo:0.36.0") + + // Source: https://mvnrepository.com/artifact/ai.djl.huggingface/tokenizers + implementation("ai.djl.huggingface:tokenizers:0.36.0") + + // Source: https://mvnrepository.com/artifact/ai.djl.pytorch/pytorch-engine + implementation("ai.djl.pytorch:pytorch-engine:0.36.0") + implementation "org.xerial:sqlite-jdbc:${sqliteJdbcVersion}" // declaring SQLite here to be used in TargetedMS test api("junit:junit:${junitVersion}") api("org.seleniumhq.selenium:selenium-api:${seleniumVersion}") diff --git a/src/org/labkey/test/pages/TestChatPage.java b/src/org/labkey/test/pages/TestChatPage.java new file mode 100644 index 0000000000..8653798c68 --- /dev/null +++ b/src/org/labkey/test/pages/TestChatPage.java @@ -0,0 +1,138 @@ +package org.labkey.test.pages; + +import org.labkey.test.BootstrapLocators; +import org.labkey.test.Locator; +import org.labkey.test.WebDriverWrapper; +import org.labkey.test.WebTestHelper; +import org.openqa.selenium.Keys; +import org.openqa.selenium.NoSuchElementException; +import org.openqa.selenium.StaleElementReferenceException; +import org.openqa.selenium.TimeoutException; +import org.openqa.selenium.WebDriver; +import org.openqa.selenium.WebElement; +import org.openqa.selenium.interactions.Actions; + +import java.util.List; +import java.util.stream.Collectors; + +import static org.labkey.test.util.selenium.WebElementUtils.getTextContent; + +public class TestChatPage extends LabKeyPage +{ + + private int _numOfResponses = 0; + + public TestChatPage(WebDriver driver) + { + super(driver); + } + + public static TestChatPage beginAt(WebDriverWrapper driver) + { + driver.beginAt(WebTestHelper.buildURL("test", "chat")); + return new TestChatPage(driver.getDriver()); + } + + @Override + protected void waitForPage() + { + waitFor(() -> { + try + { + return !BootstrapLocators.loadingSpinner.areAnyVisible(getDriver()) + && Locator.tagWithId("textarea", "chatPrompt") + .refindWhenNeeded(getDriver()).isDisplayed() + && Locator.tagWithClass("div", "genaiResponse") + .findElements(getDriver()).size() == 1; + } + catch (NoSuchElementException | StaleElementReferenceException | TimeoutException retry) + { + return false; + } + }, "There is a problem loading the chat page.", 30_000); + + } + + public void enterPrompt(String prompt) + { + + _numOfResponses = Locator.tagWithClass("div", "genaiResponse") + .findElements(getDriver()).size(); + log("enterPrompt: Num of responses: " + _numOfResponses); + + elementCache().chatPrompt.click(); + + Actions actions = new Actions(getDriver()); + actions.sendKeys(prompt) + .keyDown(Keys.SHIFT) + .keyDown(Keys.ENTER) + .keyUp(Keys.ENTER) + .keyUp(Keys.SHIFT) + .build() + .perform(); + + sleep(500); + + log("enterPrompt: Num of responses: " + _numOfResponses); + + } + + public String getMostRecentResponse() + { + log("getResponse: Current num of responses: " + + Locator.tagWithClass("div", "genaiResponse").findElements(getDriver()).size()); + + waitFor(() -> { + try + { + return !BootstrapLocators.loadingSpinner.areAnyVisible(getDriver()) + && elementCache().chatPrompt.isDisplayed() + && Locator.tagWithClass("div", "genaiResponse") + .findElements(getDriver()).size() > _numOfResponses; + } + catch (NoSuchElementException | StaleElementReferenceException | TimeoutException retry) + { + return false; + } + }, "I haven't seen a new response.", 120_000); + + _numOfResponses = Locator.tagWithClass("div", "genaiResponse") + .findElements(getDriver()).size(); + + log("getResponse: Num of responses: " + _numOfResponses); + + return Locator.tagWithClass("div", "genaiResponse") + .findElements(getDriver()).getLast().getText(); + } + + public List getAllResponses() + { + waitFor(() -> { + try + { + return !BootstrapLocators.loadingSpinner.areAnyVisible(getDriver()) + && elementCache().chatPrompt.isDisplayed(); + } + catch (NoSuchElementException | StaleElementReferenceException | TimeoutException retry) + { + return false; + } + }, "Timed out waiting for the current process to stop.", 120_000); + + List responses = Locator.tagWithClass("div", "genaiResponse").findElements(getDriver()); + return responses.stream().map(el -> getTextContent(el).trim()).collect(Collectors.toList()); + + } + + @Override + protected ElementCache newElementCache() + { + return new ElementCache(); + } + + protected class ElementCache extends LabKeyPage.ElementCache + { + WebElement chatPrompt = Locator.tagWithId("textarea", "chatPrompt") + .refindWhenNeeded(this); + } +} diff --git a/src/org/labkey/test/tests/testChatTest.java b/src/org/labkey/test/tests/testChatTest.java new file mode 100644 index 0000000000..e40c7d3b98 --- /dev/null +++ b/src/org/labkey/test/tests/testChatTest.java @@ -0,0 +1,145 @@ +package org.labkey.test.tests; + + +import ai.djl.MalformedModelException; +import ai.djl.inference.Predictor; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.translate.TranslateException; +import org.junit.Test; +import org.labkey.test.BaseWebDriverTest; +import org.labkey.test.pages.TestChatPage; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +public class testChatTest extends BaseWebDriverTest +{ + static final String PROJ_NAME = "SomeSillyProject"; + + @Override + public BrowserType bestBrowser() + { + return BrowserType.CHROME; + } + + @Override + public List getAssociatedModules() + { + return Arrays.asList("experiment", "issues"); + } + @Override + protected String getProjectName() + { + return PROJ_NAME; + } + + @Test + public void testChat() + { + + String logFormatString = "\nTest %s\nResponse 1: %s\nResponse 2: %s\nDeviation: %e"; + + log("First, check your math."); + String s1 = "Hello world!"; + String s2 = "Did you see the parade today?"; + + double derivation = calculateDeviation(s1, s2); + log(String.format(logFormatString, + "Sanity Check", s1, s2, derivation)); + + TestChatPage testChatPage = TestChatPage.beginAt(this); + testChatPage.enterPrompt("Tell me about SampleManager."); + String response1 = testChatPage.getMostRecentResponse(); + String response2 = testChatPage.getAllResponses().getLast(); + derivation = calculateDeviation(response1, response2); + + log(String.format(logFormatString, + "Same Response String", response1, response2, derivation)); + + log("Ask the same question again."); + testChatPage.enterPrompt("Tell me about SampleManager."); + response2 = testChatPage.getMostRecentResponse(); + derivation = calculateDeviation(response1, response2); + + log(String.format(logFormatString, + "Ask The Question Again", response1, response2, derivation)); + + log("Now sign out and sign back in to try and change the response."); + signOut(); + signIn(); + + testChatPage = TestChatPage.beginAt(this); + testChatPage.enterPrompt("Tell me about SampleManager."); + response2 = testChatPage.getMostRecentResponse(); + derivation = calculateDeviation(response1, response2); + + log(String.format(logFormatString, + "Log Out and Back In", response1, response2, derivation)); + + } + + private double calculateDeviation(String str01, String str02) + { + double deviation; + + // Conceptual snippet using DJL for Semantic Similarity +// Criteria criteria = Criteria.builder() +// .setTypes(String.class, float[].class) +// .optModelUrls("djl://ai.djl.huggingface.pytorch/sentence-transformers/all-MiniLM-L6-v2") +// .build(); + + Criteria criteria = Criteria.builder() + .setTypes(String.class, float[].class) + // Force the PyTorch engine and specify the Hugging Face path + .optEngine("PyTorch") + .optModelUrls("djl://ai.djl.huggingface.pytorch/sentence-transformers/all-MiniLM-L6-v2") + // This translator is often required to bridge the gap between String and the Model's Tensor input + .optArgument("tokenizer", "sentence-transformers/all-MiniLM-L6-v2") + .build(); + + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor()) { + float[] vector1 = predictor.predict(str01); + float[] vector2 = predictor.predict(str02); + deviation = calculateCosineDistance(vector1, vector2); + + } + catch (TranslateException | IOException | ModelNotFoundException | MalformedModelException e) + { + throw new RuntimeException(e); + } + + return deviation; + + } + + // Linear algebra method that calculates the Cosine Similarity (how similar they are) and then converts it to + // Cosine Distance (the "deviation" or how far apart they are). + public static double calculateCosineDistance(float[] vectorA, float[] vectorB) { + + if (vectorA.length != vectorB.length) { + throw new IllegalArgumentException("Vectors must have the same dimension."); + } + + double dotProduct = 0.0; + double normA = 0.0; + double normB = 0.0; + + for (int i = 0; i < vectorA.length; i++) { + dotProduct += vectorA[i] * vectorB[i]; + normA += Math.pow(vectorA[i], 2); + normB += Math.pow(vectorB[i], 2); + } + + // Cosine Similarity Formula + double similarity = dotProduct / (Math.sqrt(normA) * Math.sqrt(normB)); + + // Return Cosine Distance (Deviation) + // 0.0 means identical, 1.0 means orthogonal (completely different) + return 1.0 - similarity; + } + +} From 973c23befe6f21351f9b1094a6c29622bf1002eb Mon Sep 17 00:00:00 2001 From: labkey-danield Date: Thu, 12 Feb 2026 12:30:56 -0800 Subject: [PATCH 2/3] Cleaning up the code a bit. --- src/org/labkey/test/tests/testChatTest.java | 76 ++++++++++++++------- 1 file changed, 53 insertions(+), 23 deletions(-) diff --git a/src/org/labkey/test/tests/testChatTest.java b/src/org/labkey/test/tests/testChatTest.java index e40c7d3b98..c4cecfd0e0 100644 --- a/src/org/labkey/test/tests/testChatTest.java +++ b/src/org/labkey/test/tests/testChatTest.java @@ -18,6 +18,7 @@ public class testChatTest extends BaseWebDriverTest { static final String PROJ_NAME = "SomeSillyProject"; + static double _cosine_diff = 0.0; @Override public BrowserType bestBrowser() @@ -40,24 +41,38 @@ protected String getProjectName() public void testChat() { - String logFormatString = "\nTest %s\nResponse 1: %s\nResponse 2: %s\nDeviation: %e"; + String logFormatString = "\nTest %s\nResponse 1: %s\nResponse 2: %s\nCosine Similarity (1.0-identical, 0.0-unrelated, negative-opposite): %e\nCosine Distance / Deviation (1.0-orthogonal, 0.0-no deviation): %e"; log("First, check your math."); - String s1 = "Hello world!"; - String s2 = "Did you see the parade today?"; + String response1 = "ABC"; + String response2 = "123456789"; - double derivation = calculateDeviation(s1, s2); + double derivation = calculateDeviation(response1, response2); log(String.format(logFormatString, - "Sanity Check", s1, s2, derivation)); + "Sanity Check", response1, response2, _cosine_diff, derivation)); + + response1 = "Patient is Healthy"; + response2 = "Patient is Dead"; + + derivation = calculateDeviation(response1, response2); + log(String.format(logFormatString, + "Two Different Meanings", response1, response2, _cosine_diff, derivation)); + + response1 = "The quick brown fox jumped over the lazy dog."; + response2 = ".dog lazy the over jumped fox brown quick The"; + + derivation = calculateDeviation(response1, response2); + log(String.format(logFormatString, + "Same Words Different Order", response1, response2, _cosine_diff, derivation)); TestChatPage testChatPage = TestChatPage.beginAt(this); testChatPage.enterPrompt("Tell me about SampleManager."); - String response1 = testChatPage.getMostRecentResponse(); - String response2 = testChatPage.getAllResponses().getLast(); + response1 = testChatPage.getMostRecentResponse(); + response2 = testChatPage.getAllResponses().getLast(); derivation = calculateDeviation(response1, response2); log(String.format(logFormatString, - "Same Response String", response1, response2, derivation)); + "Same Response String", response1, response2, _cosine_diff, derivation)); log("Ask the same question again."); testChatPage.enterPrompt("Tell me about SampleManager."); @@ -65,7 +80,7 @@ public void testChat() derivation = calculateDeviation(response1, response2); log(String.format(logFormatString, - "Ask The Question Again", response1, response2, derivation)); + "Ask The Question Again", response1, response2, _cosine_diff, derivation)); log("Now sign out and sign back in to try and change the response."); signOut(); @@ -77,7 +92,7 @@ public void testChat() derivation = calculateDeviation(response1, response2); log(String.format(logFormatString, - "Log Out and Back In", response1, response2, derivation)); + "Log Out and Back In", response1, response2, _cosine_diff, derivation)); } @@ -85,21 +100,26 @@ private double calculateDeviation(String str01, String str02) { double deviation; - // Conceptual snippet using DJL for Semantic Similarity -// Criteria criteria = Criteria.builder() -// .setTypes(String.class, float[].class) -// .optModelUrls("djl://ai.djl.huggingface.pytorch/sentence-transformers/all-MiniLM-L6-v2") -// .build(); - Criteria criteria = Criteria.builder() .setTypes(String.class, float[].class) // Force the PyTorch engine and specify the Hugging Face path .optEngine("PyTorch") + //Load the model: sentence-transformers/all-MiniLM-L6-v2 + //This is the MiniLM embedding model: + // 384-dimensional output vectors + // Optimized for semantic similarity + //all-MiniLM-L6-v2: + // all -> The model was trained on a massive, diverse dataset. + // L6 -> Depth of the neural network. This is a 6-layer model (faster than L12). + // v2 -> Second version of the model. .optModelUrls("djl://ai.djl.huggingface.pytorch/sentence-transformers/all-MiniLM-L6-v2") - // This translator is often required to bridge the gap between String and the Model's Tensor input + // This translator is often required to bridge the gap between String and the Model's Tensor input. + // Sentence transformer models require tokenization before inference. .optArgument("tokenizer", "sentence-transformers/all-MiniLM-L6-v2") .build(); + // Loading the model is expensive. Could / should pool it. + // Predictor is not thread safe. try (ZooModel model = criteria.loadModel(); Predictor predictor = model.newPredictor()) { float[] vector1 = predictor.predict(str01); @@ -116,8 +136,8 @@ private double calculateDeviation(String str01, String str02) } - // Linear algebra method that calculates the Cosine Similarity (how similar they are) and then converts it to - // Cosine Distance (the "deviation" or how far apart they are). + // Linear algebra method that calculates the Cosine Similarity (how relevant they are to each other) and then + // converts it to Cosine Distance (the "deviation" or how far apart they are). public static double calculateCosineDistance(float[] vectorA, float[] vectorB) { if (vectorA.length != vectorB.length) { @@ -130,12 +150,22 @@ public static double calculateCosineDistance(float[] vectorA, float[] vectorB) { for (int i = 0; i < vectorA.length; i++) { dotProduct += vectorA[i] * vectorB[i]; - normA += Math.pow(vectorA[i], 2); - normB += Math.pow(vectorB[i], 2); + normA += vectorA[i] * vectorA[i]; + normB += vectorB[i] * vectorB[i]; } - // Cosine Similarity Formula - double similarity = dotProduct / (Math.sqrt(normA) * Math.sqrt(normB)); + double magnitude = (Math.sqrt(normA) * Math.sqrt(normB)); + if (magnitude == 0.0) + { + return 1.0; // Maximum distance, completely dismilar. + } + // Cosine Similarity Formula. + // Measures the cosine of the angle between two vectors (the text responses converted into numbers). + // dot(A, B) / (||A|| * ||B||) + // Dot Product of A & B divided by the magnitude, Euclidean Norms (lengths) of the vectors multiplied together. + // Range from -1 to 1. 1.0 means identical, 0.0 unrelated, negative is opposite. + double similarity = dotProduct / magnitude; + _cosine_diff = similarity; // Return Cosine Distance (Deviation) // 0.0 means identical, 1.0 means orthogonal (completely different) From 58d2aa8fbce78a70323da3d3d85e0ccf1b93e85a Mon Sep 17 00:00:00 2001 From: labkey-danield Date: Thu, 12 Feb 2026 15:39:45 -0800 Subject: [PATCH 3/3] Checking in changes before changing branch name. --- src/org/labkey/test/tests/testChatTest.java | 138 ++++++++++++-------- 1 file changed, 84 insertions(+), 54 deletions(-) diff --git a/src/org/labkey/test/tests/testChatTest.java b/src/org/labkey/test/tests/testChatTest.java index c4cecfd0e0..02930e9cc4 100644 --- a/src/org/labkey/test/tests/testChatTest.java +++ b/src/org/labkey/test/tests/testChatTest.java @@ -7,6 +7,7 @@ import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.repository.zoo.ZooModel; import ai.djl.translate.TranslateException; +import org.junit.Before; import org.junit.Test; import org.labkey.test.BaseWebDriverTest; import org.labkey.test.pages.TestChatPage; @@ -18,8 +19,15 @@ public class testChatTest extends BaseWebDriverTest { static final String PROJ_NAME = "SomeSillyProject"; + static final String LOG_FORMAT_STRING = "\n\nTest %s\nString 1: %s\nString 2: %s\nCosine Similarity (1.0-identical, 0.0-unrelated, negative-opposite): %e\nCosine Distance / Deviation (1.0-orthogonal, 0.0-no deviation): %e\n"; static double _cosine_diff = 0.0; + Criteria _criteria = null; + // Loading the model is expensive. Could / should pool it. + ZooModel _model = null; + // Predictor is not thread safe. + Predictor _predictor = null; + @Override public BrowserType bestBrowser() { @@ -37,49 +45,92 @@ protected String getProjectName() return PROJ_NAME; } + @Before + public void buildStuffIfNeeded() + { + + if (null == _criteria) + { + _criteria = Criteria.builder() + .setTypes(String.class, float[].class) + // Force the PyTorch engine and specify the Hugging Face path + .optEngine("PyTorch") + //Load the model: sentence-transformers/all-MiniLM-L6-v2 + //This is the MiniLM embedding model: + // 384-dimensional output vectors + // Optimized for semantic similarity + //all-MiniLM-L6-v2: + // all -> The model was trained on a massive, diverse dataset. + // L6 -> Depth of the neural network. This is a 6-layer model (faster than L12). + // v2 -> Second version of the model. + .optModelUrls("djl://ai.djl.huggingface.pytorch/sentence-transformers/all-MiniLM-L6-v2") + // This translator is often required to bridge the gap between String and the Model's Tensor input. + // Sentence transformer models require tokenization before inference. + .optArgument("tokenizer", "sentence-transformers/all-MiniLM-L6-v2") + .build(); + + try + { + _model = _criteria.loadModel(); + _predictor = _model.newPredictor(); + } + catch (IOException | ModelNotFoundException | MalformedModelException e) + { + throw new RuntimeException(e); + } + + } + + } + @Test - public void testChat() + public void testBaseLine() { + log("Check the math."); + String string01 = "ABC"; - String logFormatString = "\nTest %s\nResponse 1: %s\nResponse 2: %s\nCosine Similarity (1.0-identical, 0.0-unrelated, negative-opposite): %e\nCosine Distance / Deviation (1.0-orthogonal, 0.0-no deviation): %e"; + double derivation = calculateDeviation(string01, string01); + log(String.format(LOG_FORMAT_STRING, + "Same Response String", string01, string01, _cosine_diff, derivation)); - log("First, check your math."); - String response1 = "ABC"; - String response2 = "123456789"; + String string02 = "123456789"; - double derivation = calculateDeviation(response1, response2); - log(String.format(logFormatString, - "Sanity Check", response1, response2, _cosine_diff, derivation)); + derivation = calculateDeviation(string01, string02); + log(String.format(LOG_FORMAT_STRING, + "Different Check", string01, string02, _cosine_diff, derivation)); - response1 = "Patient is Healthy"; - response2 = "Patient is Dead"; + string01 = "Patient is Healthy"; + string02 = "Patient is Dead"; - derivation = calculateDeviation(response1, response2); - log(String.format(logFormatString, - "Two Different Meanings", response1, response2, _cosine_diff, derivation)); + derivation = calculateDeviation(string01, string02); + log(String.format(LOG_FORMAT_STRING, + "Two Different Meanings", string01, string02, _cosine_diff, derivation)); - response1 = "The quick brown fox jumped over the lazy dog."; - response2 = ".dog lazy the over jumped fox brown quick The"; + string01 = "The quick brown fox jumped over the lazy dog."; + string02 = ".dog lazy the over jumped fox brown quick The"; - derivation = calculateDeviation(response1, response2); - log(String.format(logFormatString, - "Same Words Different Order", response1, response2, _cosine_diff, derivation)); + derivation = calculateDeviation(string01, string02); + log(String.format(LOG_FORMAT_STRING, + "Same Words Different Order", string01, string02, _cosine_diff, derivation)); + + } + + @Test + public void testChat() + { + + log("Test the chat app."); TestChatPage testChatPage = TestChatPage.beginAt(this); testChatPage.enterPrompt("Tell me about SampleManager."); - response1 = testChatPage.getMostRecentResponse(); - response2 = testChatPage.getAllResponses().getLast(); - derivation = calculateDeviation(response1, response2); - - log(String.format(logFormatString, - "Same Response String", response1, response2, _cosine_diff, derivation)); + String response1 = testChatPage.getMostRecentResponse(); log("Ask the same question again."); testChatPage.enterPrompt("Tell me about SampleManager."); - response2 = testChatPage.getMostRecentResponse(); - derivation = calculateDeviation(response1, response2); + String response2 = testChatPage.getMostRecentResponse(); + double derivation = calculateDeviation(response1, response2); - log(String.format(logFormatString, + log(String.format(LOG_FORMAT_STRING, "Ask The Question Again", response1, response2, _cosine_diff, derivation)); log("Now sign out and sign back in to try and change the response."); @@ -91,7 +142,7 @@ public void testChat() response2 = testChatPage.getMostRecentResponse(); derivation = calculateDeviation(response1, response2); - log(String.format(logFormatString, + log(String.format(LOG_FORMAT_STRING, "Log Out and Back In", response1, response2, _cosine_diff, derivation)); } @@ -100,34 +151,13 @@ private double calculateDeviation(String str01, String str02) { double deviation; - Criteria criteria = Criteria.builder() - .setTypes(String.class, float[].class) - // Force the PyTorch engine and specify the Hugging Face path - .optEngine("PyTorch") - //Load the model: sentence-transformers/all-MiniLM-L6-v2 - //This is the MiniLM embedding model: - // 384-dimensional output vectors - // Optimized for semantic similarity - //all-MiniLM-L6-v2: - // all -> The model was trained on a massive, diverse dataset. - // L6 -> Depth of the neural network. This is a 6-layer model (faster than L12). - // v2 -> Second version of the model. - .optModelUrls("djl://ai.djl.huggingface.pytorch/sentence-transformers/all-MiniLM-L6-v2") - // This translator is often required to bridge the gap between String and the Model's Tensor input. - // Sentence transformer models require tokenization before inference. - .optArgument("tokenizer", "sentence-transformers/all-MiniLM-L6-v2") - .build(); - - // Loading the model is expensive. Could / should pool it. - // Predictor is not thread safe. - try (ZooModel model = criteria.loadModel(); - Predictor predictor = model.newPredictor()) { - float[] vector1 = predictor.predict(str01); - float[] vector2 = predictor.predict(str02); + try { + float[] vector1 = _predictor.predict(str01); + float[] vector2 = _predictor.predict(str02); deviation = calculateCosineDistance(vector1, vector2); } - catch (TranslateException | IOException | ModelNotFoundException | MalformedModelException e) + catch (TranslateException e) { throw new RuntimeException(e); } @@ -136,7 +166,7 @@ private double calculateDeviation(String str01, String str02) } - // Linear algebra method that calculates the Cosine Similarity (how relevant they are to each other) and then + // Method using linear algebra that calculates the Cosine Similarity (how relevant they are to each other) and then // converts it to Cosine Distance (the "deviation" or how far apart they are). public static double calculateCosineDistance(float[] vectorA, float[] vectorB) {