diff --git a/pytorch_ner/train.py b/pytorch_ner/train.py index 60e1d52..df70f71 100644 --- a/pytorch_ner/train.py +++ b/pytorch_ner/train.py @@ -49,8 +49,8 @@ def train_loop( for tokens, labels, lengths in dataloader: tokens, labels, lengths = ( - tokens.to(device), - labels.to(device), + tokens.to(device).to(torch.int64), + labels.to(device).to(torch.int64), lengths.to(device), ) @@ -103,8 +103,8 @@ def validate_loop( for tokens, labels, lengths in dataloader: tokens, labels, lengths = ( - tokens.to(device), - labels.to(device), + tokens.to(device).to(torch.int64), + labels.to(device).to(torch.int64), lengths.to(device), )