Coverage for pyrc \ visualization \ plot.py: 20%

303 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-01 13:11 +0200

1# ------------------------------------------------------------------------------- 

2# Copyright (C) 2026 Joel Kimmich, Tim Jourdan 

3# ------------------------------------------------------------------------------ 

4# License 

5# This file is part of PyRC, distributed under GPL-3.0-or-later. 

6# ------------------------------------------------------------------------------ 

7from __future__ import annotations 

8from datetime import datetime, timedelta 

9from typing import TYPE_CHECKING 

10 

11import numpy as np 

12from matplotlib import pyplot as plt 

13import matplotlib.dates as mdates 

14import matplotlib.ticker as ticker 

15 

16from pyrc.tools.plotting import format_date_x_axis, custom_numeric_ticks_formatter 

17from pyrc.tools.science import cm_to_inch, is_numeric 

18 

19if TYPE_CHECKING: 

20 pass 

21 

22 

23# plt.style.use('tableau-colorblind10') 

24# print(plt.style.available) 

25# load style sheet 

26# plt.style.use(os.path.normpath(os.path.join(package_dir, "visualization", "plotsettings.mplstyle"))) 

27# plt.rcParams["axes.prop_cycle"] = plt.cycler("color", ) 

28 

29 

30class PlotMixin(object): 

31 def __init__(self, x, ys, y_title="Heat Flux / W", marker_size=6, x_title=None, width_mm=160, height_mm=90): 

32 self.fig, self.ax = plt.subplots(layout="constrained") 

33 self.fig.set_size_inches(cm_to_inch(width_mm / 10), cm_to_inch(height_mm / 10)) 

34 

35 self.x = x 

36 self.ys = ys 

37 

38 if isinstance(self.x[0], datetime): 

39 self.x_is_datetime = True 

40 else: 

41 self.x_is_datetime = False 

42 

43 self.y_title = y_title 

44 self.x_title = x_title 

45 self.marker_size = marker_size 

46 

47 self.lines = None 

48 self.labels = [] 

49 

50 self.next_color = self.color_iter() 

51 self.next_line_style = self.line_style_iter() 

52 self.next_marker = self.marker_iter() 

53 

54 def __del__(self): 

55 if self.fig is not None: 

56 plt.close(self.fig) 

57 

58 @property 

59 def markers(self) -> list[str]: 

60 return ["o", "s", "^", "v", "<", ">", "d", "p", "*", "h"] 

61 

62 def marker_iter(self): 

63 markers = self.markers 

64 i = 0 

65 n = len(markers) 

66 while True: 

67 yield markers[i] 

68 i = (i + 1) % n 

69 

70 @property 

71 def colors(self) -> list[tuple[float, float, float]]: 

72 return [ 

73 (0.0051932, 0.098238, 0.34984), 

74 (0.98135, 0.80041, 0.98127), 

75 (0.51125, 0.5109, 0.1933), 

76 (0.1333, 0.37528, 0.3794), 

77 (0.94661, 0.61422, 0.41977), 

78 (0.066899, 0.26319, 0.37759), 

79 (0.9929, 0.70485, 0.70411), 

80 (0.30238, 0.45028, 0.30012), 

81 (0.75427, 0.56503, 0.21176), 

82 (0.40297, 0.48047, 0.24473), 

83 ] 

84 

85 def line_style_iter(self): 

86 n = len(self.colors) 

87 styles = ["-", "--", ":", "-."] 

88 i = 0 

89 while True: 

90 yield styles[i // n] 

91 i += 1 

92 if i >= n * len(styles): 

93 i = 0 

94 

95 def color_iter(self): 

96 colors = self.colors 

97 i = 0 

98 n = len(colors) 

99 while True: 

100 yield tuple(colors[i]) 

101 i = (i + 1) % n 

102 

103 def _add_line(self, line): 

104 if self.lines is None: 

105 self.lines = line 

106 else: 

107 self.lines = self.lines + line 

108 

109 def _add_line_and_label(self, line, label=None): 

110 if self.lines is None: 

111 self.lines = line 

112 else: 

113 self.lines = self.lines + line 

114 self.labels.append(label) 

115 

116 def format(self): 

117 if self.x_title is not None: 

118 self.ax.set_xlabel(self.x_title) 

119 if self.y_title is not None: 

120 self.ax.set_ylabel(self.y_title) 

121 self.ax.grid(True) 

122 self.ax.set_xlim(left=self.x[0], right=self.x[-1]) 

123 if self.x_is_datetime: 

124 self.formate_x_datetime() 

125 

126 def formate_x_datetime(self, start=None, end=None): 

127 """ 

128 

129 Parameters 

130 ---------- 

131 start : None, optional 

132 Start wert 

133 end 

134 

135 Returns 

136 ------- 

137 

138 """ 

139 if start is None: 

140 start = self.x[0] 

141 if end is None: 

142 end = self.x[-1] 

143 format_date_x_axis(start, end, self.ax, return_version=False) 

144 

145 def format_numeric_ticks(self): 

146 formatter = ticker.FuncFormatter(custom_numeric_ticks_formatter) 

147 

148 for ax in self.fig.get_axes(): 

149 # Check if y-axis has numeric data 

150 try: 

151 current_formatter = ax.yaxis.get_major_formatter() 

152 if not isinstance(current_formatter, plt.matplotlib.dates.DateFormatter): 

153 ax.yaxis.set_major_formatter(formatter) 

154 except: 

155 pass 

156 

157 # Check x-axis 

158 if not self.x_is_datetime and not isinstance(self.x[0], str): 

159 try: 

160 current_formatter = ax.xaxis.get_major_formatter() 

161 if not isinstance(current_formatter, plt.matplotlib.dates.DateFormatter): 

162 ax.xaxis.set_major_formatter(formatter) 

163 except: 

164 pass 

165 

166 def show_legend(self, **kwargs): 

167 initial_kwargs = {"loc": "outside upper right", "ncols": min(4, len(self.labels))} 

168 initial_kwargs.update(kwargs) 

169 self.fig.legend(handles=self.lines, labels=self.labels, **initial_kwargs) 

170 

171 def show(self): 

172 self.format_numeric_ticks() 

173 plt.show() 

174 

175 def save(self, path): 

176 self.format_numeric_ticks() 

177 self.fig.savefig(path, dpi=600) 

178 

179 def close(self): 

180 if self.fig is not None: 

181 plt.close(self.fig) 

182 

183 

184class DoubleY(PlotMixin): 

185 def __init__( 

186 self, 

187 x, 

188 ys_left: list, 

189 ys_right: list, 

190 labels_left=None, 

191 labels_right=None, 

192 y_title_left="Heat Flux / W", 

193 y_title_right="Temperature / K", 

194 marker_size=6, 

195 **kwargs, 

196 ): 

197 if not isinstance(ys_left, list): 

198 ys_left = [ys_left] 

199 if not isinstance(ys_right, list): 

200 ys_right = [ys_right] 

201 ys_left = [np.array(y) for y in ys_left] 

202 ys_right = [np.array(y) for y in ys_right] 

203 

204 super().__init__(x=np.array(x), ys=ys_left, y_title=y_title_left, marker_size=marker_size, **kwargs) 

205 

206 self.ys_left = ys_left 

207 self.ys_right = ys_right 

208 self.y_title_right = y_title_right 

209 

210 self.ax_right = self.ax.twinx() 

211 

212 if labels_left is None: 

213 labels_left = [None] * len(ys_left) 

214 if labels_right is None: 

215 labels_right = [None] * len(ys_right) 

216 if not isinstance(labels_left, list): 

217 labels_left = [labels_left] 

218 if not isinstance(labels_right, list): 

219 labels_right = [labels_right] 

220 

221 self.labels_left = labels_left 

222 self.labels_right = labels_right 

223 

224 @property 

225 def ax_left(self): 

226 return self.ax 

227 

228 def scale_right_axis(self): 

229 """ 

230 Scales the right axis so that the major ticks matches the left one. 

231 """ 

232 left_ticks = self.ax.get_yticks() 

233 num_ticks = len(left_ticks) 

234 

235 # Get right axis data range 

236 right_data_min, right_data_max = self.ax_right.get_ylim() 

237 right_range = right_data_max - right_data_min 

238 

239 # Generate nice spacings: [1,2,5] * 10^n 

240 nice_spacings = [] 

241 for n in range(-10, 10): 

242 for base in [1, 2, 3, 4, 5, 6]: 

243 nice_spacings.append(base * 10 ** n) 

244 

245 # Find the minimum spacing that can cover the data range with num_ticks-1 intervals 

246 target_spacing = right_range / (num_ticks - 1) 

247 right_tick_spacing = min([s for s in nice_spacings if s >= target_spacing]) 

248 

249 # Calculate new right axis limits based on nice spacing 

250 right_min = np.floor(right_data_min / right_tick_spacing) * right_tick_spacing 

251 right_max = right_min + (num_ticks - 1) * right_tick_spacing 

252 

253 # Create right axis ticks 

254 right_ticks = np.linspace(right_min, right_max, num_ticks) 

255 

256 self.ax_right.set_ylim(right_min, right_max) 

257 self.ax_right.set_yticks(right_ticks) 

258 

259 def plot(self): 

260 

261 for i, y in enumerate(self.ys_left): 

262 self._add_line_and_label( 

263 self.ax.plot( 

264 self.x, 

265 y, 

266 label=self.labels_left[i], 

267 color=next(self.next_color), 

268 marker=next(self.next_marker), 

269 markersize=self.marker_size, 

270 linestyle="None", 

271 ), 

272 self.labels_left[i], 

273 ) 

274 

275 for i, y in enumerate(self.ys_right): 

276 self._add_line_and_label( 

277 self.ax_right.plot( 

278 self.x, 

279 y, 

280 label=self.labels_right[i], 

281 color=next(self.next_color), 

282 marker=next(self.next_marker), 

283 markersize=self.marker_size, 

284 linestyle="None", 

285 ), 

286 self.labels_right[i], 

287 ) 

288 

289 self.format() 

290 

291 def format(self): 

292 if self.x_title is not None: 

293 self.ax.set_xlabel(self.x_title) 

294 if self.y_title is not None: 

295 self.ax.set_ylabel(self.y_title) 

296 if self.y_title_right is not None: 

297 self.ax_right.set_ylabel(self.y_title_right) 

298 

299 self.ax.grid(True) 

300 

301 if not isinstance(self.x[0], str): 

302 dx = self.x[1] - self.x[0] if len(self.x) > 1 else 0 

303 self.ax.set_xlim(left=self.x[0] - dx / 2, right=self.x[-1] + dx / 2) 

304 

305 if self.x_is_datetime: 

306 self.formate_x_datetime(self.x[0] - dx / 2, self.x[-1] + dx / 2) 

307 

308 self.format_numeric_ticks() 

309 

310 

311class DoubleYSeparated(DoubleY): 

312 def __init__( 

313 self, 

314 x, 

315 ys_left: list, 

316 ys_right: list, 

317 labels=None, 

318 y_title_left="Heat Flux / W", 

319 y_title_right="Temperature / K", 

320 marker_size=6, 

321 same_marker=False, 

322 ): 

323 super().__init__( 

324 x, ys_left, ys_right, y_title_left=y_title_left, y_title_right=y_title_right, marker_size=marker_size 

325 ) 

326 

327 if labels is None: 

328 labels = [None] * len(ys_left) 

329 if not isinstance(labels, list): 

330 labels = [labels] 

331 self.labels = labels 

332 self.same_marker = same_marker 

333 

334 def plot(self): 

335 

336 for i, (y_left, y_right) in enumerate(zip(self.ys_left, self.ys_right)): 

337 marker = next(self.next_marker) 

338 

339 self._add_line( 

340 self.ax.plot( 

341 self.x, 

342 y_left, 

343 color="black", 

344 marker=marker, 

345 markersize=self.marker_size, 

346 linestyle="None", 

347 ) 

348 ) 

349 

350 if not self.same_marker: 

351 marker = next(self.next_marker) 

352 

353 self._add_line( 

354 self.ax_right.plot( 

355 self.x, 

356 y_right, 

357 color=self.colors[0], 

358 marker=marker, 

359 markersize=self.marker_size, 

360 linestyle="None", 

361 ) 

362 ) 

363 

364 if self.labels[i] is not None and not self.same_marker: 

365 self.ax.plot( 

366 [], 

367 [], 

368 color="black", 

369 marker=marker, 

370 linestyle="None", 

371 markersize=self.marker_size, 

372 label=self.labels[i], 

373 ) 

374 

375 self.format() 

376 

377 def format(self): 

378 super().format() 

379 self.ax.yaxis.label.set_color("black") 

380 self.ax.tick_params(axis="y", colors="black") 

381 self.ax_right.yaxis.label.set_color(self.colors[0]) 

382 self.ax_right.tick_params(axis="y", colors=self.colors[0]) 

383 

384 

385class LinePlot(PlotMixin): 

386 def __init__( 

387 self, 

388 x, 

389 ys: list | np.ndarray, 

390 labels=None, 

391 y_scale=1, 

392 y_title="Values", 

393 linewidth=1.8, 

394 x_title=None, 

395 width_mm=160, 

396 height_mm=90, 

397 ): 

398 if not isinstance(ys, list): 

399 if not (isinstance(ys, np.ndarray) and len(ys.shape) > 1 and ys.shape[0] > 1 and ys.shape[1] > 1): 

400 ys = [np.array(ys)] 

401 ys = [np.array(y) for y in ys] 

402 super().__init__(x=np.array(x), ys=ys, y_title=y_title, x_title=x_title, width_mm=width_mm, height_mm=height_mm) 

403 

404 if labels is None: 

405 labels = [None] * len(ys) 

406 if not isinstance(labels, list): 

407 labels = [labels] 

408 self.labels = labels 

409 

410 self.y_scale = y_scale 

411 self.line_width = linewidth 

412 

413 def plot(self, x=None, ys=None, labels=None): 

414 if x is None: 

415 x = self.x 

416 if ys is None: 

417 ys = self.ys 

418 if labels is None: 

419 labels = self.labels 

420 if not isinstance(labels, list): 

421 labels = [labels] 

422 for i, y in enumerate(ys): 

423 self._add_line( 

424 self.ax.plot( 

425 x, 

426 np.array(y) * self.y_scale, 

427 label=labels[i], 

428 color=next(self.next_color), 

429 linewidth=self.line_width, 

430 linestyle=next(self.next_line_style), 

431 ) 

432 ) 

433 self.format() 

434 

435 def add_points(self, x, ys, labels=None): 

436 """ 

437 Add points to the plot (e.g. measurement data to interpolated lines) 

438 

439 The ys values are not scaled with self.y_scale! 

440 

441 Parameters 

442 ---------- 

443 x : float | int | np.ndarray | list 

444 x values 

445 ys : float | int | np.ndarray | list 

446 y values. If two-dimensional several point plots are added with different colors. 

447 labels : list | str, optional 

448 The labels of the point groups, accordingly 

449 """ 

450 if not isinstance(x, np.ndarray): 

451 x = np.array(x) 

452 if not isinstance(ys, np.ndarray): 

453 ys = np.array(ys) 

454 ys = np.atleast_2d(ys) 

455 if labels is None: 

456 labels = [None] * len(ys) 

457 if not isinstance(labels, list): 

458 labels = [labels] 

459 assert len(labels) == len(ys) 

460 

461 for i, y in enumerate(ys): 

462 self._add_line_and_label( 

463 self.ax.plot( 

464 x, 

465 np.array(y), 

466 label=labels[i], 

467 color=next(self.next_color), 

468 linestyle="none", 

469 marker=next(self.next_marker), 

470 ), 

471 labels[i] 

472 ) 

473 self.format() 

474 

475 def plot_marker(self): 

476 for i, y in enumerate(self.ys): 

477 self.ax.plot( 

478 self.x, 

479 np.array(y) * self.y_scale, 

480 label=self.labels[i], 

481 color=next(self.next_color), 

482 marker=next(self.next_marker), 

483 ) 

484 self.format() 

485 

486 def plot_stack(self): 

487 self.ax.stackplot( 

488 self.x, 

489 *[y * self.y_scale for y in self.ys], 

490 labels=self.labels, 

491 colors=self.colors, 

492 ) 

493 self.format() 

494 

495 def format(self): 

496 super().format() 

497 self.format_numeric_ticks() 

498 

499 

500class TimePlot(LinePlot): 

501 def format(self): 

502 super().format() 

503 self.formate_x_datetime() 

504 self.format_numeric_ticks() 

505 

506 

507class BarPlot(PlotMixin): 

508 def __init__(self, bar_values, bar_positions=None, y_title="Heat / kWh", **kwargs): 

509 assert len(bar_values) > 1 

510 if bar_positions is None: 

511 bar_positions = range(len(bar_values)) 

512 super().__init__(x=bar_positions, ys=bar_values, y_title=y_title, **kwargs) 

513 

514 @property 

515 def bar_width(self): 

516 if len(self.x) > 1: 

517 diff = np.diff(self.x) 

518 min_diff = np.min(diff) 

519 return min_diff * 0.8 

520 return 1 

521 

522 def plot(self): 

523 self.ax.bar(self.x, self.ys, width=self.bar_width, color=next(self.next_color)) 

524 self.format() 

525 

526 def format(self): 

527 super().format() 

528 if is_numeric(self.x[0]): 

529 self.ax.set_xlim(left=self.x[0] - self.bar_width / 0.8 / 2, right=self.x[-1] + self.bar_width / 0.8 / 2) 

530 self.format_numeric_ticks() 

531 

532 

533class TimeBarPlot(BarPlot): 

534 def __init__(self, bar_values, bar_positions, y_title="Heat / kWh", **kwargs): 

535 super().__init__(bar_values, bar_positions, y_title=y_title, **kwargs) 

536 

537 @property 

538 def bar_width(self) -> float: 

539 if len(self.x) > 1: 

540 date_nums = mdates.date2num(self.x) 

541 time_diffs = np.diff(date_nums) 

542 min_diff = time_diffs[0] 

543 return float(min_diff * 0.8) 

544 else: 

545 return 1 / 24 

546 

547 def format(self): 

548 super().format() 

549 format_date_x_axis(self.x[0], self.x[-1], self.ax, return_version=False) 

550 self.ax.set_xlim( 

551 left=self.x[0] - timedelta(days=self.bar_width / 0.8 / 2), 

552 right=self.x[-1] + timedelta(days=self.bar_width / 0.8 / 2), 

553 ) 

554 self.format_numeric_ticks() 

555 

556 

557def seconds_to_dates(seconds: list, start_date=datetime(2023, 1, 1), return_array=False) -> list | np.ndarray: 

558 result = [start_date + timedelta(seconds=int(s)) for s in seconds] 

559 if return_array: 

560 return np.array(result) 

561 return result 

562 

563 

564def format_x_axis_to_date(fig, ax): 

565 ax.xaxis.set_major_locator(mdates.DayLocator(bymonthday=(1, 7, 14, 21, 28))) 

566 ax.xaxis.set_minor_locator(mdates.DayLocator(interval=1)) 

567 ax.xaxis.set_major_formatter(mdates.DateFormatter("%d.%m.")) 

568 ax.grid(True) 

569 # for label in ax.get_xticklabels(which="major"): 

570 # label.set(rotation=30, horizontalalignment="right") 

571 ax.tick_params(axis="x", which="minor", bottom=True) 

572 fig.autofmt_xdate() 

573 return fig, ax