diff --git a/awpy/types.py b/awpy/types.py index ec2de03c9..503b21b88 100644 --- a/awpy/types.py +++ b/awpy/types.py @@ -10,7 +10,7 @@ class PlotPosition: """Class to store information needed for plotting a position.""" - position: tuple[float, float] + position: tuple[float, float, float] color: str marker: str alpha: float | None = None diff --git a/awpy/visualization/plot.py b/awpy/visualization/plot.py index 8512db15f..9a397dfb5 100644 --- a/awpy/visualization/plot.py +++ b/awpy/visualization/plot.py @@ -62,8 +62,36 @@ def with_tmp_dir() -> Generator[None, None, None]: shutil.rmtree(AWPY_TMP_FOLDER) +def _get_image_with_size( + base_path: str, *, qualifier: str = "", split: bool = False +) -> tuple[np.ndarray, tuple[float, float]]: + """Get image and size of image. + + Args: + base_path (str): Path to image + qualifier (str): Additional qualifier to the image. + Valid are: ("", "_dark", "_light"). Defaults to "". + split (bool): Whether to split levels for maps with multiple levels. + Defaults to False. + + Returns: + tuple[np.ndarray, tuple[float, float]]: image and size of image + """ + figsize = (6.4, 4.8) + map_bg = imageio.imread(f"{base_path}{qualifier}.png") + if split: + map_bg_lower = imageio.imread(f"{base_path}_lower{qualifier}.png") + map_bg = np.concatenate([map_bg, map_bg_lower]) + figsize = (figsize[0], figsize[1] * 2) + return map_bg, figsize + + def plot_map( - map_name: str = "de_dust2", map_type: str = "original", *, dark: bool = False + map_name: str = "de_dust2", + map_type: str = "original", + *, + dark: bool = False, + split_levels: bool = True, ) -> tuple[Figure, Axes]: """Plots a blank map. @@ -73,29 +101,24 @@ def plot_map( dark (bool, optional): Only for use with map_type="simpleradar". Indicates if you want to use the SimpleRadar dark map type Defaults to False + split_levels (bool): Whether to split levels for maps with multiple levels. + Defaults to True Returns: matplotlib fig and ax """ base_path = os.path.join(os.path.dirname(__file__), f"""../data/map/{map_name}""") + split = split_levels and map_name in MAP_DATA and "z_cutoff" in MAP_DATA[map_name] if map_type == "original": - map_bg = imageio.imread(f"{base_path}.png") - if map_name in MAP_DATA and "z_cutoff" in MAP_DATA[map_name]: - map_bg_lower = imageio.imread(f"{base_path}_lower.png") - map_bg = np.concatenate([map_bg, map_bg_lower]) + map_bg, figsize = _get_image_with_size(base_path, split=split) else: try: - col = "dark" if dark else "light" - map_bg = imageio.imread(f"{base_path}_{col}.png") - if map_name in MAP_DATA and "z_cutoff" in MAP_DATA[map_name]: - map_bg_lower = imageio.imread(f"{base_path}_lower_{col}.png") - map_bg = np.concatenate([map_bg, map_bg_lower]) + map_bg, figsize = _get_image_with_size( + base_path, qualifier="_dark" if dark else "_light", split=split + ) except FileNotFoundError: - map_bg = imageio.imread(f"{base_path}.png") - if map_name in MAP_DATA and "z_cutoff" in MAP_DATA[map_name]: - map_bg_lower = imageio.imread(f"{base_path}_lower.png") - map_bg = np.concatenate([map_bg, map_bg_lower]) - figure, axes = plt.subplots() + map_bg, figsize = _get_image_with_size(base_path, split=split) + figure, axes = plt.subplots(figsize=figsize) axes.imshow(map_bg, zorder=0) return figure, axes @@ -191,12 +214,10 @@ def plot_positions( figure, axes = plot_map(map_name=map_name, map_type=map_type, dark=dark) for position in positions: if apply_transformation: - x = position_transform(map_name, position.position[0], "x") - y = position_transform(map_name, position.position[1], "y") + x, y, _ = position_transform_all(map_name, position.position) else: - x = position.position[0] - y = position.position[1] + x, y, _ = position.position axes.scatter( x=x, y=y, @@ -223,10 +244,7 @@ def _get_plot_position_for_player( Returns: PlotPosition: Information needed to plot the player. """ - pos = ( - position_transform(map_name, player["x"], "x"), - position_transform(map_name, player["y"], "y"), - ) + pos = position_transform_all(map_name, (player["x"], player["y"], player["z"])) color = "cyan" if side == "ct" else "red" marker = "x" if player["hp"] == 0 else "." return PlotPosition(position=pos, color=color, marker=marker) @@ -242,10 +260,7 @@ def _get_plot_position_for_bomb(bomb: BombInfo, map_name: str) -> PlotPosition: Returns: PlotPosition: Information needed to plot the bomb. """ - pos = ( - position_transform(map_name, bomb["x"], "x"), - position_transform(map_name, bomb["y"], "y"), - ) + pos = position_transform_all(map_name, (bomb["x"], bomb["y"], bomb["z"])) color = "orange" marker = "8" return PlotPosition(position=pos, color=color, marker=marker) @@ -331,8 +346,9 @@ def plot_nades( """ if nades is None: nades = [] - - figure, axes = plot_map(map_name=map_name, map_type=map_type, dark=dark) + figure, axes = plot_map( + map_name=map_name, map_type=map_type, dark=dark, split_levels=False + ) for game_round in rounds: if game_round["grenades"] is None: continue diff --git a/tests/test_vis.py b/tests/test_vis.py index fcddd52d5..48cbeef8a 100644 --- a/tests/test_vis.py +++ b/tests/test_vis.py @@ -74,8 +74,8 @@ def test_plot_map(self): @patch("awpy.visualization.plot.Axes.scatter") def test_plot_positions(self, scatter_mock: MagicMock): """Test plot positions.""" - pos1 = PlotPosition((1, 2), "red", "X", 1.0, 1.0) - pos2 = PlotPosition((2, 1), "blue", "8", 0.4, 0.3) + pos1 = PlotPosition((1, 2, 0), "red", "X", 1.0, 1.0) + pos2 = PlotPosition((2, 1, 0), "blue", "8", 0.4, 0.3) fig, axis = plot_positions( positions=[pos1, pos2], apply_transformation=True, @@ -100,11 +100,11 @@ def test_plot_round(self): { "bomb": {"x": 1890, "y": 74, "z": 1613.03125}, "t": {"players": []}, - "ct": {"players": [{"hp": 100, "x": 0, "y": 0}]}, + "ct": {"players": [{"hp": 100, "x": 0, "y": 0, "z": 0}]}, }, { "bomb": {"x": 1890, "y": 74, "z": 1613.03125}, - "t": {"players": [{"hp": 0, "x": 0, "y": 0}]}, + "t": {"players": [{"hp": 0, "x": 0, "y": 0, "z": 0}]}, "ct": {"players": []}, }, ] @@ -121,6 +121,7 @@ def test_plot_round(self): ( position_transform("de_ancient", 1890, "x"), position_transform("de_ancient", 74, "y"), + 1613.03125, ), "orange", "8", @@ -129,6 +130,7 @@ def test_plot_round(self): ( position_transform("de_ancient", 0, "x"), position_transform("de_ancient", 0, "y"), + 0, ), "red", "x",