From 1ed5f379b9f3e38b64cc9de9f418c164ce400be1 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Fri, 29 Nov 2024 09:53:49 -0800 Subject: [PATCH 1/3] Move GetFeeds to service layer (#32526) Move GetFeeds from models to service layer, no code change. --- models/activities/action.go | 54 +----- models/activities/action_list.go | 52 ++++++ models/activities/action_test.go | 149 ---------------- models/activities/user_heatmap.go | 2 +- routers/api/v1/org/org.go | 3 +- routers/api/v1/org/team.go | 3 +- routers/api/v1/repo/repo.go | 3 +- routers/api/v1/user/user.go | 3 +- routers/web/feed/profile.go | 3 +- routers/web/feed/repo.go | 3 +- routers/web/user/home.go | 3 +- routers/web/user/profile.go | 3 +- services/feed/feed.go | 15 ++ services/feed/feed_test.go | 165 ++++++++++++++++++ services/feed/{action.go => notifier.go} | 0 .../feed/{action_test.go => notifier_test.go} | 0 16 files changed, 250 insertions(+), 211 deletions(-) create mode 100644 services/feed/feed.go create mode 100644 services/feed/feed_test.go rename services/feed/{action.go => notifier.go} (100%) rename services/feed/{action_test.go => notifier_test.go} (100%) diff --git a/models/activities/action.go b/models/activities/action.go index 546d4340aedca..65d95fbe6676e 100644 --- a/models/activities/action.go +++ b/models/activities/action.go @@ -448,65 +448,13 @@ type GetFeedsOptions struct { Date string // the day we want activity for: YYYY-MM-DD } -// GetFeeds returns actions according to the provided options -func GetFeeds(ctx context.Context, opts GetFeedsOptions) (ActionList, int64, error) { - if opts.RequestedUser == nil && opts.RequestedTeam == nil && opts.RequestedRepo == nil { - return nil, 0, fmt.Errorf("need at least one of these filters: RequestedUser, RequestedTeam, RequestedRepo") - } - - cond, err := activityQueryCondition(ctx, opts) - if err != nil { - return nil, 0, err - } - - actions := make([]*Action, 0, opts.PageSize) - var count int64 - opts.SetDefaultValues() - - if opts.Page < 10 { // TODO: why it's 10 but other values? It's an experience value. - sess := db.GetEngine(ctx).Where(cond) - sess = db.SetSessionPagination(sess, &opts) - - count, err = sess.Desc("`action`.created_unix").FindAndCount(&actions) - if err != nil { - return nil, 0, fmt.Errorf("FindAndCount: %w", err) - } - } else { - // First, only query which IDs are necessary, and only then query all actions to speed up the overall query - sess := db.GetEngine(ctx).Where(cond).Select("`action`.id") - sess = db.SetSessionPagination(sess, &opts) - - actionIDs := make([]int64, 0, opts.PageSize) - if err := sess.Table("action").Desc("`action`.created_unix").Find(&actionIDs); err != nil { - return nil, 0, fmt.Errorf("Find(actionsIDs): %w", err) - } - - count, err = db.GetEngine(ctx).Where(cond). - Table("action"). - Cols("`action`.id").Count() - if err != nil { - return nil, 0, fmt.Errorf("Count: %w", err) - } - - if err := db.GetEngine(ctx).In("`action`.id", actionIDs).Desc("`action`.created_unix").Find(&actions); err != nil { - return nil, 0, fmt.Errorf("Find: %w", err) - } - } - - if err := ActionList(actions).LoadAttributes(ctx); err != nil { - return nil, 0, fmt.Errorf("LoadAttributes: %w", err) - } - - return actions, count, nil -} - // ActivityReadable return whether doer can read activities of user func ActivityReadable(user, doer *user_model.User) bool { return !user.KeepActivityPrivate || doer != nil && (doer.IsAdmin || user.ID == doer.ID) } -func activityQueryCondition(ctx context.Context, opts GetFeedsOptions) (builder.Cond, error) { +func ActivityQueryCondition(ctx context.Context, opts GetFeedsOptions) (builder.Cond, error) { cond := builder.NewCond() if opts.RequestedTeam != nil && opts.RequestedUser == nil { diff --git a/models/activities/action_list.go b/models/activities/action_list.go index aafb7f8a26c57..5f9acb8f2aa46 100644 --- a/models/activities/action_list.go +++ b/models/activities/action_list.go @@ -201,3 +201,55 @@ func (actions ActionList) LoadIssues(ctx context.Context) error { } return nil } + +// GetFeeds returns actions according to the provided options +func GetFeeds(ctx context.Context, opts GetFeedsOptions) (ActionList, int64, error) { + if opts.RequestedUser == nil && opts.RequestedTeam == nil && opts.RequestedRepo == nil { + return nil, 0, fmt.Errorf("need at least one of these filters: RequestedUser, RequestedTeam, RequestedRepo") + } + + cond, err := ActivityQueryCondition(ctx, opts) + if err != nil { + return nil, 0, err + } + + actions := make([]*Action, 0, opts.PageSize) + var count int64 + opts.SetDefaultValues() + + if opts.Page < 10 { // TODO: why it's 10 but other values? It's an experience value. + sess := db.GetEngine(ctx).Where(cond) + sess = db.SetSessionPagination(sess, &opts) + + count, err = sess.Desc("`action`.created_unix").FindAndCount(&actions) + if err != nil { + return nil, 0, fmt.Errorf("FindAndCount: %w", err) + } + } else { + // First, only query which IDs are necessary, and only then query all actions to speed up the overall query + sess := db.GetEngine(ctx).Where(cond).Select("`action`.id") + sess = db.SetSessionPagination(sess, &opts) + + actionIDs := make([]int64, 0, opts.PageSize) + if err := sess.Table("action").Desc("`action`.created_unix").Find(&actionIDs); err != nil { + return nil, 0, fmt.Errorf("Find(actionsIDs): %w", err) + } + + count, err = db.GetEngine(ctx).Where(cond). + Table("action"). + Cols("`action`.id").Count() + if err != nil { + return nil, 0, fmt.Errorf("Count: %w", err) + } + + if err := db.GetEngine(ctx).In("`action`.id", actionIDs).Desc("`action`.created_unix").Find(&actions); err != nil { + return nil, 0, fmt.Errorf("Find: %w", err) + } + } + + if err := ActionList(actions).LoadAttributes(ctx); err != nil { + return nil, 0, fmt.Errorf("LoadAttributes: %w", err) + } + + return actions, count, nil +} diff --git a/models/activities/action_test.go b/models/activities/action_test.go index 64330ebbb3e9a..9cfe98165686c 100644 --- a/models/activities/action_test.go +++ b/models/activities/action_test.go @@ -42,114 +42,6 @@ func TestAction_GetRepoLink(t *testing.T) { assert.Equal(t, comment.HTMLURL(db.DefaultContext), action.GetCommentHTMLURL(db.DefaultContext)) } -func TestGetFeeds(t *testing.T) { - // test with an individual user - assert.NoError(t, unittest.PrepareTestDatabase()) - user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2}) - - actions, count, err := activities_model.GetFeeds(db.DefaultContext, activities_model.GetFeedsOptions{ - RequestedUser: user, - Actor: user, - IncludePrivate: true, - OnlyPerformedBy: false, - IncludeDeleted: true, - }) - assert.NoError(t, err) - if assert.Len(t, actions, 1) { - assert.EqualValues(t, 1, actions[0].ID) - assert.EqualValues(t, user.ID, actions[0].UserID) - } - assert.Equal(t, int64(1), count) - - actions, count, err = activities_model.GetFeeds(db.DefaultContext, activities_model.GetFeedsOptions{ - RequestedUser: user, - Actor: user, - IncludePrivate: false, - OnlyPerformedBy: false, - }) - assert.NoError(t, err) - assert.Len(t, actions, 0) - assert.Equal(t, int64(0), count) -} - -func TestGetFeedsForRepos(t *testing.T) { - assert.NoError(t, unittest.PrepareTestDatabase()) - user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2}) - privRepo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 2}) - pubRepo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 8}) - - // private repo & no login - actions, count, err := activities_model.GetFeeds(db.DefaultContext, activities_model.GetFeedsOptions{ - RequestedRepo: privRepo, - IncludePrivate: true, - }) - assert.NoError(t, err) - assert.Len(t, actions, 0) - assert.Equal(t, int64(0), count) - - // public repo & no login - actions, count, err = activities_model.GetFeeds(db.DefaultContext, activities_model.GetFeedsOptions{ - RequestedRepo: pubRepo, - IncludePrivate: true, - }) - assert.NoError(t, err) - assert.Len(t, actions, 1) - assert.Equal(t, int64(1), count) - - // private repo and login - actions, count, err = activities_model.GetFeeds(db.DefaultContext, activities_model.GetFeedsOptions{ - RequestedRepo: privRepo, - IncludePrivate: true, - Actor: user, - }) - assert.NoError(t, err) - assert.Len(t, actions, 1) - assert.Equal(t, int64(1), count) - - // public repo & login - actions, count, err = activities_model.GetFeeds(db.DefaultContext, activities_model.GetFeedsOptions{ - RequestedRepo: pubRepo, - IncludePrivate: true, - Actor: user, - }) - assert.NoError(t, err) - assert.Len(t, actions, 1) - assert.Equal(t, int64(1), count) -} - -func TestGetFeeds2(t *testing.T) { - // test with an organization user - assert.NoError(t, unittest.PrepareTestDatabase()) - org := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 3}) - user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2}) - - actions, count, err := activities_model.GetFeeds(db.DefaultContext, activities_model.GetFeedsOptions{ - RequestedUser: org, - Actor: user, - IncludePrivate: true, - OnlyPerformedBy: false, - IncludeDeleted: true, - }) - assert.NoError(t, err) - assert.Len(t, actions, 1) - if assert.Len(t, actions, 1) { - assert.EqualValues(t, 2, actions[0].ID) - assert.EqualValues(t, org.ID, actions[0].UserID) - } - assert.Equal(t, int64(1), count) - - actions, count, err = activities_model.GetFeeds(db.DefaultContext, activities_model.GetFeedsOptions{ - RequestedUser: org, - Actor: user, - IncludePrivate: false, - OnlyPerformedBy: false, - IncludeDeleted: true, - }) - assert.NoError(t, err) - assert.Len(t, actions, 0) - assert.Equal(t, int64(0), count) -} - func TestActivityReadable(t *testing.T) { tt := []struct { desc string @@ -227,26 +119,6 @@ func TestNotifyWatchers(t *testing.T) { }) } -func TestGetFeedsCorrupted(t *testing.T) { - // Now we will not check for corrupted data in the feeds - // users should run doctor to fix their data - assert.NoError(t, unittest.PrepareTestDatabase()) - user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 1}) - unittest.AssertExistsAndLoadBean(t, &activities_model.Action{ - ID: 8, - RepoID: 1700, - }) - - actions, count, err := activities_model.GetFeeds(db.DefaultContext, activities_model.GetFeedsOptions{ - RequestedUser: user, - Actor: user, - IncludePrivate: true, - }) - assert.NoError(t, err) - assert.Len(t, actions, 1) - assert.Equal(t, int64(1), count) -} - func TestConsistencyUpdateAction(t *testing.T) { if !setting.Database.Type.IsSQLite3() { t.Skip("Test is only for SQLite database.") @@ -322,24 +194,3 @@ func TestDeleteIssueActions(t *testing.T) { assert.NoError(t, activities_model.DeleteIssueActions(db.DefaultContext, issue.RepoID, issue.ID, issue.Index)) unittest.AssertCount(t, &activities_model.Action{}, 0) } - -func TestRepoActions(t *testing.T) { - assert.NoError(t, unittest.PrepareTestDatabase()) - repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 1}) - _ = db.TruncateBeans(db.DefaultContext, &activities_model.Action{}) - for i := 0; i < 3; i++ { - _ = db.Insert(db.DefaultContext, &activities_model.Action{ - UserID: 2 + int64(i), - ActUserID: 2, - RepoID: repo.ID, - OpType: activities_model.ActionCommentIssue, - }) - } - count, _ := db.Count[activities_model.Action](db.DefaultContext, &db.ListOptions{}) - assert.EqualValues(t, 3, count) - actions, _, err := activities_model.GetFeeds(db.DefaultContext, activities_model.GetFeedsOptions{ - RequestedRepo: repo, - }) - assert.NoError(t, err) - assert.Len(t, actions, 1) -} diff --git a/models/activities/user_heatmap.go b/models/activities/user_heatmap.go index 78fcd76d43ace..1f8f0f590e1ab 100644 --- a/models/activities/user_heatmap.go +++ b/models/activities/user_heatmap.go @@ -47,7 +47,7 @@ func getUserHeatmapData(ctx context.Context, user *user_model.User, team *organi groupByName = groupBy } - cond, err := activityQueryCondition(ctx, GetFeedsOptions{ + cond, err := ActivityQueryCondition(ctx, GetFeedsOptions{ RequestedUser: user, RequestedTeam: team, Actor: doer, diff --git a/routers/api/v1/org/org.go b/routers/api/v1/org/org.go index 9e5874627298d..3fb653bcb6d0c 100644 --- a/routers/api/v1/org/org.go +++ b/routers/api/v1/org/org.go @@ -19,6 +19,7 @@ import ( "code.gitea.io/gitea/routers/api/v1/utils" "code.gitea.io/gitea/services/context" "code.gitea.io/gitea/services/convert" + feed_service "code.gitea.io/gitea/services/feed" "code.gitea.io/gitea/services/org" user_service "code.gitea.io/gitea/services/user" ) @@ -447,7 +448,7 @@ func ListOrgActivityFeeds(ctx *context.APIContext) { ListOptions: listOptions, } - feeds, count, err := activities_model.GetFeeds(ctx, opts) + feeds, count, err := feed_service.GetFeeds(ctx, opts) if err != nil { ctx.Error(http.StatusInternalServerError, "GetFeeds", err) return diff --git a/routers/api/v1/org/team.go b/routers/api/v1/org/team.go index 20226b4d6b0ee..bc50960b61b11 100644 --- a/routers/api/v1/org/team.go +++ b/routers/api/v1/org/team.go @@ -22,6 +22,7 @@ import ( "code.gitea.io/gitea/routers/api/v1/utils" "code.gitea.io/gitea/services/context" "code.gitea.io/gitea/services/convert" + feed_service "code.gitea.io/gitea/services/feed" org_service "code.gitea.io/gitea/services/org" repo_service "code.gitea.io/gitea/services/repository" ) @@ -882,7 +883,7 @@ func ListTeamActivityFeeds(ctx *context.APIContext) { ListOptions: listOptions, } - feeds, count, err := activities_model.GetFeeds(ctx, opts) + feeds, count, err := feed_service.GetFeeds(ctx, opts) if err != nil { ctx.Error(http.StatusInternalServerError, "GetFeeds", err) return diff --git a/routers/api/v1/repo/repo.go b/routers/api/v1/repo/repo.go index 69a95dd5a58d9..40990a28cbdee 100644 --- a/routers/api/v1/repo/repo.go +++ b/routers/api/v1/repo/repo.go @@ -34,6 +34,7 @@ import ( actions_service "code.gitea.io/gitea/services/actions" "code.gitea.io/gitea/services/context" "code.gitea.io/gitea/services/convert" + feed_service "code.gitea.io/gitea/services/feed" "code.gitea.io/gitea/services/issue" repo_service "code.gitea.io/gitea/services/repository" ) @@ -1313,7 +1314,7 @@ func ListRepoActivityFeeds(ctx *context.APIContext) { ListOptions: listOptions, } - feeds, count, err := activities_model.GetFeeds(ctx, opts) + feeds, count, err := feed_service.GetFeeds(ctx, opts) if err != nil { ctx.Error(http.StatusInternalServerError, "GetFeeds", err) return diff --git a/routers/api/v1/user/user.go b/routers/api/v1/user/user.go index a9011427fb577..e668326861e1a 100644 --- a/routers/api/v1/user/user.go +++ b/routers/api/v1/user/user.go @@ -13,6 +13,7 @@ import ( "code.gitea.io/gitea/routers/api/v1/utils" "code.gitea.io/gitea/services/context" "code.gitea.io/gitea/services/convert" + feed_service "code.gitea.io/gitea/services/feed" ) // Search search users @@ -214,7 +215,7 @@ func ListUserActivityFeeds(ctx *context.APIContext) { ListOptions: listOptions, } - feeds, count, err := activities_model.GetFeeds(ctx, opts) + feeds, count, err := feed_service.GetFeeds(ctx, opts) if err != nil { ctx.Error(http.StatusInternalServerError, "GetFeeds", err) return diff --git a/routers/web/feed/profile.go b/routers/web/feed/profile.go index 47de7c089def1..4ec46e302a937 100644 --- a/routers/web/feed/profile.go +++ b/routers/web/feed/profile.go @@ -10,6 +10,7 @@ import ( "code.gitea.io/gitea/models/renderhelper" "code.gitea.io/gitea/modules/markup/markdown" "code.gitea.io/gitea/services/context" + feed_service "code.gitea.io/gitea/services/feed" "github.com/gorilla/feeds" ) @@ -28,7 +29,7 @@ func ShowUserFeedAtom(ctx *context.Context) { func showUserFeed(ctx *context.Context, formatType string) { includePrivate := ctx.IsSigned && (ctx.Doer.IsAdmin || ctx.Doer.ID == ctx.ContextUser.ID) - actions, _, err := activities_model.GetFeeds(ctx, activities_model.GetFeedsOptions{ + actions, _, err := feed_service.GetFeeds(ctx, activities_model.GetFeedsOptions{ RequestedUser: ctx.ContextUser, Actor: ctx.Doer, IncludePrivate: includePrivate, diff --git a/routers/web/feed/repo.go b/routers/web/feed/repo.go index bfcc3a37d6a9e..2e69fac758105 100644 --- a/routers/web/feed/repo.go +++ b/routers/web/feed/repo.go @@ -9,13 +9,14 @@ import ( activities_model "code.gitea.io/gitea/models/activities" repo_model "code.gitea.io/gitea/models/repo" "code.gitea.io/gitea/services/context" + feed_service "code.gitea.io/gitea/services/feed" "github.com/gorilla/feeds" ) // ShowRepoFeed shows user activity on the repo as RSS / Atom feed func ShowRepoFeed(ctx *context.Context, repo *repo_model.Repository, formatType string) { - actions, _, err := activities_model.GetFeeds(ctx, activities_model.GetFeedsOptions{ + actions, _, err := feed_service.GetFeeds(ctx, activities_model.GetFeedsOptions{ RequestedRepo: repo, Actor: ctx.Doer, IncludePrivate: true, diff --git a/routers/web/user/home.go b/routers/web/user/home.go index 6149ccb08d54e..0cf932ac03b60 100644 --- a/routers/web/user/home.go +++ b/routers/web/user/home.go @@ -33,6 +33,7 @@ import ( "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/routers/web/feed" "code.gitea.io/gitea/services/context" + feed_service "code.gitea.io/gitea/services/feed" issue_service "code.gitea.io/gitea/services/issue" pull_service "code.gitea.io/gitea/services/pull" @@ -113,7 +114,7 @@ func Dashboard(ctx *context.Context) { ctx.Data["HeatmapTotalContributions"] = activities_model.GetTotalContributionsInHeatmap(data) } - feeds, count, err := activities_model.GetFeeds(ctx, activities_model.GetFeedsOptions{ + feeds, count, err := feed_service.GetFeeds(ctx, activities_model.GetFeedsOptions{ RequestedUser: ctxUser, RequestedTeam: ctx.Org.Team, Actor: ctx.Doer, diff --git a/routers/web/user/profile.go b/routers/web/user/profile.go index 931af0a2839ec..c41030a5e2515 100644 --- a/routers/web/user/profile.go +++ b/routers/web/user/profile.go @@ -26,6 +26,7 @@ import ( "code.gitea.io/gitea/routers/web/org" shared_user "code.gitea.io/gitea/routers/web/shared/user" "code.gitea.io/gitea/services/context" + feed_service "code.gitea.io/gitea/services/feed" ) const ( @@ -167,7 +168,7 @@ func prepareUserProfileTabData(ctx *context.Context, showPrivate bool, profileDb case "activity": date := ctx.FormString("date") pagingNum = setting.UI.FeedPagingNum - items, count, err := activities_model.GetFeeds(ctx, activities_model.GetFeedsOptions{ + items, count, err := feed_service.GetFeeds(ctx, activities_model.GetFeedsOptions{ RequestedUser: ctx.ContextUser, Actor: ctx.Doer, IncludePrivate: showPrivate, diff --git a/services/feed/feed.go b/services/feed/feed.go new file mode 100644 index 0000000000000..93bf875fd04cd --- /dev/null +++ b/services/feed/feed.go @@ -0,0 +1,15 @@ +// Copyright 2024 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package feed + +import ( + "context" + + activities_model "code.gitea.io/gitea/models/activities" +) + +// GetFeeds returns actions according to the provided options +func GetFeeds(ctx context.Context, opts activities_model.GetFeedsOptions) (activities_model.ActionList, int64, error) { + return activities_model.GetFeeds(ctx, opts) +} diff --git a/services/feed/feed_test.go b/services/feed/feed_test.go new file mode 100644 index 0000000000000..6f1cb9a969bd9 --- /dev/null +++ b/services/feed/feed_test.go @@ -0,0 +1,165 @@ +// Copyright 2024 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package feed + +import ( + "testing" + + activities_model "code.gitea.io/gitea/models/activities" + "code.gitea.io/gitea/models/db" + repo_model "code.gitea.io/gitea/models/repo" + "code.gitea.io/gitea/models/unittest" + user_model "code.gitea.io/gitea/models/user" + + "github.com/stretchr/testify/assert" +) + +func TestGetFeeds(t *testing.T) { + // test with an individual user + assert.NoError(t, unittest.PrepareTestDatabase()) + user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2}) + + actions, count, err := GetFeeds(db.DefaultContext, activities_model.GetFeedsOptions{ + RequestedUser: user, + Actor: user, + IncludePrivate: true, + OnlyPerformedBy: false, + IncludeDeleted: true, + }) + assert.NoError(t, err) + if assert.Len(t, actions, 1) { + assert.EqualValues(t, 1, actions[0].ID) + assert.EqualValues(t, user.ID, actions[0].UserID) + } + assert.Equal(t, int64(1), count) + + actions, count, err = GetFeeds(db.DefaultContext, activities_model.GetFeedsOptions{ + RequestedUser: user, + Actor: user, + IncludePrivate: false, + OnlyPerformedBy: false, + }) + assert.NoError(t, err) + assert.Len(t, actions, 0) + assert.Equal(t, int64(0), count) +} + +func TestGetFeedsForRepos(t *testing.T) { + assert.NoError(t, unittest.PrepareTestDatabase()) + user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2}) + privRepo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 2}) + pubRepo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 8}) + + // private repo & no login + actions, count, err := GetFeeds(db.DefaultContext, activities_model.GetFeedsOptions{ + RequestedRepo: privRepo, + IncludePrivate: true, + }) + assert.NoError(t, err) + assert.Len(t, actions, 0) + assert.Equal(t, int64(0), count) + + // public repo & no login + actions, count, err = GetFeeds(db.DefaultContext, activities_model.GetFeedsOptions{ + RequestedRepo: pubRepo, + IncludePrivate: true, + }) + assert.NoError(t, err) + assert.Len(t, actions, 1) + assert.Equal(t, int64(1), count) + + // private repo and login + actions, count, err = GetFeeds(db.DefaultContext, activities_model.GetFeedsOptions{ + RequestedRepo: privRepo, + IncludePrivate: true, + Actor: user, + }) + assert.NoError(t, err) + assert.Len(t, actions, 1) + assert.Equal(t, int64(1), count) + + // public repo & login + actions, count, err = GetFeeds(db.DefaultContext, activities_model.GetFeedsOptions{ + RequestedRepo: pubRepo, + IncludePrivate: true, + Actor: user, + }) + assert.NoError(t, err) + assert.Len(t, actions, 1) + assert.Equal(t, int64(1), count) +} + +func TestGetFeeds2(t *testing.T) { + // test with an organization user + assert.NoError(t, unittest.PrepareTestDatabase()) + org := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 3}) + user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2}) + + actions, count, err := GetFeeds(db.DefaultContext, activities_model.GetFeedsOptions{ + RequestedUser: org, + Actor: user, + IncludePrivate: true, + OnlyPerformedBy: false, + IncludeDeleted: true, + }) + assert.NoError(t, err) + assert.Len(t, actions, 1) + if assert.Len(t, actions, 1) { + assert.EqualValues(t, 2, actions[0].ID) + assert.EqualValues(t, org.ID, actions[0].UserID) + } + assert.Equal(t, int64(1), count) + + actions, count, err = GetFeeds(db.DefaultContext, activities_model.GetFeedsOptions{ + RequestedUser: org, + Actor: user, + IncludePrivate: false, + OnlyPerformedBy: false, + IncludeDeleted: true, + }) + assert.NoError(t, err) + assert.Len(t, actions, 0) + assert.Equal(t, int64(0), count) +} + +func TestGetFeedsCorrupted(t *testing.T) { + // Now we will not check for corrupted data in the feeds + // users should run doctor to fix their data + assert.NoError(t, unittest.PrepareTestDatabase()) + user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 1}) + unittest.AssertExistsAndLoadBean(t, &activities_model.Action{ + ID: 8, + RepoID: 1700, + }) + + actions, count, err := GetFeeds(db.DefaultContext, activities_model.GetFeedsOptions{ + RequestedUser: user, + Actor: user, + IncludePrivate: true, + }) + assert.NoError(t, err) + assert.Len(t, actions, 1) + assert.Equal(t, int64(1), count) +} + +func TestRepoActions(t *testing.T) { + assert.NoError(t, unittest.PrepareTestDatabase()) + repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 1}) + _ = db.TruncateBeans(db.DefaultContext, &activities_model.Action{}) + for i := 0; i < 3; i++ { + _ = db.Insert(db.DefaultContext, &activities_model.Action{ + UserID: 2 + int64(i), + ActUserID: 2, + RepoID: repo.ID, + OpType: activities_model.ActionCommentIssue, + }) + } + count, _ := db.Count[activities_model.Action](db.DefaultContext, &db.ListOptions{}) + assert.EqualValues(t, 3, count) + actions, _, err := GetFeeds(db.DefaultContext, activities_model.GetFeedsOptions{ + RequestedRepo: repo, + }) + assert.NoError(t, err) + assert.Len(t, actions, 1) +} diff --git a/services/feed/action.go b/services/feed/notifier.go similarity index 100% rename from services/feed/action.go rename to services/feed/notifier.go diff --git a/services/feed/action_test.go b/services/feed/notifier_test.go similarity index 100% rename from services/feed/action_test.go rename to services/feed/notifier_test.go From fd3aa5bedb07d295d48b1f550c19ad1b387ba83f Mon Sep 17 00:00:00 2001 From: Zettat123 Date: Sat, 30 Nov 2024 04:32:10 +0800 Subject: [PATCH 2/3] Fix a bug in actions artifact test (#32672) This bug exists in `TestActionsArtifactDownload`. https://github.com/go-gitea/gitea/blob/a1f56f83bff56f86180e59742efd3748908b82c1/tests/integration/api_actions_artifact_test.go#L123-L134 We assert that `listResp.Count` is `2`, so `artifactIdx` could be `0` or `1`. https://github.com/go-gitea/gitea/blob/a1f56f83bff56f86180e59742efd3748908b82c1/tests/integration/api_actions_artifact_test.go#L144-L147 Then we assert that the length of `downloadResp.Value` is `1`. If `artifactIdx` is `1` at this point, the assertion on Line 147 will throw an `index out of range` error. --- tests/integration/api_actions_artifact_test.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/integration/api_actions_artifact_test.go b/tests/integration/api_actions_artifact_test.go index de5797f289da3..29e9930538e23 100644 --- a/tests/integration/api_actions_artifact_test.go +++ b/tests/integration/api_actions_artifact_test.go @@ -144,12 +144,12 @@ func TestActionsArtifactDownload(t *testing.T) { var downloadResp downloadArtifactResponse DecodeJSON(t, resp, &downloadResp) assert.Len(t, downloadResp.Value, 1) - assert.Equal(t, "artifact-download/abc.txt", downloadResp.Value[artifactIdx].Path) - assert.Equal(t, "file", downloadResp.Value[artifactIdx].ItemType) - assert.Contains(t, downloadResp.Value[artifactIdx].ContentLocation, "/api/actions_pipeline/_apis/pipelines/workflows/791/artifacts") + assert.Equal(t, "artifact-download/abc.txt", downloadResp.Value[0].Path) + assert.Equal(t, "file", downloadResp.Value[0].ItemType) + assert.Contains(t, downloadResp.Value[0].ContentLocation, "/api/actions_pipeline/_apis/pipelines/workflows/791/artifacts") - idx = strings.Index(downloadResp.Value[artifactIdx].ContentLocation, "/api/actions_pipeline/_apis/pipelines/") - url = downloadResp.Value[artifactIdx].ContentLocation[idx:] + idx = strings.Index(downloadResp.Value[0].ContentLocation, "/api/actions_pipeline/_apis/pipelines/") + url = downloadResp.Value[0].ContentLocation[idx:] req = NewRequest(t, "GET", url). AddTokenAuth("8061e833a55f6fc0157c98b883e91fcfeeb1a71a") resp = MakeRequest(t, req, http.StatusOK) From 79d593a9be48d8281ce9418906a540e1f98c2f7c Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Fri, 29 Nov 2024 17:15:41 -0800 Subject: [PATCH 3/3] Split mail sender sub package from mailer service package (#32618) Move all mail sender related codes into a sub package of services/mailer. Just move, no code change. Then we just have dependencies on go-mail package in the new sub package. We can use other package to replace it because it's unmaintainable. ref #18664 --- routers/private/mail.go | 3 +- services/mailer/mail.go | 19 +- services/mailer/mail_release.go | 5 +- services/mailer/mail_repo.go | 3 +- services/mailer/mail_team_invite.go | 3 +- services/mailer/mail_test.go | 5 +- services/mailer/mailer.go | 390 +----------------- services/mailer/sender/dummy.go | 26 ++ services/mailer/sender/message.go | 112 +++++ .../message_test.go} | 4 +- services/mailer/sender/sender.go | 27 ++ services/mailer/sender/sendmail.go | 76 ++++ services/mailer/sender/smtp.go | 150 +++++++ services/mailer/sender/smtp_auth.go | 69 ++++ tests/integration/incoming_email_test.go | 16 +- 15 files changed, 503 insertions(+), 405 deletions(-) create mode 100644 services/mailer/sender/dummy.go create mode 100644 services/mailer/sender/message.go rename services/mailer/{mailer_test.go => sender/message_test.go} (97%) create mode 100644 services/mailer/sender/sender.go create mode 100644 services/mailer/sender/sendmail.go create mode 100644 services/mailer/sender/smtp.go create mode 100644 services/mailer/sender/smtp_auth.go diff --git a/routers/private/mail.go b/routers/private/mail.go index cf3abb31c6e74..6c33467af7400 100644 --- a/routers/private/mail.go +++ b/routers/private/mail.go @@ -17,6 +17,7 @@ import ( "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/services/context" "code.gitea.io/gitea/services/mailer" + sender_service "code.gitea.io/gitea/services/mailer/sender" ) // SendEmail pushes messages to mail queue @@ -81,7 +82,7 @@ func SendEmail(ctx *context.PrivateContext) { func sendEmail(ctx *context.PrivateContext, subject, message string, to []string) { for _, email := range to { - msg := mailer.NewMessage(email, subject, message) + msg := sender_service.NewMessage(email, subject, message) mailer.SendAsync(msg) } diff --git a/services/mailer/mail.go b/services/mailer/mail.go index 8eee32a8c67ec..ee2c8c0963893 100644 --- a/services/mailer/mail.go +++ b/services/mailer/mail.go @@ -29,9 +29,8 @@ import ( "code.gitea.io/gitea/modules/timeutil" "code.gitea.io/gitea/modules/translation" incoming_payload "code.gitea.io/gitea/services/mailer/incoming/payload" + sender_service "code.gitea.io/gitea/services/mailer/sender" "code.gitea.io/gitea/services/mailer/token" - - "gopkg.in/gomail.v2" ) const ( @@ -60,7 +59,7 @@ func SendTestMail(email string) error { // No mail service configured return nil } - return gomail.Send(Sender, NewMessage(email, "Gitea Test Email!", "Gitea Test Email!").ToMessage()) + return sender_service.Send(sender, sender_service.NewMessage(email, "Gitea Test Email!", "Gitea Test Email!")) } // sendUserMail sends a mail to the user @@ -82,7 +81,7 @@ func sendUserMail(language string, u *user_model.User, tpl base.TplName, code, s return } - msg := NewMessage(u.EmailTo(), subject, content.String()) + msg := sender_service.NewMessage(u.EmailTo(), subject, content.String()) msg.Info = fmt.Sprintf("UID: %d, %s", u.ID, info) SendAsync(msg) @@ -130,7 +129,7 @@ func SendActivateEmailMail(u *user_model.User, email string) { return } - msg := NewMessage(email, locale.TrString("mail.activate_email"), content.String()) + msg := sender_service.NewMessage(email, locale.TrString("mail.activate_email"), content.String()) msg.Info = fmt.Sprintf("UID: %d, activate email", u.ID) SendAsync(msg) @@ -158,7 +157,7 @@ func SendRegisterNotifyMail(u *user_model.User) { return } - msg := NewMessage(u.EmailTo(), locale.TrString("mail.register_notify", setting.AppName), content.String()) + msg := sender_service.NewMessage(u.EmailTo(), locale.TrString("mail.register_notify", setting.AppName), content.String()) msg.Info = fmt.Sprintf("UID: %d, registration notify", u.ID) SendAsync(msg) @@ -189,13 +188,13 @@ func SendCollaboratorMail(u, doer *user_model.User, repo *repo_model.Repository) return } - msg := NewMessage(u.EmailTo(), subject, content.String()) + msg := sender_service.NewMessage(u.EmailTo(), subject, content.String()) msg.Info = fmt.Sprintf("UID: %d, add collaborator", u.ID) SendAsync(msg) } -func composeIssueCommentMessages(ctx *mailCommentContext, lang string, recipients []*user_model.User, fromMention bool, info string) ([]*Message, error) { +func composeIssueCommentMessages(ctx *mailCommentContext, lang string, recipients []*user_model.User, fromMention bool, info string) ([]*sender_service.Message, error) { var ( subject string link string @@ -304,9 +303,9 @@ func composeIssueCommentMessages(ctx *mailCommentContext, lang string, recipient return nil, err } - msgs := make([]*Message, 0, len(recipients)) + msgs := make([]*sender_service.Message, 0, len(recipients)) for _, recipient := range recipients { - msg := NewMessageFrom( + msg := sender_service.NewMessageFrom( recipient.Email, fromDisplayName(ctx.Doer), setting.MailService.FromEmail, diff --git a/services/mailer/mail_release.go b/services/mailer/mail_release.go index af1a7a266205b..1d73d77612255 100644 --- a/services/mailer/mail_release.go +++ b/services/mailer/mail_release.go @@ -15,6 +15,7 @@ import ( "code.gitea.io/gitea/modules/markup/markdown" "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/translation" + sender_service "code.gitea.io/gitea/services/mailer/sender" ) const ( @@ -80,11 +81,11 @@ func mailNewRelease(ctx context.Context, lang string, tos []*user_model.User, re return } - msgs := make([]*Message, 0, len(tos)) + msgs := make([]*sender_service.Message, 0, len(tos)) publisherName := fromDisplayName(rel.Publisher) msgID := generateMessageIDForRelease(rel) for _, to := range tos { - msg := NewMessageFrom(to.EmailTo(), publisherName, setting.MailService.FromEmail, subject, mailBody.String()) + msg := sender_service.NewMessageFrom(to.EmailTo(), publisherName, setting.MailService.FromEmail, subject, mailBody.String()) msg.Info = subject msg.SetHeader("Message-ID", msgID) msgs = append(msgs, msg) diff --git a/services/mailer/mail_repo.go b/services/mailer/mail_repo.go index 7003584786aa3..5f80654bcdff6 100644 --- a/services/mailer/mail_repo.go +++ b/services/mailer/mail_repo.go @@ -13,6 +13,7 @@ import ( user_model "code.gitea.io/gitea/models/user" "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/translation" + sender_service "code.gitea.io/gitea/services/mailer/sender" ) // SendRepoTransferNotifyMail triggers a notification e-mail when a pending repository transfer was created @@ -79,7 +80,7 @@ func sendRepoTransferNotifyMailPerLang(lang string, newOwner, doer *user_model.U } for _, to := range emailTos { - msg := NewMessageFrom(to.EmailTo(), fromDisplayName(doer), setting.MailService.FromEmail, subject, content.String()) + msg := sender_service.NewMessageFrom(to.EmailTo(), fromDisplayName(doer), setting.MailService.FromEmail, subject, content.String()) msg.Info = fmt.Sprintf("UID: %d, repository pending transfer notification", newOwner.ID) SendAsync(msg) diff --git a/services/mailer/mail_team_invite.go b/services/mailer/mail_team_invite.go index ceecefa50fab4..4f2d5e4ca7f60 100644 --- a/services/mailer/mail_team_invite.go +++ b/services/mailer/mail_team_invite.go @@ -15,6 +15,7 @@ import ( "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/translation" + sender_service "code.gitea.io/gitea/services/mailer/sender" ) const ( @@ -67,7 +68,7 @@ func MailTeamInvite(ctx context.Context, inviter *user_model.User, team *org_mod return err } - msg := NewMessage(invite.Email, subject, mailBody.String()) + msg := sender_service.NewMessage(invite.Email, subject, mailBody.String()) msg.Info = subject SendAsync(msg) diff --git a/services/mailer/mail_test.go b/services/mailer/mail_test.go index 663ffa85ef377..42de7599ebd18 100644 --- a/services/mailer/mail_test.go +++ b/services/mailer/mail_test.go @@ -23,6 +23,7 @@ import ( user_model "code.gitea.io/gitea/models/user" "code.gitea.io/gitea/modules/markup" "code.gitea.io/gitea/modules/setting" + sender_service "code.gitea.io/gitea/services/mailer/sender" "github.com/stretchr/testify/assert" ) @@ -167,7 +168,7 @@ func TestTemplateSelection(t *testing.T) { template.Must(bodyTemplates.New("pull/comment").Parse("pull/comment/body")) template.Must(bodyTemplates.New("issue/close").Parse("issue/close/body")) - expect := func(t *testing.T, msg *Message, expSubject, expBody string) { + expect := func(t *testing.T, msg *sender_service.Message, expSubject, expBody string) { subject := msg.ToMessage().GetHeader("Subject") msgbuf := new(bytes.Buffer) _, _ = msg.ToMessage().WriteTo(msgbuf) @@ -252,7 +253,7 @@ func TestTemplateServices(t *testing.T) { "//Re: //") } -func testComposeIssueCommentMessage(t *testing.T, ctx *mailCommentContext, recipients []*user_model.User, fromMention bool, info string) *Message { +func testComposeIssueCommentMessage(t *testing.T, ctx *mailCommentContext, recipients []*user_model.User, fromMention bool, info string) *sender_service.Message { msgs, err := composeIssueCommentMessages(ctx, "en-US", recipients, fromMention, info) assert.NoError(t, err) assert.Len(t, msgs, 1) diff --git a/services/mailer/mailer.go b/services/mailer/mailer.go index 5cb6d035213e3..bf4b5a43cb19f 100644 --- a/services/mailer/mailer.go +++ b/services/mailer/mailer.go @@ -5,391 +5,21 @@ package mailer import ( - "bytes" "context" - "crypto/tls" - "fmt" - "hash/fnv" - "io" - "net" - "net/smtp" - "os" - "os/exec" - "strings" - "time" - "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/graceful" "code.gitea.io/gitea/modules/log" - "code.gitea.io/gitea/modules/process" "code.gitea.io/gitea/modules/queue" "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/templates" + sender_service "code.gitea.io/gitea/services/mailer/sender" notify_service "code.gitea.io/gitea/services/notify" - - ntlmssp "github.com/Azure/go-ntlmssp" - "github.com/jaytaylor/html2text" - "gopkg.in/gomail.v2" ) -// Message mail body and log info -type Message struct { - Info string // Message information for log purpose. - FromAddress string - FromDisplayName string - To string // Use only one recipient to prevent leaking of addresses - ReplyTo string - Subject string - Date time.Time - Body string - Headers map[string][]string -} - -// ToMessage converts a Message to gomail.Message -func (m *Message) ToMessage() *gomail.Message { - msg := gomail.NewMessage() - msg.SetAddressHeader("From", m.FromAddress, m.FromDisplayName) - msg.SetHeader("To", m.To) - if m.ReplyTo != "" { - msg.SetHeader("Reply-To", m.ReplyTo) - } - for header := range m.Headers { - msg.SetHeader(header, m.Headers[header]...) - } - - if setting.MailService.SubjectPrefix != "" { - msg.SetHeader("Subject", setting.MailService.SubjectPrefix+" "+m.Subject) - } else { - msg.SetHeader("Subject", m.Subject) - } - msg.SetDateHeader("Date", m.Date) - msg.SetHeader("X-Auto-Response-Suppress", "All") - - plainBody, err := html2text.FromString(m.Body) - if err != nil || setting.MailService.SendAsPlainText { - if strings.Contains(base.TruncateString(m.Body, 100), "") { - log.Warn("Mail contains HTML but configured to send as plain text.") - } - msg.SetBody("text/plain", plainBody) - } else { - msg.SetBody("text/plain", plainBody) - msg.AddAlternative("text/html", m.Body) - } - - if len(msg.GetHeader("Message-ID")) == 0 { - msg.SetHeader("Message-ID", m.generateAutoMessageID()) - } - - for k, v := range setting.MailService.OverrideHeader { - if len(msg.GetHeader(k)) != 0 { - log.Debug("Mailer override header '%s' as per config", k) - } - msg.SetHeader(k, v...) - } - - return msg -} - -// SetHeader adds additional headers to a message -func (m *Message) SetHeader(field string, value ...string) { - m.Headers[field] = value -} - -func (m *Message) generateAutoMessageID() string { - dateMs := m.Date.UnixNano() / 1e6 - h := fnv.New64() - if len(m.To) > 0 { - _, _ = h.Write([]byte(m.To)) - } - _, _ = h.Write([]byte(m.Subject)) - _, _ = h.Write([]byte(m.Body)) - return fmt.Sprintf("", dateMs, h.Sum64(), setting.Domain) -} - -// NewMessageFrom creates new mail message object with custom From header. -func NewMessageFrom(to, fromDisplayName, fromAddress, subject, body string) *Message { - log.Trace("NewMessageFrom (body):\n%s", body) - - return &Message{ - FromAddress: fromAddress, - FromDisplayName: fromDisplayName, - To: to, - Subject: subject, - Date: time.Now(), - Body: body, - Headers: map[string][]string{}, - } -} - -// NewMessage creates new mail message object with default From header. -func NewMessage(to, subject, body string) *Message { - return NewMessageFrom(to, setting.MailService.FromName, setting.MailService.FromEmail, subject, body) -} - -type loginAuth struct { - username, password string -} - -// LoginAuth SMTP AUTH LOGIN Auth Handler -func LoginAuth(username, password string) smtp.Auth { - return &loginAuth{username, password} -} - -// Start start SMTP login auth -func (a *loginAuth) Start(server *smtp.ServerInfo) (string, []byte, error) { - return "LOGIN", []byte{}, nil -} - -// Next next step of SMTP login auth -func (a *loginAuth) Next(fromServer []byte, more bool) ([]byte, error) { - if more { - switch string(fromServer) { - case "Username:": - return []byte(a.username), nil - case "Password:": - return []byte(a.password), nil - default: - return nil, fmt.Errorf("unknown fromServer: %s", string(fromServer)) - } - } - return nil, nil -} - -type ntlmAuth struct { - username, password, domain string - domainNeeded bool -} - -// NtlmAuth SMTP AUTH NTLM Auth Handler -func NtlmAuth(username, password string) smtp.Auth { - user, domain, domainNeeded := ntlmssp.GetDomain(username) - return &ntlmAuth{user, password, domain, domainNeeded} -} - -// Start starts SMTP NTLM Auth -func (a *ntlmAuth) Start(server *smtp.ServerInfo) (string, []byte, error) { - negotiateMessage, err := ntlmssp.NewNegotiateMessage(a.domain, "") - return "NTLM", negotiateMessage, err -} - -// Next next step of SMTP ntlm auth -func (a *ntlmAuth) Next(fromServer []byte, more bool) ([]byte, error) { - if more { - if len(fromServer) == 0 { - return nil, fmt.Errorf("ntlm ChallengeMessage is empty") - } - authenticateMessage, err := ntlmssp.ProcessChallenge(fromServer, a.username, a.password, a.domainNeeded) - return authenticateMessage, err - } - return nil, nil -} - -// Sender SMTP mail sender -type smtpSender struct{} - -// Send send email -func (s *smtpSender) Send(from string, to []string, msg io.WriterTo) error { - opts := setting.MailService - - var network string - var address string - if opts.Protocol == "smtp+unix" { - network = "unix" - address = opts.SMTPAddr - } else { - network = "tcp" - address = net.JoinHostPort(opts.SMTPAddr, opts.SMTPPort) - } - - conn, err := net.Dial(network, address) - if err != nil { - return fmt.Errorf("failed to establish network connection to SMTP server: %w", err) - } - defer conn.Close() - - var tlsconfig *tls.Config - if opts.Protocol == "smtps" || opts.Protocol == "smtp+starttls" { - tlsconfig = &tls.Config{ - InsecureSkipVerify: opts.ForceTrustServerCert, - ServerName: opts.SMTPAddr, - } - - if opts.UseClientCert { - cert, err := tls.LoadX509KeyPair(opts.ClientCertFile, opts.ClientKeyFile) - if err != nil { - return fmt.Errorf("could not load SMTP client certificate: %w", err) - } - tlsconfig.Certificates = []tls.Certificate{cert} - } - } - - if opts.Protocol == "smtps" { - conn = tls.Client(conn, tlsconfig) - } - - host := "localhost" - if opts.Protocol == "smtp+unix" { - host = opts.SMTPAddr - } - client, err := smtp.NewClient(conn, host) - if err != nil { - return fmt.Errorf("could not initiate SMTP session: %w", err) - } - - if opts.EnableHelo { - hostname := opts.HeloHostname - if len(hostname) == 0 { - hostname, err = os.Hostname() - if err != nil { - return fmt.Errorf("could not retrieve system hostname: %w", err) - } - } - - if err = client.Hello(hostname); err != nil { - return fmt.Errorf("failed to issue HELO command: %w", err) - } - } - - if opts.Protocol == "smtp+starttls" { - hasStartTLS, _ := client.Extension("STARTTLS") - if hasStartTLS { - if err = client.StartTLS(tlsconfig); err != nil { - return fmt.Errorf("failed to start TLS connection: %w", err) - } - } else { - log.Warn("StartTLS requested, but SMTP server does not support it; falling back to regular SMTP") - } - } - - canAuth, options := client.Extension("AUTH") - if len(opts.User) > 0 { - if !canAuth { - return fmt.Errorf("SMTP server does not support AUTH, but credentials provided") - } - - var auth smtp.Auth - - if strings.Contains(options, "CRAM-MD5") { - auth = smtp.CRAMMD5Auth(opts.User, opts.Passwd) - } else if strings.Contains(options, "PLAIN") { - auth = smtp.PlainAuth("", opts.User, opts.Passwd, host) - } else if strings.Contains(options, "LOGIN") { - // Patch for AUTH LOGIN - auth = LoginAuth(opts.User, opts.Passwd) - } else if strings.Contains(options, "NTLM") { - auth = NtlmAuth(opts.User, opts.Passwd) - } - - if auth != nil { - if err = client.Auth(auth); err != nil { - return fmt.Errorf("failed to authenticate SMTP: %w", err) - } - } - } - - if opts.OverrideEnvelopeFrom { - if err = client.Mail(opts.EnvelopeFrom); err != nil { - return fmt.Errorf("failed to issue MAIL command: %w", err) - } - } else { - if err = client.Mail(from); err != nil { - return fmt.Errorf("failed to issue MAIL command: %w", err) - } - } - - for _, rec := range to { - if err = client.Rcpt(rec); err != nil { - return fmt.Errorf("failed to issue RCPT command: %w", err) - } - } - - w, err := client.Data() - if err != nil { - return fmt.Errorf("failed to issue DATA command: %w", err) - } else if _, err = msg.WriteTo(w); err != nil { - return fmt.Errorf("SMTP write failed: %w", err) - } else if err = w.Close(); err != nil { - return fmt.Errorf("SMTP close failed: %w", err) - } - - return client.Quit() -} - -// Sender sendmail mail sender -type sendmailSender struct{} - -// Send send email -func (s *sendmailSender) Send(from string, to []string, msg io.WriterTo) error { - var err error - var closeError error - var waitError error - - envelopeFrom := from - if setting.MailService.OverrideEnvelopeFrom { - envelopeFrom = setting.MailService.EnvelopeFrom - } - - args := []string{"-f", envelopeFrom, "-i"} - args = append(args, setting.MailService.SendmailArgs...) - args = append(args, to...) - log.Trace("Sending with: %s %v", setting.MailService.SendmailPath, args) - - desc := fmt.Sprintf("SendMail: %s %v", setting.MailService.SendmailPath, args) - - ctx, _, finished := process.GetManager().AddContextTimeout(graceful.GetManager().HammerContext(), setting.MailService.SendmailTimeout, desc) - defer finished() - - cmd := exec.CommandContext(ctx, setting.MailService.SendmailPath, args...) - pipe, err := cmd.StdinPipe() - if err != nil { - return err - } - process.SetSysProcAttribute(cmd) - - if err = cmd.Start(); err != nil { - _ = pipe.Close() - return err - } - - if setting.MailService.SendmailConvertCRLF { - buf := &strings.Builder{} - _, err = msg.WriteTo(buf) - if err == nil { - _, err = strings.NewReplacer("\r\n", "\n").WriteString(pipe, buf.String()) - } - } else { - _, err = msg.WriteTo(pipe) - } - - // we MUST close the pipe or sendmail will hang waiting for more of the message - // Also we should wait on our sendmail command even if something fails - closeError = pipe.Close() - waitError = cmd.Wait() - if err != nil { - return err - } else if closeError != nil { - return closeError - } - return waitError -} - -// Sender sendmail mail sender -type dummySender struct{} - -// Send send email -func (s *dummySender) Send(from string, to []string, msg io.WriterTo) error { - buf := bytes.Buffer{} - if _, err := msg.WriteTo(&buf); err != nil { - return err - } - log.Debug("Mail From: %s To: %v Body: %s", from, to, buf.String()) - return nil -} - -var mailQueue *queue.WorkerPoolQueue[*Message] +var mailQueue *queue.WorkerPoolQueue[*sender_service.Message] -// Sender sender for sending mail synchronously -var Sender gomail.Sender +// sender sender for sending mail synchronously +var sender sender_service.Sender // NewContext start mail queue service func NewContext(ctx context.Context) { @@ -406,20 +36,20 @@ func NewContext(ctx context.Context) { switch setting.MailService.Protocol { case "sendmail": - Sender = &sendmailSender{} + sender = &sender_service.SendmailSender{} case "dummy": - Sender = &dummySender{} + sender = &sender_service.DummySender{} default: - Sender = &smtpSender{} + sender = &sender_service.SMTPSender{} } subjectTemplates, bodyTemplates = templates.Mailer(ctx) - mailQueue = queue.CreateSimpleQueue(graceful.GetManager().ShutdownContext(), "mail", func(items ...*Message) []*Message { + mailQueue = queue.CreateSimpleQueue(graceful.GetManager().ShutdownContext(), "mail", func(items ...*sender_service.Message) []*sender_service.Message { for _, msg := range items { gomailMsg := msg.ToMessage() log.Trace("New e-mail sending request %s: %s", gomailMsg.GetHeader("To"), msg.Info) - if err := gomail.Send(Sender, gomailMsg); err != nil { + if err := sender_service.Send(sender, msg); err != nil { log.Error("Failed to send emails %s: %s - %v", gomailMsg.GetHeader("To"), msg.Info, err) } else { log.Trace("E-mails sent %s: %s", gomailMsg.GetHeader("To"), msg.Info) @@ -436,7 +66,7 @@ func NewContext(ctx context.Context) { // SendAsync send emails asynchronously (make it mockable) var SendAsync = sendAsync -func sendAsync(msgs ...*Message) { +func sendAsync(msgs ...*sender_service.Message) { if setting.MailService == nil { log.Error("Mailer: SendAsync is being invoked but mail service hasn't been initialized") return diff --git a/services/mailer/sender/dummy.go b/services/mailer/sender/dummy.go new file mode 100644 index 0000000000000..dd5f14abec232 --- /dev/null +++ b/services/mailer/sender/dummy.go @@ -0,0 +1,26 @@ +// Copyright 2024 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package sender + +import ( + "bytes" + "io" + + "code.gitea.io/gitea/modules/log" +) + +// DummySender Sender sendmail mail sender +type DummySender struct{} + +var _ Sender = &DummySender{} + +// Send send email +func (s *DummySender) Send(from string, to []string, msg io.WriterTo) error { + buf := bytes.Buffer{} + if _, err := msg.WriteTo(&buf); err != nil { + return err + } + log.Debug("Mail From: %s To: %v Body: %s", from, to, buf.String()) + return nil +} diff --git a/services/mailer/sender/message.go b/services/mailer/sender/message.go new file mode 100644 index 0000000000000..a3255692f0798 --- /dev/null +++ b/services/mailer/sender/message.go @@ -0,0 +1,112 @@ +// Copyright 2024 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package sender + +import ( + "fmt" + "hash/fnv" + "strings" + "time" + + "code.gitea.io/gitea/modules/base" + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/setting" + + "github.com/jaytaylor/html2text" + "gopkg.in/gomail.v2" +) + +// Message mail body and log info +type Message struct { + Info string // Message information for log purpose. + FromAddress string + FromDisplayName string + To string // Use only one recipient to prevent leaking of addresses + ReplyTo string + Subject string + Date time.Time + Body string + Headers map[string][]string +} + +// ToMessage converts a Message to gomail.Message +func (m *Message) ToMessage() *gomail.Message { + msg := gomail.NewMessage() + msg.SetAddressHeader("From", m.FromAddress, m.FromDisplayName) + msg.SetHeader("To", m.To) + if m.ReplyTo != "" { + msg.SetHeader("Reply-To", m.ReplyTo) + } + for header := range m.Headers { + msg.SetHeader(header, m.Headers[header]...) + } + + if setting.MailService.SubjectPrefix != "" { + msg.SetHeader("Subject", setting.MailService.SubjectPrefix+" "+m.Subject) + } else { + msg.SetHeader("Subject", m.Subject) + } + msg.SetDateHeader("Date", m.Date) + msg.SetHeader("X-Auto-Response-Suppress", "All") + + plainBody, err := html2text.FromString(m.Body) + if err != nil || setting.MailService.SendAsPlainText { + if strings.Contains(base.TruncateString(m.Body, 100), "") { + log.Warn("Mail contains HTML but configured to send as plain text.") + } + msg.SetBody("text/plain", plainBody) + } else { + msg.SetBody("text/plain", plainBody) + msg.AddAlternative("text/html", m.Body) + } + + if len(msg.GetHeader("Message-ID")) == 0 { + msg.SetHeader("Message-ID", m.generateAutoMessageID()) + } + + for k, v := range setting.MailService.OverrideHeader { + if len(msg.GetHeader(k)) != 0 { + log.Debug("Mailer override header '%s' as per config", k) + } + msg.SetHeader(k, v...) + } + + return msg +} + +// SetHeader adds additional headers to a message +func (m *Message) SetHeader(field string, value ...string) { + m.Headers[field] = value +} + +func (m *Message) generateAutoMessageID() string { + dateMs := m.Date.UnixNano() / 1e6 + h := fnv.New64() + if len(m.To) > 0 { + _, _ = h.Write([]byte(m.To)) + } + _, _ = h.Write([]byte(m.Subject)) + _, _ = h.Write([]byte(m.Body)) + return fmt.Sprintf("", dateMs, h.Sum64(), setting.Domain) +} + +// NewMessageFrom creates new mail message object with custom From header. +func NewMessageFrom(to, fromDisplayName, fromAddress, subject, body string) *Message { + log.Trace("NewMessageFrom (body):\n%s", body) + + return &Message{ + FromAddress: fromAddress, + FromDisplayName: fromDisplayName, + To: to, + Subject: subject, + Date: time.Now(), + Body: body, + Headers: map[string][]string{}, + } +} + +// NewMessage creates new mail message object with default From header. +func NewMessage(to, subject, body string) *Message { + return NewMessageFrom(to, setting.MailService.FromName, setting.MailService.FromEmail, subject, body) +} diff --git a/services/mailer/mailer_test.go b/services/mailer/sender/message_test.go similarity index 97% rename from services/mailer/mailer_test.go rename to services/mailer/sender/message_test.go index 6d7c44f40c044..d47052685ef15 100644 --- a/services/mailer/mailer_test.go +++ b/services/mailer/sender/message_test.go @@ -1,7 +1,7 @@ -// Copyright 2021 The Gogs Authors. All rights reserved. +// Copyright 2024 The Gitea Authors. All rights reserved. // SPDX-License-Identifier: MIT -package mailer +package sender import ( "strings" diff --git a/services/mailer/sender/sender.go b/services/mailer/sender/sender.go new file mode 100644 index 0000000000000..bf317aa846295 --- /dev/null +++ b/services/mailer/sender/sender.go @@ -0,0 +1,27 @@ +// Copyright 2024 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package sender + +import ( + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/setting" + + "gopkg.in/gomail.v2" +) + +type Sender gomail.Sender + +var Send = send + +func send(sender Sender, msgs ...*Message) error { + if setting.MailService == nil { + log.Error("Mailer: Send is being invoked but mail service hasn't been initialized") + return nil + } + goMsgs := []*gomail.Message{} + for _, msg := range msgs { + goMsgs = append(goMsgs, msg.ToMessage()) + } + return gomail.Send(sender, goMsgs...) +} diff --git a/services/mailer/sender/sendmail.go b/services/mailer/sender/sendmail.go new file mode 100644 index 0000000000000..64c7f8f0816b4 --- /dev/null +++ b/services/mailer/sender/sendmail.go @@ -0,0 +1,76 @@ +// Copyright 2024 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package sender + +import ( + "fmt" + "io" + "os/exec" + "strings" + + "code.gitea.io/gitea/modules/graceful" + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/process" + "code.gitea.io/gitea/modules/setting" +) + +// SendmailSender Sender sendmail mail sender +type SendmailSender struct{} + +var _ Sender = &SendmailSender{} + +// Send send email +func (s *SendmailSender) Send(from string, to []string, msg io.WriterTo) error { + var err error + var closeError error + var waitError error + + envelopeFrom := from + if setting.MailService.OverrideEnvelopeFrom { + envelopeFrom = setting.MailService.EnvelopeFrom + } + + args := []string{"-f", envelopeFrom, "-i"} + args = append(args, setting.MailService.SendmailArgs...) + args = append(args, to...) + log.Trace("Sending with: %s %v", setting.MailService.SendmailPath, args) + + desc := fmt.Sprintf("SendMail: %s %v", setting.MailService.SendmailPath, args) + + ctx, _, finished := process.GetManager().AddContextTimeout(graceful.GetManager().HammerContext(), setting.MailService.SendmailTimeout, desc) + defer finished() + + cmd := exec.CommandContext(ctx, setting.MailService.SendmailPath, args...) + pipe, err := cmd.StdinPipe() + if err != nil { + return err + } + process.SetSysProcAttribute(cmd) + + if err = cmd.Start(); err != nil { + _ = pipe.Close() + return err + } + + if setting.MailService.SendmailConvertCRLF { + buf := &strings.Builder{} + _, err = msg.WriteTo(buf) + if err == nil { + _, err = strings.NewReplacer("\r\n", "\n").WriteString(pipe, buf.String()) + } + } else { + _, err = msg.WriteTo(pipe) + } + + // we MUST close the pipe or sendmail will hang waiting for more of the message + // Also we should wait on our sendmail command even if something fails + closeError = pipe.Close() + waitError = cmd.Wait() + if err != nil { + return err + } else if closeError != nil { + return closeError + } + return waitError +} diff --git a/services/mailer/sender/smtp.go b/services/mailer/sender/smtp.go new file mode 100644 index 0000000000000..ab49b7e5f830c --- /dev/null +++ b/services/mailer/sender/smtp.go @@ -0,0 +1,150 @@ +// Copyright 2024 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package sender + +import ( + "crypto/tls" + "fmt" + "io" + "net" + "net/smtp" + "os" + "strings" + + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/setting" +) + +// SMTPSender Sender SMTP mail sender +type SMTPSender struct{} + +var _ Sender = &SMTPSender{} + +// Send send email +func (s *SMTPSender) Send(from string, to []string, msg io.WriterTo) error { + opts := setting.MailService + + var network string + var address string + if opts.Protocol == "smtp+unix" { + network = "unix" + address = opts.SMTPAddr + } else { + network = "tcp" + address = net.JoinHostPort(opts.SMTPAddr, opts.SMTPPort) + } + + conn, err := net.Dial(network, address) + if err != nil { + return fmt.Errorf("failed to establish network connection to SMTP server: %w", err) + } + defer conn.Close() + + var tlsconfig *tls.Config + if opts.Protocol == "smtps" || opts.Protocol == "smtp+starttls" { + tlsconfig = &tls.Config{ + InsecureSkipVerify: opts.ForceTrustServerCert, + ServerName: opts.SMTPAddr, + } + + if opts.UseClientCert { + cert, err := tls.LoadX509KeyPair(opts.ClientCertFile, opts.ClientKeyFile) + if err != nil { + return fmt.Errorf("could not load SMTP client certificate: %w", err) + } + tlsconfig.Certificates = []tls.Certificate{cert} + } + } + + if opts.Protocol == "smtps" { + conn = tls.Client(conn, tlsconfig) + } + + host := "localhost" + if opts.Protocol == "smtp+unix" { + host = opts.SMTPAddr + } + client, err := smtp.NewClient(conn, host) + if err != nil { + return fmt.Errorf("could not initiate SMTP session: %w", err) + } + + if opts.EnableHelo { + hostname := opts.HeloHostname + if len(hostname) == 0 { + hostname, err = os.Hostname() + if err != nil { + return fmt.Errorf("could not retrieve system hostname: %w", err) + } + } + + if err = client.Hello(hostname); err != nil { + return fmt.Errorf("failed to issue HELO command: %w", err) + } + } + + if opts.Protocol == "smtp+starttls" { + hasStartTLS, _ := client.Extension("STARTTLS") + if hasStartTLS { + if err = client.StartTLS(tlsconfig); err != nil { + return fmt.Errorf("failed to start TLS connection: %w", err) + } + } else { + log.Warn("StartTLS requested, but SMTP server does not support it; falling back to regular SMTP") + } + } + + canAuth, options := client.Extension("AUTH") + if len(opts.User) > 0 { + if !canAuth { + return fmt.Errorf("SMTP server does not support AUTH, but credentials provided") + } + + var auth smtp.Auth + + if strings.Contains(options, "CRAM-MD5") { + auth = smtp.CRAMMD5Auth(opts.User, opts.Passwd) + } else if strings.Contains(options, "PLAIN") { + auth = smtp.PlainAuth("", opts.User, opts.Passwd, host) + } else if strings.Contains(options, "LOGIN") { + // Patch for AUTH LOGIN + auth = LoginAuth(opts.User, opts.Passwd) + } else if strings.Contains(options, "NTLM") { + auth = NtlmAuth(opts.User, opts.Passwd) + } + + if auth != nil { + if err = client.Auth(auth); err != nil { + return fmt.Errorf("failed to authenticate SMTP: %w", err) + } + } + } + + if opts.OverrideEnvelopeFrom { + if err = client.Mail(opts.EnvelopeFrom); err != nil { + return fmt.Errorf("failed to issue MAIL command: %w", err) + } + } else { + if err = client.Mail(from); err != nil { + return fmt.Errorf("failed to issue MAIL command: %w", err) + } + } + + for _, rec := range to { + if err = client.Rcpt(rec); err != nil { + return fmt.Errorf("failed to issue RCPT command: %w", err) + } + } + + w, err := client.Data() + if err != nil { + return fmt.Errorf("failed to issue DATA command: %w", err) + } else if _, err = msg.WriteTo(w); err != nil { + return fmt.Errorf("SMTP write failed: %w", err) + } else if err = w.Close(); err != nil { + return fmt.Errorf("SMTP close failed: %w", err) + } + + return client.Quit() +} diff --git a/services/mailer/sender/smtp_auth.go b/services/mailer/sender/smtp_auth.go new file mode 100644 index 0000000000000..df65498a5a73c --- /dev/null +++ b/services/mailer/sender/smtp_auth.go @@ -0,0 +1,69 @@ +// Copyright 2024 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package sender + +import ( + "fmt" + "net/smtp" + + "github.com/Azure/go-ntlmssp" +) + +type loginAuth struct { + username, password string +} + +// LoginAuth SMTP AUTH LOGIN Auth Handler +func LoginAuth(username, password string) smtp.Auth { + return &loginAuth{username, password} +} + +// Start start SMTP login auth +func (a *loginAuth) Start(server *smtp.ServerInfo) (string, []byte, error) { + return "LOGIN", []byte{}, nil +} + +// Next next step of SMTP login auth +func (a *loginAuth) Next(fromServer []byte, more bool) ([]byte, error) { + if more { + switch string(fromServer) { + case "Username:": + return []byte(a.username), nil + case "Password:": + return []byte(a.password), nil + default: + return nil, fmt.Errorf("unknown fromServer: %s", string(fromServer)) + } + } + return nil, nil +} + +type ntlmAuth struct { + username, password, domain string + domainNeeded bool +} + +// NtlmAuth SMTP AUTH NTLM Auth Handler +func NtlmAuth(username, password string) smtp.Auth { + user, domain, domainNeeded := ntlmssp.GetDomain(username) + return &ntlmAuth{user, password, domain, domainNeeded} +} + +// Start starts SMTP NTLM Auth +func (a *ntlmAuth) Start(server *smtp.ServerInfo) (string, []byte, error) { + negotiateMessage, err := ntlmssp.NewNegotiateMessage(a.domain, "") + return "NTLM", negotiateMessage, err +} + +// Next next step of SMTP ntlm auth +func (a *ntlmAuth) Next(fromServer []byte, more bool) ([]byte, error) { + if more { + if len(fromServer) == 0 { + return nil, fmt.Errorf("ntlm ChallengeMessage is empty") + } + authenticateMessage, err := ntlmssp.ProcessChallenge(fromServer, a.username, a.password, a.domainNeeded) + return authenticateMessage, err + } + return nil, nil +} diff --git a/tests/integration/incoming_email_test.go b/tests/integration/incoming_email_test.go index 88571303ac72b..e968a2956ee05 100644 --- a/tests/integration/incoming_email_test.go +++ b/tests/integration/incoming_email_test.go @@ -19,11 +19,11 @@ import ( "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/services/mailer/incoming" incoming_payload "code.gitea.io/gitea/services/mailer/incoming/payload" + sender_service "code.gitea.io/gitea/services/mailer/sender" token_service "code.gitea.io/gitea/services/mailer/token" "code.gitea.io/gitea/tests" "github.com/stretchr/testify/assert" - "gopkg.in/gomail.v2" ) func TestIncomingEmail(t *testing.T) { @@ -189,11 +189,15 @@ func TestIncomingEmail(t *testing.T) { token, err := token_service.CreateToken(token_service.ReplyHandlerType, user, payload) assert.NoError(t, err) - msg := gomail.NewMessage() - msg.SetHeader("To", strings.Replace(setting.IncomingEmail.ReplyToAddress, setting.IncomingEmail.TokenPlaceholder, token, 1)) - msg.SetHeader("From", user.Email) - msg.SetBody("text/plain", token) - err = gomail.Send(&smtpTestSender{}, msg) + msg := sender_service.NewMessageFrom( + strings.Replace(setting.IncomingEmail.ReplyToAddress, setting.IncomingEmail.TokenPlaceholder, token, 1), + "", + user.Email, + "", + token, + ) + + err = sender_service.Send(&smtpTestSender{}, msg) assert.NoError(t, err) assert.Eventually(t, func() bool {