Skip to content

Commit

Permalink
Update.
Browse files Browse the repository at this point in the history
  • Loading branch information
Buhake Sindi committed Dec 14, 2024
1 parent 0cca08b commit 0e91b10
Showing 1 changed file with 27 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,27 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;

import org.jboss.logging.Logger;

import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.store.embedding.EmbeddingStore;
import io.smallrye.llm.core.langchain4j.core.config.spi.LLMConfig;
import io.smallrye.llm.core.langchain4j.core.config.spi.LLMConfigProvider;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.enterprise.inject.Instance;
import jakarta.enterprise.inject.literal.NamedLiteral;
import jakarta.enterprise.inject.spi.CDI;
import jakarta.enterprise.util.TypeLiteral;

/*
smallrye.llm.plugin.content-retriever.class=dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever
Expand All @@ -37,6 +41,13 @@
public class CommonLLMPluginCreator {

public static final Logger LOGGER = Logger.getLogger(CommonLLMPluginCreator.class);

private static final Map<Class<?>, TypeLiteral<?>> TYPE_LITERALS = new HashMap<>();

static {
TYPE_LITERALS.put(EmbeddingStore.class, new TypeLiteral<EmbeddingStore<TextSegment>>() {
});
}

@SuppressWarnings("unchecked")
public static void createAllLLMBeans(LLMConfig llmConfig, Consumer<BeanData> beanBuilder) throws ClassNotFoundException {
Expand Down Expand Up @@ -165,12 +176,9 @@ public static Object create(Instance<Object> lookup, String beanName, Class<?> t
LOGGER.info("Lookup " + lookupableBean + " " + parameterType);
Instance<?> inst;
if ("default".equals(lookupableBean)) {
inst = lookup.select(parameterType);
if (!inst.isResolvable()) {
inst = CDI.current().select(parameterType);
}
inst = getInstance(lookup, parameterType);
} else {
inst = lookup.select(parameterType, NamedLiteral.of(lookupableBean));
inst = getInstance(lookup, parameterType, lookupableBean);
}
methodToCall.invoke(builder, inst.get());
break;
Expand All @@ -195,4 +203,17 @@ public static Object create(Instance<Object> lookup, String beanName, Class<?> t
private static Class<?> loadClass(String scopeClassName) throws ClassNotFoundException {
return Thread.currentThread().getContextClassLoader().loadClass(scopeClassName);
}

@SuppressWarnings("unchecked")
private static <T> Instance<T> getInstance(Instance<Object> lookup, Class<T> clazz) {
if (TYPE_LITERALS.containsKey(clazz))
return (Instance<T>) lookup.select(TYPE_LITERALS.get(clazz));
return lookup.select(clazz);
}

private static <T> Instance<T> getInstance(Instance<Object> lookup, Class<T> clazz, String lookupName) {
if (lookupName == null || lookupName.isBlank())
return getInstance(lookup, clazz);
return lookup.select(clazz, NamedLiteral.of(lookupName));
}
}

0 comments on commit 0e91b10

Please sign in to comment.