Skip to content

Commit

Permalink
Drop usage of #symmetric_topics_similarity_search. Don't rely on stubs
Browse files Browse the repository at this point in the history
  • Loading branch information
romanrizzi committed Dec 12, 2024
1 parent 75ad967 commit a5a09fa
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 67 deletions.
6 changes: 4 additions & 2 deletions lib/embeddings/semantic_related.rb
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ def related_topic_ids_for(topic)
Discourse
.cache
.fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do
vector_rep
.symmetric_topics_similarity_search(topic)
DiscourseAi::Embeddings::Schema
.for(Topic, vector: vector_rep)
.symmetric_similarity_search(topic)
.map(&:topic_id)
.tap do |candidate_ids|
# Happens when the topic doesn't have any embeddings
# I'd rather not use Exceptions to control the flow, so this should be refactored soon
Expand Down
48 changes: 0 additions & 48 deletions lib/embeddings/vector_representations/base.rb
Original file line number Diff line number Diff line change
Expand Up @@ -132,54 +132,6 @@ def post_id_from_representation(raw_vector)
SQL
end

def symmetric_topics_similarity_search(topic)
DB.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
WITH le_target AS (
SELECT
embeddings
FROM
#{topic_table_name}
WHERE
model_id = #{id} AND
strategy_id = #{@strategy.id} AND
topic_id = :topic_id
LIMIT 1
)
SELECT topic_id FROM (
SELECT
topic_id, embeddings
FROM
#{topic_table_name}
WHERE
model_id = #{id} AND
strategy_id = #{@strategy.id}
ORDER BY
binary_quantize(embeddings)::bit(#{dimensions}) <~> (
SELECT
binary_quantize(embeddings)::bit(#{dimensions})
FROM
le_target
LIMIT 1
)
LIMIT 200
) AS widenet
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} (
SELECT
embeddings::halfvec(#{dimensions})
FROM
le_target
LIMIT 1
)
LIMIT 100;
SQL
rescue PG::Error => e
Rails.logger.error(
"Error #{e} querying embeddings for topic #{topic.id} and model #{name}",
)
raise MissingEmbeddingError
end

def topic_table_name
"ai_topic_embeddings"
end
Expand Down
4 changes: 3 additions & 1 deletion spec/lib/modules/ai_bot/personas/persona_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,9 @@ def stub_fragments(fragment_count, persona: ai_persona)
end

it "uses the re-ranker to reorder the fragments and pick the top 10 candidates" do
expected_reranked = (0..14).to_a.reverse.map { |idx| { index: idx } }
# The re-ranker reverses the similarity search, but return less results
# to act as a limit for test-purposes.
expected_reranked = (4..14).to_a.reverse.map { |idx| { index: idx } }

WebMock.stub_request(:post, "https://test.reranker.com/rerank").to_return(
status: 200,
Expand Down
28 changes: 12 additions & 16 deletions spec/lib/modules/embeddings/semantic_topic_query_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,19 @@

fab!(:target) { Fabricate(:topic) }

def stub_semantic_search_with(results)
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn
.any_instance
.expects(:symmetric_topics_similarity_search)
.returns(results.concat([target.id]))
def seed_embeddings(topics)
schema = DiscourseAi::Embeddings::Schema.for(Topic)

embeddings = [1] * 1024
(topics << target).each { |t| schema.store(t, embeddings, "digest") }
end

after { DiscourseAi::Embeddings::SemanticRelated.clear_cache_for(target) }

context "when the semantic search returns an unlisted topic" do
fab!(:unlisted_topic) { Fabricate(:topic, visible: false) }

before { stub_semantic_search_with([unlisted_topic.id]) }
before { seed_embeddings([unlisted_topic]) }

it "filters it out" do
expect(topic_query.list_semantic_related_topics(target).topics).to be_empty
Expand All @@ -31,7 +31,7 @@ def stub_semantic_search_with(results)
context "when the semantic search returns a private topic" do
fab!(:private_topic) { Fabricate(:private_message_topic) }

before { stub_semantic_search_with([private_topic.id]) }
before { seed_embeddings([private_topic]) }

it "filters it out" do
expect(topic_query.list_semantic_related_topics(target).topics).to be_empty
Expand All @@ -43,7 +43,7 @@ def stub_semantic_search_with(results)
fab!(:category) { Fabricate(:private_category, group: group) }
fab!(:secured_category_topic) { Fabricate(:topic, category: category) }

before { stub_semantic_search_with([secured_category_topic.id]) }
before { seed_embeddings([secured_category_topic]) }

it "filters it out" do
expect(topic_query.list_semantic_related_topics(target).topics).to be_empty
Expand All @@ -63,7 +63,7 @@ def stub_semantic_search_with(results)

before do
SiteSetting.ai_embeddings_semantic_related_include_closed_topics = false
stub_semantic_search_with([closed_topic.id])
seed_embeddings([closed_topic])
end

it "filters it out" do
Expand All @@ -80,7 +80,7 @@ def stub_semantic_search_with(results)
category_id: category.id,
notification_level: CategoryUser.notification_levels[:muted],
)
stub_semantic_search_with([topic.id])
seed_embeddings([topic])
expect(topic_query.list_semantic_related_topics(target).topics).not_to include(topic)
end
end
Expand All @@ -91,11 +91,7 @@ def stub_semantic_search_with(results)
fab!(:normal_topic_3) { Fabricate(:topic) }
fab!(:closed_topic) { Fabricate(:topic, closed: true) }

before do
stub_semantic_search_with(
[closed_topic.id, normal_topic_1.id, normal_topic_2.id, normal_topic_3.id],
)
end
before { seed_embeddings([closed_topic, normal_topic_1, normal_topic_2, normal_topic_3]) }

it "filters it out" do
expect(topic_query.list_semantic_related_topics(target).topics).to eq(
Expand All @@ -117,7 +113,7 @@ def stub_semantic_search_with(results)
fab!(:included_topic) { Fabricate(:topic) }
fab!(:excluded_topic) { Fabricate(:topic) }

before { stub_semantic_search_with([included_topic.id, excluded_topic.id]) }
before { seed_embeddings([included_topic, excluded_topic]) }

let(:modifier_block) { Proc.new { |query| query.where.not(id: excluded_topic.id) } }

Expand Down

0 comments on commit a5a09fa

Please sign in to comment.