Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

22 add more cdi features #72

Closed
wants to merge 13 commits into from
Prev Previous commit
Next Next commit
Updated to support CDI creation for RAG support.
Buhake Sindi committed Dec 13, 2024
commit 59a91524b5be67bd5cc3fcbeeb12dbb84419edcb
Original file line number Diff line number Diff line change
@@ -1,127 +1,103 @@
package io.smallrye.llm.aiservice;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.util.ArrayList;
import java.util.List;

import jakarta.enterprise.inject.Instance;
import jakarta.enterprise.inject.literal.NamedLiteral;

import org.jboss.logging.Logger;

import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.moderation.ModerationModel;
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.MemoryId;
import dev.langchain4j.service.Moderate;
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
import io.smallrye.llm.core.langchain4j.core.config.spi.ChatMemoryFactoryProvider;
import io.smallrye.llm.spi.RegisterAIService;
import jakarta.enterprise.inject.Instance;
import jakarta.enterprise.inject.literal.NamedLiteral;

public class CommonAIServiceCreator {

private static final Logger LOGGER = Logger.getLogger(CommonAIServiceCreator.class);

@SuppressWarnings("unchecked")
public static <X> X create(Instance<Object> lookup, Class<X> interfaceClass) {
RegisterAIService annotation = interfaceClass.getAnnotation(RegisterAIService.class);
Instance<ChatLanguageModel> chatLanguageModel = getInstance(lookup, ChatLanguageModel.class,
annotation.chatLanguageModelName());
Instance<StreamingChatLanguageModel> streamingChatLanguageModel = getInstance(lookup, StreamingChatLanguageModel.class,
annotation.streamingChatLanguageModelName());
Instance<ContentRetriever> contentRetriever = getInstance(lookup, ContentRetriever.class,
annotation.contentRetrieverName());
Instance<RetrievalAugmentor> retrievalAugmentor = getInstance(lookup, RetrievalAugmentor.class,
annotation.retrievalAugmentorName());
try {
AiServices<?> aiServices = AiServices.builder(interfaceClass);
if (chatLanguageModel.isResolvable()) {
AiServices<X> aiServices = AiServices.builder(interfaceClass);
if (chatLanguageModel != null && chatLanguageModel.isResolvable()) {
LOGGER.info("ChatLanguageModel " + chatLanguageModel.get());
aiServices.chatLanguageModel(chatLanguageModel.get());
}
if (contentRetriever.isResolvable()) {
if (streamingChatLanguageModel != null && streamingChatLanguageModel.isResolvable()) {
LOGGER.info("StreamingChatLanguageModel " + streamingChatLanguageModel.get());
aiServices.streamingChatLanguageModel(streamingChatLanguageModel.get());
}
if (contentRetriever != null && contentRetriever.isResolvable()) {
LOGGER.info("ContentRetriever " + contentRetriever.get());
aiServices.contentRetriever(contentRetriever.get());
}
if (retrievalAugmentor != null && retrievalAugmentor.isResolvable()) {
LOGGER.info("RetrievalAugmentor " + retrievalAugmentor.get());
aiServices.retrievalAugmentor(retrievalAugmentor.get());
}
if (annotation.tools() != null && annotation.tools().length > 0) {
List<Object> tools = new ArrayList<>(annotation.tools().length);
for (Class<?> toolClass : annotation.tools()) {
try {
tools.add(toolClass.getConstructor(null).newInstance(null));
tools.add(toolClass.getConstructor((Class<?>[])null).newInstance((Object[])null));
} catch (NoSuchMethodException | SecurityException | InstantiationException | IllegalAccessException
| IllegalArgumentException | InvocationTargetException ex) {
}
}
aiServices.tools(tools);
}

Instance<ChatMemory> chatMemory = getInstance(lookup, ChatMemory.class,
annotation.chatMemoryName());
if (chatMemory != null && chatMemory.isResolvable()) {
LOGGER.info("ChatMemory " + chatMemory.get());
aiServices.chatMemory(chatMemory.get());
}

ChatMemoryProvider chatMemoryProvider = createChatMemoryProvider(lookup, interfaceClass, annotation);
if (chatMemoryProvider != null) {
aiServices.chatMemoryProvider(chatMemoryProvider);
} else {
aiServices.chatMemory(
ChatMemoryFactoryProvider.getChatMemoryFactory().getChatMemory(lookup,
annotation.chatMemoryMaxMessages()));
}
Instance<ChatMemoryProvider> chatMemoryProvider = getInstance(lookup, ChatMemoryProvider.class,
annotation.chatMemoryProviderName());
if (chatMemoryProvider != null && chatMemoryProvider.isResolvable()) {
LOGGER.info("ChatMemoryProvider " + chatMemoryProvider.get());
aiServices.chatMemoryProvider(chatMemoryProvider.get());
}

ModerationModel moderationModel = findModerationModel(lookup, interfaceClass, annotation);
if (moderationModel != null) {
aiServices.moderationModel(moderationModel);
Instance<ModerationModel> moderationModelInstance = getInstance(lookup, ModerationModel.class,
annotation.moderationModelName());
if (moderationModelInstance != null && moderationModelInstance.isResolvable()) {
LOGGER.info("ModerationModel " + moderationModelInstance.get());
aiServices.moderationModel(moderationModelInstance.get());
}
return (X) aiServices.build();

return aiServices.build();
} catch (Exception e) {
throw new RuntimeException(e);
}
}

private static <X> Instance<X> getInstance(Instance<Object> lookup, Class<X> type, String name) {
LOGGER.info("Getinstance of '" + type + "' with name '" + name + "'");
if (name == null || name.isBlank()) {
return lookup.select(type);
}
return lookup.select(type, NamedLiteral.of(name));
}

private static ModerationModel findModerationModel(Instance<Object> lookup, Class<?> interfaceClass,
RegisterAIService registerAIService) {
//Get all methods.
for (Method method : interfaceClass.getMethods()) {
Moderate moderate = method.getAnnotation(Moderate.class);
if (moderate != null) {
Instance<ModerationModel> moderationModelInstance = getInstance(lookup, ModerationModel.class,
registerAIService.moderationModelName());
if (moderationModelInstance != null && moderationModelInstance.isResolvable())
return moderationModelInstance.get();
}
LOGGER.info("CDI get instance of type '" + type + "' with name '" + name + "'");
if (name != null && !name.isBlank()) {
if ("#default".equals(name))
return lookup.select(type);

return lookup.select(type, NamedLiteral.of(name));
}

return null;
}

private static ChatMemoryProvider createChatMemoryProvider(Instance<Object> lookup, Class<?> interfaceClass,
RegisterAIService registerAIService) {
//Get all methods.
for (Method method : interfaceClass.getMethods()) {
for (Parameter parameter : method.getParameters()) {
MemoryId memoryIdAnnotation = parameter.getAnnotation(MemoryId.class);
if (memoryIdAnnotation != null) {
Instance<ChatMemoryStore> chatMemoryStore = getInstance(lookup, ChatMemoryStore.class,
registerAIService.chatMemoryStoreName());
if (chatMemoryStore == null || !chatMemoryStore.isResolvable()) {
throw new IllegalStateException("Unable to resolve a ChatMemoryStore for your ChatMemoryProvider.");
}

ChatMemoryProvider chatMemoryProvider = memoryId -> MessageWindowChatMemory.builder()
.id(memoryId)
.maxMessages(registerAIService.chatMemoryMaxMessages())
.chatMemoryStore(chatMemoryStore.get())
.build();
return chatMemoryProvider;
}
}
}


return null;
}
}
Original file line number Diff line number Diff line change
@@ -16,19 +16,19 @@
import java.util.function.Function;
import java.util.stream.Collectors;

import jakarta.enterprise.context.ApplicationScoped;
import jakarta.enterprise.inject.Instance;
import jakarta.enterprise.inject.literal.NamedLiteral;
import jakarta.enterprise.util.TypeLiteral;

import org.jboss.logging.Logger;

import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.store.embedding.EmbeddingStore;
import io.smallrye.llm.core.langchain4j.core.config.spi.LLMConfig;
import io.smallrye.llm.core.langchain4j.core.config.spi.LLMConfigProvider;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.enterprise.inject.Instance;
import jakarta.enterprise.inject.literal.NamedLiteral;
import jakarta.enterprise.util.TypeLiteral;

/*
smallrye.llm.plugin.content-retriever.class=dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever
@@ -149,7 +149,7 @@ public static Object create(Instance<Object> lookup, String beanName, Class<?> t
for (Method methodToCall : methodsToCall) {
Class<?> parameterType = methodToCall.getParameterTypes()[0];
if ("listeners".equals(property)) {
Class<?> typeParameterClass = ChatLanguageModel.class.isAssignableFrom(targetClass)
Class<?> typeParameterClass = ChatLanguageModel.class.isAssignableFrom(targetClass) || StreamingChatLanguageModel.class.isAssignableFrom(targetClass)
? ChatModelListener.class
: parameterType.getTypeParameters()[0].getGenericDeclaration();
List<Object> listeners = (List<Object>) Collections.checkedList(new ArrayList<>(),
Original file line number Diff line number Diff line change
@@ -19,19 +19,19 @@

Class<?>[] tools() default {};

String chatLanguageModelName() default "";
String chatLanguageModelName() default "#default";

String streamingChatLanguageModelName() default "";

String contentRetrieverModelName() default "";

int chatMemoryMaxMessages() default 10;

String embeddingModelName() default "";

String embeddingStoreName() default "";

String contentRetrieverName() default "";

String moderationModelName() default "";

String chatMemoryStoreName() default "";
String chatMemoryName() default "";

String chatMemoryProviderName() default "";

String retrievalAugmentorName() default "";
}