File size: 3,424 Bytes
a166479
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@File    :   DatasetAnalyzer.py
@Time    :   2022/04/08 10:10:12
@Author  :   zzubqh 
@Version :   1.0
@Contact :   [email protected]
@License :   (C)Copyright 2017-2018, Liugroup-NLPR-CASIA
@Desc    :   None
'''

# here put the import lib

import numpy as np
import os
import SimpleITK as sitk
from multiprocessing import Pool


class DatasetAnalyzer(object):
    """
    接收一个类似train.md的文件
    格式:**/ct_file.nii.gz, */seg_file.nii.gz
    """
    def __init__(self, annotation_file, num_processes=4):
        self.dataset = []
        self.num_processes = num_processes
        with open(annotation_file, 'r', encoding='utf-8') as rf:
            for line_item in rf:
                items = line_item.strip().split(',')
                self.dataset.append({'ct': items[0], 'mask': items[1]})

        print('total load {0} ct files'.format(len(self.dataset)))

    def _get_effective_data(self, dataset_item: dict):
        itk_img = sitk.ReadImage(dataset_item['ct'])
        itk_mask = sitk.ReadImage(dataset_item['mask'])

        img_np = sitk.GetArrayFromImage(itk_img)
        mask_np = sitk.GetArrayFromImage(itk_mask)

        mask_index = mask_np > 0
        effective_data = img_np[mask_index][::10]
        return list(effective_data)

    def compute_stats(self):
        if len(self.dataset) == 0:
            return np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan

        process_pool = Pool(self.num_processes)
        data_value = process_pool.map(self._get_effective_data, self.dataset)

        print('sub process end, get {0} case data'.format(len(data_value)))
        voxels = []
        for value in data_value:
            voxels += value

        median = np.median(voxels)
        mean = np.mean(voxels)
        sd = np.std(voxels)
        mn = np.min(voxels)
        mx = np.max(voxels)
        percentile_99_5 = np.percentile(voxels, 99.5)
        percentile_00_5 = np.percentile(voxels, 00.5)

        process_pool.close()
        process_pool.join()
        return median, mean, sd, mn, mx, percentile_99_5, percentile_00_5


if __name__ == '__main__':
    import tqdm
    annotation = r'/home/code/Dental/Segmentation/dataset/tooth_label.md'
    analyzer = DatasetAnalyzer(annotation, num_processes=8)
    out_dir = r'/data/Dental/SegTrainingClipdata'
    # t = analyzer.compute_stats()
    # print(t)

    # new_annotation = r'/home/code/BoneSegLandmark/dataset/knee_clip_label_seg.md'
    # wf = open(new_annotation, 'w', encoding='utf-8')
    # with open(annotation, 'r', encoding='utf-8') as rf:
    #     for str_line in rf:
    #         items = str_line.strip().split(',')
    #         ct_name = os.path.basename(items[0])
    #         new_ct_path = os.path.join(out_dir, ct_name)
    #         label_file = items[1]
    #         wf.write('{0},{1}\r'.format(new_ct_path, label_file))
    # wf.close()

    # 根据CT值的范围重新生成新CT
    for item in tqdm.tqdm(analyzer.dataset):
        ct_file = item['ct']
        out_name = os.path.basename(ct_file)
        out_path = os.path.join(out_dir, out_name)
        itk_img = sitk.ReadImage(item['ct'])
        img_np = sitk.GetArrayFromImage(itk_img)
        data = np.clip(img_np, 181.0, 7578.0)
        clip_img = sitk.GetImageFromArray(data)
        clip_img.CopyInformation(itk_img)
        sitk.WriteImage(clip_img, out_path)