kopia lustrzana https://github.com/OpenDroneMap/ODM
AI-powered automatic sky removal
rodzic
d61d0e0cbe
commit
b584459fc9
|
@ -27,3 +27,4 @@ settings.yaml
|
|||
.setupdevenv
|
||||
__pycache__
|
||||
*.snap
|
||||
storage/
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
import os
|
||||
from opendm.net import download
|
||||
from opendm import log
|
||||
import zipfile
|
||||
import time
|
||||
|
||||
def get_model(namespace, url, version, name = "model.onnx"):
|
||||
version = version.replace(".", "_")
|
||||
|
||||
base_dir = os.path.join(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")), "storage", "models")
|
||||
namespace_dir = os.path.join(base_dir, namespace)
|
||||
versioned_dir = os.path.join(namespace_dir, version)
|
||||
|
||||
if not os.path.isdir(versioned_dir):
|
||||
os.makedirs(versioned_dir, exist_ok=True)
|
||||
|
||||
# Check if we need to download it
|
||||
model_file = os.path.join(versioned_dir, name)
|
||||
if not os.path.isfile(model_file):
|
||||
log.ODM_INFO("Downloading AI model from %s ..." % url)
|
||||
|
||||
last_update = 0
|
||||
|
||||
def callback(progress):
|
||||
nonlocal last_update
|
||||
|
||||
time_has_elapsed = time.time() - last_update >= 2
|
||||
|
||||
if time_has_elapsed or int(progress) == 100:
|
||||
log.ODM_INFO("Downloading: %s%%" % int(progress))
|
||||
last_update = time.time()
|
||||
|
||||
try:
|
||||
downloaded_file = download(url, versioned_dir, progress_callback=callback)
|
||||
except Exception as e:
|
||||
log.ODM_WARNING("Cannot download %s: %s" % (url, str(e)))
|
||||
return None
|
||||
|
||||
if os.path.basename(downloaded_file).lower().endswith(".zip"):
|
||||
log.ODM_INFO("Extracting %s ..." % downloaded_file)
|
||||
with zipfile.ZipFile(downloaded_file, 'r') as z:
|
||||
z.extractall(versioned_dir)
|
||||
os.remove(downloaded_file)
|
||||
|
||||
if not os.path.isfile(model_file):
|
||||
log.ODM_WARNING("Cannot find %s (is the URL to the AI model correct?)" % model_file)
|
||||
return None
|
||||
else:
|
||||
return model_file
|
||||
else:
|
||||
return model_file
|
|
@ -25,7 +25,7 @@ def get_max_memory_mb(minimum = 100, use_at_most = 0.5):
|
|||
"""
|
||||
return max(minimum, (virtual_memory().available / 1024 / 1024) * use_at_most)
|
||||
|
||||
def parallel_map(func, items, max_workers=1, single_thread_fallback=True):
|
||||
def parallel_map(func, items, max_workers=1, single_thread_fallback=True, copy_queue_items=True):
|
||||
"""
|
||||
Our own implementation for parallel processing
|
||||
which handles gracefully CTRL+C and reverts to
|
||||
|
@ -66,7 +66,10 @@ def parallel_map(func, items, max_workers=1, single_thread_fallback=True):
|
|||
|
||||
i = 1
|
||||
for t in items:
|
||||
pq.put((i, t.copy()))
|
||||
if copy_queue_items:
|
||||
pq.put((i, t.copy()))
|
||||
else:
|
||||
pq.put((i, t))
|
||||
i += 1
|
||||
|
||||
def stop_workers():
|
||||
|
|
|
@ -237,6 +237,12 @@ def config(argv=None, parser=None):
|
|||
'Can be one of: %(choices)s. Default: '
|
||||
'%(default)s'))
|
||||
|
||||
parser.add_argument('--sky-removal',
|
||||
action=StoreTrue,
|
||||
nargs=0,
|
||||
default=False,
|
||||
help='Automatically compute image masks using AI to remove the sky. Default: %(default)s')
|
||||
|
||||
parser.add_argument('--use-3dmesh',
|
||||
action=StoreTrue,
|
||||
nargs=0,
|
||||
|
|
|
@ -0,0 +1,164 @@
|
|||
import requests
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
try:
|
||||
import queue
|
||||
except ImportError:
|
||||
import Queue as queue
|
||||
import threading
|
||||
from pyodm.utils import AtomicCounter
|
||||
from pyodm.exceptions import RangeNotAvailableError, OdmError
|
||||
from urllib3.exceptions import ReadTimeoutError
|
||||
|
||||
def download(url, destination, progress_callback=None, parallel_downloads=16, parallel_chunks_size=10, timeout=30):
|
||||
"""Download files in parallel (download accelerator)
|
||||
|
||||
Args:
|
||||
url (str): URL to download
|
||||
destination (str): directory where to download file. If the directory does not exist, it will be created.
|
||||
progress_callback (function): an optional callback with one parameter, the download progress percentage.
|
||||
parallel_downloads (int): maximum number of parallel downloads if the node supports http range.
|
||||
parallel_chunks_size (int): size in MB of chunks for parallel downloads
|
||||
timeout (int): seconds before timing out
|
||||
Returns:
|
||||
str: path to file
|
||||
"""
|
||||
if not os.path.exists(destination):
|
||||
os.makedirs(destination, exist_ok=True)
|
||||
|
||||
try:
|
||||
|
||||
download_stream = requests.get(url, timeout=timeout, stream=True)
|
||||
headers = download_stream.headers
|
||||
|
||||
output_path = os.path.join(destination, os.path.basename(url))
|
||||
|
||||
# Keep track of download progress (if possible)
|
||||
content_length = download_stream.headers.get('content-length')
|
||||
total_length = int(content_length) if content_length is not None else None
|
||||
downloaded = 0
|
||||
chunk_size = int(parallel_chunks_size * 1024 * 1024)
|
||||
use_fallback = False
|
||||
accept_ranges = headers.get('accept-ranges')
|
||||
|
||||
# Can we do parallel downloads?
|
||||
if accept_ranges is not None and accept_ranges.lower() == 'bytes' and total_length is not None and total_length > chunk_size and parallel_downloads > 1:
|
||||
num_chunks = int(math.ceil(total_length / float(chunk_size)))
|
||||
num_workers = parallel_downloads
|
||||
|
||||
class nonloc:
|
||||
completed_chunks = AtomicCounter(0)
|
||||
merge_chunks = [False] * num_chunks
|
||||
error = None
|
||||
|
||||
def merge():
|
||||
current_chunk = 0
|
||||
|
||||
with open(output_path, "wb") as out_file:
|
||||
while current_chunk < num_chunks and nonloc.error is None:
|
||||
if nonloc.merge_chunks[current_chunk]:
|
||||
chunk_file = "%s.part%s" % (output_path, current_chunk)
|
||||
with open(chunk_file, "rb") as fd:
|
||||
out_file.write(fd.read())
|
||||
|
||||
os.unlink(chunk_file)
|
||||
|
||||
current_chunk += 1
|
||||
else:
|
||||
time.sleep(0.1)
|
||||
|
||||
def worker():
|
||||
while True:
|
||||
task = q.get()
|
||||
part_num, bytes_range = task
|
||||
if bytes_range is None or nonloc.error is not None:
|
||||
q.task_done()
|
||||
break
|
||||
|
||||
try:
|
||||
# Download chunk
|
||||
res = requests.get(url, stream=True, timeout=timeout, headers={'Range': 'bytes=%s-%s' % bytes_range})
|
||||
if res.status_code == 206:
|
||||
with open("%s.part%s" % (output_path, part_num), 'wb') as fd:
|
||||
bytes_written = 0
|
||||
try:
|
||||
for chunk in res.iter_content(4096):
|
||||
bytes_written += fd.write(chunk)
|
||||
except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as e:
|
||||
raise OdmError(str(e))
|
||||
|
||||
if bytes_written != (bytes_range[1] - bytes_range[0] + 1):
|
||||
# Process again
|
||||
q.put((part_num, bytes_range))
|
||||
return
|
||||
|
||||
with nonloc.completed_chunks.lock:
|
||||
nonloc.completed_chunks.value += 1
|
||||
|
||||
if progress_callback is not None:
|
||||
progress_callback(100.0 * nonloc.completed_chunks.value / num_chunks)
|
||||
|
||||
nonloc.merge_chunks[part_num] = True
|
||||
else:
|
||||
nonloc.error = RangeNotAvailableError()
|
||||
except OdmError as e:
|
||||
time.sleep(5)
|
||||
q.put((part_num, bytes_range))
|
||||
except Exception as e:
|
||||
nonloc.error = e
|
||||
finally:
|
||||
q.task_done()
|
||||
|
||||
q = queue.PriorityQueue()
|
||||
threads = []
|
||||
for i in range(num_workers):
|
||||
t = threading.Thread(target=worker)
|
||||
t.start()
|
||||
threads.append(t)
|
||||
|
||||
merge_thread = threading.Thread(target=merge)
|
||||
merge_thread.start()
|
||||
|
||||
range_start = 0
|
||||
|
||||
for i in range(num_chunks):
|
||||
range_end = min(range_start + chunk_size - 1, total_length - 1)
|
||||
q.put((i, (range_start, range_end)))
|
||||
range_start = range_end + 1
|
||||
|
||||
# block until all tasks are done
|
||||
while not all(nonloc.merge_chunks) and nonloc.error is None:
|
||||
time.sleep(0.1)
|
||||
|
||||
# stop workers
|
||||
for i in range(len(threads)):
|
||||
q.put((-1, None))
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
merge_thread.join()
|
||||
|
||||
if nonloc.error is not None:
|
||||
if isinstance(nonloc.error, RangeNotAvailableError):
|
||||
use_fallback = True
|
||||
else:
|
||||
raise nonloc.error
|
||||
else:
|
||||
use_fallback = True
|
||||
|
||||
if use_fallback:
|
||||
# Single connection, boring download
|
||||
with open(output_path, 'wb') as fd:
|
||||
for chunk in download_stream.iter_content(4096):
|
||||
downloaded += len(chunk)
|
||||
|
||||
if progress_callback is not None and total_length is not None:
|
||||
progress_callback((100.0 * float(downloaded) / total_length))
|
||||
|
||||
fd.write(chunk)
|
||||
|
||||
except (requests.exceptions.Timeout, requests.exceptions.ConnectionError, ReadTimeoutError) as e:
|
||||
raise OdmError(e)
|
||||
|
||||
return output_path
|
|
@ -0,0 +1,37 @@
|
|||
import numpy as np
|
||||
|
||||
# Based on Fast Guided Filter
|
||||
# Kaiming He, Jian Sun
|
||||
# https://arxiv.org/abs/1505.00996
|
||||
|
||||
def box(img, radius):
|
||||
dst = np.zeros_like(img)
|
||||
(r, c) = img.shape
|
||||
|
||||
s = [radius, 1]
|
||||
c_sum = np.cumsum(img, 0)
|
||||
dst[0:radius+1, :, ...] = c_sum[radius:2*radius+1, :, ...]
|
||||
dst[radius+1:r-radius, :, ...] = c_sum[2*radius+1:r, :, ...] - c_sum[0:r-2*radius-1, :, ...]
|
||||
dst[r-radius:r, :, ...] = np.tile(c_sum[r-1:r, :, ...], s) - c_sum[r-2*radius-1:r-radius-1, :, ...]
|
||||
|
||||
s = [1, radius]
|
||||
c_sum = np.cumsum(dst, 1)
|
||||
dst[:, 0:radius+1, ...] = c_sum[:, radius:2*radius+1, ...]
|
||||
dst[:, radius+1:c-radius, ...] = c_sum[:, 2*radius+1 : c, ...] - c_sum[:, 0 : c-2*radius-1, ...]
|
||||
dst[:, c-radius: c, ...] = np.tile(c_sum[:, c-1:c, ...], s) - c_sum[:, c-2*radius-1 : c-radius-1, ...]
|
||||
|
||||
return dst
|
||||
|
||||
|
||||
def guided_filter(img, guide, radius, eps):
|
||||
(r, c) = img.shape
|
||||
|
||||
CNT = box(np.ones([r, c]), radius)
|
||||
|
||||
mean_img = box(img, radius) / CNT
|
||||
mean_guide = box(guide, radius) / CNT
|
||||
|
||||
a = ((box(img * guide, radius) / CNT) - mean_img * mean_guide) / (((box(img * img, radius) / CNT) - mean_img * mean_img) + eps)
|
||||
b = mean_guide - a * mean_img
|
||||
|
||||
return (box(a, radius) / CNT) * img + (box(b, radius) / CNT)
|
|
@ -0,0 +1,103 @@
|
|||
|
||||
import time
|
||||
import numpy as np
|
||||
import cv2
|
||||
import os
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
from .guidedfilter import guided_filter
|
||||
from opendm import log
|
||||
from threading import Lock
|
||||
|
||||
mutex = Lock()
|
||||
|
||||
# Use GPU if it is available, otherwise CPU
|
||||
provider = "CUDAExecutionProvider" if "CUDAExecutionProvider" in ort.get_available_providers() else "CPUExecutionProvider"
|
||||
|
||||
class SkyFilter():
|
||||
|
||||
def __init__(self, model, width = 384, height = 384):
|
||||
|
||||
self.model = model
|
||||
self.width, self.height = width, height
|
||||
|
||||
log.ODM_INFO(' ?> Using provider %s' % provider)
|
||||
self.load_model()
|
||||
|
||||
|
||||
def load_model(self):
|
||||
log.ODM_INFO(' -> Loading the model')
|
||||
onnx_model = onnx.load(self.model)
|
||||
|
||||
# Check the model
|
||||
try:
|
||||
onnx.checker.check_model(onnx_model)
|
||||
except onnx.checker.ValidationError as e:
|
||||
log.ODM_INFO(' !> The model is invalid: %s' % e)
|
||||
raise
|
||||
else:
|
||||
log.ODM_INFO(' ?> The model is valid!')
|
||||
|
||||
self.session = ort.InferenceSession(self.model, providers=[provider])
|
||||
|
||||
|
||||
def get_mask(self, img):
|
||||
|
||||
height, width, c = img.shape
|
||||
|
||||
# Resize image to fit the model input
|
||||
new_img = cv2.resize(img, (self.width, self.height), interpolation=cv2.INTER_AREA)
|
||||
new_img = np.array(new_img, dtype=np.float32)
|
||||
|
||||
# Input vector for onnx model
|
||||
input_v = np.expand_dims(new_img.transpose((2, 0, 1)), axis=0)
|
||||
ort_inputs = {self.session.get_inputs()[0].name: input_v}
|
||||
|
||||
# Run the model
|
||||
with mutex:
|
||||
ort_outs = self.session.run(None, ort_inputs)
|
||||
|
||||
# Get the output
|
||||
output = np.array(ort_outs)
|
||||
output = output[0][0].transpose((1, 2, 0))
|
||||
output = cv2.resize(output, (width, height), interpolation=cv2.INTER_LANCZOS4)
|
||||
output = np.array([output, output, output]).transpose((1, 2, 0))
|
||||
output = np.clip(output, a_max=1.0, a_min=0.0)
|
||||
|
||||
return self.refine(output, img)
|
||||
|
||||
|
||||
def refine(self, pred, img):
|
||||
guided_filter_radius, guided_filter_eps = 20, 0.01
|
||||
refined = guided_filter(img[:,:,2], pred[:,:,0], guided_filter_radius, guided_filter_eps)
|
||||
|
||||
res = np.clip(refined, a_min=0, a_max=1)
|
||||
|
||||
# Convert res to CV_8UC1
|
||||
res = np.array(res * 255., dtype=np.uint8)
|
||||
|
||||
# Thresholding
|
||||
res = cv2.threshold(res, 127, 255, cv2.THRESH_BINARY_INV)[1]
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def run_img(self, img_path, dest):
|
||||
|
||||
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
return None
|
||||
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img = np.array(img / 255., dtype=np.float32)
|
||||
|
||||
mask = self.get_mask(img)
|
||||
|
||||
img_name = os.path.basename(img_path)
|
||||
fpath = os.path.join(dest, img_name)
|
||||
|
||||
fname, _ = os.path.splitext(fpath)
|
||||
mask_name = fname + '_mask.png'
|
||||
cv2.imwrite(mask_name, mask)
|
||||
|
||||
return mask_name
|
|
@ -29,3 +29,5 @@ scipy==1.5.4
|
|||
xmltodict==0.12.0
|
||||
fpdf2==2.4.6
|
||||
Shapely==1.7.1
|
||||
onnx==1.12.0
|
||||
onnxruntime==1.11.1
|
|
@ -11,6 +11,9 @@ from opendm.geo import GeoFile
|
|||
from shutil import copyfile
|
||||
from opendm import progress
|
||||
from opendm import boundary
|
||||
from opendm import ai
|
||||
from opendm.skyremoval.skyfilter import SkyFilter
|
||||
from opendm.concurrency import parallel_map
|
||||
|
||||
def save_images_database(photos, database_file):
|
||||
with open(database_file, 'w') as f:
|
||||
|
@ -113,7 +116,7 @@ class ODMLoadDatasetStage(types.ODM_Stage):
|
|||
try:
|
||||
p = types.ODM_Photo(f)
|
||||
p.set_mask(find_mask(f, masks))
|
||||
photos += [p]
|
||||
photos.append(p)
|
||||
dataset_list.write(photos[-1].filename + '\n')
|
||||
except PhotoCorruptedException:
|
||||
log.ODM_WARNING("%s seems corrupted and will not be used" % os.path.basename(f))
|
||||
|
@ -145,6 +148,49 @@ class ODMLoadDatasetStage(types.ODM_Stage):
|
|||
for p in photos:
|
||||
p.override_camera_projection(args.camera_lens)
|
||||
|
||||
# Automatic sky removal
|
||||
if args.sky_removal:
|
||||
# For each image that :
|
||||
# - Doesn't already have a mask, AND
|
||||
# - Is not nadir (or if orientation info is missing), AND
|
||||
# - There are no spaces in the image filename (OpenSfM requirement)
|
||||
# Automatically generate a sky mask
|
||||
|
||||
# Generate list of sky images
|
||||
sky_images = []
|
||||
for p in photos:
|
||||
if p.mask is None and (p.pitch is None or (-10 > p.pitch > 10)) and (not " " in p.filename):
|
||||
sky_images.append({'file': os.path.join(images_dir, p.filename), 'p': p})
|
||||
|
||||
if len(sky_images) > 0:
|
||||
log.ODM_INFO("Automatically generating sky masks for %s images" % len(sky_images))
|
||||
model = ai.get_model("skyremoval", "https://github.com/OpenDroneMap/SkyRemoval/releases/download/v1.0.5/model.zip", "v1.0.5")
|
||||
if model is not None:
|
||||
sf = SkyFilter(model=model)
|
||||
|
||||
def parallel_sky_filter(item):
|
||||
try:
|
||||
mask_file = sf.run_img(item['file'], images_dir)
|
||||
|
||||
# Check and set
|
||||
if mask_file is not None and os.path.isfile(mask_file):
|
||||
item['p'].set_mask(os.path.basename(mask_file))
|
||||
log.ODM_INFO("Wrote %s" % os.path.basename(mask_file))
|
||||
else:
|
||||
log.ODM_WARNING("Cannot generate mask for %s" % img)
|
||||
except Exception as e:
|
||||
log.ODM_WARNING("Cannot generate mask for %s: %s" % (img, str(e)))
|
||||
|
||||
parallel_map(parallel_sky_filter, sky_images, max_workers=args.max_concurrency)
|
||||
|
||||
log.ODM_INFO("Sky masks generation completed!")
|
||||
else:
|
||||
log.ODM_WARNING("Cannot load AI model (you might need to be connected to the internet?)")
|
||||
else:
|
||||
log.ODM_WARNING("No images suitable for sky mask generation detected (are they all nadir?)")
|
||||
|
||||
# End sky removal
|
||||
|
||||
# Save image database for faster restart
|
||||
save_images_database(photos, images_database_file)
|
||||
else:
|
||||
|
|
Ładowanie…
Reference in New Issue