From cd720006f6d14b5d9780e4e847decc5fc8b224ff Mon Sep 17 00:00:00 2001 From: Piero Toffanin Date: Mon, 19 Sep 2022 15:33:59 -0400 Subject: [PATCH] Add AI background removal --- opendm/bgfilter.py | 90 ++++++++++++++++++++++++++++++++++++++++++++++ opendm/config.py | 6 ++++ stages/dataset.py | 42 ++++++++++++++++++++++ 3 files changed, 138 insertions(+) create mode 100644 opendm/bgfilter.py diff --git a/opendm/bgfilter.py b/opendm/bgfilter.py new file mode 100644 index 00000000..3dfd818e --- /dev/null +++ b/opendm/bgfilter.py @@ -0,0 +1,90 @@ + +import time +import numpy as np +import cv2 +import os +import onnxruntime as ort +from opendm import log +from threading import Lock + +mutex = Lock() + +# Implementation based on https://github.com/danielgatis/rembg by Daniel Gatis + +# Use GPU if it is available, otherwise CPU +provider = "CUDAExecutionProvider" if "CUDAExecutionProvider" in ort.get_available_providers() else "CPUExecutionProvider" + +class BgFilter(): + def __init__(self, model): + self.model = model + + log.ODM_INFO(' ?> Using provider %s' % provider) + self.load_model() + + + def load_model(self): + log.ODM_INFO(' -> Loading the model') + + self.session = ort.InferenceSession(self.model, providers=[provider]) + + def normalize(self, img, mean, std, size): + im = cv2.resize(img, size, interpolation=cv2.INTER_AREA) + im_ary = np.array(im) + im_ary = im_ary / np.max(im_ary) + + tmpImg = np.zeros((im_ary.shape[0], im_ary.shape[1], 3)) + tmpImg[:, :, 0] = (im_ary[:, :, 0] - mean[0]) / std[0] + tmpImg[:, :, 1] = (im_ary[:, :, 1] - mean[1]) / std[1] + tmpImg[:, :, 2] = (im_ary[:, :, 2] - mean[2]) / std[2] + + tmpImg = tmpImg.transpose((2, 0, 1)) + + return { + self.session.get_inputs()[0] + .name: np.expand_dims(tmpImg, 0) + .astype(np.float32) + } + + def get_mask(self, img): + height, width, c = img.shape + + with mutex: + ort_outs = self.session.run( + None, + self.normalize( + img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320) # <-- image size + ), + ) + + pred = ort_outs[0][:, 0, :, :] + + ma = np.max(pred) + mi = np.min(pred) + + pred = (pred - mi) / (ma - mi) + pred = np.squeeze(pred) + + pred *= 255 + pred = pred.astype("uint8") + output = cv2.resize(pred, (width, height), interpolation=cv2.INTER_LANCZOS4) + output[output > 127] = 255 + output[output <= 127] = 0 + + return output + + def run_img(self, img_path, dest): + img = cv2.imread(img_path, cv2.IMREAD_COLOR) + if img is None: + return None + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + mask = self.get_mask(img) + + img_name = os.path.basename(img_path) + fpath = os.path.join(dest, img_name) + + fname, _ = os.path.splitext(fpath) + mask_name = fname + '_mask.png' + cv2.imwrite(mask_name, mask) + + return mask_name diff --git a/opendm/config.py b/opendm/config.py index d0450167..eccef616 100755 --- a/opendm/config.py +++ b/opendm/config.py @@ -242,6 +242,12 @@ def config(argv=None, parser=None): nargs=0, default=False, help='Automatically compute image masks using AI to remove the sky. Experimental. Default: %(default)s') + + parser.add_argument('--bg-removal', + action=StoreTrue, + nargs=0, + default=False, + help='Automatically compute image masks using AI to remove the background. Experimental. Default: %(default)s') parser.add_argument('--use-3dmesh', action=StoreTrue, diff --git a/stages/dataset.py b/stages/dataset.py index 529cee18..f8354aaa 100644 --- a/stages/dataset.py +++ b/stages/dataset.py @@ -13,6 +13,7 @@ from opendm import progress from opendm import boundary from opendm import ai from opendm.skyremoval.skyfilter import SkyFilter +from opendm.bgfilter import BgFilter from opendm.concurrency import parallel_map def save_images_database(photos, database_file): @@ -191,6 +192,47 @@ class ODMLoadDatasetStage(types.ODM_Stage): # End sky removal + # Automatic background removal + if args.bg_removal: + # For each image that : + # - Doesn't already have a mask, AND + # - There are no spaces in the image filename (OpenSfM requirement) + + # Generate list of sky images + bg_images = [] + for p in photos: + if p.mask is None and (not " " in p.filename): + bg_images.append({'file': os.path.join(images_dir, p.filename), 'p': p}) + + if len(bg_images) > 0: + log.ODM_INFO("Automatically generating background masks for %s images" % len(bg_images)) + model = ai.get_model("bgremoval", "https://github.com/OpenDroneMap/ODM/releases/download/v2.9.0/u2net.zip", "v2.9.0") + if model is not None: + bg = BgFilter(model=model) + + def parallel_bg_filter(item): + try: + mask_file = bg.run_img(item['file'], images_dir) + + # Check and set + if mask_file is not None and os.path.isfile(mask_file): + item['p'].set_mask(os.path.basename(mask_file)) + log.ODM_INFO("Wrote %s" % os.path.basename(mask_file)) + else: + log.ODM_WARNING("Cannot generate mask for %s" % img) + except Exception as e: + log.ODM_WARNING("Cannot generate mask for %s: %s" % (img, str(e))) + + parallel_map(parallel_bg_filter, bg_images, max_workers=args.max_concurrency) + + log.ODM_INFO("Background masks generation completed!") + else: + log.ODM_WARNING("Cannot load AI model (you might need to be connected to the internet?)") + else: + log.ODM_INFO("No background masks will be generated (masks already provided)") + + # End bg removal + # Save image database for faster restart save_images_database(photos, images_database_file) else: