Add AI background removal

pull/1533/head
Piero Toffanin 2022-09-19 15:33:59 -04:00
rodzic deb53279e3
commit cd720006f6
3 zmienionych plików z 138 dodań i 0 usunięć

90
opendm/bgfilter.py 100644
Wyświetl plik

@ -0,0 +1,90 @@
import time
import numpy as np
import cv2
import os
import onnxruntime as ort
from opendm import log
from threading import Lock
mutex = Lock()
# Implementation based on https://github.com/danielgatis/rembg by Daniel Gatis
# Use GPU if it is available, otherwise CPU
provider = "CUDAExecutionProvider" if "CUDAExecutionProvider" in ort.get_available_providers() else "CPUExecutionProvider"
class BgFilter():
def __init__(self, model):
self.model = model
log.ODM_INFO(' ?> Using provider %s' % provider)
self.load_model()
def load_model(self):
log.ODM_INFO(' -> Loading the model')
self.session = ort.InferenceSession(self.model, providers=[provider])
def normalize(self, img, mean, std, size):
im = cv2.resize(img, size, interpolation=cv2.INTER_AREA)
im_ary = np.array(im)
im_ary = im_ary / np.max(im_ary)
tmpImg = np.zeros((im_ary.shape[0], im_ary.shape[1], 3))
tmpImg[:, :, 0] = (im_ary[:, :, 0] - mean[0]) / std[0]
tmpImg[:, :, 1] = (im_ary[:, :, 1] - mean[1]) / std[1]
tmpImg[:, :, 2] = (im_ary[:, :, 2] - mean[2]) / std[2]
tmpImg = tmpImg.transpose((2, 0, 1))
return {
self.session.get_inputs()[0]
.name: np.expand_dims(tmpImg, 0)
.astype(np.float32)
}
def get_mask(self, img):
height, width, c = img.shape
with mutex:
ort_outs = self.session.run(
None,
self.normalize(
img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320) # <-- image size
),
)
pred = ort_outs[0][:, 0, :, :]
ma = np.max(pred)
mi = np.min(pred)
pred = (pred - mi) / (ma - mi)
pred = np.squeeze(pred)
pred *= 255
pred = pred.astype("uint8")
output = cv2.resize(pred, (width, height), interpolation=cv2.INTER_LANCZOS4)
output[output > 127] = 255
output[output <= 127] = 0
return output
def run_img(self, img_path, dest):
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
if img is None:
return None
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
mask = self.get_mask(img)
img_name = os.path.basename(img_path)
fpath = os.path.join(dest, img_name)
fname, _ = os.path.splitext(fpath)
mask_name = fname + '_mask.png'
cv2.imwrite(mask_name, mask)
return mask_name

Wyświetl plik

@ -242,6 +242,12 @@ def config(argv=None, parser=None):
nargs=0, nargs=0,
default=False, default=False,
help='Automatically compute image masks using AI to remove the sky. Experimental. Default: %(default)s') help='Automatically compute image masks using AI to remove the sky. Experimental. Default: %(default)s')
parser.add_argument('--bg-removal',
action=StoreTrue,
nargs=0,
default=False,
help='Automatically compute image masks using AI to remove the background. Experimental. Default: %(default)s')
parser.add_argument('--use-3dmesh', parser.add_argument('--use-3dmesh',
action=StoreTrue, action=StoreTrue,

Wyświetl plik

@ -13,6 +13,7 @@ from opendm import progress
from opendm import boundary from opendm import boundary
from opendm import ai from opendm import ai
from opendm.skyremoval.skyfilter import SkyFilter from opendm.skyremoval.skyfilter import SkyFilter
from opendm.bgfilter import BgFilter
from opendm.concurrency import parallel_map from opendm.concurrency import parallel_map
def save_images_database(photos, database_file): def save_images_database(photos, database_file):
@ -191,6 +192,47 @@ class ODMLoadDatasetStage(types.ODM_Stage):
# End sky removal # End sky removal
# Automatic background removal
if args.bg_removal:
# For each image that :
# - Doesn't already have a mask, AND
# - There are no spaces in the image filename (OpenSfM requirement)
# Generate list of sky images
bg_images = []
for p in photos:
if p.mask is None and (not " " in p.filename):
bg_images.append({'file': os.path.join(images_dir, p.filename), 'p': p})
if len(bg_images) > 0:
log.ODM_INFO("Automatically generating background masks for %s images" % len(bg_images))
model = ai.get_model("bgremoval", "https://github.com/OpenDroneMap/ODM/releases/download/v2.9.0/u2net.zip", "v2.9.0")
if model is not None:
bg = BgFilter(model=model)
def parallel_bg_filter(item):
try:
mask_file = bg.run_img(item['file'], images_dir)
# Check and set
if mask_file is not None and os.path.isfile(mask_file):
item['p'].set_mask(os.path.basename(mask_file))
log.ODM_INFO("Wrote %s" % os.path.basename(mask_file))
else:
log.ODM_WARNING("Cannot generate mask for %s" % img)
except Exception as e:
log.ODM_WARNING("Cannot generate mask for %s: %s" % (img, str(e)))
parallel_map(parallel_bg_filter, bg_images, max_workers=args.max_concurrency)
log.ODM_INFO("Background masks generation completed!")
else:
log.ODM_WARNING("Cannot load AI model (you might need to be connected to the internet?)")
else:
log.ODM_INFO("No background masks will be generated (masks already provided)")
# End bg removal
# Save image database for faster restart # Save image database for faster restart
save_images_database(photos, images_database_file) save_images_database(photos, images_database_file)
else: else: