diff --git a/lib/teiserver/o_auth/queries/application_query.ex b/lib/teiserver/o_auth/queries/application_query.ex index a225a2009..0a22de451 100644 --- a/lib/teiserver/o_auth/queries/application_query.ex +++ b/lib/teiserver/o_auth/queries/application_query.ex @@ -1,7 +1,7 @@ defmodule Teiserver.OAuth.ApplicationQueries do use TeiserverWeb, :queries - alias Teiserver.OAuth.Application + alias Teiserver.OAuth.{Application, TokenQueries, CodeQueries, CredentialQueries} @doc """ Returns the application corresponding to the given uid/client id @@ -37,18 +37,47 @@ defmodule Teiserver.OAuth.ApplicationQueries do base_query() |> preload(:owner) |> Repo.all() end + @doc """ + returns the number of authorisation codes, authentication token and + client credentials for the given applications + """ + @spec get_stats(Application.id() | [Application.id()]) :: [ + %{ + code_count: non_neg_integer(), + token_count: non_neg_integer(), + credential_count: non_neg_integer() + } + ] + def get_stats(app_ids) when not is_list(app_ids), do: get_stats([app_ids]) + + def get_stats(app_ids) do + code_counts = CodeQueries.count_per_apps(app_ids) + token_counts = TokenQueries.count_per_apps(app_ids) + cred_counts = CredentialQueries.count_per_apps(app_ids) + + List.foldr(app_ids, [], fn app_id, acc -> + elem = %{ + code_count: Map.get(code_counts, app_id, 0), + token_count: Map.get(token_counts, app_id, 0), + credential_count: Map.get(cred_counts, app_id, 0) + } + + [elem | acc] + end) + end + def base_query() do from app in Application, as: :app end def where_id(query, id) do - from e in query, - where: e.id == ^id + from [app: app] in query, + where: app.id == ^id end def where_uid(query, uid) do - from e in query, - where: e.uid == ^uid + from [app: app] in query, + where: app.uid == ^uid end def join_application(query, name \\ :application) do diff --git a/lib/teiserver/o_auth/queries/code_query.ex b/lib/teiserver/o_auth/queries/code_query.ex index 00cf6166c..67d52cf7c 100644 --- a/lib/teiserver/o_auth/queries/code_query.ex +++ b/lib/teiserver/o_auth/queries/code_query.ex @@ -21,4 +21,31 @@ defmodule Teiserver.OAuth.CodeQueries do from e in query, where: e.value == ^value end + + def where_app_ids(query, app_ids) do + from [code: code] in query, + where: code.application_id in ^app_ids + end + + def not_expired(query, as_at \\ nil) do + as_at = as_at || DateTime.utc_now() + + from [code: code] in query, + where: code.expires_at > ^as_at + end + + @spec count_per_apps([Application.id()], DateTime.t() | nil) :: %{Application.id() => non_neg_integer()} + def count_per_apps(app_ids, as_at \\ nil) do + query = + base_query() + |> not_expired(as_at) + |> where_app_ids(app_ids) + + from([code: code] in query, + group_by: code.application_id, + select: {code.application_id, count(code.id)} + ) + |> Repo.all() + |> Enum.into(%{}) + end end diff --git a/lib/teiserver/o_auth/queries/credential_query.ex b/lib/teiserver/o_auth/queries/credential_query.ex index 61ce909cc..ec06266c6 100644 --- a/lib/teiserver/o_auth/queries/credential_query.ex +++ b/lib/teiserver/o_auth/queries/credential_query.ex @@ -17,7 +17,26 @@ defmodule Teiserver.OAuth.CredentialQueries do end def where_client_id(query, client_id) do - from credential in query, + from [credential: credential] in query, where: credential.client_id == ^client_id end + + @spec count_per_apps([Application.id()]) :: %{Application.id() => non_neg_integer()} + def count_per_apps(app_ids) do + query = + base_query() + |> where_app_ids(app_ids) + + from([credential: credential] in query, + group_by: credential.application_id, + select: {credential.application_id, count(credential.id)} + ) + |> Repo.all() + |> Enum.into(%{}) + end + + def where_app_ids(query, app_ids) do + from [credential: credential] in query, + where: credential.application_id in ^app_ids + end end diff --git a/lib/teiserver/o_auth/queries/token_query.ex b/lib/teiserver/o_auth/queries/token_query.ex index 4a56f6899..75041a9d1 100644 --- a/lib/teiserver/o_auth/queries/token_query.ex +++ b/lib/teiserver/o_auth/queries/token_query.ex @@ -1,6 +1,6 @@ defmodule Teiserver.OAuth.TokenQueries do use TeiserverWeb, :queries - alias Teiserver.OAuth.Token + alias Teiserver.OAuth.{Application, Token} @doc """ Return the db object corresponding to the given token. @@ -17,8 +17,7 @@ defmodule Teiserver.OAuth.TokenQueries do def base_query() do from token in Token, - as: :token, - preload: [refresh_token: token] + as: :token end def where_token(query, value) do @@ -26,6 +25,18 @@ defmodule Teiserver.OAuth.TokenQueries do where: e.value == ^value end + def where_app_ids(query, app_ids) do + from [token: token] in query, + where: token.application_id in ^app_ids + end + + def not_expired(query, as_at \\ nil) do + as_at = as_at || DateTime.utc_now() + + from [token: token] in query, + where: token.expires_at > ^as_at + end + @doc """ given a refresh token, deletes it and its potential associated token """ @@ -33,4 +44,21 @@ defmodule Teiserver.OAuth.TokenQueries do from(tok in Token, where: tok.id == ^token.id or tok.refresh_token_id == ^token.id) |> Repo.delete_all() end + + @spec count_per_apps([Application.id()], DateTime.t() | nil) :: %{ + Application.id() => non_neg_integer() + } + def count_per_apps(app_ids, as_at \\ nil) do + query = + base_query() + |> not_expired(as_at) + |> where_app_ids(app_ids) + + from([token: token] in query, + group_by: token.application_id, + select: {token.application_id, count(token.id)} + ) + |> Repo.all() + |> Enum.into(%{}) + end end diff --git a/test/support/fixtures/o_auth/o_auth_fixtures.ex b/test/support/fixtures/o_auth/o_auth_fixtures.ex new file mode 100644 index 000000000..66eab5839 --- /dev/null +++ b/test/support/fixtures/o_auth/o_auth_fixtures.ex @@ -0,0 +1,69 @@ +defmodule Teiserver.OAuthFixtures do + alias Teiserver.OAuth.{Application, Code, Token, Credential} + alias Teiserver.Repo + + def app_attrs(owner_id) do + %{ + name: "fixture app", + uid: "fixture_app", + owner_id: owner_id, + scopes: ["tachyon.lobby"], + redirect_uris: ["http://localhost/foo"], + description: "app created as part of a test" + } + end + + def create_app(attrs) do + %Application{} |> Application.changeset(attrs) |> Repo.insert!() + end + + def code_attrs(user_id, app) do + now = DateTime.utc_now() + + %{ + value: Base.hex_encode32(:crypto.strong_rand_bytes(32)), + owner_id: user_id, + application_id: app.id, + scopes: app.scopes, + expires_at: Timex.add(now, Timex.Duration.from_minutes(5)), + redirect_uri: hd(app.redirect_uris), + challenge: "TODO", + challenge_method: :plain + } + end + + def create_code(attrs) do + %Code{} |> Code.changeset(attrs) |> Repo.insert!() + end + + def token_attrs(user_id, application) do + now = DateTime.utc_now() + + %{ + value: Base.hex_encode32(:crypto.strong_rand_bytes(32), padding: false), + owner_id: user_id, + application_id: application.id, + scopes: application.scopes, + expires_at: Timex.add(now, Timex.Duration.from_days(60)), + type: :access, + refresh_token: nil + } + end + + def create_token(attrs) do + %Token{} |> Token.changeset(attrs) |> Repo.insert!() + end + + def credential_attrs(autohost, app_id) do + %{ + application_id: app_id, + autohost_id: autohost.id, + client_id: UUID.uuid4(), + hashed_secret: UUID.uuid4() + } + end + + def create_credential(attrs) do + %Credential{} |> Credential.changeset(attrs) |> Repo.insert!() + end +end diff --git a/test/teiserver/o_auth/application_query_test.exs b/test/teiserver/o_auth/application_query_test.exs new file mode 100644 index 000000000..dc4a64935 --- /dev/null +++ b/test/teiserver/o_auth/application_query_test.exs @@ -0,0 +1,111 @@ +defmodule Teiserver.OAuth.ApplicationQueryTest do + use Teiserver.DataCase + alias Teiserver.Repo + + alias Teiserver.OAuth.{Application, Code, ApplicationQueries, Token, Credential} + alias Teiserver.OAuthFixtures + + defp setup_app(_context) do + user = Teiserver.TeiserverTestLib.new_user() + + app = OAuthFixtures.app_attrs(user.id) |> OAuthFixtures.create_app() + + %{user: user, app: app} + end + + defp setup_autohost(_context) do + alias Teiserver.Autohost.Autohost + + autohost = + %Autohost{} + |> Autohost.changeset(%{name: "fixture autohost"}) + |> Repo.insert!() + + %{autohost: autohost} + end + + describe "app stats" do + setup [:setup_app, :setup_autohost] + + test "nothing associated", %{app: app} do + assert [ + %{ + code_count: 0, + token_count: 0, + credential_count: 0 + } + ] == ApplicationQueries.get_stats(app.id) + end + + test "bit of everything", %{app: app, autohost: autohost} do + OAuthFixtures.code_attrs(app.owner_id, app) + |> OAuthFixtures.create_code() + + Enum.each(1..2, fn _ -> + OAuthFixtures.token_attrs(app.owner_id, app) + |> OAuthFixtures.create_token() + end) + + Enum.each(1..3, fn _ -> + OAuthFixtures.credential_attrs(autohost, app.id) + |> OAuthFixtures.create_credential() + end) + + assert [ + %{ + code_count: 1, + token_count: 2, + credential_count: 3 + } + ] == ApplicationQueries.get_stats(app.id) + end + + test "select only correct applications", %{user: user, app: app} do + other_app = + OAuthFixtures.app_attrs(user.id) + |> Map.merge(%{name: "other app", uid: "other_app"}) + |> OAuthFixtures.create_app() + + OAuthFixtures.code_attrs(user.id, other_app) |> OAuthFixtures.create_code() + + assert [%{code_count: 1}] = ApplicationQueries.get_stats(other_app.id) + end + + test "ignore expired code and tokens", %{user: user, app: app} do + yesterday = DateTime.utc_now() |> Timex.subtract(Timex.Duration.from_days(1)) + + OAuthFixtures.code_attrs(user.id, app) + |> Map.put(:expires_at, yesterday) + |> OAuthFixtures.create_code() + + OAuthFixtures.token_attrs(user.id, app) + |> Map.put(:expires_at, yesterday) + |> OAuthFixtures.create_token() + + assert [%{code_count: 0, token_count: 0}] = ApplicationQueries.get_stats(app.id) + end + + test "don't mix up different applications", %{user: user, app: app} do + other_app = + OAuthFixtures.app_attrs(user.id) + |> Map.merge(%{name: "other app", uid: "other_app"}) + |> OAuthFixtures.create_app() + + OAuthFixtures.code_attrs(user.id, app) + |> OAuthFixtures.create_code() + + Enum.each(1..2, fn i -> + OAuthFixtures.code_attrs(user.id, other_app) + |> Map.put(:value, "value_#{i}") + |> OAuthFixtures.create_code() + end) + + assert [%{code_count: 2}, %{code_count: 1}] = + ApplicationQueries.get_stats([other_app.id, app.id]) + + # check that it also works with different id ordering + assert [%{code_count: 1}, %{code_count: 2}] = + ApplicationQueries.get_stats([app.id, other_app.id]) + end + end +end