Co-SLAM

어떻게 tcnn의 hash grid encoding과 mapping이 상호작용하는가?

  1. Forward (learning)
    • scene_rep.py 에서 forward
    • loss를 구하기 위해 ray를 render하고 rend_dict에 저장한다.
    • render된 rgb 및 depth값과 loss를 return한다.
     def forward(self, rays_o, rays_d, target_rgb, target_d, global_step=0):
         '''
         Params:
             rays_o: ray origins (Bs, 3)
             rays_d: ray directions (Bs, 3)
             frame_ids: use for pose correction (Bs, 1)
             target_rgb: rgb value (Bs, 3)
             target_d: depth value (Bs, 1)
             c2w_array: poses (N, 4, 4) 
              r r r tx
              r r r ty
              r r r tz
         '''
        
         # Get render results
         **rend_dict = self.render_rays(rays_o, rays_d, target_d=target_d)**
        
         if not self.training:
             return rend_dict
            
         # Get depth and rgb weights for loss
         valid_depth_mask = (target_d.squeeze() > 0.) * (target_d.squeeze() < self.config['cam']['depth_trunc'])
         rgb_weight = valid_depth_mask.clone().unsqueeze(-1)
         rgb_weight[rgb_weight==0] = self.config['training']['rgb_missing']
        
         # Get render loss
         rgb_loss = compute_loss(rend_dict["rgb"]*rgb_weight, target_rgb*rgb_weight)
         psnr = mse2psnr(rgb_loss)
         depth_loss = compute_loss(rend_dict["depth"].squeeze()[valid_depth_mask], target_d.squeeze()[valid_depth_mask])
        
         if 'rgb0' in rend_dict:
             rgb_loss += compute_loss(rend_dict["rgb0"]*rgb_weight, target_rgb*rgb_weight)
             depth_loss += compute_loss(rend_dict["depth0"][valid_depth_mask], target_d.squeeze()[valid_depth_mask])
            
         # Get sdf loss
         z_vals = rend_dict['z_vals']  # [N_rand, N_samples + N_importance]
         sdf = rend_dict['raw'][..., -1]  # [N_rand, N_samples + N_importance]
         truncation = self.config['training']['trunc'] * self.config['data']['sc_factor']
         fs_loss, sdf_loss = get_sdf_loss(z_vals, target_d, sdf, truncation, 'l2', grad=None)         
            
        
         ret = {
             "rgb": rend_dict["rgb"],
             "depth": rend_dict["depth"],
             "rgb_loss": rgb_loss,
             "depth_loss": depth_loss,
             "sdf_loss": sdf_loss,
             "fs_loss": fs_loss,
             "psnr": psnr,
         }
        
         return ret
    
  2. Render Rays
    • scene_rep.py 에서 render_rays
    • ray와 target을 바탕으로 render된 결과를 return 한다.
     def render_rays(self, rays_o, rays_d, target_d=None):
         '''
         Params:
             rays_o: [N_rays, 3]
             rays_d: [N_rays, 3]
             target_d: [N_rays, 1]
        
         '''
         n_rays = rays_o.shape[0]
        
         # Sample depth
         if target_d is not None:
             z_samples = torch.linspace(-self.config['training']['range_d'], self.config['training']['range_d'], steps=self.config['training']['n_range_d']).to(target_d)
             z_samples = z_samples[None, :].repeat(n_rays, 1) + target_d
             z_samples[target_d.squeeze()<=0] = torch.linspace(self.config['cam']['near'], self.config['cam']['far'], steps=self.config['training']['n_range_d']).to(target_d)
        
             if self.config['training']['n_samples_d'] > 0:
                 z_vals = torch.linspace(self.config['cam']['near'], self.config['cam']['far'], self.config['training']['n_samples_d'])[None, :].repeat(n_rays, 1).to(rays_o)
                 z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
             else:
                 z_vals = z_samples
         else:
             z_vals = torch.linspace(self.config['cam']['near'], self.config['cam']['far'], self.config['training']['n_samples']).to(rays_o)
             z_vals = z_vals[None, :].repeat(n_rays, 1) # [n_rays, n_samples]
        
         # Perturb sampling depths
         if self.config['training']['perturb'] > 0.:
             mids = .5 * (z_vals[...,1:] + z_vals[...,:-1])
             upper = torch.cat([mids, z_vals[...,-1:]], -1)
             lower = torch.cat([z_vals[...,:1], mids], -1)
             z_vals = lower + (upper - lower) * torch.rand(z_vals.shape).to(rays_o)
        
         # Run rendering pipeline
         pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3]
         raw = self.run_network(pts)
         rgb_map, disp_map, acc_map, weights, depth_map, depth_var = self.raw2outputs(raw, z_vals, self.config['training']['white_bkgd'])
        
         # Importance sampling
         if self.config['training']['n_importance'] > 0:
        
             rgb_map_0, disp_map_0, acc_map_0, depth_map_0, depth_var_0 = rgb_map, disp_map, acc_map, depth_map, depth_var
        
             z_vals_mid = .5 * (z_vals[...,1:] + z_vals[...,:-1])
             z_samples = sample_pdf(z_vals_mid, weights[...,1:-1], self.config['training']['n_importance'], det=(self.config['training']['perturb']==0.))
             z_samples = z_samples.detach()
        
             z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
             pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples + N_importance, 3]
        
             raw = self.run_network(pts)
             rgb_map, disp_map, acc_map, weights, depth_map, depth_var = self.raw2outputs(raw, z_vals, self.config['training']['white_bkgd'])
        
         # Return rendering outputs
         ret = {'rgb' : rgb_map, 'depth' :depth_map,
                'disp_map' : disp_map, 'acc_map' : acc_map,
                'depth_var':depth_var,}
         ret = {**ret, 'z_vals': z_vals}
        
         ret['raw'] = raw
        
         if self.config['training']['n_importance'] > 0:
             ret['rgb0'] = rgb_map_0
             ret['disp0'] = disp_map_0
             ret['acc0'] = acc_map_0
             ret['depth0'] = depth_map_0
             ret['depth_var0'] = depth_var_0
             ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False)
        
       return ret
    
  3. Run Network
    • scene_rep.py 에서 run_network
    • inputs를 tcnn encoder로 저장 후 decoder에 넣어 output으로 나오게 한다.
     def run_network(self, inputs):
         """
         Run the network on a batch of inputs.
        
         Params:
             inputs: [N_rays, N_samples, 3]
         Returns:
             outputs: [N_rays, N_samples, 4]
         """
         inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
            
         # Normalize the input to [0, 1] (TCNN convention)
         if self.config['grid']['tcnn_encoding']:
             inputs_flat = (inputs_flat - self.bounding_box[:, 0]) / (self.bounding_box[:, 1] - self.bounding_box[:, 0])
        
         **outputs_flat = batchify(self.query_color_sdf, None)(inputs_flat)**
         outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
        
         return outputs
    
  4. 이후

    1) Query RGB and SDF

    scene_rep.py 에서 query_color_sdf

     def query_color_sdf(self, **query_points**):
         '''
         Query the color and sdf at query_points.
        
         Params:
             query_points: [N_rays, N_samples, 3]
         Returns:
             raw: [N_rays, N_samples, 4]
         '''
         inputs_flat = torch.reshape(query_points, [-1, query_points.shape[-1]])
         # print(inputs_flat.shape)
        
         **embed = self.embed_fn(inputs_flat)
         embe_pos = self.embedpos_fn(inputs_flat)**
         if not self.config['grid']['oneGrid']:
             embed_color = self.embed_fn_color(inputs_flat)
             return self.decoder(embed, embe_pos, embed_color)
         return **self.decoder(embed, embe_pos)**
    

    2) Batchify

    model/utils.py 에서 bachify

     def batchify(fn, chunk=1024*64):
           """Constructs a version of 'fn' that applies to smaller batches.
           """
           if chunk is None:
               return fn
           def ret(inputs, inputs_dir=None):
               if inputs_dir is not None:
                   return torch.cat([fn(inputs[i:i+chunk], inputs_dir[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
               return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
           return ret