diff --git a/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/model/Langchain4jModelAdapter.java b/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/model/Langchain4jModelAdapter.java new file mode 100644 index 000000000000..3d7f1d801491 --- /dev/null +++ b/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/model/Langchain4jModelAdapter.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.languagemodels.textvectorisation.model; + +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.model.embedding.EmbeddingModel; +import java.lang.invoke.MethodHandles; +import java.lang.reflect.Method; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Map; +import org.apache.solr.common.SolrException; +import org.apache.solr.core.SolrResourceLoader; +import org.apache.solr.languagemodels.textvectorisation.store.TextToVectorModelException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Adapter that wraps a LangChain4j {@link EmbeddingModel} to implement the Solr-native {@link + * TextToVectorModel} interface. + * + *

This provides backward compatibility for existing LangChain4j model configurations (OpenAI, + * Cohere, HuggingFace, MistralAI). + */ +public class Langchain4jModelAdapter implements TextToVectorModel { + private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + + private static final String TIMEOUT_PARAM = "timeout"; + private static final String MAX_SEGMENTS_PER_BATCH_PARAM = "maxSegmentsPerBatch"; + private static final String MAX_RETRIES_PARAM = "maxRetries"; + + private final EmbeddingModel delegate; + + public Langchain4jModelAdapter(EmbeddingModel delegate) { + this.delegate = delegate; + } + + /** + * Create a Langchain4jModelAdapter by instantiating the specified EmbeddingModel class using + * reflection and the builder pattern. + */ + public static Langchain4jModelAdapter create( + SolrResourceLoader solrResourceLoader, String className, Map params) + throws TextToVectorModelException { + try { + Class modelClass = solrResourceLoader.findClass(className, EmbeddingModel.class); + var builder = modelClass.getMethod("builder").invoke(null); + + if (params != null) { + /* + * This block of code has the responsibility of instantiate a {@link + * dev.langchain4j.model.embedding.EmbeddingModel} using the params provided.classes have + * params of The specific implementation of {@link + * dev.langchain4j.model.embedding.EmbeddingModel} is not known beforehand. So we benefit of + * the design choice in langchain4j that each subclass implementing {@link + * dev.langchain4j.model.embedding.EmbeddingModel} uses setters with the same name of the + * param. + */ + for (String paramName : params.keySet()) { + /* + * When a param is not primitive, we need to instantiate the object explicitly and then call the + * setter method. + * N.B. when adding support to new models, pay attention to all the parameters they + * support, some of them may require to be handled in here as separate switch cases + */ + switch (paramName) { + case TIMEOUT_PARAM -> { + Duration timeOut = Duration.ofSeconds((Long) params.get(paramName)); + builder.getClass().getMethod(paramName, Duration.class).invoke(builder, timeOut); + } + case MAX_SEGMENTS_PER_BATCH_PARAM, MAX_RETRIES_PARAM -> builder + .getClass() + .getMethod(paramName, Integer.class) + .invoke(builder, ((Long) params.get(paramName)).intValue()); + + /* + * For primitive params if there's only one setter available, we call it. + * If there's choice we default to the string one + */ + default -> { + ArrayList paramNameMatches = new ArrayList<>(); + for (var method : builder.getClass().getMethods()) { + if (paramName.equals(method.getName()) && method.getParameterCount() == 1) { + paramNameMatches.add(method); + } + } + if (paramNameMatches.size() == 1) { + Method method = paramNameMatches.getFirst(); + Class paramType = method.getParameterTypes()[0]; + Object convertedValue = + ModelConfigUtils.convertValue(params.get(paramName), paramType); + method.invoke(builder, convertedValue); + } else { + try { + builder + .getClass() + .getMethod(paramName, String.class) + .invoke(builder, params.get(paramName).toString()); + } catch (NoSuchMethodException e) { + log.error("Parameter {} not supported by model {}", paramName, className); + throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e.getMessage(), e); + } + } + } + } + } + } + + EmbeddingModel embeddingModel = + (EmbeddingModel) builder.getClass().getMethod("build").invoke(builder); + return new Langchain4jModelAdapter(embeddingModel); + } catch (final Exception e) { + throw new TextToVectorModelException("LangChain4j model loading failed for " + className, e); + } + } + + @Override + public float[] vectorise(String text) { + Embedding vector = delegate.embed(text).content(); + return vector.vector(); + } +} diff --git a/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/model/ModelConfigUtils.java b/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/model/ModelConfigUtils.java new file mode 100644 index 000000000000..84c19e94b01b --- /dev/null +++ b/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/model/ModelConfigUtils.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.languagemodels.textvectorisation.model; + +public class ModelConfigUtils { + + /** Convert JSON-parsed values to the expected parameter type */ + public static Object convertValue(Object value, Class targetType) { + if (value == null) return null; + + if (targetType.isAssignableFrom(value.getClass())) { + return value; + } + + // Handle common type conversions from JSON parsing + if (targetType == int.class || targetType == Integer.class) { + if (value instanceof Long) return ((Long) value).intValue(); + if (value instanceof String) return Integer.parseInt((String) value); + } + if (targetType == long.class || targetType == Long.class) { + if (value instanceof Integer) return ((Integer) value).longValue(); + if (value instanceof String) return Long.parseLong((String) value); + } + if (targetType == double.class || targetType == Double.class) { + if (value instanceof Number) return ((Number) value).doubleValue(); + if (value instanceof String) return Double.parseDouble((String) value); + } + if (targetType == float.class || targetType == Float.class) { + if (value instanceof Number) return ((Number) value).floatValue(); + if (value instanceof String) return Float.parseFloat((String) value); + } + if (targetType == boolean.class || targetType == Boolean.class) { + if (value instanceof String) return Boolean.parseBoolean((String) value); + } + if (targetType == String.class) { + return value.toString(); + } + + return value; + } +} diff --git a/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/model/SolrTextToVectorModel.java b/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/model/SolrTextToVectorModel.java index 21f7f8035be7..403375b6d9b3 100644 --- a/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/model/SolrTextToVectorModel.java +++ b/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/model/SolrTextToVectorModel.java @@ -16,17 +16,12 @@ */ package org.apache.solr.languagemodels.textvectorisation.model; -import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.model.embedding.EmbeddingModel; import java.lang.invoke.MethodHandles; -import java.lang.reflect.Method; -import java.time.Duration; -import java.util.ArrayList; import java.util.Map; import java.util.Objects; import org.apache.lucene.util.Accountable; import org.apache.lucene.util.RamUsageEstimator; -import org.apache.solr.common.SolrException; import org.apache.solr.core.SolrResourceLoader; import org.apache.solr.languagemodels.textvectorisation.store.TextToVectorModelException; import org.apache.solr.languagemodels.textvectorisation.store.rest.ManagedTextToVectorModelStore; @@ -34,21 +29,26 @@ import org.slf4j.LoggerFactory; /** - * This object wraps a {@link dev.langchain4j.model.embedding.EmbeddingModel} to encode text to - * vector. It's meant to be used as a managed resource with the {@link - * ManagedTextToVectorModelStore} + * This object wraps a {@link TextToVectorModel} to encode text to vector. It's meant to be used as + * a managed resource with the {@link ManagedTextToVectorModelStore}. + * + *

Supports two types of model implementations: + * + *

*/ public class SolrTextToVectorModel implements Accountable { private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); private static final long BASE_RAM_BYTES = RamUsageEstimator.shallowSizeOfInstance(SolrTextToVectorModel.class); - private static final String TIMEOUT_PARAM = "timeout"; - private static final String MAX_SEGMENTS_PER_BATCH_PARAM = "maxSegmentsPerBatch"; - private static final String MAX_RETRIES_PARAM = "maxRetries"; private final String name; + private final String className; private final Map params; - private final EmbeddingModel textToVector; + private final TextToVectorModel textToVector; private final int hashCode; public static SolrTextToVectorModel getInstance( @@ -57,89 +57,117 @@ public static SolrTextToVectorModel getInstance( String name, Map params) throws TextToVectorModelException { + + TextToVectorModel textToVector = createModel(solrResourceLoader, className, params); + return new SolrTextToVectorModel(name, className, textToVector, params); + } + + /** + * Create a TextToVectorModel instance from the given class name. First tries to load as a + * Solr-native TextToVectorModel, then falls back to LangChain4j EmbeddingModel wrapped in an + * adapter. + */ + private static TextToVectorModel createModel( + SolrResourceLoader solrResourceLoader, String className, Map params) + throws TextToVectorModelException { + + // First, try to load as a Solr-native TextToVectorModel try { - /* - * The idea here is to build a {@link dev.langchain4j.model.embedding.EmbeddingModel} using inversion - * of control. - * Each model has its own list of parameters we don't know beforehand, but each {@link dev.langchain4j.model.embedding.EmbeddingModel} class - * has its own builder that uses setters with the same name of the parameter in input. - * */ - EmbeddingModel textToVector; - Class modelClass = solrResourceLoader.findClass(className, EmbeddingModel.class); - var builder = modelClass.getMethod("builder").invoke(null); - if (params != null) { - /* - * This block of code has the responsibility of instantiate a {@link - * dev.langchain4j.model.embedding.EmbeddingModel} using the params provided.classes have - * params of The specific implementation of {@link - * dev.langchain4j.model.embedding.EmbeddingModel} is not known beforehand. So we benefit of - * the design choice in langchain4j that each subclass implementing {@link - * dev.langchain4j.model.embedding.EmbeddingModel} uses setters with the same name of the - * param. - */ - for (String paramName : params.keySet()) { - /* - * When a param is not primitive, we need to instantiate the object explicitly and then call the - * setter method. - * N.B. when adding support to new models, pay attention to all the parameters they - * support, some of them may require to be handled in here as separate switch cases - */ - switch (paramName) { - case TIMEOUT_PARAM -> { - Duration timeOut = Duration.ofSeconds((Long) params.get(paramName)); - builder.getClass().getMethod(paramName, Duration.class).invoke(builder, timeOut); - } - case MAX_SEGMENTS_PER_BATCH_PARAM, MAX_RETRIES_PARAM -> builder - .getClass() - .getMethod(paramName, Integer.class) - .invoke(builder, ((Long) params.get(paramName)).intValue()); - - /* - * For primitive params if there's only one setter available, we call it. - * If there's choice we default to the string one - */ - default -> { - ArrayList paramNameMatches = new ArrayList<>(); + Class clazz = solrResourceLoader.findClass(className, Object.class); + + if (TextToVectorModel.class.isAssignableFrom(clazz)) { + log.info("Loading Solr-native TextToVectorModel: {}", className); + return createSolrNativeModel(solrResourceLoader, className, params); + } + + if (EmbeddingModel.class.isAssignableFrom(clazz)) { + log.info("Loading LangChain4j EmbeddingModel via adapter: {}", className); + return Langchain4jModelAdapter.create(solrResourceLoader, className, params); + } + + throw new TextToVectorModelException( + "Class " + + className + + " must implement either " + + TextToVectorModel.class.getName() + + " or " + + EmbeddingModel.class.getName()); + + } catch (TextToVectorModelException e) { + throw e; + } catch (Exception e) { + throw new TextToVectorModelException("Model loading failed for " + className, e); + } + } + + /** + * Create a Solr-native TextToVectorModel using either: 1. Builder pattern (if static builder() + * method exists) 2. No-arg constructor + init(params) + */ + private static TextToVectorModel createSolrNativeModel( + SolrResourceLoader solrResourceLoader, String className, Map params) + throws TextToVectorModelException { + try { + Class modelClass = solrResourceLoader.findClass(className, TextToVectorModel.class); + + TextToVectorModel model; + + // Try builder pattern first + try { + var builderMethod = modelClass.getMethod("builder"); + var builder = builderMethod.invoke(null); + + // Apply params to builder using reflection + if (params != null) { + for (Map.Entry entry : params.entrySet()) { + String paramName = entry.getKey(); + Object paramValue = entry.getValue(); + try { + // Find matching setter method on builder for (var method : builder.getClass().getMethods()) { if (paramName.equals(method.getName()) && method.getParameterCount() == 1) { - paramNameMatches.add(method); - } - } - if (paramNameMatches.size() == 1) { - paramNameMatches.getFirst().invoke(builder, params.get(paramName)); - } else { - try { - builder - .getClass() - .getMethod(paramName, String.class) - .invoke(builder, params.get(paramName).toString()); - } catch (NoSuchMethodException e) { - log.error("Parameter {} not supported by model {}", paramName, className); - throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e.getMessage(), e); + Class paramType = method.getParameterTypes()[0]; + Object convertedValue = ModelConfigUtils.convertValue(paramValue, paramType); + method.invoke(builder, convertedValue); + break; } } + } catch (Exception e) { + log.warn("Could not set parameter {} on builder for {}", paramName, className, e); } } } + + model = (TextToVectorModel) builder.getClass().getMethod("build").invoke(builder); + log.debug("Created {} using builder pattern", className); + + } catch (NoSuchMethodException e) { + // Fall back to no-arg constructor + init() + model = (TextToVectorModel) modelClass.getDeclaredConstructor().newInstance(); + if (params != null) { + model.init(params); + } + log.debug("Created {} using no-arg constructor + init()", className); } - textToVector = (EmbeddingModel) builder.getClass().getMethod("build").invoke(builder); - return new SolrTextToVectorModel(name, textToVector, params); - } catch (final Exception e) { - throw new TextToVectorModelException("Model loading failed for " + className, e); + + return model; + + } catch (Exception e) { + throw new TextToVectorModelException("Failed to create Solr-native model: " + className, e); } } public SolrTextToVectorModel( - String name, EmbeddingModel textToVector, Map params) { + String name, String className, TextToVectorModel textToVector, Map params) { this.name = name; + this.className = className; this.textToVector = textToVector; this.params = params; this.hashCode = calculateHashCode(); } public float[] vectorise(String text) { - Embedding vector = textToVector.embed(text).content(); - return vector.vector(); + return textToVector.vectorise(text); } @Override @@ -180,7 +208,7 @@ public String getName() { } public String getEmbeddingModelClassName() { - return textToVector.getClass().getName(); + return className; } public Map getParams() { diff --git a/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/model/TextToVectorModel.java b/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/model/TextToVectorModel.java new file mode 100644 index 000000000000..1bbcca7dfb83 --- /dev/null +++ b/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/model/TextToVectorModel.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.languagemodels.textvectorisation.model; + +import java.util.Map; + +/** + * Interface for text-to-vector embedding models. + * + *

Implement this interface to create custom embedding model integrations (e.g., Triton Inference + * Server, TensorFlow Serving, or any custom HTTP endpoint). + * + *

Implementations must have either: + * + *

    + *
  • A public no-arg constructor, or + *
  • A static {@code builder()} method returning a builder with a {@code build()} method + *
+ * + *

Example usage in model configuration: + * + *

+ * {
+ *   "class": "com.example.MyCustomEmbeddingModel",
+ *   "name": "my-model",
+ *   "params": {
+ *     "endpoint": "http://my-server:8000/embed",
+ *     "api_key": "my-api-key",
+ *     "dimension": 384
+ *   }
+ * }
+ * 
+ */ +public interface TextToVectorModel { + /** + * Convert text to a vector embedding. + * + * @param text the input text to vectorise + * @return the embedding vector as a float array + */ + float[] vectorise(String text); + + /** + * Initialize the model with configuration parameters. Called after construction with the params + * from the model configuration. + * + * @param params the configuration parameters from the model JSON + */ + default void init(Map params) {} +} diff --git a/solr/modules/language-models/src/test-files/modelExamples/dummy-custom-model.json b/solr/modules/language-models/src/test-files/modelExamples/dummy-custom-model.json new file mode 100644 index 000000000000..dc85289e7f13 --- /dev/null +++ b/solr/modules/language-models/src/test-files/modelExamples/dummy-custom-model.json @@ -0,0 +1,7 @@ +{ + "class": "org.apache.solr.languagemodels.textvectorisation.model.DummyTextToVectorModel", + "name": "dummy-1", + "params": { + "embedding": [1.0, 2.0, 3.0, 4.0] + } +} diff --git a/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/model/DummyTextToVectorModel.java b/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/model/DummyTextToVectorModel.java new file mode 100644 index 000000000000..fd9af60916e9 --- /dev/null +++ b/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/model/DummyTextToVectorModel.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.languagemodels.textvectorisation.model; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +public class DummyTextToVectorModel implements TextToVectorModel { + private float[] vector; + + public DummyTextToVectorModel() { + } + + @Override + public float[] vectorise(String text) { + return vector; + } + + @Override + public void init(Map params) { + List embeddings = (List) params.get("embedding"); + float[] floatArray = new float[embeddings.size()]; + for (int i = 0; i < embeddings.size(); i++) { + floatArray[i] = ((Number) embeddings.get(i)).floatValue(); + } + this.vector = floatArray; + } +} diff --git a/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/model/DummyTextToVectorModelTest.java b/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/model/DummyTextToVectorModelTest.java new file mode 100644 index 000000000000..8ed068f5c233 --- /dev/null +++ b/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/model/DummyTextToVectorModelTest.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.languagemodels.textvectorisation.model; + +import java.util.Arrays; +import java.util.Map; +import org.apache.solr.SolrTestCase; +import org.junit.Test; + +public class DummyTextToVectorModelTest extends SolrTestCase { + @Test + public void testVectorise() { + DummyTextToVectorModel model = new DummyTextToVectorModel(); + model.init(Map.of("embedding", new float[] {1, 2, 3})); + float[] vector = model.vectorise("test"); + assertEquals("[1.0, 2.0, 3.0]", Arrays.toString(vector)); + } +} diff --git a/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/update/processor/TextToVectorUpdateProcessorFactoryTest.java b/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/update/processor/TextToVectorUpdateProcessorFactoryTest.java index 5ccb9d95e605..007592dcab25 100644 --- a/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/update/processor/TextToVectorUpdateProcessorFactoryTest.java +++ b/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/update/processor/TextToVectorUpdateProcessorFactoryTest.java @@ -188,7 +188,7 @@ private UpdateRequestProcessor createUpdateProcessor( NamedList args = new NamedList<>(); ManagedTextToVectorModelStore.getManagedModelStore(collection1) - .addModel(new SolrTextToVectorModel(modelName, null, null)); + .addModel(new SolrTextToVectorModel(modelName, null, null, null)); args.add("inputField", inputFieldName); args.add("outputField", outputFieldName); args.add("model", modelName); diff --git a/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/update/processor/TextToVectorUpdateProcessorTest.java b/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/update/processor/TextToVectorUpdateProcessorTest.java index 2dd9cda1a585..7178eda3fd59 100644 --- a/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/update/processor/TextToVectorUpdateProcessorTest.java +++ b/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/update/processor/TextToVectorUpdateProcessorTest.java @@ -51,7 +51,16 @@ public void afterEachTest() throws Exception { @Test public void processAdd_inputField_shouldVectoriseInputField() throws Exception { - loadModel("dummy-model.json"); // preparation + assertVectorisationWithModel("dummy-model.json"); + } + + @Test + public void processAdd_customModel_shouldVectoriseInputField() throws Exception { + assertVectorisationWithModel("dummy-custom-model.json"); + } + + private void assertVectorisationWithModel(String modelJsonFile) throws Exception { + loadModel(modelJsonFile); addWithChain(sdoc("id", "99", "_text_", "Vegeta is the saiyan prince."), "textToVector"); addWithChain( @@ -68,8 +77,6 @@ public void processAdd_inputField_shouldVectoriseInputField() throws Exception { "/response/docs/[0]/vector==[1.0, 2.0, 3.0, 4.0]", "/response/docs/[1]/id=='98'", "/response/docs/[1]/vector==[1.0, 2.0, 3.0, 4.0]"); - - restTestHarness.delete(ManagedTextToVectorModelStore.REST_END_POINT + "/dummy-1"); // clean up } private SolrQuery getSolrQuery() {