-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathMeshDataset.py
27 lines (22 loc) · 930 Bytes
/
MeshDataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from torch.utils.data import DataLoader, Dataset
import fnmatch
import os
from pytorch3d.io import load_objs_as_meshes, load_obj
class MeshDataset(Dataset):
def __init__(self, mesh_dir, device, shuffle=True, max_num=9999):
self.len = min(len(fnmatch.filter(os.listdir(mesh_dir), '*.obj')), max_num)
self.mesh_dir = mesh_dir
self.shuffle = shuffle
self.mesh_filenames = fnmatch.filter(os.listdir(mesh_dir), '*.obj')
self.mesh_filenames = self.mesh_filenames[:self.len]
self.mesh_files = []
for m in self.mesh_filenames:
self.mesh_files.append(os.path.join(self.mesh_dir, m))
print('Meshes: ', self.mesh_files)
self.meshes = []
for mesh in self.mesh_files:
self.meshes.append(load_objs_as_meshes([mesh], device=device, create_texture_atlas = True,texture_atlas_size = 1))
def __len__(self):
return self.len
def __getitem__(self, idx):
return self.meshes[idx]