Skip to content

Commit

Permalink
rename apis
Browse files Browse the repository at this point in the history
  • Loading branch information
ain-soph committed Apr 28, 2021
1 parent dad8df3 commit 69d1dbb
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 21 deletions.
37 changes: 18 additions & 19 deletions alpsplot/figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np

import matplotlib.ticker as ticker
from matplotlib import rc
import matplotlib
from matplotlib import pyplot as plt
from matplotlib.figure import Figure
from matplotlib.axes import Axes
Expand All @@ -15,9 +15,10 @@
import seaborn

add_palatino()
rc('svg', image_inline=True, fonttype='none')
rc('pdf', fonttype=42)
rc('ps', fonttype=42)
matplotlib.rc('mathtext', fontset='cm')
matplotlib.rc('pdf', fonttype=42)
matplotlib.rc('ps', fonttype=42)
matplotlib.rc('svg', image_inline=True, fonttype='none')


# Markers
Expand Down Expand Up @@ -51,18 +52,16 @@
# ':' dotted line style

class Figure:
def __init__(self, name: str, folder_path: str = None, fig: Figure = None, ax: Axes = None, figsize: tuple[float, float] = (5, 2.5), tex=False):
def __init__(self, name: str, folder_path: str = None, fig: Figure = None, ax: Axes = None, figsize: tuple[float, float] = (5, 2.5), **kwargs):
super(Figure, self).__init__()
if tex:
rc('text', usetex=True)
self.name: str = name
self.folder_path: str = folder_path
if folder_path is None:
self.folder_path = './output/'
self.fig: Figure = fig
self.ax: Axes = ax
if fig is None and ax is None:
self.fig: Figure = plt.figure(figsize=figsize)
self.fig: Figure = plt.figure(figsize=figsize, **kwargs)
self.ax = self.fig.add_subplot(1, 1, 1)
self.ax.spines['top'].set_visible(False)
self.ax.spines['bottom'].set_visible(True)
Expand All @@ -75,24 +74,24 @@ def __init__(self, name: str, folder_path: str = None, fig: Figure = None, ax: A
self.ax.set_ylim([0.0, 1.0])

def set_legend(self, *args, fontsize=11, frameon: bool = True, edgecolor='white', framealpha=1.0,
font='Palatino', fontstyle='italic', fontweight='bold', math_fontfamily='cm', **kwargs) -> None:
fontproperties='Palatino', fontstyle='italic', fontweight='bold', math_fontfamily='cm', **kwargs) -> None:
self.ax.legend(*args, frameon=frameon, edgecolor=edgecolor,
framealpha=framealpha, **kwargs)
plt.setp(self.ax.get_legend().get_texts(), fontsize=fontsize,
font=font, fontstyle=fontstyle, fontweight=fontweight, math_fontfamily=math_fontfamily)
fontproperties=fontproperties, fontstyle=fontstyle, fontweight=fontweight, math_fontfamily=math_fontfamily)

def set_axis_label(self, axis: str, text: str, fontsize: int = 12,
font='Palatino', fontweight='bold', math_fontfamily='cm', **kwargs):
fontproperties='Palatino', fontweight='bold', math_fontfamily='cm', **kwargs):
func = getattr(self.ax, f'set_{axis}label')
func(text, fontsize=fontsize, font=font,
func(text, fontsize=fontsize, fontproperties=fontproperties,
fontweight=fontweight, math_fontfamily=math_fontfamily, **kwargs)

def set_title(self, text: str = None, fontsize: int = 16,
font='Palatino', fontweight: str = 'bold', math_fontfamily='cm') -> None:
fontproperties='Palatino', fontweight: str = 'bold', math_fontfamily='cm') -> None:
if text is None:
text = self.name
self.ax.set_title(text, fontsize=fontsize,
font=font, fontweight=fontweight, math_fontfamily=math_fontfamily)
fontproperties=fontproperties, fontweight=fontweight, math_fontfamily=math_fontfamily)

def save(self, path: str = None, folder_path: str = None, ext: str = '.pdf') -> None:
if path is None:
Expand All @@ -107,7 +106,7 @@ def save(self, path: str = None, folder_path: str = None, ext: str = '.pdf') ->

def set_axis_lim(self, axis: str, lim: list[float] = [0.0, 1.0], margin: list[float] = [0.0, 0.0],
piece: int = 10, _format: str = '%.1f', fontsize: int = 11,
font='Palatino', fontweight: str = 'bold', math_fontfamily='cm') -> None:
fontproperties='Palatino', fontweight: str = 'bold', math_fontfamily='cm') -> None:
if _format == 'integer':
_format = '%d'
lim_func = getattr(self.ax, f'set_{axis}lim')
Expand All @@ -124,10 +123,10 @@ def format_func(_str):
ticks = getattr(self.ax, f'get_{axis}ticks')()
set_ticklabels_func = getattr(self.ax, f'set_{axis}ticklabels')
set_ticklabels_func(ticks, fontsize=fontsize,
font=font, fontweight=fontweight, math_fontfamily=math_fontfamily)
fontproperties=fontproperties, fontweight=fontweight, math_fontfamily=math_fontfamily)
format_func(_format)

def curve(self, x: np.ndarray, y: np.ndarray, color: str = 'black', linewidth: int = 2,
def plot(self, x: np.ndarray, y: np.ndarray, color: str = 'black', linewidth: int = 2,
label: str = None, markerfacecolor: str = 'white', linestyle: str = '-', zorder: int = 1, **kwargs) -> Line2D:
# linestyle marker markeredgecolor markeredgewidth markerfacecolor markersize alpha
ax = seaborn.lineplot(x=x, y=y, ax=self.ax, color=color, linewidth=linewidth,
Expand Down Expand Up @@ -172,7 +171,7 @@ def hist(self, x: np.ndarray, bins: list[float] = None, normed: bool = True, **k
return self.ax.hist(x, bins=bins, normed=normed, **kwargs)

def autolabel(self, rects: BarContainer, above: bool = True, fontsize: int = 6,
font='Palatino', fontweight: str = 'bold', math_fontfamily='cm'):
fontproperties='Palatino', fontweight: str = 'bold', math_fontfamily='cm') -> None:
"""Attach a text label above each bar in *rects*, displaying its height."""
for rect in rects:
height = int(rect.get_height())
Expand All @@ -182,4 +181,4 @@ def autolabel(self, rects: BarContainer, above: bool = True, fontsize: int = 6,
xytext=(0, offset), # 3 points vertical offset
textcoords="offset points",
ha='center', va='bottom', fontsize=fontsize,
font=font, fontweight=fontweight, math_fontfamily=math_fontfamily)
fontproperties=fontproperties, fontweight=fontweight, math_fontfamily=math_fontfamily)
2 changes: 1 addition & 1 deletion examples/figure_attack_asr_alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@
y_grid = monotone(y_grid, increase=False)
y_grid = avg_smooth(y_grid, window=100)
y_grid[0] = y_list[0]
fig.curve(x_grid, y_grid, color=color_dict[key])
fig.plot(x_grid, y_grid, color=color_dict[key])
fig.scatter(
x_list, y_list, color=color_dict[key], marker=mark_dict[key], label=attack_mapping[key])
if dataset == 'cifar10':
Expand Down
2 changes: 1 addition & 1 deletion examples/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
fig.set_axis_lim('y', lim=[0, 100], piece=5, margin=[1.0, 1.0],
_format='%d')

fig.curve(x=x, y=y, color=ting_color['red'])
fig.plot(x=x, y=y, color=ting_color['red'])
fig.scatter(x=x, y=y, color=ting_color['red'],
marker='H', label='resnet')
fig.set_legend()
Expand Down

0 comments on commit 69d1dbb

Please sign in to comment.