kopia lustrzana https://github.com/lzzcd001/MeshDiffusion
first commit
commit
64edf07e23
|
@ -0,0 +1,133 @@
|
|||
<<<<<<< HEAD
|
||||
__pycache__
|
||||
=======
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
>>>>>>> 0755e60e97dfcaf72aa397a7dd807819692dc314
|
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2023 Zliu
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
|
@ -0,0 +1,106 @@
|
|||
# MeshDiffusion: Score-based Generative 3D Mesh Modeling (ICLR 2023 Spotlight)
|
||||
|
||||
This is the official implementation of MeshDiffusion.
|
||||
|
||||
MeshDiffusion is a diffusion model for generating 3D meshes with a direct parametrization of deep marching tetrahedra (DMTet). Please refer to https://meshdiffusion.github.io for more details.
|
||||
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Requirements
|
||||
|
||||
- Python >= 3.8
|
||||
- CUDA 11.6
|
||||
- Pytorch >= 1.6
|
||||
- Pytorch3D
|
||||
|
||||
|
||||
Install https://github.com/NVlabs/nvdiffrec
|
||||
|
||||
### Pretrained Models
|
||||
|
||||
Download the files from
|
||||
|
||||
## Inference
|
||||
|
||||
### Unconditional Generation
|
||||
|
||||
Run the following
|
||||
|
||||
```
|
||||
python main_diffusion.py --config=$DIFFUSION_CONFIG --mode=uncond_gen \
|
||||
--config.eval.eval_dir=$OUTPUT_PATH \
|
||||
--config.eval.ckpt_path=$CKPT_PATH
|
||||
```
|
||||
|
||||
Later run
|
||||
|
||||
```
|
||||
cd nvdiffrec
|
||||
python eval.py --config $DMTET_CONFIG --sample-path $SAMPLE_PATH
|
||||
```
|
||||
|
||||
where `$SAMPLE_PATH` is the generated sample `.npy` file in `$OUTPUT_PATH`
|
||||
|
||||
|
||||
### Single-view Conditional Generation
|
||||
|
||||
First fit a DMTet from a single view of a mesh
|
||||
|
||||
```
|
||||
cd nvdiffrec
|
||||
python fit_singleview.py --mesh-path $MESH_PATH --angle-ind $ANGLE_IND --out-dir $OUT_DIR --validate $VALIDATE
|
||||
```
|
||||
|
||||
Then use the trained diffusion model to complete the occluded regions
|
||||
|
||||
```
|
||||
cd ..
|
||||
python main_diffusion.py --config=$DIFFUSION_CONFIG --mode=cond_gen \
|
||||
--config.eval.eval_dir=$EVAL_DIR \
|
||||
--config.eval.ckpt_path=$CKPT_PATH \
|
||||
--config.eval.partial_dmtet_path=$OUT_DIR/tets/dmtet.pt \
|
||||
--config.eval.tet_path=$TET_PATH \
|
||||
--config.eval.batch_size=$EVAL_BATCH_SIZE
|
||||
```
|
||||
|
||||
Now visualize the completed meshes
|
||||
|
||||
```
|
||||
cd nvdiffrec
|
||||
python eval.py --config $DMTET_CONFIG --sample-path $SAMPLE_PATH
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
For ShapeNet, first create a list of paths of all ground-truth meshes and store them as a json file under `./nvdiffrec/data/shapenet_json`.
|
||||
|
||||
Then run the following
|
||||
|
||||
```
|
||||
cd nvdiffrec
|
||||
python fit_dmtets.py
|
||||
```
|
||||
|
||||
Create a meta file for diffusion model training:
|
||||
|
||||
```
|
||||
cd ../metadata/
|
||||
python save_meta.py
|
||||
```
|
||||
|
||||
Train a diffusion model
|
||||
|
||||
```
|
||||
cd ..
|
||||
python main_diffusion.py --mode=train \
|
||||
|
||||
```
|
||||
|
||||
## Texture Completion
|
||||
|
||||
Follow the instructions in https://github.com/TEXTurePaper/TEXTurePaper and create text-conditioned textures for the generated meshes.
|
||||
|
||||
## Acknowledgement
|
||||
|
||||
This repo is adapted from https://github.com/NVlabs/nvdiffrec and https://github.com/yang-song/score_sde_pytorch.
|
|
@ -0,0 +1,88 @@
|
|||
import ml_collections
|
||||
import torch
|
||||
|
||||
|
||||
def get_default_configs():
|
||||
config = ml_collections.ConfigDict()
|
||||
# training
|
||||
config.training = training = ml_collections.ConfigDict()
|
||||
config.training.batch_size = 64
|
||||
training.n_iters = 2400001
|
||||
training.snapshot_freq = 50000
|
||||
training.log_freq = 50
|
||||
training.eval_freq = 100
|
||||
## store additional checkpoints for preemption in cloud computing environments
|
||||
training.snapshot_freq_for_preemption = 5000
|
||||
## produce samples at each snapshot.
|
||||
training.snapshot_sampling = True
|
||||
training.likelihood_weighting = False
|
||||
training.continuous = True
|
||||
training.reduce_mean = False
|
||||
training.iter_size = 1
|
||||
training.loss_type = 'l2'
|
||||
training.train_dir = "PLACEHOLDER"
|
||||
|
||||
# sampling
|
||||
config.sampling = sampling = ml_collections.ConfigDict()
|
||||
sampling.n_steps_each = 1
|
||||
sampling.noise_removal = True
|
||||
sampling.probability_flow = False
|
||||
sampling.snr = 0.075
|
||||
|
||||
# evaluation
|
||||
config.eval = evaluate = ml_collections.ConfigDict()
|
||||
evaluate.begin_ckpt = 50
|
||||
evaluate.end_ckpt = 96
|
||||
evaluate.batch_size = 512
|
||||
evaluate.enable_sampling = True
|
||||
evaluate.num_samples = 50000
|
||||
evaluate.enable_loss = True
|
||||
evaluate.enable_bpd = False
|
||||
evaluate.bpd_dataset = 'test'
|
||||
evaluate.ckpt_path = "PLACEHOLDER"
|
||||
evaluate.partial_dmtet_path = "PLACEHOLDER"
|
||||
evaluate.tet_path = "PLACEHOLDER"
|
||||
evaluate.freeze_iters = 950
|
||||
|
||||
# data
|
||||
config.data = data = ml_collections.ConfigDict()
|
||||
data.dataset = 'LSUN'
|
||||
data.image_size = 256
|
||||
data.random_flip = True
|
||||
data.uniform_dequantization = False
|
||||
data.centered = False
|
||||
data.num_channels = 3
|
||||
data.num_workers = 4
|
||||
data.normalize_sdf = True
|
||||
data.meta_path = "PLACEHOLDER" ### metadata for all dataset files
|
||||
data.filter_meta_path = "PLACEHOLDER" ### metadata for the list of training samples
|
||||
|
||||
# model
|
||||
config.model = model = ml_collections.ConfigDict()
|
||||
model.sigma_max = 378
|
||||
model.sigma_min = 0.01
|
||||
model.num_scales = 2000
|
||||
model.beta_min = 0.1
|
||||
model.beta_max = 20.
|
||||
model.dropout = 0.
|
||||
model.embedding_type = 'fourier'
|
||||
model.deform_scale = 1.0
|
||||
|
||||
# optimization
|
||||
config.optim = optim = ml_collections.ConfigDict()
|
||||
optim.weight_decay = 0
|
||||
optim.optimizer = 'Adam'
|
||||
optim.lr = 2e-4
|
||||
optim.beta1 = 0.9
|
||||
optim.eps = 1e-8
|
||||
optim.warmup = 5000
|
||||
optim.grad_clip = 1.
|
||||
|
||||
config.seed = 42
|
||||
config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
||||
|
||||
|
||||
# rendering
|
||||
config.render = render = ml_collections.ConfigDict()
|
||||
|
||||
return config
|
|
@ -0,0 +1,78 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lint as: python3
|
||||
"""Config file for reproducing the results of DDPM on bedrooms."""
|
||||
|
||||
from configs.default_configs import get_default_configs
|
||||
|
||||
|
||||
def get_config():
|
||||
config = get_default_configs()
|
||||
|
||||
# training
|
||||
training = config.training
|
||||
training.sde = 'vpsde'
|
||||
training.continuous = False
|
||||
training.reduce_mean = True
|
||||
training.batch_size = 8
|
||||
training.lip_scale = None
|
||||
training.iter_size = 4
|
||||
|
||||
training.snapshot_freq_for_preemption = 1000
|
||||
|
||||
# sampling
|
||||
sampling = config.sampling
|
||||
sampling.method = 'pc'
|
||||
sampling.predictor = 'ancestral_sampling'
|
||||
sampling.corrector = 'none'
|
||||
|
||||
# data
|
||||
data = config.data
|
||||
data.dataset = 'ShapeNet'
|
||||
data.centered = True
|
||||
data.image_size = 128
|
||||
data.num_channels = 4
|
||||
data.meta_path = "PLACEHOLDER" ### metadata for all dataset files
|
||||
data.filter_meta_path = "PLACEHOLDER" ### metadata for the list of training samples
|
||||
data.num_workers = 8
|
||||
data.aug = True
|
||||
|
||||
|
||||
# model
|
||||
model = config.model
|
||||
model.name = 'ddpm_res128_v2'
|
||||
model.scale_by_sigma = False
|
||||
model.num_scales = 1000
|
||||
model.ema_rate = 0.9999
|
||||
model.normalization = 'GroupNorm'
|
||||
model.nonlinearity = 'swish'
|
||||
model.nf = 128
|
||||
model.ch_mult = (1, 1, 2, 4, 4, 4)
|
||||
model.num_res_blocks_first = 2
|
||||
model.num_res_blocks = 2
|
||||
model.attn_resolutions = (16,)
|
||||
model.resamp_with_conv = True
|
||||
model.conditional = True
|
||||
model.dropout = 0.1
|
||||
|
||||
# optim
|
||||
optim = config.optim
|
||||
optim.lr = 7e-5 / training.iter_size * 2.0
|
||||
|
||||
config.eval.batch_size = 7
|
||||
config.seed = 42
|
||||
|
||||
return config
|
|
@ -0,0 +1,79 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lint as: python3
|
||||
"""Config file for reproducing the results of DDPM on bedrooms."""
|
||||
|
||||
from configs.default_configs import get_default_configs
|
||||
|
||||
|
||||
def get_config():
|
||||
config = get_default_configs()
|
||||
|
||||
# training
|
||||
training = config.training
|
||||
training.sde = 'vpsde'
|
||||
training.continuous = False
|
||||
training.reduce_mean = True
|
||||
training.batch_size = 48
|
||||
training.lip_scale = None
|
||||
|
||||
training.snapshot_freq_for_preemption = 1000
|
||||
|
||||
# sampling
|
||||
sampling = config.sampling
|
||||
sampling.method = 'pc'
|
||||
sampling.predictor = 'ancestral_sampling'
|
||||
sampling.corrector = 'none'
|
||||
|
||||
# data
|
||||
data = config.data
|
||||
data.dataset = 'ShapeNet'
|
||||
data.centered = True
|
||||
data.image_size = 64
|
||||
data.num_channels = 4
|
||||
data.meta_path = "PLACEHOLDER" ### metadata for all dataset files
|
||||
data.filter_meta_path = "PLACEHOLDER" ### metadata for the list of training samples
|
||||
data.num_workers = 4
|
||||
data.aug = True
|
||||
|
||||
|
||||
# model
|
||||
model = config.model
|
||||
model.name = 'ddpm_res64'
|
||||
model.scale_by_sigma = False
|
||||
model.num_scales = 1000
|
||||
model.ema_rate = 0.9999
|
||||
model.normalization = 'GroupNorm'
|
||||
model.nonlinearity = 'swish'
|
||||
model.nf = 128
|
||||
model.ch_mult = (1, 1, 2, 4, 4)
|
||||
model.num_res_blocks_first = 2
|
||||
model.num_res_blocks = 3
|
||||
model.attn_resolutions = (16,)
|
||||
model.resamp_with_conv = True
|
||||
model.conditional = True
|
||||
model.dropout = 0.1
|
||||
|
||||
# optim
|
||||
optim = config.optim
|
||||
optim.lr = 2e-5
|
||||
|
||||
config.eval.batch_size = 4
|
||||
config.eval.eval_dir = "PLACEHOLDER"
|
||||
|
||||
config.seed = 42
|
||||
|
||||
return config
|
|
@ -0,0 +1,37 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import tqdm
|
||||
import argparse
|
||||
|
||||
def tet_to_grids(vertices, grid_size):
|
||||
|
||||
grid = torch.zeros(grid_size, grid_size, grid_size, device=vertices.device)
|
||||
with torch.no_grad():
|
||||
for i in tqdm.tqdm(range(vertices.size(0))):
|
||||
grid[vertices[i, 0].item(), vertices[i, 1].item(), vertices[i, 2].item()] = 1.0
|
||||
return grid
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--resolution', type=int)
|
||||
parser.add_argument('--tet_folder', type=str, default='../nvdiffrec/data/tets/')
|
||||
args = parser.parse_args()
|
||||
|
||||
tet_path = f'{args.tet_folder}/{args.resolution}_tets_cropped.npz'
|
||||
tet = np.load(tet_path)
|
||||
|
||||
vertices = torch.tensor(tet['vertices'])
|
||||
vertices_unique = vertices[:].unique()
|
||||
dx = vertices_unique[1] - vertices_unique[0]
|
||||
|
||||
vertices_discretized = (torch.round(
|
||||
(vertices - vertices.min()) / dx)
|
||||
).long()
|
||||
|
||||
grid = tet_to_grids(vertices_discretized, args.resolution)
|
||||
torch.save(grid, f'grid_mask_{args.resolution}.pt')
|
Plik binarny nie jest wyświetlany.
Plik binarny nie jest wyświetlany.
|
@ -0,0 +1,49 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import os
|
||||
import tqdm
|
||||
import argparse
|
||||
|
||||
def tet_to_grids(vertices, values_list, grid_size):
|
||||
grid = torch.zeros(4, grid_size, grid_size, grid_size, device=vertices.device)
|
||||
with torch.no_grad():
|
||||
for k, values in enumerate(values_list):
|
||||
if k == 0:
|
||||
grid[k, vertices[:, 0], vertices[:, 1], vertices[:, 2]] = values.squeeze()
|
||||
else:
|
||||
grid[1:, vertices[:, 0], vertices[:, 1], vertices[:, 2]] = values.transpose(0, 1)
|
||||
return grid
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='nvdiffrec')
|
||||
parser.add_argument('-res', '--resolution', type=int)
|
||||
parser.add_argument('-ss', '--split-size', type=int, default=int(1e8))
|
||||
parser.add_argument('-ind', '--index', type=int)
|
||||
parser.add_argument('-r', '--root', type=str)
|
||||
parser.add_argument('-s', '--source', type=str)
|
||||
parser.add_argument('-t', '--target', type=str)
|
||||
FLAGS = parser.parse_args()
|
||||
|
||||
tet_path = f'../nvdiffrec/data/tets/{FLAGS.resolution}_tets_cropped.npz'
|
||||
tet = np.load(tet_path)
|
||||
vertices = torch.tensor(tet['vertices'])
|
||||
vertices_unique = vertices[:].unique()
|
||||
dx = vertices_unique[1] - vertices_unique[0]
|
||||
vertices_discretized = (torch.round(
|
||||
(vertices - vertices.min()) / dx)
|
||||
).long()
|
||||
|
||||
save_folder = FLAGS.root
|
||||
|
||||
grid_folder = os.path.join(save_folder, FLAGS.target)
|
||||
os.makedirs(grid_folder, exist_ok=True)
|
||||
|
||||
tets_folder = os.path.join(save_folder, FLAGS.source)
|
||||
|
||||
for k in tqdm.trange(FLAGS.split_size):
|
||||
global_index = k + FLAGS.index * FLAGS.split_size
|
||||
tet_path = os.path.join(tets_folder, 'dmt_dict_{:05d}.pt'.format(global_index))
|
||||
if os.path.exists(tet_path):
|
||||
tet = torch.load(tet_path, map_location="cpu")
|
||||
grid = tet_to_grids(vertices_discretized, (tet['sdf'].unsqueeze(-1), tet['deform']), FLAGS.resolution)
|
||||
torch.save(grid, os.path.join(grid_folder, 'grid_{:05d}.pt'.format(global_index)))
|
|
@ -0,0 +1,43 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Training and evaluation"""
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from ml_collections.config_flags import config_flags
|
||||
|
||||
import lib.diffusion.trainer as trainer
|
||||
import lib.diffusion.evaler as evaler
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
config_flags.DEFINE_config_file(
|
||||
"config", None, "diffusion configs", lock_config=False)
|
||||
flags.DEFINE_enum("mode", None, ["train", "uncond_gen", "cond_gen"], "Running mode")
|
||||
flags.mark_flags_as_required(["config", "mode"])
|
||||
|
||||
|
||||
def main(argv):
|
||||
if FLAGS.mode == 'train':
|
||||
trainer.train(FLAGS.config)
|
||||
elif FLAGS.mode == 'uncond_gen':
|
||||
evaler.uncond_gen(FLAGS.config)
|
||||
elif FLAGS.mode == 'cond_gen':
|
||||
evaler.cond_gen(FLAGS.config)
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
|
@ -0,0 +1,14 @@
|
|||
import os
|
||||
import json
|
||||
import argparse
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--data_path', type=str)
|
||||
parser.add_argument('--json_path', type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
fpath_list = sorted([os.path.join(args.data_path, fname) for fname in os.listdir(root) if fname.endswith('.pt')])
|
||||
os.makedirs(args.json_path, exist_ok=True)
|
||||
json.dump(fpath_list, open(args.json_path, 'w'))
|
||||
|
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,18 @@
|
|||
{
|
||||
"random_textures": true,
|
||||
"iter": 5000,
|
||||
"save_interval": 100,
|
||||
"texture_res": [ 2048, 2048 ],
|
||||
"train_res": [1000, 1000],
|
||||
"batch": 4,
|
||||
"learning_rate": [0.01, 0.003],
|
||||
"ks_min" : [0, 0.08, 0.0],
|
||||
"dmtet_grid" : 128,
|
||||
"mesh_scale" : 1.1,
|
||||
"laplace_scale" : 10000,
|
||||
"display": [{"bsdf" : "kd"}, {"bsdf" : "ks"}, {"bsdf" : "normal"}, {"depth": true}],
|
||||
"background" : "white",
|
||||
"envmap": "./data/irrmaps/aerodynamics_workshop_2k.hdr",
|
||||
"tet_path": "./data/tets/128_tets_cropped.npz",
|
||||
"cropped": true
|
||||
}
|
|
@ -0,0 +1,18 @@
|
|||
{
|
||||
"random_textures": true,
|
||||
"iter": 5000,
|
||||
"save_interval": 100,
|
||||
"texture_res": [ 2048, 2048 ],
|
||||
"train_res": [1000, 1000],
|
||||
"batch": 4,
|
||||
"learning_rate": [0.01, 0.003],
|
||||
"ks_min" : [0, 0.08, 0.0],
|
||||
"dmtet_grid" : 64,
|
||||
"mesh_scale" : 1.1,
|
||||
"laplace_scale" : 10000,
|
||||
"display": [{"bsdf" : "kd"}, {"bsdf" : "ks"}, {"bsdf" : "normal"}, {"depth": true}],
|
||||
"background" : "white",
|
||||
"envmap": "./data/irrmaps/aerodynamics_workshop_2k.hdr",
|
||||
"tet_path": "./data/tets/64_tets_cropped.npz",
|
||||
"cropped": true
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
The aerodynamics_workshop_2k.hdr HDR probe is from https://polyhaven.com/a/aerodynamics_workshop
|
||||
CC0 License.
|
||||
|
Plik binarny nie jest wyświetlany.
Plik binarny nie jest wyświetlany.
Plik diff jest za duży
Load Diff
Plik diff jest za duży
Load Diff
Plik binarny nie jest wyświetlany.
Plik binarny nie jest wyświetlany.
|
@ -0,0 +1,6 @@
|
|||
Place the tet grid files in this folder.
|
||||
We provide a few example grids. See the main README.md for a download link.
|
||||
|
||||
You can also generate your own grids using https://github.com/crawforddoran/quartet
|
||||
Please see the `generate_tets.py` script for an example.
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
def crop_tets(vertices, indices):
|
||||
assert indices.shape[1] == 4
|
||||
vertices_cropped = np.array(vertices)
|
||||
mask = None
|
||||
for k in range(3):
|
||||
if mask is None:
|
||||
mask = (vertices[:, k] != np.min(vertices[:, k])) & (vertices[:, k] != np.max(vertices[:, k]))
|
||||
else:
|
||||
mask = (vertices[:, k] != np.min(vertices[:, k])) & (vertices[:, k] != np.max(vertices[:, k])) & mask
|
||||
print(f"remaining: {mask.sum()} out of {vertices.shape[0]}")
|
||||
|
||||
vertices_cropped = vertices[mask]
|
||||
|
||||
vert_inds = np.arange(vertices.shape[0])
|
||||
vert_inds_unused_mask = (1.0 - mask).astype(np.bool)
|
||||
verts_inds_unused = vert_inds[vert_inds_unused_mask]
|
||||
|
||||
print(f"{verts_inds_unused.shape[0]} out of {vertices.shape[0]}")
|
||||
|
||||
remapping = defaultdict(lambda : -1)
|
||||
count = 0
|
||||
for i in range(vertices.shape[0]):
|
||||
if mask[i]:
|
||||
remapping[i] = count
|
||||
count += 1
|
||||
|
||||
indices_cropped = np.zeros_like(indices, dtype=np.int32)
|
||||
count = 0
|
||||
for i in range(indices.shape[0]):
|
||||
flag = True
|
||||
tmp = np.zeros((4,))
|
||||
for k in range(4):
|
||||
if remapping[indices[i, k]] == -1:
|
||||
flag = False
|
||||
break
|
||||
else:
|
||||
tmp[k] = remapping[indices[i, k]]
|
||||
|
||||
if flag:
|
||||
indices_cropped[count, :] = tmp[:]
|
||||
count += 1
|
||||
|
||||
if i % 1000 == 0:
|
||||
print(f"iter {i} / {indices.shape[0]}")
|
||||
|
||||
print(vertices_cropped.shape[0], np.min(indices_cropped), np.max(indices_cropped))
|
||||
indices_cropped = indices_cropped[:count]
|
||||
return vertices_cropped, indices_cropped
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--resolution', type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
resolution = args.resolution
|
||||
npzfile = f'{resolution}_tets.npz'
|
||||
data = np.load(npzfile)
|
||||
new_verts, new_inds = crop_tets(data['vertices'], data['indices'])
|
||||
np.savez_compressed(f'{resolution}_tets_cropped.npz', vertices=new_verts, indices=new_inds)
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
'''
|
||||
This code segment shows how to use Quartet: https://github.com/crawforddoran/quartet,
|
||||
to generate a tet grid
|
||||
1) Download, compile and run Quartet as described in the link above. Example usage `quartet meshes/cube.obj 0.5 cube_5.tet`
|
||||
2) Run the function below to generate a file `cube_32_tet.tet`
|
||||
'''
|
||||
|
||||
def generate_tetrahedron_grid_file(res=32, root='..'):
|
||||
frac = 1.0 / res
|
||||
command = 'cd %s/quartet; ' % (root) + \
|
||||
'./quartet meshes/cube.obj %f meshes/cube_%f_tet.tet -s meshes/cube_boundary_%f.obj' % (frac, res, res)
|
||||
os.system(command)
|
||||
|
||||
|
||||
'''
|
||||
This code segment shows how to convert from a quartet .tet file to compressed npz file
|
||||
'''
|
||||
def convert_from_quartet_to_npz(quartetfile = 'cube_32_tet.tet', npzfile = '32_tets'):
|
||||
|
||||
file1 = open(quartetfile, 'r')
|
||||
header = file1.readline()
|
||||
numvertices = int(header.split(" ")[1])
|
||||
numtets = int(header.split(" ")[2])
|
||||
print(numvertices, numtets)
|
||||
|
||||
# load vertices
|
||||
vertices = np.loadtxt(quartetfile, skiprows=1, max_rows=numvertices)
|
||||
print(vertices.shape)
|
||||
|
||||
# load indices
|
||||
indices = np.loadtxt(quartetfile, dtype=int, skiprows=1+numvertices, max_rows=numtets)
|
||||
print(indices.shape)
|
||||
|
||||
np.savez_compressed(npzfile, vertices=vertices, indices=indices)
|
|
@ -0,0 +1,452 @@
|
|||
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
|
||||
import glob
|
||||
import tqdm
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import nvdiffrast.torch as dr
|
||||
import xatlas
|
||||
|
||||
# Import topology / geometry trainers
|
||||
from lib.geometry.dmtet import DMTetGeometry
|
||||
|
||||
import lib.render.renderutils as ru
|
||||
from lib.render import material
|
||||
from lib.render import util
|
||||
from lib.render import mesh
|
||||
from lib.render import texture
|
||||
from lib.render import mlptexture
|
||||
from lib.render import light
|
||||
from lib.render import render
|
||||
|
||||
from pytorch3d.io import save_obj
|
||||
|
||||
import pymeshlab
|
||||
|
||||
RADIUS = 3.0
|
||||
|
||||
# Enable to debug back-prop anomalies
|
||||
# torch.autograd.set_detect_anomaly(True)
|
||||
|
||||
###############################################################################
|
||||
# Loss setup
|
||||
###############################################################################
|
||||
|
||||
@torch.no_grad()
|
||||
def createLoss(FLAGS):
|
||||
if FLAGS.loss == "smape":
|
||||
return lambda img, ref: ru.image_loss(img, ref, loss='smape', tonemapper='none')
|
||||
elif FLAGS.loss == "mse":
|
||||
return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='none')
|
||||
elif FLAGS.loss == "logl1":
|
||||
return lambda img, ref: ru.image_loss(img, ref, loss='l1', tonemapper='log_srgb')
|
||||
elif FLAGS.loss == "logl2":
|
||||
return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='log_srgb')
|
||||
elif FLAGS.loss == "relmse":
|
||||
return lambda img, ref: ru.image_loss(img, ref, loss='relmse', tonemapper='none')
|
||||
else:
|
||||
assert False
|
||||
|
||||
###############################################################################
|
||||
# Mix background into a dataset image
|
||||
###############################################################################
|
||||
|
||||
@torch.no_grad()
|
||||
def prepare_batch(target, bg_type='black'):
|
||||
# assert len(target['img'].shape) == 4, "Image shape should be [n, h, w, c]"
|
||||
if bg_type == 'checker':
|
||||
background = torch.tensor(util.checkerboard(target['img'].shape[1:3], 8), dtype=torch.float32, device='cuda')[None, ...]
|
||||
elif bg_type == 'black':
|
||||
background = torch.zeros((1, target['resolution'][0], target['resolution'][1]) + (3,), dtype=torch.float32, device='cuda')
|
||||
elif bg_type == 'white':
|
||||
background = torch.ones((1, target['resolution'][0], target['resolution'][1]) + (3,), dtype=torch.float32, device='cuda')
|
||||
elif bg_type == 'reference':
|
||||
background = target['img'][..., 0:3]
|
||||
elif bg_type == 'random':
|
||||
background = torch.rand((1, target['resolution'][0], target['resolution'][1]) + (3,), dtype=torch.float32, device='cuda')
|
||||
else:
|
||||
assert False, "Unknown background type %s" % bg_type
|
||||
|
||||
target['mv'] = target['mv'].cuda()
|
||||
target['mvp'] = target['mvp'].cuda()
|
||||
target['campos'] = target['campos'].cuda()
|
||||
target['background'] = background
|
||||
|
||||
return target
|
||||
|
||||
###############################################################################
|
||||
# UV - map geometry & convert to a mesh
|
||||
###############################################################################
|
||||
|
||||
@torch.no_grad()
|
||||
def xatlas_uvmap(glctx, geometry, mat, FLAGS):
|
||||
eval_mesh = geometry.getMesh(mat)
|
||||
|
||||
# Create uvs with xatlas
|
||||
v_pos = eval_mesh.v_pos.detach().cpu().numpy()
|
||||
t_pos_idx = eval_mesh.t_pos_idx.detach().cpu().numpy()
|
||||
vmapping, indices, uvs = xatlas.parametrize(v_pos, t_pos_idx)
|
||||
|
||||
# Convert to tensors
|
||||
indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
|
||||
|
||||
uvs = torch.tensor(uvs, dtype=torch.float32, device='cuda')
|
||||
faces = torch.tensor(indices_int64, dtype=torch.int64, device='cuda')
|
||||
|
||||
new_mesh = mesh.Mesh(v_tex=uvs, t_tex_idx=faces, base=eval_mesh)
|
||||
|
||||
mask, kd, ks, normal = render.render_uv(glctx, new_mesh, FLAGS.texture_res, eval_mesh.material['kd_ks_normal'])
|
||||
|
||||
if FLAGS.layers > 1:
|
||||
kd = torch.cat((kd, torch.rand_like(kd[...,0:1])), dim=-1)
|
||||
|
||||
kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda')
|
||||
ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda')
|
||||
nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')
|
||||
|
||||
new_mesh.material = material.Material({
|
||||
'bsdf' : mat['bsdf'],
|
||||
'kd' : texture.Texture2D(kd, min_max=[kd_min, kd_max]),
|
||||
'ks' : texture.Texture2D(ks, min_max=[ks_min, ks_max]),
|
||||
'normal' : texture.Texture2D(normal, min_max=[nrm_min, nrm_max])
|
||||
})
|
||||
|
||||
return new_mesh
|
||||
|
||||
###############################################################################
|
||||
# Utility functions for material
|
||||
###############################################################################
|
||||
|
||||
def initial_guess_material(geometry, mlp, FLAGS, init_mat=None):
|
||||
kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda')
|
||||
ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda')
|
||||
nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')
|
||||
if mlp:
|
||||
mlp_min = torch.cat((kd_min[0:3], ks_min, nrm_min), dim=0)
|
||||
mlp_max = torch.cat((kd_max[0:3], ks_max, nrm_max), dim=0)
|
||||
mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=9, min_max=[mlp_min, mlp_max])
|
||||
mat = material.Material({'kd_ks_normal' : mlp_map_opt})
|
||||
else:
|
||||
# Setup Kd (albedo) and Ks (x, roughness, metalness) textures
|
||||
if FLAGS.random_textures or init_mat is None:
|
||||
num_channels = 4 if FLAGS.layers > 1 else 3
|
||||
kd_init = torch.rand(size=FLAGS.texture_res + [num_channels], device='cuda') * (kd_max - kd_min)[None, None, 0:num_channels] + kd_min[None, None, 0:num_channels]
|
||||
kd_map_opt = texture.create_trainable(kd_init , FLAGS.texture_res, not FLAGS.custom_mip, [kd_min, kd_max])
|
||||
|
||||
ksR = np.random.uniform(size=FLAGS.texture_res + [1], low=0.0, high=0.01)
|
||||
ksG = np.random.uniform(size=FLAGS.texture_res + [1], low=ks_min[1].cpu(), high=ks_max[1].cpu())
|
||||
ksB = np.random.uniform(size=FLAGS.texture_res + [1], low=ks_min[2].cpu(), high=ks_max[2].cpu())
|
||||
|
||||
ks_map_opt = texture.create_trainable(np.concatenate((ksR, ksG, ksB), axis=2), FLAGS.texture_res, not FLAGS.custom_mip, [ks_min, ks_max])
|
||||
else:
|
||||
kd_map_opt = texture.create_trainable(init_mat['kd'], FLAGS.texture_res, not FLAGS.custom_mip, [kd_min, kd_max])
|
||||
ks_map_opt = texture.create_trainable(init_mat['ks'], FLAGS.texture_res, not FLAGS.custom_mip, [ks_min, ks_max])
|
||||
|
||||
# Setup normal map
|
||||
if FLAGS.random_textures or init_mat is None or 'normal' not in init_mat:
|
||||
normal_map_opt = texture.create_trainable(np.array([0, 0, 1]), FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max])
|
||||
else:
|
||||
normal_map_opt = texture.create_trainable(init_mat['normal'], FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max])
|
||||
|
||||
mat = material.Material({
|
||||
'kd' : kd_map_opt,
|
||||
'ks' : ks_map_opt,
|
||||
'normal' : normal_map_opt
|
||||
})
|
||||
|
||||
if init_mat is not None:
|
||||
mat['bsdf'] = init_mat['bsdf']
|
||||
else:
|
||||
mat['bsdf'] = 'pbr'
|
||||
|
||||
return mat
|
||||
|
||||
###############################################################################
|
||||
# Validation & testing
|
||||
###############################################################################
|
||||
|
||||
def rotate_scene(FLAGS, itr):
|
||||
fovy = np.deg2rad(45)
|
||||
cam_radius = RADIUS
|
||||
proj_mtx = util.perspective(fovy, FLAGS.display_res[1] / FLAGS.display_res[0], FLAGS.cam_near_far[0], FLAGS.cam_near_far[1])
|
||||
|
||||
# Smooth rotation for display.
|
||||
ang = (itr / 50) * np.pi * 2
|
||||
mv = util.translate(0, 0, -cam_radius) @ (util.rotate_x(-0.4) @ util.rotate_y(ang))
|
||||
mvp = proj_mtx @ mv
|
||||
campos = torch.linalg.inv(mv)[:3, 3]
|
||||
|
||||
res_dict = {
|
||||
'mv': mv[None, ...].cuda(),
|
||||
'mvp': mvp[None, ...].cuda(),
|
||||
'campos': campos[None, ...].cuda(),
|
||||
'spp': 1,
|
||||
'resolution': FLAGS.display_res
|
||||
}
|
||||
|
||||
return res_dict
|
||||
|
||||
def validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS):
|
||||
result_dict = {}
|
||||
with torch.no_grad():
|
||||
lgt.build_mips()
|
||||
if FLAGS.camera_space_light:
|
||||
lgt.xfm(target['mv'])
|
||||
|
||||
buffers = geometry.render(glctx, target, lgt, opt_material)
|
||||
|
||||
result_dict['opt'] = util.rgb_to_srgb(buffers['shaded'][...,0:3])[0]
|
||||
result_image = result_dict['opt']
|
||||
|
||||
return result_image, result_dict
|
||||
|
||||
def validate(glctx, geometry, opt_material, lgt, dataset_validate, out_dir, FLAGS):
|
||||
|
||||
# ==============================================================================================
|
||||
# Validation loop
|
||||
# ==============================================================================================
|
||||
mse_values = []
|
||||
psnr_values = []
|
||||
|
||||
dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_validate.collate)
|
||||
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
with open(os.path.join(out_dir, 'metrics.txt'), 'w') as fout:
|
||||
fout.write('ID, MSE, PSNR\n')
|
||||
|
||||
print("Running validation")
|
||||
for it, target in enumerate(dataloader_validate):
|
||||
|
||||
# Mix validation background
|
||||
target = prepare_batch(target, FLAGS.background)
|
||||
|
||||
result_image, result_dict = validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS)
|
||||
|
||||
# Compute metrics
|
||||
opt = torch.clamp(result_dict['opt'], 0.0, 1.0)
|
||||
ref = torch.clamp(result_dict['ref'], 0.0, 1.0)
|
||||
|
||||
mse = torch.nn.functional.mse_loss(opt, ref, size_average=None, reduce=None, reduction='mean').item()
|
||||
mse_values.append(float(mse))
|
||||
psnr = util.mse_to_psnr(mse)
|
||||
psnr_values.append(float(psnr))
|
||||
|
||||
line = "%d, %1.8f, %1.8f\n" % (it, mse, psnr)
|
||||
fout.write(str(line))
|
||||
|
||||
for k in result_dict.keys():
|
||||
np_img = result_dict[k].detach().cpu().numpy()
|
||||
util.save_image(out_dir + '/' + ('val_%06d_%s.png' % (it, k)), np_img)
|
||||
|
||||
avg_mse = np.mean(np.array(mse_values))
|
||||
avg_psnr = np.mean(np.array(psnr_values))
|
||||
line = "AVERAGES: %1.4f, %2.3f\n" % (avg_mse, avg_psnr)
|
||||
fout.write(str(line))
|
||||
print("MSE, PSNR")
|
||||
print("%1.8f, %2.3f" % (avg_mse, avg_psnr))
|
||||
return avg_psnr
|
||||
|
||||
###############################################################################
|
||||
# Main shape fitter function / optimization loop
|
||||
###############################################################################
|
||||
|
||||
class Trainer(torch.nn.Module):
|
||||
def __init__(self, glctx, geometry, lgt, mat, optimize_geometry, optimize_light, image_loss_fn, FLAGS):
|
||||
super(Trainer, self).__init__()
|
||||
|
||||
self.glctx = glctx
|
||||
self.geometry = geometry
|
||||
self.light = lgt
|
||||
self.material = mat
|
||||
self.optimize_geometry = optimize_geometry
|
||||
self.optimize_light = optimize_light
|
||||
self.image_loss_fn = image_loss_fn
|
||||
self.FLAGS = FLAGS
|
||||
|
||||
if not self.optimize_light:
|
||||
with torch.no_grad():
|
||||
self.light.build_mips()
|
||||
|
||||
self.params = list(self.material.parameters())
|
||||
self.params += list(self.light.parameters()) if optimize_light else []
|
||||
self.geo_params = list(self.geometry.parameters()) if optimize_geometry else []
|
||||
|
||||
def forward(self, target, it):
|
||||
if self.optimize_light:
|
||||
self.light.build_mips()
|
||||
if self.FLAGS.camera_space_light:
|
||||
self.light.xfm(target['mv'])
|
||||
|
||||
return self.geometry.tick(glctx, target, self.light, self.material, self.image_loss_fn, it)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Main function.
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='nvdiffrec')
|
||||
parser.add_argument('--config', type=str, default=None, help='Config file')
|
||||
parser.add_argument('-i', '--iter', type=int, default=5000)
|
||||
parser.add_argument('-b', '--batch', type=int, default=1)
|
||||
parser.add_argument('-s', '--spp', type=int, default=1)
|
||||
parser.add_argument('-l', '--layers', type=int, default=1)
|
||||
parser.add_argument('-r', '--train-res', nargs=2, type=int, default=[512, 512])
|
||||
parser.add_argument('-dr', '--display-res', type=int, default=None)
|
||||
parser.add_argument('-tr', '--texture-res', nargs=2, type=int, default=[1024, 1024])
|
||||
parser.add_argument('-di', '--display-interval', type=int, default=0)
|
||||
parser.add_argument('-si', '--save-interval', type=int, default=1000)
|
||||
parser.add_argument('-lr', '--learning-rate', type=float, default=0.01)
|
||||
parser.add_argument('-mr', '--min-roughness', type=float, default=0.08)
|
||||
parser.add_argument('-mip', '--custom-mip', action='store_true', default=False)
|
||||
parser.add_argument('-rt', '--random-textures', action='store_true', default=False)
|
||||
parser.add_argument('-bg', '--background', default='checker', choices=['black', 'white', 'checker', 'reference'])
|
||||
parser.add_argument('-o', '--out-dir', type=str, default='./viz_tet')
|
||||
parser.add_argument('-sp', '--sample-path', type=str, default=None)
|
||||
parser.add_argument('-bm', '--base-mesh', type=str, default=None)
|
||||
parser.add_argument('-ds', '--deform-scale', type=float, default=2.0)
|
||||
parser.add_argument('-vn', '--viz-name', type=str, default='viz')
|
||||
parser.add_argument('--unnormalized_sdf', action="store_true")
|
||||
parser.add_argument('--validate', type=bool, default=True)
|
||||
|
||||
FLAGS = parser.parse_args()
|
||||
|
||||
FLAGS.mtl_override = None # Override material of model
|
||||
FLAGS.dmtet_grid = 64 # Resolution of initial tet grid. We provide 64 and 128 resolution grids. Other resolutions can be generated with https://github.com/crawforddoran/quartet
|
||||
FLAGS.mesh_scale = 2.1 # Scale of tet grid box. Adjust to cover the model
|
||||
FLAGS.env_scale = 1.0 # Env map intensity multiplier
|
||||
FLAGS.envmap = None # HDR environment probe
|
||||
FLAGS.display = None # Conf validation window/display. E.g. [{"relight" : <path to envlight>}]
|
||||
FLAGS.camera_space_light = False # Fixed light in camera space. This is needed for setups like ethiopian head where the scanned object rotates on a stand.
|
||||
FLAGS.lock_light = False # Disable light optimization in the second pass
|
||||
FLAGS.lock_pos = False # Disable vertex position optimization in the second pass
|
||||
FLAGS.sdf_regularizer = 0.2 # Weight for sdf regularizer (see paper for details)
|
||||
FLAGS.laplace = "relative" # Mesh Laplacian ["absolute", "relative"]
|
||||
FLAGS.laplace_scale = 10000.0 # Weight for sdf regularizer. Default is relative with large weight
|
||||
FLAGS.pre_load = True # Pre-load entire dataset into memory for faster training
|
||||
FLAGS.kd_min = [ 0.0, 0.0, 0.0, 0.0] # Limits for kd
|
||||
FLAGS.kd_max = [ 1.0, 1.0, 1.0, 1.0]
|
||||
FLAGS.ks_min = [ 0.0, 0.08, 0.0] # Limits for ks
|
||||
FLAGS.ks_max = [ 1.0, 1.0, 1.0]
|
||||
FLAGS.nrm_min = [-1.0, -1.0, 0.0] # Limits for normal map
|
||||
FLAGS.nrm_max = [ 1.0, 1.0, 1.0]
|
||||
FLAGS.cam_near_far = [0.1, 1000.0]
|
||||
FLAGS.learn_light = False
|
||||
FLAGS.cropped = False
|
||||
FLAGS.random_lgt = False
|
||||
|
||||
if FLAGS.config is not None:
|
||||
data = json.load(open(FLAGS.config, 'r'))
|
||||
for key in data:
|
||||
FLAGS.__dict__[key] = data[key]
|
||||
|
||||
if FLAGS.display_res is None:
|
||||
FLAGS.display_res = FLAGS.train_res
|
||||
|
||||
os.makedirs(FLAGS.out_dir, exist_ok=True)
|
||||
viz_path = os.path.join(FLAGS.out_dir, 'viz')
|
||||
mesh_path = os.path.join(FLAGS.out_dir, 'mesh')
|
||||
os.makedirs(viz_path, exist_ok=True)
|
||||
os.makedirs(mesh_path, exist_ok=True)
|
||||
|
||||
glctx = dr.RasterizeGLContext()
|
||||
|
||||
# ==============================================================================================
|
||||
# Create env light with trainable parameters
|
||||
# ==============================================================================================
|
||||
|
||||
lgt = light.load_env(FLAGS.envmap, scale=FLAGS.env_scale)
|
||||
|
||||
# ==============================================================================================
|
||||
# If no initial guess, use DMtets to create geometry
|
||||
# ==============================================================================================
|
||||
|
||||
# Setup geometry for optimization
|
||||
resolution = FLAGS.dmtet_grid
|
||||
geometry = DMTetGeometry(resolution, FLAGS.mesh_scale, FLAGS)
|
||||
geometry.deform_scale = FLAGS.deform_scale
|
||||
|
||||
mask = torch.load(f'../data/grid_mask_{resolution}.pt').view(1, resolution, resolution, resolution).to("cuda")
|
||||
|
||||
|
||||
### compute the mapping from tet indices to 3D cubic grid vertex indices
|
||||
tet_path = FLAGS.tet_path
|
||||
tet = np.load(tet_path)
|
||||
vertices = torch.tensor(tet['vertices'])
|
||||
vertices_unique = vertices[:].unique()
|
||||
dx = vertices_unique[1] - vertices_unique[0]
|
||||
|
||||
vertices_discretized = (torch.round(
|
||||
(vertices - vertices.min()) / dx)
|
||||
).long()
|
||||
|
||||
data_all = np.load(FLAGS.sample_path)
|
||||
print('shape of generated data', data_all.shape)
|
||||
|
||||
for no_data in tqdm.trange(data_all.shape[0]):
|
||||
|
||||
grid = torch.tensor(data_all[no_data])
|
||||
if FLAGS.unnormalized_sdf:
|
||||
raise NotImplementedError
|
||||
geometry.sdf.data[:] = (
|
||||
grid[0, vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]]
|
||||
).cuda()
|
||||
else:
|
||||
geometry.sdf.data[:] = torch.sign(
|
||||
grid[0, vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]]
|
||||
).cuda()
|
||||
geometry.deform.data[:] = (
|
||||
grid[1:, vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]]
|
||||
).cuda().transpose(0, 1)
|
||||
|
||||
geometry.deform.data[:] = geometry.deform.data[:].clip(-1.0, 1.0)
|
||||
|
||||
### mtl for visualization
|
||||
opt_material = {
|
||||
'name' : '_default_mat',
|
||||
# 'bsdf' : 'pbr',
|
||||
'bsdf' : 'diffuse',
|
||||
'kd' : texture.Texture2D(torch.tensor([0.75, 0.3, 0.6], dtype=torch.float32, device='cuda')),
|
||||
'ks' : texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda'))
|
||||
}
|
||||
|
||||
### create and optimize mesh
|
||||
base_mesh = geometry.getMesh(opt_material)
|
||||
|
||||
|
||||
### save image (before post-processing)
|
||||
v_pose = rotate_scene(FLAGS, 25) ## pick a pose (pose # from 0 to 50)
|
||||
result_image, _ = validate_itr(glctx, prepare_batch(v_pose, FLAGS.background), geometry, opt_material, lgt, FLAGS)
|
||||
result_image = result_image.detach().cpu().numpy()
|
||||
util.save_image(os.path.join(viz_path, ('%s_%06d.png' % (FLAGS.viz_name, no_data))), result_image)
|
||||
|
||||
|
||||
# ### save post-processed mesh
|
||||
# mesh_savepath = os.path.join(mesh_path, '{:06d}.obj'.format(no_data))
|
||||
# save_obj(
|
||||
# verts=base_mesh.v_pos,
|
||||
# faces=base_mesh.t_pos_idx,
|
||||
# f=mesh_savepath
|
||||
# )
|
||||
|
||||
# ms = pymeshlab.MeshSet()
|
||||
# ms.load_new_mesh(mesh_savepath)
|
||||
# ms.meshing_isotropic_explicit_remeshing()
|
||||
# ms.apply_coord_laplacian_smoothing(stepsmoothnum=3, cotangentweight=False)
|
||||
# # ms.apply_coord_laplacian_smoothing(stepsmoothnum=3, cotangentweight=True) ## for smoother surface
|
||||
# ms.meshing_isotropic_explicit_remeshing()
|
||||
# ms.apply_filter_script()
|
||||
# ms.save_current_mesh(mesh_savepath)
|
|
@ -0,0 +1,451 @@
|
|||
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
|
||||
import glob
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import nvdiffrast.torch as dr
|
||||
import xatlas
|
||||
|
||||
# Import topology / geometry trainers
|
||||
from lib.geometry.dmtet import DMTetGeometry
|
||||
|
||||
import lib.render.renderutils as ru
|
||||
from lib.render import material
|
||||
from lib.render import util
|
||||
from lib.render import mesh
|
||||
from lib.render import texture
|
||||
from lib.render import mlptexture
|
||||
from lib.render import light
|
||||
from lib.render import render
|
||||
|
||||
from pytorch3d.io import save_obj
|
||||
|
||||
import pymeshlab
|
||||
|
||||
RADIUS = 3.0
|
||||
|
||||
# Enable to debug back-prop anomalies
|
||||
# torch.autograd.set_detect_anomaly(True)
|
||||
|
||||
###############################################################################
|
||||
# Loss setup
|
||||
###############################################################################
|
||||
|
||||
@torch.no_grad()
|
||||
def createLoss(FLAGS):
|
||||
if FLAGS.loss == "smape":
|
||||
return lambda img, ref: ru.image_loss(img, ref, loss='smape', tonemapper='none')
|
||||
elif FLAGS.loss == "mse":
|
||||
return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='none')
|
||||
elif FLAGS.loss == "logl1":
|
||||
return lambda img, ref: ru.image_loss(img, ref, loss='l1', tonemapper='log_srgb')
|
||||
elif FLAGS.loss == "logl2":
|
||||
return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='log_srgb')
|
||||
elif FLAGS.loss == "relmse":
|
||||
return lambda img, ref: ru.image_loss(img, ref, loss='relmse', tonemapper='none')
|
||||
else:
|
||||
assert False
|
||||
|
||||
###############################################################################
|
||||
# Mix background into a dataset image
|
||||
###############################################################################
|
||||
|
||||
@torch.no_grad()
|
||||
def prepare_batch(target, bg_type='black'):
|
||||
# assert len(target['img'].shape) == 4, "Image shape should be [n, h, w, c]"
|
||||
if bg_type == 'checker':
|
||||
background = torch.tensor(util.checkerboard(target['img'].shape[1:3], 8), dtype=torch.float32, device='cuda')[None, ...]
|
||||
elif bg_type == 'black':
|
||||
background = torch.zeros((1, target['resolution'][0], target['resolution'][1]) + (3,), dtype=torch.float32, device='cuda')
|
||||
elif bg_type == 'white':
|
||||
background = torch.ones((1, target['resolution'][0], target['resolution'][1]) + (3,), dtype=torch.float32, device='cuda')
|
||||
elif bg_type == 'reference':
|
||||
background = target['img'][..., 0:3]
|
||||
elif bg_type == 'random':
|
||||
background = torch.rand((1, target['resolution'][0], target['resolution'][1]) + (3,), dtype=torch.float32, device='cuda')
|
||||
else:
|
||||
assert False, "Unknown background type %s" % bg_type
|
||||
|
||||
target['mv'] = target['mv'].cuda()
|
||||
target['mvp'] = target['mvp'].cuda()
|
||||
target['campos'] = target['campos'].cuda()
|
||||
target['background'] = background
|
||||
|
||||
return target
|
||||
|
||||
###############################################################################
|
||||
# UV - map geometry & convert to a mesh
|
||||
###############################################################################
|
||||
|
||||
@torch.no_grad()
|
||||
def xatlas_uvmap(glctx, geometry, mat, FLAGS):
|
||||
eval_mesh = geometry.getMesh(mat)
|
||||
|
||||
# Create uvs with xatlas
|
||||
v_pos = eval_mesh.v_pos.detach().cpu().numpy()
|
||||
t_pos_idx = eval_mesh.t_pos_idx.detach().cpu().numpy()
|
||||
vmapping, indices, uvs = xatlas.parametrize(v_pos, t_pos_idx)
|
||||
|
||||
# Convert to tensors
|
||||
indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
|
||||
|
||||
uvs = torch.tensor(uvs, dtype=torch.float32, device='cuda')
|
||||
faces = torch.tensor(indices_int64, dtype=torch.int64, device='cuda')
|
||||
|
||||
new_mesh = mesh.Mesh(v_tex=uvs, t_tex_idx=faces, base=eval_mesh)
|
||||
|
||||
mask, kd, ks, normal = render.render_uv(glctx, new_mesh, FLAGS.texture_res, eval_mesh.material['kd_ks_normal'])
|
||||
|
||||
if FLAGS.layers > 1:
|
||||
kd = torch.cat((kd, torch.rand_like(kd[...,0:1])), dim=-1)
|
||||
|
||||
kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda')
|
||||
ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda')
|
||||
nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')
|
||||
|
||||
new_mesh.material = material.Material({
|
||||
'bsdf' : mat['bsdf'],
|
||||
'kd' : texture.Texture2D(kd, min_max=[kd_min, kd_max]),
|
||||
'ks' : texture.Texture2D(ks, min_max=[ks_min, ks_max]),
|
||||
'normal' : texture.Texture2D(normal, min_max=[nrm_min, nrm_max])
|
||||
})
|
||||
|
||||
return new_mesh
|
||||
|
||||
###############################################################################
|
||||
# Utility functions for material
|
||||
###############################################################################
|
||||
|
||||
def initial_guess_material(geometry, mlp, FLAGS, init_mat=None):
|
||||
kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda')
|
||||
ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda')
|
||||
nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')
|
||||
if mlp:
|
||||
mlp_min = torch.cat((kd_min[0:3], ks_min, nrm_min), dim=0)
|
||||
mlp_max = torch.cat((kd_max[0:3], ks_max, nrm_max), dim=0)
|
||||
mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=9, min_max=[mlp_min, mlp_max])
|
||||
mat = material.Material({'kd_ks_normal' : mlp_map_opt})
|
||||
else:
|
||||
# Setup Kd (albedo) and Ks (x, roughness, metalness) textures
|
||||
if FLAGS.random_textures or init_mat is None:
|
||||
num_channels = 4 if FLAGS.layers > 1 else 3
|
||||
kd_init = torch.rand(size=FLAGS.texture_res + [num_channels], device='cuda') * (kd_max - kd_min)[None, None, 0:num_channels] + kd_min[None, None, 0:num_channels]
|
||||
kd_map_opt = texture.create_trainable(kd_init , FLAGS.texture_res, not FLAGS.custom_mip, [kd_min, kd_max])
|
||||
|
||||
ksR = np.random.uniform(size=FLAGS.texture_res + [1], low=0.0, high=0.01)
|
||||
ksG = np.random.uniform(size=FLAGS.texture_res + [1], low=ks_min[1].cpu(), high=ks_max[1].cpu())
|
||||
ksB = np.random.uniform(size=FLAGS.texture_res + [1], low=ks_min[2].cpu(), high=ks_max[2].cpu())
|
||||
|
||||
ks_map_opt = texture.create_trainable(np.concatenate((ksR, ksG, ksB), axis=2), FLAGS.texture_res, not FLAGS.custom_mip, [ks_min, ks_max])
|
||||
else:
|
||||
kd_map_opt = texture.create_trainable(init_mat['kd'], FLAGS.texture_res, not FLAGS.custom_mip, [kd_min, kd_max])
|
||||
ks_map_opt = texture.create_trainable(init_mat['ks'], FLAGS.texture_res, not FLAGS.custom_mip, [ks_min, ks_max])
|
||||
|
||||
# Setup normal map
|
||||
if FLAGS.random_textures or init_mat is None or 'normal' not in init_mat:
|
||||
normal_map_opt = texture.create_trainable(np.array([0, 0, 1]), FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max])
|
||||
else:
|
||||
normal_map_opt = texture.create_trainable(init_mat['normal'], FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max])
|
||||
|
||||
mat = material.Material({
|
||||
'kd' : kd_map_opt,
|
||||
'ks' : ks_map_opt,
|
||||
'normal' : normal_map_opt
|
||||
})
|
||||
|
||||
if init_mat is not None:
|
||||
mat['bsdf'] = init_mat['bsdf']
|
||||
else:
|
||||
mat['bsdf'] = 'pbr'
|
||||
|
||||
return mat
|
||||
|
||||
###############################################################################
|
||||
# Validation & testing
|
||||
###############################################################################
|
||||
|
||||
def rotate_scene(FLAGS, itr):
|
||||
fovy = np.deg2rad(45)
|
||||
cam_radius = RADIUS
|
||||
proj_mtx = util.perspective(fovy, FLAGS.display_res[1] / FLAGS.display_res[0], FLAGS.cam_near_far[0], FLAGS.cam_near_far[1])
|
||||
|
||||
# Smooth rotation for display.
|
||||
ang = (itr / 50) * np.pi * 2
|
||||
mv = util.translate(0, 0, -cam_radius) @ (util.rotate_x(-0.4) @ util.rotate_y(ang))
|
||||
mvp = proj_mtx @ mv
|
||||
campos = torch.linalg.inv(mv)[:3, 3]
|
||||
|
||||
res_dict = {
|
||||
'mv': mv[None, ...].cuda(),
|
||||
'mvp': mvp[None, ...].cuda(),
|
||||
'campos': campos[None, ...].cuda(),
|
||||
'spp': 1,
|
||||
'resolution': FLAGS.display_res
|
||||
}
|
||||
|
||||
return res_dict
|
||||
|
||||
def validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS):
|
||||
result_dict = {}
|
||||
with torch.no_grad():
|
||||
lgt.build_mips()
|
||||
if FLAGS.camera_space_light:
|
||||
lgt.xfm(target['mv'])
|
||||
|
||||
buffers = geometry.render(glctx, target, lgt, opt_material)
|
||||
|
||||
result_dict['opt'] = util.rgb_to_srgb(buffers['shaded'][...,0:3])[0]
|
||||
result_image = result_dict['opt']
|
||||
|
||||
return result_image, result_dict
|
||||
|
||||
def validate(glctx, geometry, opt_material, lgt, dataset_validate, out_dir, FLAGS):
|
||||
|
||||
# ==============================================================================================
|
||||
# Validation loop
|
||||
# ==============================================================================================
|
||||
mse_values = []
|
||||
psnr_values = []
|
||||
|
||||
dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_validate.collate)
|
||||
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
with open(os.path.join(out_dir, 'metrics.txt'), 'w') as fout:
|
||||
fout.write('ID, MSE, PSNR\n')
|
||||
|
||||
print("Running validation")
|
||||
for it, target in enumerate(dataloader_validate):
|
||||
|
||||
# Mix validation background
|
||||
target = prepare_batch(target, FLAGS.background)
|
||||
|
||||
result_image, result_dict = validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS)
|
||||
|
||||
# Compute metrics
|
||||
opt = torch.clamp(result_dict['opt'], 0.0, 1.0)
|
||||
ref = torch.clamp(result_dict['ref'], 0.0, 1.0)
|
||||
|
||||
mse = torch.nn.functional.mse_loss(opt, ref, size_average=None, reduce=None, reduction='mean').item()
|
||||
mse_values.append(float(mse))
|
||||
psnr = util.mse_to_psnr(mse)
|
||||
psnr_values.append(float(psnr))
|
||||
|
||||
line = "%d, %1.8f, %1.8f\n" % (it, mse, psnr)
|
||||
fout.write(str(line))
|
||||
|
||||
for k in result_dict.keys():
|
||||
np_img = result_dict[k].detach().cpu().numpy()
|
||||
util.save_image(out_dir + '/' + ('val_%06d_%s.png' % (it, k)), np_img)
|
||||
|
||||
avg_mse = np.mean(np.array(mse_values))
|
||||
avg_psnr = np.mean(np.array(psnr_values))
|
||||
line = "AVERAGES: %1.4f, %2.3f\n" % (avg_mse, avg_psnr)
|
||||
fout.write(str(line))
|
||||
print("MSE, PSNR")
|
||||
print("%1.8f, %2.3f" % (avg_mse, avg_psnr))
|
||||
return avg_psnr
|
||||
|
||||
###############################################################################
|
||||
# Main shape fitter function / optimization loop
|
||||
###############################################################################
|
||||
|
||||
class Trainer(torch.nn.Module):
|
||||
def __init__(self, glctx, geometry, lgt, mat, optimize_geometry, optimize_light, image_loss_fn, FLAGS):
|
||||
super(Trainer, self).__init__()
|
||||
|
||||
self.glctx = glctx
|
||||
self.geometry = geometry
|
||||
self.light = lgt
|
||||
self.material = mat
|
||||
self.optimize_geometry = optimize_geometry
|
||||
self.optimize_light = optimize_light
|
||||
self.image_loss_fn = image_loss_fn
|
||||
self.FLAGS = FLAGS
|
||||
|
||||
if not self.optimize_light:
|
||||
with torch.no_grad():
|
||||
self.light.build_mips()
|
||||
|
||||
self.params = list(self.material.parameters())
|
||||
self.params += list(self.light.parameters()) if optimize_light else []
|
||||
self.geo_params = list(self.geometry.parameters()) if optimize_geometry else []
|
||||
|
||||
def forward(self, target, it):
|
||||
if self.optimize_light:
|
||||
self.light.build_mips()
|
||||
if self.FLAGS.camera_space_light:
|
||||
self.light.xfm(target['mv'])
|
||||
|
||||
return self.geometry.tick(glctx, target, self.light, self.material, self.image_loss_fn, it)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Main function.
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='nvdiffrec')
|
||||
parser.add_argument('--config', type=str, default=None, help='Config file')
|
||||
parser.add_argument('-i', '--iter', type=int, default=5000)
|
||||
parser.add_argument('-b', '--batch', type=int, default=1)
|
||||
parser.add_argument('-s', '--spp', type=int, default=1)
|
||||
parser.add_argument('-l', '--layers', type=int, default=1)
|
||||
parser.add_argument('-r', '--train-res', nargs=2, type=int, default=[512, 512])
|
||||
parser.add_argument('-dr', '--display-res', type=int, default=None)
|
||||
parser.add_argument('-tr', '--texture-res', nargs=2, type=int, default=[1024, 1024])
|
||||
parser.add_argument('-di', '--display-interval', type=int, default=0)
|
||||
parser.add_argument('-si', '--save-interval', type=int, default=1000)
|
||||
parser.add_argument('-lr', '--learning-rate', type=float, default=0.01)
|
||||
parser.add_argument('-mr', '--min-roughness', type=float, default=0.08)
|
||||
parser.add_argument('-mip', '--custom-mip', action='store_true', default=False)
|
||||
parser.add_argument('-rt', '--random-textures', action='store_true', default=False)
|
||||
parser.add_argument('-bg', '--background', default='checker', choices=['black', 'white', 'checker', 'reference'])
|
||||
parser.add_argument('-o', '--out-dir', type=str, default='./viz_tet_traj')
|
||||
parser.add_argument('-sf', '--sample-folder', type=str, default=None)
|
||||
parser.add_argument('-bm', '--base-mesh', type=str, default=None)
|
||||
parser.add_argument('-ds', '--deform-scale', type=float, default=2.0)
|
||||
parser.add_argument('-vn', '--viz-name', type=str, default='viz')
|
||||
parser.add_argument('--unnormalized_sdf', action="store_true")
|
||||
parser.add_argument('--validate', type=bool, default=True)
|
||||
|
||||
FLAGS = parser.parse_args()
|
||||
|
||||
FLAGS.mtl_override = None # Override material of model
|
||||
FLAGS.dmtet_grid = 64 # Resolution of initial tet grid. We provide 64 and 128 resolution grids. Other resolutions can be generated with https://github.com/crawforddoran/quartet
|
||||
FLAGS.mesh_scale = 2.1 # Scale of tet grid box. Adjust to cover the model
|
||||
FLAGS.env_scale = 1.0 # Env map intensity multiplier
|
||||
FLAGS.envmap = None # HDR environment probe
|
||||
FLAGS.display = None # Conf validation window/display. E.g. [{"relight" : <path to envlight>}]
|
||||
FLAGS.camera_space_light = False # Fixed light in camera space. This is needed for setups like ethiopian head where the scanned object rotates on a stand.
|
||||
FLAGS.lock_light = False # Disable light optimization in the second pass
|
||||
FLAGS.lock_pos = False # Disable vertex position optimization in the second pass
|
||||
FLAGS.sdf_regularizer = 0.2 # Weight for sdf regularizer (see paper for details)
|
||||
FLAGS.laplace = "relative" # Mesh Laplacian ["absolute", "relative"]
|
||||
FLAGS.laplace_scale = 10000.0 # Weight for sdf regularizer. Default is relative with large weight
|
||||
FLAGS.pre_load = True # Pre-load entire dataset into memory for faster training
|
||||
FLAGS.kd_min = [ 0.0, 0.0, 0.0, 0.0] # Limits for kd
|
||||
FLAGS.kd_max = [ 1.0, 1.0, 1.0, 1.0]
|
||||
FLAGS.ks_min = [ 0.0, 0.08, 0.0] # Limits for ks
|
||||
FLAGS.ks_max = [ 1.0, 1.0, 1.0]
|
||||
FLAGS.nrm_min = [-1.0, -1.0, 0.0] # Limits for normal map
|
||||
FLAGS.nrm_max = [ 1.0, 1.0, 1.0]
|
||||
FLAGS.cam_near_far = [0.1, 1000.0]
|
||||
FLAGS.learn_light = False
|
||||
FLAGS.cropped = False
|
||||
FLAGS.random_lgt = False
|
||||
|
||||
if FLAGS.config is not None:
|
||||
data = json.load(open(FLAGS.config, 'r'))
|
||||
for key in data:
|
||||
FLAGS.__dict__[key] = data[key]
|
||||
|
||||
if FLAGS.display_res is None:
|
||||
FLAGS.display_res = FLAGS.train_res
|
||||
|
||||
os.makedirs(FLAGS.out_dir, exist_ok=True)
|
||||
viz_path = os.path.join(FLAGS.out_dir, 'viz')
|
||||
mesh_path = os.path.join(FLAGS.out_dir, 'mesh')
|
||||
os.makedirs(viz_path, exist_ok=True)
|
||||
os.makedirs(mesh_path, exist_ok=True)
|
||||
|
||||
glctx = dr.RasterizeGLContext()
|
||||
|
||||
# ==============================================================================================
|
||||
# Create env light with trainable parameters
|
||||
# ==============================================================================================
|
||||
|
||||
lgt = light.load_env(FLAGS.envmap, scale=FLAGS.env_scale)
|
||||
|
||||
# ==============================================================================================
|
||||
# If no initial guess, use DMtets to create geometry
|
||||
# ==============================================================================================
|
||||
|
||||
# Setup geometry for optimization
|
||||
resolution = FLAGS.dmtet_grid
|
||||
geometry = DMTetGeometry(resolution, FLAGS.mesh_scale, FLAGS)
|
||||
geometry.deform_scale = FLAGS.deform_scale
|
||||
|
||||
mask = torch.load(f'../data/grid_mask_{resolution}.pt').view(1, resolution, resolution, resolution).to("cuda")
|
||||
|
||||
### compute the mapping from tet indices to 3D cubic grid vertex indices
|
||||
tet_path = FLAGS.tet_path
|
||||
tet = np.load(tet_path)
|
||||
vertices = torch.tensor(tet['vertices'])
|
||||
vertices_unique = vertices[:].unique()
|
||||
dx = vertices_unique[1] - vertices_unique[0]
|
||||
|
||||
vertices_discretized = (torch.round(
|
||||
(vertices - vertices.min()) / dx)
|
||||
).long()
|
||||
|
||||
filelist = sorted([x for x in glob.glob(os.path.join(FLAGS.sample_folder, "*.npy"))])
|
||||
|
||||
for k, fpath in enumerate(filelist):
|
||||
data_all = np.load(fpath)
|
||||
print('shape of generated data', data_all.shape)
|
||||
|
||||
for no_data in range(data_all.shape[0]):
|
||||
|
||||
grid = torch.tensor(data_all[no_data])
|
||||
if FLAGS.unnormalized_sdf:
|
||||
raise NotImplementedError
|
||||
geometry.sdf.data[:] = (
|
||||
grid[0, vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]]
|
||||
).cuda()
|
||||
else:
|
||||
geometry.sdf.data[:] = torch.sign(
|
||||
grid[0, vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]]
|
||||
).cuda()
|
||||
geometry.deform.data[:] = (
|
||||
grid[1:, vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]]
|
||||
).cuda().transpose(0, 1)
|
||||
|
||||
geometry.deform.data[:] = geometry.deform.data[:].clip(-1.0, 1.0)
|
||||
|
||||
### mtl for visualization
|
||||
opt_material = {
|
||||
'name' : '_default_mat',
|
||||
# 'bsdf' : 'pbr',
|
||||
'bsdf' : 'diffuse',
|
||||
'kd' : texture.Texture2D(torch.tensor([0.75, 0.3, 0.6], dtype=torch.float32, device='cuda')),
|
||||
'ks' : texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda'))
|
||||
}
|
||||
|
||||
### create and optimize mesh
|
||||
base_mesh = geometry.getMesh(opt_material)
|
||||
|
||||
|
||||
v_pose = rotate_scene(FLAGS, 30)
|
||||
result_image, _ = validate_itr(glctx, prepare_batch(v_pose, FLAGS.background), geometry, opt_material, lgt, FLAGS)
|
||||
result_image = result_image.detach().cpu().numpy()
|
||||
util.save_image(os.path.join(viz_path, ('%s_%03d_time%03d.png' % (FLAGS.viz_name, no_data, k))), result_image)
|
||||
|
||||
### save post-processed mesh
|
||||
mesh_savepath = os.path.join(mesh_path, '%s_%03d_time%03d.obj' % (FLAGS.viz_name, no_data, k))
|
||||
save_obj(
|
||||
verts=base_mesh.v_pos,
|
||||
faces=base_mesh.t_pos_idx,
|
||||
f=mesh_savepath
|
||||
)
|
||||
|
||||
ms = pymeshlab.MeshSet()
|
||||
ms.load_new_mesh(mesh_savepath)
|
||||
ms.meshing_isotropic_explicit_remeshing()
|
||||
ms.apply_coord_laplacian_smoothing(stepsmoothnum=3, cotangentweight=False)
|
||||
# ms.apply_coord_laplacian_smoothing(stepsmoothnum=3, cotangentweight=True) ## for smoother surface
|
||||
ms.meshing_isotropic_explicit_remeshing()
|
||||
ms.apply_filter_script()
|
||||
ms.save_current_mesh(mesh_savepath)
|
|
@ -0,0 +1,820 @@
|
|||
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import nvdiffrast.torch as dr
|
||||
import xatlas
|
||||
|
||||
# Import data readers / generators
|
||||
from lib.dataset.dataset_mesh import DatasetMesh
|
||||
from lib.dataset.dataset_shapenet import ShapeNetDataset
|
||||
|
||||
# Import topology / geometry trainers
|
||||
from lib.geometry.dmtet import DMTetGeometry
|
||||
from lib.geometry.dmtet_fixedtopo import DMTetGeometryFixedTopo
|
||||
|
||||
import lib.render.renderutils as ru
|
||||
from lib.render import obj
|
||||
from lib.render import material
|
||||
from lib.render import util
|
||||
from lib.render import mesh
|
||||
from lib.render import texture
|
||||
from lib.render import mlptexture
|
||||
from lib.render import light
|
||||
from lib.render import render
|
||||
|
||||
import traceback
|
||||
|
||||
|
||||
RADIUS = 2.0
|
||||
|
||||
# # Enable to debug back-prop anomalies
|
||||
# torch.autograd.set_detect_anomaly(True)
|
||||
|
||||
# define colors
|
||||
color1 = (0, 0, 255) #red
|
||||
color2 = (0, 165, 255) #orange
|
||||
color3 = (0, 255, 255) #yellow
|
||||
color4 = (255, 255, 0) #cyan
|
||||
color5 = (255, 0, 0) #blue
|
||||
color6 = (128, 64, 64) #violet
|
||||
colorArr = np.array([[color1, color2, color3, color4, color5, color6]], dtype=np.uint8)
|
||||
|
||||
# resize lut to 256 (or more) values
|
||||
lut = cv2.resize(colorArr, (256,1), interpolation = cv2.INTER_LINEAR)
|
||||
|
||||
###############################################################################
|
||||
# Loss setup
|
||||
###############################################################################
|
||||
|
||||
@torch.no_grad()
|
||||
def createLoss(FLAGS):
|
||||
if FLAGS.loss == "smape":
|
||||
return lambda img, ref: ru.image_loss(img, ref, loss='smape', tonemapper='none')
|
||||
elif FLAGS.loss == "mse":
|
||||
return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='none')
|
||||
elif FLAGS.loss == "logl1":
|
||||
return lambda img, ref: ru.image_loss(img, ref, loss='l1', tonemapper='log_srgb')
|
||||
elif FLAGS.loss == "logl2":
|
||||
return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='log_srgb')
|
||||
elif FLAGS.loss == "relmse":
|
||||
return lambda img, ref: ru.image_loss(img, ref, loss='relmse', tonemapper='none')
|
||||
else:
|
||||
assert False
|
||||
|
||||
###############################################################################
|
||||
# Mix background into a dataset image
|
||||
###############################################################################
|
||||
|
||||
@torch.no_grad()
|
||||
def prepare_batch(target, bg_type='black'):
|
||||
assert len(target['img'].shape) == 4, "Image shape should be [n, h, w, c]"
|
||||
if bg_type == 'checker':
|
||||
background = torch.tensor(util.checkerboard(target['img'].shape[1:3], 8), dtype=torch.float32, device='cuda')[None, ...]
|
||||
elif bg_type == 'black':
|
||||
background = torch.zeros(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda')
|
||||
elif bg_type == 'white':
|
||||
background = torch.ones(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda')
|
||||
elif bg_type == 'reference':
|
||||
background = target['img'][..., 0:3]
|
||||
elif bg_type == 'random':
|
||||
background = torch.rand(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda')
|
||||
else:
|
||||
assert False, "Unknown background type %s" % bg_type
|
||||
|
||||
target['mv'] = target['mv'].cuda()
|
||||
target['mvp'] = target['mvp'].cuda()
|
||||
target['campos'] = target['campos'].cuda()
|
||||
target['img'] = target['img'].cuda()
|
||||
target['background'] = background
|
||||
|
||||
target['img'] = torch.cat((torch.lerp(background, target['img'][..., 0:3], target['img'][..., 3:4]), target['img'][..., 3:4]), dim=-1)
|
||||
|
||||
target['spts'] = target['spts'].cuda()
|
||||
target['vpts'] = target['vpts'].cuda()
|
||||
return target
|
||||
|
||||
###############################################################################
|
||||
# UV - map geometry & convert to a mesh
|
||||
###############################################################################
|
||||
|
||||
@torch.no_grad()
|
||||
def xatlas_uvmap(glctx, geometry, mat, FLAGS):
|
||||
eval_mesh = geometry.getMesh(mat)
|
||||
|
||||
# Create uvs with xatlas
|
||||
v_pos = eval_mesh.v_pos.detach().cpu().numpy()
|
||||
t_pos_idx = eval_mesh.t_pos_idx.detach().cpu().numpy()
|
||||
vmapping, indices, uvs = xatlas.parametrize(v_pos, t_pos_idx)
|
||||
|
||||
# Convert to tensors
|
||||
indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
|
||||
|
||||
uvs = torch.tensor(uvs, dtype=torch.float32, device='cuda')
|
||||
faces = torch.tensor(indices_int64, dtype=torch.int64, device='cuda')
|
||||
|
||||
new_mesh = mesh.Mesh(v_tex=uvs, t_tex_idx=faces, base=eval_mesh)
|
||||
|
||||
mask, kd, ks, normal = render.render_uv(glctx, new_mesh, FLAGS.texture_res, eval_mesh.material['kd_ks_normal'])
|
||||
|
||||
if FLAGS.layers > 1:
|
||||
kd = torch.cat((kd, torch.rand_like(kd[...,0:1])), dim=-1)
|
||||
|
||||
kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda')
|
||||
ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda')
|
||||
nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')
|
||||
|
||||
new_mesh.material = material.Material({
|
||||
'bsdf' : mat['bsdf'],
|
||||
'kd' : texture.Texture2D(kd, min_max=[kd_min, kd_max]),
|
||||
'ks' : texture.Texture2D(ks, min_max=[ks_min, ks_max]),
|
||||
'normal' : texture.Texture2D(normal, min_max=[nrm_min, nrm_max])
|
||||
})
|
||||
|
||||
return new_mesh
|
||||
|
||||
@torch.no_grad()
|
||||
def xatlas_uvmap_nrm(glctx, geometry, mat, FLAGS):
|
||||
eval_mesh = geometry.getMesh(mat)
|
||||
|
||||
# Create uvs with xatlas
|
||||
v_pos = eval_mesh.v_pos.detach().cpu().numpy()
|
||||
t_pos_idx = eval_mesh.t_pos_idx.detach().cpu().numpy()
|
||||
vmapping, indices, uvs = xatlas.parametrize(v_pos, t_pos_idx)
|
||||
|
||||
# Convert to tensors
|
||||
indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
|
||||
|
||||
uvs = torch.tensor(uvs, dtype=torch.float32, device='cuda')
|
||||
faces = torch.tensor(indices_int64, dtype=torch.int64, device='cuda')
|
||||
|
||||
new_mesh = mesh.Mesh(v_tex=uvs, t_tex_idx=faces, base=eval_mesh)
|
||||
|
||||
mask, normal = render.render_uv_nrm(glctx, new_mesh, FLAGS.texture_res, eval_mesh.material['normal'])
|
||||
|
||||
if FLAGS.layers > 1:
|
||||
kd = torch.cat((kd, torch.rand_like(kd[...,0:1])), dim=-1)
|
||||
|
||||
nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')
|
||||
|
||||
new_mesh.material = material.Material({
|
||||
'bsdf' : mat['bsdf'],
|
||||
'kd' : mat['kd'],
|
||||
'ks' : mat['ks'],
|
||||
'normal' : texture.Texture2D(normal, min_max=[nrm_min, nrm_max])
|
||||
})
|
||||
|
||||
return new_mesh
|
||||
|
||||
###############################################################################
|
||||
# Utility functions for material
|
||||
###############################################################################
|
||||
|
||||
def initial_guess_material(geometry, mlp, FLAGS, init_mat=None):
|
||||
kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda')
|
||||
ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda')
|
||||
nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')
|
||||
if mlp:
|
||||
mlp_min = torch.cat((kd_min[0:3], ks_min, nrm_min), dim=0)
|
||||
mlp_max = torch.cat((kd_max[0:3], ks_max, nrm_max), dim=0)
|
||||
mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=9, min_max=[mlp_min, mlp_max])
|
||||
mat = material.Material({'kd_ks_normal' : mlp_map_opt})
|
||||
else:
|
||||
# Setup Kd (albedo) and Ks (x, roughness, metalness) textures
|
||||
if FLAGS.random_textures or init_mat is None:
|
||||
num_channels = 4 if FLAGS.layers > 1 else 3
|
||||
kd_init = torch.rand(size=FLAGS.texture_res + [num_channels], device='cuda') * (kd_max - kd_min)[None, None, 0:num_channels] + kd_min[None, None, 0:num_channels]
|
||||
kd_map_opt = texture.create_trainable(kd_init , FLAGS.texture_res, not FLAGS.custom_mip, [kd_min, kd_max])
|
||||
|
||||
ksR = np.random.uniform(size=FLAGS.texture_res + [1], low=0.0, high=0.01)
|
||||
ksG = np.random.uniform(size=FLAGS.texture_res + [1], low=ks_min[1].cpu(), high=ks_max[1].cpu())
|
||||
ksB = np.random.uniform(size=FLAGS.texture_res + [1], low=ks_min[2].cpu(), high=ks_max[2].cpu())
|
||||
|
||||
ks_map_opt = texture.create_trainable(np.concatenate((ksR, ksG, ksB), axis=2), FLAGS.texture_res, not FLAGS.custom_mip, [ks_min, ks_max])
|
||||
else:
|
||||
kd_map_opt = texture.create_trainable(init_mat['kd'], FLAGS.texture_res, not FLAGS.custom_mip, [kd_min, kd_max])
|
||||
ks_map_opt = texture.create_trainable(init_mat['ks'], FLAGS.texture_res, not FLAGS.custom_mip, [ks_min, ks_max])
|
||||
|
||||
# Setup normal map
|
||||
if FLAGS.random_textures or init_mat is None or 'normal' not in init_mat:
|
||||
normal_map_opt = texture.create_trainable(np.array([0, 0, 1]), FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max])
|
||||
else:
|
||||
normal_map_opt = texture.create_trainable(init_mat['normal'], FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max])
|
||||
|
||||
mat = material.Material({
|
||||
'kd' : kd_map_opt,
|
||||
'ks' : ks_map_opt,
|
||||
'normal' : normal_map_opt
|
||||
})
|
||||
|
||||
if init_mat is not None:
|
||||
mat['bsdf'] = init_mat['bsdf']
|
||||
else:
|
||||
mat['bsdf'] = 'pbr'
|
||||
|
||||
return mat
|
||||
|
||||
def initial_guess_material_knownkskd(geometry, mlp, FLAGS, init_mat=None):
|
||||
nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')
|
||||
|
||||
if mlp:
|
||||
mlp_min = nrm_min
|
||||
mlp_max = nrm_max
|
||||
mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=3, min_max=[mlp_min, mlp_max])
|
||||
# mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=3, min_max=None)
|
||||
mat = material.Material({
|
||||
'kd' : init_mat['kd'],
|
||||
'ks' : init_mat['ks'],
|
||||
'normal' : mlp_map_opt,
|
||||
})
|
||||
else:
|
||||
# Setup normal map
|
||||
if FLAGS.random_textures or init_mat is None or 'normal' not in init_mat:
|
||||
normal_map_opt = texture.create_trainable(np.array([0, 0, 1]), FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max])
|
||||
else:
|
||||
normal_map_opt = texture.create_trainable(init_mat['normal'], FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max])
|
||||
|
||||
mat = material.Material({
|
||||
'kd' : init_mat['kd'],
|
||||
'ks' : init_mat['ks'],
|
||||
'normal' : normal_map_opt
|
||||
})
|
||||
|
||||
if init_mat is not None:
|
||||
mat['bsdf'] = init_mat['bsdf']
|
||||
else:
|
||||
mat['bsdf'] = 'pbr'
|
||||
|
||||
return mat
|
||||
|
||||
###############################################################################
|
||||
# Validation & testing
|
||||
###############################################################################
|
||||
|
||||
def validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS):
|
||||
result_dict = {}
|
||||
with torch.no_grad():
|
||||
lgt.build_mips()
|
||||
if FLAGS.camera_space_light:
|
||||
lgt.xfm(target['mv'])
|
||||
lgt.xfm(target['envlight_transform'])
|
||||
|
||||
try:
|
||||
buffers = geometry.render(glctx, target, lgt, opt_material, ema=True, xfm_lgt=target['envlight_transform'])
|
||||
except:
|
||||
buffers = geometry.render(glctx, target, lgt, opt_material, xfm_lgt=target['envlight_transform'])
|
||||
|
||||
result_dict['ref'] = util.rgb_to_srgb(target['img'][...,0:3])[0]
|
||||
result_dict['opt'] = util.rgb_to_srgb(buffers['shaded'][...,0:3])[0]
|
||||
result_image = torch.cat([result_dict['opt'], result_dict['ref']], axis=1)
|
||||
|
||||
return result_image, result_dict
|
||||
|
||||
def validate(glctx, geometry, opt_material, lgt, dataset_validate, out_dir, FLAGS):
|
||||
|
||||
# ==============================================================================================
|
||||
# Validation loop
|
||||
# ==============================================================================================
|
||||
mse_values = []
|
||||
psnr_values = []
|
||||
|
||||
dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_validate.collate)
|
||||
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
with open(os.path.join(out_dir, 'metrics.txt'), 'w') as fout:
|
||||
fout.write('ID, MSE, PSNR\n')
|
||||
|
||||
print("Running validation")
|
||||
for it, target in enumerate(dataloader_validate):
|
||||
|
||||
# Mix validation background
|
||||
target = prepare_batch(target, FLAGS.background)
|
||||
|
||||
result_image, result_dict = validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS)
|
||||
|
||||
# Compute metrics
|
||||
opt = torch.clamp(result_dict['opt'], 0.0, 1.0)
|
||||
ref = torch.clamp(result_dict['ref'], 0.0, 1.0)
|
||||
|
||||
mse = torch.nn.functional.mse_loss(opt, ref, size_average=None, reduce=None, reduction='mean').item()
|
||||
mse_values.append(float(mse))
|
||||
psnr = util.mse_to_psnr(mse)
|
||||
psnr_values.append(float(psnr))
|
||||
|
||||
line = "%d, %1.8f, %1.8f\n" % (it, mse, psnr)
|
||||
fout.write(str(line))
|
||||
|
||||
for k in result_dict.keys():
|
||||
np_img = result_dict[k].detach().cpu().numpy()
|
||||
util.save_image(out_dir + '/' + ('val_%06d_%s.png' % (it, k)), np_img)
|
||||
|
||||
avg_mse = np.mean(np.array(mse_values))
|
||||
avg_psnr = np.mean(np.array(psnr_values))
|
||||
line = "AVERAGES: %1.4f, %2.3f\n" % (avg_mse, avg_psnr)
|
||||
fout.write(str(line))
|
||||
print("MSE, PSNR")
|
||||
print("%1.8f, %2.3f" % (avg_mse, avg_psnr))
|
||||
return avg_psnr
|
||||
|
||||
###############################################################################
|
||||
# Main shape fitter function / optimization loop
|
||||
###############################################################################
|
||||
|
||||
class Trainer(torch.nn.Module):
|
||||
def __init__(self, glctx, geometry, lgt, mat, optimize_geometry, optimize_light, image_loss_fn, FLAGS):
|
||||
super(Trainer, self).__init__()
|
||||
|
||||
self.glctx = glctx
|
||||
self.geometry = geometry
|
||||
self.light = lgt
|
||||
self.material = mat
|
||||
self.optimize_geometry = optimize_geometry
|
||||
self.optimize_light = optimize_light
|
||||
self.image_loss_fn = image_loss_fn
|
||||
self.FLAGS = FLAGS
|
||||
|
||||
if not self.optimize_light:
|
||||
with torch.no_grad():
|
||||
self.light.build_mips()
|
||||
|
||||
self.params = list(self.material.parameters())
|
||||
self.params += list(self.light.parameters()) if optimize_light else []
|
||||
self.geo_params = list(self.geometry.parameters()) if optimize_geometry else []
|
||||
try:
|
||||
self.sdf_params = [self.geometry.sdf]
|
||||
except:
|
||||
self.sdf_params = []
|
||||
self.deform_params = [self.geometry.deform]
|
||||
|
||||
def forward(self, target, it):
|
||||
if self.optimize_light:
|
||||
self.light.build_mips()
|
||||
if self.FLAGS.camera_space_light:
|
||||
self.light.xfm(target['mv'])
|
||||
self.light.xfm(target['envlight_transform'])
|
||||
|
||||
return self.geometry.tick(glctx, target, self.light, self.material, self.image_loss_fn, it, xfm_lgt=target['envlight_transform'])
|
||||
|
||||
def optimize_mesh(
|
||||
glctx,
|
||||
geometry,
|
||||
opt_material,
|
||||
lgt,
|
||||
dataset_train,
|
||||
dataset_validate,
|
||||
FLAGS,
|
||||
warmup_iter=0,
|
||||
log_interval=10,
|
||||
pass_idx=0,
|
||||
pass_name="",
|
||||
optimize_light=True,
|
||||
optimize_geometry=True,
|
||||
):
|
||||
|
||||
# ==============================================================================================
|
||||
# Setup torch optimizer
|
||||
# ==============================================================================================
|
||||
|
||||
learning_rate = FLAGS.learning_rate[pass_idx] if isinstance(FLAGS.learning_rate, list) or isinstance(FLAGS.learning_rate, tuple) else FLAGS.learning_rate
|
||||
learning_rate_pos = learning_rate[0] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate
|
||||
learning_rate_mat = learning_rate[1] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate
|
||||
|
||||
def lr_schedule(iter, fraction):
|
||||
if iter < warmup_iter:
|
||||
return iter / warmup_iter
|
||||
return max(0.0, 10**(-(iter - warmup_iter)*0.0002)) # Exponential falloff from [1.0, 0.1] over 5k epochs.
|
||||
|
||||
# ==============================================================================================
|
||||
# Image loss
|
||||
# ==============================================================================================
|
||||
image_loss_fn = createLoss(FLAGS)
|
||||
|
||||
trainer_noddp = Trainer(glctx, geometry, lgt, opt_material, optimize_geometry, optimize_light, image_loss_fn, FLAGS)
|
||||
|
||||
if FLAGS.multi_gpu:
|
||||
raise NotImplementedError
|
||||
# Multi GPU training mode
|
||||
import apex
|
||||
from apex.parallel import DistributedDataParallel as DDP
|
||||
|
||||
trainer = DDP(trainer_noddp)
|
||||
trainer.train()
|
||||
if optimize_geometry:
|
||||
optimizer_mesh = apex.optimizers.FusedAdam(trainer_noddp.geo_params, lr=learning_rate_pos)
|
||||
scheduler_mesh = torch.optim.lr_scheduler.LambdaLR(optimizer_mesh, lr_lambda=lambda x: lr_schedule(x, 0.9))
|
||||
|
||||
optimizer = apex.optimizers.FusedAdam(trainer_noddp.params, lr=learning_rate_mat)
|
||||
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: lr_schedule(x, 0.9))
|
||||
else:
|
||||
# Single GPU training mode
|
||||
trainer = trainer_noddp
|
||||
if optimize_geometry:
|
||||
# optimizer_mesh = torch.optim.Adam(trainer_noddp.geo_params, lr=learning_rate_pos)
|
||||
optimizer_mesh = torch.optim.Adam([
|
||||
{'params': trainer_noddp.sdf_params, 'lr': learning_rate_pos},
|
||||
{'params': trainer_noddp.deform_params, 'lr': learning_rate_pos},
|
||||
])
|
||||
# optimizer_mesh = torch.optim.Adam(trainer_noddp.geo_params, lr=learning_rate_pos, betas=(0.2, 0.999), eps=1e-5)
|
||||
scheduler_mesh = torch.optim.lr_scheduler.LambdaLR(optimizer_mesh, lr_lambda=lambda x: lr_schedule(x, 0.9))
|
||||
|
||||
optimizer = torch.optim.Adam(trainer_noddp.params, lr=learning_rate_mat)
|
||||
# optimizer = torch.optim.Adam(trainer_noddp.params, lr=learning_rate_mat, betas=(0.2, 0.999), eps=1e-5)
|
||||
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: lr_schedule(x, 0.9))
|
||||
|
||||
# ==============================================================================================
|
||||
# Training loop
|
||||
# ==============================================================================================
|
||||
img_cnt = 0
|
||||
img_loss_vec = []
|
||||
reg_loss_vec = []
|
||||
iter_dur_vec = []
|
||||
|
||||
dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=FLAGS.batch, collate_fn=dataset_train.collate, shuffle=True)
|
||||
|
||||
print("Start training loop...")
|
||||
sys.stdout.flush()
|
||||
|
||||
for it, target in enumerate(dataloader_train):
|
||||
|
||||
# Mix randomized background into dataset image
|
||||
target = prepare_batch(target, 'random')
|
||||
|
||||
iter_start_time = time.time()
|
||||
|
||||
|
||||
# ==============================================================================================
|
||||
# Zero gradients
|
||||
# ==============================================================================================
|
||||
optimizer.zero_grad()
|
||||
if optimize_geometry:
|
||||
optimizer_mesh.zero_grad()
|
||||
|
||||
# ==============================================================================================
|
||||
# Training
|
||||
# ==============================================================================================
|
||||
img_loss, reg_loss = trainer(target, it)
|
||||
|
||||
# ==============================================================================================
|
||||
# Final loss
|
||||
# ==============================================================================================
|
||||
total_loss = img_loss + reg_loss
|
||||
|
||||
img_loss_vec.append(img_loss.item())
|
||||
reg_loss_vec.append(reg_loss.item())
|
||||
|
||||
# ==============================================================================================
|
||||
# Backpropagate
|
||||
# ==============================================================================================
|
||||
total_loss.backward()
|
||||
|
||||
if hasattr(lgt, 'base') and lgt.base.grad is not None and optimize_light:
|
||||
lgt.base.grad *= 64
|
||||
if 'kd_ks_normal' in opt_material:
|
||||
opt_material['kd_ks_normal'].encoder.params.grad /= 8.0
|
||||
if 'normal' in opt_material and FLAGS.normal_only:
|
||||
try:
|
||||
opt_material['normal'].encoder.params.grad /= 8.0
|
||||
except:
|
||||
pass
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
|
||||
if optimize_geometry:
|
||||
optimizer_mesh.step()
|
||||
scheduler_mesh.step()
|
||||
|
||||
geometry.clamp_deform()
|
||||
geometry.update_ema()
|
||||
|
||||
# ==============================================================================================
|
||||
# Clamp trainables to reasonable range
|
||||
# ==============================================================================================
|
||||
with torch.no_grad():
|
||||
if 'kd' in opt_material:
|
||||
opt_material['kd'].clamp_()
|
||||
if 'ks' in opt_material:
|
||||
opt_material['ks'].clamp_()
|
||||
if 'normal' in opt_material and not FLAGS.normal_only:
|
||||
opt_material['normal'].clamp_()
|
||||
opt_material['normal'].normalize_()
|
||||
if lgt is not None:
|
||||
lgt.clamp_(min=0.0)
|
||||
|
||||
torch.cuda.current_stream().synchronize()
|
||||
iter_dur_vec.append(time.time() - iter_start_time)
|
||||
|
||||
# ==============================================================================================
|
||||
# Logging
|
||||
# ==============================================================================================
|
||||
if it % log_interval == 0 and FLAGS.local_rank == 0:
|
||||
img_loss_avg = np.mean(np.asarray(img_loss_vec[-log_interval:]))
|
||||
reg_loss_avg = np.mean(np.asarray(reg_loss_vec[-log_interval:]))
|
||||
iter_dur_avg = np.mean(np.asarray(iter_dur_vec[-log_interval:]))
|
||||
|
||||
remaining_time = (FLAGS.iter-it)*iter_dur_avg
|
||||
print("iter=%5d, img_loss=%.6f, reg_loss=%.6f, lr=%.5f, time=%.1f ms, rem=%s" %
|
||||
(it, img_loss_avg, reg_loss_avg, optimizer.param_groups[0]['lr'], iter_dur_avg*1000, util.time_to_text(remaining_time)))
|
||||
sys.stdout.flush()
|
||||
|
||||
return geometry, opt_material
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Main function.
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
# sleep(randint(0,15))
|
||||
|
||||
parser = argparse.ArgumentParser(description='nvdiffrec')
|
||||
parser.add_argument('--config', type=str, default='./configs/res64.json', help='Config file')
|
||||
parser.add_argument('-i', '--iter', type=int, default=5000)
|
||||
parser.add_argument('-b', '--batch', type=int, default=1)
|
||||
parser.add_argument('-s', '--spp', type=int, default=1)
|
||||
parser.add_argument('-l', '--layers', type=int, default=1)
|
||||
parser.add_argument('-r', '--train-res', nargs=2, type=int, default=[512, 512])
|
||||
parser.add_argument('-dr', '--display-res', type=int, default=None)
|
||||
parser.add_argument('-tr', '--texture-res', nargs=2, type=int, default=[1024, 1024])
|
||||
parser.add_argument('-di', '--display-interval', type=int, default=0)
|
||||
parser.add_argument('-si', '--save-interval', type=int, default=1000)
|
||||
parser.add_argument('-lr', '--learning-rate', type=float, default=0.01)
|
||||
parser.add_argument('-mr', '--min-roughness', type=float, default=0.08)
|
||||
parser.add_argument('-mip', '--custom-mip', action='store_true', default=False)
|
||||
parser.add_argument('-rt', '--random-textures', action='store_true', default=False)
|
||||
parser.add_argument('-bg', '--background', default='checker', choices=['black', 'white', 'checker', 'reference'])
|
||||
parser.add_argument('--loss', default='logl1', choices=['logl1', 'logl2', 'mse', 'smape', 'relmse'])
|
||||
parser.add_argument('-o', '--out-dir', type=str, default='./dmtet_results')
|
||||
parser.add_argument('-bm', '--base-mesh', type=str, default=None)
|
||||
parser.add_argument('--validate', type=bool, default=True)
|
||||
parser.add_argument('-ind', '--index', type=int)
|
||||
parser.add_argument('-ss', '--split-size', type=int, default=10)
|
||||
parser.add_argument('--cropped', type=bool, default=True)
|
||||
parser.add_argument('-no', '--normal-only', type=bool, default=True)
|
||||
parser.add_argument('--meta-folder', type=str, default='./data/shapenet_json')
|
||||
parser.add_argument('--cat-name', type=str, default='chair')
|
||||
parser.add_argument('-rp', '--resume-path', type=str, default=None)
|
||||
parser.add_argument('-ema', '--use-ema', action="store_true")
|
||||
|
||||
FLAGS = parser.parse_args()
|
||||
print(f"parsed arguments")
|
||||
global_index = FLAGS.index * FLAGS.split_size
|
||||
|
||||
FLAGS.mtl_override = None # Override material of model
|
||||
FLAGS.dmtet_grid = 64 # Resolution of initial tet grid. We provide 64 and 128 resolution grids. Other resolutions can be generated with https://github.com/crawforddoran/quartet
|
||||
FLAGS.mesh_scale = 1.0 # Scale of tet grid box. Adjust to cover the model
|
||||
FLAGS.env_scale = 1.0 # Env map intensity multiplier
|
||||
FLAGS.envmap = None # HDR environment probe
|
||||
FLAGS.display = None # Conf validation window/display. E.g. [{"relight" : <path to envlight>}]
|
||||
FLAGS.camera_space_light = False # Fixed light in camera space. This is needed for setups like ethiopian head where the scanned object rotates on a stand.
|
||||
FLAGS.lock_light = False # Disable light optimization in the second pass
|
||||
FLAGS.lock_pos = False # Disable vertex position optimization in the second pass
|
||||
FLAGS.sdf_regularizer = 0.2 # Weight for sdf regularizer (see paper for details)
|
||||
FLAGS.laplace = "relative" # Mesh Laplacian ["absolute", "relative"]
|
||||
FLAGS.laplace_scale = 10000.0 # Weight for sdf regularizer. Default is relative with large weight
|
||||
FLAGS.pre_load = True # Pre-load entire dataset into memory for faster training
|
||||
FLAGS.kd_min = [ 0.0, 0.0, 0.0, 0.0] # Limits for kd
|
||||
FLAGS.kd_max = [ 1.0, 1.0, 1.0, 1.0]
|
||||
FLAGS.ks_min = [ 0.0, 0.08, 0.0] # Limits for ks
|
||||
FLAGS.ks_max = [ 1.0, 1.0, 1.0]
|
||||
FLAGS.nrm_min = [-1.0, -1.0, 0.0] # Limits for normal map
|
||||
FLAGS.nrm_max = [ 1.0, 1.0, 1.0]
|
||||
FLAGS.cam_near_far = [0.1, 1000.0]
|
||||
FLAGS.learn_light = False
|
||||
FLAGS.cropped = True
|
||||
FLAGS.use_ema = False
|
||||
FLAGS.random_lgt = True
|
||||
FLAGS.dataset_flat_shading = False
|
||||
|
||||
FLAGS.local_rank = 0
|
||||
FLAGS.multi_gpu = "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1
|
||||
if FLAGS.multi_gpu:
|
||||
if "MASTER_ADDR" not in os.environ:
|
||||
os.environ["MASTER_ADDR"] = 'localhost'
|
||||
if "MASTER_PORT" not in os.environ:
|
||||
os.environ["MASTER_PORT"] = '23456'
|
||||
|
||||
FLAGS.local_rank = int(os.environ["LOCAL_RANK"])
|
||||
torch.cuda.set_device(FLAGS.local_rank)
|
||||
torch.distributed.init_process_group(backend="nccl", init_method="env://")
|
||||
|
||||
if FLAGS.config is not None:
|
||||
data = json.load(open(FLAGS.config, 'r'))
|
||||
for key in data:
|
||||
FLAGS.__dict__[key] = data[key]
|
||||
|
||||
if FLAGS.display_res is None:
|
||||
FLAGS.display_res = FLAGS.train_res
|
||||
|
||||
FLAGS.out_dir = os.path.join(FLAGS.out_dir, FLAGS.cat_name)
|
||||
|
||||
if FLAGS.local_rank == 0:
|
||||
print("Config / Flags:")
|
||||
print("---------")
|
||||
for key in FLAGS.__dict__.keys():
|
||||
print(key, FLAGS.__dict__[key])
|
||||
print("---------")
|
||||
|
||||
os.makedirs(FLAGS.out_dir, exist_ok=True)
|
||||
os.makedirs(os.path.join(FLAGS.out_dir, 'val_viz'), exist_ok=True)
|
||||
os.makedirs(os.path.join(FLAGS.out_dir, 'tets'), exist_ok=True)
|
||||
os.makedirs(os.path.join(FLAGS.out_dir, 'tets_pre'), exist_ok=True)
|
||||
|
||||
|
||||
print(f"Using dmt grid of resolution {FLAGS.dmtet_grid}")
|
||||
|
||||
glctx = dr.RasterizeGLContext()
|
||||
|
||||
### Default mtl
|
||||
mtl_default = {
|
||||
'name' : '_default_mat',
|
||||
'bsdf': 'diffuse',
|
||||
'uniform': True,
|
||||
'kd' : texture.Texture2D(torch.tensor([0.75, 0.3, 0.6], dtype=torch.float32, device='cuda'), trainable=False),
|
||||
'ks' : texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda'), trainable=False)
|
||||
}
|
||||
|
||||
|
||||
print(f"meta json path {os.path.join(FLAGS.meta_folder, f'{FLAGS.cat_name}.json')}")
|
||||
shapenet_dataset = ShapeNetDataset(
|
||||
os.path.join(FLAGS.meta_folder, f'{FLAGS.cat_name}.json'),
|
||||
shapenet_v1=(FLAGS.cat_name == 'car')
|
||||
)
|
||||
|
||||
print("Start iterating through objects")
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
|
||||
if len(shapenet_dataset) > 0:
|
||||
for k in range(FLAGS.split_size):
|
||||
# ==============================================================================================
|
||||
# Create data pipeline
|
||||
# ==============================================================================================
|
||||
|
||||
global_index = k + FLAGS.index * FLAGS.split_size
|
||||
|
||||
print("file path to save: {:s}".format(os.path.join(FLAGS.out_dir, 'tets/dmt_dict_{:05d}.pt'.format(global_index))))
|
||||
|
||||
skip_if_exists = True
|
||||
if skip_if_exists and os.path.exists(os.path.join(FLAGS.out_dir, 'tets/dmt_dict_{:05d}.pt'.format(global_index))):
|
||||
continue
|
||||
|
||||
try:
|
||||
global_index = k + FLAGS.index * FLAGS.split_size
|
||||
|
||||
if global_index >= len(shapenet_dataset):
|
||||
break
|
||||
mesh_fname = shapenet_dataset[global_index]
|
||||
|
||||
print(f"Loading mesh: {mesh_fname}")
|
||||
sys.stdout.flush()
|
||||
ref_mesh = mesh.load_mesh(mesh_fname, FLAGS.mtl_override, mtl_default, use_default=FLAGS.normal_only, no_additional=True)
|
||||
ref_mesh = mesh.center_by_reference(ref_mesh, mesh.aabb_clean(ref_mesh), 1.0)
|
||||
|
||||
a = ref_mesh.v_nrm.clone()
|
||||
ref_mesh = mesh.auto_normals(ref_mesh) ### important
|
||||
|
||||
print("Loading dataset")
|
||||
sys.stdout.flush()
|
||||
if FLAGS.cat_name == 'car':
|
||||
RADIUS = 2.0
|
||||
dataset_train = DatasetMesh(ref_mesh, glctx, RADIUS, FLAGS, validate=False)
|
||||
dataset_validate = DatasetMesh(ref_mesh, glctx, RADIUS, FLAGS, validate=True)
|
||||
print("Dataset loaded")
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
# ==============================================================================================
|
||||
# Create env light with trainable parameters
|
||||
# ==============================================================================================
|
||||
|
||||
if FLAGS.learn_light:
|
||||
lgt = light.create_trainable_env_rnd(512, scale=0.0, bias=0.5)
|
||||
else:
|
||||
lgt = light.load_env(FLAGS.envmap, scale=FLAGS.env_scale, trainable=False)
|
||||
|
||||
# ==============================================================================================
|
||||
# If no initial guess, use DMtets to create geometry
|
||||
# ==============================================================================================
|
||||
|
||||
# Setup geometry for optimization
|
||||
geometry = DMTetGeometry(FLAGS.dmtet_grid, FLAGS.mesh_scale, FLAGS)
|
||||
|
||||
# Setup textures, make initial guess from reference if possible
|
||||
if not FLAGS.normal_only:
|
||||
mat = initial_guess_material(geometry, True, FLAGS, mtl_default)
|
||||
else:
|
||||
mat = initial_guess_material_knownkskd(geometry, True, FLAGS, mtl_default)
|
||||
|
||||
print("Start optimization")
|
||||
sys.stdout.flush()
|
||||
|
||||
if FLAGS.resume_path is None:
|
||||
# Run optimization
|
||||
geometry, mat = optimize_mesh(glctx, geometry, mat, lgt, dataset_train, dataset_validate,
|
||||
FLAGS, pass_idx=0, pass_name="dmtet_pass1", optimize_light=FLAGS.learn_light)
|
||||
|
||||
base_mesh = geometry.getMesh(mat)
|
||||
|
||||
vert_mask = torch.zeros_like(geometry.sdf).long().cuda().view(-1, 1)
|
||||
vert_mask[geometry.getValidVertsIdx()] = 1
|
||||
|
||||
# Free temporaries / cached memory
|
||||
torch.cuda.empty_cache() ### may slow down training
|
||||
|
||||
torch.save({
|
||||
'sdf': geometry.sdf.cpu().detach(),
|
||||
'sdf_ema': geometry.sdf_ema.cpu().detach(),
|
||||
'deform': (geometry.deform * vert_mask).cpu().detach(),
|
||||
'deform_unmasked': geometry.deform.cpu().detach(),
|
||||
}, os.path.join(FLAGS.out_dir, 'tets_pre/dmt_dict_{:05d}.pt'.format(global_index)))
|
||||
|
||||
old_geometry = geometry
|
||||
else:
|
||||
dmt_dict = torch.load(os.path.join(FLAGS.resume_path, 'tets_pre/dmt_dict_{:05d}.pt'.format(global_index)))
|
||||
if FLAGS.use_ema:
|
||||
geometry.sdf.data[:] = dmt_dict['sdf_ema']
|
||||
else:
|
||||
geometry.sdf.data[:] = dmt_dict['sdf']
|
||||
geometry.deform.data[:] = dmt_dict['deform']
|
||||
old_geometry = geometry
|
||||
|
||||
# Create textured mesh from result
|
||||
if FLAGS.normal_only:
|
||||
base_mesh = xatlas_uvmap_nrm(glctx, geometry, mat, FLAGS)
|
||||
else:
|
||||
base_mesh = xatlas_uvmap(glctx, geometry, mat, FLAGS)
|
||||
|
||||
|
||||
# # ==============================================================================================
|
||||
# # Pass 2: Finetune deformation with fixed topology
|
||||
# # ==============================================================================================
|
||||
geometry = DMTetGeometryFixedTopo(geometry, base_mesh, FLAGS.dmtet_grid, FLAGS.mesh_scale, FLAGS)
|
||||
|
||||
|
||||
geometry.sdf_sign.requires_grad = False
|
||||
geometry.sdf_abs.requires_grad = False
|
||||
geometry.deform.requires_grad = True
|
||||
|
||||
# geometry.deform.data[:] = geometry.deform * 2.0 / 3.0
|
||||
# geometry.deform_scale = 3.0
|
||||
|
||||
geometry.deform.data[:] = geometry.deform * 0.45 / 1.5
|
||||
geometry.deform_scale = 1.5
|
||||
|
||||
if FLAGS.use_ema:
|
||||
geometry.sdf_sign.data[:] = torch.sign(old_geometry.sdf_ema)
|
||||
else:
|
||||
geometry.sdf_sign.data[:] = torch.sign(old_geometry.sdf)
|
||||
|
||||
geometry.set_init_v_pos()
|
||||
|
||||
|
||||
geometry, mat = optimize_mesh(glctx, geometry, mat, lgt, dataset_train, dataset_validate, FLAGS,
|
||||
pass_idx=1, pass_name="mesh_pass", warmup_iter=100, optimize_light=FLAGS.learn_light and not FLAGS.lock_light,
|
||||
optimize_geometry=not FLAGS.lock_pos)
|
||||
|
||||
vert_mask = torch.zeros_like(geometry.sdf_sign).long().cuda().view(-1, 1)
|
||||
vert_mask[geometry.getValidVertsIdx()] = 1
|
||||
|
||||
torch.save({
|
||||
'sdf': geometry.sdf_sign.cpu().detach(),
|
||||
'deform': (geometry.deform * vert_mask).cpu().detach(),
|
||||
'deform_unmasked': geometry.deform.cpu().detach(),
|
||||
},
|
||||
os.path.join(FLAGS.out_dir, 'tets/dmt_dict_{:05d}.pt'.format(global_index))
|
||||
)
|
||||
|
||||
if FLAGS.local_rank == 0 and FLAGS.validate:
|
||||
validate(glctx, geometry, mat, lgt, dataset_validate, os.path.join(FLAGS.out_dir, f"val_viz/dmtet_validate_{FLAGS.index}_{k}_{FLAGS.split_size}"), FLAGS)
|
||||
|
||||
# Free temporaries / cached memory
|
||||
del geometry
|
||||
del ref_mesh
|
||||
del dataset_train
|
||||
del dataset_validate
|
||||
torch.cuda.empty_cache() ### may slow down training
|
||||
|
||||
print(f"\n\n============ {FLAGS.index}_{k}/{FLAGS.split_size} finished ============\n\n")
|
||||
except Exception as err:
|
||||
print(f"\n\n============ {FLAGS.index}_{k}/{FLAGS.split_size} Failed ============\n\n")
|
||||
print(traceback.format_exc())
|
||||
print("\n\n")
|
||||
continue
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,836 @@
|
|||
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import nvdiffrast.torch as dr
|
||||
import xatlas
|
||||
|
||||
# Import data readers / generators
|
||||
from lib.dataset.dataset_mesh import DatasetMesh
|
||||
from lib.dataset.dataset_shapenet import ShapeNetDataset
|
||||
|
||||
# Import topology / geometry trainers
|
||||
from lib.geometry.dmtet_singleview import DMTetGeometry
|
||||
from lib.geometry.dmtet_fixedtopo import DMTetGeometryFixedTopo
|
||||
|
||||
import lib.render.renderutils as ru
|
||||
from lib.render import obj
|
||||
from lib.render import material
|
||||
from lib.render import util
|
||||
from lib.render import mesh
|
||||
from lib.render import texture
|
||||
from lib.render import mlptexture
|
||||
from lib.render import light
|
||||
from lib.render import render
|
||||
from random import randint
|
||||
from time import sleep
|
||||
|
||||
import traceback
|
||||
|
||||
RADIUS = 2.0
|
||||
|
||||
# define colors
|
||||
color1 = (0, 0, 255) #red
|
||||
color2 = (0, 165, 255) #orange
|
||||
color3 = (0, 255, 255) #yellow
|
||||
color4 = (255, 255, 0) #cyan
|
||||
color5 = (255, 0, 0) #blue
|
||||
color6 = (128, 64, 64) #violet
|
||||
colorArr = np.array([[color1, color2, color3, color4, color5, color6]], dtype=np.uint8)
|
||||
|
||||
# resize lut to 256 (or more) values
|
||||
lut = cv2.resize(colorArr, (256,1), interpolation = cv2.INTER_LINEAR)
|
||||
|
||||
###############################################################################
|
||||
# Loss setup
|
||||
###############################################################################
|
||||
|
||||
@torch.no_grad()
|
||||
def createLoss(FLAGS):
|
||||
if FLAGS.loss == "smape":
|
||||
return lambda img, ref: ru.image_loss(img, ref, loss='smape', tonemapper='none')
|
||||
elif FLAGS.loss == "mse":
|
||||
return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='none')
|
||||
elif FLAGS.loss == "logl1":
|
||||
return lambda img, ref: ru.image_loss(img, ref, loss='l1', tonemapper='log_srgb')
|
||||
elif FLAGS.loss == "logl2":
|
||||
return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='log_srgb')
|
||||
elif FLAGS.loss == "relmse":
|
||||
return lambda img, ref: ru.image_loss(img, ref, loss='relmse', tonemapper='none')
|
||||
else:
|
||||
assert False
|
||||
|
||||
###############################################################################
|
||||
# Mix background into a dataset image
|
||||
###############################################################################
|
||||
|
||||
@torch.no_grad()
|
||||
def prepare_batch(target, bg_type='black'):
|
||||
assert len(target['img'].shape) == 4, "Image shape should be [n, h, w, c]"
|
||||
if bg_type == 'checker':
|
||||
background = torch.tensor(util.checkerboard(target['img'].shape[1:3], 8), dtype=torch.float32, device='cuda')[None, ...]
|
||||
elif bg_type == 'black':
|
||||
background = torch.zeros(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda')
|
||||
elif bg_type == 'white':
|
||||
background = torch.ones(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda')
|
||||
elif bg_type == 'reference':
|
||||
background = target['img'][..., 0:3]
|
||||
elif bg_type == 'random':
|
||||
background = torch.rand(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda')
|
||||
else:
|
||||
assert False, "Unknown background type %s" % bg_type
|
||||
|
||||
target['mv'] = target['mv'].cuda()
|
||||
target['mvp'] = target['mvp'].cuda()
|
||||
target['campos'] = target['campos'].cuda()
|
||||
target['img'] = target['img'].cuda()
|
||||
target['background'] = background
|
||||
|
||||
target['img'] = torch.cat((torch.lerp(background, target['img'][..., 0:3], target['img'][..., 3:4]), target['img'][..., 3:4]), dim=-1)
|
||||
|
||||
target['spts'] = target['spts'].cuda()
|
||||
target['vpts'] = target['vpts'].cuda()
|
||||
return target
|
||||
|
||||
###############################################################################
|
||||
# UV - map geometry & convert to a mesh
|
||||
###############################################################################
|
||||
|
||||
@torch.no_grad()
|
||||
def xatlas_uvmap(glctx, geometry, mat, FLAGS):
|
||||
eval_mesh = geometry.getMesh(mat)
|
||||
|
||||
# Create uvs with xatlas
|
||||
v_pos = eval_mesh.v_pos.detach().cpu().numpy()
|
||||
t_pos_idx = eval_mesh.t_pos_idx.detach().cpu().numpy()
|
||||
vmapping, indices, uvs = xatlas.parametrize(v_pos, t_pos_idx)
|
||||
|
||||
# Convert to tensors
|
||||
indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
|
||||
|
||||
uvs = torch.tensor(uvs, dtype=torch.float32, device='cuda')
|
||||
faces = torch.tensor(indices_int64, dtype=torch.int64, device='cuda')
|
||||
|
||||
new_mesh = mesh.Mesh(v_tex=uvs, t_tex_idx=faces, base=eval_mesh)
|
||||
|
||||
mask, kd, ks, normal = render.render_uv(glctx, new_mesh, FLAGS.texture_res, eval_mesh.material['kd_ks_normal'])
|
||||
|
||||
if FLAGS.layers > 1:
|
||||
kd = torch.cat((kd, torch.rand_like(kd[...,0:1])), dim=-1)
|
||||
|
||||
kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda')
|
||||
ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda')
|
||||
nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')
|
||||
|
||||
new_mesh.material = material.Material({
|
||||
'bsdf' : mat['bsdf'],
|
||||
'kd' : texture.Texture2D(kd, min_max=[kd_min, kd_max]),
|
||||
'ks' : texture.Texture2D(ks, min_max=[ks_min, ks_max]),
|
||||
'normal' : texture.Texture2D(normal, min_max=[nrm_min, nrm_max])
|
||||
})
|
||||
|
||||
return new_mesh
|
||||
|
||||
@torch.no_grad()
|
||||
def xatlas_uvmap_nrm(glctx, geometry, mat, FLAGS):
|
||||
eval_mesh = geometry.getMesh(mat)
|
||||
|
||||
# Create uvs with xatlas
|
||||
v_pos = eval_mesh.v_pos.detach().cpu().numpy()
|
||||
t_pos_idx = eval_mesh.t_pos_idx.detach().cpu().numpy()
|
||||
vmapping, indices, uvs = xatlas.parametrize(v_pos, t_pos_idx)
|
||||
|
||||
# Convert to tensors
|
||||
indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
|
||||
|
||||
uvs = torch.tensor(uvs, dtype=torch.float32, device='cuda')
|
||||
faces = torch.tensor(indices_int64, dtype=torch.int64, device='cuda')
|
||||
|
||||
new_mesh = mesh.Mesh(v_tex=uvs, t_tex_idx=faces, base=eval_mesh)
|
||||
|
||||
mask, normal = render.render_uv_nrm(glctx, new_mesh, FLAGS.texture_res, eval_mesh.material['normal'])
|
||||
|
||||
if FLAGS.layers > 1:
|
||||
kd = torch.cat((kd, torch.rand_like(kd[...,0:1])), dim=-1)
|
||||
|
||||
nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')
|
||||
|
||||
new_mesh.material = material.Material({
|
||||
'bsdf' : mat['bsdf'],
|
||||
'kd' : mat['kd'],
|
||||
'ks' : mat['ks'],
|
||||
'normal' : texture.Texture2D(normal, min_max=[nrm_min, nrm_max])
|
||||
})
|
||||
|
||||
return new_mesh
|
||||
|
||||
###############################################################################
|
||||
# Utility functions for material
|
||||
###############################################################################
|
||||
|
||||
def initial_guess_material(geometry, mlp, FLAGS, init_mat=None):
|
||||
kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda')
|
||||
ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda')
|
||||
nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')
|
||||
if mlp:
|
||||
mlp_min = torch.cat((kd_min[0:3], ks_min, nrm_min), dim=0)
|
||||
mlp_max = torch.cat((kd_max[0:3], ks_max, nrm_max), dim=0)
|
||||
mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=9, min_max=[mlp_min, mlp_max])
|
||||
mat = material.Material({'kd_ks_normal' : mlp_map_opt})
|
||||
else:
|
||||
# Setup Kd (albedo) and Ks (x, roughness, metalness) textures
|
||||
if FLAGS.random_textures or init_mat is None:
|
||||
num_channels = 4 if FLAGS.layers > 1 else 3
|
||||
kd_init = torch.rand(size=FLAGS.texture_res + [num_channels], device='cuda') * (kd_max - kd_min)[None, None, 0:num_channels] + kd_min[None, None, 0:num_channels]
|
||||
kd_map_opt = texture.create_trainable(kd_init , FLAGS.texture_res, not FLAGS.custom_mip, [kd_min, kd_max])
|
||||
|
||||
ksR = np.random.uniform(size=FLAGS.texture_res + [1], low=0.0, high=0.01)
|
||||
ksG = np.random.uniform(size=FLAGS.texture_res + [1], low=ks_min[1].cpu(), high=ks_max[1].cpu())
|
||||
ksB = np.random.uniform(size=FLAGS.texture_res + [1], low=ks_min[2].cpu(), high=ks_max[2].cpu())
|
||||
|
||||
ks_map_opt = texture.create_trainable(np.concatenate((ksR, ksG, ksB), axis=2), FLAGS.texture_res, not FLAGS.custom_mip, [ks_min, ks_max])
|
||||
else:
|
||||
kd_map_opt = texture.create_trainable(init_mat['kd'], FLAGS.texture_res, not FLAGS.custom_mip, [kd_min, kd_max])
|
||||
ks_map_opt = texture.create_trainable(init_mat['ks'], FLAGS.texture_res, not FLAGS.custom_mip, [ks_min, ks_max])
|
||||
|
||||
# Setup normal map
|
||||
if FLAGS.random_textures or init_mat is None or 'normal' not in init_mat:
|
||||
normal_map_opt = texture.create_trainable(np.array([0, 0, 1]), FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max])
|
||||
else:
|
||||
normal_map_opt = texture.create_trainable(init_mat['normal'], FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max])
|
||||
|
||||
mat = material.Material({
|
||||
'kd' : kd_map_opt,
|
||||
'ks' : ks_map_opt,
|
||||
'normal' : normal_map_opt
|
||||
})
|
||||
|
||||
if init_mat is not None:
|
||||
mat['bsdf'] = init_mat['bsdf']
|
||||
else:
|
||||
mat['bsdf'] = 'pbr'
|
||||
|
||||
return mat
|
||||
|
||||
def initial_guess_material_knownkskd(geometry, mlp, FLAGS, init_mat=None):
|
||||
nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')
|
||||
|
||||
if mlp:
|
||||
mlp_min = nrm_min
|
||||
mlp_max = nrm_max
|
||||
mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=3, min_max=[mlp_min, mlp_max])
|
||||
# mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=3, min_max=None)
|
||||
mat = material.Material({
|
||||
'kd' : init_mat['kd'],
|
||||
'ks' : init_mat['ks'],
|
||||
'normal' : mlp_map_opt,
|
||||
})
|
||||
else:
|
||||
# Setup normal map
|
||||
if FLAGS.random_textures or init_mat is None or 'normal' not in init_mat:
|
||||
normal_map_opt = texture.create_trainable(np.array([0, 0, 1]), FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max])
|
||||
else:
|
||||
normal_map_opt = texture.create_trainable(init_mat['normal'], FLAGS.texture_res, not FLAGS.custom_mip, [nrm_min, nrm_max])
|
||||
|
||||
mat = material.Material({
|
||||
'kd' : init_mat['kd'],
|
||||
'ks' : init_mat['ks'],
|
||||
'normal' : normal_map_opt
|
||||
})
|
||||
|
||||
if init_mat is not None:
|
||||
mat['bsdf'] = init_mat['bsdf']
|
||||
else:
|
||||
mat['bsdf'] = 'pbr'
|
||||
|
||||
return mat
|
||||
|
||||
###############################################################################
|
||||
# Validation & testing
|
||||
###############################################################################
|
||||
|
||||
def validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS):
|
||||
result_dict = {}
|
||||
with torch.no_grad():
|
||||
lgt.build_mips()
|
||||
if FLAGS.camera_space_light:
|
||||
lgt.xfm(target['mv'])
|
||||
lgt.xfm(target['envlight_transform'])
|
||||
|
||||
try:
|
||||
buffers = geometry.render(glctx, target, lgt, opt_material, ema=True, xfm_lgt=target['envlight_transform'])
|
||||
except:
|
||||
buffers = geometry.render(glctx, target, lgt, opt_material, xfm_lgt=target['envlight_transform'])
|
||||
|
||||
result_dict['ref'] = util.rgb_to_srgb(target['img'][...,0:3])[0]
|
||||
result_dict['opt'] = util.rgb_to_srgb(buffers['shaded'][...,0:3])[0]
|
||||
result_image = torch.cat([result_dict['opt'], result_dict['ref']], axis=1)
|
||||
|
||||
if FLAGS.display is not None:
|
||||
white_bg = torch.ones_like(target['background'])
|
||||
for layer in FLAGS.display:
|
||||
if 'latlong' in layer and layer['latlong']:
|
||||
if isinstance(lgt, light.EnvironmentLight):
|
||||
result_dict['light_image'] = util.cubemap_to_latlong(lgt.base, FLAGS.display_res)
|
||||
result_image = torch.cat([result_image, result_dict['light_image']], axis=1)
|
||||
elif 'relight' in layer:
|
||||
if not isinstance(layer['relight'], light.EnvironmentLight):
|
||||
layer['relight'] = light.load_env(layer['relight'])
|
||||
img = geometry.render(glctx, target, layer['relight'], opt_material)
|
||||
result_dict['relight'] = util.rgb_to_srgb(img[..., 0:3])[0]
|
||||
result_image = torch.cat([result_image, result_dict['relight']], axis=1)
|
||||
elif 'bsdf' in layer:
|
||||
buffers = geometry.render(glctx, target, lgt, opt_material, bsdf=layer['bsdf'])
|
||||
if layer['bsdf'] == 'kd':
|
||||
result_dict[layer['bsdf']] = util.rgb_to_srgb(buffers['shaded'][0, ..., 0:3])
|
||||
elif layer['bsdf'] == 'normal':
|
||||
result_dict[layer['bsdf']] = (buffers['shaded'][0, ..., 0:3] + 1) * 0.5
|
||||
else:
|
||||
result_dict[layer['bsdf']] = buffers['shaded'][0, ..., 0:3]
|
||||
result_image = torch.cat([result_image, result_dict[layer['bsdf']]], axis=1)
|
||||
elif "depth" in layer:
|
||||
depth = buffers['depth'][:, :, :, 0].squeeze().unsqueeze(-1).expand(-1, -1, 1)
|
||||
mask = (depth != 0).float()
|
||||
depth_min = ((1 - mask) * 1e3 + depth).min()
|
||||
depth_max = depth.max()
|
||||
depth = (depth - depth_min) / (depth_max - depth_min + 1e-8)
|
||||
depth = depth * mask + (1 - mask) * depth_min
|
||||
depth = depth.expand(-1, -1, 3)
|
||||
depth = cv2.LUT(np.array(depth.detach().cpu().numpy() * 255.0, dtype=np.uint8), lut)
|
||||
result_dict['depth'] = depth = (torch.tensor(depth, device=mask.device).float() / 255.0 * mask) + 255. * (1 - mask)
|
||||
result_image = torch.cat([result_image, depth], axis=1)
|
||||
|
||||
buffers = geometry.render(glctx, target, lgt, opt_material)
|
||||
|
||||
camera = target['geo_viewdir'][:, :, :, :3]
|
||||
|
||||
result_dict['geo_normal'] = (util.safe_normalize(buffers['geo_normal'][:, :, :, :3]) * camera).sum(-1, keepdim=False).abs()[0]
|
||||
|
||||
mask = buffers['mask'][0].expand(-1, -1, 3)
|
||||
result_image = torch.cat([result_image, mask], axis=1)
|
||||
|
||||
return result_image, result_dict
|
||||
|
||||
def validate(glctx, geometry, opt_material, lgt, dataset_validate, out_dir, FLAGS):
|
||||
|
||||
# ==============================================================================================
|
||||
# Validation loop
|
||||
# ==============================================================================================
|
||||
mse_values = []
|
||||
psnr_values = []
|
||||
|
||||
dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_validate.collate)
|
||||
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
with open(os.path.join(out_dir, 'metrics.txt'), 'w') as fout:
|
||||
fout.write('ID, MSE, PSNR\n')
|
||||
|
||||
print("Running validation")
|
||||
for it, target in enumerate(dataloader_validate):
|
||||
|
||||
# Mix validation background
|
||||
target = prepare_batch(target, FLAGS.background)
|
||||
|
||||
result_image, result_dict = validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS)
|
||||
|
||||
# Compute metrics
|
||||
opt = torch.clamp(result_dict['opt'], 0.0, 1.0)
|
||||
ref = torch.clamp(result_dict['ref'], 0.0, 1.0)
|
||||
|
||||
mse = torch.nn.functional.mse_loss(opt, ref, size_average=None, reduce=None, reduction='mean').item()
|
||||
mse_values.append(float(mse))
|
||||
psnr = util.mse_to_psnr(mse)
|
||||
psnr_values.append(float(psnr))
|
||||
|
||||
line = "%d, %1.8f, %1.8f\n" % (it, mse, psnr)
|
||||
fout.write(str(line))
|
||||
|
||||
for k in result_dict.keys():
|
||||
np_img = result_dict[k].detach().cpu().numpy()
|
||||
util.save_image(out_dir + '/' + ('val_%06d_%s.png' % (it, k)), np_img)
|
||||
|
||||
avg_mse = np.mean(np.array(mse_values))
|
||||
avg_psnr = np.mean(np.array(psnr_values))
|
||||
line = "AVERAGES: %1.4f, %2.3f\n" % (avg_mse, avg_psnr)
|
||||
fout.write(str(line))
|
||||
print("MSE, PSNR")
|
||||
print("%1.8f, %2.3f" % (avg_mse, avg_psnr))
|
||||
return avg_psnr
|
||||
|
||||
###############################################################################
|
||||
# Main shape fitter function / optimization loop
|
||||
###############################################################################
|
||||
|
||||
class Trainer(torch.nn.Module):
|
||||
def __init__(self, glctx, geometry, lgt, mat, optimize_geometry, optimize_light, image_loss_fn, FLAGS):
|
||||
super(Trainer, self).__init__()
|
||||
|
||||
self.glctx = glctx
|
||||
self.geometry = geometry
|
||||
self.light = lgt
|
||||
self.material = mat
|
||||
self.optimize_geometry = optimize_geometry
|
||||
self.optimize_light = optimize_light
|
||||
self.image_loss_fn = image_loss_fn
|
||||
self.FLAGS = FLAGS
|
||||
|
||||
if not self.optimize_light:
|
||||
with torch.no_grad():
|
||||
self.light.build_mips()
|
||||
|
||||
self.params = list(self.material.parameters())
|
||||
self.params += list(self.light.parameters()) if optimize_light else []
|
||||
self.geo_params = list(self.geometry.parameters()) if optimize_geometry else []
|
||||
try:
|
||||
self.sdf_params = [self.geometry.sdf]
|
||||
except:
|
||||
self.sdf_params = []
|
||||
self.deform_params = [self.geometry.deform]
|
||||
|
||||
def forward(self, target, it):
|
||||
if self.optimize_light:
|
||||
self.light.build_mips()
|
||||
if self.FLAGS.camera_space_light:
|
||||
self.light.xfm(target['mv'])
|
||||
self.light.xfm(target['envlight_transform'])
|
||||
|
||||
return self.geometry.tick(glctx, target, self.light, self.material, self.image_loss_fn, it, xfm_lgt=target['envlight_transform'], no_depth_thin=False)
|
||||
|
||||
def optimize_mesh(
|
||||
glctx,
|
||||
geometry,
|
||||
opt_material,
|
||||
lgt,
|
||||
dataset_train,
|
||||
dataset_validate,
|
||||
FLAGS,
|
||||
warmup_iter=0,
|
||||
log_interval=10,
|
||||
pass_idx=0,
|
||||
pass_name="",
|
||||
optimize_light=True,
|
||||
optimize_geometry=True,
|
||||
):
|
||||
|
||||
# ==============================================================================================
|
||||
# Setup torch optimizer
|
||||
# ==============================================================================================
|
||||
|
||||
learning_rate = FLAGS.learning_rate[pass_idx] if isinstance(FLAGS.learning_rate, list) or isinstance(FLAGS.learning_rate, tuple) else FLAGS.learning_rate
|
||||
learning_rate_pos = learning_rate[0] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate
|
||||
learning_rate_mat = learning_rate[1] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate
|
||||
|
||||
def lr_schedule(iter, fraction):
|
||||
if iter < warmup_iter:
|
||||
return iter / warmup_iter
|
||||
return max(0.0, 10**(-(iter - warmup_iter)*0.0002)) # Exponential falloff from [1.0, 0.1] over 5k epochs.
|
||||
|
||||
# ==============================================================================================
|
||||
# Image loss
|
||||
# ==============================================================================================
|
||||
image_loss_fn = createLoss(FLAGS)
|
||||
|
||||
trainer_noddp = Trainer(glctx, geometry, lgt, opt_material, optimize_geometry, optimize_light, image_loss_fn, FLAGS)
|
||||
|
||||
# Single GPU training mode
|
||||
trainer = trainer_noddp
|
||||
if optimize_geometry:
|
||||
optimizer_mesh = torch.optim.Adam([
|
||||
{'params': trainer_noddp.sdf_params, 'lr': learning_rate_pos},
|
||||
{'params': trainer_noddp.deform_params, 'lr': learning_rate_pos},
|
||||
]
|
||||
)
|
||||
scheduler_mesh = torch.optim.lr_scheduler.LambdaLR(optimizer_mesh, lr_lambda=lambda x: lr_schedule(x, 0.9))
|
||||
|
||||
optimizer = torch.optim.Adam(trainer_noddp.params, lr=learning_rate_mat)
|
||||
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: lr_schedule(x, 0.9))
|
||||
|
||||
# ==============================================================================================
|
||||
# Training loop
|
||||
# ==============================================================================================
|
||||
img_cnt = 0
|
||||
img_loss_vec = []
|
||||
reg_loss_vec = []
|
||||
iter_dur_vec = []
|
||||
|
||||
dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_train.collate)
|
||||
|
||||
def cycle(iterable):
|
||||
iterator = iter(iterable)
|
||||
while True:
|
||||
try:
|
||||
yield next(iterator)
|
||||
except StopIteration:
|
||||
iterator = iter(iterable)
|
||||
|
||||
v_it = cycle(dataloader_validate)
|
||||
# v_iter_no = 25
|
||||
v_iter_no = 10
|
||||
|
||||
print("Start training loop...")
|
||||
sys.stdout.flush()
|
||||
for _ in range(v_iter_no):
|
||||
v_curr = next(v_it)
|
||||
|
||||
for it in range(5000):
|
||||
# Mix randomized background into dataset image
|
||||
target = prepare_batch(v_curr, 'random')
|
||||
|
||||
### for robustness, we take the easy way of initializing the tet grid with the gt depth image
|
||||
if it < 300 and it % 10 == 0:
|
||||
gt_visible_triangles = target['rast_triangle_id'].long()
|
||||
gt_verts, gt_faces = target['vpts'], target['faces']
|
||||
surface_faces = gt_faces[gt_visible_triangles]
|
||||
campos = target['campos'][0]
|
||||
try:
|
||||
geometry.init_with_gt_surface(gt_verts, surface_faces, campos)
|
||||
except:
|
||||
pass
|
||||
|
||||
iter_start_time = time.time()
|
||||
|
||||
|
||||
# ==============================================================================================
|
||||
# Zero gradients
|
||||
# ==============================================================================================
|
||||
optimizer.zero_grad()
|
||||
if optimize_geometry:
|
||||
optimizer_mesh.zero_grad()
|
||||
|
||||
# ==============================================================================================
|
||||
# Training
|
||||
# ==============================================================================================
|
||||
img_loss, reg_loss = trainer(target, it)
|
||||
|
||||
# ==============================================================================================
|
||||
# Final loss
|
||||
# ==============================================================================================
|
||||
total_loss = img_loss + reg_loss
|
||||
|
||||
img_loss_vec.append(img_loss.item())
|
||||
reg_loss_vec.append(reg_loss.item())
|
||||
|
||||
# ==============================================================================================
|
||||
# Backpropagate
|
||||
# ==============================================================================================
|
||||
total_loss.backward()
|
||||
|
||||
if hasattr(lgt, 'base') and lgt.base.grad is not None and optimize_light:
|
||||
lgt.base.grad *= 64
|
||||
if 'kd_ks_normal' in opt_material:
|
||||
opt_material['kd_ks_normal'].encoder.params.grad /= 8.0
|
||||
if 'normal' in opt_material and FLAGS.normal_only:
|
||||
try:
|
||||
opt_material['normal'].encoder.params.grad /= 8.0
|
||||
except:
|
||||
pass
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
|
||||
if optimize_geometry:
|
||||
optimizer_mesh.step()
|
||||
scheduler_mesh.step()
|
||||
|
||||
geometry.clamp_deform()
|
||||
if FLAGS.use_ema:
|
||||
raise NotImplementedError
|
||||
geometry.update_ema()
|
||||
|
||||
# ==============================================================================================
|
||||
# Clamp trainables to reasonable range
|
||||
# ==============================================================================================
|
||||
with torch.no_grad():
|
||||
if 'kd' in opt_material:
|
||||
opt_material['kd'].clamp_()
|
||||
if 'ks' in opt_material:
|
||||
opt_material['ks'].clamp_()
|
||||
if 'normal' in opt_material and not FLAGS.normal_only:
|
||||
opt_material['normal'].clamp_()
|
||||
opt_material['normal'].normalize_()
|
||||
if lgt is not None:
|
||||
lgt.clamp_(min=0.0)
|
||||
|
||||
torch.cuda.current_stream().synchronize()
|
||||
iter_dur_vec.append(time.time() - iter_start_time)
|
||||
|
||||
# ==============================================================================================
|
||||
# Logging
|
||||
# ==============================================================================================
|
||||
if it % log_interval == 0 and FLAGS.local_rank == 0:
|
||||
img_loss_avg = np.mean(np.asarray(img_loss_vec[-log_interval:]))
|
||||
reg_loss_avg = np.mean(np.asarray(reg_loss_vec[-log_interval:]))
|
||||
iter_dur_avg = np.mean(np.asarray(iter_dur_vec[-log_interval:]))
|
||||
|
||||
remaining_time = (FLAGS.iter-it)*iter_dur_avg
|
||||
print("iter=%5d, img_loss=%.6f, reg_loss=%.6f, lr=%.5f, time=%.1f ms, rem=%s" %
|
||||
(it, img_loss_avg, reg_loss_avg, optimizer.param_groups[0]['lr'], iter_dur_avg*1000, util.time_to_text(remaining_time)))
|
||||
sys.stdout.flush()
|
||||
|
||||
return geometry, opt_material
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Main function.
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='nvdiffrec')
|
||||
parser.add_argument('--config', type=str, default='./configs/res64.json', help='Config file')
|
||||
parser.add_argument('-i', '--iter', type=int, default=5000)
|
||||
parser.add_argument('-s', '--spp', type=int, default=1)
|
||||
parser.add_argument('-l', '--layers', type=int, default=1)
|
||||
parser.add_argument('-r', '--train-res', nargs=2, type=int, default=[512, 512])
|
||||
parser.add_argument('-dr', '--display-res', type=int, default=None)
|
||||
parser.add_argument('-tr', '--texture-res', nargs=2, type=int, default=[1024, 1024])
|
||||
parser.add_argument('-di', '--display-interval', type=int, default=0)
|
||||
parser.add_argument('-si', '--save-interval', type=int, default=1000)
|
||||
parser.add_argument('-lr', '--learning-rate', type=float, default=0.01)
|
||||
parser.add_argument('-mr', '--min-roughness', type=float, default=0.08)
|
||||
parser.add_argument('-mip', '--custom-mip', action='store_true', default=False)
|
||||
parser.add_argument('-rt', '--random-textures', action='store_true', default=False)
|
||||
parser.add_argument('-bg', '--background', default='checker', choices=['black', 'white', 'checker', 'reference'])
|
||||
parser.add_argument('--loss', default='logl1', choices=['logl1', 'logl2', 'mse', 'smape', 'relmse'])
|
||||
parser.add_argument('-o', '--out-dir', type=str, default='./dmtet_results_singleview')
|
||||
parser.add_argument('--validate', type=bool, default=True)
|
||||
parser.add_argument('-no', '--normal-only', type=bool, default=True)
|
||||
parser.add_argument('-ema', '--use-ema', action="store_true")
|
||||
parser.add_argument('-rp', '--resume-path', type=str, default=None)
|
||||
parser.add_argument('-mp', '--mesh-path', type=str)
|
||||
parser.add_argument('-an', '--angle-ind', type=int, help='angle index from 0 to 50')
|
||||
|
||||
FLAGS = parser.parse_args()
|
||||
print(f"parsed arguments")
|
||||
|
||||
FLAGS.mtl_override = None # Override material of model
|
||||
FLAGS.dmtet_grid = 64 # Resolution of initial tet grid. We provide 64 and 128 resolution grids. Other resolutions can be generated with https://github.com/crawforddoran/quartet
|
||||
FLAGS.mesh_scale = 1.0 # Scale of tet grid box. Adjust to cover the model
|
||||
FLAGS.env_scale = 1.0 # Env map intensity multiplier
|
||||
FLAGS.envmap = None # HDR environment probe
|
||||
FLAGS.display = None # Conf validation window/display. E.g. [{"relight" : <path to envlight>}]
|
||||
FLAGS.camera_space_light = False # Fixed light in camera space. This is needed for setups like ethiopian head where the scanned object rotates on a stand.
|
||||
FLAGS.lock_light = False # Disable light optimization in the second pass
|
||||
FLAGS.lock_pos = False # Disable vertex position optimization in the second pass
|
||||
FLAGS.sdf_regularizer = 0.2 # Weight for sdf regularizer (see paper for details)
|
||||
FLAGS.laplace = "relative" # Mesh Laplacian ["absolute", "relative"]
|
||||
FLAGS.laplace_scale = 10000.0 # Weight for sdf regularizer. Default is relative with large weight
|
||||
FLAGS.pre_load = True # Pre-load entire dataset into memory for faster training
|
||||
FLAGS.kd_min = [ 0.0, 0.0, 0.0, 0.0] # Limits for kd
|
||||
FLAGS.kd_max = [ 1.0, 1.0, 1.0, 1.0]
|
||||
FLAGS.ks_min = [ 0.0, 0.08, 0.0] # Limits for ks
|
||||
FLAGS.ks_max = [ 1.0, 1.0, 1.0]
|
||||
FLAGS.nrm_min = [-1.0, -1.0, 0.0] # Limits for normal map
|
||||
FLAGS.nrm_max = [ 1.0, 1.0, 1.0]
|
||||
FLAGS.cam_near_far = [0.1, 1000.0]
|
||||
FLAGS.learn_light = False
|
||||
FLAGS.use_ema = False
|
||||
FLAGS.random_lgt = True
|
||||
FLAGS.dataset_flat_shading = False
|
||||
|
||||
FLAGS.local_rank = 0
|
||||
if FLAGS.config is not None:
|
||||
data = json.load(open(FLAGS.config, 'r'))
|
||||
for key in data:
|
||||
FLAGS.__dict__[key] = data[key]
|
||||
|
||||
if FLAGS.display_res is None:
|
||||
FLAGS.display_res = FLAGS.train_res
|
||||
|
||||
print(f"Out dir: {FLAGS.out_dir}")
|
||||
|
||||
if FLAGS.local_rank == 0:
|
||||
print("Config / Flags:")
|
||||
print("---------")
|
||||
for key in FLAGS.__dict__.keys():
|
||||
print(key, FLAGS.__dict__[key])
|
||||
print("---------")
|
||||
|
||||
os.makedirs(FLAGS.out_dir, exist_ok=True)
|
||||
os.makedirs(os.path.join(FLAGS.out_dir, 'val_viz'), exist_ok=True)
|
||||
os.makedirs(os.path.join(FLAGS.out_dir, 'val_viz_pre'), exist_ok=True)
|
||||
os.makedirs(os.path.join(FLAGS.out_dir, 'tets'), exist_ok=True)
|
||||
os.makedirs(os.path.join(FLAGS.out_dir, 'tets_pre'), exist_ok=True)
|
||||
|
||||
|
||||
print(f"Using dmtet grid of resolution {FLAGS.dmtet_grid}")
|
||||
|
||||
glctx = dr.RasterizeGLContext()
|
||||
|
||||
### Default mtl
|
||||
mtl_default = {
|
||||
'name' : '_default_mat',
|
||||
'bsdf': 'diffuse',
|
||||
'uniform': True,
|
||||
'kd' : texture.Texture2D(torch.tensor([0.75, 0.3, 0.6], dtype=torch.float32, device='cuda'), trainable=False),
|
||||
'ks' : texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda'), trainable=False)
|
||||
}
|
||||
|
||||
print(f"Loading mesh: {FLAGS.mesh_path}")
|
||||
sys.stdout.flush()
|
||||
ref_mesh = mesh.load_mesh(FLAGS.mesh_path, FLAGS.mtl_override, mtl_default, use_default=FLAGS.normal_only, no_additional=True)
|
||||
ref_mesh = mesh.center_by_reference(ref_mesh, mesh.aabb_clean(ref_mesh), 1.0)
|
||||
print("Loading dataset")
|
||||
sys.stdout.flush()
|
||||
dataset_train = DatasetMesh(ref_mesh, glctx, RADIUS, FLAGS, validate=False)
|
||||
dataset_validate = DatasetMesh(ref_mesh, glctx, RADIUS, FLAGS, validate=True)
|
||||
print("Dataset loaded")
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
# ==============================================================================================
|
||||
# Create env light with trainable parameters
|
||||
# ==============================================================================================
|
||||
|
||||
if FLAGS.learn_light:
|
||||
lgt = light.create_trainable_env_rnd(512, scale=0.0, bias=0.5)
|
||||
else:
|
||||
lgt = light.load_env(FLAGS.envmap, scale=FLAGS.env_scale, trainable=False)
|
||||
|
||||
|
||||
# ==============================================================================================
|
||||
# If no initial guess, use DMtets to create geometry
|
||||
# ==============================================================================================
|
||||
|
||||
# Setup geometry for optimization
|
||||
geometry = DMTetGeometry(FLAGS.dmtet_grid, FLAGS.mesh_scale, FLAGS)
|
||||
|
||||
# Setup textures, make initial guess from reference if possible
|
||||
if not FLAGS.normal_only:
|
||||
mat = initial_guess_material(geometry, True, FLAGS, mtl_default)
|
||||
else:
|
||||
mat = initial_guess_material_knownkskd(geometry, True, FLAGS, mtl_default)
|
||||
|
||||
print("Start optimization")
|
||||
sys.stdout.flush()
|
||||
|
||||
if FLAGS.resume_path is None:
|
||||
# Run optimization
|
||||
geometry, mat = optimize_mesh(glctx, geometry, mat, lgt, dataset_train, dataset_validate,
|
||||
FLAGS, pass_idx=0, pass_name="dmtet_pass1", optimize_light=FLAGS.learn_light)
|
||||
|
||||
base_mesh = geometry.getMesh(mat)
|
||||
|
||||
vert_mask = torch.zeros_like(geometry.sdf).long().cuda().view(-1, 1)
|
||||
vert_mask[geometry.getValidVertsIdx()] = 1
|
||||
|
||||
# Free temporaries / cached memory
|
||||
torch.cuda.empty_cache() ### may slow down training
|
||||
|
||||
torch.save({
|
||||
'sdf': geometry.sdf.cpu().detach(),
|
||||
'sdf_ema': geometry.sdf_ema.cpu().detach(),
|
||||
'deform': (geometry.deform * vert_mask).cpu().detach(),
|
||||
'deform_unmasked': geometry.deform.cpu().detach(),
|
||||
}, os.path.join(FLAGS.out_dir, 'tets_pre/dmt_dict_{:05d}.pt'.format(global_index)))
|
||||
|
||||
old_geometry = geometry
|
||||
|
||||
if FLAGS.local_rank == 0 and FLAGS.validate:
|
||||
validate(glctx, geometry, mat, lgt, dataset_validate, os.path.join(FLAGS.out_dir, f"val_viz_pre/dmtet_validate_{FLAGS.index}_{k}_{FLAGS.split_size}"), FLAGS)
|
||||
|
||||
else:
|
||||
dmt_dict = torch.load(os.path.join(FLAGS.resume_path, 'tets_pre/dmt_dict_{:05d}.pt'.format(global_index)))
|
||||
if FLAGS.use_ema:
|
||||
geometry.sdf.data[:] = dmt_dict['sdf_ema']
|
||||
else:
|
||||
geometry.sdf.data[:] = dmt_dict['sdf']
|
||||
geometry.deform.data[:] = dmt_dict['deform']
|
||||
old_geometry = geometry
|
||||
|
||||
# Create textured mesh from result
|
||||
if FLAGS.normal_only:
|
||||
base_mesh = xatlas_uvmap_nrm(glctx, geometry, mat, FLAGS)
|
||||
else:
|
||||
base_mesh = xatlas_uvmap(glctx, geometry, mat, FLAGS)
|
||||
|
||||
|
||||
geometry = DMTetGeometryFixedTopo(geometry, base_mesh, FLAGS.dmtet_grid, FLAGS.mesh_scale, FLAGS)
|
||||
|
||||
|
||||
geometry.sdf_sign.requires_grad = False
|
||||
geometry.sdf_abs.requires_grad = False
|
||||
geometry.deform.requires_grad = True
|
||||
|
||||
geometry.deform.data[:] = geometry.deform * 2.0 / 3.0
|
||||
geometry.deform_scale = 3.0
|
||||
|
||||
if FLAGS.use_ema:
|
||||
geometry.sdf_sign.data[:] = torch.sign(old_geometry.sdf_ema) ### use ema
|
||||
else:
|
||||
geometry.sdf_sign.data[:] = torch.sign(old_geometry.sdf) ### use ema
|
||||
|
||||
geometry.set_init_v_pos()
|
||||
|
||||
|
||||
# ==============================================================================================
|
||||
# Pass 2: Train with fixed topology (mesh)
|
||||
# ==============================================================================================
|
||||
geometry, mat = optimize_mesh(glctx, geometry, mat, lgt, dataset_train, dataset_validate, FLAGS,
|
||||
pass_idx=1, pass_name="mesh_pass", warmup_iter=100, optimize_light=FLAGS.learn_light and not FLAGS.lock_light,
|
||||
optimize_geometry=not FLAGS.lock_pos)
|
||||
|
||||
##### Process single-view tet grid
|
||||
dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_train.collate)
|
||||
|
||||
v_it = iter(dataloader_validate)
|
||||
for _ in range(FLAGS.angle_ind):
|
||||
v_curr = next(v_it)
|
||||
|
||||
target = prepare_batch(v_curr, 'random')
|
||||
|
||||
# ==============================================================================================
|
||||
# Infer occluded regions
|
||||
# ==============================================================================================
|
||||
valid_tet_idx = geometry.getValidTetIdx().long()
|
||||
|
||||
buffers = geometry.render(glctx, target, lgt, mat, get_visible_tets=True)
|
||||
|
||||
## visible tets (except for rasterized ones)
|
||||
visible_tets = torch.zeros(geometry.indices.size(0)).cuda()
|
||||
visible_tets[buffers['visible_tet_id'].long()] = 1
|
||||
|
||||
## to include the rasterized tetrahedra
|
||||
visible_and_rast_tets = visible_tets.clone()
|
||||
rast_tet_id = valid_tet_idx[buffers['rast_triangle_id'].long()].unique()
|
||||
visible_and_rast_tets[rast_tet_id] = 1
|
||||
|
||||
visible_tets = (visible_tets == 1)
|
||||
visible_and_rast_tets = (visible_and_rast_tets == 1)
|
||||
|
||||
## label all tetrahedral vertices associated with any visible tets
|
||||
visible_verts = torch.zeros(geometry.verts.size(0))
|
||||
tet_inds = torch.arange(geometry.indices.size(0))
|
||||
vis_vert_inds = geometry.indices[visible_tets].unique()
|
||||
visible_verts[vis_vert_inds] = 1
|
||||
|
||||
visible_and_rast_verts = visible_verts.clone()
|
||||
vis_and_rast_vert_inds = geometry.indices[visible_and_rast_tets].unique()
|
||||
visible_and_rast_verts[vis_and_rast_vert_inds] = 1
|
||||
visible_and_rast_verts = visible_and_rast_verts.bool()
|
||||
|
||||
torch.save({
|
||||
'sdf': geometry.sdf_sign.cpu().detach(),
|
||||
'deform': geometry.deform.cpu().detach(),
|
||||
'vis': visible_verts.cpu().detach(),
|
||||
'vis_rast': visible_and_rast_verts.cpu().detach()
|
||||
}, os.path.join(FLAGS.out_dir, 'tets/dmtet.pt'.format(global_index)))
|
||||
|
||||
# ==============================================================================================
|
||||
|
||||
|
||||
if FLAGS.local_rank == 0 and FLAGS.validate:
|
||||
validate(glctx, geometry, mat, lgt, dataset_validate, os.path.join(FLAGS.out_dir, f"val_viz/dmtet"), FLAGS)
|
||||
|
Ładowanie…
Reference in New Issue