From e3b31b844ef21a48cc199e84ba0caa60c242498f Mon Sep 17 00:00:00 2001 From: Pavel Salikov Date: Tue, 2 Feb 2021 19:14:47 +0300 Subject: [PATCH] Fix bug with Pytorch CPU version --- pytorch_ner/train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_ner/train.py b/pytorch_ner/train.py index 60e1d52..df70f71 100644 --- a/pytorch_ner/train.py +++ b/pytorch_ner/train.py @@ -49,8 +49,8 @@ def train_loop( for tokens, labels, lengths in dataloader: tokens, labels, lengths = ( - tokens.to(device), - labels.to(device), + tokens.to(device).to(torch.int64), + labels.to(device).to(torch.int64), lengths.to(device), ) @@ -103,8 +103,8 @@ def validate_loop( for tokens, labels, lengths in dataloader: tokens, labels, lengths = ( - tokens.to(device), - labels.to(device), + tokens.to(device).to(torch.int64), + labels.to(device).to(torch.int64), lengths.to(device), )