Skip to content

Commit 4ff1481

Browse files
committed
Final state of models (needs cleanup/review).
1 parent b4286ae commit 4ff1481

File tree

4 files changed

+4
-3
lines changed

4 files changed

+4
-3
lines changed

models/code2seq-merged/model.py

+1
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,7 @@ def adv_train(self, lamb=0.0):
360360
eval_input_tensors = self.sess.run(eval_input_tensors_data)
361361
if lamb>0 and transf==0:
362362
orig_tensors = eval_input_tensors
363+
continue
363364
eval_curr_loss = self.sess.run(eval_graph_loss, feed_dict={eval_target_index: eval_input_tensors[reader.TARGET_INDEX_KEY], eval_target_lengths: eval_input_tensors[reader.TARGET_LENGTH_KEY], eval_path_source_indices: eval_input_tensors[reader.PATH_SOURCE_INDICES_KEY], eval_node_indices: eval_input_tensors[reader.NODE_INDICES_KEY], eval_path_target_indices: eval_input_tensors[reader.PATH_TARGET_INDICES_KEY], eval_valid_context_mask: eval_input_tensors[reader.VALID_CONTEXT_MASK_KEY], eval_path_source_lengths: eval_input_tensors[reader.PATH_SOURCE_LENGTHS_KEY], eval_path_lengths: eval_input_tensors[reader.PATH_LENGTHS_KEY], eval_path_target_lengths: eval_input_tensors[reader.PATH_TARGET_LENGTHS_KEY]})
364365
# print(eval_curr_loss)
365366
if eval_curr_loss > worst_loss:

models/pytorch-seq2seq/evaluate.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def parse_args():
3131
help='Path to experiment directory. If load_checkpoint is True, then path to checkpoint directory has to be provided')
3232
parser.add_argument('--load_checkpoint', action='store', dest='load_checkpoint', default='Best_F1',
3333
help='The name of the checkpoint to load, usually an encoded time string')
34-
parser.add_argument('--batch_size', action='store', dest='batch_size', default=128, type=int)
34+
parser.add_argument('--batch_size', action='store', dest='batch_size', default=32, type=int)
3535
parser.add_argument('--output_dir', action='store', dest='output_dir', default=None)
3636
parser.add_argument('--src_field_name', action='store', dest='src_field_name', default='src')
3737
parser.add_argument('--save', action='store_true', default=False)

models/pytorch-seq2seq/gradient_attack.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def parse_args():
3232
parser.add_argument('--distinct', action='store_true', dest='distinct', default=True)
3333
parser.add_argument('--no-distinct', action='store_false', dest='distinct')
3434
parser.add_argument('--no_gradient', action='store_true', dest='no_gradient', default=False)
35-
parser.add_argument('--batch_size', type=int, default=32)
35+
parser.add_argument('--batch_size', type=int, default=16)
3636
parser.add_argument('--save_path', default=None)
3737
parser.add_argument('--random', action='store_true', default=False, help='Also generate random attack')
3838
opt = parser.parse_args()

models/pytorch-seq2seq/train_adv.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def len_filter(example):
209209

210210

211211
# train with lamb*normal_loss + (1-lamb)*adv_loss
212-
lamb = np.linspace(opt.lamb, 0.0, opt.epochs)
212+
lamb = opt.lamb
213213
print(lamb)
214214

215215
load_checkpoint_path = None if opt.load_checkpoint is None else \

0 commit comments

Comments
 (0)