@@ -396,6 +396,9 @@ def __init__(
396
396
self .test_dataloaders = None
397
397
self .val_dataloaders = None
398
398
399
+ # when .test() is called, it sets this
400
+ self .tested_ckpt_path = None
401
+
399
402
# training state
400
403
self .model = None
401
404
self .testing = False
@@ -965,6 +968,10 @@ def fit(
965
968
966
969
self .ddp_train (process_idx = task , q = None , model = model )
967
970
elif self .use_ddp :
971
+
972
+ # set testing if set in environ
973
+ self .testing = os .environ .get ('PL_TESTING_MODE' , self .testing )
974
+
968
975
if self .is_slurm_managing_tasks :
969
976
task = int (os .environ ['SLURM_LOCALID' ])
970
977
self .ddp_train (process_idx = task , q = None , model = model )
@@ -1058,7 +1065,7 @@ def __run_ddp_spawn(self, model, nprocs):
1058
1065
smp = mp .get_context ('spawn' )
1059
1066
q = smp .SimpleQueue ()
1060
1067
1061
- mp .spawn (self .ddp_train , nprocs = nprocs , args = (q , model ,))
1068
+ mp .spawn (self .ddp_train , nprocs = nprocs , args = (q , model , ))
1062
1069
1063
1070
# restore main state with best weights
1064
1071
best_path = q .get ()
@@ -1070,7 +1077,8 @@ def __run_ddp_spawn(self, model, nprocs):
1070
1077
1071
1078
# load last weights
1072
1079
if last_path is not None and not self .testing :
1073
- torch .load (last_path , map_location = lambda storage , loc : storage )
1080
+ ckpt = torch .load (last_path , map_location = lambda storage , loc : storage )
1081
+ model .load_state_dict (ckpt )
1074
1082
1075
1083
self .model = model
1076
1084
return results
@@ -1262,62 +1270,83 @@ def test(
1262
1270
# --------------------
1263
1271
# SETUP HOOK
1264
1272
# --------------------
1273
+ if self .global_rank != 0 :
1274
+ return
1275
+
1265
1276
self .setup ('test' )
1266
- model_ref = self .model if model is None else model
1267
- if self .is_function_implemented ('setup' , model_ref ):
1268
- model_ref .setup ('test' )
1277
+
1278
+ if model is not None :
1279
+ results = self .__test_given_model (model , test_dataloaders )
1280
+ else :
1281
+ results = self .__test_using_best_weights (ckpt_path , test_dataloaders )
1282
+
1283
+ self .teardown ('test' )
1284
+
1285
+ return results
1286
+
1287
+ def __test_using_best_weights (self , ckpt_path , test_dataloaders ):
1288
+ model = self .get_model ()
1289
+ if self .is_function_implemented ('setup' , model ):
1290
+ model .setup ('test' )
1269
1291
1270
1292
# if user requests the best checkpoint but we don't have it, error
1271
- if model is None and ckpt_path == 'best' and self .checkpoint_callback .save_top_k <= 0 :
1293
+ if ckpt_path == 'best' and self .checkpoint_callback .save_top_k <= 0 :
1272
1294
raise MisconfigurationException (
1273
1295
'ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.' )
1274
1296
1275
- # --------------------
1276
- # AUTO-LOAD BEST CKPT
1277
- # --------------------
1278
- # load the best checkpoint automatically unless model is given
1279
- # in which case we use that one
1280
- if model is None and ckpt_path is not None :
1297
+ # load best weights
1298
+ if ckpt_path is not None :
1281
1299
# ckpt_path is 'best' so load the best model
1282
1300
if ckpt_path == 'best' :
1283
1301
ckpt_path = self .checkpoint_callback .best_model_path
1284
- model = self .get_model ().load_from_checkpoint (ckpt_path )
1285
1302
1286
- # ----------------------------------------------------
1287
- # AUTO-LOAD BEST CKPT with the model trained in .fit()
1288
- # ----------------------------------------------------
1289
- elif model is None and ckpt_path is None :
1290
- model = model_ref
1303
+ ckpt = torch .load (ckpt_path , map_location = lambda storage , loc : storage )
1304
+ model .load_state_dict (ckpt ['state_dict' ])
1291
1305
1292
- # --------------------
1293
- # LOAD DATA
1294
- # --------------------
1306
+ # attach dataloaders
1295
1307
if test_dataloaders is not None :
1296
- if model :
1297
- self .__attach_dataloaders (model , test_dataloaders = test_dataloaders )
1298
- else :
1299
- self .__attach_dataloaders (self .model , test_dataloaders = test_dataloaders )
1308
+ self .__attach_dataloaders (model , test_dataloaders = test_dataloaders )
1300
1309
1301
- # --------------------
1302
- # RUN TEST SET
1303
- # --------------------
1304
- # sets up testing so we short circuit to eval
1310
+ # run tests
1311
+ self .tested_ckpt_path = ckpt_path
1305
1312
self .set_random_port (force = True )
1306
1313
self .testing = True
1314
+ os .environ ['PL_TESTING_MODE' ] = '1'
1307
1315
self .model = model
1308
1316
results = self .fit (model )
1309
1317
self .testing = False
1318
+ del os .environ ['PL_TESTING_MODE' ]
1310
1319
1311
- # --------------------
1312
- # TEAR DOWN HOOK
1313
- # --------------------
1314
- self .teardown ('test' )
1320
+ # teardown
1315
1321
if self .is_function_implemented ('teardown' ):
1316
1322
model_ref = self .get_model ()
1317
1323
model_ref .teardown ('test' )
1318
1324
1319
1325
return results
1320
1326
1327
+ def __test_given_model (self , model , test_dataloaders ):
1328
+ # setup hook
1329
+ if self .is_function_implemented ('setup' , model ):
1330
+ model .setup ('test' )
1331
+
1332
+ # attach data
1333
+ if test_dataloaders is not None :
1334
+ self .__attach_dataloaders (model , test_dataloaders = test_dataloaders )
1335
+
1336
+ # run test
1337
+ # sets up testing so we short circuit to eval
1338
+ self .set_random_port (force = True )
1339
+ self .testing = True
1340
+ self .model = model
1341
+ results = self .fit (model )
1342
+ self .testing = False
1343
+
1344
+ # teardown
1345
+ if self .is_function_implemented ('teardown' ):
1346
+ model .teardown ('test' )
1347
+
1348
+ return results
1349
+
1321
1350
def check_model_configuration (self , model : LightningModule ):
1322
1351
r"""
1323
1352
Checks that the model is configured correctly before training or testing is started.
0 commit comments