theano.scan(fn, sequences=None, outputs_info=None, non_sequences=None, n_steps=None, truncate_gradient=-1, go_backwards=False, mode=None, name=None, profile=False, allow_gc=None, strict=False)
主要参数的含义:
fn :一步 scan 所进行的操作sequences :输入的序列outputs_info:前一步输出结果的初始状态non_sequences:非序列参数,即每次迭代都要用到此值,若为矩阵,每次用到此矩阵,和序列不一样,每次只用到序列的一个值n_steps:迭代步数go_backwards:是否从后向前遍历
输出为一个元组 (outputs, updates):
outputs:从初始状态开始,每一步 fn 的输出结果updates:一个字典,用来记录 scan 过程中用到的共享变量更新规则,构造函数的时候,如果需要更新共享变量,将这个变量当作 updates 的参数传入。
fn 是一个函数句柄,对于这个函数句柄,它每一步接受的参数是由 sequences, outputs_info, non_sequence 这三个参数所决定的,默认情况下,在第 k 次迭代时,如果 sequences 和 outputs_info 中给定的值不是字典(dictionary)或者一个字典列表(list of dictionaries),那么
sequences 中的序列 seq 传入 fn 的是 seq[k] 的值
outputs_info 中的序列 output 传入 fn 的是 output[k-1] 的值
fn 的返回值有两部分 (outputs_list, update_dictionary),第一部分将作为序列,传入 outputs 中,与 outputs_info 中的初始输入值的维度一致(如果没有给定 outputs_info ,输出值可以任意。)
第二部分则是更新规则的字典,告诉我们如何对 scan 中使用到的一些共享的变量进行更新:
return [y1_t, y2_t], {x:x+1}
这两部分可以任意,即顺序既可以是 (outputs_list, update_dictionary), 也可以是 (update_dictionary, outputs_list),theano 会根据类型自动识别。
两部分只需要有一个存在即可,另一个可以为空。
具体例子见GitHub/loop with scan