|
|
|
|
|
''' |
|
@File : DatasetAnalyzer.py |
|
@Time : 2022/04/08 10:10:12 |
|
@Author : zzubqh |
|
@Version : 1.0 |
|
@Contact : [email protected] |
|
@License : (C)Copyright 2017-2018, Liugroup-NLPR-CASIA |
|
@Desc : None |
|
''' |
|
|
|
|
|
|
|
import numpy as np |
|
import os |
|
import SimpleITK as sitk |
|
from multiprocessing import Pool |
|
|
|
|
|
class DatasetAnalyzer(object): |
|
""" |
|
接收一个类似train.md的文件 |
|
格式:**/ct_file.nii.gz, */seg_file.nii.gz |
|
""" |
|
def __init__(self, annotation_file, num_processes=4): |
|
self.dataset = [] |
|
self.num_processes = num_processes |
|
with open(annotation_file, 'r', encoding='utf-8') as rf: |
|
for line_item in rf: |
|
items = line_item.strip().split(',') |
|
self.dataset.append({'ct': items[0], 'mask': items[1]}) |
|
|
|
print('total load {0} ct files'.format(len(self.dataset))) |
|
|
|
def _get_effective_data(self, dataset_item: dict): |
|
itk_img = sitk.ReadImage(dataset_item['ct']) |
|
itk_mask = sitk.ReadImage(dataset_item['mask']) |
|
|
|
img_np = sitk.GetArrayFromImage(itk_img) |
|
mask_np = sitk.GetArrayFromImage(itk_mask) |
|
|
|
mask_index = mask_np > 0 |
|
effective_data = img_np[mask_index][::10] |
|
return list(effective_data) |
|
|
|
def compute_stats(self): |
|
if len(self.dataset) == 0: |
|
return np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan |
|
|
|
process_pool = Pool(self.num_processes) |
|
data_value = process_pool.map(self._get_effective_data, self.dataset) |
|
|
|
print('sub process end, get {0} case data'.format(len(data_value))) |
|
voxels = [] |
|
for value in data_value: |
|
voxels += value |
|
|
|
median = np.median(voxels) |
|
mean = np.mean(voxels) |
|
sd = np.std(voxels) |
|
mn = np.min(voxels) |
|
mx = np.max(voxels) |
|
percentile_99_5 = np.percentile(voxels, 99.5) |
|
percentile_00_5 = np.percentile(voxels, 00.5) |
|
|
|
process_pool.close() |
|
process_pool.join() |
|
return median, mean, sd, mn, mx, percentile_99_5, percentile_00_5 |
|
|
|
|
|
if __name__ == '__main__': |
|
import tqdm |
|
annotation = r'/home/code/Dental/Segmentation/dataset/tooth_label.md' |
|
analyzer = DatasetAnalyzer(annotation, num_processes=8) |
|
out_dir = r'/data/Dental/SegTrainingClipdata' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for item in tqdm.tqdm(analyzer.dataset): |
|
ct_file = item['ct'] |
|
out_name = os.path.basename(ct_file) |
|
out_path = os.path.join(out_dir, out_name) |
|
itk_img = sitk.ReadImage(item['ct']) |
|
img_np = sitk.GetArrayFromImage(itk_img) |
|
data = np.clip(img_np, 181.0, 7578.0) |
|
clip_img = sitk.GetImageFromArray(data) |
|
clip_img.CopyInformation(itk_img) |
|
sitk.WriteImage(clip_img, out_path) |
|
|
|
|
|
|
|
|