what is lstm learning?

154 Views
No Comments

共计 2314 个字符,预计需要花费 6 分钟才能阅读完成。

LSTM

A special RNN structure, which is very popular at present, effectively solves the problem of gradient explosion and long sequence memory of RNN

advantage

LSTM realizes the memory and forgetting of special features by introducing forgetting gate, input gate and output gate, so as to achieve better processing and memory effect of sequence data.

Schematic diagram:
what is lstm learning?

Summary formula:
what is lstm learning?
It's kind of a formula like this

Simply put, LSTM has three doors, input door, forget door, output door,
Are the degree parameters of the three gates respectively,
g is a regular RNN operation on the input.
You can see in the formula that LSTM has two outputs, cell state C'
And hidden state h'
c' is the product of the input and forget gate, that is, the content of the current cell itself, and gets h' through the output gate, that is, what content do you want to output to the next unit
So in practice, we don't care about the state of the cell itself, we want to get the state it presents
h' as the final output.

realize

Implement lstm manually with pytorch

Construction formula

class myLstm(nn.Module):
  def __init__(self,input_sz,hidden_sz):
    super().__init__()
    self.input_size=input_sz
    self.hidden_size=hidden_sz
    self.U_i=nn.Parameter(torch.Tensor(input_sz,hidden_sz))
    self.V_i = nn.Parameter(torch.Tensor(hidden_sz,hidden_sz))
    self.b_i = nn.parameter(torch.Tensor(hidden_sz))

    #f_t
    self.U_f = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
    self.V_f = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
    self.b_f = nn.Parameter(torch.Tensor(hidden_sz))

    #c_t
    self.U_c = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
    self.V_c = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
    self.b_c = nn.Parameter(torch.Tensor(hidden_sz))

    #o_t
    self.U_o = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
    self.V_o = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
    self.b_o = nn.Parameter(torch.Tensor(hidden_sz))

    self.init_weights()
  def forward(self,x,init_states=None):
      bs,seq_sz,_=x.size()
      hidden_seq=[]

      if init_states is None:
        h_t,c_t=(
            torch.zeros(bs,self.hidden_size).to(x.device),
            torch.zeros(bs,self.hidden_size).to(x.device)
        )
      else:
        h_t, c_t = init_states
      for t in range(seq_sz):
        x_t = x[:, t, :]

        i_t = torch.sigmoid(x_t @ self.U_i + h_t @ self.V_i + self.b_i)
        f_t = torch.sigmoid(x_t @ self.U_f + h_t @ self.V_f + self.b_f)
        g_t = torch.tanh(x_t @ self.U_c + h_t @ self.V_c + self.b_c)
        o_t = torch.sigmoid(x_t @ self.U_o + h_t @ self.V_o + self.b_o)
        c_t = f_t * c_t + i_t * g_t
        h_t = o_t * torch.tanh(c_t)

        hidden_seq.append(h_t.unsqueeze(0))
        hidden_seq = torch.cat(hidden_seq, dim=0)
        hidden_seq = hidden_seq.transpose(0, 1).contiguous()
        return hidden_seq, (h_t, c_t)       
END
 0
Comment(No Comments)