Matplotlib 画图

这是我学习matplotlib时记录的一些笔记 ,希望能对你有所帮助😊

Matplotlib 机器学习绘图笔记

在机器学习流程中,数据可视化是不可或缺的一环,它能帮助我们:

  1. **探索性数据分析 (EDA)**:理解数据分布、特征间的关系、发现异常值。

  2. 模型评估与诊断:可视化模型性能,如损失函数变化、预测结果与真实值的差异、ROC 曲线等。

  3. 结果展示:清晰地展示模型结果和发现。

Matplotlib 是 Python 中最基础、最强大的绘图库。掌握它的核心用法对每个机器学习工程师都至关重要。

基础:pyplot 与面向对象接口

Matplotlib 有两种常用的绘图接口:

  1. pyplot 接口:通过 import matplotlib.pyplot as plt 使用,调用如 plt.figure()plt.plot() 等函数,适合快速、简单的绘图。如果没有子图,那么也可以采用这种简便的方式。

  2. 面向对象 (OO) 接口:先创建一个 Figure 和一个或多个 Axes 对象(子图),然后调用这些对象的方法进行绘图,如 ax.plot()ax.set_title()这种方式更灵活,对复杂图形的控制力更强,是官方推荐的方式。

本笔记将主要使用面向对象接口

1
2
3
4
5
6
7
8
9
10
import matplotlib.pyplot as plt
import numpy as np

# 创建一个 Figure 和一个 Axes. fig 是整个画布,ax 是画布上的一个绘图区域(子图)
fig, ax = plt.subplots(figsize=(8, 6))

# 在 ax 上进行绘图...

# 显示图形
plt.show()

1. 散点图 (Scatter Plot)

散点图是探索两个数值变量之间关系的利器。在机器学习中,常用于:

  • 观察两个特征 (feature) 之间的相关性。

  • 在回归任务中,可视化模型的预测值与真实值的对比。

核心方法ax.scatter()

常用参数解释

  • x, y: 一维数组,分别代表横轴和纵轴的数据。

  • s: 点的大小 (scale)。可以是一个数值(所有点同样大小),也可以是一个数组(每个点大小不同)。

  • c: 点的颜色 (color)。可以是一个颜色字符串(如 'r', 'blue'),也可以是一个数值数组,配合 cmap 使用,实现颜色渐变。

  • cmap: 颜色映射表 (Colormap),如 'viridis', 'coolwarm', 'plasma'。当 c 是数值数组时生效。

  • alpha: 透明度,取值范围在 0 (完全透明) 到 1 (完全不透明) 之间,在数据点密集时非常有用。

  • marker: 点的形状,如 'o' (圆形), '^' (三角形), 's' (正方形), 'x'

  • label: 图例标签,需要配合 ax.legend() 使用。

  • edgecolors: 点的边缘颜色。

代码示例:特征关系探索

假设我们有两个特征 feature1feature2,以及一个目标类别 target

Python

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# --- 准备数据 ---
# 生成 150 个样本数据
np.random.seed(42)
feature1 = np.random.randn(150) * 10
feature2 = feature1 * 0.5 + np.random.randn(150) * 5
# 生成 3 个类别
target = np.random.randint(0, 3, 150)

# --- 绘图 ---
fig, ax = plt.subplots(figsize=(10, 7))

# 使用 scatter 方法绘图
# c=target 会根据 target 的值 (0, 1, 2) 自动选择颜色
# s=50 设置点的大小,alpha=0.7 设置透明度
scatter = ax.scatter(feature1, feature2, c=target, s=50, alpha=0.7, cmap='viridis')

# --- 美化图形 ---
ax.set_title('特征1 与 特征2 的关系', fontsize=16)
ax.set_xlabel('特征 1 (Feature 1)', fontsize=12)
ax.set_ylabel('特征 2 (Feature 2)', fontsize=12)

# 添加图例
legend1 = ax.legend(*scatter.legend_elements(), title="类别 (Target)")
ax.add_artist(legend1)

# 添加网格线,透明度 0.5
ax.grid(True, linestyle='--', alpha=0.5)

# 显示图形
plt.show()

效果如下:


2. 折线图 (Line Plot)

折线图常用于展示数据随某个序列(如时间、迭代次数)的变化趋势。在机器学习中,典型应用包括:

  • 绘制模型训练过程中的损失 (Loss) 和准确率 (Accuracy) 曲线。
  • 绘制学习曲线 (Learning Curve) 来诊断过拟合或欠拟合。

核心方法ax.plot()

常用参数解释

  • x, y: 横轴和纵轴的数据。

  • color: 线的颜色。

  • linestyle: 线的样式,如 '-' (实线), '--' (虚线), ':' (点线), '-.' (点划线)。

  • linewidth: 线的宽度。

  • marker: 数据点的标记样式,如 'o', '.', 's'

  • label: 图例标签。

代码示例:绘制训练与验证损失曲线

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# --- 准备数据 ---
# 模拟训练过程的 epoch 和 loss
epochs = np.arange(1, 51)
train_loss = 0.8 / epochs**0.5 + np.random.randn(50) * 0.05
val_loss = 0.9 / epochs**0.5 + np.random.randn(50) * 0.06 + 0.1

# --- 绘图 ---
fig, ax = plt.subplots(figsize=(10, 7))

# 绘制训练损失曲线
ax.plot(epochs, train_loss, color='blue', linestyle='-', marker='o', markersize=4, label='训练损失 (Train Loss)')

# 绘制验证损失曲线
ax.plot(epochs, val_loss, color='red', linestyle='--', marker='s', markersize=4, label='验证损失 (Validation Loss)')

# --- 美化图形 ---
ax.set_title('模型训练过程中的损失变化', fontsize=16)
ax.set_xlabel('迭代次数 (Epoch)', fontsize=12)
ax.set_ylabel('损失 (Loss)', fontsize=12)
ax.grid(True, linestyle='--', alpha=0.6)
ax.legend() # 显示图例

plt.show()

效果如下:


图片示例

  • CVer通常在加载数据集后,会查看一下前几张图片,用做一个示例,那么下面这套流程可以实现
1
2
3
4
5
6
7
temp_dataset = get_dataset(os.path.join(workspace_dir, 'faces'))

images = [temp_dataset[i] for i in range(4)]
grid_img = torchvision.utils.make_grid(images, nrow=4) #将多张图像拼接成网格形式
plt.figure(figsize=(10,10))
plt.imshow(grid_img.permute(1, 2, 0)) # imshow的输入为(H, W, C)
plt.show()


总结与通用美化技巧

无论绘制哪种图形,最后的美化步骤都至关重要。

通用美化函数 (都属于 Axes 对象的方法):

  • ax.set_title(): 设置标题。

  • ax.set_xlabel(), ax.set_ylabel(): 设置 x, y 轴标签。

  • ax.set_xlim(), ax.set_ylim(): 设置 x, y 轴的范围。

  • ax.set_xticks(), ax.set_yticks(): 设置 x, y 轴的刻度位置。

  • ax.set_xticklabels(), ax.set_yticklabels(): 设置刻度对应的标签。

  • ax.legend(): 显示图例。

  • ax.grid(True): 显示网格线。

  • plt.tight_layout(): 在 plt.show() 前调用,自动调整子图参数,使之填充整个图像区域,防止标签重叠。

  • plt.savefig('figure.png', dpi=300): 保存图像到文件,dpi 参数可以设置分辨率。


Matplotlib 画图
http://pzhwuhu.github.io/2025/09/15/Matplotlib 画图/
本文作者
pzhwuhu
发布于
2025年9月15日
更新于
2025年10月16日
许可协议