|
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
7 | 7 | from collections import defaultdict |
8 | | -from dataclasses import dataclass, field, fields |
| 8 | +from dataclasses import dataclass |
9 | 9 | from typing import ( |
10 | | - Any, |
11 | 10 | ClassVar, |
12 | 11 | Dict, |
13 | 12 | Iterable, |
14 | 13 | Iterator, |
15 | 14 | List, |
16 | | - Mapping, |
17 | 15 | Optional, |
18 | 16 | Sequence, |
19 | 17 | Tuple, |
20 | 18 | Type, |
21 | | - Union, |
22 | 19 | ) |
23 | 20 |
|
24 | | -import numpy as np |
25 | 21 | import torch |
26 | | -from pytorch3d.renderer.camera_utils import join_cameras_as_batch |
27 | | -from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras |
28 | | -from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds |
29 | 22 |
|
30 | | - |
31 | | -@dataclass |
32 | | -class FrameData(Mapping[str, Any]): |
33 | | - """ |
34 | | - A type of the elements returned by indexing the dataset object. |
35 | | - It can represent both individual frames and batches of thereof; |
36 | | - in this documentation, the sizes of tensors refer to single frames; |
37 | | - add the first batch dimension for the collation result. |
38 | | -
|
39 | | - Args: |
40 | | - frame_number: The number of the frame within its sequence. |
41 | | - 0-based continuous integers. |
42 | | - sequence_name: The unique name of the frame's sequence. |
43 | | - sequence_category: The object category of the sequence. |
44 | | - frame_timestamp: The time elapsed since the start of a sequence in sec. |
45 | | - image_size_hw: The size of the image in pixels; (height, width) tensor |
46 | | - of shape (2,). |
47 | | - image_path: The qualified path to the loaded image (with dataset_root). |
48 | | - image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image |
49 | | - of the frame; elements are floats in [0, 1]. |
50 | | - mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image |
51 | | - regions. Regions can be invalid (mask_crop[i,j]=0) in case they |
52 | | - are a result of zero-padding of the image after cropping around |
53 | | - the object bounding box; elements are floats in {0.0, 1.0}. |
54 | | - depth_path: The qualified path to the frame's depth map. |
55 | | - depth_map: A float Tensor of shape `(1, H, W)` holding the depth map |
56 | | - of the frame; values correspond to distances from the camera; |
57 | | - use `depth_mask` and `mask_crop` to filter for valid pixels. |
58 | | - depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the |
59 | | - depth map that are valid for evaluation, they have been checked for |
60 | | - consistency across views; elements are floats in {0.0, 1.0}. |
61 | | - mask_path: A qualified path to the foreground probability mask. |
62 | | - fg_probability: A Tensor of `(1, H, W)` denoting the probability of the |
63 | | - pixels belonging to the captured object; elements are floats |
64 | | - in [0, 1]. |
65 | | - bbox_xywh: The bounding box tightly enclosing the foreground object in the |
66 | | - format (x0, y0, width, height). The convention assumes that |
67 | | - `x0+width` and `y0+height` includes the boundary of the box. |
68 | | - I.e., to slice out the corresponding crop from an image tensor `I` |
69 | | - we execute `crop = I[..., y0:y0+height, x0:x0+width]` |
70 | | - crop_bbox_xywh: The bounding box denoting the boundaries of `image_rgb` |
71 | | - in the original image coordinates in the format (x0, y0, width, height). |
72 | | - The convention is the same as for `bbox_xywh`. `crop_bbox_xywh` differs |
73 | | - from `bbox_xywh` due to padding (which can happen e.g. due to |
74 | | - setting `JsonIndexDataset.box_crop_context > 0`) |
75 | | - camera: A PyTorch3D camera object corresponding the frame's viewpoint, |
76 | | - corrected for cropping if it happened. |
77 | | - camera_quality_score: The score proportional to the confidence of the |
78 | | - frame's camera estimation (the higher the more accurate). |
79 | | - point_cloud_quality_score: The score proportional to the accuracy of the |
80 | | - frame's sequence point cloud (the higher the more accurate). |
81 | | - sequence_point_cloud_path: The path to the sequence's point cloud. |
82 | | - sequence_point_cloud: A PyTorch3D Pointclouds object holding the |
83 | | - point cloud corresponding to the frame's sequence. When the object |
84 | | - represents a batch of frames, point clouds may be deduplicated; |
85 | | - see `sequence_point_cloud_idx`. |
86 | | - sequence_point_cloud_idx: Integer indices mapping frame indices to the |
87 | | - corresponding point clouds in `sequence_point_cloud`; to get the |
88 | | - corresponding point cloud to `image_rgb[i]`, use |
89 | | - `sequence_point_cloud[sequence_point_cloud_idx[i]]`. |
90 | | - frame_type: The type of the loaded frame specified in |
91 | | - `subset_lists_file`, if provided. |
92 | | - meta: A dict for storing additional frame information. |
93 | | - """ |
94 | | - |
95 | | - frame_number: Optional[torch.LongTensor] |
96 | | - sequence_name: Union[str, List[str]] |
97 | | - sequence_category: Union[str, List[str]] |
98 | | - frame_timestamp: Optional[torch.Tensor] = None |
99 | | - image_size_hw: Optional[torch.Tensor] = None |
100 | | - image_path: Union[str, List[str], None] = None |
101 | | - image_rgb: Optional[torch.Tensor] = None |
102 | | - # masks out padding added due to cropping the square bit |
103 | | - mask_crop: Optional[torch.Tensor] = None |
104 | | - depth_path: Union[str, List[str], None] = None |
105 | | - depth_map: Optional[torch.Tensor] = None |
106 | | - depth_mask: Optional[torch.Tensor] = None |
107 | | - mask_path: Union[str, List[str], None] = None |
108 | | - fg_probability: Optional[torch.Tensor] = None |
109 | | - bbox_xywh: Optional[torch.Tensor] = None |
110 | | - crop_bbox_xywh: Optional[torch.Tensor] = None |
111 | | - camera: Optional[PerspectiveCameras] = None |
112 | | - camera_quality_score: Optional[torch.Tensor] = None |
113 | | - point_cloud_quality_score: Optional[torch.Tensor] = None |
114 | | - sequence_point_cloud_path: Union[str, List[str], None] = None |
115 | | - sequence_point_cloud: Optional[Pointclouds] = None |
116 | | - sequence_point_cloud_idx: Optional[torch.Tensor] = None |
117 | | - frame_type: Union[str, List[str], None] = None # known | unseen |
118 | | - meta: dict = field(default_factory=lambda: {}) |
119 | | - |
120 | | - def to(self, *args, **kwargs): |
121 | | - new_params = {} |
122 | | - for f in fields(self): |
123 | | - value = getattr(self, f.name) |
124 | | - if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)): |
125 | | - new_params[f.name] = value.to(*args, **kwargs) |
126 | | - else: |
127 | | - new_params[f.name] = value |
128 | | - return type(self)(**new_params) |
129 | | - |
130 | | - def cpu(self): |
131 | | - return self.to(device=torch.device("cpu")) |
132 | | - |
133 | | - def cuda(self): |
134 | | - return self.to(device=torch.device("cuda")) |
135 | | - |
136 | | - # the following functions make sure **frame_data can be passed to functions |
137 | | - def __iter__(self): |
138 | | - for f in fields(self): |
139 | | - yield f.name |
140 | | - |
141 | | - def __getitem__(self, key): |
142 | | - return getattr(self, key) |
143 | | - |
144 | | - def __len__(self): |
145 | | - return len(fields(self)) |
146 | | - |
147 | | - @classmethod |
148 | | - def collate(cls, batch): |
149 | | - """ |
150 | | - Given a list objects `batch` of class `cls`, collates them into a batched |
151 | | - representation suitable for processing with deep networks. |
152 | | - """ |
153 | | - |
154 | | - elem = batch[0] |
155 | | - |
156 | | - if isinstance(elem, cls): |
157 | | - pointcloud_ids = [id(el.sequence_point_cloud) for el in batch] |
158 | | - id_to_idx = defaultdict(list) |
159 | | - for i, pc_id in enumerate(pointcloud_ids): |
160 | | - id_to_idx[pc_id].append(i) |
161 | | - |
162 | | - sequence_point_cloud = [] |
163 | | - sequence_point_cloud_idx = -np.ones((len(batch),)) |
164 | | - for i, ind in enumerate(id_to_idx.values()): |
165 | | - sequence_point_cloud_idx[ind] = i |
166 | | - sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud) |
167 | | - assert (sequence_point_cloud_idx >= 0).all() |
168 | | - |
169 | | - override_fields = { |
170 | | - "sequence_point_cloud": sequence_point_cloud, |
171 | | - "sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(), |
172 | | - } |
173 | | - # note that the pre-collate value of sequence_point_cloud_idx is unused |
174 | | - |
175 | | - collated = {} |
176 | | - for f in fields(elem): |
177 | | - list_values = override_fields.get( |
178 | | - f.name, [getattr(d, f.name) for d in batch] |
179 | | - ) |
180 | | - collated[f.name] = ( |
181 | | - cls.collate(list_values) |
182 | | - if all(list_value is not None for list_value in list_values) |
183 | | - else None |
184 | | - ) |
185 | | - return cls(**collated) |
186 | | - |
187 | | - elif isinstance(elem, Pointclouds): |
188 | | - return join_pointclouds_as_batch(batch) |
189 | | - |
190 | | - elif isinstance(elem, CamerasBase): |
191 | | - # TODO: don't store K; enforce working in NDC space |
192 | | - return join_cameras_as_batch(batch) |
193 | | - else: |
194 | | - return torch.utils.data._utils.collate.default_collate(batch) |
195 | | - |
196 | | - |
197 | | -class _GenericWorkaround: |
198 | | - """ |
199 | | - OmegaConf.structured has a weirdness when you try to apply |
200 | | - it to a dataclass whose first base class is a Generic which is not |
201 | | - Dict. The issue is with a function called get_dict_key_value_types |
202 | | - in omegaconf/_utils.py. |
203 | | - For example this fails: |
204 | | -
|
205 | | - @dataclass(eq=False) |
206 | | - class D(torch.utils.data.Dataset[int]): |
207 | | - a: int = 3 |
208 | | -
|
209 | | - OmegaConf.structured(D) |
210 | | -
|
211 | | - We avoid the problem by adding this class as an extra base class. |
212 | | - """ |
213 | | - |
214 | | - pass |
| 23 | +from pytorch3d.implicitron.dataset.frame_data import FrameData |
| 24 | +from pytorch3d.implicitron.dataset.utils import GenericWorkaround |
215 | 25 |
|
216 | 26 |
|
217 | 27 | @dataclass(eq=False) |
218 | | -class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]): |
| 28 | +class DatasetBase(GenericWorkaround, torch.utils.data.Dataset[FrameData]): |
219 | 29 | """ |
220 | 30 | Base class to describe a dataset to be used with Implicitron. |
221 | 31 |
|
|
0 commit comments