diff --git a/modules/birthday/birthdaybase.go b/modules/birthday/birthdaybase.go index fe40179..aba9dcc 100644 --- a/modules/birthday/birthdaybase.go +++ b/modules/birthday/birthdaybase.go @@ -43,12 +43,14 @@ type birthdayBase struct { } type birthdayEntry struct { - ID uint64 `database:"id"` - Day int `database:"day"` - Month int `database:"month"` - Year int `database:"year"` - Visible bool `database:"visible"` - time time.Time + ID uint64 `database:"id"` + Day int `database:"day"` + Month int `database:"month"` + Year int `database:"year"` + Visible bool `database:"visible"` + time time.Time + GuildIDsRaw string `database:"guilds"` + GuildIDs []string } // Returns a readable Form of the date @@ -108,15 +110,75 @@ func (b birthdayEntry) Age() int { return b.Next().Year() - b.Year - 1 } +// ParseGuildIDs splits the guild IDs into a slice and stores them in b.GuildIDs. +func (b *birthdayEntry) ParseGuildIDs() { + b.GuildIDs = strings.Split(b.GuildIDsRaw, ",") +} + +// IsInGuild returns true if the guildID is in b.GuildIDs. +// If guildID is empty, IsInGuild returns true. +func (b birthdayEntry) IsInGuild(guildID string) bool { + if guildID == "" { + return true + } + return util.ContainsString(b.GuildIDs, guildID) +} + +// SetGuild sets the guildID in the birthday entry. +func (b *birthdayEntry) SetGuild(guildID string) { + b.GuildIDsRaw += guildID + b.ParseGuildIDs() +} + +// AddGuild adds the guildID to the birthday entry. +func (b *birthdayEntry) AddGuild(guildID string) error { + if util.ContainsString(b.GuildIDs, guildID) { + return nil + } else if len(b.GuildIDs) >= 3 { + return fmt.Errorf("this entry already has %d guilds", len(b.GuildIDs)) + } + b.GuildIDsRaw += "," + guildID + b.GuildIDsRaw = strings.Trim(b.GuildIDsRaw, ", ") + b.ParseGuildIDs() + return nil +} + +// IsEqual returns true if b and b2 are equal. +// +// That is, if all of the following are true +// 1. They have the same user ID. +// 2. They are on the same date. +// 3. They have the same visibility. +// 4. They have the same guilds in (any order). +func (b birthdayEntry) IsEqual(b2 birthdayEntry) bool { + if b.ID != b2.ID || b.Day != b2.Day || b.Month != b2.Month || b.Year != b2.Year || b.Visible != b2.Visible { + return false + } + + // check for same guilds in any order + for _, guildID := range b.GuildIDs { + if !util.ContainsString(b2.GuildIDs, guildID) { + return false + } + } + for _, guildID := range b2.GuildIDs { + if !util.ContainsString(b.GuildIDs, guildID) { + return false + } + } + return true +} + // getBirthday copies all birthday fields into the struct pointed at by b. // // If the user from b.ID is not found it returns sql.ErrNoRows. func (cmd birthdayBase) getBirthday(b *birthdayEntry) (err error) { - row := database.QueryRow("SELECT day,month,year,visible FROM birthdays WHERE id=?", b.ID) - err = row.Scan(&b.Day, &b.Month, &b.Year, &b.Visible) + row := database.QueryRow("SELECT day,month,year,visible,guilds FROM birthdays WHERE id=?", b.ID) + err = row.Scan(&b.Day, &b.Month, &b.Year, &b.Visible, &b.GuildIDsRaw) if err != nil { return err } + b.ParseGuildIDs() return b.ParseTime() } @@ -127,27 +189,36 @@ func (cmd birthdayBase) hasBirthday(id uint64) (hasBirthday bool, err error) { } // setBirthday inserts a new database entry with the values from b. -func (cmd birthdayBase) setBirthday(b birthdayEntry) error { - _, err := database.Exec("INSERT INTO birthdays(id,day,month,year,visible) VALUES(?,?,?,?,?);", b.ID, b.Day, b.Month, b.Year, b.Visible) +func (cmd birthdayBase) setBirthday(b *birthdayEntry) (err error) { + b.SetGuild(cmd.Interaction.GuildID) + _, err = database.Exec("INSERT INTO birthdays(id,day,month,year,visible,guilds) VALUES(?,?,?,?,?);", b.ID, b.Day, b.Month, b.Year, b.Visible, b.GuildIDsRaw) return err } // updateBirthday updates an existing database entry with the values from b. -func (cmd birthdayBase) updateBirthday(b birthdayEntry) (before birthdayEntry, err error) { - err = b.ParseTime() - if err != nil { - return birthdayEntry{}, err - } +func (cmd birthdayBase) updateBirthday(b *birthdayEntry) (before birthdayEntry, err error) { before.ID = b.ID if err = cmd.getBirthday(&before); err != nil { return birthdayEntry{}, fmt.Errorf("trying to get old birthday: %v", err) } + b.GuildIDsRaw = before.GuildIDsRaw + b.ParseGuildIDs() + + err = b.AddGuild(cmd.Interaction.GuildID) + if err != nil { + return birthdayEntry{}, fmt.Errorf("adding guild '%s' to birthday entry: %v", cmd.Interaction.GuildID, err) + } + + // early return if nothing changed + if b.IsEqual(before) { + return before, nil + } var ( updateNames []string updateVars []any oldV reflect.Value = reflect.ValueOf(before) - v reflect.Value = reflect.ValueOf(b) + v reflect.Value = reflect.ValueOf(*b) ) for i := 0; i < v.NumField(); i++ { var ( @@ -160,11 +231,11 @@ func (cmd birthdayBase) updateBirthday(b birthdayEntry) (before birthdayEntry, e continue } + tag := v.Type().Field(i).Tag.Get("database") + if tag == "" { + continue + } if f.Interface() != oldF.Interface() { - tag := v.Type().Field(i).Tag.Get("database") - if tag == "" { - continue - } updateNames = append(updateNames, tag) updateVars = append(updateVars, f.Interface()) } @@ -193,8 +264,8 @@ func (cmd birthdayBase) removeBirthday(id uint64) (birthdayEntry, error) { return b, err } -// getBirthdaysMonth return a sorted slice of birthday entries that matches the given month. -func (cmd birthdayBase) getBirthdaysMonth(month int) (birthdays []birthdayEntry, err error) { +// getBirthdaysMonth return a sorted slice of birthday entries that matches the given guildID and month. +func (cmd birthdayBase) getBirthdaysMonth(guildID string, month int) (birthdays []birthdayEntry, err error) { var numOfEntries int64 err = database.QueryRow("SELECT COUNT(*) FROM birthdays WHERE month=?", month).Scan(&numOfEntries) if err != nil { @@ -206,7 +277,7 @@ func (cmd birthdayBase) getBirthdaysMonth(month int) (birthdays []birthdayEntry, return birthdays, nil } - rows, err := database.Query("SELECT id,day,year,visible FROM birthdays WHERE month=?", month) + rows, err := database.Query("SELECT id,day,year,visible,guilds FROM birthdays WHERE month=?", month) if err != nil { return birthdays, err } @@ -214,12 +285,13 @@ func (cmd birthdayBase) getBirthdaysMonth(month int) (birthdays []birthdayEntry, for rows.Next() { b := birthdayEntry{Month: month} - err = rows.Scan(&b.ID, &b.Day, &b.Year, &b.Visible) + err = rows.Scan(&b.ID, &b.Day, &b.Year, &b.Visible, &b.GuildIDsRaw) if err != nil { return birthdays, err } + b.ParseGuildIDs() - if !b.Visible { + if !b.Visible || !b.IsInGuild(guildID) { continue } @@ -238,8 +310,8 @@ func (cmd birthdayBase) getBirthdaysMonth(month int) (birthdays []birthdayEntry, return birthdays, nil } -// getBirthdaysDate return a slice of birthday entries that matches the given date. -func getBirthdaysDate(day int, month int) (birthdays []birthdayEntry, err error) { +// getBirthdaysDate return a slice of birthday entries that matches the given guildID and date. +func getBirthdaysDate(guildID string, day int, month int) (birthdays []birthdayEntry, err error) { var numOfEntries int64 err = database.QueryRow("SELECT COUNT(*) FROM birthdays WHERE day=? AND month=?", day, month).Scan(&numOfEntries) if err != nil { @@ -251,7 +323,7 @@ func getBirthdaysDate(day int, month int) (birthdays []birthdayEntry, err error) return birthdays, nil } - rows, err := database.Query("SELECT id,year,visible FROM birthdays WHERE day=? AND month=?", day, month) + rows, err := database.Query("SELECT id,year,visible,guilds FROM birthdays WHERE day=? AND month=?", day, month) if err != nil { return birthdays, err } @@ -259,12 +331,13 @@ func getBirthdaysDate(day int, month int) (birthdays []birthdayEntry, err error) for rows.Next() { b := birthdayEntry{Day: day, Month: month} - err = rows.Scan(&b.ID, &b.Year, &b.Visible) + err = rows.Scan(&b.ID, &b.Year, &b.Visible, &b.GuildIDsRaw) if err != nil { return birthdays, err } + b.ParseGuildIDs() - if !b.Visible { + if !b.Visible || !b.IsInGuild(guildID) { continue } diff --git a/modules/birthday/handleCheck.go b/modules/birthday/handleCheck.go index 30a7c8c..7969ca2 100644 --- a/modules/birthday/handleCheck.go +++ b/modules/birthday/handleCheck.go @@ -35,14 +35,6 @@ func Check(s *discordgo.Session) { defer rows.Close() now := time.Now() - birthdays, err := getBirthdaysDate(now.Day(), int(now.Month())) - if err != nil { - log.Printf("Error on getting todays birthdays from database: %v\n", err) - } - e, n := birthdayAnnounceEmbed(s, birthdays) - if n <= 0 { - return - } for rows.Next() { err = rows.Scan(&guildID, &channelID) @@ -50,6 +42,9 @@ func Check(s *discordgo.Session) { log.Printf("Error on scanning birthday channel ID from database %v\n", err) continue } + if channelID == 0 { + continue + } channel, err := s.Channel(fmt.Sprint(channelID)) if err != nil { @@ -61,6 +56,15 @@ func Check(s *discordgo.Session) { return } + birthdays, err := getBirthdaysDate(fmt.Sprint(guildID), now.Day(), int(now.Month())) + if err != nil { + log.Printf("Error on getting todays birthdays from guild %s from database: %v\n", fmt.Sprint(guildID), err) + } + e, n := birthdayAnnounceEmbed(s, fmt.Sprint(guildID), birthdays) + if n <= 0 { + return + } + // announce _, err = s.ChannelMessageSendEmbed(channel.ID, e) if err != nil { @@ -71,7 +75,7 @@ func Check(s *discordgo.Session) { // birthdayAnnounceEmbed returns the embed, that contains all birthdays and 'n' as the number of // birthdays, which is always len(b) -func birthdayAnnounceEmbed(s *discordgo.Session, b []birthdayEntry) (e *discordgo.MessageEmbed, n int) { +func birthdayAnnounceEmbed(s *discordgo.Session, guildID string, b []birthdayEntry) (e *discordgo.MessageEmbed, n int) { var title, fValue string switch len(b) { @@ -85,14 +89,17 @@ func birthdayAnnounceEmbed(s *discordgo.Session, b []birthdayEntry) (e *discordg } for _, b := range b { - mention := fmt.Sprintf("<@%d>", b.ID) + member := util.IsGuildMember(s, guildID, fmt.Sprint(b.ID)) + if member == nil { + continue + } if b.Year == 0 { - fValue += fmt.Sprintf("%s\n", mention) + fValue += fmt.Sprintf("%s\n", member.Mention()) } else { format := lang.Get(tp+"msg.announce.with_age", lang.FallbackLang()) format += "\n" - fValue += fmt.Sprintf(format, mention, fmt.Sprint(b.Age())) + fValue += fmt.Sprintf(format, member.Mention(), fmt.Sprint(b.Age())) } } diff --git a/modules/birthday/handlerSubcommandAnnounce.go b/modules/birthday/handlerSubcommandAnnounce.go index 82d5a33..2f50bc0 100644 --- a/modules/birthday/handlerSubcommandAnnounce.go +++ b/modules/birthday/handlerSubcommandAnnounce.go @@ -37,14 +37,14 @@ func (cmd Chat) subcommandAnnounce() subcommandAnnounce { func (cmd subcommandAnnounce) handler() { now := time.Now() - b, err := getBirthdaysDate(now.Day(), int(now.Month())) + b, err := getBirthdaysDate(cmd.Interaction.GuildID, now.Day(), int(now.Month())) if err != nil { - log.Printf("Error on announce birthday: %v\n", err) + log.Printf("Error on announce birthday in guild %s: %v\n", cmd.Interaction.GuildID, err) cmd.ReplyError() return } - e, n := birthdayAnnounceEmbed(cmd.Session, b) + e, n := birthdayAnnounceEmbed(cmd.Session, cmd.Interaction.GuildID, b) if n <= 0 { cmd.ReplyHiddenEmbed(e) diff --git a/modules/birthday/handlerSubcommandList.go b/modules/birthday/handlerSubcommandList.go index 238804f..40b4eb5 100644 --- a/modules/birthday/handlerSubcommandList.go +++ b/modules/birthday/handlerSubcommandList.go @@ -49,9 +49,9 @@ func (cmd subcommandList) handler() { } month := int(cmd.month.IntValue()) - birthdays, err := cmd.getBirthdaysMonth(month) + birthdays, err := cmd.getBirthdaysMonth(cmd.Interaction.GuildID, month) if err != nil { - log.Printf("Error on get birthdays by month: %v\n", err) + log.Printf("Error on get birthdays by month from guild %s: %v\n", cmd.Interaction.GuildID, err) cmd.ReplyError() return } diff --git a/modules/birthday/handlerSubcommandSet.go b/modules/birthday/handlerSubcommandSet.go index bd97d93..02e5fca 100644 --- a/modules/birthday/handlerSubcommandSet.go +++ b/modules/birthday/handlerSubcommandSet.go @@ -121,7 +121,7 @@ func (cmd subcommandSet) interactionHandler() { return } - b := birthdayEntry{ + b := &birthdayEntry{ ID: authorID, Day: int(cmd.day.IntValue()), Month: int(cmd.month.IntValue()), @@ -194,13 +194,13 @@ func (cmd subcommandSet) interactionHandler() { } // seperate handler for an update of the birthday -func (cmd subcommandSet) handleUpdate(b birthdayEntry, e *discordgo.MessageEmbed) error { +func (cmd subcommandSet) handleUpdate(b *birthdayEntry, e *discordgo.MessageEmbed) (err error) { before, err := cmd.updateBirthday(b) if err != nil { return err } - if b == before { + if b.IsEqual(before) { var age string if b.Year > 0 { age = fmt.Sprintf(" (%d)", b.Age()+1) diff --git a/util/discord.go b/util/discord.go index e0b6ef1..7b89314 100644 --- a/util/discord.go +++ b/util/discord.go @@ -437,3 +437,24 @@ func MessageComplexWebhookEdit(src any) *discordgo.WebhookEdit { panic("Given source type is not supported: " + fmt.Sprintf("%T", src)) } } + +// IsGuildMember returns the given user as a member of the given guild. If the +// user is not a member of the guild IsGuildMember returns nil. +func IsGuildMember(s *discordgo.Session, guildID, userID string) (member *discordgo.Member) { + member, err := s.State.Member(guildID, userID) + if err == nil { + return member + } else if err != discordgo.ErrStateNotFound { + log.Printf("ERROR: Failed to get guild member from cache (G: %s, U: %s): %v\n", guildID, userID, err) + } + member, err = s.GuildMember(guildID, userID) + if err == nil { + return member + } + + var restErr *discordgo.RESTError + if !errors.As(err, &restErr) || restErr.Response.StatusCode != http.StatusNotFound { + log.Printf("ERROR: Failed to get guild member from API (G: %s, U: %s): %v\n", guildID, userID, err) + } + return nil +}