forked from modmail-dev/Modmail
-
Notifications
You must be signed in to change notification settings - Fork 0
/
bot.py
1815 lines (1535 loc) · 67.8 KB
/
bot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
__version__ = "4.0.1"
import asyncio
import copy
import hashlib
import logging
import os
import re
import string
import struct
import sys
import platform
import typing
from datetime import datetime, timezone
from subprocess import PIPE
from types import SimpleNamespace
import discord
import isodate
from aiohttp import ClientSession, ClientResponseError
from discord.ext import commands, tasks
from discord.ext.commands.view import StringView
from emoji import UNICODE_EMOJI
from pkg_resources import parse_version
try:
# noinspection PyUnresolvedReferences
from colorama import init
init()
except ImportError:
pass
from core import checks
from core.changelog import Changelog
from core.clients import ApiClient, MongoDBClient, PluginDatabaseClient
from core.config import ConfigManager
from core.models import (
DMDisabled,
HostingMethod,
InvalidConfigError,
PermissionLevel,
SafeFormatter,
configure_logging,
getLogger,
)
from core.thread import ThreadManager
from core.time import human_timedelta
from core.utils import extract_block_timestamp, normalize_alias, parse_alias, truncate, tryint
logger = getLogger(__name__)
temp_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "temp")
if not os.path.exists(temp_dir):
os.mkdir(temp_dir)
if sys.platform == "win32":
try:
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
except AttributeError:
logger.error("Failed to use WindowsProactorEventLoopPolicy.", exc_info=True)
class ModmailBot(commands.Bot):
def __init__(self):
intents = discord.Intents.all()
super().__init__(command_prefix=None, intents=intents) # implemented in `get_prefix`
self.session = None
self._api = None
self.formatter = SafeFormatter()
self.loaded_cogs = ["cogs.modmail", "cogs.plugins", "cogs.utility"]
self._connected = None
self.start_time = discord.utils.utcnow()
self._started = False
self.config = ConfigManager(self)
self.config.populate_cache()
self.threads = ThreadManager(self)
self.log_file_name = os.path.join(temp_dir, f"{self.token.split('.')[0]}.log")
self._configure_logging()
self.plugin_db = PluginDatabaseClient(self) # Deprecated
self.startup()
def _resolve_snippet(self, name: str) -> typing.Optional[str]:
"""
Get actual snippet names from direct aliases to snippets.
If the provided name is a snippet, it's returned unchanged.
If there is an alias by this name, it is parsed to see if it
refers only to a snippet, in which case that snippet name is
returned.
If no snippets were found, None is returned.
"""
if name in self.snippets:
return name
try:
(command,) = parse_alias(self.aliases[name])
except (KeyError, ValueError):
# There is either no alias by this name present or the
# alias has multiple steps.
pass
else:
if command in self.snippets:
return command
@property
def uptime(self) -> str:
now = discord.utils.utcnow()
delta = now - self.start_time
hours, remainder = divmod(int(delta.total_seconds()), 3600)
minutes, seconds = divmod(remainder, 60)
days, hours = divmod(hours, 24)
fmt = "{h}h {m}m {s}s"
if days:
fmt = "{d}d " + fmt
return self.formatter.format(fmt, d=days, h=hours, m=minutes, s=seconds)
@property
def hosting_method(self) -> HostingMethod:
# use enums
if ".heroku" in os.environ.get("PYTHONHOME", ""):
return HostingMethod.HEROKU
if os.environ.get("pm_id"):
return HostingMethod.PM2
if os.environ.get("INVOCATION_ID"):
return HostingMethod.SYSTEMD
if os.environ.get("USING_DOCKER"):
return HostingMethod.DOCKER
if os.environ.get("TERM"):
return HostingMethod.SCREEN
return HostingMethod.OTHER
def startup(self):
logger.line()
logger.info("┌┬┐┌─┐┌┬┐┌┬┐┌─┐┬┬")
logger.info("││││ │ │││││├─┤││")
logger.info("┴ ┴└─┘─┴┘┴ ┴┴ ┴┴┴─┘")
logger.info("v%s", __version__)
logger.info("Authors: kyb3r, fourjr, Taaku18")
logger.line()
logger.info("discord.py: v%s", discord.__version__)
logger.line()
async def load_extensions(self):
for cog in self.loaded_cogs:
if cog in self.extensions:
continue
logger.debug("Loading %s.", cog)
try:
await self.load_extension(cog)
logger.debug("Successfully loaded %s.", cog)
except Exception:
logger.exception("Failed to load %s.", cog)
logger.line("debug")
def _configure_logging(self):
level_text = self.config["log_level"].upper()
logging_levels = {
"CRITICAL": logging.CRITICAL,
"ERROR": logging.ERROR,
"WARNING": logging.WARNING,
"INFO": logging.INFO,
"DEBUG": logging.DEBUG,
}
logger.line()
log_level = logging_levels.get(level_text)
if log_level is None:
log_level = self.config.remove("log_level")
logger.warning("Invalid logging level set: %s.", level_text)
logger.warning("Using default logging level: INFO.")
else:
logger.info("Logging level: %s", level_text)
logger.info("Log file: %s", self.log_file_name)
configure_logging(self.log_file_name, log_level)
logger.debug("Successfully configured logging.")
@property
def version(self):
return parse_version(__version__)
@property
def api(self) -> ApiClient:
if self._api is None:
if self.config["database_type"].lower() == "mongodb":
self._api = MongoDBClient(self)
else:
logger.critical("Invalid database type.")
raise RuntimeError
return self._api
@property
def db(self):
# deprecated
return self.api.db
async def get_prefix(self, message=None):
return [self.prefix, f"<@{self.user.id}> ", f"<@!{self.user.id}> "]
def run(self):
async def runner():
async with self:
self._connected = asyncio.Event()
self.session = ClientSession(loop=self.loop)
try:
retry_intents = False
try:
await self.start(self.token)
except discord.PrivilegedIntentsRequired:
retry_intents = True
if retry_intents:
await self.http.close()
if self.ws is not None and self.ws.open:
await self.ws.close(code=1000)
self._ready.clear()
intents = discord.Intents.default()
intents.members = True
intents.message_content = True
# Try again with members intent
self._connection._intents = intents
logger.warning(
"Attempting to login with only the server members and message content privileged intent. Some plugins might not work correctly."
)
await self.start(self.token)
except discord.PrivilegedIntentsRequired:
logger.critical(
"Privileged intents are not explicitly granted in the discord developers dashboard."
)
except discord.LoginFailure:
logger.critical("Invalid token")
except Exception:
logger.critical("Fatal exception", exc_info=True)
finally:
if self.session:
await self.session.close()
if not self.is_closed():
await self.close()
async def _cancel_tasks():
async with self:
task_retriever = asyncio.all_tasks
loop = self.loop
tasks = {t for t in task_retriever() if not t.done() and t.get_coro() != cancel_tasks_coro}
if not tasks:
return
logger.info("Cleaning up after %d tasks.", len(tasks))
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
logger.info("All tasks finished cancelling.")
for task in tasks:
try:
if task.exception() is not None:
loop.call_exception_handler(
{
"message": "Unhandled exception during Client.run shutdown.",
"exception": task.exception(),
"task": task,
}
)
except (asyncio.InvalidStateError, asyncio.CancelledError):
pass
try:
asyncio.run(runner(), debug=bool(os.getenv("DEBUG_ASYNCIO")))
except (KeyboardInterrupt, SystemExit):
logger.info("Received signal to terminate bot and event loop.")
finally:
logger.info("Cleaning up tasks.")
try:
cancel_tasks_coro = _cancel_tasks()
asyncio.run(cancel_tasks_coro)
finally:
logger.info("Closing the event loop.")
@property
def bot_owner_ids(self):
owner_ids = self.config["owners"]
if owner_ids is not None:
owner_ids = set(map(int, str(owner_ids).split(",")))
if self.owner_id is not None:
owner_ids.add(self.owner_id)
permissions = self.config["level_permissions"].get(PermissionLevel.OWNER.name, [])
for perm in permissions:
owner_ids.add(int(perm))
return owner_ids
async def is_owner(self, user: discord.User) -> bool:
if user.id in self.bot_owner_ids:
return True
return await super().is_owner(user)
@property
def log_channel(self) -> typing.Optional[discord.TextChannel]:
channel_id = self.config["log_channel_id"]
if channel_id is not None:
try:
channel = self.get_channel(int(channel_id))
if channel is not None:
return channel
except ValueError:
pass
logger.debug("LOG_CHANNEL_ID was invalid, removed.")
self.config.remove("log_channel_id")
if self.main_category is not None:
try:
channel = self.main_category.channels[0]
self.config["log_channel_id"] = channel.id
logger.warning("No log channel set, setting #%s to be the log channel.", channel.name)
return channel
except IndexError:
pass
logger.warning(
"No log channel set, set one with `%ssetup` or `%sconfig set log_channel_id <id>`.",
self.prefix,
self.prefix,
)
return None
@property
def mention_channel(self):
channel_id = self.config["mention_channel_id"]
if channel_id is not None:
try:
channel = self.get_channel(int(channel_id))
if channel is not None:
return channel
except ValueError:
pass
logger.debug("MENTION_CHANNEL_ID was invalid, removed.")
self.config.remove("mention_channel_id")
return self.log_channel
@property
def update_channel(self):
channel_id = self.config["update_channel_id"]
if channel_id is not None:
try:
channel = self.get_channel(int(channel_id))
if channel is not None:
return channel
except ValueError:
pass
logger.debug("UPDATE_CHANNEL_ID was invalid, removed.")
self.config.remove("update_channel_id")
return self.log_channel
async def wait_for_connected(self) -> None:
await self.wait_until_ready()
await self._connected.wait()
await self.config.wait_until_ready()
@property
def snippets(self) -> typing.Dict[str, str]:
return self.config["snippets"]
@property
def aliases(self) -> typing.Dict[str, str]:
return self.config["aliases"]
@property
def auto_triggers(self) -> typing.Dict[str, str]:
return self.config["auto_triggers"]
@property
def token(self) -> str:
token = self.config["token"]
if token is None:
logger.critical("TOKEN must be set, set this as bot token found on the Discord Developer Portal.")
sys.exit(0)
return token
@property
def guild_id(self) -> typing.Optional[int]:
guild_id = self.config["guild_id"]
if guild_id is not None:
try:
return int(str(guild_id))
except ValueError:
self.config.remove("guild_id")
logger.critical("Invalid GUILD_ID set.")
else:
logger.debug("No GUILD_ID set.")
return None
@property
def guild(self) -> typing.Optional[discord.Guild]:
"""
The guild that the bot is serving
(the server where users message it from)
"""
return discord.utils.get(self.guilds, id=self.guild_id)
@property
def modmail_guild(self) -> typing.Optional[discord.Guild]:
"""
The guild that the bot is operating in
(where the bot is creating threads)
"""
modmail_guild_id = self.config["modmail_guild_id"]
if modmail_guild_id is None:
return self.guild
try:
guild = discord.utils.get(self.guilds, id=int(modmail_guild_id))
if guild is not None:
return guild
except ValueError:
pass
self.config.remove("modmail_guild_id")
logger.critical("Invalid MODMAIL_GUILD_ID set.")
return self.guild
@property
def using_multiple_server_setup(self) -> bool:
return self.modmail_guild != self.guild
@property
def main_category(self) -> typing.Optional[discord.CategoryChannel]:
if self.modmail_guild is not None:
category_id = self.config["main_category_id"]
if category_id is not None:
try:
cat = discord.utils.get(self.modmail_guild.categories, id=int(category_id))
if cat is not None:
return cat
except ValueError:
pass
self.config.remove("main_category_id")
logger.debug("MAIN_CATEGORY_ID was invalid, removed.")
cat = discord.utils.get(self.modmail_guild.categories, name="Modmail")
if cat is not None:
self.config["main_category_id"] = cat.id
logger.debug(
'No main category set explicitly, setting category "Modmail" as the main category.'
)
return cat
return None
@property
def blocked_users(self) -> typing.Dict[str, str]:
return self.config["blocked"]
@property
def blocked_roles(self) -> typing.Dict[str, str]:
return self.config["blocked_roles"]
@property
def blocked_whitelisted_users(self) -> typing.List[str]:
return self.config["blocked_whitelist"]
@property
def prefix(self) -> str:
return str(self.config["prefix"])
@property
def mod_color(self) -> int:
return self.config.get("mod_color")
@property
def recipient_color(self) -> int:
return self.config.get("recipient_color")
@property
def main_color(self) -> int:
return self.config.get("main_color")
@property
def error_color(self) -> int:
return self.config.get("error_color")
def command_perm(self, command_name: str) -> PermissionLevel:
level = self.config["override_command_level"].get(command_name)
if level is not None:
try:
return PermissionLevel[level.upper()]
except KeyError:
logger.warning("Invalid override_command_level for command %s.", command_name)
self.config["override_command_level"].pop(command_name)
command = self.get_command(command_name)
if command is None:
logger.debug("Command %s not found.", command_name)
return PermissionLevel.INVALID
level = next(
(check.permission_level for check in command.checks if hasattr(check, "permission_level")),
None,
)
if level is None:
logger.debug("Command %s does not have a permission level.", command_name)
return PermissionLevel.INVALID
return level
async def on_connect(self):
try:
await self.api.validate_database_connection()
except Exception:
logger.debug("Logging out due to failed database connection.")
return await self.close()
logger.debug("Connected to gateway.")
await self.config.refresh()
await self.api.setup_indexes()
await self.load_extensions()
self._connected.set()
async def on_ready(self):
"""Bot startup, sets uptime."""
# Wait until config cache is populated with stuff from db and on_connect ran
await self.wait_for_connected()
if self.guild is None:
logger.error("Logging out due to invalid GUILD_ID.")
return await self.close()
if self._started:
# Bot has started before
logger.line()
logger.warning("Bot restarted due to internal discord reloading.")
logger.line()
return
logger.line()
logger.debug("Client ready.")
logger.info("Logged in as: %s", self.user)
logger.info("Bot ID: %s", self.user.id)
owners = ", ".join(
getattr(self.get_user(owner_id), "name", str(owner_id)) for owner_id in self.bot_owner_ids
)
logger.info("Owners: %s", owners)
logger.info("Prefix: %s", self.prefix)
logger.info("Guild Name: %s", self.guild.name)
logger.info("Guild ID: %s", self.guild.id)
if self.using_multiple_server_setup:
logger.info("Receiving guild ID: %s", self.modmail_guild.id)
logger.line()
if "dev" in __version__:
logger.warning(
"You are running a developmental version. This should not be used in production. (v%s)",
__version__,
)
logger.line()
await self.threads.populate_cache()
# closures
closures = self.config["closures"]
logger.info("There are %d thread(s) pending to be closed.", len(closures))
logger.line()
for recipient_id, items in tuple(closures.items()):
after = (
datetime.fromisoformat(items["time"]).astimezone(timezone.utc) - discord.utils.utcnow()
).total_seconds()
if after <= 0:
logger.debug("Closing thread for recipient %s.", recipient_id)
after = 0
else:
logger.debug("Thread for recipient %s will be closed after %s seconds.", recipient_id, after)
thread = await self.threads.find(recipient_id=int(recipient_id))
if not thread:
# If the channel is deleted
logger.debug("Failed to close thread for recipient %s.", recipient_id)
self.config["closures"].pop(recipient_id)
await self.config.update()
continue
await thread.close(
closer=await self.get_or_fetch_user(items["closer_id"]),
after=after,
silent=items["silent"],
delete_channel=items["delete_channel"],
message=items["message"],
auto_close=items.get("auto_close", False),
)
for log in await self.api.get_open_logs():
if self.get_channel(int(log["channel_id"])) is None:
logger.debug("Unable to resolve thread with channel %s.", log["channel_id"])
log_data = await self.api.post_log(
log["channel_id"],
{
"open": False,
"title": None,
"closed_at": str(discord.utils.utcnow()),
"close_message": "Channel has been deleted, no closer found.",
"closer": {
"id": str(self.user.id),
"name": self.user.name,
"discriminator": self.user.discriminator,
"avatar_url": self.user.display_avatar.url,
"mod": True,
},
},
)
if log_data:
logger.debug("Successfully closed thread with channel %s.", log["channel_id"])
else:
logger.debug("Failed to close thread with channel %s, skipping.", log["channel_id"])
other_guilds = [guild for guild in self.guilds if guild not in {self.guild, self.modmail_guild}]
if any(other_guilds):
logger.warning(
"The bot is in more servers other than the main and staff server. "
"This may cause data compromise (%s).",
", ".join(str(guild.name) for guild in other_guilds),
)
logger.warning("If the external servers are valid, you may ignore this message.")
self.post_metadata.start()
self.autoupdate.start()
self._started = True
async def convert_emoji(self, name: str) -> str:
ctx = SimpleNamespace(bot=self, guild=self.modmail_guild)
converter = commands.EmojiConverter()
if name not in UNICODE_EMOJI["en"]:
try:
name = await converter.convert(ctx, name.strip(":"))
except commands.BadArgument as e:
logger.warning("%s is not a valid emoji. %s.", name, e)
raise
return name
async def get_or_fetch_user(self, id: int) -> discord.User:
"""
Retrieve a User based on their ID.
This tries getting the user from the cache and falls back to making
an API call if they're not found in the cache.
"""
return self.get_user(id) or await self.fetch_user(id)
async def retrieve_emoji(self) -> typing.Tuple[str, str]:
sent_emoji = self.config["sent_emoji"]
blocked_emoji = self.config["blocked_emoji"]
if sent_emoji != "disable":
try:
sent_emoji = await self.convert_emoji(sent_emoji)
except commands.BadArgument:
logger.warning("Removed sent emoji (%s).", sent_emoji)
sent_emoji = self.config.remove("sent_emoji")
await self.config.update()
if blocked_emoji != "disable":
try:
blocked_emoji = await self.convert_emoji(blocked_emoji)
except commands.BadArgument:
logger.warning("Removed blocked emoji (%s).", blocked_emoji)
blocked_emoji = self.config.remove("blocked_emoji")
await self.config.update()
return sent_emoji, blocked_emoji
def check_account_age(self, author: discord.Member) -> bool:
account_age = self.config.get("account_age")
now = discord.utils.utcnow()
try:
min_account_age = author.created_at + account_age
except ValueError:
logger.warning("Error with 'account_age'.", exc_info=True)
min_account_age = author.created_at + self.config.remove("account_age")
if min_account_age > now:
# User account has not reached the required time
delta = human_timedelta(min_account_age)
logger.debug("Blocked due to account age, user %s.", author.name)
if str(author.id) not in self.blocked_users:
new_reason = f"System Message: New Account. User can try again {delta}."
self.blocked_users[str(author.id)] = new_reason
return False
return True
def check_guild_age(self, author: discord.Member) -> bool:
guild_age = self.config.get("guild_age")
now = discord.utils.utcnow()
if not hasattr(author, "joined_at"):
logger.warning("Not in guild, cannot verify guild_age, %s.", author.name)
return True
try:
min_guild_age = author.joined_at + guild_age
except ValueError:
logger.warning("Error with 'guild_age'.", exc_info=True)
min_guild_age = author.joined_at + self.config.remove("guild_age")
if min_guild_age > now:
# User has not stayed in the guild for long enough
delta = human_timedelta(min_guild_age)
logger.debug("Blocked due to guild age, user %s.", author.name)
if str(author.id) not in self.blocked_users:
new_reason = f"System Message: Recently Joined. User can try again {delta}."
self.blocked_users[str(author.id)] = new_reason
return False
return True
def check_manual_blocked_roles(self, author: discord.Member) -> bool:
if isinstance(author, discord.Member):
for r in author.roles:
if str(r.id) in self.blocked_roles:
blocked_reason = self.blocked_roles.get(str(r.id)) or ""
try:
end_time, after = extract_block_timestamp(blocked_reason, author.id)
except ValueError:
return False
if end_time is not None:
if after <= 0:
# No longer blocked
self.blocked_roles.pop(str(r.id))
logger.debug("No longer blocked, role %s.", r.name)
return True
logger.debug("User blocked, role %s.", r.name)
return False
return True
def check_manual_blocked(self, author: discord.Member) -> bool:
if str(author.id) not in self.blocked_users:
return True
blocked_reason = self.blocked_users.get(str(author.id)) or ""
if blocked_reason.startswith("System Message:"):
# Met the limits already, otherwise it would've been caught by the previous checks
logger.debug("No longer internally blocked, user %s.", author.name)
self.blocked_users.pop(str(author.id))
return True
try:
end_time, after = extract_block_timestamp(blocked_reason, author.id)
except ValueError:
return False
if end_time is not None:
if after <= 0:
# No longer blocked
self.blocked_users.pop(str(author.id))
logger.debug("No longer blocked, user %s.", author.name)
return True
logger.debug("User blocked, user %s.", author.name)
return False
async def _process_blocked(self, message):
_, blocked_emoji = await self.retrieve_emoji()
if await self.is_blocked(message.author, channel=message.channel, send_message=True):
await self.add_reaction(message, blocked_emoji)
return True
return False
async def is_blocked(
self,
author: discord.User,
*,
channel: discord.TextChannel = None,
send_message: bool = False,
) -> bool:
member = self.guild.get_member(author.id)
if member is None:
# try to find in other guilds
for g in self.guilds:
member = g.get_member(author.id)
if member:
break
if member is None:
logger.debug("User not in guild, %s.", author.id)
if member is not None:
author = member
if str(author.id) in self.blocked_whitelisted_users:
if str(author.id) in self.blocked_users:
self.blocked_users.pop(str(author.id))
await self.config.update()
return False
blocked_reason = self.blocked_users.get(str(author.id)) or ""
if not self.check_account_age(author) or not self.check_guild_age(author):
new_reason = self.blocked_users.get(str(author.id))
if new_reason != blocked_reason:
if send_message:
await channel.send(
embed=discord.Embed(
title="Message not sent!",
description=new_reason,
color=self.error_color,
)
)
return True
if not self.check_manual_blocked(author):
return True
if not self.check_manual_blocked_roles(author):
return True
await self.config.update()
return False
async def get_thread_cooldown(self, author: discord.Member):
thread_cooldown = self.config.get("thread_cooldown")
now = discord.utils.utcnow()
if thread_cooldown == isodate.Duration():
return
last_log = await self.api.get_latest_user_logs(author.id)
if last_log is None:
logger.debug("Last thread wasn't found, %s.", author.name)
return
last_log_closed_at = last_log.get("closed_at")
if not last_log_closed_at:
logger.debug("Last thread was not closed, %s.", author.name)
return
try:
cooldown = datetime.fromisoformat(last_log_closed_at).astimezone(timezone.utc) + thread_cooldown
except ValueError:
logger.warning("Error with 'thread_cooldown'.", exc_info=True)
cooldown = datetime.fromisoformat(last_log_closed_at).astimezone(
timezone.utc
) + self.config.remove("thread_cooldown")
if cooldown > now:
# User messaged before thread cooldown ended
delta = human_timedelta(cooldown)
logger.debug("Blocked due to thread cooldown, user %s.", author.name)
return delta
return
@staticmethod
async def add_reaction(
msg, reaction: typing.Union[discord.Emoji, discord.Reaction, discord.PartialEmoji, str]
) -> bool:
if reaction != "disable":
try:
await msg.add_reaction(reaction)
except (discord.HTTPException, discord.BadArgument) as e:
logger.warning("Failed to add reaction %s: %s.", reaction, e)
return False
return True
async def process_dm_modmail(self, message: discord.Message) -> None:
"""Processes messages sent to the bot."""
blocked = await self._process_blocked(message)
if blocked:
return
sent_emoji, blocked_emoji = await self.retrieve_emoji()
if message.type != discord.MessageType.default:
return
thread = await self.threads.find(recipient=message.author)
if thread is None:
delta = await self.get_thread_cooldown(message.author)
if delta:
await message.channel.send(
embed=discord.Embed(
title=self.config["cooldown_thread_title"],
description=self.config["cooldown_thread_response"].format(delta=delta),
color=self.error_color,
)
)
return
if self.config["dm_disabled"] in (DMDisabled.NEW_THREADS, DMDisabled.ALL_THREADS):
embed = discord.Embed(
title=self.config["disabled_new_thread_title"],
color=self.error_color,
description=self.config["disabled_new_thread_response"],
)
embed.set_footer(text=self.config["disabled_new_thread_footer"], icon_url=self.guild.icon.url)
logger.info("A new thread was blocked from %s due to disabled Modmail.", message.author)
await self.add_reaction(message, blocked_emoji)
return await message.channel.send(embed=embed)
thread = await self.threads.create(message.author, message=message)
else:
if self.config["dm_disabled"] == DMDisabled.ALL_THREADS:
embed = discord.Embed(
title=self.config["disabled_current_thread_title"],
color=self.error_color,
description=self.config["disabled_current_thread_response"],
)
embed.set_footer(
text=self.config["disabled_current_thread_footer"],
icon_url=self.guild.icon.url,
)
logger.info("A message was blocked from %s due to disabled Modmail.", message.author)
await self.add_reaction(message, blocked_emoji)
return await message.channel.send(embed=embed)
if not thread.cancelled:
try:
await thread.send(message)
except Exception:
logger.error("Failed to send message:", exc_info=True)
await self.add_reaction(message, blocked_emoji)
else:
for user in thread.recipients:
# send to all other recipients
if user != message.author:
try:
await thread.send(message, user)
except Exception:
# silently ignore
logger.error("Failed to send message:", exc_info=True)
await self.add_reaction(message, sent_emoji)
self.dispatch("thread_reply", thread, False, message, False, False)
def _get_snippet_command(self) -> commands.Command:
"""Get the correct reply command based on the snippet config"""
modifiers = "f"
if self.config["plain_snippets"]:
modifiers += "p"
if self.config["anonymous_snippets"]:
modifiers += "a"
return self.get_command(f"{modifiers}reply")
async def get_contexts(self, message, *, cls=commands.Context):
"""
Returns all invocation contexts from the message.
Supports getting the prefix from database as well as command aliases.
"""
view = StringView(message.content)
ctx = cls(prefix=self.prefix, view=view, bot=self, message=message)
thread = await self.threads.find(channel=ctx.channel)
if message.author.id == self.user.id: # type: ignore
return [ctx]
prefixes = await self.get_prefix()
invoked_prefix = discord.utils.find(view.skip_string, prefixes)
if invoked_prefix is None:
return [ctx]
invoker = view.get_word().lower()
# Check if a snippet is being called.
# This needs to be done before checking for aliases since
# snippets can have multiple words.
try:
# Use removeprefix once PY3.9+
snippet_text = self.snippets[message.content[len(invoked_prefix) :]]
except KeyError:
snippet_text = None
# Check if there is any aliases being called.
alias = self.aliases.get(invoker)
if alias is not None and snippet_text is None:
ctxs = []
aliases = normalize_alias(alias, message.content[len(f"{invoked_prefix}{invoker}") :])