Skip to content

Commit bbd5921

Browse files
committed
Update models and add DecoderOnly part
1 parent 8912225 commit bbd5921

File tree

6 files changed

+159
-121
lines changed

6 files changed

+159
-121
lines changed

exp/exp_forecast.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def vali(self, vali_data, vali_loader, criterion, is_test=False):
7070
batch_x_mark = batch_x_mark.float().to(self.device)
7171
batch_y_mark = batch_y_mark.float().to(self.device)
7272

73-
outputs = self.model(batch_x, batch_x_mark, None, batch_y_mark)
73+
outputs = self.model(batch_x, batch_x_mark, batch_y_mark)
7474
if is_test or self.args.nonautoregressive:
7575
outputs = outputs[:, -self.args.output_token_len:, :]
7676
batch_y = batch_y[:, -self.args.output_token_len:, :].to(self.device)
@@ -138,7 +138,7 @@ def train(self, setting):
138138
batch_x_mark = batch_x_mark.float().to(self.device)
139139
batch_y_mark = batch_y_mark.float().to(self.device)
140140

141-
outputs = self.model(batch_x, batch_x_mark, None, batch_y_mark)
141+
outputs = self.model(batch_x, batch_x_mark, batch_y_mark)
142142
if self.args.dp:
143143
torch.cuda.synchronize()
144144
if self.args.nonautoregressive:
@@ -228,7 +228,7 @@ def test(self, setting, test=0):
228228
for j in range(inference_steps):
229229
if len(pred_y) != 0:
230230
batch_x = torch.cat([batch_x[:, self.args.input_token_len:, :], pred_y[-1]], dim=1)
231-
outputs = self.model(batch_x, batch_x_mark, None, batch_y_mark)
231+
outputs = self.model(batch_x, batch_x_mark, batch_y_mark)
232232
pred_y.append(outputs[:, -self.args.output_token_len:, :])
233233
pred_y = torch.cat(pred_y, dim=1)
234234
if dis != 0:

layers/Transformer_EncDec.py

+99-59
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,6 @@
22
import torch.nn.functional as F
33

44

5-
class ConvLayer(nn.Module):
6-
def __init__(self, c_in):
7-
super(ConvLayer, self).__init__()
8-
self.downConv = nn.Conv1d(in_channels=c_in,
9-
out_channels=c_in,
10-
kernel_size=3,
11-
padding=2,
12-
padding_mode='circular')
13-
self.norm = nn.BatchNorm1d(c_in)
14-
self.activation = nn.ELU()
15-
self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
16-
17-
def forward(self, x):
18-
x = self.downConv(x.permute(0, 2, 1))
19-
x = self.norm(x)
20-
x = self.activation(x)
21-
x = self.maxPool(x)
22-
x = x.transpose(1, 2)
23-
return x
24-
25-
265
class EncoderLayer(nn.Module):
276
def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
287
super(EncoderLayer, self).__init__()
@@ -52,6 +31,73 @@ def forward(self, x, attn_mask=None, tau=None, delta=None):
5231
return self.norm2(x + y), attn
5332

5433

34+
class DecoderLayer(nn.Module):
35+
def __init__(self, self_attention, cross_attention, d_model, d_ff=None,
36+
dropout=0.1, activation="relu"):
37+
super(DecoderLayer, self).__init__()
38+
d_ff = d_ff or 4 * d_model
39+
self.self_attention = self_attention
40+
self.cross_attention = cross_attention
41+
self.conv1 = nn.Conv1d(in_channels=d_model,
42+
out_channels=d_ff, kernel_size=1)
43+
self.conv2 = nn.Conv1d(
44+
in_channels=d_ff, out_channels=d_model, kernel_size=1)
45+
self.norm1 = nn.LayerNorm(d_model)
46+
self.norm2 = nn.LayerNorm(d_model)
47+
self.norm3 = nn.LayerNorm(d_model)
48+
self.dropout = nn.Dropout(dropout)
49+
self.activation = F.relu if activation == "relu" else F.gelu
50+
51+
def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
52+
x = x + self.dropout(self.self_attention(
53+
x, x, x,
54+
attn_mask=x_mask,
55+
tau=tau, delta=None
56+
)[0])
57+
x = self.norm1(x)
58+
59+
x = x + self.dropout(self.cross_attention(
60+
x, cross, cross,
61+
attn_mask=cross_mask,
62+
tau=tau, delta=delta
63+
)[0])
64+
65+
y = x = self.norm2(x)
66+
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
67+
y = self.dropout(self.conv2(y).transpose(-1, 1))
68+
69+
return self.norm3(x + y)
70+
71+
72+
class DecoderOnlyLayer(nn.Module):
73+
def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
74+
super(DecoderOnlyLayer, self).__init__()
75+
d_ff = d_ff or 4 * d_model
76+
self.attention = attention
77+
self.conv1 = nn.Conv1d(in_channels=d_model,
78+
out_channels=d_ff, kernel_size=1)
79+
self.conv2 = nn.Conv1d(
80+
in_channels=d_ff, out_channels=d_model, kernel_size=1)
81+
self.norm1 = nn.LayerNorm(d_model)
82+
self.norm2 = nn.LayerNorm(d_model)
83+
self.dropout = nn.Dropout(dropout)
84+
self.activation = F.relu if activation == "relu" else F.gelu
85+
86+
def forward(self, x, attn_mask=None, tau=None, delta=None):
87+
new_x, attn = self.attention(
88+
x, x, x,
89+
attn_mask=attn_mask,
90+
tau=tau, delta=delta
91+
)
92+
x = x + self.dropout(new_x)
93+
94+
y = x = self.norm1(x)
95+
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
96+
y = self.dropout(self.conv2(y).transpose(-1, 1))
97+
98+
return self.norm2(x + y), attn
99+
100+
55101
class TimerLayer(nn.Module):
56102
def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
57103
super(TimerLayer, self).__init__()
@@ -115,44 +161,6 @@ def forward(self, x, attn_mask=None, tau=None, delta=None):
115161
return x, attns
116162

117163

118-
class DecoderLayer(nn.Module):
119-
def __init__(self, self_attention, cross_attention, d_model, d_ff=None,
120-
dropout=0.1, activation="relu"):
121-
super(DecoderLayer, self).__init__()
122-
d_ff = d_ff or 4 * d_model
123-
self.self_attention = self_attention
124-
self.cross_attention = cross_attention
125-
self.conv1 = nn.Conv1d(in_channels=d_model,
126-
out_channels=d_ff, kernel_size=1)
127-
self.conv2 = nn.Conv1d(
128-
in_channels=d_ff, out_channels=d_model, kernel_size=1)
129-
self.norm1 = nn.LayerNorm(d_model)
130-
self.norm2 = nn.LayerNorm(d_model)
131-
self.norm3 = nn.LayerNorm(d_model)
132-
self.dropout = nn.Dropout(dropout)
133-
self.activation = F.relu if activation == "relu" else F.gelu
134-
135-
def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
136-
x = x + self.dropout(self.self_attention(
137-
x, x, x,
138-
attn_mask=x_mask,
139-
tau=tau, delta=None
140-
)[0])
141-
x = self.norm1(x)
142-
143-
x = x + self.dropout(self.cross_attention(
144-
x, cross, cross,
145-
attn_mask=cross_mask,
146-
tau=tau, delta=delta
147-
)[0])
148-
149-
y = x = self.norm2(x)
150-
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
151-
y = self.dropout(self.conv2(y).transpose(-1, 1))
152-
153-
return self.norm3(x + y)
154-
155-
156164
class Decoder(nn.Module):
157165
def __init__(self, layers, norm_layer=None, projection=None):
158166
super(Decoder, self).__init__()
@@ -173,6 +181,38 @@ def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
173181
return x
174182

175183

184+
class DecoderOnly(nn.Module):
185+
def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
186+
super(DecoderOnly, self).__init__()
187+
self.attn_layers = nn.ModuleList(attn_layers)
188+
self.conv_layers = nn.ModuleList(
189+
conv_layers) if conv_layers is not None else None
190+
self.norm = norm_layer
191+
192+
def forward(self, x, attn_mask=None, tau=None, delta=None):
193+
# x [B, L, D]
194+
attns = []
195+
if self.conv_layers is not None:
196+
for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)):
197+
delta = delta if i == 0 else None
198+
x, attn = attn_layer(
199+
x, attn_mask=attn_mask, tau=tau, delta=delta)
200+
x = conv_layer(x)
201+
attns.append(attn)
202+
x, attn = self.attn_layers[-1](x, tau=tau, delta=None)
203+
attns.append(attn)
204+
else:
205+
for attn_layer in self.attn_layers:
206+
x, attn = attn_layer(
207+
x, attn_mask=attn_mask, tau=tau, delta=delta)
208+
attns.append(attn)
209+
210+
if self.norm is not None:
211+
x = self.norm(x)
212+
213+
return x, attns
214+
215+
176216
class TimerBlock(nn.Module):
177217
def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
178218
super(TimerBlock, self).__init__()

models/moirai.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -28,24 +28,24 @@ def __init__(self, configs):
2828
)
2929
self.head = nn.Linear(configs.d_model, configs.input_token_len)
3030

31-
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
31+
def forecast(self, x, x_mark, y_mark):
3232
if self.use_norm:
33-
means = x_enc.mean(1, keepdim=True).detach()
34-
x_enc = x_enc - means
33+
means = x.mean(1, keepdim=True).detach()
34+
x = x - means
3535
stdev = torch.sqrt(
36-
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
37-
x_enc /= stdev
36+
torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5)
37+
x /= stdev
3838

39-
B, _, C = x_enc.shape
40-
padding = torch.zeros(B, self.input_token_len, C).to(x_enc.device)
41-
x_enc = torch.cat([x_enc, padding], dim=1)
39+
B, _, C = x.shape
40+
padding = torch.zeros(B, self.input_token_len, C).to(x.device)
41+
x = torch.cat([x, padding], dim=1)
4242
# [B, C, L]
43-
x_enc = x_enc.permute(0, 2, 1)
43+
x = x.permute(0, 2, 1)
4444
# [B, C, N, P]
45-
x_enc = x_enc.unfold(dimension=-1, size=self.input_token_len, step=self.input_token_len)
46-
N = x_enc.shape[2]
45+
x = x.unfold(dimension=-1, size=self.input_token_len, step=self.input_token_len)
46+
N = x.shape[2]
4747
# [B, C, N, D]
48-
enc_out = self.embedding(x_enc)
48+
enc_out = self.embedding(x)
4949
# [B, C * N, D]
5050
enc_out = enc_out.reshape(B, C * N, -1)
5151
enc_out, attns = self.encoder(enc_out, n_vars=C, n_tokens=N)
@@ -60,5 +60,5 @@ def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
6060
dec_out = dec_out * stdev + means
6161
return dec_out
6262

63-
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
64-
return self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
63+
def forward(self, x, x_mark, y_mark):
64+
return self.forecast(x, x_mark, y_mark)

models/patchtst.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -64,19 +64,19 @@ def __init__(self, configs):
6464
self.head_nf = configs.d_model * int((configs.seq_len - patch_len) / stride + 2)
6565
self.head = FlattenHead(self.head_nf, configs.test_pred_len, head_dropout=configs.dropout)
6666

67-
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
67+
def forecast(self, x, x_mark, y_mark):
6868
if self.use_norm:
6969
# Normalization from Non-stationary Transformer
70-
means = x_enc.mean(1, keepdim=True).detach()
71-
x_enc = x_enc - means
70+
means = x.mean(1, keepdim=True).detach()
71+
x = x - means
7272
stdev = torch.sqrt(
73-
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
74-
x_enc /= stdev
73+
torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5)
74+
x /= stdev
7575

7676
# do patching and embedding
77-
x_enc = x_enc.permute(0, 2, 1)
77+
x = x.permute(0, 2, 1)
7878
# u: [bs * nvars x patch_num x d_model]
79-
enc_out, n_vars = self.patch_embedding(x_enc)
79+
enc_out, n_vars = self.patch_embedding(x)
8080

8181
# Encoder
8282
# z: [bs * nvars x patch_num x d_model]
@@ -99,6 +99,6 @@ def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
9999
(means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
100100
return dec_out
101101

102-
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
103-
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
102+
def forward(self, x, x_mark, y_mark):
103+
dec_out = self.forecast(x, x_mark, y_mark)
104104
return dec_out[:, -self.pred_len:, :] # [B, L, D]

models/timer.py

+19-21
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
from torch import nn
3-
from layers.Transformer_EncDec import Encoder, EncoderLayer
3+
from layers.Transformer_EncDec import DecoderOnly, DecoderOnlyLayer
44
from layers.SelfAttention_Family import FullAttention, AttentionLayer
55
from layers.Embed import PositionalEmbedding
66

@@ -15,11 +15,9 @@ def __init__(self, configs):
1515
self.embedding = nn.Linear(self.input_token_len, configs.d_model, bias=False)
1616
self.position_embedding = PositionalEmbedding(configs.d_model)
1717
self.dropout = nn.Dropout(configs.dropout)
18-
19-
# Timer is a Decoder-only Transformer. Please refer to issue: https://github.com/thuml/Large-Time-Series-Model/issues/23
20-
self.blocks = Encoder(
18+
self.blocks = DecoderOnly(
2119
[
22-
EncoderLayer(
20+
DecoderOnlyLayer(
2321
AttentionLayer(
2422
FullAttention(True, attention_dropout=configs.dropout,
2523
output_attention=False), configs.d_model, configs.n_heads),
@@ -34,29 +32,29 @@ def __init__(self, configs):
3432
self.head = nn.Linear(configs.d_model, configs.output_token_len)
3533
self.use_norm = configs.use_norm
3634

37-
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
35+
def forecast(self, x, x_mark, y_mark):
3836
if self.use_norm:
39-
means = x_enc.mean(1, keepdim=True).detach()
40-
x_enc = x_enc - means
37+
means = x.mean(1, keepdim=True).detach()
38+
x = x - means
4139
stdev = torch.sqrt(
42-
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
43-
x_enc /= stdev
40+
torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5)
41+
x /= stdev
4442
# [B, L, C]
45-
B, _, C = x_enc.shape
43+
B, _, C = x.shape
4644
# [B, C, L]
47-
x_enc = x_enc.permute(0, 2, 1)
45+
x = x.permute(0, 2, 1)
4846
# [B, C, N, P]
49-
x_enc = x_enc.unfold(
47+
x = x.unfold(
5048
dimension=-1, size=self.input_token_len, step=self.input_token_len)
51-
N = x_enc.shape[2]
49+
N = x.shape[2]
5250
# [B * C, N, P]
53-
x_enc = x_enc.reshape(B * C, N, -1)
51+
x = x.reshape(B * C, N, -1)
5452
# [B * C, N, D]
55-
enc_out = self.embedding(x_enc) + self.position_embedding(x_enc)
56-
enc_out = self.dropout(enc_out)
57-
enc_out, attns = self.blocks(enc_out)
53+
embed_out = self.embedding(x) + self.position_embedding(x)
54+
embed_out = self.dropout(embed_out)
55+
embed_out, attns = self.blocks(embed_out)
5856
# [B * C, N, P]
59-
dec_out = self.head(enc_out)
57+
dec_out = self.head(embed_out)
6058
# [B, C, L]
6159
dec_out = dec_out.reshape(B, C, -1)
6260
# [B, L, C]
@@ -65,5 +63,5 @@ def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
6563
dec_out = dec_out * stdev + means
6664
return dec_out
6765

68-
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
69-
return self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
66+
def forward(self, x, x_mark, y_mark):
67+
return self.forecast(x, x_mark, y_mark)

0 commit comments

Comments
 (0)