Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Axis from colmap to nvdiffrast #204

Open
Zerui-Yu opened this issue Oct 11, 2024 · 0 comments
Open

Axis from colmap to nvdiffrast #204

Zerui-Yu opened this issue Oct 11, 2024 · 0 comments

Comments

@Zerui-Yu
Copy link

Hello! I have recently been trying to train my mesh to ground truth by using the normals of some images in the COLMAP dataset, specifically, the camera information that provides each view in the dataset: cameras.txt contains the camera's internal information, images.txt contains the camera's external parameters. The dataset can be downloaded through Google Drive. I calculated the view and projection matrices using the following functions.

import numpy as np
from scipy.spatial.transform import Rotation as R
import torch

def get_mv_proj(path):
    with open((path + '/cameras.txt'), 'r') as cam_file:
        for line in cam_file:
            if line.startswith('#') or len(line.strip()) == 0:
                continue
            parts = line.split()
            camera_id = int(parts[0])
            model = parts[1]
            width = int(parts[2])
            height = int(parts[3])
            f_x, f_y, c_x, c_y = map(float, parts[4:])
            break

    image_data = []
    with open((path + '/images.txt'), 'r') as img_file:
        for line in img_file:
            if line.endswith('.png\n'):
                parts = line.split()
                image_id = int(parts[0])
                qw, qx, qy, qz = map(float, parts[1:5])
                tx, ty, tz = map(float, parts[5:8])
                camera_id = int(parts[8])
                name = parts[9]
                image_data.append((name, qw, qx, qy, qz, tx, ty, tz))
            else:
                continue

    image_data.sort(key=lambda x: x[0])

    mv = []
    proj = []

    for data in image_data:
        _, qw, qx, qy, qz, tx, ty, tz = data

        q = [qw, qx, qy, qz]
        rotation_matrix = R.from_quat(q).as_matrix()

        view_matrix = np.eye(4)
        view_matrix[:3, :3] = rotation_matrix.T
        view_matrix[:3, 3] = -rotation_matrix.T @ np.array([tx, ty, tz])

        mv.append(view_matrix)

        near, far = 0.1, 100.0
        projection_matrix = np.zeros((4, 4))
        projection_matrix[0, 0] = 2 * f_x / width
        projection_matrix[1, 1] = 2 * f_y / height
        projection_matrix[0, 2] = 2 * c_x / width - 1
        projection_matrix[1, 2] = 2 * c_y / height - 1
        projection_matrix[2, 2] = -(far + near) / (far - near)
        projection_matrix[2, 3] = -(2 * far * near) / (far - near)
        projection_matrix[3, 2] = -1
        proj.append(projection_matrix)

    view_matrices_tensor = torch.tensor(mv, dtype=torch.float32, device='cuda')
    projection_matrices_tensor = torch.tensor(proj, dtype=torch.float32, device='cuda')

    return view_matrices_tensor, projection_matrices_tensor

When I pass the resulting matrix into an nvdiffrast-based renderer, I seem to have gotten the wrong view (It was either black, and all pixels are zero, or only part of the mesh was visible). The renderer is constructed as follows:

import torch
from typing import Tuple


def _warmup(glctx):
    def tensor(*args, **kwargs):
        return torch.tensor(*args, device='cuda', **kwargs)
    pos = tensor([[[-0.8, -0.8, 0, 1], [0.8, -0.8, 0, 1], [-0.8, 0.8, 0, 1]]], dtype=torch.float32)
    tri = tensor([[0, 1, 2]], dtype=torch.int32)
    dr.rasterize(glctx, pos, tri, resolution=[256, 256])

class NormalsRenderer:
    
    _glctx:dr.RasterizeGLContext = None
    
    def __init__(
            self,
            mv: torch.Tensor, #C,4,4
            proj: torch.Tensor, #C,4,4
            image_size: Tuple[int,int],
            ):
        self._mvp = proj @ mv #C,4,4
        self._image_size = image_size
        self._glctx = dr.RasterizeGLContext()
        _warmup(self._glctx)

    def render(self,
            vertices: torch.Tensor, #V,3 float
            normals: torch.Tensor, #V,3 float
            faces: torch.Tensor, #F,3 long
            ) ->torch.Tensor: #C,H,W,4

        V = vertices.shape[0]
        faces = faces.type(torch.int32)
        vert_hom = torch.cat((vertices, torch.ones(V,1,device=vertices.device)),axis=-1) #V,3 -> V,4
        vertices_clip = vert_hom @ self._mvp.transpose(-2,-1) #C,V,4
        rast_out,_ = dr.rasterize(self._glctx, vertices_clip, faces, resolution=self._image_size, grad_db=False) #C,H,W,4
        vert_col = (normals+1)/2 #V,3
        col,_ = dr.interpolate(vert_col, rast_out, faces) #C,H,W,3
        alpha = torch.clamp(rast_out[..., -1:], max=1) #C,H,W,1
        col = torch.concat((col,alpha),dim=-1) #C,H,W,4
        col = dr.antialias(col, rast_out, vertices_clip, faces) #C,H,W,4
        return col #C,H,W,4

I wonder if this is because the coordinates of nvdiffrast are different from those in colmap? How Do I get the right projection matrix and view matrix?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant