diff --git a/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/model/Langchain4jModelAdapter.java b/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/model/Langchain4jModelAdapter.java
new file mode 100644
index 000000000000..3d7f1d801491
--- /dev/null
+++ b/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/model/Langchain4jModelAdapter.java
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.languagemodels.textvectorisation.model;
+
+import dev.langchain4j.data.embedding.Embedding;
+import dev.langchain4j.model.embedding.EmbeddingModel;
+import java.lang.invoke.MethodHandles;
+import java.lang.reflect.Method;
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.Map;
+import org.apache.solr.common.SolrException;
+import org.apache.solr.core.SolrResourceLoader;
+import org.apache.solr.languagemodels.textvectorisation.store.TextToVectorModelException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Adapter that wraps a LangChain4j {@link EmbeddingModel} to implement the Solr-native {@link
+ * TextToVectorModel} interface.
+ *
+ *
This provides backward compatibility for existing LangChain4j model configurations (OpenAI,
+ * Cohere, HuggingFace, MistralAI).
+ */
+public class Langchain4jModelAdapter implements TextToVectorModel {
+ private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
+
+ private static final String TIMEOUT_PARAM = "timeout";
+ private static final String MAX_SEGMENTS_PER_BATCH_PARAM = "maxSegmentsPerBatch";
+ private static final String MAX_RETRIES_PARAM = "maxRetries";
+
+ private final EmbeddingModel delegate;
+
+ public Langchain4jModelAdapter(EmbeddingModel delegate) {
+ this.delegate = delegate;
+ }
+
+ /**
+ * Create a Langchain4jModelAdapter by instantiating the specified EmbeddingModel class using
+ * reflection and the builder pattern.
+ */
+ public static Langchain4jModelAdapter create(
+ SolrResourceLoader solrResourceLoader, String className, Map params)
+ throws TextToVectorModelException {
+ try {
+ Class> modelClass = solrResourceLoader.findClass(className, EmbeddingModel.class);
+ var builder = modelClass.getMethod("builder").invoke(null);
+
+ if (params != null) {
+ /*
+ * This block of code has the responsibility of instantiate a {@link
+ * dev.langchain4j.model.embedding.EmbeddingModel} using the params provided.classes have
+ * params of The specific implementation of {@link
+ * dev.langchain4j.model.embedding.EmbeddingModel} is not known beforehand. So we benefit of
+ * the design choice in langchain4j that each subclass implementing {@link
+ * dev.langchain4j.model.embedding.EmbeddingModel} uses setters with the same name of the
+ * param.
+ */
+ for (String paramName : params.keySet()) {
+ /*
+ * When a param is not primitive, we need to instantiate the object explicitly and then call the
+ * setter method.
+ * N.B. when adding support to new models, pay attention to all the parameters they
+ * support, some of them may require to be handled in here as separate switch cases
+ */
+ switch (paramName) {
+ case TIMEOUT_PARAM -> {
+ Duration timeOut = Duration.ofSeconds((Long) params.get(paramName));
+ builder.getClass().getMethod(paramName, Duration.class).invoke(builder, timeOut);
+ }
+ case MAX_SEGMENTS_PER_BATCH_PARAM, MAX_RETRIES_PARAM -> builder
+ .getClass()
+ .getMethod(paramName, Integer.class)
+ .invoke(builder, ((Long) params.get(paramName)).intValue());
+
+ /*
+ * For primitive params if there's only one setter available, we call it.
+ * If there's choice we default to the string one
+ */
+ default -> {
+ ArrayList paramNameMatches = new ArrayList<>();
+ for (var method : builder.getClass().getMethods()) {
+ if (paramName.equals(method.getName()) && method.getParameterCount() == 1) {
+ paramNameMatches.add(method);
+ }
+ }
+ if (paramNameMatches.size() == 1) {
+ Method method = paramNameMatches.getFirst();
+ Class> paramType = method.getParameterTypes()[0];
+ Object convertedValue =
+ ModelConfigUtils.convertValue(params.get(paramName), paramType);
+ method.invoke(builder, convertedValue);
+ } else {
+ try {
+ builder
+ .getClass()
+ .getMethod(paramName, String.class)
+ .invoke(builder, params.get(paramName).toString());
+ } catch (NoSuchMethodException e) {
+ log.error("Parameter {} not supported by model {}", paramName, className);
+ throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e.getMessage(), e);
+ }
+ }
+ }
+ }
+ }
+ }
+
+ EmbeddingModel embeddingModel =
+ (EmbeddingModel) builder.getClass().getMethod("build").invoke(builder);
+ return new Langchain4jModelAdapter(embeddingModel);
+ } catch (final Exception e) {
+ throw new TextToVectorModelException("LangChain4j model loading failed for " + className, e);
+ }
+ }
+
+ @Override
+ public float[] vectorise(String text) {
+ Embedding vector = delegate.embed(text).content();
+ return vector.vector();
+ }
+}
diff --git a/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/model/ModelConfigUtils.java b/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/model/ModelConfigUtils.java
new file mode 100644
index 000000000000..84c19e94b01b
--- /dev/null
+++ b/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/model/ModelConfigUtils.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.languagemodels.textvectorisation.model;
+
+public class ModelConfigUtils {
+
+ /** Convert JSON-parsed values to the expected parameter type */
+ public static Object convertValue(Object value, Class> targetType) {
+ if (value == null) return null;
+
+ if (targetType.isAssignableFrom(value.getClass())) {
+ return value;
+ }
+
+ // Handle common type conversions from JSON parsing
+ if (targetType == int.class || targetType == Integer.class) {
+ if (value instanceof Long) return ((Long) value).intValue();
+ if (value instanceof String) return Integer.parseInt((String) value);
+ }
+ if (targetType == long.class || targetType == Long.class) {
+ if (value instanceof Integer) return ((Integer) value).longValue();
+ if (value instanceof String) return Long.parseLong((String) value);
+ }
+ if (targetType == double.class || targetType == Double.class) {
+ if (value instanceof Number) return ((Number) value).doubleValue();
+ if (value instanceof String) return Double.parseDouble((String) value);
+ }
+ if (targetType == float.class || targetType == Float.class) {
+ if (value instanceof Number) return ((Number) value).floatValue();
+ if (value instanceof String) return Float.parseFloat((String) value);
+ }
+ if (targetType == boolean.class || targetType == Boolean.class) {
+ if (value instanceof String) return Boolean.parseBoolean((String) value);
+ }
+ if (targetType == String.class) {
+ return value.toString();
+ }
+
+ return value;
+ }
+}
diff --git a/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/model/SolrTextToVectorModel.java b/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/model/SolrTextToVectorModel.java
index 21f7f8035be7..403375b6d9b3 100644
--- a/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/model/SolrTextToVectorModel.java
+++ b/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/model/SolrTextToVectorModel.java
@@ -16,17 +16,12 @@
*/
package org.apache.solr.languagemodels.textvectorisation.model;
-import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.model.embedding.EmbeddingModel;
import java.lang.invoke.MethodHandles;
-import java.lang.reflect.Method;
-import java.time.Duration;
-import java.util.ArrayList;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.RamUsageEstimator;
-import org.apache.solr.common.SolrException;
import org.apache.solr.core.SolrResourceLoader;
import org.apache.solr.languagemodels.textvectorisation.store.TextToVectorModelException;
import org.apache.solr.languagemodels.textvectorisation.store.rest.ManagedTextToVectorModelStore;
@@ -34,21 +29,26 @@
import org.slf4j.LoggerFactory;
/**
- * This object wraps a {@link dev.langchain4j.model.embedding.EmbeddingModel} to encode text to
- * vector. It's meant to be used as a managed resource with the {@link
- * ManagedTextToVectorModelStore}
+ * This object wraps a {@link TextToVectorModel} to encode text to vector. It's meant to be used as
+ * a managed resource with the {@link ManagedTextToVectorModelStore}.
+ *
+ * Supports two types of model implementations:
+ *
+ *
+ * - Solr-native {@link TextToVectorModel} implementations for custom integrations
+ *
- LangChain4j {@link EmbeddingModel} implementations (wrapped via {@link
+ * Langchain4jModelAdapter})
+ *
*/
public class SolrTextToVectorModel implements Accountable {
private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
private static final long BASE_RAM_BYTES =
RamUsageEstimator.shallowSizeOfInstance(SolrTextToVectorModel.class);
- private static final String TIMEOUT_PARAM = "timeout";
- private static final String MAX_SEGMENTS_PER_BATCH_PARAM = "maxSegmentsPerBatch";
- private static final String MAX_RETRIES_PARAM = "maxRetries";
private final String name;
+ private final String className;
private final Map params;
- private final EmbeddingModel textToVector;
+ private final TextToVectorModel textToVector;
private final int hashCode;
public static SolrTextToVectorModel getInstance(
@@ -57,89 +57,117 @@ public static SolrTextToVectorModel getInstance(
String name,
Map params)
throws TextToVectorModelException {
+
+ TextToVectorModel textToVector = createModel(solrResourceLoader, className, params);
+ return new SolrTextToVectorModel(name, className, textToVector, params);
+ }
+
+ /**
+ * Create a TextToVectorModel instance from the given class name. First tries to load as a
+ * Solr-native TextToVectorModel, then falls back to LangChain4j EmbeddingModel wrapped in an
+ * adapter.
+ */
+ private static TextToVectorModel createModel(
+ SolrResourceLoader solrResourceLoader, String className, Map params)
+ throws TextToVectorModelException {
+
+ // First, try to load as a Solr-native TextToVectorModel
try {
- /*
- * The idea here is to build a {@link dev.langchain4j.model.embedding.EmbeddingModel} using inversion
- * of control.
- * Each model has its own list of parameters we don't know beforehand, but each {@link dev.langchain4j.model.embedding.EmbeddingModel} class
- * has its own builder that uses setters with the same name of the parameter in input.
- * */
- EmbeddingModel textToVector;
- Class> modelClass = solrResourceLoader.findClass(className, EmbeddingModel.class);
- var builder = modelClass.getMethod("builder").invoke(null);
- if (params != null) {
- /*
- * This block of code has the responsibility of instantiate a {@link
- * dev.langchain4j.model.embedding.EmbeddingModel} using the params provided.classes have
- * params of The specific implementation of {@link
- * dev.langchain4j.model.embedding.EmbeddingModel} is not known beforehand. So we benefit of
- * the design choice in langchain4j that each subclass implementing {@link
- * dev.langchain4j.model.embedding.EmbeddingModel} uses setters with the same name of the
- * param.
- */
- for (String paramName : params.keySet()) {
- /*
- * When a param is not primitive, we need to instantiate the object explicitly and then call the
- * setter method.
- * N.B. when adding support to new models, pay attention to all the parameters they
- * support, some of them may require to be handled in here as separate switch cases
- */
- switch (paramName) {
- case TIMEOUT_PARAM -> {
- Duration timeOut = Duration.ofSeconds((Long) params.get(paramName));
- builder.getClass().getMethod(paramName, Duration.class).invoke(builder, timeOut);
- }
- case MAX_SEGMENTS_PER_BATCH_PARAM, MAX_RETRIES_PARAM -> builder
- .getClass()
- .getMethod(paramName, Integer.class)
- .invoke(builder, ((Long) params.get(paramName)).intValue());
-
- /*
- * For primitive params if there's only one setter available, we call it.
- * If there's choice we default to the string one
- */
- default -> {
- ArrayList paramNameMatches = new ArrayList<>();
+ Class> clazz = solrResourceLoader.findClass(className, Object.class);
+
+ if (TextToVectorModel.class.isAssignableFrom(clazz)) {
+ log.info("Loading Solr-native TextToVectorModel: {}", className);
+ return createSolrNativeModel(solrResourceLoader, className, params);
+ }
+
+ if (EmbeddingModel.class.isAssignableFrom(clazz)) {
+ log.info("Loading LangChain4j EmbeddingModel via adapter: {}", className);
+ return Langchain4jModelAdapter.create(solrResourceLoader, className, params);
+ }
+
+ throw new TextToVectorModelException(
+ "Class "
+ + className
+ + " must implement either "
+ + TextToVectorModel.class.getName()
+ + " or "
+ + EmbeddingModel.class.getName());
+
+ } catch (TextToVectorModelException e) {
+ throw e;
+ } catch (Exception e) {
+ throw new TextToVectorModelException("Model loading failed for " + className, e);
+ }
+ }
+
+ /**
+ * Create a Solr-native TextToVectorModel using either: 1. Builder pattern (if static builder()
+ * method exists) 2. No-arg constructor + init(params)
+ */
+ private static TextToVectorModel createSolrNativeModel(
+ SolrResourceLoader solrResourceLoader, String className, Map params)
+ throws TextToVectorModelException {
+ try {
+ Class> modelClass = solrResourceLoader.findClass(className, TextToVectorModel.class);
+
+ TextToVectorModel model;
+
+ // Try builder pattern first
+ try {
+ var builderMethod = modelClass.getMethod("builder");
+ var builder = builderMethod.invoke(null);
+
+ // Apply params to builder using reflection
+ if (params != null) {
+ for (Map.Entry entry : params.entrySet()) {
+ String paramName = entry.getKey();
+ Object paramValue = entry.getValue();
+ try {
+ // Find matching setter method on builder
for (var method : builder.getClass().getMethods()) {
if (paramName.equals(method.getName()) && method.getParameterCount() == 1) {
- paramNameMatches.add(method);
- }
- }
- if (paramNameMatches.size() == 1) {
- paramNameMatches.getFirst().invoke(builder, params.get(paramName));
- } else {
- try {
- builder
- .getClass()
- .getMethod(paramName, String.class)
- .invoke(builder, params.get(paramName).toString());
- } catch (NoSuchMethodException e) {
- log.error("Parameter {} not supported by model {}", paramName, className);
- throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e.getMessage(), e);
+ Class> paramType = method.getParameterTypes()[0];
+ Object convertedValue = ModelConfigUtils.convertValue(paramValue, paramType);
+ method.invoke(builder, convertedValue);
+ break;
}
}
+ } catch (Exception e) {
+ log.warn("Could not set parameter {} on builder for {}", paramName, className, e);
}
}
}
+
+ model = (TextToVectorModel) builder.getClass().getMethod("build").invoke(builder);
+ log.debug("Created {} using builder pattern", className);
+
+ } catch (NoSuchMethodException e) {
+ // Fall back to no-arg constructor + init()
+ model = (TextToVectorModel) modelClass.getDeclaredConstructor().newInstance();
+ if (params != null) {
+ model.init(params);
+ }
+ log.debug("Created {} using no-arg constructor + init()", className);
}
- textToVector = (EmbeddingModel) builder.getClass().getMethod("build").invoke(builder);
- return new SolrTextToVectorModel(name, textToVector, params);
- } catch (final Exception e) {
- throw new TextToVectorModelException("Model loading failed for " + className, e);
+
+ return model;
+
+ } catch (Exception e) {
+ throw new TextToVectorModelException("Failed to create Solr-native model: " + className, e);
}
}
public SolrTextToVectorModel(
- String name, EmbeddingModel textToVector, Map params) {
+ String name, String className, TextToVectorModel textToVector, Map params) {
this.name = name;
+ this.className = className;
this.textToVector = textToVector;
this.params = params;
this.hashCode = calculateHashCode();
}
public float[] vectorise(String text) {
- Embedding vector = textToVector.embed(text).content();
- return vector.vector();
+ return textToVector.vectorise(text);
}
@Override
@@ -180,7 +208,7 @@ public String getName() {
}
public String getEmbeddingModelClassName() {
- return textToVector.getClass().getName();
+ return className;
}
public Map getParams() {
diff --git a/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/model/TextToVectorModel.java b/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/model/TextToVectorModel.java
new file mode 100644
index 000000000000..1bbcca7dfb83
--- /dev/null
+++ b/solr/modules/language-models/src/java/org/apache/solr/languagemodels/textvectorisation/model/TextToVectorModel.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.languagemodels.textvectorisation.model;
+
+import java.util.Map;
+
+/**
+ * Interface for text-to-vector embedding models.
+ *
+ * Implement this interface to create custom embedding model integrations (e.g., Triton Inference
+ * Server, TensorFlow Serving, or any custom HTTP endpoint).
+ *
+ *
Implementations must have either:
+ *
+ *
+ * - A public no-arg constructor, or
+ *
- A static {@code builder()} method returning a builder with a {@code build()} method
+ *
+ *
+ * Example usage in model configuration:
+ *
+ *
+ * {
+ * "class": "com.example.MyCustomEmbeddingModel",
+ * "name": "my-model",
+ * "params": {
+ * "endpoint": "http://my-server:8000/embed",
+ * "api_key": "my-api-key",
+ * "dimension": 384
+ * }
+ * }
+ *
+ */
+public interface TextToVectorModel {
+ /**
+ * Convert text to a vector embedding.
+ *
+ * @param text the input text to vectorise
+ * @return the embedding vector as a float array
+ */
+ float[] vectorise(String text);
+
+ /**
+ * Initialize the model with configuration parameters. Called after construction with the params
+ * from the model configuration.
+ *
+ * @param params the configuration parameters from the model JSON
+ */
+ default void init(Map params) {}
+}
diff --git a/solr/modules/language-models/src/test-files/modelExamples/dummy-custom-model.json b/solr/modules/language-models/src/test-files/modelExamples/dummy-custom-model.json
new file mode 100644
index 000000000000..dc85289e7f13
--- /dev/null
+++ b/solr/modules/language-models/src/test-files/modelExamples/dummy-custom-model.json
@@ -0,0 +1,7 @@
+{
+ "class": "org.apache.solr.languagemodels.textvectorisation.model.DummyTextToVectorModel",
+ "name": "dummy-1",
+ "params": {
+ "embedding": [1.0, 2.0, 3.0, 4.0]
+ }
+}
diff --git a/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/model/DummyTextToVectorModel.java b/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/model/DummyTextToVectorModel.java
new file mode 100644
index 000000000000..fd9af60916e9
--- /dev/null
+++ b/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/model/DummyTextToVectorModel.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.languagemodels.textvectorisation.model;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+public class DummyTextToVectorModel implements TextToVectorModel {
+ private float[] vector;
+
+ public DummyTextToVectorModel() {
+ }
+
+ @Override
+ public float[] vectorise(String text) {
+ return vector;
+ }
+
+ @Override
+ public void init(Map params) {
+ List> embeddings = (List>) params.get("embedding");
+ float[] floatArray = new float[embeddings.size()];
+ for (int i = 0; i < embeddings.size(); i++) {
+ floatArray[i] = ((Number) embeddings.get(i)).floatValue();
+ }
+ this.vector = floatArray;
+ }
+}
diff --git a/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/model/DummyTextToVectorModelTest.java b/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/model/DummyTextToVectorModelTest.java
new file mode 100644
index 000000000000..8ed068f5c233
--- /dev/null
+++ b/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/model/DummyTextToVectorModelTest.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.languagemodels.textvectorisation.model;
+
+import java.util.Arrays;
+import java.util.Map;
+import org.apache.solr.SolrTestCase;
+import org.junit.Test;
+
+public class DummyTextToVectorModelTest extends SolrTestCase {
+ @Test
+ public void testVectorise() {
+ DummyTextToVectorModel model = new DummyTextToVectorModel();
+ model.init(Map.of("embedding", new float[] {1, 2, 3}));
+ float[] vector = model.vectorise("test");
+ assertEquals("[1.0, 2.0, 3.0]", Arrays.toString(vector));
+ }
+}
diff --git a/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/update/processor/TextToVectorUpdateProcessorFactoryTest.java b/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/update/processor/TextToVectorUpdateProcessorFactoryTest.java
index 5ccb9d95e605..007592dcab25 100644
--- a/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/update/processor/TextToVectorUpdateProcessorFactoryTest.java
+++ b/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/update/processor/TextToVectorUpdateProcessorFactoryTest.java
@@ -188,7 +188,7 @@ private UpdateRequestProcessor createUpdateProcessor(
NamedList args = new NamedList<>();
ManagedTextToVectorModelStore.getManagedModelStore(collection1)
- .addModel(new SolrTextToVectorModel(modelName, null, null));
+ .addModel(new SolrTextToVectorModel(modelName, null, null, null));
args.add("inputField", inputFieldName);
args.add("outputField", outputFieldName);
args.add("model", modelName);
diff --git a/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/update/processor/TextToVectorUpdateProcessorTest.java b/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/update/processor/TextToVectorUpdateProcessorTest.java
index 2dd9cda1a585..7178eda3fd59 100644
--- a/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/update/processor/TextToVectorUpdateProcessorTest.java
+++ b/solr/modules/language-models/src/test/org/apache/solr/languagemodels/textvectorisation/update/processor/TextToVectorUpdateProcessorTest.java
@@ -51,7 +51,16 @@ public void afterEachTest() throws Exception {
@Test
public void processAdd_inputField_shouldVectoriseInputField() throws Exception {
- loadModel("dummy-model.json"); // preparation
+ assertVectorisationWithModel("dummy-model.json");
+ }
+
+ @Test
+ public void processAdd_customModel_shouldVectoriseInputField() throws Exception {
+ assertVectorisationWithModel("dummy-custom-model.json");
+ }
+
+ private void assertVectorisationWithModel(String modelJsonFile) throws Exception {
+ loadModel(modelJsonFile);
addWithChain(sdoc("id", "99", "_text_", "Vegeta is the saiyan prince."), "textToVector");
addWithChain(
@@ -68,8 +77,6 @@ public void processAdd_inputField_shouldVectoriseInputField() throws Exception {
"/response/docs/[0]/vector==[1.0, 2.0, 3.0, 4.0]",
"/response/docs/[1]/id=='98'",
"/response/docs/[1]/vector==[1.0, 2.0, 3.0, 4.0]");
-
- restTestHarness.delete(ManagedTextToVectorModelStore.REST_END_POINT + "/dummy-1"); // clean up
}
private SolrQuery getSolrQuery() {