Add AI background removal

pull/1533/head
Piero Toffanin 2022-09-19 15:33:59 -04:00
rodzic deb53279e3
commit cd720006f6
3 zmienionych plików z 138 dodań i 0 usunięć

90
opendm/bgfilter.py 100644
Wyświetl plik

@ -0,0 +1,90 @@
import time
import numpy as np
import cv2
import os
import onnxruntime as ort
from opendm import log
from threading import Lock
mutex = Lock()
# Implementation based on https://github.com/danielgatis/rembg by Daniel Gatis
# Use GPU if it is available, otherwise CPU
provider = "CUDAExecutionProvider" if "CUDAExecutionProvider" in ort.get_available_providers() else "CPUExecutionProvider"
class BgFilter():
def __init__(self, model):
self.model = model
log.ODM_INFO(' ?> Using provider %s' % provider)
self.load_model()
def load_model(self):
log.ODM_INFO(' -> Loading the model')
self.session = ort.InferenceSession(self.model, providers=[provider])
def normalize(self, img, mean, std, size):
im = cv2.resize(img, size, interpolation=cv2.INTER_AREA)
im_ary = np.array(im)
im_ary = im_ary / np.max(im_ary)
tmpImg = np.zeros((im_ary.shape[0], im_ary.shape[1], 3))
tmpImg[:, :, 0] = (im_ary[:, :, 0] - mean[0]) / std[0]
tmpImg[:, :, 1] = (im_ary[:, :, 1] - mean[1]) / std[1]
tmpImg[:, :, 2] = (im_ary[:, :, 2] - mean[2]) / std[2]
tmpImg = tmpImg.transpose((2, 0, 1))
return {
self.session.get_inputs()[0]
.name: np.expand_dims(tmpImg, 0)
.astype(np.float32)
}
def get_mask(self, img):
height, width, c = img.shape
with mutex:
ort_outs = self.session.run(
None,
self.normalize(
img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320) # <-- image size
),
)
pred = ort_outs[0][:, 0, :, :]
ma = np.max(pred)
mi = np.min(pred)
pred = (pred - mi) / (ma - mi)
pred = np.squeeze(pred)
pred *= 255
pred = pred.astype("uint8")
output = cv2.resize(pred, (width, height), interpolation=cv2.INTER_LANCZOS4)
output[output > 127] = 255
output[output <= 127] = 0
return output
def run_img(self, img_path, dest):
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
if img is None:
return None
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
mask = self.get_mask(img)
img_name = os.path.basename(img_path)
fpath = os.path.join(dest, img_name)
fname, _ = os.path.splitext(fpath)
mask_name = fname + '_mask.png'
cv2.imwrite(mask_name, mask)
return mask_name

Wyświetl plik

@ -242,6 +242,12 @@ def config(argv=None, parser=None):
nargs=0,
default=False,
help='Automatically compute image masks using AI to remove the sky. Experimental. Default: %(default)s')
parser.add_argument('--bg-removal',
action=StoreTrue,
nargs=0,
default=False,
help='Automatically compute image masks using AI to remove the background. Experimental. Default: %(default)s')
parser.add_argument('--use-3dmesh',
action=StoreTrue,

Wyświetl plik

@ -13,6 +13,7 @@ from opendm import progress
from opendm import boundary
from opendm import ai
from opendm.skyremoval.skyfilter import SkyFilter
from opendm.bgfilter import BgFilter
from opendm.concurrency import parallel_map
def save_images_database(photos, database_file):
@ -191,6 +192,47 @@ class ODMLoadDatasetStage(types.ODM_Stage):
# End sky removal
# Automatic background removal
if args.bg_removal:
# For each image that :
# - Doesn't already have a mask, AND
# - There are no spaces in the image filename (OpenSfM requirement)
# Generate list of sky images
bg_images = []
for p in photos:
if p.mask is None and (not " " in p.filename):
bg_images.append({'file': os.path.join(images_dir, p.filename), 'p': p})
if len(bg_images) > 0:
log.ODM_INFO("Automatically generating background masks for %s images" % len(bg_images))
model = ai.get_model("bgremoval", "https://github.com/OpenDroneMap/ODM/releases/download/v2.9.0/u2net.zip", "v2.9.0")
if model is not None:
bg = BgFilter(model=model)
def parallel_bg_filter(item):
try:
mask_file = bg.run_img(item['file'], images_dir)
# Check and set
if mask_file is not None and os.path.isfile(mask_file):
item['p'].set_mask(os.path.basename(mask_file))
log.ODM_INFO("Wrote %s" % os.path.basename(mask_file))
else:
log.ODM_WARNING("Cannot generate mask for %s" % img)
except Exception as e:
log.ODM_WARNING("Cannot generate mask for %s: %s" % (img, str(e)))
parallel_map(parallel_bg_filter, bg_images, max_workers=args.max_concurrency)
log.ODM_INFO("Background masks generation completed!")
else:
log.ODM_WARNING("Cannot load AI model (you might need to be connected to the internet?)")
else:
log.ODM_INFO("No background masks will be generated (masks already provided)")
# End bg removal
# Save image database for faster restart
save_images_database(photos, images_database_file)
else: