Coverage for crip\utils.py: 83%

112 statements  

« prev     ^ index     » next       coverage.py v7.5.2, created at 2024-07-16 01:15 +0800

1''' 

2 Utilities of crip. 

3 

4 https://github.com/SEU-CT-Recon/crip 

5''' 

6 

7import os 

8import logging 

9import math 

10import numpy as np 

11import functools 

12 

13from ._typing import * 

14 

15 

16def readFileText(path_: str, encoding=None) -> str: 

17 ''' Read text file. 

18 ''' 

19 with open(path_, 'r', encoding=encoding) as fp: 

20 content = fp.read() 

21 

22 return content 

23 

24 

25### Expection ### 

26 

27 

28class CripException(BaseException): 

29 ''' The universal expection class for crip. 

30 ''' 

31 

32 def __init__(self, *args) -> None: 

33 super().__init__(*args) 

34 

35 

36def cripAssert(cond: Any, hint=''): 

37 ''' The only assert method for crip. 

38 ''' 

39 if not cond: 

40 raise CripException(hint) 

41 

42 

43def cripWarning(ensure: Any, hint='', stackTrace=False): 

44 ''' The only warning method for crip. 

45 ''' 

46 if not ensure: 

47 logging.warning(hint, stack_info=stackTrace) 

48 

49 

50### Type check ### 

51 

52 

53def ConvertListNDArray(f): 

54 ''' Function decorator to convert List[ndarray] to ndarray. 

55 ''' 

56 

57 @functools.wraps(f) 

58 def fn(*args, **kwargs): 

59 argsn = [] 

60 for a in args: 

61 if isListNDArray(a): 

62 a = np.array(a) 

63 argsn.append(a) 

64 

65 kwargsn = {} 

66 for k in kwargs: 

67 if isListNDArray(kwargs[k]): 

68 kwargs[k] = np.array(kwargs[k]) 

69 kwargsn[k] = kwargs[k] 

70 

71 return f(*argsn, **kwargsn) 

72 

73 return fn 

74 

75 

76def asFloat(arr: NDArray) -> NDArray: 

77 ''' Ensure `arr` has `DefaultFloatDType` dtype. 

78 ''' 

79 cripAssert(isType(arr, NDArray), '`arr` should be NDArray.') 

80 

81 return arr.astype(DefaultFloatDType) 

82 

83 

84def is1D(x: NDArray) -> bool: 

85 ''' Check if `x` is 1D ndarray. 

86 ''' 

87 return isType(x, NDArray) and len(x.shape) == 1 

88 

89 

90def is2D(x: NDArray) -> bool: 

91 ''' Check if `x` is 2D ndarray. 

92 ''' 

93 return isType(x, NDArray) and len(x.shape) == 2 

94 

95 

96def is3D(x: NDArray) -> bool: 

97 ''' Check if `x` is 3D ndarray. 

98 ''' 

99 return isType(x, NDArray) and len(x.shape) == 3 

100 

101 

102def is2or3D(x: NDArray) -> bool: 

103 ''' Check if `x` is 2D or 3D ndarray. 

104 ''' 

105 return is2D(x) or is3D(x) 

106 

107 

108def as3D(x: NDArray) -> NDArray: 

109 ''' Ensure `x` to be 3D ndarray. 

110 ''' 

111 cripAssert(is2or3D(x)) 

112 

113 return x if is3D(x) else x[np.newaxis, ...] 

114 

115 

116def isInt(n) -> bool: 

117 ''' Check if `n` is int. 

118 ''' 

119 return math.floor(n) == n 

120 

121 

122def isIntDtype(dtype) -> bool: 

123 ''' Check if `dtype` is integer type. 

124 ''' 

125 return np.issubdtype(dtype, np.integer) 

126 

127 

128def isFloatDtype(dtype) -> bool: 

129 ''' Check if `dtype` is float type. 

130 ''' 

131 return np.issubdtype(dtype, np.floating) 

132 

133 

134def hasIntDtype(arr: NDArray) -> bool: 

135 ''' Check if `arr` has integer dtype. 

136 ''' 

137 return isIntDtype(arr.dtype) 

138 

139 

140def isType(x, t) -> bool: 

141 ''' Check if `x` has type `t` or isinstance of `t`. 

142 ''' 

143 if t is Callable: 

144 return callable(x) 

145 return type(x) == t or isinstance(x, t) 

146 

147 

148def isListNDArray(x) -> bool: 

149 ''' Check if `x` is List[NDArray]. 

150 ''' 

151 return isType(x, list) and len(x) > 0 and isType(x[0], NDArray) 

152 

153 

154def isOfSameShape(a: NDArray, b: NDArray) -> bool: 

155 ''' Check if two NDArray have the same shape. 

156 ''' 

157 return np.array_equal(a.shape, b.shape) 

158 

159 

160def getAsset(folder: str, prefix='_asset') -> str: 

161 ''' Get asset path under `crip/<prefix>/<folder>`. 

162 ''' 

163 return os.path.join(os.path.dirname(os.path.abspath(__file__)), f'{prefix}/{folder}') 

164 

165 

166def convertEnergyUnit(arr: Or[NDArray, float], from_: str, to: str) -> Or[NDArray, float]: 

167 ''' Convert between energy units. [ev, keV, MeV] 

168 ''' 

169 units = ['eV', 'keV', 'MeV'] 

170 cripAssert(from_ in units and to in units, f'Invalid unit: from_={from_}, to={to}.') 

171 

172 from_ = units.index(from_) 

173 to = units.index(to) 

174 a = 1000 

175 b = 1 / a 

176 M = np.array([ 

177 [1, b, b**2], 

178 [a, 1, b], 

179 [a**2, a, 1], 

180 ]) # M[from_, to] 

181 

182 return arr * M[from_, to] 

183 

184 

185def convertLengthUnit(arr: Or[NDArray, float], from_: str, to: str) -> Or[NDArray, float]: 

186 ''' Convert between length units. [um, mm, cm, m] 

187 ''' 

188 units = ['um', 'mm', 'cm', 'm'] 

189 cripAssert(from_ in units and to in units, f'Invalid unit: from_={from_}, to={to}.') 

190 

191 from_ = units.index(from_) 

192 to = units.index(to) 

193 M = np.array([ 

194 [1, 1e-3, 1e-4, 1e-6], 

195 [1e3, 1, 1e-1, 1e-3], 

196 [1e4, 1e1, 1, 1e-2], 

197 [1e6, 1e3, 1e2, 1], 

198 ]) # M[from_, to] 

199 

200 return arr * M[from_, to] 

201 

202 

203def convertMuUnit(arr: Or[NDArray, float], from_: str, to: str) -> Or[NDArray, float]: 

204 ''' Convert between mu value units. [um-1, mm-1, cm-1, m-1] 

205 ''' 

206 units = ['um-1', 'mm-1', 'cm-1', 'm-1'] 

207 cripAssert(from_ in units and to in units, f'Invalid unit: from_={from_}, to={to}.') 

208 

209 return convertLengthUnit(arr, to.replace('-1', ''), from_.replace('-1', '')) 

210 

211 

212def convertConcentrationUnit(arr: Or[NDArray, float], from_: str, to: str) -> Or[NDArray, float]: 

213 ''' Convert between concentration units. [g/mL, mg/mL] 

214 ''' 

215 units = ['g/mL', 'mg/mL'] 

216 cripAssert(from_ in units and to in units, f'Invalid unit: from_={from_}, to={to}.') 

217 

218 from_ = units.index(from_) 

219 to = units.index(to) 

220 # to g/mL, mg/mL 

221 M = np.array([ 

222 [1, 1000], # from g/mL 

223 [1 / 1000, 1] # from mg/mL 

224 ]) 

225 

226 return arr * M[from_, to] 

227 

228 

229def getHnW(img: NDArray) -> Tuple[int, int]: 

230 ''' Get height and width of `img` with shape [CHW] or [HW]. 

231 ''' 

232 cripAssert(is2or3D(img), 'img should be 2D or 3D.') 

233 

234 return img.shape[-2], img.shape[-1] 

235 

236 

237def nextPow2(x: int) -> int: 

238 ''' Get the next power of 2 of integer `x`. 

239 ''' 

240 return 1 if x == 0 else 2**math.ceil(math.log2(x)) 

241 

242 

243def getAttrKeysOfObject(obj: object) -> List[str]: 

244 ''' Get all attribute keys of `obj` excluding methods, private and default attributes. 

245 ''' 

246 keys = [ 

247 a for a in (set(dir(obj)) - set(dir(object))) 

248 if not a.startswith('__') and not callable(getattr(obj, a)) and getattr(obj, a) is not None 

249 ] 

250 

251 return keys 

252 

253 

254def chw2hwc(img: ThreeD) -> ThreeD: 

255 ''' Convert CHW to HWC. 

256 ''' 

257 cripAssert(is3D(img), 'img should be 3D.') 

258 

259 return np.moveaxis(img, 0, -1) 

260 

261 

262def hwc2chw(img: ThreeD) -> ThreeD: 

263 ''' Convert HWC to CHW. 

264 ''' 

265 cripAssert(is3D(img), 'img should be 3D.') 

266 

267 return np.moveaxis(img, -1, 0) 

268 

269 

270def simpleValidate(conds: List[bool]): 

271 ''' Validate conditions. 

272 ''' 

273 for i in range(len(conds)): 

274 cripAssert(conds[i], f'Condition {i} validation failed.') 

275 

276 

277def identity(x: Any) -> Any: 

278 ''' Identity function. 

279 ''' 

280 return x