[python] plt绘图前准备
程序员文章站
2022-05-19 13:32:55
...
import matplotlib.pyplot as plt
# ======================== plot
def set_figure(font_size=10., tick_size=8., ms=7., lw=1.2, fig_w=8.):
# print(plt.rcParams.keys()) # 很有用,查看所需属性
# exit()
cm_to_inc = 1 / 2.54 # 厘米和英寸的转换 1inc = 2.54cm
w = fig_w * cm_to_inc # cm ==> inch
h = w * 3 / 4
plt.rcParams['figure.figsize'] = (w, h) # 单位 inc
plt.rcParams['figure.dpi'] = 300
# plt.rcParams['figure.figsize'] = (14 * cm_to_inc, 6 * cm_to_inc)
# 1. Times New Roman or Arial:
plt.rc('font', family='Arial', weight= 'normal', size=float(font_size))
# 2. Helvetica:
# font = {'family': 'sans-serif', 'sans-serif': 'Helvetica',
# 'weight': 'normal', 'size': float(font_size)}
# plt.rc('font', **font) # pass in the font dict as kwargs
plt.rcParams['axes.linewidth'] = lw # 图框宽度
# plt.rcParams['lines.markeredgecolor'] = 'k'
plt.rcParams['lines.markeredgewidth'] = lw
plt.rcParams['lines.markersize'] = ms
# 刻度在内
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['xtick.labelsize'] = tick_size
plt.rcParams['xtick.major.width'] = lw
plt.rcParams['xtick.major.size'] = 2.5 # 刻度长度
plt.rcParams['ytick.direction'] = 'in'
plt.rcParams['ytick.labelsize'] = tick_size
plt.rcParams['ytick.major.width'] = lw
plt.rcParams['ytick.major.size'] = 2.5
plt.rcParams["legend.frameon"] = True # 图框
plt.rcParams["legend.framealpha"] = 0.8 # 不透明度
plt.rcParams["legend.fancybox"] = False # 圆形边缘
plt.rcParams['legend.edgecolor'] = 'k'
plt.rcParams["legend.columnspacing"] = 1 # /font unit 以字体大小为单位
plt.rcParams['legend.labelspacing'] = 0.2
plt.rcParams["legend.borderaxespad"] = 0.5
plt.rcParams["legend.borderpad"] = 0.3
marker_style = ['o', '^', 'v', 's', 'd', '>', '<', 'h']
marker_color = [
[0.00, 0.45, 0.74], # 蓝色
[0.93, 0.69, 0.13], # 黄色
[0.85, 0.33, 0.10], # 橘红色
[0.49, 0.18, 0.56], # 紫色
[0.47, 0.67, 0.19], # 绿色
[0.30, 0.75, 0.93], # 青色
[0.64, 0.08, 0.18], # 棕色
]
line_style = ['-', '--', '-.', ':']
def plot_acc(data, x_lim, y_lim, legend_list, save_path):
lw = 1.2
ms = 7
set_figure(lw=lw, ms=ms, font_size=10, tick_size=10)
plt.figure(1)
x = [1, 2, 3, 4]
x_name = ['10', '30', '50', '100']
for i in range(len(data)):
plt.plot(x, data[i], lw=lw, color='k', marker=marker_style[i],
markerfacecolor=marker_color[i], markeredgecolor='k')
# adjust
plt.ylim(y_lim)
plt.xlim(x_lim)
plt.ylabel('Accuracy (%)')
plt.xlabel('Sample size')
plt.xticks(x, x_name)
# plt.yticks([0, 20, 40, 60, 80, 95])
# plt.legend(legend_list, ncol=2, loc="lower center",
# labelspacing=0.2).get_frame().set_linewidth(lw)
plt.legend(legend_list, ncol=2, loc="lower center").get_frame().set_linewidth(lw)
# save
order = input('Save fig' + save_path[-4:] + '? Y/N\n')
if order == 'y' or order == 'Y':
# plt.savefig(save_path, dpi=300, format='svg', bbox_inches='tight')
plt.savefig(save_path + '.svg', dpi=600, format='svg', bbox_inches='tight', pad_inches=0.01)
plt.savefig(save_path + '.png', dpi=600, format='png', bbox_inches='tight', pad_inches=0.01)
# 保留0.01白边,防止切掉线框.注意剪裁之后,图像尺寸会变小。
# show
plt.show()
def plot_box_acc(data, legend_list, save_path):
lw = 1.2
ms = 7
# set_figure(font_size=10, tick_size=8, lw=lw, ms=ms, fig_w=8.5)
set_figure(font_size=10, tick_size=10, lw=lw, ms=ms, fig_w=8)
plt.figure(1)
plt.boxplot(data, showmeans=True, showfliers=False,
labels=legend_list,
meanprops=dict(markersize=6), whiskerprops=dict(linewidth=lw),
boxprops=dict(linewidth=lw),
capprops=dict(linewidth=lw), # 首尾横线属性
medianprops=dict(linewidth=lw), # 中位线属性
)
# plt.violinplot(data, showmeans=True, showmedians=True)
plt.ylabel('Accuracy (%)')
# plt.xlabel('Models')
# save
order = input('Save fig' + save_path[-4:] + '? Y/N\n')
if order == 'y' or order == 'Y':
plt.savefig(save_path + '.svg', dpi=600, format='svg', bbox_inches='tight', pad_inches=0.01)
plt.savefig(save_path + '.png', dpi=600, format='png', bbox_inches='tight', pad_inches=0.01)
# 保留0.01白边,防止切掉线框.注意剪裁之后,图像尺寸会变小。
plt.show()
def plot_acc_ablation(data, x_lim, y_lim, legend_list, save_path):
lw = 1.2
ms = 7
set_figure(lw=lw, ms=ms, font_size=10, tick_size=10)
plt.figure(1)
x = [1, 2, 3, 4]
x_name = ['10', '30', '50', '100']
for i in range(2):
plt.plot(x, data[i], lw=lw, color='k', marker=marker_style[i],
markerfacecolor=marker_color[i], markeredgecolor='k')
plt.plot(x, data[2], lw=lw, color='k', marker=marker_style[5],
markerfacecolor=marker_color[5], markeredgecolor='k') # DASMN
plt.plot(x, data[3], '--', lw=lw, color='k', marker=marker_style[0],
markerfacecolor=marker_color[3], markeredgecolor='k') # CNN
plt.plot(x, data[4], '--', lw=lw, color='k', marker=marker_style[1],
markerfacecolor=marker_color[4], markeredgecolor='k') # ProtoNets
# adjust
plt.ylim(y_lim)
plt.xlim(x_lim)
plt.ylabel('Accuracy (%)')
plt.xlabel('Sample size')
plt.xticks(x, x_name)
# plt.yticks([0, 20, 40, 60, 80, 95])
plt.legend(legend_list, ncol=2, loc="lower center",
labelspacing=0.2).get_frame().set_linewidth(lw)
# save
order = input('Save fig' + save_path[-4:] + '? Y/N\n')
if order == 'y' or order == 'Y':
# plt.savefig(save_path, dpi=300, format='svg', bbox_inches='tight')
plt.savefig(save_path + '.svg', dpi=600, format='svg', bbox_inches='tight', pad_inches=0.01)
plt.savefig(save_path + '.png', dpi=600, format='png', bbox_inches='tight', pad_inches=0.01)
# 保留0.01白边,防止切掉线框.注意剪裁之后,图像尺寸会变小。
# show
plt.show()
if __name__ == '__main__':
pass
上一篇: plt.show仍然不能绘图