Skip to content

Commit

Permalink
chore: update test
Browse files Browse the repository at this point in the history
  • Loading branch information
appflowy committed Jan 5, 2025
1 parent bb590c0 commit 8aa17ac
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 95 deletions.
41 changes: 40 additions & 1 deletion libs/appflowy-ai-client/src/dto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::str::FromStr;

pub const STREAM_METADATA_KEY: &str = "0";
pub const STREAM_ANSWER_KEY: &str = "1";
pub const STREAM_IMAGE_KEY: &str = "2";
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SummarizeRowResponse {
pub text: String,
Expand All @@ -31,7 +32,7 @@ pub struct ChatQuestion {
pub struct ResponseFormat {
pub output_layout: OutputLayout,
pub output_content: OutputContent,
pub output_content_metadata: Option<serde_json::Value>,
pub output_content_metadata: Option<OutputContentMetadata>,
}

#[derive(Clone, Debug, Default, Serialize_repr, Deserialize_repr)]
Expand All @@ -53,6 +54,44 @@ pub enum OutputContent {
RichTextImage = 2,
}

#[derive(Clone, Default, Debug, Serialize, Deserialize)]
pub struct OutputContentMetadata {
/// Custom prompt for image generation.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub custom_image_prompt: Option<String>,

/// The image model to use for generation (default: "dall-e-2").
#[serde(default = "default_image_model")]
pub image_model: String,

/// Size of the image (default: "256x256").
#[serde(
default = "default_image_size",
skip_serializing_if = "Option::is_none"
)]
pub size: Option<String>,

/// Quality of the image (default: "standard").
#[serde(
default = "default_image_quality",
skip_serializing_if = "Option::is_none"
)]
pub quality: Option<String>,
}

// Default values for the fields
fn default_image_model() -> String {
"dall-e-2".to_string()
}

fn default_image_size() -> Option<String> {
Some("256x256".to_string())
}

fn default_image_quality() -> Option<String> {
Some("standard".to_string())
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct MessageData {
pub content: String,
Expand Down
38 changes: 0 additions & 38 deletions libs/appflowy-ai-client/tests/chat_test/qa_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,44 +25,6 @@ async fn qa_test() {
assert_eq!(questions.len(), 3)
}

#[tokio::test]
async fn stream_test() {
let client = appflowy_ai_client();
client.health_check().await.expect("Health check failed");
let chat_id = uuid::Uuid::new_v4().to_string();
let stream = client
.stream_question_v2(
&chat_id,
1,
"I feel hungry",
None,
vec![],
&AIModel::GPT4oMini,
)
.await
.expect("Failed to initiate question stream");

// Wrap the stream in JsonStream with appropriate type parameters
let json_stream = JsonStream::<serde_json::Value, _, AIError>::new(stream);

// Collect messages from the stream
let messages: Vec<String> = json_stream
.filter_map(|item| async {
match item {
Ok(value) => value
.get(STREAM_ANSWER_KEY)
.and_then(|s| s.as_str().map(ToString::to_string)),
Err(err) => {
eprintln!("Error during streaming: {:?}", err); // Log the error for better debugging
None
},
}
})
.collect()
.await;

println!("final answer: {}", messages.join(""));
}
#[tokio::test]
async fn download_package_test() {
let client = appflowy_ai_client();
Expand Down
9 changes: 8 additions & 1 deletion libs/client-api/src/http_chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use reqwest::Method;
use serde_json::Value;
use shared_entity::dto::ai_dto::{
CalculateSimilarityParams, ChatQuestionQuery, RepeatedRelatedQuestion, SimilarityResponse,
STREAM_ANSWER_KEY, STREAM_METADATA_KEY,
STREAM_ANSWER_KEY, STREAM_IMAGE_KEY, STREAM_METADATA_KEY,
};
use shared_entity::dto::chat_dto::{ChatSettings, UpdateChatParams};
use shared_entity::response::{AppResponse, AppResponseError};
Expand Down Expand Up @@ -387,6 +387,13 @@ impl Stream for QuestionStream {
return Poll::Ready(Some(Ok(QuestionStreamValue::Answer { value: answer })));
}

if let Some(image) = value
.remove(STREAM_IMAGE_KEY)
.and_then(|s| s.as_str().map(ToString::to_string))
{
return Poll::Ready(Some(Ok(QuestionStreamValue::Answer { value: image })));
}

error!("Invalid streaming value: {:?}", value);
Poll::Ready(None)
},
Expand Down
113 changes: 58 additions & 55 deletions tests/ai_test/chat_test.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::ai_test::util::read_text_from_asset;

use appflowy_ai_client::dto::{ChatQuestionQuery, OutputContent, OutputLayout, ResponseFormat};
use appflowy_ai_client::dto::{
ChatQuestionQuery, OutputContent, OutputContentMetadata, OutputLayout, ResponseFormat,
};
use assert_json_diff::{assert_json_eq, assert_json_include};
use client_api::entity::{QuestionStream, QuestionStreamValue};
use client_api_test::{ai_test_enabled, TestClient};
Expand Down Expand Up @@ -295,60 +297,6 @@ async fn generate_chat_message_answer_test() {
assert!(!answer.is_empty());
}

// #[tokio::test]
// async fn update_chat_message_test() {
// if !ai_test_enabled() {
// return;
// }

// let test_client = TestClient::new_user_without_ws_conn().await;
// let workspace_id = test_client.workspace_id().await;
// let chat_id = uuid::Uuid::new_v4().to_string();
// let params = CreateChatParams {
// chat_id: chat_id.clone(),
// name: "my second chat".to_string(),
// rag_ids: vec![],
// };

// test_client
// .api_client
// .create_chat(&workspace_id, params)
// .await
// .unwrap();

// let params = CreateChatMessageParams::new_user("where is singapore?");
// let stream = test_client
// .api_client
// .create_chat_message(&workspace_id, &chat_id, params)
// .await
// .unwrap();
// let messages: Vec<ChatMessage> = stream.map(|message| message.unwrap()).collect().await;
// assert_eq!(messages.len(), 2);

// let params = UpdateChatMessageContentParams {
// chat_id: chat_id.clone(),
// message_id: messages[0].message_id,
// content: "where is China?".to_string(),
// };
// test_client
// .api_client
// .update_chat_message(&workspace_id, &chat_id, params)
// .await
// .unwrap();

// let remote_messages = test_client
// .api_client
// .get_chat_messages(&workspace_id, &chat_id, MessageCursor::NextBack, 2)
// .await
// .unwrap()
// .messages;
// assert_eq!(remote_messages[0].content, "where is China?");
// assert_eq!(remote_messages.len(), 2);

// // when the question was updated, the answer should be different
// assert_ne!(remote_messages[1].content, messages[1].content);
// }

#[tokio::test]
async fn get_format_question_message_test() {
if !ai_test_enabled() {
Expand Down Expand Up @@ -399,6 +347,61 @@ async fn get_format_question_message_test() {
assert!(!answer.is_empty());
}

#[tokio::test]
async fn get_text_with_image_message_test() {
if !ai_test_enabled() {
return;
}

let test_client = TestClient::new_user_without_ws_conn().await;
let workspace_id = test_client.workspace_id().await;
let chat_id = uuid::Uuid::new_v4().to_string();
let params = CreateChatParams {
chat_id: chat_id.clone(),
name: "my ai chat".to_string(),
rag_ids: vec![],
};

test_client
.api_client
.create_chat(&workspace_id, params)
.await
.unwrap();

let params = CreateChatMessageParams::new_user(
"I have a little cat. It is black with big eyes, short legs and a long tail",
);
let question = test_client
.api_client
.create_question(&workspace_id, &chat_id, params)
.await
.unwrap();

let query = ChatQuestionQuery {
chat_id,
question_id: question.message_id,
format: ResponseFormat {
output_layout: OutputLayout::SimpleTable,
output_content: OutputContent::RichTextImage,
output_content_metadata: Some(OutputContentMetadata {
custom_image_prompt: None,
image_model: "dall-e-3".to_string(),
size: None,
quality: None,
}),
},
};

let answer_stream = test_client
.api_client
.stream_answer_v3(&workspace_id, query)
.await
.unwrap();
let answer = collect_answer(answer_stream).await;
println!("answer:\n{}", answer);
assert!(!answer.is_empty());
}

#[tokio::test]
async fn get_question_message_test() {
if !ai_test_enabled() {
Expand Down

0 comments on commit 8aa17ac

Please sign in to comment.