Coverage for crip\lowdose.py: 93%

91 statements  

« prev     ^ index     » next       coverage.py v7.5.2, created at 2024-07-16 01:15 +0800

1''' 

2 Low Dose CT module of crip. 

3 

4 https://github.com/SEU-CT-Recon/crip 

5''' 

6 

7import numpy as np 

8 

9from ._typing import * 

10from .shared import * 

11from .utils import * 

12 

13 

14@ConvertListNDArray 

15def injectGaussianNoise(projections: TwoOrThreeD, 

16 sigma: float, 

17 mu: float = 0, 

18 clipMinMax: Or[None, Tuple[float, float]] = None) -> TwoOrThreeD: 

19 ''' Inject additive Gaussian noise ~ `N(mu, sigma^2)` where `sigma` is the standard deviation and `mu` is the mean. 

20 use `clipMinMax = [min, max]` to clip the noisy projections. 

21 ''' 

22 cripAssert(is2or3D(projections), f'`projections` should be 2 or 3-D, but got {projections.ndim}-D.') 

23 cripAssert(sigma > 0, 'sigma should be greater than 0.') 

24 

25 _inject1 = lambda img: (np.random.randn(*img.shape) * sigma + mu) + img 

26 if is3D(projections): 

27 res = np.array(list(map(_inject1, projections))) 

28 else: 

29 res = _inject1(projections) 

30 

31 if clipMinMax is not None: 

32 cripAssert(len(clipMinMax) == 2, 'Invalid `clipMinMax`.') 

33 res = np.clip(res, *clipMinMax) 

34 

35 return res 

36 

37 

38@ConvertListNDArray 

39def injectPoissonNoise( 

40 projections: TwoOrThreeD, 

41 type_: str = 'postlog', 

42 nPhoton: int = 1e5, 

43) -> TwoOrThreeD: 

44 ''' Inject Poisson noise ~ `P(lambda)` where `lambda` is the ground-truth quanta deduced from arguments. 

45 `type_` [postlog or raw] gives the content type of `projections`, usually you use 

46 postlog as input and get postlog as output. `nPhoton` is the photon count for each pixel. 

47 ''' 

48 cripAssert(type_ in ['postlog', 'raw'], f'Invalid type_: {type_}.') 

49 cripAssert(is2or3D(projections), '`projections` should be 2D or 3D.') 

50 

51 img = projections 

52 if type_ == 'postlog': 

53 img = np.exp(-img) 

54 

55 img = nPhoton * img # N0 exp(-\sum \mu L), i.e., ground truth quanta 

56 img = np.random.poisson(img.astype(np.uint32)).astype(DefaultFloatDType) # noisy quanta 

57 img[img <= 0] = 1 

58 img /= nPhoton # cancel the rescaling from N0 

59 

60 if type_ == 'postlog': 

61 img = -np.log(img) 

62 

63 return img 

64 

65 

66@ConvertListNDArray 

67def totalVariation(img: TwoOrThreeD) -> Or[float, NDArray]: 

68 ''' Computes the Total Variation (TV) of images. 

69 For 2D image, it returns a scalar. 

70 For 3D image, it returns an array of TV values for each slice. 

71 ''' 

72 cripAssert(is2or3D(img), 'img should be 2 or 3D.') 

73 

74 vX = img[..., :, 1:] - img[..., :, :-1] 

75 vY = img[..., 1:, :] - img[..., :-1, :] 

76 

77 axis = (-2, -1) 

78 tv = np.sum(np.abs(vX), axis=axis) + np.sum(np.abs(vY), axis=axis) 

79 

80 return tv 

81 

82 

83def nps2D(roi: TwoOrThreeD, 

84 pixelSize: float, 

85 detrend: Or[str, None] = 'individual', 

86 n: Or[int, None] = None, 

87 fftshift: bool = True, 

88 normalize: Or[None, str] = None) -> TwoD: 

89 ''' Compute the noise power spectrum (NPS) of a 2D square ROI using DFT. 

90 `pixelSize` is the pixel size of reconstructed image ([mm] recommended). 

91 `detrend` method can be `individual` (by mean value subtraction), `mutual` (by foreground subtraction) or None. 

92 `normalize` method can be `sum` (by amp. sum), `max` (by amp. max) or None. 

93 `fftshift` is used to shift the zero frequency to the center. 

94 `n` is the number of dots in DFT, if not provided, it will be the next power of 2 of the ROI size. 

95 Usually, the ROI should be a uniform region, and multiple realizations are recommended. 

96 The output NPS unit is usually recognized as [a.u.], and x,y-coordinates correspond to 

97 physical location `coord*pixelSize`. 

98 

99 [1] https://amos3.aapm.org/abstracts/pdf/99-28842-359478-110263-658667764.pdf 

100 ''' 

101 cripAssert(detrend in ['individual', 'mutual', None], f'Invalid detrend method: {detrend}.') 

102 cripAssert(normalize in ['sum', 'max', None], f'Invalid normalize method: {normalize}.') 

103 cripWarning(is3D(roi), "It's highly recommended to provide multiple realizations of the ROI.") 

104 if detrend == 'mutual': 

105 cripAssert(is3D(roi), '`mutual` detrend method requires multiple realizations of the ROI.') 

106 

107 roi = as3D(roi) 

108 h, w = getHnW(roi) 

109 cripAssert(h == w, 'ROI should be square.') 

110 dots = n or nextPow2(h) 

111 

112 # de-trend the signal 

113 if detrend == 'individual': 

114 detrended = np.zeros_like(roi) 

115 for i in range(roi.shape[0]): 

116 detrended[i] = roi[i] - np.mean(roi[i]) # (DC+noise)-DC 

117 s = 1 

118 elif detrend == 'mutual': 

119 detrended = np.zeros((roi.shape[0] - 1, h, w)) 

120 for i in range(1, roi.shape[0]): 

121 detrended[i - 1] = roi[i, ...] - roi[i - 1, ...] # (DC+noise)-(DC+noise) 

122 s = 1 / 2 

123 elif detrend is None: 

124 detrended = roi.copy() 

125 s = 1 

126 

127 dft = np.fft.fft2(detrended, s=(dots, dots)) 

128 dft = np.mean(dft, axis=0) # averaged NPS 

129 if fftshift: 

130 dft = np.fft.fftshift(dft) 

131 

132 mod2 = np.abs(dft)**2 # square of modulus 

133 nps = (pixelSize * pixelSize) / (h * w) * mod2 * s 

134 if normalize == 'sum': 

135 nps /= nps.sum() 

136 elif normalize == 'max': 

137 nps /= nps.max() 

138 

139 return nps 

140 

141 

142def nps2DRadAvg(nps: TwoD, fftshifted: bool = True, normalize: Or[str, None] = None) -> NDArray: 

143 ''' Compute the radially averaged noise power spectrum (NPS) where `nps` can be the output from  

144 `nps2D` (unnormalized, fftshifted). Do not normalize the input `nps` before using this function. 

145 The output is a 1D array of NPS values [a.u.] and x-axis is the spatial frequency [1/[unit of pixelSize]]. 

146 ''' 

147 cripAssert(is2D(nps), '`nps` should be 2D.') 

148 cripAssert(nps.shape[0] == nps.shape[1], '`nps` should be square.') 

149 cripAssert(normalize in ['sum', 'max', None], f'Invalid normalize method: {normalize}.') 

150 if nps.max() <= 1: 

151 cripWarning(False, 'The input `nps` looks to be normalized already, which is not expected.') 

152 if not fftshifted: 

153 nps = np.fft.fftshift(nps) 

154 

155 N = nps.shape[0] 

156 x, y = np.meshgrid(np.arange(N), np.arange(N)) 

157 R = np.sqrt(x**2 + y**2) 

158 

159 _f = lambda r: nps[(R >= r - .5) & (R < r + .5)].mean() 

160 _args = np.linspace(1, N, num=N) 

161 res = np.vectorize(_f)(_args) 

162 

163 if normalize == 'sum': 

164 res /= res.sum() 

165 elif normalize == 'max': 

166 res /= res.max() 

167 

168 return res