Co-SLAM
어떻게 tcnn의 hash grid encoding과 mapping이 상호작용하는가?
- Forward (learning)
scene_rep.py
에서forward
- loss를 구하기 위해 ray를 render하고 rend_dict에 저장한다.
- render된 rgb 및 depth값과 loss를 return한다.
def forward(self, rays_o, rays_d, target_rgb, target_d, global_step=0): ''' Params: rays_o: ray origins (Bs, 3) rays_d: ray directions (Bs, 3) frame_ids: use for pose correction (Bs, 1) target_rgb: rgb value (Bs, 3) target_d: depth value (Bs, 1) c2w_array: poses (N, 4, 4) r r r tx r r r ty r r r tz ''' # Get render results **rend_dict = self.render_rays(rays_o, rays_d, target_d=target_d)** if not self.training: return rend_dict # Get depth and rgb weights for loss valid_depth_mask = (target_d.squeeze() > 0.) * (target_d.squeeze() < self.config['cam']['depth_trunc']) rgb_weight = valid_depth_mask.clone().unsqueeze(-1) rgb_weight[rgb_weight==0] = self.config['training']['rgb_missing'] # Get render loss rgb_loss = compute_loss(rend_dict["rgb"]*rgb_weight, target_rgb*rgb_weight) psnr = mse2psnr(rgb_loss) depth_loss = compute_loss(rend_dict["depth"].squeeze()[valid_depth_mask], target_d.squeeze()[valid_depth_mask]) if 'rgb0' in rend_dict: rgb_loss += compute_loss(rend_dict["rgb0"]*rgb_weight, target_rgb*rgb_weight) depth_loss += compute_loss(rend_dict["depth0"][valid_depth_mask], target_d.squeeze()[valid_depth_mask]) # Get sdf loss z_vals = rend_dict['z_vals'] # [N_rand, N_samples + N_importance] sdf = rend_dict['raw'][..., -1] # [N_rand, N_samples + N_importance] truncation = self.config['training']['trunc'] * self.config['data']['sc_factor'] fs_loss, sdf_loss = get_sdf_loss(z_vals, target_d, sdf, truncation, 'l2', grad=None) ret = { "rgb": rend_dict["rgb"], "depth": rend_dict["depth"], "rgb_loss": rgb_loss, "depth_loss": depth_loss, "sdf_loss": sdf_loss, "fs_loss": fs_loss, "psnr": psnr, } return ret
- Render Rays
scene_rep.py
에서render_rays
- ray와 target을 바탕으로 render된 결과를 return 한다.
def render_rays(self, rays_o, rays_d, target_d=None): ''' Params: rays_o: [N_rays, 3] rays_d: [N_rays, 3] target_d: [N_rays, 1] ''' n_rays = rays_o.shape[0] # Sample depth if target_d is not None: z_samples = torch.linspace(-self.config['training']['range_d'], self.config['training']['range_d'], steps=self.config['training']['n_range_d']).to(target_d) z_samples = z_samples[None, :].repeat(n_rays, 1) + target_d z_samples[target_d.squeeze()<=0] = torch.linspace(self.config['cam']['near'], self.config['cam']['far'], steps=self.config['training']['n_range_d']).to(target_d) if self.config['training']['n_samples_d'] > 0: z_vals = torch.linspace(self.config['cam']['near'], self.config['cam']['far'], self.config['training']['n_samples_d'])[None, :].repeat(n_rays, 1).to(rays_o) z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1) else: z_vals = z_samples else: z_vals = torch.linspace(self.config['cam']['near'], self.config['cam']['far'], self.config['training']['n_samples']).to(rays_o) z_vals = z_vals[None, :].repeat(n_rays, 1) # [n_rays, n_samples] # Perturb sampling depths if self.config['training']['perturb'] > 0.: mids = .5 * (z_vals[...,1:] + z_vals[...,:-1]) upper = torch.cat([mids, z_vals[...,-1:]], -1) lower = torch.cat([z_vals[...,:1], mids], -1) z_vals = lower + (upper - lower) * torch.rand(z_vals.shape).to(rays_o) # Run rendering pipeline pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3] raw = self.run_network(pts) rgb_map, disp_map, acc_map, weights, depth_map, depth_var = self.raw2outputs(raw, z_vals, self.config['training']['white_bkgd']) # Importance sampling if self.config['training']['n_importance'] > 0: rgb_map_0, disp_map_0, acc_map_0, depth_map_0, depth_var_0 = rgb_map, disp_map, acc_map, depth_map, depth_var z_vals_mid = .5 * (z_vals[...,1:] + z_vals[...,:-1]) z_samples = sample_pdf(z_vals_mid, weights[...,1:-1], self.config['training']['n_importance'], det=(self.config['training']['perturb']==0.)) z_samples = z_samples.detach() z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1) pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples + N_importance, 3] raw = self.run_network(pts) rgb_map, disp_map, acc_map, weights, depth_map, depth_var = self.raw2outputs(raw, z_vals, self.config['training']['white_bkgd']) # Return rendering outputs ret = {'rgb' : rgb_map, 'depth' :depth_map, 'disp_map' : disp_map, 'acc_map' : acc_map, 'depth_var':depth_var,} ret = {**ret, 'z_vals': z_vals} ret['raw'] = raw if self.config['training']['n_importance'] > 0: ret['rgb0'] = rgb_map_0 ret['disp0'] = disp_map_0 ret['acc0'] = acc_map_0 ret['depth0'] = depth_map_0 ret['depth_var0'] = depth_var_0 ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False) return ret
- Run Network
scene_rep.py
에서run_network
- inputs를 tcnn encoder로 저장 후 decoder에 넣어 output으로 나오게 한다.
def run_network(self, inputs): """ Run the network on a batch of inputs. Params: inputs: [N_rays, N_samples, 3] Returns: outputs: [N_rays, N_samples, 4] """ inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]]) # Normalize the input to [0, 1] (TCNN convention) if self.config['grid']['tcnn_encoding']: inputs_flat = (inputs_flat - self.bounding_box[:, 0]) / (self.bounding_box[:, 1] - self.bounding_box[:, 0]) **outputs_flat = batchify(self.query_color_sdf, None)(inputs_flat)** outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]]) return outputs
-
이후
1) Query RGB and SDF
scene_rep.py
에서query_color_sdf
def query_color_sdf(self, **query_points**): ''' Query the color and sdf at query_points. Params: query_points: [N_rays, N_samples, 3] Returns: raw: [N_rays, N_samples, 4] ''' inputs_flat = torch.reshape(query_points, [-1, query_points.shape[-1]]) # print(inputs_flat.shape) **embed = self.embed_fn(inputs_flat) embe_pos = self.embedpos_fn(inputs_flat)** if not self.config['grid']['oneGrid']: embed_color = self.embed_fn_color(inputs_flat) return self.decoder(embed, embe_pos, embed_color) return **self.decoder(embed, embe_pos)**
2) Batchify
model/utils.py
에서bachify
def batchify(fn, chunk=1024*64): """Constructs a version of 'fn' that applies to smaller batches. """ if chunk is None: return fn def ret(inputs, inputs_dir=None): if inputs_dir is not None: return torch.cat([fn(inputs[i:i+chunk], inputs_dir[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0) return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0) return ret