Merge pull request #1533 from pierotofy/bgremoval

Add AI background removal
pull/1538/head
Piero Toffanin 2022-09-19 19:47:49 -04:00 zatwierdzone przez GitHub
commit f6d6210827
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: 4AEE18F83AFDEB23
6 zmienionych plików z 141 dodań i 3 usunięć

Wyświetl plik

@ -53,7 +53,7 @@ ExternalProject_Add(${_proj_name}
#--Download step-------------- #--Download step--------------
DOWNLOAD_DIR ${SB_DOWNLOAD_DIR} DOWNLOAD_DIR ${SB_DOWNLOAD_DIR}
GIT_REPOSITORY https://github.com/OpenDroneMap/openMVS GIT_REPOSITORY https://github.com/OpenDroneMap/openMVS
GIT_TAG 288 GIT_TAG 291
#--Update/Patch step---------- #--Update/Patch step----------
UPDATE_COMMAND "" UPDATE_COMMAND ""
#--Configure step------------- #--Configure step-------------

Wyświetl plik

@ -25,7 +25,7 @@ ExternalProject_Add(${_proj_name}
#--Download step-------------- #--Download step--------------
DOWNLOAD_DIR ${SB_DOWNLOAD_DIR} DOWNLOAD_DIR ${SB_DOWNLOAD_DIR}
GIT_REPOSITORY https://github.com/OpenDroneMap/OpenSfM/ GIT_REPOSITORY https://github.com/OpenDroneMap/OpenSfM/
GIT_TAG 290 GIT_TAG 291
#--Update/Patch step---------- #--Update/Patch step----------
UPDATE_COMMAND git submodule update --init --recursive UPDATE_COMMAND git submodule update --init --recursive
#--Configure step------------- #--Configure step-------------

Wyświetl plik

@ -1 +1 @@
2.9.0 2.9.1

90
opendm/bgfilter.py 100644
Wyświetl plik

@ -0,0 +1,90 @@
import time
import numpy as np
import cv2
import os
import onnxruntime as ort
from opendm import log
from threading import Lock
mutex = Lock()
# Implementation based on https://github.com/danielgatis/rembg by Daniel Gatis
# Use GPU if it is available, otherwise CPU
provider = "CUDAExecutionProvider" if "CUDAExecutionProvider" in ort.get_available_providers() else "CPUExecutionProvider"
class BgFilter():
def __init__(self, model):
self.model = model
log.ODM_INFO(' ?> Using provider %s' % provider)
self.load_model()
def load_model(self):
log.ODM_INFO(' -> Loading the model')
self.session = ort.InferenceSession(self.model, providers=[provider])
def normalize(self, img, mean, std, size):
im = cv2.resize(img, size, interpolation=cv2.INTER_AREA)
im_ary = np.array(im)
im_ary = im_ary / np.max(im_ary)
tmpImg = np.zeros((im_ary.shape[0], im_ary.shape[1], 3))
tmpImg[:, :, 0] = (im_ary[:, :, 0] - mean[0]) / std[0]
tmpImg[:, :, 1] = (im_ary[:, :, 1] - mean[1]) / std[1]
tmpImg[:, :, 2] = (im_ary[:, :, 2] - mean[2]) / std[2]
tmpImg = tmpImg.transpose((2, 0, 1))
return {
self.session.get_inputs()[0]
.name: np.expand_dims(tmpImg, 0)
.astype(np.float32)
}
def get_mask(self, img):
height, width, c = img.shape
with mutex:
ort_outs = self.session.run(
None,
self.normalize(
img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320) # <-- image size
),
)
pred = ort_outs[0][:, 0, :, :]
ma = np.max(pred)
mi = np.min(pred)
pred = (pred - mi) / (ma - mi)
pred = np.squeeze(pred)
pred *= 255
pred = pred.astype("uint8")
output = cv2.resize(pred, (width, height), interpolation=cv2.INTER_LANCZOS4)
output[output > 127] = 255
output[output <= 127] = 0
return output
def run_img(self, img_path, dest):
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
if img is None:
return None
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
mask = self.get_mask(img)
img_name = os.path.basename(img_path)
fpath = os.path.join(dest, img_name)
fname, _ = os.path.splitext(fpath)
mask_name = fname + '_mask.png'
cv2.imwrite(mask_name, mask)
return mask_name

Wyświetl plik

@ -243,6 +243,12 @@ def config(argv=None, parser=None):
default=False, default=False,
help='Automatically compute image masks using AI to remove the sky. Experimental. Default: %(default)s') help='Automatically compute image masks using AI to remove the sky. Experimental. Default: %(default)s')
parser.add_argument('--bg-removal',
action=StoreTrue,
nargs=0,
default=False,
help='Automatically compute image masks using AI to remove the background. Experimental. Default: %(default)s')
parser.add_argument('--use-3dmesh', parser.add_argument('--use-3dmesh',
action=StoreTrue, action=StoreTrue,
nargs=0, nargs=0,

Wyświetl plik

@ -13,6 +13,7 @@ from opendm import progress
from opendm import boundary from opendm import boundary
from opendm import ai from opendm import ai
from opendm.skyremoval.skyfilter import SkyFilter from opendm.skyremoval.skyfilter import SkyFilter
from opendm.bgfilter import BgFilter
from opendm.concurrency import parallel_map from opendm.concurrency import parallel_map
def save_images_database(photos, database_file): def save_images_database(photos, database_file):
@ -191,6 +192,47 @@ class ODMLoadDatasetStage(types.ODM_Stage):
# End sky removal # End sky removal
# Automatic background removal
if args.bg_removal:
# For each image that :
# - Doesn't already have a mask, AND
# - There are no spaces in the image filename (OpenSfM requirement)
# Generate list of sky images
bg_images = []
for p in photos:
if p.mask is None and (not " " in p.filename):
bg_images.append({'file': os.path.join(images_dir, p.filename), 'p': p})
if len(bg_images) > 0:
log.ODM_INFO("Automatically generating background masks for %s images" % len(bg_images))
model = ai.get_model("bgremoval", "https://github.com/OpenDroneMap/ODM/releases/download/v2.9.0/u2net.zip", "v2.9.0")
if model is not None:
bg = BgFilter(model=model)
def parallel_bg_filter(item):
try:
mask_file = bg.run_img(item['file'], images_dir)
# Check and set
if mask_file is not None and os.path.isfile(mask_file):
item['p'].set_mask(os.path.basename(mask_file))
log.ODM_INFO("Wrote %s" % os.path.basename(mask_file))
else:
log.ODM_WARNING("Cannot generate mask for %s" % img)
except Exception as e:
log.ODM_WARNING("Cannot generate mask for %s: %s" % (img, str(e)))
parallel_map(parallel_bg_filter, bg_images, max_workers=args.max_concurrency)
log.ODM_INFO("Background masks generation completed!")
else:
log.ODM_WARNING("Cannot load AI model (you might need to be connected to the internet?)")
else:
log.ODM_INFO("No background masks will be generated (masks already provided)")
# End bg removal
# Save image database for faster restart # Save image database for faster restart
save_images_database(photos, images_database_file) save_images_database(photos, images_database_file)
else: else: