Skip to content

Commit

Permalink
Add test for match ai + group
Browse files Browse the repository at this point in the history
  • Loading branch information
elie222 committed Dec 20, 2024
1 parent 67c71fc commit 37d8662
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
7 changes: 7 additions & 0 deletions apps/web/utils/ai/choose-rule/ai-choose-rule.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import { chatCompletionObject } from "@/utils/llms";
import type { User } from "@prisma/client";
import { stringifyEmail } from "@/utils/ai/choose-rule/stringify-email";
import type { EmailForLLM } from "@/utils/ai/choose-rule/stringify-email";
import { createScopedLogger } from "@/utils/logger";

const logger = createScopedLogger("ai-choose-rule");

type GetAiResponseOptions = {
email: {
Expand Down Expand Up @@ -56,6 +59,8 @@ ${stringifyEmail(email, 500)}
</email>
`;

logger.trace("AI choose rule prompt", { system, prompt });

const aiResponse = await chatCompletionObject({
userAi: user,
prompt,
Expand All @@ -68,6 +73,8 @@ ${stringifyEmail(email, 500)}
usageLabel: "Choose rule",
});

logger.trace("AI choose rule response", aiResponse.object);

return aiResponse.object;
}

Expand Down
41 changes: 40 additions & 1 deletion apps/web/utils/ai/choose-rule/match-rules.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { describe, it, expect, vi, afterEach } from "vitest";
import { describe, it, expect, vi, beforeEach } from "vitest";
import { findMatchingRule } from "./match-rules";
import {
type Category,
Expand All @@ -15,14 +15,22 @@ import type {
ParsedMessageHeaders,
} from "@/utils/types";
import prisma from "@/utils/__mocks__/prisma";
import { aiChooseRule } from "@/utils/ai/choose-rule/ai-choose-rule";

// Run with:
// pnpm test match-rules.test.ts

vi.mock("server-only", () => ({}));
vi.mock("@/utils/prisma");
vi.mock("@/utils/ai/choose-rule/ai-choose-rule", () => ({
aiChooseRule: vi.fn(),
}));

describe("findMatchingRule", () => {
beforeEach(() => {
vi.clearAllMocks();
});

it("matches a static rule", async () => {
const rule = getRule({ from: "[email protected]" });
const rules = [rule];
Expand Down Expand Up @@ -156,6 +164,37 @@ describe("findMatchingRule", () => {
expect(result.reason).toBeUndefined();
});

it("matches a rule with multiple conditions AND (category and AI)", async () => {
prisma.newsletter.findUnique.mockResolvedValue(
getNewsletter({ categoryId: "newsletterCategory" }),
);

const rule = getRule({
conditionalOperator: LogicalOperator.AND,
instructions: "Match if the email is an AI newsletter",
categoryFilters: [getCategory({ id: "newsletterCategory" })],
});

(aiChooseRule as ReturnType<typeof vi.fn>).mockImplementationOnce(() => {
return {
reason: "reason",
rule: { id: "r123" },
};
});

const rules = [rule];
const message = getMessage({
headers: getHeaders({ from: "[email protected]" }),
});
const user = getUser();

const result = await findMatchingRule(rules, message, user);

expect(result.rule?.id).toBe(rule.id);
expect(result.reason).toBeDefined();
expect(aiChooseRule).toHaveBeenCalledOnce();
});

it("doesn't match with only one of category or group", async () => {
prisma.newsletter.findUnique.mockResolvedValue(
getNewsletter({ categoryId: "category1" }),
Expand Down

0 comments on commit 37d8662

Please sign in to comment.