From 7cc4cfe1b1fa988ccc8d1cdb5a336c04c0a37a61 Mon Sep 17 00:00:00 2001 From: Oleh Astappiev Date: Tue, 3 Oct 2023 12:06:01 +0200 Subject: [PATCH] feat: use first message as title if missing, secure retrieval by uuid --- .../de/l3s/interweb/server/PanacheUtils.java | 20 ++++++ .../de/l3s/interweb/server/chat/Chat.java | 30 ++++++++- .../interweb/server/chat/ChatResource.java | 65 +++++++++++++------ 3 files changed, 93 insertions(+), 22 deletions(-) create mode 100644 interweb-server/src/main/java/de/l3s/interweb/server/PanacheUtils.java diff --git a/interweb-server/src/main/java/de/l3s/interweb/server/PanacheUtils.java b/interweb-server/src/main/java/de/l3s/interweb/server/PanacheUtils.java new file mode 100644 index 00000000..3f547ff9 --- /dev/null +++ b/interweb-server/src/main/java/de/l3s/interweb/server/PanacheUtils.java @@ -0,0 +1,20 @@ +package de.l3s.interweb.server; + +import io.quarkus.panache.common.Sort; + +public final class PanacheUtils { + public static Sort createSort(String order) { + Sort sort = Sort.empty(); + if (order != null) { + String[] tokens = order.split(","); + for (String token : tokens) { + if (token.startsWith("-")) { + sort = sort.and(token.substring(1), Sort.Direction.Descending); + } else { + sort = sort.and(token, Sort.Direction.Ascending); + } + } + } + return sort; + } +} diff --git a/interweb-server/src/main/java/de/l3s/interweb/server/chat/Chat.java b/interweb-server/src/main/java/de/l3s/interweb/server/chat/Chat.java index 33819190..2f46b727 100644 --- a/interweb-server/src/main/java/de/l3s/interweb/server/chat/Chat.java +++ b/interweb-server/src/main/java/de/l3s/interweb/server/chat/Chat.java @@ -10,6 +10,8 @@ import jakarta.validation.constraints.Size; import io.quarkus.hibernate.reactive.panache.PanacheEntityBase; +import io.quarkus.panache.common.Parameters; +import io.quarkus.panache.common.Sort; import io.smallrye.mutiny.Uni; import org.hibernate.annotations.ColumnDefault; import org.hibernate.annotations.CreationTimestamp; @@ -17,6 +19,7 @@ import com.fasterxml.jackson.annotation.JsonIgnore; +import de.l3s.interweb.server.PanacheUtils; import de.l3s.interweb.server.principal.Consumer; @Entity @@ -77,7 +80,30 @@ public void addCosts(int tokens, double cost) { this.estimatedCost += cost; } - public static Uni findById(UUID id) { - return find("id", id).firstResult(); + public Uni updateTitle() { + return update("title = ?1 where id = ?2", title, id); + } + + public Uni updateTitleAndUsage() { + return update("title = ?1, usedTokens = ?2, estimatedCost = ?3 where id = ?4", title, usedTokens, estimatedCost, id); + } + + public static Uni> listByUser(Consumer consumer, String user, String order, int page, int perPage) { + String query = "consumer.id = :id AND usedTokens != 0"; + Parameters params = Parameters.with("id", consumer.id); + Sort sort = PanacheUtils.createSort(order); + + if (user == null) { + query += " AND user IS NULL"; + } else { + params.and("user", user); + query += " AND user = :user"; + } + + return find(query, sort, params).page(page - 1, perPage).list(); + } + + public static Uni findById(Consumer consumer, UUID id) { + return find("consumer.id = ?1 AND id = ?2", consumer.id, id).firstResult(); } } diff --git a/interweb-server/src/main/java/de/l3s/interweb/server/chat/ChatResource.java b/interweb-server/src/main/java/de/l3s/interweb/server/chat/ChatResource.java index d7ca78a5..8f13d331 100644 --- a/interweb-server/src/main/java/de/l3s/interweb/server/chat/ChatResource.java +++ b/interweb-server/src/main/java/de/l3s/interweb/server/chat/ChatResource.java @@ -13,10 +13,12 @@ import io.quarkus.hibernate.reactive.panache.Panache; import io.quarkus.security.Authenticated; import io.quarkus.security.identity.SecurityIdentity; +import io.smallrye.mutiny.Multi; import io.smallrye.mutiny.Uni; import org.hibernate.reactive.mutiny.Mutiny; import de.l3s.interweb.core.completion.*; +import de.l3s.interweb.core.util.StringUtils; import de.l3s.interweb.server.principal.Consumer; @Path("/chat") @@ -30,20 +32,48 @@ public class ChatResource { @GET @Authenticated - public Uni> chats(@QueryParam("user") String user) { + @Path("/models") + public Map models() { + return chatService.getModels(); + } + + @GET + @Authenticated + public Uni> chats( + @QueryParam("user") String user, + @QueryParam("sort") @DefaultValue("-created") String order, + @QueryParam("page") @DefaultValue("1") Integer page, + @QueryParam("perPage") @DefaultValue("20") Integer perPage + ) { Consumer consumer = securityIdentity.getCredential(Consumer.class); - if (user == null) { - return Chat.list("consumer.id = ?1 AND user IS NULL ORDER BY created DESC LIMIT 20", consumer.id); - } else { - return Chat.list("consumer.id = ?1 AND user = ?2 ORDER BY created DESC LIMIT 20", consumer.id, user); - } + + return Chat.listByUser(consumer, user, order, page, perPage).flatMap(chats -> Multi.createFrom().iterable(chats).call(chat -> { + if (chat.title == null) { + return Mutiny.fetch(chat.getMessages()).call(() -> { + if (!chat.getMessages().isEmpty()) { + for (ChatMessage message : chat.getMessages()) { + if (message.role == Message.Role.user) { + chat.title = StringUtils.shorten(message.content, 120); + return chat.updateTitle(); + } + } + } + + return Uni.createFrom().voidItem(); + }); + } + + return Uni.createFrom().voidItem(); + }).collect().asList()); } @GET @Authenticated @Path("{uuid}") public Uni chat(@PathParam("uuid") UUID id) { - return Chat.findById(id).call(chat -> Mutiny.fetch(chat.getMessages())).map(chat -> { + Consumer consumer = securityIdentity.getCredential(Consumer.class); + + return Chat.findById(consumer, id).call(chat -> Mutiny.fetch(chat.getMessages())).map(chat -> { Conversation conversation = new Conversation(); conversation.setId(chat.id); conversation.setTitle(chat.title); @@ -55,19 +85,14 @@ public Uni chat(@PathParam("uuid") UUID id) { }); } - @GET - @Authenticated - @Path("/models") - public Map models() { - return chatService.getModels(); - } - @POST @Authenticated @Path("/completions") public Uni completions(@Valid CompletionQuery query) { - return getOrCreateChat(query).flatMap(chat -> { - //noinspection CodeBlock2Expr + Consumer consumer = securityIdentity.getCredential(Consumer.class); + + return getOrCreateChat(query, consumer).flatMap(chat -> { + // noinspection CodeBlock2Expr return chatService.completions(query).call(results -> { results.setChatId(chat.id); return persistMessages(chat, query.getMessages(), results); @@ -86,17 +111,17 @@ public Uni completions(@Valid CompletionQuery query) { } else { return Uni.createFrom().voidItem(); } - }).call(() -> Chat.update("title = ?1, usedTokens = ?2, estimatedCost = ?3 where id = ?4", chat.title, chat.usedTokens, chat.estimatedCost, chat.id)); + }).call(chat::updateTitleAndUsage); }); } - private Uni getOrCreateChat(CompletionQuery query) { + private Uni getOrCreateChat(CompletionQuery query, Consumer consumer) { if (query.getId() != null) { - return Chat.findById(query.getId()).call(chat -> Mutiny.fetch(chat.getMessages())); + return Chat.findById(consumer, query.getId()).call(chat -> Mutiny.fetch(chat.getMessages())); } Chat chat = new Chat(); - chat.consumer = securityIdentity.getCredential(Consumer.class); + chat.consumer = consumer; chat.model = query.getModel(); chat.user = query.getUser(); return Panache.withTransaction(chat::persist);