Skip to content

Commit

Permalink
feat: use first message as title if missing, secure retrieval by uuid
Browse files Browse the repository at this point in the history
  • Loading branch information
astappiev committed Oct 3, 2023
1 parent fa3ada8 commit 7cc4cfe
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 22 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package de.l3s.interweb.server;

import io.quarkus.panache.common.Sort;

public final class PanacheUtils {
public static Sort createSort(String order) {
Sort sort = Sort.empty();
if (order != null) {
String[] tokens = order.split(",");
for (String token : tokens) {
if (token.startsWith("-")) {
sort = sort.and(token.substring(1), Sort.Direction.Descending);
} else {
sort = sort.and(token, Sort.Direction.Ascending);
}
}
}
return sort;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@
import jakarta.validation.constraints.Size;

import io.quarkus.hibernate.reactive.panache.PanacheEntityBase;
import io.quarkus.panache.common.Parameters;
import io.quarkus.panache.common.Sort;
import io.smallrye.mutiny.Uni;
import org.hibernate.annotations.ColumnDefault;
import org.hibernate.annotations.CreationTimestamp;
import org.hibernate.annotations.UuidGenerator;

import com.fasterxml.jackson.annotation.JsonIgnore;

import de.l3s.interweb.server.PanacheUtils;
import de.l3s.interweb.server.principal.Consumer;

@Entity
Expand Down Expand Up @@ -77,7 +80,30 @@ public void addCosts(int tokens, double cost) {
this.estimatedCost += cost;
}

public static Uni<Chat> findById(UUID id) {
return find("id", id).firstResult();
public Uni<Integer> updateTitle() {
return update("title = ?1 where id = ?2", title, id);
}

public Uni<Integer> updateTitleAndUsage() {
return update("title = ?1, usedTokens = ?2, estimatedCost = ?3 where id = ?4", title, usedTokens, estimatedCost, id);
}

public static Uni<List<Chat>> listByUser(Consumer consumer, String user, String order, int page, int perPage) {
String query = "consumer.id = :id AND usedTokens != 0";
Parameters params = Parameters.with("id", consumer.id);
Sort sort = PanacheUtils.createSort(order);

if (user == null) {
query += " AND user IS NULL";
} else {
params.and("user", user);
query += " AND user = :user";
}

return find(query, sort, params).page(page - 1, perPage).list();
}

public static Uni<Chat> findById(Consumer consumer, UUID id) {
return find("consumer.id = ?1 AND id = ?2", consumer.id, id).firstResult();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
import io.quarkus.hibernate.reactive.panache.Panache;
import io.quarkus.security.Authenticated;
import io.quarkus.security.identity.SecurityIdentity;
import io.smallrye.mutiny.Multi;
import io.smallrye.mutiny.Uni;
import org.hibernate.reactive.mutiny.Mutiny;

import de.l3s.interweb.core.completion.*;
import de.l3s.interweb.core.util.StringUtils;
import de.l3s.interweb.server.principal.Consumer;

@Path("/chat")
Expand All @@ -30,20 +32,48 @@ public class ChatResource {

@GET
@Authenticated
public Uni<List<Chat>> chats(@QueryParam("user") String user) {
@Path("/models")
public Map<String, UsagePrice> models() {
return chatService.getModels();
}

@GET
@Authenticated
public Uni<List<Chat>> chats(
@QueryParam("user") String user,
@QueryParam("sort") @DefaultValue("-created") String order,
@QueryParam("page") @DefaultValue("1") Integer page,
@QueryParam("perPage") @DefaultValue("20") Integer perPage
) {
Consumer consumer = securityIdentity.getCredential(Consumer.class);
if (user == null) {
return Chat.list("consumer.id = ?1 AND user IS NULL ORDER BY created DESC LIMIT 20", consumer.id);
} else {
return Chat.list("consumer.id = ?1 AND user = ?2 ORDER BY created DESC LIMIT 20", consumer.id, user);
}

return Chat.listByUser(consumer, user, order, page, perPage).flatMap(chats -> Multi.createFrom().iterable(chats).call(chat -> {
if (chat.title == null) {
return Mutiny.fetch(chat.getMessages()).call(() -> {
if (!chat.getMessages().isEmpty()) {
for (ChatMessage message : chat.getMessages()) {
if (message.role == Message.Role.user) {
chat.title = StringUtils.shorten(message.content, 120);
return chat.updateTitle();
}
}
}

return Uni.createFrom().voidItem();
});
}

return Uni.createFrom().voidItem();
}).collect().asList());
}

@GET
@Authenticated
@Path("{uuid}")
public Uni<Conversation> chat(@PathParam("uuid") UUID id) {
return Chat.findById(id).call(chat -> Mutiny.fetch(chat.getMessages())).map(chat -> {
Consumer consumer = securityIdentity.getCredential(Consumer.class);

return Chat.findById(consumer, id).call(chat -> Mutiny.fetch(chat.getMessages())).map(chat -> {
Conversation conversation = new Conversation();
conversation.setId(chat.id);
conversation.setTitle(chat.title);
Expand All @@ -55,19 +85,14 @@ public Uni<Conversation> chat(@PathParam("uuid") UUID id) {
});
}

@GET
@Authenticated
@Path("/models")
public Map<String, UsagePrice> models() {
return chatService.getModels();
}

@POST
@Authenticated
@Path("/completions")
public Uni<CompletionResults> completions(@Valid CompletionQuery query) {
return getOrCreateChat(query).flatMap(chat -> {
//noinspection CodeBlock2Expr
Consumer consumer = securityIdentity.getCredential(Consumer.class);

return getOrCreateChat(query, consumer).flatMap(chat -> {
// noinspection CodeBlock2Expr
return chatService.completions(query).call(results -> {
results.setChatId(chat.id);
return persistMessages(chat, query.getMessages(), results);
Expand All @@ -86,17 +111,17 @@ public Uni<CompletionResults> completions(@Valid CompletionQuery query) {
} else {
return Uni.createFrom().voidItem();
}
}).call(() -> Chat.update("title = ?1, usedTokens = ?2, estimatedCost = ?3 where id = ?4", chat.title, chat.usedTokens, chat.estimatedCost, chat.id));
}).call(chat::updateTitleAndUsage);
});
}

private Uni<Chat> getOrCreateChat(CompletionQuery query) {
private Uni<Chat> getOrCreateChat(CompletionQuery query, Consumer consumer) {
if (query.getId() != null) {
return Chat.findById(query.getId()).call(chat -> Mutiny.fetch(chat.getMessages()));
return Chat.findById(consumer, query.getId()).call(chat -> Mutiny.fetch(chat.getMessages()));
}

Chat chat = new Chat();
chat.consumer = securityIdentity.getCredential(Consumer.class);
chat.consumer = consumer;
chat.model = query.getModel();
chat.user = query.getUser();
return Panache.withTransaction(chat::persist);
Expand Down

0 comments on commit 7cc4cfe

Please sign in to comment.