Coverage for crip\lowdose.py: 93%
91 statements
« prev ^ index » next coverage.py v7.5.2, created at 2024-07-16 01:15 +0800
« prev ^ index » next coverage.py v7.5.2, created at 2024-07-16 01:15 +0800
1'''
2 Low Dose CT module of crip.
4 https://github.com/SEU-CT-Recon/crip
5'''
7import numpy as np
9from ._typing import *
10from .shared import *
11from .utils import *
14@ConvertListNDArray
15def injectGaussianNoise(projections: TwoOrThreeD,
16 sigma: float,
17 mu: float = 0,
18 clipMinMax: Or[None, Tuple[float, float]] = None) -> TwoOrThreeD:
19 ''' Inject additive Gaussian noise ~ `N(mu, sigma^2)` where `sigma` is the standard deviation and `mu` is the mean.
20 use `clipMinMax = [min, max]` to clip the noisy projections.
21 '''
22 cripAssert(is2or3D(projections), f'`projections` should be 2 or 3-D, but got {projections.ndim}-D.')
23 cripAssert(sigma > 0, 'sigma should be greater than 0.')
25 _inject1 = lambda img: (np.random.randn(*img.shape) * sigma + mu) + img
26 if is3D(projections):
27 res = np.array(list(map(_inject1, projections)))
28 else:
29 res = _inject1(projections)
31 if clipMinMax is not None:
32 cripAssert(len(clipMinMax) == 2, 'Invalid `clipMinMax`.')
33 res = np.clip(res, *clipMinMax)
35 return res
38@ConvertListNDArray
39def injectPoissonNoise(
40 projections: TwoOrThreeD,
41 type_: str = 'postlog',
42 nPhoton: int = 1e5,
43) -> TwoOrThreeD:
44 ''' Inject Poisson noise ~ `P(lambda)` where `lambda` is the ground-truth quanta deduced from arguments.
45 `type_` [postlog or raw] gives the content type of `projections`, usually you use
46 postlog as input and get postlog as output. `nPhoton` is the photon count for each pixel.
47 '''
48 cripAssert(type_ in ['postlog', 'raw'], f'Invalid type_: {type_}.')
49 cripAssert(is2or3D(projections), '`projections` should be 2D or 3D.')
51 img = projections
52 if type_ == 'postlog':
53 img = np.exp(-img)
55 img = nPhoton * img # N0 exp(-\sum \mu L), i.e., ground truth quanta
56 img = np.random.poisson(img.astype(np.uint32)).astype(DefaultFloatDType) # noisy quanta
57 img[img <= 0] = 1
58 img /= nPhoton # cancel the rescaling from N0
60 if type_ == 'postlog':
61 img = -np.log(img)
63 return img
66@ConvertListNDArray
67def totalVariation(img: TwoOrThreeD) -> Or[float, NDArray]:
68 ''' Computes the Total Variation (TV) of images.
69 For 2D image, it returns a scalar.
70 For 3D image, it returns an array of TV values for each slice.
71 '''
72 cripAssert(is2or3D(img), 'img should be 2 or 3D.')
74 vX = img[..., :, 1:] - img[..., :, :-1]
75 vY = img[..., 1:, :] - img[..., :-1, :]
77 axis = (-2, -1)
78 tv = np.sum(np.abs(vX), axis=axis) + np.sum(np.abs(vY), axis=axis)
80 return tv
83def nps2D(roi: TwoOrThreeD,
84 pixelSize: float,
85 detrend: Or[str, None] = 'individual',
86 n: Or[int, None] = None,
87 fftshift: bool = True,
88 normalize: Or[None, str] = None) -> TwoD:
89 ''' Compute the noise power spectrum (NPS) of a 2D square ROI using DFT.
90 `pixelSize` is the pixel size of reconstructed image ([mm] recommended).
91 `detrend` method can be `individual` (by mean value subtraction), `mutual` (by foreground subtraction) or None.
92 `normalize` method can be `sum` (by amp. sum), `max` (by amp. max) or None.
93 `fftshift` is used to shift the zero frequency to the center.
94 `n` is the number of dots in DFT, if not provided, it will be the next power of 2 of the ROI size.
95 Usually, the ROI should be a uniform region, and multiple realizations are recommended.
96 The output NPS unit is usually recognized as [a.u.], and x,y-coordinates correspond to
97 physical location `coord*pixelSize`.
99 [1] https://amos3.aapm.org/abstracts/pdf/99-28842-359478-110263-658667764.pdf
100 '''
101 cripAssert(detrend in ['individual', 'mutual', None], f'Invalid detrend method: {detrend}.')
102 cripAssert(normalize in ['sum', 'max', None], f'Invalid normalize method: {normalize}.')
103 cripWarning(is3D(roi), "It's highly recommended to provide multiple realizations of the ROI.")
104 if detrend == 'mutual':
105 cripAssert(is3D(roi), '`mutual` detrend method requires multiple realizations of the ROI.')
107 roi = as3D(roi)
108 h, w = getHnW(roi)
109 cripAssert(h == w, 'ROI should be square.')
110 dots = n or nextPow2(h)
112 # de-trend the signal
113 if detrend == 'individual':
114 detrended = np.zeros_like(roi)
115 for i in range(roi.shape[0]):
116 detrended[i] = roi[i] - np.mean(roi[i]) # (DC+noise)-DC
117 s = 1
118 elif detrend == 'mutual':
119 detrended = np.zeros((roi.shape[0] - 1, h, w))
120 for i in range(1, roi.shape[0]):
121 detrended[i - 1] = roi[i, ...] - roi[i - 1, ...] # (DC+noise)-(DC+noise)
122 s = 1 / 2
123 elif detrend is None:
124 detrended = roi.copy()
125 s = 1
127 dft = np.fft.fft2(detrended, s=(dots, dots))
128 dft = np.mean(dft, axis=0) # averaged NPS
129 if fftshift:
130 dft = np.fft.fftshift(dft)
132 mod2 = np.abs(dft)**2 # square of modulus
133 nps = (pixelSize * pixelSize) / (h * w) * mod2 * s
134 if normalize == 'sum':
135 nps /= nps.sum()
136 elif normalize == 'max':
137 nps /= nps.max()
139 return nps
142def nps2DRadAvg(nps: TwoD, fftshifted: bool = True, normalize: Or[str, None] = None) -> NDArray:
143 ''' Compute the radially averaged noise power spectrum (NPS) where `nps` can be the output from
144 `nps2D` (unnormalized, fftshifted). Do not normalize the input `nps` before using this function.
145 The output is a 1D array of NPS values [a.u.] and x-axis is the spatial frequency [1/[unit of pixelSize]].
146 '''
147 cripAssert(is2D(nps), '`nps` should be 2D.')
148 cripAssert(nps.shape[0] == nps.shape[1], '`nps` should be square.')
149 cripAssert(normalize in ['sum', 'max', None], f'Invalid normalize method: {normalize}.')
150 if nps.max() <= 1:
151 cripWarning(False, 'The input `nps` looks to be normalized already, which is not expected.')
152 if not fftshifted:
153 nps = np.fft.fftshift(nps)
155 N = nps.shape[0]
156 x, y = np.meshgrid(np.arange(N), np.arange(N))
157 R = np.sqrt(x**2 + y**2)
159 _f = lambda r: nps[(R >= r - .5) & (R < r + .5)].mean()
160 _args = np.linspace(1, N, num=N)
161 res = np.vectorize(_f)(_args)
163 if normalize == 'sum':
164 res /= res.sum()
165 elif normalize == 'max':
166 res /= res.max()
168 return res