diff --git a/fawkes/protection.py b/fawkes/protection.py index d13eef4..108698e 100644 --- a/fawkes/protection.py +++ b/fawkes/protection.py @@ -208,10 +208,17 @@ def main(*argv): image_paths = [path for path in image_paths if "_cloaked" not in path.split("/")[-1]] protector = Fawkes(args.feature_extractor, args.gpu, args.batch_size) - protector.run_protection(image_paths, mode=args.mode, th=args.th, sd=args.sd, lr=args.lr, - max_step=args.max_step, - batch_size=args.batch_size, format=args.format, - separate_target=args.separate_target, debug=args.debug, no_align=args.no_align) + if args.mode == 'all': + for mode in ['min', 'low', 'mid', 'high']: + protector.run_protection(image_paths, mode=mode, th=args.th, sd=args.sd, lr=args.lr, + max_step=args.max_step, + batch_size=args.batch_size, format=args.format, + separate_target=args.separate_target, debug=args.debug, no_align=args.no_align) + else: + protector.run_protection(image_paths, mode=args.mode, th=args.th, sd=args.sd, lr=args.lr, + max_step=args.max_step, + batch_size=args.batch_size, format=args.format, + separate_target=args.separate_target, debug=args.debug, no_align=args.no_align) if __name__ == '__main__':