Skip to content

Commit bcd4a8e

Browse files
authored
Add files via upload
1 parent b3e9d71 commit bcd4a8e

File tree

2 files changed

+419
-0
lines changed

2 files changed

+419
-0
lines changed

model.py

+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import transformers
5+
6+
from transformers import BertTokenizer, BertModel
7+
from transformers import AdamW, BertForSequenceClassification,XLMRobertaForSequenceClassification
8+
9+
class CNN(nn.Module):
10+
def __init__(self,max_len=30,word_dim=300,class_size=2,size='normal'):
11+
super(CNN, self).__init__()
12+
13+
self.MAX_SENT_LEN = max_len
14+
self.WORD_DIM = word_dim
15+
self.CLASS_SIZE = class_size
16+
print("size=",size)
17+
if size=='normal':
18+
print("Init Normal")
19+
self.FILTERS = [2,3,4]
20+
self.FILTER_NUM = [100, 100, 100]
21+
self.fc = nn.Linear(sum(self.FILTER_NUM), self.CLASS_SIZE)
22+
elif size=='tiny':
23+
print("Tiny Size")
24+
self.FILTERS = [3]
25+
self.FILTER_NUM = [20]
26+
self.fc = nn.Linear(sum(self.FILTER_NUM), self.CLASS_SIZE)
27+
self.DROPOUT_PROB = 0.5
28+
self.IN_CHANNEL = 1
29+
30+
assert (len(self.FILTERS) == len(self.FILTER_NUM))
31+
32+
for i in range(len(self.FILTERS)):
33+
conv = nn.Conv1d(self.IN_CHANNEL, self.FILTER_NUM[i], self.WORD_DIM * self.FILTERS[i], stride=self.WORD_DIM)
34+
setattr(self, f'conv_{i}', conv)
35+
36+
37+
def get_conv(self, i):
38+
return getattr(self, f'conv_{i}')
39+
40+
def forward(self, inp):
41+
# [B 1 C]
42+
x = inp.view(-1, 1, self.WORD_DIM * self.MAX_SENT_LEN)
43+
# print(x.size())
44+
conv_results = [
45+
F.max_pool1d(F.relu(self.get_conv(i)(x)), self.MAX_SENT_LEN - self.FILTERS[i] + 1)
46+
.view(-1, self.FILTER_NUM[i])
47+
for i in range(len(self.FILTERS))]
48+
49+
x = torch.cat(conv_results, 1)
50+
x = F.dropout(x, p=self.DROPOUT_PROB, training=self.training)
51+
x = self.fc(x)
52+
# x = torch.softmax(x,1)
53+
return x
54+
55+
# Attention-Based Bidirectional Long Short-Term Memory Networks for Relation Classification
56+
class BLSTMATT(nn.Module):
57+
def __init__(self, max_len=30,word_dim=300,class_size=2):
58+
super(BLSTMATT,self).__init__()
59+
self.hidden_dim = 50
60+
self.emb_dim = word_dims
61+
self.dropout = 0.3
62+
self.encoder = nn.LSTM(self.emb_dim, self.hidden_dim, num_layers=2, bidirectional=True, dropout=self.dropout)
63+
self.fc = nn.Linear(self.hidden_dim, class_size)
64+
self.dropout = nn.Dropout(self.dropout)
65+
#self.hidden = nn.Parameters(self.batch_size, self.hidden_dim)
66+
67+
def attnetwork(self, encoder_out, final_hidden):
68+
hidden = final_hidden.squeeze(0)
69+
#M = torch.tanh(encoder_out)
70+
attn_weights = torch.bmm(encoder_out, hidden.unsqueeze(2)).squeeze(2)
71+
soft_attn_weights = F.softmax(attn_weights, 1)
72+
new_hidden = torch.bmm(encoder_out.transpose(1,2), soft_attn_weights.unsqueeze(2)).squeeze(2)
73+
#print (wt.shape, new_hidden.shape)
74+
#new_hidden = torch.tanh(new_hidden)
75+
#print ('UP:', new_hidden, new_hidden.shape)
76+
77+
return new_hidden
78+
79+
def forward(self, sequence):
80+
# emb_input = self.embedding(sequence)
81+
inputx = self.dropout(sequence)
82+
output, (hn, cn) = self.encoder(inputx)
83+
fbout = output[:, :, :self.hidden_dim]+ output[:, :, self.hidden_dim:] #sum bidir outputs F+B
84+
fbout = fbout.permute(1,0,2)
85+
fbhn = (hn[-2,:,:]+hn[-1,:,:]).unsqueeze(0)
86+
#print (fbhn.shape, fbout.shape)
87+
attn_out = self.attnetwork(fbout, fbhn)
88+
#attn1_out = self.attnetwork1(output, hn)
89+
logits = self.fc(attn_out)
90+
return logits

0 commit comments

Comments
 (0)