2
2
import torch .nn .functional as F
3
3
4
4
5
- class ConvLayer (nn .Module ):
6
- def __init__ (self , c_in ):
7
- super (ConvLayer , self ).__init__ ()
8
- self .downConv = nn .Conv1d (in_channels = c_in ,
9
- out_channels = c_in ,
10
- kernel_size = 3 ,
11
- padding = 2 ,
12
- padding_mode = 'circular' )
13
- self .norm = nn .BatchNorm1d (c_in )
14
- self .activation = nn .ELU ()
15
- self .maxPool = nn .MaxPool1d (kernel_size = 3 , stride = 2 , padding = 1 )
16
-
17
- def forward (self , x ):
18
- x = self .downConv (x .permute (0 , 2 , 1 ))
19
- x = self .norm (x )
20
- x = self .activation (x )
21
- x = self .maxPool (x )
22
- x = x .transpose (1 , 2 )
23
- return x
24
-
25
-
26
5
class EncoderLayer (nn .Module ):
27
6
def __init__ (self , attention , d_model , d_ff = None , dropout = 0.1 , activation = "relu" ):
28
7
super (EncoderLayer , self ).__init__ ()
@@ -52,6 +31,73 @@ def forward(self, x, attn_mask=None, tau=None, delta=None):
52
31
return self .norm2 (x + y ), attn
53
32
54
33
34
+ class DecoderLayer (nn .Module ):
35
+ def __init__ (self , self_attention , cross_attention , d_model , d_ff = None ,
36
+ dropout = 0.1 , activation = "relu" ):
37
+ super (DecoderLayer , self ).__init__ ()
38
+ d_ff = d_ff or 4 * d_model
39
+ self .self_attention = self_attention
40
+ self .cross_attention = cross_attention
41
+ self .conv1 = nn .Conv1d (in_channels = d_model ,
42
+ out_channels = d_ff , kernel_size = 1 )
43
+ self .conv2 = nn .Conv1d (
44
+ in_channels = d_ff , out_channels = d_model , kernel_size = 1 )
45
+ self .norm1 = nn .LayerNorm (d_model )
46
+ self .norm2 = nn .LayerNorm (d_model )
47
+ self .norm3 = nn .LayerNorm (d_model )
48
+ self .dropout = nn .Dropout (dropout )
49
+ self .activation = F .relu if activation == "relu" else F .gelu
50
+
51
+ def forward (self , x , cross , x_mask = None , cross_mask = None , tau = None , delta = None ):
52
+ x = x + self .dropout (self .self_attention (
53
+ x , x , x ,
54
+ attn_mask = x_mask ,
55
+ tau = tau , delta = None
56
+ )[0 ])
57
+ x = self .norm1 (x )
58
+
59
+ x = x + self .dropout (self .cross_attention (
60
+ x , cross , cross ,
61
+ attn_mask = cross_mask ,
62
+ tau = tau , delta = delta
63
+ )[0 ])
64
+
65
+ y = x = self .norm2 (x )
66
+ y = self .dropout (self .activation (self .conv1 (y .transpose (- 1 , 1 ))))
67
+ y = self .dropout (self .conv2 (y ).transpose (- 1 , 1 ))
68
+
69
+ return self .norm3 (x + y )
70
+
71
+
72
+ class DecoderOnlyLayer (nn .Module ):
73
+ def __init__ (self , attention , d_model , d_ff = None , dropout = 0.1 , activation = "relu" ):
74
+ super (DecoderOnlyLayer , self ).__init__ ()
75
+ d_ff = d_ff or 4 * d_model
76
+ self .attention = attention
77
+ self .conv1 = nn .Conv1d (in_channels = d_model ,
78
+ out_channels = d_ff , kernel_size = 1 )
79
+ self .conv2 = nn .Conv1d (
80
+ in_channels = d_ff , out_channels = d_model , kernel_size = 1 )
81
+ self .norm1 = nn .LayerNorm (d_model )
82
+ self .norm2 = nn .LayerNorm (d_model )
83
+ self .dropout = nn .Dropout (dropout )
84
+ self .activation = F .relu if activation == "relu" else F .gelu
85
+
86
+ def forward (self , x , attn_mask = None , tau = None , delta = None ):
87
+ new_x , attn = self .attention (
88
+ x , x , x ,
89
+ attn_mask = attn_mask ,
90
+ tau = tau , delta = delta
91
+ )
92
+ x = x + self .dropout (new_x )
93
+
94
+ y = x = self .norm1 (x )
95
+ y = self .dropout (self .activation (self .conv1 (y .transpose (- 1 , 1 ))))
96
+ y = self .dropout (self .conv2 (y ).transpose (- 1 , 1 ))
97
+
98
+ return self .norm2 (x + y ), attn
99
+
100
+
55
101
class TimerLayer (nn .Module ):
56
102
def __init__ (self , attention , d_model , d_ff = None , dropout = 0.1 , activation = "relu" ):
57
103
super (TimerLayer , self ).__init__ ()
@@ -115,44 +161,6 @@ def forward(self, x, attn_mask=None, tau=None, delta=None):
115
161
return x , attns
116
162
117
163
118
- class DecoderLayer (nn .Module ):
119
- def __init__ (self , self_attention , cross_attention , d_model , d_ff = None ,
120
- dropout = 0.1 , activation = "relu" ):
121
- super (DecoderLayer , self ).__init__ ()
122
- d_ff = d_ff or 4 * d_model
123
- self .self_attention = self_attention
124
- self .cross_attention = cross_attention
125
- self .conv1 = nn .Conv1d (in_channels = d_model ,
126
- out_channels = d_ff , kernel_size = 1 )
127
- self .conv2 = nn .Conv1d (
128
- in_channels = d_ff , out_channels = d_model , kernel_size = 1 )
129
- self .norm1 = nn .LayerNorm (d_model )
130
- self .norm2 = nn .LayerNorm (d_model )
131
- self .norm3 = nn .LayerNorm (d_model )
132
- self .dropout = nn .Dropout (dropout )
133
- self .activation = F .relu if activation == "relu" else F .gelu
134
-
135
- def forward (self , x , cross , x_mask = None , cross_mask = None , tau = None , delta = None ):
136
- x = x + self .dropout (self .self_attention (
137
- x , x , x ,
138
- attn_mask = x_mask ,
139
- tau = tau , delta = None
140
- )[0 ])
141
- x = self .norm1 (x )
142
-
143
- x = x + self .dropout (self .cross_attention (
144
- x , cross , cross ,
145
- attn_mask = cross_mask ,
146
- tau = tau , delta = delta
147
- )[0 ])
148
-
149
- y = x = self .norm2 (x )
150
- y = self .dropout (self .activation (self .conv1 (y .transpose (- 1 , 1 ))))
151
- y = self .dropout (self .conv2 (y ).transpose (- 1 , 1 ))
152
-
153
- return self .norm3 (x + y )
154
-
155
-
156
164
class Decoder (nn .Module ):
157
165
def __init__ (self , layers , norm_layer = None , projection = None ):
158
166
super (Decoder , self ).__init__ ()
@@ -173,6 +181,38 @@ def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
173
181
return x
174
182
175
183
184
+ class DecoderOnly (nn .Module ):
185
+ def __init__ (self , attn_layers , conv_layers = None , norm_layer = None ):
186
+ super (DecoderOnly , self ).__init__ ()
187
+ self .attn_layers = nn .ModuleList (attn_layers )
188
+ self .conv_layers = nn .ModuleList (
189
+ conv_layers ) if conv_layers is not None else None
190
+ self .norm = norm_layer
191
+
192
+ def forward (self , x , attn_mask = None , tau = None , delta = None ):
193
+ # x [B, L, D]
194
+ attns = []
195
+ if self .conv_layers is not None :
196
+ for i , (attn_layer , conv_layer ) in enumerate (zip (self .attn_layers , self .conv_layers )):
197
+ delta = delta if i == 0 else None
198
+ x , attn = attn_layer (
199
+ x , attn_mask = attn_mask , tau = tau , delta = delta )
200
+ x = conv_layer (x )
201
+ attns .append (attn )
202
+ x , attn = self .attn_layers [- 1 ](x , tau = tau , delta = None )
203
+ attns .append (attn )
204
+ else :
205
+ for attn_layer in self .attn_layers :
206
+ x , attn = attn_layer (
207
+ x , attn_mask = attn_mask , tau = tau , delta = delta )
208
+ attns .append (attn )
209
+
210
+ if self .norm is not None :
211
+ x = self .norm (x )
212
+
213
+ return x , attns
214
+
215
+
176
216
class TimerBlock (nn .Module ):
177
217
def __init__ (self , attn_layers , conv_layers = None , norm_layer = None ):
178
218
super (TimerBlock , self ).__init__ ()
0 commit comments