关于tf.reverse_sequence()简述

 更新时间:2020-01-26 11:00:55   作者:佚名   我要评论(0)

tf.reverse_sequence()简述
在看bidirectional_dynamic_rnn()的源码的时候,看到了代码中有调用 reverse_sequence()这一方法,于是又回去看了下这个函数的用法,发现

tf.reverse_sequence()简述

在看bidirectional_dynamic_rnn()的源码的时候,看到了代码中有调用 reverse_sequence()这一方法,于是又回去看了下这个函数的用法,发现还是有点意思的。根据名字就可以能看得出,这个方法主要是用来翻转序列的,就像双线LSTM中在反向传播那里需要从下文往上文处理一样,需要对序列做一个镜像的翻转处理。

先来看一下这个方法的定义:

reverse_sequence(
  input,
  seq_lengths,
  seq_axis=None,
  batch_axis=None,
  name=None,
  seq_dim=None,
  batch_dim=None)

其中input是输入的需要翻转的目标张量,seq_lengths是一个张量;

其元素是input中每一处需要翻转时翻转的长度,在双向LSTM中这个值统一被设为输入语句的长度,代表着整句话都需要被翻转,而实际上张量中的元素值可以是不同的,下面的例子中就可以看出;

seq_axis和seq_dim的关系,在源码中做了如下操作:

seq_axis = deprecation.deprecated_argument_lookup("seq_axis", seq_axis,
                          "seq_dim", seq_dim)

返回中return gen_array_ops.reverse_sequence(..., seq_dim=seq_axis,...),同理,对于batch_axis和batch_dim也是相同的处理。意义上来说,按照官方给出的解释,“此操作首先沿着维度batch_axis对input进行分割,并且对于每个切片 i,将前 seq_lengths 元素沿维度 seq_axis 反转”。实际上通俗来理解,就是对于张量input中的第batch_axis维中的每一个子张量,在这个子张量的第seq_axis维上进行翻转,翻转的长度为 seq_lengths 张量中对应的数值。

举个例子,如果 batch_axis=0,seq_axis=1,则代表我希望每一行为单位分开处理,对于每一行中的每一列进行翻转。相反的,如果 batch_axis=1,seq_axis=0,则是以列为单位,对于每一列的张量,进行相应行的翻转。回头去看双向RNN的源码,就可以理解当time_major这一属性不同时,time_dim 和 batch_dim 这一对组合的取值为什么恰好是相反的了。

写一个简单的测试代码:

a = tf.constant([[1,2,3], [4,5,6], [7,8,9]])
l = tf.constant([1,2,3],tf.int64) # 每一次翻转长度分别为1,2,3.由于a是(3,3)维的,所以l中数值最大只能是3
x = tf.reverse_sequence(a,seq_lengths=l,seq_axis = 0,batch_axis= 1) # 以列为单位进行翻转,翻转的是每一行的元素
y = tf.reverse_sequence(a,seq_lengths=l,seq_axis = 1,batch_axis= 0) # 以行为单位进行翻转,翻转的是每一列的元素
with tf.Session() as sess:
  print(sess.run(x))
  print(sess.run(y))

结果如下:

# 每一列上的元素种类没有发生变化,但是从每一行来看,行的顺序分别翻转了前1,前2,前3个元素
[[1 5 9]
 [4 2 6]
 [7 8 3]]
# 每一行上的元素种类没有发生变化,但是从每一列来看,列的顺序分别翻转了前1,前2,前3个元素
[[1 2 3]
 [5 4 6]
 [9 8 7]]

以上这篇关于tf.reverse_sequence()简述就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

您可能感兴趣的文章:

  • 详解Python3中的Sequence type的使用

相关文章

  • 关于tf.reverse_sequence()简述

    关于tf.reverse_sequence()简述

    tf.reverse_sequence()简述 在看bidirectional_dynamic_rnn()的源码的时候,看到了代码中有调用 reverse_sequence()这一方法,于是又回去看了下这个函数的用法,发现
    2020-01-26
  • 双向RNN:bidirectional_dynamic_rnn()函数的使用详解

    双向RNN:bidirectional_dynamic_rnn()函数的使用详解

    双向RNN:bidirectional_dynamic_rnn()函数的使用详解 先说下为什么要使用到双向RNN,在读一篇文章的时候,上文提到的信息十分的重要,但这些信息是不足以捕捉文章信
    2020-01-26
  • 关于tf.nn.dynamic_rnn返回值详解

    关于tf.nn.dynamic_rnn返回值详解

    函数原型 tf.nn.dynamic_rnn( cell, inputs, sequence_length=None, initial_state=None, dtype=None, parallel_iterations=None, swap_memory=Fal
    2020-01-26
  • python机器学习库xgboost的使用

    python机器学习库xgboost的使用

    1.数据读取 利用原生xgboost库读取libsvm数据 import xgboost as xgb data = xgb.DMatrix(libsvm文件) 使用sklearn读取libsvm数据 from sklearn.
    2020-01-26
  • 浅谈Tensorflow 动态双向RNN的输出问题

    浅谈Tensorflow 动态双向RNN的输出问题

    tf.nn.bidirectional_dynamic_rnn() 函数: def bidirectional_dynamic_rnn( cell_fw, # 前向RNN cell_bw, # 后向RNN inputs, # 输入 sequence_length=No
    2020-01-26
  • python爬取本站电子书信息并入库的实现代码

    python爬取本站电子书信息并入库的实现代码

    入门级爬虫:只抓取书籍名称,信息及下载地址并存储到数据库 数据库工具类:DBUtil.py import pymysql class DBUtils(object): def connDB(self):
    2020-01-26
  • TFRecord格式存储数据与队列读取实例

    TFRecord格式存储数据与队列读取实例

    Tensor Flow官方网站上提供三种读取数据的方法 1. 预加载数据:在Tensor Flow图中定义常量或变量来保存所有数据,将数据直接嵌到数据图中,当训练数据较大时,很消耗
    2020-01-26
  • 使用 tf.nn.dynamic_rnn 展开时间维度方式

    使用 tf.nn.dynamic_rnn 展开时间维度方式

    对于单个的 RNNCell , 使用色的 call 函数进行运算时 ,只是在序列时间上前进了一步 。 如使用 x1、 ho 得到此h1, 通过 x2 、 h1 得到 h2 等 。 tf.nn.dynamic_r
    2020-01-26
  • tensorflow ckpt模型和pb模型获取节点名称,及ckpt转pb模型实例

    tensorflow ckpt模型和pb模型获取节点名称,及ckpt转pb模型实例

    ckpt from tensorflow.python import pywrap_tensorflow checkpoint_path = 'model.ckpt-8000' reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_pa
    2020-01-26
  • tensorflow模型继续训练 fineturn实例

    tensorflow模型继续训练 fineturn实例

    解决tensoflow如何在已训练模型上继续训练fineturn的问题。 训练代码 任务描述: x = 3.0, y = 100.0, 运算公式 x×W+b = y,求 W和b的最优解。 # -*- coding: u
    2020-01-26

最新评论