欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

[python] plt绘图前准备

程序员文章站 2022-05-19 13:32:55
...
import matplotlib.pyplot as plt
# ======================== plot
def set_figure(font_size=10., tick_size=8., ms=7., lw=1.2, fig_w=8.):
    # print(plt.rcParams.keys())  # 很有用,查看所需属性
    # exit()
    cm_to_inc = 1 / 2.54  # 厘米和英寸的转换 1inc = 2.54cm
    w = fig_w * cm_to_inc  # cm ==> inch
    h = w * 3 / 4
    plt.rcParams['figure.figsize'] = (w, h)  # 单位 inc
    plt.rcParams['figure.dpi'] = 300
    # plt.rcParams['figure.figsize'] = (14 * cm_to_inc, 6 * cm_to_inc)

    # 1. Times New Roman or Arial:
    plt.rc('font', family='Arial', weight= 'normal', size=float(font_size))
    # 2. Helvetica:
    # font = {'family': 'sans-serif', 'sans-serif': 'Helvetica',
    #         'weight': 'normal', 'size': float(font_size)}
    # plt.rc('font', **font)  # pass in the font dict as kwargs

    plt.rcParams['axes.linewidth'] = lw  # 图框宽度

    # plt.rcParams['lines.markeredgecolor'] = 'k'
    plt.rcParams['lines.markeredgewidth'] = lw
    plt.rcParams['lines.markersize'] = ms

    # 刻度在内
    plt.rcParams['xtick.direction'] = 'in'
    plt.rcParams['xtick.labelsize'] = tick_size
    plt.rcParams['xtick.major.width'] = lw
    plt.rcParams['xtick.major.size'] = 2.5  # 刻度长度

    plt.rcParams['ytick.direction'] = 'in'
    plt.rcParams['ytick.labelsize'] = tick_size
    plt.rcParams['ytick.major.width'] = lw
    plt.rcParams['ytick.major.size'] = 2.5

    plt.rcParams["legend.frameon"] = True  # 图框
    plt.rcParams["legend.framealpha"] = 0.8  # 不透明度
    plt.rcParams["legend.fancybox"] = False  # 圆形边缘
    plt.rcParams['legend.edgecolor'] = 'k'
    plt.rcParams["legend.columnspacing"] = 1  # /font unit 以字体大小为单位
    plt.rcParams['legend.labelspacing'] = 0.2
    plt.rcParams["legend.borderaxespad"] = 0.5
    plt.rcParams["legend.borderpad"] = 0.3


marker_style = ['o', '^', 'v', 's', 'd', '>', '<', 'h']
marker_color = [
    [0.00, 0.45, 0.74],  # 蓝色
    [0.93, 0.69, 0.13],  # 黄色
    [0.85, 0.33, 0.10],  # 橘红色
    [0.49, 0.18, 0.56],  # 紫色
    [0.47, 0.67, 0.19],  # 绿色
    [0.30, 0.75, 0.93],  # 青色
    [0.64, 0.08, 0.18],  # 棕色
]
line_style = ['-', '--', '-.', ':']


def plot_acc(data, x_lim, y_lim, legend_list, save_path):
    lw = 1.2
    ms = 7
    set_figure(lw=lw, ms=ms, font_size=10, tick_size=10)
    plt.figure(1)
    x = [1, 2, 3, 4]
    x_name = ['10', '30', '50', '100']

    for i in range(len(data)):
        plt.plot(x, data[i], lw=lw, color='k', marker=marker_style[i],
                 markerfacecolor=marker_color[i], markeredgecolor='k')
    # adjust
    plt.ylim(y_lim)
    plt.xlim(x_lim)
    plt.ylabel('Accuracy (%)')
    plt.xlabel('Sample size')
    plt.xticks(x, x_name)
    # plt.yticks([0, 20, 40, 60, 80, 95])
    # plt.legend(legend_list, ncol=2, loc="lower center",
    #            labelspacing=0.2).get_frame().set_linewidth(lw)
    plt.legend(legend_list, ncol=2, loc="lower center").get_frame().set_linewidth(lw)

    # save
    order = input('Save fig' + save_path[-4:] + '? Y/N\n')
    if order == 'y' or order == 'Y':
        # plt.savefig(save_path, dpi=300, format='svg', bbox_inches='tight')
        plt.savefig(save_path + '.svg', dpi=600, format='svg', bbox_inches='tight', pad_inches=0.01)
        plt.savefig(save_path + '.png', dpi=600, format='png', bbox_inches='tight', pad_inches=0.01)
        # 保留0.01白边,防止切掉线框.注意剪裁之后,图像尺寸会变小。

    # show
    plt.show()


def plot_box_acc(data, legend_list, save_path):
    lw = 1.2
    ms = 7
    # set_figure(font_size=10, tick_size=8, lw=lw, ms=ms, fig_w=8.5)
    set_figure(font_size=10, tick_size=10, lw=lw, ms=ms, fig_w=8)
    plt.figure(1)

    plt.boxplot(data, showmeans=True, showfliers=False,
                labels=legend_list,
                meanprops=dict(markersize=6), whiskerprops=dict(linewidth=lw),
                boxprops=dict(linewidth=lw),
                capprops=dict(linewidth=lw),  # 首尾横线属性
                medianprops=dict(linewidth=lw),  # 中位线属性
                )
    # plt.violinplot(data, showmeans=True, showmedians=True)
    plt.ylabel('Accuracy (%)')
    # plt.xlabel('Models')
    # save
    order = input('Save fig' + save_path[-4:] + '? Y/N\n')
    if order == 'y' or order == 'Y':
        plt.savefig(save_path + '.svg', dpi=600, format='svg', bbox_inches='tight', pad_inches=0.01)
        plt.savefig(save_path + '.png', dpi=600, format='png', bbox_inches='tight', pad_inches=0.01)
        # 保留0.01白边,防止切掉线框.注意剪裁之后,图像尺寸会变小。

    plt.show()


def plot_acc_ablation(data, x_lim, y_lim, legend_list, save_path):
    lw = 1.2
    ms = 7
    set_figure(lw=lw, ms=ms, font_size=10, tick_size=10)
    plt.figure(1)
    x = [1, 2, 3, 4]
    x_name = ['10', '30', '50', '100']

    for i in range(2):
        plt.plot(x, data[i], lw=lw, color='k', marker=marker_style[i],
                 markerfacecolor=marker_color[i], markeredgecolor='k')

    plt.plot(x, data[2], lw=lw, color='k', marker=marker_style[5],
             markerfacecolor=marker_color[5], markeredgecolor='k')  # DASMN
    plt.plot(x, data[3], '--', lw=lw, color='k', marker=marker_style[0],
             markerfacecolor=marker_color[3], markeredgecolor='k')  # CNN
    plt.plot(x, data[4], '--', lw=lw, color='k', marker=marker_style[1],
             markerfacecolor=marker_color[4], markeredgecolor='k')  # ProtoNets
    # adjust
    plt.ylim(y_lim)
    plt.xlim(x_lim)
    plt.ylabel('Accuracy (%)')
    plt.xlabel('Sample size')
    plt.xticks(x, x_name)
    # plt.yticks([0, 20, 40, 60, 80, 95])
    plt.legend(legend_list, ncol=2, loc="lower center",
               labelspacing=0.2).get_frame().set_linewidth(lw)

    # save
    order = input('Save fig' + save_path[-4:] + '? Y/N\n')
    if order == 'y' or order == 'Y':
        # plt.savefig(save_path, dpi=300, format='svg', bbox_inches='tight')
        plt.savefig(save_path + '.svg', dpi=600, format='svg', bbox_inches='tight', pad_inches=0.01)
        plt.savefig(save_path + '.png', dpi=600, format='png', bbox_inches='tight', pad_inches=0.01)
        # 保留0.01白边,防止切掉线框.注意剪裁之后,图像尺寸会变小。

    # show
    plt.show()


if __name__ == '__main__':
    pass