Coverage for src/algolib/numerics/trig_pure.py: 69%

275 statements  

« prev     ^ index     » next       coverage.py v7.10.4, created at 2025-08-20 19:37 +0000

1# src/algolib/numerics/trig.py 

2r""" 

3Numerical trig functions: sin / cos / tan 

4 

5- No stdlib math usage in implementation 

6- Cody–Waite style range reduction by π/2 

7- Polynomial approximation on [-π/4, π/4] 

8""" 

9 

10from algolib.numerics.constants import ( 

11 INV_PI_2_HI, INV_PI_2_LO, 

12 PI2_HI, PI2_MID, PI2_LO, 

13 PI_4 

14) 

15 

16PI4_HI = PI2_HI * 0.5 

17PI4_MID = PI2_MID * 0.5 

18PI4_LO = PI2_LO * 0.5 

19 

20 

21 

22 

23_SPLIT = 134217729.0 

24 

25def _split(a: float) -> tuple[float, float]: 

26 C_SPLIT = 134217729.0 # 2**27 + 1 

27 c = C_SPLIT * a 

28 ah = c - (c - a) 

29 al = a - ah 

30 return ah, al 

31 

32def _two_sum(a: float, b: float) -> tuple[float, float]: 

33 s = a + b 

34 bp = s - a 

35 e = (a - (s - bp)) + (b - bp) 

36 return s, e 

37 

38def _two_prod(a: float, b: float) -> tuple[float, float]: 

39 p = a * b 

40 ah, al = _split(a) 

41 bh, bl = _split(b) 

42 err = ((ah * bh - p) + ah * bl + al * bh) + al * bl 

43 return p, err 

44 

45def _compensated_div(n: float, d: float) -> float: 

46 """ 

47 计算 n/d ,Kahan 补偿除法: 

48 q0 = n/d 

49 r = n - q0*d (用 two_prod 得到无损乘积,残差更准) 

50 q1 = q0 + r/d 

51 再做一次残差修正,进一步压误差。 

52 """ 

53 # 第一次修正 

54 q0 = n / d 

55 p, pe = _two_prod(q0, d) # p ≈ q0*d, pe 是乘法的舍入误差 

56 r = (n - p) - pe 

57 q = q0 + r / d 

58 

59 # 第二次(通常很小,但能把难点例子压下去) 

60 p2, pe2 = _two_prod(q, d) 

61 r2 = (n - p2) - pe2 

62 return q + r2 / d 

63 

64def _dd_norm(h: float, l: float): 

65 # 归一化 double-double:保证 |l| <= 0.5 ulp(h) 

66 return _two_sum(h, l) 

67 

68def _dd_from_three(a_hi: float, a_mid: float, a_lo: float): 

69 # 把三段拼成规范化的 (H, L) 

70 s, e = _two_sum(a_hi, a_mid) 

71 h, t = _two_sum(s, a_lo) 

72 return _dd_norm(h, e + t) 

73 

74PI4_H, PI4_L = _dd_from_three(PI4_HI, PI4_MID, PI4_LO) 

75 

76def _floor(x: float) -> int: 

77 # no math.floor; consistent for negatives 

78 i = int(x) 

79 return i if (x >= 0.0 or i == x) else (i - 1) 

80 

81def _round_nearest_even_dd(yh: float, yl: float) -> int: 

82 """ 

83 round(yh+yl) to nearest-even, 在 double-double 里严格处理 .5 的粘连。 

84 """ 

85 k = _floor(yh) # 先用 yh 的 floor 做基准 

86 r = (yh - k) + yl # 剩余分数(包含低位) 

87 if r > 0.5 or (r == 0.5 and (k & 1) == 1): 

88 return k + 1 

89 if r < -0.5 or (r == -0.5 and (k & 1) == 1): 

90 return k - 1 

91 return k 

92 

93 

94_INF = float("inf") 

95_NAN = float("nan") 

96 

97def _nearest_int(y: float) -> int: 

98 # round-to-nearest, ties-to-even 

99 i = int(y) # toward zero 

100 f = y - i 

101 if y >= 0.0: 

102 if f > 0.5 or (f == 0.5 and (i & 1) == 1): 

103 i += 1 

104 else: 

105 if f < -0.5 or (f == -0.5 and (i & 1) == 1): 

106 i -= 1 

107 return i 

108 

109def _is_finite(x: float) -> bool: 

110 # NaN: x != x ; inf: |x| == inf 

111 if x != x: 

112 return False 

113 # compare against infinities without math 

114 return not (x == _INF or x == -_INF) 

115 

116def _round_half_even(y: float) -> int: 

117 # 银行家舍入:最近整数;恰好在 .5 时取偶数 

118 k = int(y) # toward zero 

119 frac = y - k 

120 if y >= 0.0: 

121 if frac > 0.5 or (frac == 0.5 and (k & 1) == 1): 

122 k += 1 

123 else: 

124 if frac < -0.5 or (frac == -0.5 and (k & 1) == 1): 

125 k -= 1 

126 return k 

127 

128 

129_BIG_ARG = 1 << 17 

130 

131 

132# ---- helpers to add/remove one (pi/2) in double-double ---- 

133def _dd_add(a_hi: float, a_lo: float, b: float): 

134 s, e = _two_sum(a_hi, b) 

135 t, f = _two_sum(a_lo, e) 

136 return _two_sum(s, t + f) # renormalize 

137 

138def _dd_add_pi2(r_hi: float, r_lo: float, sign: int): 

139 # 稳定顺序:先高段相加,再把中段并入低位,最后把误差和最末段一次性收尾 

140 s0, e0 = _two_sum(r_hi, sign * PI2_HI) # 高位合并 

141 s1, e1 = _two_sum(r_lo, sign * PI2_MID) # 低位合并 

142 s2, e2 = _two_sum(s0, s1) # 归并 

143 tail = e0 + e1 + e2 + sign * PI2_LO # 所有尾项 + 最末段 

144 return _dd_norm(s2, tail) 

145 

146def _dd_sub(a_hi: float, a_lo: float, b_hi: float, b_lo: float): 

147 s, e = _two_sum(a_hi, -b_hi) 

148 t, f = _two_sum(a_lo, -b_lo) 

149 return _two_sum(s, t + f) 

150 

151# ---------------- range reduction with post-correction ---------------- 

152# 词典序比较:判断 (a_hi+a_lo) 是否大于 (b_hi+b_lo) 

153def _dd_gt(a_hi: float, a_lo: float, b_hi: float, b_lo: float) -> bool: 

154 if a_hi > b_hi: return True 

155 if a_hi < b_hi: return False 

156 return a_lo > b_lo 

157 

158def _dd_lt(a_hi: float, a_lo: float, b_hi: float, b_lo: float) -> bool: 

159 if a_hi < b_hi: return True 

160 if a_hi > b_hi: return False 

161 return a_lo < b_lo 

162 

163 

164 

165 

166 

167# -------- 改造后的规约:round-to-even + 事后校正把 r 拉回 [-pi/4, pi/4] -------- 

168def _reduce_pi2(x: float): 

169 # y = x*(2/pi) (double-double) 

170 p1, e1 = _two_prod(x, INV_PI_2_HI) 

171 p2, e2 = _two_prod(x, INV_PI_2_LO) 

172 s, t = _two_sum(p1, p2) # s≈y 的高位 

173 e = e1 + e2 

174 yh, u = _two_sum(s, e) # 先把 e 推到高位上 

175 yl = t + u # 剩余都归到低位 

176 yh, yl = _dd_norm(yh, yl) # <--- 新增:规范化 (yh, yl) 

177 

178 # k = round-to-even 

179 # k = _round_nearest_even_dd(yh, yl) 

180 k = _round_half_even(yh + yl) # 直接对 yh+yl 做最近偶数 

181 kh = float(k) 

182 

183 # r = x - k*(pi/2) (double-double) 

184 p, pe = _two_prod(kh, PI2_HI) 

185 r, re = _two_sum(x, -p) 

186 err = (-pe) + re 

187 

188 p, pe = _two_prod(kh, PI2_MID) 

189 r, re = _two_sum(r, -p); err += (-pe) + re 

190 

191 p, pe = _two_prod(kh, PI2_LO) 

192 r, re = _two_sum(r, -p); err += (-pe) + re 

193 

194 r_hi, r_lo = _two_sum(r, err) 

195 

196 # 事后校正到 [-pi/4, pi/4] 

197 d_hi, d_lo = _dd_sub(r_hi, r_lo, PI4_H, PI4_L) 

198 if d_hi > 0.0 or (d_hi == 0.0 and d_lo > 0.0): 

199 k += 1 

200 r_hi, r_lo = _dd_add_pi2(r_hi, r_lo, -1) 

201 d_hi, d_lo = _dd_sub(r_hi, r_lo, PI4_H, PI4_L) 

202 if d_hi > 0.0 or (d_hi == 0.0 and d_lo > 0.0): 

203 k += 1 

204 r_hi, r_lo = _dd_add_pi2(r_hi, r_lo, -1) 

205 else: 

206 d2_hi, d2_lo = _dd_sub(-r_hi, -r_lo, PI4_H, PI4_L) 

207 if d2_hi > 0.0 or (d2_hi == 0.0 and d2_lo > 0.0): 

208 k -= 1 

209 r_hi, r_lo = _dd_add_pi2(r_hi, r_lo, +1) 

210 d2_hi, d2_lo = _dd_sub(-r_hi, -r_lo, PI4_H, PI4_L) 

211 if d2_hi > 0.0 or (d2_hi == 0.0 and d2_lo > 0.0): 

212 k -= 1 

213 r_hi, r_lo = _dd_add_pi2(r_hi, r_lo, +1) 

214 

215 q = k & 3 

216 return q, r_hi, r_lo 

217 

218# ----------------------------- 

219# ----------------------------- 

220# Polynomial kernels on small |r| 

221# ----------------------------- 

222# 采用 fdlibm 风格的极小极大系数(double 精度) 

223# sin(r) ≈ r + r^3 * (S1 + r^2 * (S2 + ... + r^2 * S6)) 

224_S1 = -1.66666666666666666666666666667e-01 # -1/3!,微调后 

225_S2 = 8.33333333333333333333333333333e-03 

226_S3 = -1.98412698412698412698412698413e-04 

227_S4 = 2.75573192239858906525573192240e-06 

228_S5 = -2.50521083854417187750521083854e-08 

229_S6 = 1.60590438368216145993923771702e-10 

230_S7 = -7.64716373181981647590113198579e-13 

231 

232# cos(r) ≈ 1 + r^2 * (C1 + r^2 * (C2 + ... + r^2 * C6)) 

233_C1 = -5.00000000000000000000000000000e-01 

234_C2 = 4.16666666666666666666666666667e-02 

235_C3 = -1.38888888888888888888888888889e-03 

236_C4 = 2.48015873015873015873015873016e-05 

237_C5 = -2.75573192239858906525573192240e-07 

238_C6 = 2.08767569878680989792100903212e-09 

239_C7 = -1.14707455977297247138516979787e-11 

240_C8 = 4.77947733238738529743820749112e-14 

241# 注:fdlibm 里还有 C7≈-1.13596475577881948265e-11;是否加到 7 次由你权衡。 

242# 先到 C6 通常就足以把误差压过你当前阈值。 

243 

244 

245def _sin_kernel(r: float) -> float: 

246 z = r * r 

247 p = (((((((_S7 * z + _S6) * z + _S5) * z + _S4) * z + _S3) * z + _S2) * z + _S1)) 

248 return r + r * z * p 

249 

250def _cos_kernel(r: float) -> float: 

251 z = r * r 

252 p = ((((((((_C8 * z + _C7) * z + _C6) * z + _C5) * z + _C4) * z + _C3) * z + _C2) * z + _C1)) 

253 return 1.0 + z * p 

254 

255 

256_STICKY = 2**-40 # ≈ 9.09e-13,仅在 r 极小才触发 

257 

258def _sin_cos_dd(r_hi: float, r_lo: float): 

259 """Return (sin(r), cos(r)) for r = r_hi + r_lo using higher-order correction in b=r_lo. 

260 

261 We keep the polynomial kernels at a = r_hi, and include b terms up to O(b^5): 

262 sin(a+b) ≈ sin a * (1 - b^2/2 + b^4/24) + cos a * (b - b^3/6 + b^5/120) 

263 cos(a+b) ≈ cos a * (1 - b^2/2 + b^4/24) - sin a * (b - b^3/6 + b^5/120) 

264 This tightens errors near |r| ≈ π/4 and for large-argument reductions. 

265 """ 

266 s0 = _sin_kernel(r_hi) 

267 c0 = _cos_kernel(r_hi) 

268 

269 b = r_lo 

270 # 极小 b 直接返回,避免无谓的舍入噪声 

271 if b == 0.0: 

272 return s0, c0 

273 

274 b2 = b * b 

275 b3 = b2 * b 

276 b4 = b2 * b2 

277 b5 = b4 * b 

278 

279 # 泰勒系数 

280 sb = b - (1.0 / 6.0) * b3 + (1.0 / 120.0) * b5 # b - b^3/6 + b^5/120 

281 cb = 1.0 - 0.5 * b2 + (1.0 / 24.0) * b4 # 1 - b^2/2 + b^4/24 

282 

283 s = s0 * cb + c0 * sb 

284 c = c0 * cb - s0 * sb 

285 return s, c 

286 

287 

288# --- 3) 顶层 sin/cos/tan 用新的规约 + 二阶补偿 --- 

289# src/algolib/numerics/trig.py 里,保持 _sin_kernel / _cos_kernel 不变 

290# 只改顶层 sin / cos / tan 

291 

292def sin(x: float) -> float: 

293 if not _is_finite(x): 

294 return _NAN 

295 q, a, b = _reduce_pi2(x) 

296 # 小残差“粘零”:避免周期性用例冒出 1e-12 级毛刺 

297 r_sum = a + b 

298 if -9.094947017729282e-13 < r_sum < 9.094947017729282e-13: # 2**-40 

299 # 用 s≈r, c≈1 的最小近似,再做象限拼接 

300 s_small = r_sum 

301 c_small = 1.0 

302 if q == 0: # sin 

303 return s_small 

304 if q == 1: 

305 return c_small 

306 if q == 2: 

307 return -s_small 

308 return -c_small 

309 s, c = _sin_cos_dd(a, b) 

310 if q == 0: return s 

311 if q == 1: return c 

312 if q == 2: return -s 

313 return -c 

314 

315def cos(x: float) -> float: 

316 if not _is_finite(x): 

317 return _NAN 

318 q, a, b = _reduce_pi2(x) 

319 r_sum = a + b 

320 if -9.094947017729282e-13 < r_sum < 9.094947017729282e-13: 

321 s_small = r_sum 

322 c_small = 1.0 

323 if q == 0: 

324 return c_small 

325 if q == 1: 

326 return -s_small 

327 if q == 2: 

328 return -c_small 

329 return s_small 

330 s, c = _sin_cos_dd(a, b) 

331 if q == 0: return c 

332 if q == 1: return -s 

333 if q == 2: return -c 

334 return s 

335 

336_T1 = 0.33333333333333333333 

337_T2 = 0.13333333333333333333 

338_T3 = 0.053968253968253968254 

339_T4 = 0.021869488536155202822 

340_T5 = 0.0088632355299021965689 

341_T6 = 0.0035921280365724810169 

342_T7 = 0.0014558343870513182682 

343_T8 = 0.00059002744094558598138 

344_T9 = 0.00023912911424355248149 

345_T10 = 0.000096915379569294503256 

346_T11 = 0.000039278323883316834053 

347 

348def _tan_kernel(r: float) -> float: 

349 # Horner: tan(r) ≈ r + r^3*(T1 + z*(T2 + ...)), z=r^2 

350 z = r * r 

351 p = (((((((((( _T11 * z + _T10) * z + _T9) * z + _T8) * z + _T7) 

352 * z + _T6) * z + _T5) * z + _T4) * z + _T3) * z + _T2) * z + _T1) 

353 return r + r * z * p 

354 

355def _refined_inv(z: float) -> float: 

356 # 纯算术两步牛顿:先粗略 inv,然后两次 inv *= (2 - z*inv)。 

357 # 其中 z*inv 与校正用 two_prod / two_sum 做补偿,减少每步舍入。 

358 inv = 1.0 / z 

359 

360 # step 1 

361 p, pe = _two_prod(z, inv) # p ≈ z*inv 

362 t, te = _two_sum(2.0, -p) # t ≈ 2 - p 

363 t += -pe + te # 把乘法与加法误差一并补上 

364 inv *= t 

365 

366 # step 2 

367 p, pe = _two_prod(z, inv) 

368 t, te = _two_sum(2.0, -p) 

369 t += -pe + te 

370 inv *= t 

371 

372 return inv 

373 

374def tan(x: float) -> float: 

375 if not _is_finite(x): 

376 return _NAN 

377 

378 q, r_hi, r_lo = _reduce_pi2(x) 

379 s, c = _sin_cos_dd(r_hi, r_lo) # 你现有的高精核/返回 float 都行 

380 

381 if q == 0 or q == 2: 

382 # tan = s / c 用补偿除法 

383 return _compensated_div(s, c) 

384 else: 

385 # tan = -c / s 

386 return -_compensated_div(c, s) 

387 

388# --- backend thin wrapper for set_backend("pure") --- 

389 

390import math as _math 

391 

392# 避免名字冲突,先存一份指针 

393_sin_impl = sin 

394_cos_impl = cos 

395_tan_impl = tan 

396 

397class PureTrigBackend: 

398 """Wrap current pure-Python trig implementations as a backend.""" 

399 name = "pure" 

400 

401 def sin(self, x): 

402 try: 

403 xf = float(x) 

404 except Exception: 

405 return float("nan") 

406 if not _math.isfinite(xf): 

407 return float("nan") 

408 return _sin_impl(xf) 

409 

410 def cos(self, x): 

411 try: 

412 xf = float(x) 

413 except Exception: 

414 return float("nan") 

415 if not _math.isfinite(xf): 

416 return float("nan") 

417 return _cos_impl(xf) 

418 

419 def tan(self, x): 

420 try: 

421 xf = float(x) 

422 except Exception: 

423 return float("nan") 

424 if not _math.isfinite(xf): 

425 return float("nan") 

426 return _tan_impl(xf) 

427 

428# 明确导出,便于惰性注册 import 

429__all__ = [*globals().get("__all__", []), "PureTrigBackend"]