import numpy
as np
import matplotlib.pyplot
as plt
import matplotlib.image
as mpimg
import caffe
%matplotlib inline
plt.rcParams[
'figure.figsize'] = (
8,
8)
plt.rcParams[
'image.interpolation'] =
'nearest'
plt.rcParams[
'image.cmap'] =
'gray'
def show_data(data, padsize=1, padval=0):
"""Take an array of shape (n, height, width) or (n, height, width, 3)
and visualize each (height, width) thing in a grid of size approx. sqrt(n) by sqrt(n)"""
data -= data.min()
data /= data.max()
n = int(np.ceil(np.sqrt(data.shape[
0])))
padding = ((
0, n **
2 - data.shape[
0]), (
0, padsize), (
0, padsize)) + ((
0,
0),) * (data.ndim -
3)
data = np.pad(data, padding, mode=
'constant', constant_values=(padval, padval))
data = data.reshape((n, n) + data.shape[
1:]).transpose((
0,
2,
1,
3) + tuple(range(
4, data.ndim +
1)))
data = data.reshape((n * data.shape[
1], n * data.shape[
3]) + data.shape[
4:])
plt.figure()
plt.imshow(data,cmap=
'gray')
plt.axis(
'off')
print net.blobs[
'conv1'].data[
0].shape
show_data(net.blobs[
'conv1'].data[
0])
print net.params[
'conv1'][
0].data.shape
show_data(net.params[
'conv1'][
0].data.reshape(
32*
3,
5,
5))