变长循环神经网络 [Pytorch]

LSTM
在使用循环神经网络时,经常碰到可变长数据,就是每一个样本的时间步是不一样的。这里总结一下pytorch里面的处理方法。

1
2
3
4
5
6
'''导入相关包'''
import torch
from torch import nn
import torch.nn.utils.rnn as rnn_utils
from torch.utils.data import DataLoader
import torch.utils.data as data
1
2
3
4
5
6
7
8
9
10
11
12
'''
人工给出一小部分数据做示例
这里,一共包含7个数据样本,每个样本的长度不一样,分别是:7,6,5,...,1
数据的维度为1维
'''
train_x = [torch.Tensor([1, 1, 1, 1, 1, 1, 1]),
torch.Tensor([2, 2, 2, 2, 2, 2]),
torch.Tensor([3, 3, 3, 3, 3]),
torch.Tensor([4, 4, 4, 4]),
torch.Tensor([5, 5, 5]),
torch.Tensor([6, 6]),
torch.Tensor([7])]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
'''
长度不一样,在代码中没法处理,因此必须将短的数据增长到与最长的数据长度一样
一般都是补0,在pytorch里面有专门的函数: pad_sequence
这样一来,数据将变为:
[torch.Tensor([1, 1, 1, 1, 1, 1, 1]),
torch.Tensor([2, 2, 2, 2, 2, 2, 0]),
torch.Tensor([3, 3, 3, 3, 3, 0, 0]),
torch.Tensor([4, 4, 4, 4, 0, 0, 0]),
torch.Tensor([5, 5, 5, 0, 0, 0, 0]),
torch.Tensor([6, 6, 0, 0, 0, 0, 0]),
torch.Tensor([7, 0, 0, 0, 0, 0, 0])]
这里我们只是给出例子,看看pad_sequence的作用,为了能够训练模型,需要将其封装到数据集Dataset里面,方便DataLoader
'''
x = rnn_utils.pad_sequence(train_x, batch_first=True)
print(x)
tensor([[ 1.,  1.,  1.,  1.,  1.,  1.,  1.],
        [ 2.,  2.,  2.,  2.,  2.,  2.,  0.],
        [ 3.,  3.,  3.,  3.,  3.,  0.,  0.],
        [ 4.,  4.,  4.,  4.,  0.,  0.,  0.],
        [ 5.,  5.,  5.,  0.,  0.,  0.,  0.],
        [ 6.,  6.,  0.,  0.,  0.,  0.,  0.],
        [ 7.,  0.,  0.,  0.,  0.,  0.,  0.]])
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
'''
封装到数据集MyData
实际上是定义了一个collate_fn函数,它是用来控制在DataLoader中,load数据的时候,返回数据的格式
'''

class MyData(data.Dataset):
def __init__(self, data_seq):
self.data_seq = data_seq

def __len__(self):
return len(self.data_seq)

def __getitem__(self, idx):
return self.data_seq[idx]


def collate_fn(data):
data.sort(key=lambda x: len(x), reverse=True) #将数据按数据长度从大到小排列
data_length = [len(sq) for sq in data]
data = rnn_utils.pad_sequence(data, batch_first=True, padding_value=0)
return data.unsqueeze(-1), data_length #这里的unsqueeze是为了将数据从7x7变为7x7x1,符合模型的输入数据格式,共7个样本,维度为1,pad之后的"长度"为7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
'''
测试
'''
if __name__=='__main__':

data = MyData(train_x) #将原始数据通过MyData封装

data_loader = DataLoader(data, batch_size=3, shuffle=True,collate_fn=collate_fn) #封装进DataLoader,通过collate_fn控制返回数据的格式,即对每一个样本都进行pad

batch_x, batch_x_len = iter(data_loader).next()

'''
这里batch_size=3,因此batch_x的格式为:
[[[ 1.], [1.], [1.], [1.], [1.], [1.], [1.]],
[[ 3.], [3.], [3.], [3.], [3.], [0.], [0.]],
[[ 6.], [6.], [0.], [0.], [0.], [0.], [0.]]]

为了不让后面的0参与运算,即运算完所有的非0数据后就停止
pytorch里面需要对batch_x进行pack,即压缩
压缩后的数据格式,以及为设么这么压缩,和pytorch里面循环神经网络的工作原理有关,这里不做详细介绍
'''

batch_x_pack = rnn_utils.pack_padded_sequence(batch_x, batch_x_len, batch_first=True)


'''
构建模型,初始化h0,c0
'''
net = nn.LSTM(1, 10, 2, batch_first=True)
h0 = torch.rand(2, 3, 10)
c0 = torch.rand(2, 3, 10)

'''通过LSTM计算'''
out, (h1, c1) = net(batch_x_pack, (h0, c0))

'''这里,模型输出的数据和输入的数据格式一样,都是被压缩过的,这里需要将其还原成正常的矩阵形式,使用pad_packed_sequence函数'''
out_pad, out_len = rnn_utils.pad_packed_sequence(out, batch_first=True)

print(out_pad)
print('END')
tensor([[[ 0.0490,  0.1100,  0.0750,  0.0998,  0.3622, -0.0891,  0.1562, 0.0122, -0.1227,  0.0537],
         [ 0.0113,  0.0290,  0.0375,  0.1051,  0.2836, -0.1209,  0.1458, -0.0848, -0.1296, -0.0788],
         [ 0.0048, -0.0075,  0.0173,  0.0673,  0.2713, -0.1431,  0.1392, -0.1227, -0.1214, -0.1425],
         [ 0.0102, -0.0246,  0.0090,  0.0291,  0.2600, -0.1579,  0.1342, -0.1392, -0.1135, -0.1720],
         [ 0.0177, -0.0328,  0.0058, -0.0016,  0.2524, -0.1678,  0.1313, -0.1480, -0.1091, -0.1863],
         [ 0.0240, -0.0371,  0.0048, -0.0244,  0.2475, -0.1742,  0.1300, -0.1542, -0.1074, -0.1939],
         [ 0.0286, -0.0398,  0.0047, -0.0408,  0.2445, -0.1782,  0.1298, -0.1592, -0.1071, -0.1983]],

        [[ 0.2161, -0.0114,  0.1041,  0.0859,  0.0486, -0.0040,  0.2124, 0.1901,  0.1633,  0.1281],
         [ 0.1089, -0.0357,  0.0969,  0.0758,  0.1588, -0.1085,  0.1722, 0.0266, -0.0036, -0.0442],
         [ 0.0438, -0.0525,  0.0566,  0.0414,  0.2152, -0.1581,  0.1527, -0.0615, -0.0681, -0.1262],
         [ 0.0199, -0.0631,  0.0303,  0.0025,  0.2376, -0.1802,  0.1418, -0.1058, -0.0890, -0.1643],
         [ 0.0137, -0.0696,  0.0154, -0.0301,  0.2450, -0.1906,  0.1362, -0.1309, -0.0954, -0.1821],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 0.0000,  0.0000,  0.0000]],

        [[-0.0518,  0.0285,  0.1449,  0.1145,  0.2673, -0.0603,  0.1157, .1086, -0.0482, -0.0614],
         [-0.0578, -0.0089,  0.0664,  0.0769,  0.2777, -0.1151,  0.1328, -0.0094, -0.0666, -0.1451],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, 0.0000,  0.0000,  0.0000]]])
END


~赞~

么么哒,请我喝杯咖啡吧~

支付宝
微信