forked from adampy/adambot
-
Notifications
You must be signed in to change notification settings - Fork 0
/
adambot.py
280 lines (229 loc) · 12.4 KB
/
adambot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
import datetime
import json
import os
import time
from typing import Callable, Optional
import asyncpg
import discord
import pandas
import pytz
from discord import Interaction
from discord.app_commands import AppCommandError
from discord.ext import commands
from discord.ext.commands import Bot, when_mentioned_or, when_mentioned
from tzlocal import get_localzone
import libs.db.database_handle as database_handle # not strictly a lib rn but hopefully will be in the future
import libs.misc.utils as utils
from libs.misc.decorators import MissingStaffError, MissingDevError, MissingStaffSlashError, MissingDevSlashError
from libs.misc.utils import DefaultEmbedResponses, ContextTypes, get_context_type
from scripts.utils import cog_handler
class AdamTree(discord.app_commands.tree.CommandTree):
def __init__(self, client: discord.Client) -> None:
self.client = client
self.maps = {}
super().__init__(client)
def map(self, error: AppCommandError, method: Callable) -> None:
"""
Allows for mapping custom AppCommandErrors to custom handler methods.
"""
self.maps[error] = method
async def on_error(self, interaction: Interaction, error: AppCommandError) -> None:
"""
Custom error handler for AppCommandErrors. If a custom handler is mapped to the error, it will be called.
Otherwise, the error will be raised as normal.
"""
if isinstance(error, MissingStaffSlashError) or isinstance(error, MissingDevSlashError):
await DefaultEmbedResponses.invalid_perms(self.client, interaction)
else:
mapped_method = self.maps.get(error.__class__, None)
if callable(mapped_method):
await mapped_method(interaction, error)
else:
raise error
class AdamBot(Bot):
async def get_context(self, message: discord.Message, *, cls=commands.Context) -> commands.Context:
return await super().get_context(message, cls=cls) if cls else None
async def determine_prefix(self, bot, message: discord.Message) -> list[str]:
"""
Procedure that determines the prefix for a guild. This determines the prefix when a global one is not being used
"bot" is a required argument but also pointless since each AdamBot object isn't going to be trying to handle *other* AdamBot objects' prefixes
"""
watch_prefixes = [await self.get_config_key(message, "prefix") if message.guild else None, self.global_prefix]
if watch_prefixes != [None] * len(watch_prefixes):
return when_mentioned_or(*tuple([prefix for prefix in watch_prefixes if type(prefix) is str]))(self,
message) # internal conf prefix or guild conf prefix can be used
else:
# Config tables aren't loaded yet or internal config doesn't specify another prefix, temporarily set to mentions only
return when_mentioned(self, message)
async def get_used_prefixes(self, ctx: commands.Context | discord.Interaction | discord.Message | discord.Guild) -> \
list[str]:
"""
Gets the prefixes that can be used to invoke a command in the guild where the message is from
"""
if not hasattr(self, "get_config_key"):
return [] # config cog not loaded yet
guild_prefix = await self.get_config_key(ctx, "prefix")
return [prefix for prefix in [self.user.mention, self.global_prefix if self.global_prefix else None,
guild_prefix if guild_prefix else None] if type(prefix) is str]
def __init__(self, start_time: float, config_path: str = "config.json", command_prefix: str = "", *args,
**kwargs) -> None:
self.ContextType = ContextTypes
self.get_context_type = get_context_type
self.internal_config = self.load_internal_config(config_path)
self.cog_handler = cog_handler.CogHandler(self)
self.kwargs = kwargs
self.global_prefix = self.internal_config.get("global_prefix")
self.kwargs["command_prefix"] = self.determine_prefix if not command_prefix else when_mentioned_or(
command_prefix)
self.cog_handler.preload_core_cogs()
cog_dict = pandas.json_normalize(self.internal_config.get("cogs", {}), sep=".").to_dict(orient="records")[0]
if cog_dict:
self.cog_handler.preload_cogs(
pandas.json_normalize(self.internal_config["cogs"], sep=".").to_dict(orient="records")[0])
else:
print("[X] No cogs specified.")
super().__init__(*args,
intents=self.cog_handler.make_intents(list(dict.fromkeys(self.cog_handler.intent_list))),
tree_cls=AdamTree, **kwargs)
self.db_start = time.time()
print("Creating DB pool...")
self.db_url = self.internal_config.get("database_url", "")
if not self.db_url:
self.db_url = os.environ.get("DATABASE_URL", "")
self.connections = kwargs.get("connections", 10) # Max DB pool connections
self.online = False # Start at False, changes to True once fully initialised
self.LOCAL_HOST = False if os.environ.get("REMOTE", None) else True
self.display_timezone = pytz.timezone("Europe/London")
self.timezone = get_localzone()
self.ts_format = "%A %d/%m/%Y %H:%M:%S"
self.start_time = start_time
self._init_time = time.time()
self.last_active = {} # Used for ensuring bots do not respond or invoke commands
print(f"BOT INITIALISED {self._init_time - start_time} seconds")
async def shutdown(self,
ctx: commands.Context | discord.Interaction = None) -> None: # ctx = None because this is also called upon CTRL+C in command line
"""
Procedure that closes down AdamBot, using the standard client.close() command, as well as some database handling methods.
"""
ctx_type = self.get_context_type(ctx)
self.online = False # This is set to false to prevent DB things going on in the background once bot closed
user = f"{self.user.mention} " if self.user else ""
p_s = f"Beginning process of shutting {user}down. DB pool shutting down..."
if ctx_type == self.ContextTypes.Context:
await ctx.send(p_s)
elif ctx_type == self.ContextTypes.Interaction:
await ctx.response.send_message(p_s)
print(p_s)
if hasattr(self, "pool"):
self.pool.terminate() # TODO: Make this more graceful
c_s = "Closing connection to Discord..."
if ctx_type != self.ContextTypes.Unknown:
await ctx.channel.send(c_s)
print(c_s)
try:
await self.change_presence(status=discord.Status.offline)
except AttributeError:
pass # hasattr returns true but then you get yelled at if you use it
await super().close()
time.sleep(1) # stops bs RuntimeError spam at the end
print(f"Bot closed after {time.time() - self.start_time} seconds")
@staticmethod
def load_internal_config(config_path: str) -> dict:
"""
Loads bot's internal config from specified location.
Perhaps in the future have a "default" config generated e.g. with all cogs auto-detected, rather than it being specifically included in the repo?
"""
config = config_file = None
try:
config_file = open(config_path)
except Exception as e:
error_msg = f"Config is inaccessible! See the error below for more details\n{type(e).__name__}: {e}"
print(error_msg)
try:
config = json.loads(config_file.read())
except json.decoder.JSONDecodeError as e:
print(f"The JSON in the config is invalid! See the error below for more details\n{type(e).__name__}: {e}")
config_file.close()
exit()
config_file.close()
return config
async def start_up(self) -> None:
"""
Command that starts AdamBot, is run in AdamBot.__init__
"""
print("Loading utils into the bot instance...")
self.__dict__.update(utils.__dict__) # Bring all of utils into the bot - prevents referencing utils in cogs
print("Setting flag handlers...")
self.set_flag_handlers()
print("Loading cogs...")
await self.cog_handler.load_cogs()
self.cog_load = time.time()
print(
f"\nLoaded all cogs in {self.cog_load - self._init_time} seconds ({self.cog_load - self.start_time} seconds total)")
# Moved to here as it makes more sense to not load everything then tell the user they did an oopsies
print(f"Bot fully setup! ({time.time() - self.start_time} seconds total)")
print("Logging into Discord...")
token = self.internal_config.get("token", "")
if not token:
token = os.environ.get("TOKEN", "")
token = token if token else self.kwargs.get("token", "")
if not token:
print("No token provided!")
return
self.internal_config = []
self.pool: \
asyncpg.pool.Pool = await asyncpg.create_pool(self.db_url + "?sslmode=require", max_size=self.connections)
await database_handle.introduce_tables(self.pool, self.cog_handler.db_tables)
await database_handle.insert_cog_db_columns_if_not_exists(self.pool, self.cog_handler.db_tables)
print(f"DB took {time.time() - self.db_start} seconds to connect to")
try:
await self.start(token)
except Exception as e:
print(
f"Something went wrong handling the token!\nThe error was {type(e).__name__}: {e}") # overridden close cleans this up neatly
async def on_ready(self) -> None:
"""
Event that sets the bot instance's status and online presence
"""
self.login_time = time.time()
print(f"Bot logged into Discord ({self.login_time - self.start_time} seconds total)")
await self.tree.sync()
await self.change_presence(activity=discord.Game(name=f"in {len(self.guilds)} servers | Type `help` for help"),
status=discord.Status.online)
self.online = True
async def on_message(self, message: discord.Message) -> None:
"""
Event that has checks that stop bots from executing commands
"""
if type(message.channel) == discord.DMChannel or message.author.bot:
return
if message.guild.id not in self.last_active:
self.last_active[message.guild.id] = [] # create the dict key for that guild if it doesn't exist
last_active_list = self.last_active[message.guild.id]
if message.author in last_active_list:
last_active_list.remove(message.author)
last_active_list.insert(0, message.author)
# Now run commands, due to overriding of default bot `on_message` doesn't do this automatically
await self.process_commands(message)
async def on_command_error(self, ctx: commands.Context, error) -> None:
print(error) # added back for the sake of retaining sanity when debugging
if isinstance(error, MissingStaffError) or isinstance(error, MissingDevError):
await self.DefaultEmbedResponses.invalid_perms(ctx.bot, ctx)
def set_flag_handlers(self) -> None:
self.flag_handler = self.flags()
self.flag_handler.set_flag("time", {"flag": "t", "post_parse_handler": self.flag_methods.str_time_to_seconds})
self.flag_handler.set_flag("reason", {"flag": "r"})
def correct_time(self, conv_time: Optional[datetime.datetime] = None,
timezone_: str = "system") -> datetime.datetime:
if not conv_time:
conv_time = datetime.datetime.now()
if timezone_ == "system" and conv_time.tzinfo is None:
tz_obj = self.timezone
elif conv_time.tzinfo is not None:
tz_obj = pytz.timezone(conv_time.tzinfo.tzname(conv_time)) # conv_time.tzinfo isn't a pytz.tzinfo object
else:
tz_obj = pytz.timezone(timezone_)
try:
return tz_obj.localize(conv_time.replace(tzinfo=None)).astimezone(self.display_timezone)
except AttributeError: # TODO: Sometimes on local env throws exception (AttributeError: 'zoneinfo.ZoneInfo' object has no attribute 'localize') / potential fix?
return conv_time