Coverage for crip\plot.py: 52%
132 statements
« prev ^ index » next coverage.py v7.5.2, created at 2024-07-16 01:15 +0800
« prev ^ index » next coverage.py v7.5.2, created at 2024-07-16 01:15 +0800
1'''
2 Figure artist module of crip.
4 https://github.com/z0gSh1u/crip
5'''
7import cv2
8import numpy as np
9import matplotlib.pyplot as plt
10import matplotlib.figure
11import matplotlib.patches
12import matplotlib.axes
13from mpl_toolkits.axes_grid1 import ImageGrid as MplImageGrid
14from scipy.ndimage import uniform_filter
16from ._typing import *
17from .utils import *
18from .physics import Spectrum, DiagEnergyRange, Atten
19from .shared import resizeTo
21VMIN0_VMAX1 = {'vmin': 0, 'vmax': 1}
24def smooth1D(data: NDArray, winSize: int = 5) -> NDArray:
25 ''' Smooth an 1D array using moving average window with length `winSize`.
26 The implementation is from https://stackoverflow.com/questions/40443020
27 '''
28 cripAssert(is1D(data), '`data` should be 1D array.')
29 cripAssert(isInt(winSize) and winSize % 2 == 1, '`winSize` should be odd positive integer.')
31 out0 = np.convolve(data, np.ones(winSize, dtype=int), 'valid') / winSize
32 r = np.arange(1, winSize - 1, 2)
33 start = np.cumsum(data[:winSize - 1])[::2] / r
34 stop = (np.cumsum(data[:-winSize:-1])[::2] / r)[::-1]
36 return np.concatenate((start, out0, stop))
39def smoothZ(img: ThreeD, ksize=3) -> ThreeD:
40 ''' Smooth a 3D image using a uniform filter with `ksize` along Z dimension.
41 '''
42 cripAssert(is3D(img), '`img` should be 3D array.')
44 kernel = (ksize, 1, 1)
45 img = uniform_filter(img, kernel, mode='reflect')
47 return img
50def window(img: TwoOrThreeD,
51 win: Tuple[float, float],
52 style: str = 'lr',
53 normalize: Or[str, None] = None) -> TwoOrThreeD:
54 ''' Window `img` using `win` (WW, WL) with style `wwwl` or (left, right) with style `lr`.
55 Set `normalize` to `0255` to convert to 8-bit image, or `01` to [0, 1] float image.
56 '''
57 cripAssert(len(win) == 2, '`win` should have length of 2.')
58 cripAssert(style in ['wwwl', 'lr'], "`style` should be 'wwwl' or 'lr'")
60 if style == 'wwwl':
61 ww, wl = win
62 l = wl - ww / 2
63 r = l + ww
64 elif style == 'lr':
65 l, r = win
67 res = asFloat(img.copy())
68 res[res > r] = r
69 res[res < l] = l
71 if normalize == '0255':
72 res = cv2.normalize(res, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
73 elif normalize == '01':
74 res = (res - l) / (r - l)
76 return res
79def windowFullRange(img: TwoOrThreeD, normalize='01') -> TwoOrThreeD:
80 ''' Window `img` using its full dynamic range of pixel values.
81 '''
82 return window(img, (np.max(img), np.min(img)), 'lr', normalize)
85def zoomIn(img: TwoD, row: int, col: int, h: int, w: int) -> TwoD:
86 ''' Crop a patch. `(row, col)` determines the left-top point. `(h, w)` gives height and width.
87 '''
88 return img[row:row + h, col:col + w]
91def stddev(img: TwoD) -> float:
92 ''' Compute the standard deviation of a image crop.
93 `(row, col)` determines the left-top point. (h, w) gives height and width.
94 '''
95 return np.std(img)
98def meanstd(x: Any) -> Tuple[float, float]:
99 ''' Compute the mean and standard deviation of `x`.
100 '''
101 return np.mean(x), np.std(x)
104def fontdict(family, weight, size):
105 return {'family': family, 'weight': weight, 'size': size}
108class ImageGrid:
109 subimgs: List[TwoD]
110 nrow: int
111 ncol: int
112 fig: matplotlib.figure.Figure
113 grid: MplImageGrid
115 # titles
116 rowTitles: List[str] = None
117 colTitles: List[str] = None
118 # preprocessor
119 preprocessor: Callable = None
120 # fontdict
121 fontdict: Dict = None
122 # crops
123 crops: List[Tuple[int, int, int]] = None
124 cropLoc: str = 'bottom left'
125 cropSize: int = 96 * 2
127 def __init__(self, subimgs: List[TwoD], nrow: int, ncol: int) -> None:
128 ''' Initialize the ImageGrid with `subimgs` in `nrow` * `ncol` layout.
129 '''
130 self.subimgs = subimgs
131 self.nrow = nrow
132 self.ncol = ncol
133 cripAssert(len(subimgs) == nrow * ncol, 'Number of subimages should be equal to `nrow * ncol`.')
135 def setTitles(self, rowTitles: List[str], colTitles: List[str]):
136 ''' Set the row and column titles.
137 '''
138 self.rowTitles = rowTitles
139 self.colTitles = colTitles
141 def setPreprocessor(self, fn: Callable):
142 ''' Set the preprocessor for the subimages.
143 A preprocessor is a function that takes the index of a subimage and the subimage and returns a new one.
144 '''
145 self.preprocessor = fn
147 def setFontdict(self, fontdict: Dict):
148 ''' Set the fontdict for the texts in the figure.
149 '''
150 self.fontdict = fontdict
152 def setCrops(self, crops, cropLoc='bottom left', cropSize=96 * 2):
153 self.crops = crops
154 self.cropLoc = cropLoc
155 self.cropSize = cropSize
157 def _overlayPatch(self, img, patch, loc):
158 if loc == 'bottom left':
159 img[-patch.shape[0]:, :patch.shape[1]] = patch
160 box = matplotlib.patches.Rectangle((0, img.shape[0] - patch.shape[0]),
161 patch.shape[1],
162 patch.shape[0],
163 linewidth=1,
164 edgecolor='yellow',
165 facecolor='none')
166 return box
167 else:
168 cripAssert(False, 'Currently only loc at `bottom left` is supported.')
170 def fig(self):
171 ''' Execute the plot and return the figure.
172 '''
173 # preprocess the subimages
174 if self.preprocessor is not None:
175 self.subimages = list(map(lambda ix: self.preprocessor(*ix), list(enumerate(self.subimages))))
177 # create the figure
178 self.fig = plt.figure(figsize=(self.ncol * 2, self.nrow * 2))
179 self.grid = MplImageGrid(self.fig, 111, nrows_ncols=(self.nrow, self.ncol), axes_pad=0)
181 # display the subimages
182 cur = 0
183 for ax, img in zip(self.grid, self.subimages):
184 # remove the ticks and spines
185 ax.get_xaxis().set_ticks([])
186 ax.get_yaxis().set_ticks([])
187 list(map(lambda x: x.set_visible(False), ax.spines.values()))
189 # prepare the image crop
190 box = None
191 if self.crops is not None and self.crops[cur // self.ncol]:
192 r, c, hw = self.crops[cur // self.ncol]
193 patch = resizeTo(zoomIn(img, r, c, hw, hw), (self.cropSize, self.cropSize))
194 box = self._overlayPatch(img, patch, self.cropLoc)
196 # display the image
197 ax.imshow(img, cmap='gray', vmin=0, vmax=1)
199 # display the crop box
200 box and ax.add_patch(box)
201 if box and cur % self.ncol == 0:
202 r, c, hw = self.crops[cur // self.ncol]
203 box = matplotlib.patches.Rectangle((c, r), hw, hw, linewidth=0.8, edgecolor='yellow', facecolor='none')
204 ax.add_patch(box)
206 # display the column titles
207 if self.colTitles and cur < len(self.colTitles):
208 ax.set_title(self.colTitles[cur], loc='center', fontdict=fontdict)
209 cur += 1
211 # display the row titles
212 if self.rowTitles:
213 for i in range(self.nrows):
214 self.grid[self.ncols * i].set_ylabel(self.rowTitles[i], fontdict=fontdict)
216 return self.fig
219def plotSpectrum(ax: matplotlib.axes.Axes, spec: Spectrum):
220 ''' Plot the spectrum using handler `ax`.
221 '''
222 energies = DiagEnergyRange
223 omega = spec.omega
225 ax.plot(energies, omega, 'k')
226 ax.set_xlabel('Energy (keV)')
227 ax.set_ylabel('Omega (a.u.)')
230def plotMu(ax: matplotlib.axes.Axes, atten: Atten, startEnergy: int = 1, logScale=True):
231 ''' Plot the LACs of `atten` from `startEnergy` keV in ax in `logScale` if true.
232 '''
233 x = list(DiagEnergyRange)[startEnergy:]
234 ax.plot(x, atten.mu[startEnergy:])
236 if logScale:
237 ax.set_yscale('log')
239 ax.xlabel('Energy (keV)')
240 ax.ylabel('LAC (1/mm)')
243def savefigTight(fig, path, dpi=200, pad=0.05):
244 fig.tight_layout()
245 fig.savefig(path, dpi=dpi, bbox_inches='tight', pad_inches=pad)