diff --git a/README.md b/README.md index 938fa20..d58131a 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,7 @@ sh scripts/download_model.sh ## Instance Prediction Please follow the command below to predict all the bounding boxes fo the images in `example` folder. ``` +./bbox.sh or python inference_bbox.py --test_img_dir example ``` All the prediction results would save in `example_bbox` folder. @@ -64,6 +65,7 @@ All the prediction results would save in `example_bbox` folder. ## Colorize Images Please follow the command below to colorize all the images in `example` foler. ``` +./color.sh or python test_fusion.py --name test_fusion --sample_p 1.0 --model fusion --fineSize 256 --test_img_dir example --results_img_dir results ``` All the colorized results would save in `results` folder. @@ -89,3 +91,6 @@ If you find our code/models useful, please consider citing our paper: ## Acknowledgments Our code borrows heavily from the amazing [colorization-pytorch](https://github.com/richzhang/colorization-pytorch) repository. + + +ENet: A Deep Neural Network Architecture for Real-Time Semantic Segmentation diff --git a/bbox.sh b/bbox.sh new file mode 100755 index 0000000..41aa925 --- /dev/null +++ b/bbox.sh @@ -0,0 +1,2 @@ +python inference_bbox.py --test_img_dir example + diff --git a/color.sh b/color.sh new file mode 100755 index 0000000..41f310c --- /dev/null +++ b/color.sh @@ -0,0 +1,7 @@ +python test_fusion.py \ + --name test_fusion \ + --sample_p 1.0 \ + --model fusion \ + --fineSize 256 \ + --test_img_dir example \ + --results_img_dir results diff --git a/inference_bbox.py b/inference_bbox.py index 0e289a1..e990fe2 100644 --- a/inference_bbox.py +++ b/inference_bbox.py @@ -19,10 +19,16 @@ import torch from tqdm import tqdm +import pdb + cfg = get_cfg() cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml")) + cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml") +# https://dl.fbaipublicfiles.com/detectron2/COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x/139653917/model_final_2d9806.pkl +cfg.MODEL.WEIGHTS = "checkpoints/model_final_2d9806.pkl" + predictor = DefaultPredictor(cfg) parser = ArgumentParser() @@ -42,6 +48,9 @@ lab_image = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) l_channel, a_channel, b_channel = cv2.split(lab_image) l_stack = np.stack([l_channel, l_channel, l_channel], axis=2) + # l_channel.shape -- (320, 480) + # l_stack.shape -- (320, 480, 3) + outputs = predictor(l_stack) save_path = join(output_npz_dir, image_path.split('.')[0]) pred_bbox = outputs["instances"].pred_boxes.to(torch.device('cpu')).tensor.numpy() @@ -50,4 +59,7 @@ print('delete {0}'.format(image_path)) os.remove(join(input_dir, image_path)) continue + + # (Pdb) pred_bbox.shape -- (10, 4) + # (Pdb) pred_scores.shape -- (10,) np.savez(save_path, bbox = pred_bbox, scores = pred_scores) \ No newline at end of file diff --git a/models/fusion_model.py b/models/fusion_model.py index 0f6ef6b..dcac393 100644 --- a/models/fusion_model.py +++ b/models/fusion_model.py @@ -94,20 +94,32 @@ def save_current_imgs(self, path): def setup_to_test(self, fusion_weight_path): GF_path = 'checkpoints/{0}/latest_net_GF.pth'.format(fusion_weight_path) print('load Fusion model from %s' % GF_path) - GF_state_dict = torch.load(GF_path) + GF_state_dict = torch.load(GF_path, map_location=torch.device('cpu')) # G_path = 'checkpoints/coco_finetuned_mask_256/latest_net_G.pth' # fine tuned on cocostuff G_path = 'checkpoints/{0}/latest_net_G.pth'.format(fusion_weight_path) - G_state_dict = torch.load(G_path) + G_state_dict = torch.load(G_path, map_location=torch.device('cpu')) # GComp_path = 'checkpoints/siggraph_retrained/latest_net_G.pth' # original net # GComp_path = 'checkpoints/coco_finetuned_mask_256/latest_net_GComp.pth' # fine tuned on cocostuff GComp_path = 'checkpoints/{0}/latest_net_GComp.pth'.format(fusion_weight_path) - GComp_state_dict = torch.load(GComp_path) - - self.netGF.load_state_dict(GF_state_dict, strict=False) - self.netG.module.load_state_dict(G_state_dict, strict=False) - self.netGComp.module.load_state_dict(GComp_state_dict, strict=False) + GComp_state_dict = torch.load(GComp_path, map_location=torch.device('cpu')) + + if (len(self.gpu_ids) > 0): + self.netGF.load_state_dict(GF_state_dict, strict=False) + self.netG.module.load_state_dict(G_state_dict, strict=False) + self.netGComp.module.load_state_dict(GComp_state_dict, strict=False) + else: + # self.netGF + target_state_dict = self.netGF.state_dict() + for n, p in GF_state_dict.items(): + n = n.replace('module.', '') + if n in target_state_dict.keys(): + target_state_dict[n].copy_(p) + else: + raise KeyError(n) + self.netG.load_state_dict(G_state_dict, strict=False) + self.netGComp.load_state_dict(GComp_state_dict, strict=False) self.netGF.eval() self.netG.eval() - self.netGComp.eval() \ No newline at end of file + self.netGComp.eval() diff --git a/project/README.md b/project/README.md new file mode 100644 index 0000000..319fa16 --- /dev/null +++ b/project/README.md @@ -0,0 +1,3 @@ + + +ENet: A Deep Neural Network Architecture for Real-Time Semantic Segmentation diff --git a/test_fusion.py b/test_fusion.py index ea571e9..7f0885d 100644 --- a/test_fusion.py +++ b/test_fusion.py @@ -16,6 +16,9 @@ os.environ["CUDA_VISIBLE_DEVICES"] = "0" import numpy as np import multiprocessing + +import pdb + multiprocessing.set_start_method('spawn', True) torch.backends.cudnn.benchmark = True @@ -36,26 +39,36 @@ model = create_model(opt) # model.setup_to_test('coco_finetuned_mask_256') model.setup_to_test('coco_finetuned_mask_256_ffs') + model.eval() count_empty = 0 for data_raw in tqdm(dataset_loader, dynamic_ncols=True): # if os.path.isfile(join(save_img_path, data_raw['file_id'][0] + '.png')) is True: # continue - data_raw['full_img'][0] = data_raw['full_img'][0].cuda() + if (len(opt.gpu_ids) > 0): + data_raw['full_img'][0] = data_raw['full_img'][0].cuda() + if data_raw['empty_box'][0] == 0: - data_raw['cropped_img'][0] = data_raw['cropped_img'][0].cuda() + if (len(opt.gpu_ids) > 0): + data_raw['cropped_img'][0] = data_raw['cropped_img'][0].cuda() box_info = data_raw['box_info'][0] box_info_2x = data_raw['box_info_2x'][0] box_info_4x = data_raw['box_info_4x'][0] box_info_8x = data_raw['box_info_8x'][0] cropped_data = util.get_colorization_data(data_raw['cropped_img'], opt, ab_thresh=0, p=opt.sample_p) full_img_data = util.get_colorization_data(data_raw['full_img'], opt, ab_thresh=0, p=opt.sample_p) + model.set_input(cropped_data) model.set_fusion_input(full_img_data, [box_info, box_info_2x, box_info_4x, box_info_8x]) - model.forward() + + with torch.no_grad(): + model.forward() else: count_empty += 1 full_img_data = util.get_colorization_data(data_raw['full_img'], opt, ab_thresh=0, p=opt.sample_p) - model.set_forward_without_box(full_img_data) + + with torch.no_grad(): + model.set_forward_without_box(full_img_data) + model.save_current_imgs(join(save_img_path, data_raw['file_id'][0] + '.png')) print('{0} images without bounding boxes'.format(count_empty))