kopia lustrzana https://github.com/OpenDroneMap/ODM
Add AI background removal
rodzic
deb53279e3
commit
cd720006f6
|
@ -0,0 +1,90 @@
|
||||||
|
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import os
|
||||||
|
import onnxruntime as ort
|
||||||
|
from opendm import log
|
||||||
|
from threading import Lock
|
||||||
|
|
||||||
|
mutex = Lock()
|
||||||
|
|
||||||
|
# Implementation based on https://github.com/danielgatis/rembg by Daniel Gatis
|
||||||
|
|
||||||
|
# Use GPU if it is available, otherwise CPU
|
||||||
|
provider = "CUDAExecutionProvider" if "CUDAExecutionProvider" in ort.get_available_providers() else "CPUExecutionProvider"
|
||||||
|
|
||||||
|
class BgFilter():
|
||||||
|
def __init__(self, model):
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
log.ODM_INFO(' ?> Using provider %s' % provider)
|
||||||
|
self.load_model()
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
log.ODM_INFO(' -> Loading the model')
|
||||||
|
|
||||||
|
self.session = ort.InferenceSession(self.model, providers=[provider])
|
||||||
|
|
||||||
|
def normalize(self, img, mean, std, size):
|
||||||
|
im = cv2.resize(img, size, interpolation=cv2.INTER_AREA)
|
||||||
|
im_ary = np.array(im)
|
||||||
|
im_ary = im_ary / np.max(im_ary)
|
||||||
|
|
||||||
|
tmpImg = np.zeros((im_ary.shape[0], im_ary.shape[1], 3))
|
||||||
|
tmpImg[:, :, 0] = (im_ary[:, :, 0] - mean[0]) / std[0]
|
||||||
|
tmpImg[:, :, 1] = (im_ary[:, :, 1] - mean[1]) / std[1]
|
||||||
|
tmpImg[:, :, 2] = (im_ary[:, :, 2] - mean[2]) / std[2]
|
||||||
|
|
||||||
|
tmpImg = tmpImg.transpose((2, 0, 1))
|
||||||
|
|
||||||
|
return {
|
||||||
|
self.session.get_inputs()[0]
|
||||||
|
.name: np.expand_dims(tmpImg, 0)
|
||||||
|
.astype(np.float32)
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_mask(self, img):
|
||||||
|
height, width, c = img.shape
|
||||||
|
|
||||||
|
with mutex:
|
||||||
|
ort_outs = self.session.run(
|
||||||
|
None,
|
||||||
|
self.normalize(
|
||||||
|
img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320) # <-- image size
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
pred = ort_outs[0][:, 0, :, :]
|
||||||
|
|
||||||
|
ma = np.max(pred)
|
||||||
|
mi = np.min(pred)
|
||||||
|
|
||||||
|
pred = (pred - mi) / (ma - mi)
|
||||||
|
pred = np.squeeze(pred)
|
||||||
|
|
||||||
|
pred *= 255
|
||||||
|
pred = pred.astype("uint8")
|
||||||
|
output = cv2.resize(pred, (width, height), interpolation=cv2.INTER_LANCZOS4)
|
||||||
|
output[output > 127] = 255
|
||||||
|
output[output <= 127] = 0
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def run_img(self, img_path, dest):
|
||||||
|
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
|
||||||
|
if img is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
mask = self.get_mask(img)
|
||||||
|
|
||||||
|
img_name = os.path.basename(img_path)
|
||||||
|
fpath = os.path.join(dest, img_name)
|
||||||
|
|
||||||
|
fname, _ = os.path.splitext(fpath)
|
||||||
|
mask_name = fname + '_mask.png'
|
||||||
|
cv2.imwrite(mask_name, mask)
|
||||||
|
|
||||||
|
return mask_name
|
|
@ -242,6 +242,12 @@ def config(argv=None, parser=None):
|
||||||
nargs=0,
|
nargs=0,
|
||||||
default=False,
|
default=False,
|
||||||
help='Automatically compute image masks using AI to remove the sky. Experimental. Default: %(default)s')
|
help='Automatically compute image masks using AI to remove the sky. Experimental. Default: %(default)s')
|
||||||
|
|
||||||
|
parser.add_argument('--bg-removal',
|
||||||
|
action=StoreTrue,
|
||||||
|
nargs=0,
|
||||||
|
default=False,
|
||||||
|
help='Automatically compute image masks using AI to remove the background. Experimental. Default: %(default)s')
|
||||||
|
|
||||||
parser.add_argument('--use-3dmesh',
|
parser.add_argument('--use-3dmesh',
|
||||||
action=StoreTrue,
|
action=StoreTrue,
|
||||||
|
|
|
@ -13,6 +13,7 @@ from opendm import progress
|
||||||
from opendm import boundary
|
from opendm import boundary
|
||||||
from opendm import ai
|
from opendm import ai
|
||||||
from opendm.skyremoval.skyfilter import SkyFilter
|
from opendm.skyremoval.skyfilter import SkyFilter
|
||||||
|
from opendm.bgfilter import BgFilter
|
||||||
from opendm.concurrency import parallel_map
|
from opendm.concurrency import parallel_map
|
||||||
|
|
||||||
def save_images_database(photos, database_file):
|
def save_images_database(photos, database_file):
|
||||||
|
@ -191,6 +192,47 @@ class ODMLoadDatasetStage(types.ODM_Stage):
|
||||||
|
|
||||||
# End sky removal
|
# End sky removal
|
||||||
|
|
||||||
|
# Automatic background removal
|
||||||
|
if args.bg_removal:
|
||||||
|
# For each image that :
|
||||||
|
# - Doesn't already have a mask, AND
|
||||||
|
# - There are no spaces in the image filename (OpenSfM requirement)
|
||||||
|
|
||||||
|
# Generate list of sky images
|
||||||
|
bg_images = []
|
||||||
|
for p in photos:
|
||||||
|
if p.mask is None and (not " " in p.filename):
|
||||||
|
bg_images.append({'file': os.path.join(images_dir, p.filename), 'p': p})
|
||||||
|
|
||||||
|
if len(bg_images) > 0:
|
||||||
|
log.ODM_INFO("Automatically generating background masks for %s images" % len(bg_images))
|
||||||
|
model = ai.get_model("bgremoval", "https://github.com/OpenDroneMap/ODM/releases/download/v2.9.0/u2net.zip", "v2.9.0")
|
||||||
|
if model is not None:
|
||||||
|
bg = BgFilter(model=model)
|
||||||
|
|
||||||
|
def parallel_bg_filter(item):
|
||||||
|
try:
|
||||||
|
mask_file = bg.run_img(item['file'], images_dir)
|
||||||
|
|
||||||
|
# Check and set
|
||||||
|
if mask_file is not None and os.path.isfile(mask_file):
|
||||||
|
item['p'].set_mask(os.path.basename(mask_file))
|
||||||
|
log.ODM_INFO("Wrote %s" % os.path.basename(mask_file))
|
||||||
|
else:
|
||||||
|
log.ODM_WARNING("Cannot generate mask for %s" % img)
|
||||||
|
except Exception as e:
|
||||||
|
log.ODM_WARNING("Cannot generate mask for %s: %s" % (img, str(e)))
|
||||||
|
|
||||||
|
parallel_map(parallel_bg_filter, bg_images, max_workers=args.max_concurrency)
|
||||||
|
|
||||||
|
log.ODM_INFO("Background masks generation completed!")
|
||||||
|
else:
|
||||||
|
log.ODM_WARNING("Cannot load AI model (you might need to be connected to the internet?)")
|
||||||
|
else:
|
||||||
|
log.ODM_INFO("No background masks will be generated (masks already provided)")
|
||||||
|
|
||||||
|
# End bg removal
|
||||||
|
|
||||||
# Save image database for faster restart
|
# Save image database for faster restart
|
||||||
save_images_database(photos, images_database_file)
|
save_images_database(photos, images_database_file)
|
||||||
else:
|
else:
|
||||||
|
|
Ładowanie…
Reference in New Issue