#lstm #pytorch #candle

candle-birnn

implement Pytorch LSTM and BiDirectional LSTM with Candle

5 releases

0.2.3 Sep 25, 2024
0.2.2 Aug 17, 2024
0.2.1 Aug 13, 2024
0.2.0 Aug 12, 2024
0.1.0 Aug 11, 2024

#829 in Machine learning

Custom license

42KB
295 lines

Contains (Zip file, 29KB) bi_lstm_test.pt, (Zip file, 16KB) lstm_test.pt

Candle BiRNN

Implementing PyTorch LSTM inference using Candle, including the implementation of bidirectional LSTM inference.

Test Data

  1. lstm_test.pt: Results generated using a PyTorch demo program. The code is as follows:

    import torch
    import torch.nn as nn
    
    rnn = nn.LSTM(10, 20, 1)
    input = torch.randn(5, 3, 10)
    output, (hn, cn) = rnn(input)
    
    state_dict = rnn.state_dict()
    state_dict['input'] = input
    state_dict['output'] = output
    state_dict['hn'] = hn
    state_dict['cn'] = cn
    torch.save(state_dict, "lstm_test.pt")
    
  2. bi_lstm_test.pt: Results generated using a PyTorch demo program. The code is as follows:

    import torch
    import torch.nn as nn
    
    rnn = nn.LSTM(10, 20, 1, bidirectional=True)
    input = torch.randn(5, 3, 10)
    output, (hn, cn) = rnn(input)
    
    state_dict = rnn.state_dict()
    state_dict['input'] = input
    state_dict['output'] = output
    state_dict['hn'] = hn
    state_dict['cn'] = cn
    torch.save(state_dict, "bi_lstm_test.pt")
    

Dependencies

~9–19MB
~333K SLoC