diff --git a/alpsplot/figure.py b/alpsplot/figure.py index 35ce04c..3a796bd 100644 --- a/alpsplot/figure.py +++ b/alpsplot/figure.py @@ -6,7 +6,7 @@ import numpy as np import matplotlib.ticker as ticker -from matplotlib import rc +import matplotlib from matplotlib import pyplot as plt from matplotlib.figure import Figure from matplotlib.axes import Axes @@ -15,9 +15,10 @@ import seaborn add_palatino() -rc('svg', image_inline=True, fonttype='none') -rc('pdf', fonttype=42) -rc('ps', fonttype=42) +matplotlib.rc('mathtext', fontset='cm') +matplotlib.rc('pdf', fonttype=42) +matplotlib.rc('ps', fonttype=42) +matplotlib.rc('svg', image_inline=True, fonttype='none') # Markers @@ -51,10 +52,8 @@ # ':' dotted line style class Figure: - def __init__(self, name: str, folder_path: str = None, fig: Figure = None, ax: Axes = None, figsize: tuple[float, float] = (5, 2.5), tex=False): + def __init__(self, name: str, folder_path: str = None, fig: Figure = None, ax: Axes = None, figsize: tuple[float, float] = (5, 2.5), **kwargs): super(Figure, self).__init__() - if tex: - rc('text', usetex=True) self.name: str = name self.folder_path: str = folder_path if folder_path is None: @@ -62,7 +61,7 @@ def __init__(self, name: str, folder_path: str = None, fig: Figure = None, ax: A self.fig: Figure = fig self.ax: Axes = ax if fig is None and ax is None: - self.fig: Figure = plt.figure(figsize=figsize) + self.fig: Figure = plt.figure(figsize=figsize, **kwargs) self.ax = self.fig.add_subplot(1, 1, 1) self.ax.spines['top'].set_visible(False) self.ax.spines['bottom'].set_visible(True) @@ -75,24 +74,24 @@ def __init__(self, name: str, folder_path: str = None, fig: Figure = None, ax: A self.ax.set_ylim([0.0, 1.0]) def set_legend(self, *args, fontsize=11, frameon: bool = True, edgecolor='white', framealpha=1.0, - font='Palatino', fontstyle='italic', fontweight='bold', math_fontfamily='cm', **kwargs) -> None: + fontproperties='Palatino', fontstyle='italic', fontweight='bold', math_fontfamily='cm', **kwargs) -> None: self.ax.legend(*args, frameon=frameon, edgecolor=edgecolor, framealpha=framealpha, **kwargs) plt.setp(self.ax.get_legend().get_texts(), fontsize=fontsize, - font=font, fontstyle=fontstyle, fontweight=fontweight, math_fontfamily=math_fontfamily) + fontproperties=fontproperties, fontstyle=fontstyle, fontweight=fontweight, math_fontfamily=math_fontfamily) def set_axis_label(self, axis: str, text: str, fontsize: int = 12, - font='Palatino', fontweight='bold', math_fontfamily='cm', **kwargs): + fontproperties='Palatino', fontweight='bold', math_fontfamily='cm', **kwargs): func = getattr(self.ax, f'set_{axis}label') - func(text, fontsize=fontsize, font=font, + func(text, fontsize=fontsize, fontproperties=fontproperties, fontweight=fontweight, math_fontfamily=math_fontfamily, **kwargs) def set_title(self, text: str = None, fontsize: int = 16, - font='Palatino', fontweight: str = 'bold', math_fontfamily='cm') -> None: + fontproperties='Palatino', fontweight: str = 'bold', math_fontfamily='cm') -> None: if text is None: text = self.name self.ax.set_title(text, fontsize=fontsize, - font=font, fontweight=fontweight, math_fontfamily=math_fontfamily) + fontproperties=fontproperties, fontweight=fontweight, math_fontfamily=math_fontfamily) def save(self, path: str = None, folder_path: str = None, ext: str = '.pdf') -> None: if path is None: @@ -107,7 +106,7 @@ def save(self, path: str = None, folder_path: str = None, ext: str = '.pdf') -> def set_axis_lim(self, axis: str, lim: list[float] = [0.0, 1.0], margin: list[float] = [0.0, 0.0], piece: int = 10, _format: str = '%.1f', fontsize: int = 11, - font='Palatino', fontweight: str = 'bold', math_fontfamily='cm') -> None: + fontproperties='Palatino', fontweight: str = 'bold', math_fontfamily='cm') -> None: if _format == 'integer': _format = '%d' lim_func = getattr(self.ax, f'set_{axis}lim') @@ -124,10 +123,10 @@ def format_func(_str): ticks = getattr(self.ax, f'get_{axis}ticks')() set_ticklabels_func = getattr(self.ax, f'set_{axis}ticklabels') set_ticklabels_func(ticks, fontsize=fontsize, - font=font, fontweight=fontweight, math_fontfamily=math_fontfamily) + fontproperties=fontproperties, fontweight=fontweight, math_fontfamily=math_fontfamily) format_func(_format) - def curve(self, x: np.ndarray, y: np.ndarray, color: str = 'black', linewidth: int = 2, + def plot(self, x: np.ndarray, y: np.ndarray, color: str = 'black', linewidth: int = 2, label: str = None, markerfacecolor: str = 'white', linestyle: str = '-', zorder: int = 1, **kwargs) -> Line2D: # linestyle marker markeredgecolor markeredgewidth markerfacecolor markersize alpha ax = seaborn.lineplot(x=x, y=y, ax=self.ax, color=color, linewidth=linewidth, @@ -172,7 +171,7 @@ def hist(self, x: np.ndarray, bins: list[float] = None, normed: bool = True, **k return self.ax.hist(x, bins=bins, normed=normed, **kwargs) def autolabel(self, rects: BarContainer, above: bool = True, fontsize: int = 6, - font='Palatino', fontweight: str = 'bold', math_fontfamily='cm'): + fontproperties='Palatino', fontweight: str = 'bold', math_fontfamily='cm') -> None: """Attach a text label above each bar in *rects*, displaying its height.""" for rect in rects: height = int(rect.get_height()) @@ -182,4 +181,4 @@ def autolabel(self, rects: BarContainer, above: bool = True, fontsize: int = 6, xytext=(0, offset), # 3 points vertical offset textcoords="offset points", ha='center', va='bottom', fontsize=fontsize, - font=font, fontweight=fontweight, math_fontfamily=math_fontfamily) + fontproperties=fontproperties, fontweight=fontweight, math_fontfamily=math_fontfamily) diff --git a/examples/figure_attack_asr_alpha.py b/examples/figure_attack_asr_alpha.py index c7edffa..2f8ba05 100644 --- a/examples/figure_attack_asr_alpha.py +++ b/examples/figure_attack_asr_alpha.py @@ -208,7 +208,7 @@ y_grid = monotone(y_grid, increase=False) y_grid = avg_smooth(y_grid, window=100) y_grid[0] = y_list[0] - fig.curve(x_grid, y_grid, color=color_dict[key]) + fig.plot(x_grid, y_grid, color=color_dict[key]) fig.scatter( x_list, y_list, color=color_dict[key], marker=mark_dict[key], label=attack_mapping[key]) if dataset == 'cifar10': diff --git a/examples/plot.py b/examples/plot.py index d837ce0..3d8d917 100644 --- a/examples/plot.py +++ b/examples/plot.py @@ -14,7 +14,7 @@ fig.set_axis_lim('y', lim=[0, 100], piece=5, margin=[1.0, 1.0], _format='%d') - fig.curve(x=x, y=y, color=ting_color['red']) + fig.plot(x=x, y=y, color=ting_color['red']) fig.scatter(x=x, y=y, color=ting_color['red'], marker='H', label='resnet') fig.set_legend()