Coverage for crip\plot.py: 52%

132 statements  

« prev     ^ index     » next       coverage.py v7.5.2, created at 2024-07-16 01:15 +0800

1''' 

2 Figure artist module of crip. 

3 

4 https://github.com/z0gSh1u/crip 

5''' 

6 

7import cv2 

8import numpy as np 

9import matplotlib.pyplot as plt 

10import matplotlib.figure 

11import matplotlib.patches 

12import matplotlib.axes 

13from mpl_toolkits.axes_grid1 import ImageGrid as MplImageGrid 

14from scipy.ndimage import uniform_filter 

15 

16from ._typing import * 

17from .utils import * 

18from .physics import Spectrum, DiagEnergyRange, Atten 

19from .shared import resizeTo 

20 

21VMIN0_VMAX1 = {'vmin': 0, 'vmax': 1} 

22 

23 

24def smooth1D(data: NDArray, winSize: int = 5) -> NDArray: 

25 ''' Smooth an 1D array using moving average window with length `winSize`. 

26 The implementation is from https://stackoverflow.com/questions/40443020 

27 ''' 

28 cripAssert(is1D(data), '`data` should be 1D array.') 

29 cripAssert(isInt(winSize) and winSize % 2 == 1, '`winSize` should be odd positive integer.') 

30 

31 out0 = np.convolve(data, np.ones(winSize, dtype=int), 'valid') / winSize 

32 r = np.arange(1, winSize - 1, 2) 

33 start = np.cumsum(data[:winSize - 1])[::2] / r 

34 stop = (np.cumsum(data[:-winSize:-1])[::2] / r)[::-1] 

35 

36 return np.concatenate((start, out0, stop)) 

37 

38 

39def smoothZ(img: ThreeD, ksize=3) -> ThreeD: 

40 ''' Smooth a 3D image using a uniform filter with `ksize` along Z dimension. 

41 ''' 

42 cripAssert(is3D(img), '`img` should be 3D array.') 

43 

44 kernel = (ksize, 1, 1) 

45 img = uniform_filter(img, kernel, mode='reflect') 

46 

47 return img 

48 

49 

50def window(img: TwoOrThreeD, 

51 win: Tuple[float, float], 

52 style: str = 'lr', 

53 normalize: Or[str, None] = None) -> TwoOrThreeD: 

54 ''' Window `img` using `win` (WW, WL) with style `wwwl` or (left, right) with style `lr`. 

55 Set `normalize` to `0255` to convert to 8-bit image, or `01` to [0, 1] float image. 

56 ''' 

57 cripAssert(len(win) == 2, '`win` should have length of 2.') 

58 cripAssert(style in ['wwwl', 'lr'], "`style` should be 'wwwl' or 'lr'") 

59 

60 if style == 'wwwl': 

61 ww, wl = win 

62 l = wl - ww / 2 

63 r = l + ww 

64 elif style == 'lr': 

65 l, r = win 

66 

67 res = asFloat(img.copy()) 

68 res[res > r] = r 

69 res[res < l] = l 

70 

71 if normalize == '0255': 

72 res = cv2.normalize(res, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) 

73 elif normalize == '01': 

74 res = (res - l) / (r - l) 

75 

76 return res 

77 

78 

79def windowFullRange(img: TwoOrThreeD, normalize='01') -> TwoOrThreeD: 

80 ''' Window `img` using its full dynamic range of pixel values. 

81 ''' 

82 return window(img, (np.max(img), np.min(img)), 'lr', normalize) 

83 

84 

85def zoomIn(img: TwoD, row: int, col: int, h: int, w: int) -> TwoD: 

86 ''' Crop a patch. `(row, col)` determines the left-top point. `(h, w)` gives height and width. 

87 ''' 

88 return img[row:row + h, col:col + w] 

89 

90 

91def stddev(img: TwoD) -> float: 

92 ''' Compute the standard deviation of a image crop. 

93 `(row, col)` determines the left-top point. (h, w) gives height and width. 

94 ''' 

95 return np.std(img) 

96 

97 

98def meanstd(x: Any) -> Tuple[float, float]: 

99 ''' Compute the mean and standard deviation of `x`. 

100 ''' 

101 return np.mean(x), np.std(x) 

102 

103 

104def fontdict(family, weight, size): 

105 return {'family': family, 'weight': weight, 'size': size} 

106 

107 

108class ImageGrid: 

109 subimgs: List[TwoD] 

110 nrow: int 

111 ncol: int 

112 fig: matplotlib.figure.Figure 

113 grid: MplImageGrid 

114 

115 # titles 

116 rowTitles: List[str] = None 

117 colTitles: List[str] = None 

118 # preprocessor 

119 preprocessor: Callable = None 

120 # fontdict 

121 fontdict: Dict = None 

122 # crops 

123 crops: List[Tuple[int, int, int]] = None 

124 cropLoc: str = 'bottom left' 

125 cropSize: int = 96 * 2 

126 

127 def __init__(self, subimgs: List[TwoD], nrow: int, ncol: int) -> None: 

128 ''' Initialize the ImageGrid with `subimgs` in `nrow` * `ncol` layout. 

129 ''' 

130 self.subimgs = subimgs 

131 self.nrow = nrow 

132 self.ncol = ncol 

133 cripAssert(len(subimgs) == nrow * ncol, 'Number of subimages should be equal to `nrow * ncol`.') 

134 

135 def setTitles(self, rowTitles: List[str], colTitles: List[str]): 

136 ''' Set the row and column titles. 

137 ''' 

138 self.rowTitles = rowTitles 

139 self.colTitles = colTitles 

140 

141 def setPreprocessor(self, fn: Callable): 

142 ''' Set the preprocessor for the subimages. 

143 A preprocessor is a function that takes the index of a subimage and the subimage and returns a new one. 

144 ''' 

145 self.preprocessor = fn 

146 

147 def setFontdict(self, fontdict: Dict): 

148 ''' Set the fontdict for the texts in the figure. 

149 ''' 

150 self.fontdict = fontdict 

151 

152 def setCrops(self, crops, cropLoc='bottom left', cropSize=96 * 2): 

153 self.crops = crops 

154 self.cropLoc = cropLoc 

155 self.cropSize = cropSize 

156 

157 def _overlayPatch(self, img, patch, loc): 

158 if loc == 'bottom left': 

159 img[-patch.shape[0]:, :patch.shape[1]] = patch 

160 box = matplotlib.patches.Rectangle((0, img.shape[0] - patch.shape[0]), 

161 patch.shape[1], 

162 patch.shape[0], 

163 linewidth=1, 

164 edgecolor='yellow', 

165 facecolor='none') 

166 return box 

167 else: 

168 cripAssert(False, 'Currently only loc at `bottom left` is supported.') 

169 

170 def fig(self): 

171 ''' Execute the plot and return the figure. 

172 ''' 

173 # preprocess the subimages 

174 if self.preprocessor is not None: 

175 self.subimages = list(map(lambda ix: self.preprocessor(*ix), list(enumerate(self.subimages)))) 

176 

177 # create the figure 

178 self.fig = plt.figure(figsize=(self.ncol * 2, self.nrow * 2)) 

179 self.grid = MplImageGrid(self.fig, 111, nrows_ncols=(self.nrow, self.ncol), axes_pad=0) 

180 

181 # display the subimages 

182 cur = 0 

183 for ax, img in zip(self.grid, self.subimages): 

184 # remove the ticks and spines 

185 ax.get_xaxis().set_ticks([]) 

186 ax.get_yaxis().set_ticks([]) 

187 list(map(lambda x: x.set_visible(False), ax.spines.values())) 

188 

189 # prepare the image crop 

190 box = None 

191 if self.crops is not None and self.crops[cur // self.ncol]: 

192 r, c, hw = self.crops[cur // self.ncol] 

193 patch = resizeTo(zoomIn(img, r, c, hw, hw), (self.cropSize, self.cropSize)) 

194 box = self._overlayPatch(img, patch, self.cropLoc) 

195 

196 # display the image 

197 ax.imshow(img, cmap='gray', vmin=0, vmax=1) 

198 

199 # display the crop box 

200 box and ax.add_patch(box) 

201 if box and cur % self.ncol == 0: 

202 r, c, hw = self.crops[cur // self.ncol] 

203 box = matplotlib.patches.Rectangle((c, r), hw, hw, linewidth=0.8, edgecolor='yellow', facecolor='none') 

204 ax.add_patch(box) 

205 

206 # display the column titles 

207 if self.colTitles and cur < len(self.colTitles): 

208 ax.set_title(self.colTitles[cur], loc='center', fontdict=fontdict) 

209 cur += 1 

210 

211 # display the row titles 

212 if self.rowTitles: 

213 for i in range(self.nrows): 

214 self.grid[self.ncols * i].set_ylabel(self.rowTitles[i], fontdict=fontdict) 

215 

216 return self.fig 

217 

218 

219def plotSpectrum(ax: matplotlib.axes.Axes, spec: Spectrum): 

220 ''' Plot the spectrum using handler `ax`. 

221 ''' 

222 energies = DiagEnergyRange 

223 omega = spec.omega 

224 

225 ax.plot(energies, omega, 'k') 

226 ax.set_xlabel('Energy (keV)') 

227 ax.set_ylabel('Omega (a.u.)') 

228 

229 

230def plotMu(ax: matplotlib.axes.Axes, atten: Atten, startEnergy: int = 1, logScale=True): 

231 ''' Plot the LACs of `atten` from `startEnergy` keV in ax in `logScale` if true. 

232 ''' 

233 x = list(DiagEnergyRange)[startEnergy:] 

234 ax.plot(x, atten.mu[startEnergy:]) 

235 

236 if logScale: 

237 ax.set_yscale('log') 

238 

239 ax.xlabel('Energy (keV)') 

240 ax.ylabel('LAC (1/mm)') 

241 

242 

243def savefigTight(fig, path, dpi=200, pad=0.05): 

244 fig.tight_layout() 

245 fig.savefig(path, dpi=dpi, bbox_inches='tight', pad_inches=pad)