import numpy as np
import matplotlib
import matplotlib.pyplot as pltvegetables = ["cucumber", "tomato", "lettuce", "asparagus","potato", "wheat", "barley"]
#蔬菜类
farmers = ["Farmer Joe", "Upland Bros.", "Smith Gardening","Agrifun", "Organiculture", "BioGoods Ltd.", "Cornylee Corp."]
#农民类harvest = np.array([[0.8, 2.4, 2.5, 3.9, 0.0, 4.0, 0.0],[2.4, 0.0, 4.0, 1.0, 2.7, 0.0, 0.0],[1.1, 2.4, 0.8, 4.3, 1.9, 4.4, 0.0],[0.6, 0.0, 0.3, 0.0, 3.1, 0.0, 0.0],[0.7, 1.7, 0.6, 2.6, 2.2, 6.2, 0.0],[1.3, 1.2, 0.0, 0.0, 0.0, 3.2, 5.1],[0.1, 2.0, 0.0, 1.4, 0.0, 1.9, 6.3]])
#定义热力图数据fig, ax = plt.subplots()
#将元组分解为fig和ax两个变量
im = ax.imshow(harvest)
#显示图片ax.set_xticks(np.arange(len(farmers)))
#设置x轴刻度间隔
ax.set_yticks(np.arange(len(vegetables)))
#设置y轴刻度间隔
ax.set_xticklabels(farmers)
#设置x轴标签'''
ax.set_yticklabels(vegetables)
#设置y轴标签'''plt.setp(ax.get_xticklabels(), rotation=45, ha="right",rotation_mode="anchor")
#设置标签 旋转45度 ha有三个选择:right,center,left(对齐方式)for i in range(len(vegetables)):for j in range(len(farmers)):text = ax.text(j, i, harvest[i, j],ha="center", va="center", color="w")
'''
画图
j,i:表示坐标值上的值
harvest[i, j]表示内容
ha有三个选择:right,center,left(对齐方式)
va有四个选择:'top', 'bottom', 'center', 'baseline'(对齐方式)
color:设置颜色
'''ax.set_title("Harvest of local farmers (in tons/year)")
#设置题目
fig.tight_layout() #自动调整子图参数,使之填充整个图像区域。
plt.show() #图像展示
def heatmap(data, row_labels, col_labels, ax=None,cbar_kw={}, cbarlabel="", **kwargs):"""从一个numpy数组和两个标签列表创建一个热图。data形状为(N,M)的2D numpy数组。row_labels长度为N且带有行标签的列表或数组。col_labels长度为M的列表或数组,带有列的标签。ax绘制热图的`matplotlib.axes.Axes`实例。 如果未提供,请使用当前轴或创建一个新轴。 (可选的。)cbar_kw带有`matplotlib.Figure.colorbar`参数的字典。 可选的。cbarlabel颜色条的标签。 可选的。**kwargs所有其他参数都转发给“imshow”。"""if not ax:ax = plt.gca()#如果不在ax,挪动坐标轴im = ax.imshow(data, **kwargs)#画图#创造彩条cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)'''im:一个可以映射颜色的对象ax=ax:指示im得到的对象在哪里展示(整体)'''cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")'''设置y轴标签'''ax.set_xticks(np.arange(data.shape[1]))ax.set_yticks(np.arange(data.shape[0]))#设置x,y轴刻度间隔ax.set_xticklabels(col_labels)ax.set_yticklabels(row_labels)#设置横,纵轴#让水平轴标签显示在顶部ax.tick_params(top=True, bottom=False,labeltop=True, labelbottom=False)#旋转刻度线标签并设置其对齐方式。plt.setp(ax.get_xticklabels(), rotation=-30, ha="right",rotation_mode="anchor")#关闭spines并创建白色网格。#spines是连接轴刻度标记的线,而且标明了数据区域的边界for edge, spine in ax.spines.items():spine.set_visible(False)ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True)ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True)##设置x,y轴刻度间隔ax.grid(which="minor", color="w", linestyle='-', linewidth=3)#设置边框主刻度线,颜色为白色,线条格式为'-',线的宽度为3ax.tick_params(which="minor", bottom=False, left=False)#设置主刻度线,参数bottom, top, left, right的值为布尔值,分别代表设置绘图区四个边框线上的的刻度线是否显示return im, cbardef annotate_heatmap(im, data=None, valfmt="{x:.2f}",textcolors=["black", "white"],threshold=None, **textkw):'''im要标记的AxesImage。data用于注释的数据。 如果为None,则使用图像数据。 (可选的。)valfmt热图内注释的格式。 这应该使用字符串格式方法,例如 “ $ {x:.2f}”,或成为`matplotlib.ticker.Formatter`。 (可选的。)textcolors两种颜色规格的列表或数组。 第一个代表值低于阈值,第二个代表高于阈值的值。 (可选的。)threshold以数据单位表示的值,根据该值,textcolors中的颜色是应用。 如果为None(默认),则将颜色图的中间用作分离。( 可选的。)**kwargs所有其他参数都转发给用于创建的每个`text`调用。文字标签。'''if not isinstance(data, (list, np.ndarray)):data = im.get_array()#保证data是一个list类型#将阈值标准化为图像颜色范围。if threshold is not None:threshold = im.norm(threshold)else:threshold = im.norm(data.max())/2.#将默认对齐方式设置为居中,但允许将其设置为居中#被textkw覆盖。kw = dict(horizontalalignment="center",verticalalignment="center")kw.update(textkw)#获取格式化程序(如果提供了字符串)if isinstance(valfmt, str):valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)#给热力图标注文本设置格式#遍历数据并为每个“pixel”创建一个“Text”。#根据数据更改文本的颜色。texts = []for i in range(data.shape[0]):for j in range(data.shape[1]):kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)texts.append(text)return texts
fig, ax = plt.subplots()
#将元组分解为fig和ax两个变量 im, cbar = heatmap(harvest, vegetables, farmers, ax=ax,cmap="YlGn", cbarlabel="harvest [t/year]")
"""从一个numpy数组和两个标签列表创建一个热图。data形状为(N,M)的2D numpy数组。row_labels长度为N且带有行标签的列表或数组。col_labels长度为M的列表或数组,带有列的标签。ax绘制热图的`matplotlib.axes.Axes`实例。 如果未提供,请使用当前轴或创建一个新轴。 (可选的。)cbar_kw带有`matplotlib.Figure.colorbar`参数的字典。 可选的。cbarlabel颜色条的标签。 可选的。**kwargs所有其他参数都转发给“imshow”。
"""
texts = annotate_heatmap(im, valfmt="{x:.1f} t")
'''im要标记的AxesImage。data用于注释的数据。 如果为None,则使用图像数据。 (可选的。)valfmt热图内注释的格式。 这应该使用字符串格式方法,例如 “ $ {x:.2f}”,或成为`matplotlib.ticker.Formatter`。 (可选的。)textcolors两种颜色规格的列表或数组。 第一个代表值低于阈值,第二个代表高于阈值的值。 (可选的。)threshold以数据单位表示的值,根据该值,textcolors中的颜色是应用。 如果为None(默认),则将颜色图的中间用作分离。( 可选的。)**kwargs所有其他参数都转发给用于创建的每个`text`调用。文字标签。'''fig.tight_layout() #自动调整子图参数,使之填充整个图像区域。
plt.show() #图像展示
np.random.seed(19680801)fig, ((ax, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(8, 6))#使用不同的字体大小和颜色图复制上面的示例。im, _ = heatmap(harvest, vegetables, farmers, ax=ax,cmap="Wistia", cbarlabel="harvest [t/year]")
annotate_heatmap(im, valfmt="{x:.1f}", size=7)#创建一些新数据,为imshow(vmin)提供更多参数,
#在注释上使用整数格式,并提供一些颜色。
data = np.random.randint(2, 100, size=(7, 7))
y = ["Book {}".format(i) for i in range(1, 8)]
x = ["Store {}".format(i) for i in list("ABCDEFG")]
im, _ = heatmap(data, y, x, ax=ax2, vmin=0,cmap="magma_r", cbarlabel="weekly sold copies")
annotate_heatmap(im, valfmt="{x:d}", size=7, threshold=20,textcolors=["red", "white"])#有时甚至数据本身也是分类的。 在这里我们使用
#:class:`matplotlib.colors.BoundaryNorm`将数据放入类中
#并使用它为图着色,也可以获取类
#来自一组类的标签。
data = np.random.randn(6, 6)
y = ["Prod. {}".format(i) for i in range(10, 70, 10)]
x = ["Cycle {}".format(i) for i in range(1, 7)]qrates = np.array(list("ABCDEFG"))
norm = matplotlib.colors.BoundaryNorm(np.linspace(-3.5, 3.5, 8), 7)
fmt = matplotlib.ticker.FuncFormatter(lambda x, pos: qrates[::-1][norm(x)])im, _ = heatmap(data, y, x, ax=ax3,cmap=plt.get_cmap("PiYG", 7), norm=norm,cbar_kw=dict(ticks=np.arange(-3, 4), format=fmt),cbarlabel="Quality Rating")annotate_heatmap(im, valfmt=fmt, size=9, fontweight="bold", threshold=-1,textcolors=["red", "black"])#我们可以很好地绘制一个相关矩阵。 由于这受-1和1约束,
#我们将它们用作vmin和vmax。 我们可能还会删除前导零并隐藏
#通过使用a对角元素(全为1)
#:class:`matplotlib.ticker.FuncFormatter`.corr_matrix = np.corrcoef(np.random.rand(6, 5))
im, _ = heatmap(corr_matrix, vegetables, vegetables, ax=ax4,cmap="PuOr", vmin=-1, vmax=1,cbarlabel="correlation coeff.")def func(x, pos):return "{:.2f}".format(x).replace("0.", ".").replace("1.00", "")annotate_heatmap(im, valfmt=matplotlib.ticker.FuncFormatter(func), size=7)plt.tight_layout()
plt.show()
于是就大功告成啦!