Skip to content

Commit

Permalink
add bumblebee adapter for bonfire-networks/bonfire-app#852
Browse files Browse the repository at this point in the history
  • Loading branch information
mayel committed Sep 15, 2024
1 parent b7a055a commit 08feb88
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ defmodule Bonfire.Common.AntiSpam.Akismet do
def check_profile(text, context) do
check_content(
%{
#  contains name & bio
comment_content: text,
comment_type: "signup"
},
Expand Down
106 changes: 106 additions & 0 deletions lib/anti_spam/bumblebee_adapter.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
defmodule Bonfire.Common.AntiSpam.BumblebeeAdapter do
@moduledoc """
Integration with Bumblebee model(s) for anti-spam detection.
"""
use Bonfire.Common.Utils

alias Bonfire.Common.AntiSpam.Provider

@behaviour Provider

@impl Provider
def check_current_user(context) do
# TODO: check profile instead?
:ham
end

@impl Provider
def check_profile(text, context) do
check_content(
%{
comment_content: text
},
context
)
end

@impl Provider
def check_object(text, context) do
check_content(
%{
#  contains name & bio
comment_content: text
},
context
)
end

@impl Provider
def check_comment(
comment_body,
_is_reply?,
context
) do
check_content(
%{
comment_content: comment_body
},
context
)
end

defp check_content(comment, context) do
if Config.get(:env) != :test and ready?() do
# debug(context)
current_user = current_user(context)
serving = prepare_serving()

with %{
predictions: predictions
} =
"""
#{e(current_user, :profile, :username, nil)} (#{maybe_apply(Bonfire.Me.Characters, :display_username, [current_user, true, nil, ""], fallback_return: nil)}): #{comment[:comment_content]}
"""
|> debug("text to check")
|> Nx.Serving.run(serving, ...),
# |> debug("spam_result"),
%{"LABEL_0" => [ham_score], "LABEL_1" => [spam_score]} <-
Enum.group_by(predictions, &Map.get(&1, :label), &Map.get(&1, :score))
|> debug("result") do
cond do
spam_score > 0.90 and ham_score < 0.5 -> :spam
spam_score > 0.99 -> :spam
true -> :ham
end
else
other ->
error(other, "Could not recognise response from AI model")
:ham
end
else
:ham
end
end

defp prepare_serving(
model \\ "mrm8488/bert-tiny-finetuned-enron-spam-detection",
tokenizer \\ nil
) do
{:ok, model_info} =
Bumblebee.load_model({:hf, model})

{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, tokenizer || model})

serving =
Bumblebee.Text.text_classification(model_info, tokenizer,
compile: [batch_size: 1, sequence_length: 100],
defn_options: [compiler: EXLA]
)
end

@impl Provider
def ready?, do: module_enabled?(Bumblebee)

defp log_response(res),
do: tap(res, fn res -> debug(res, "Return from Akismet is") end)
end
4 changes: 2 additions & 2 deletions lib/anti_spam/provider.ex
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ defmodule Bonfire.Common.AntiSpam.Provider do
@callback ready?() :: boolean()

@doc """
Check an user details
Check an user/account's details
"""
@callback check_current_user(context :: any()) ::
result()

@doc """
Check a profile details
Check profile info
"""
@callback check_profile(
summary :: String.t(),
Expand Down
6 changes: 4 additions & 2 deletions lib/config_settings/config.ex
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,16 @@ defmodule Bonfire.Common.Config do
if Application.compile_env!(:bonfire, :env) == :test do
# NOTE: enables using `ProcessTree` in test env, eg. `Process.put([:bonfire_common, :my_key], :value)`
def get(keys, default, otp_app) when is_list(keys),
do: ProcessTree.get([otp_app] ++ keys) || get_config(keys, default, otp_app)
do: get_for_process([otp_app] ++ keys) || get_config(keys, default, otp_app)

def get(key, default, otp_app),
do: ProcessTree.get([otp_app, key]) || get_config(key, default, otp_app)
do: get_for_process([otp_app, key]) || get_config(key, default, otp_app)
else
def get(keys, default, otp_app), do: get_config(keys, default, otp_app)
end

def get_for_process(keys), do: ProcessTree.get(keys)

defp get_config([key], default, otp_app), do: get_config(key, default, otp_app)

defp get_config([parent_key | keys], default, otp_app) do
Expand Down
3 changes: 2 additions & 1 deletion mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@ defmodule Bonfire.Common.MixProject do
optional: true},
{:text, "~> 0.2.0", optional: true},
{:text_corpus_udhr, "~> 0.1.0", optional: true},
{:bumblebee, "~> 0.5.0", optional: true},
# needed for graphql client, eg github for changelog
{:neuron, "~> 5.0", optional: true},
# for extension install + mix tasks that do patching
{:igniter, "~> 0.3", optional: true}
{:igniter, "~> 0.3", optional: true}
])
]
end
Expand Down

0 comments on commit 08feb88

Please sign in to comment.