一、环境
TensorFlow API r1.14
CUDA 9.0 V9.0.176
Python 3.7.3
二、官方说明
一个或多个矩阵的奇异值分解
https://tensorflow.google.cn/api_docs/python/tf/linalg/svd
tf.linalg.svd(tensor,full_matrices=False,compute_uv=True,name=None
)
参数:
tensor:形状为 […, M, N] 的张量。设 P 为 M 和 N 的最小值
full_matrices:布尔型参数,True 表示计算全尺寸的 u 和 v;默认为 False 表示仅计算前导 P 个奇异向量。如果 compute_uv 是 False,则忽略。
compute_uv:布尔型参数,默认为 True 表示计算左、右奇异向量并分别返回到 u、v 中,False 表示计算奇异值,会大幅度提高速度。
name:字符型参数,可选参数,设置操作的名称。
返回:
s:奇异值,形状为 […, P],数值会倒序排列,所以 s[…, 0] 是最大值,s[…, 1] 是次大值。
u:左奇异向量,如果参数 full_matrix 为默认值 False,则 u 的形状为 […, M, P];如果参数 full_matrix 为True,则 u 的形状为 […, M, M]。如果 compute_uv 为 False,则不返回 u 值;
v:右奇异向量,如果参数 full_matrix 为默认值 True,则 v 的形状为 […, N, P]。如果参数 full_matrix 为True,则 u 的形状为 […, N, N]。如果 compute_uv 为 False,则不返回 v 值;
三、实例
>>> matrix = tf.constant([1,2,3,4,5,6,7,8,9], shape=[3,3])
>>> matrix_float = tf.dtypes.cast(matrix, tf.float32) # 将矩阵的类型转换为 tensorflow 支持的 float32 形式,否则会报错
>>> s, u, v = tf.linalg.svd(matrix_float)
>>> s
<tf.Tensor: id=7, shape=(3,), dtype=float32, numpy=array([1.6848103e+01, 1.0683696e+00, 2.8763120e-07], dtype=float32)>
>>> u
<tf.Tensor: id=8, shape=(3, 3), dtype=float32, numpy=
array([[ 0.21483716, 0.8872305 , -0.40824857],[ 0.5205872 , 0.24964423, 0.8164965 ],[ 0.8263376 , -0.3879429 , -0.4082481 ]], dtype=float32)>
>>> v
<tf.Tensor: id=9, shape=(3, 3), dtype=float32, numpy=
array([[ 0.47967106, -0.77669096, 0.40824836],[ 0.5723676 , -0.07568647, -0.81649655],[ 0.6650643 , 0.62531805, 0.40824822]], dtype=float32)>
需要注意的是必须将矩阵的数据类型设为浮点类型,否则会报错:
>>> s, u, v = tf.linalg.svd(matrix)
Traceback (most recent call last):File "<stdin>", line 1, in <module>File "/******/Anaconda/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/linalg_ops.py", line 418, in svdtensor, compute_uv=compute_uv, full_matrices=full_matrices, name=name)File "/******/Anaconda/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/gen_linalg_ops.py", line 2265, in svd_six.raise_from(_core._status_to_exception(e.code, message), None)File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InternalError: Could not find valid device for node.
Node: {
{
node Svd}}
All kernels registered for op Svd :device='CPU'; T in [DT_COMPLEX128]device='CPU'; T in [DT_COMPLEX64]device='CPU'; T in [DT_DOUBLE]device='CPU'; T in [DT_FLOAT][Op:Svd]