import time import numpy as np import cv2 import os import onnxruntime as ort from .guidedfilter import guided_filter from opendm import log from threading import Lock mutex = Lock() # Use GPU if it is available, otherwise CPU provider = "CUDAExecutionProvider" if "CUDAExecutionProvider" in ort.get_available_providers() else "CPUExecutionProvider" class SkyFilter(): def __init__(self, model, width = 384, height = 384): self.model = model self.width, self.height = width, height log.ODM_INFO(' ?> Using provider %s' % provider) self.load_model() def load_model(self): log.ODM_INFO(' -> Loading the model') self.session = ort.InferenceSession(self.model, providers=[provider]) def get_mask(self, img): height, width, c = img.shape # Resize image to fit the model input new_img = cv2.resize(img, (self.width, self.height), interpolation=cv2.INTER_AREA) new_img = np.array(new_img, dtype=np.float32) # Input vector for onnx model input_v = np.expand_dims(new_img.transpose((2, 0, 1)), axis=0) ort_inputs = {self.session.get_inputs()[0].name: input_v} # Run the model with mutex: ort_outs = self.session.run(None, ort_inputs) # Get the output output = np.array(ort_outs) output = output[0][0].transpose((1, 2, 0)) output = cv2.resize(output, (width, height), interpolation=cv2.INTER_LANCZOS4) output = np.array([output, output, output]).transpose((1, 2, 0)) output = np.clip(output, a_max=1.0, a_min=0.0) return self.refine(output, img) def refine(self, pred, img): guided_filter_radius, guided_filter_eps = 20, 0.01 refined = guided_filter(img[:,:,2], pred[:,:,0], guided_filter_radius, guided_filter_eps) res = np.clip(refined, a_min=0, a_max=1) # Convert res to CV_8UC1 res = np.array(res * 255., dtype=np.uint8) # Thresholding res = cv2.threshold(res, 127, 255, cv2.THRESH_BINARY_INV)[1] return res def run_img(self, img_path, dest): img = cv2.imread(img_path, cv2.IMREAD_COLOR) if img is None: return None img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = np.array(img / 255., dtype=np.float32) mask = self.get_mask(img) img_name = os.path.basename(img_path) fpath = os.path.join(dest, img_name) fname, _ = os.path.splitext(fpath) mask_name = fname + '_mask.png' cv2.imwrite(mask_name, mask) return mask_name