Coverage for pyrc \ core \ solver \ handler.py: 71%

262 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-13 16:59 +0200

1# ------------------------------------------------------------------------------- 

2# Copyright (C) 2026 Joel Kimmich, Tim Jourdan 

3# ------------------------------------------------------------------------------ 

4# License 

5# This file is part of PyRC, distributed under GPL-3.0-or-later. 

6# ------------------------------------------------------------------------------ 

7 

8from __future__ import annotations 

9 

10import os.path 

11import time 

12import warnings 

13from collections import deque 

14from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor 

15from datetime import datetime 

16from multiprocessing import cpu_count 

17from typing import Any, Callable 

18 

19import numpy as np 

20from scipy.integrate import solve_ivp 

21from scipy.sparse import spmatrix, sparray 

22 

23from pyrc.core.settings import initial_settings, Settings, SolveSettings 

24from pyrc.core.solver.symbolic import SparseSymbolicEvaluator 

25from pyrc.core.components.templates import RCSolution 

26from pyrc.tools.science import get_free_ram_gb 

27 

28 

29class HomogeneousSystemHandler: 

30 def __init__( 

31 self, 

32 system_matrix: SparseSymbolicEvaluator | spmatrix | sparray, 

33 rc_solution: RCSolution, 

34 settings: Settings, 

35 print_progress=True, 

36 print_points: list | np.ndarray = None, 

37 batch_end=None, 

38 time_dependent_function=None, 

39 _initialize_temperature_function=True, 

40 **kwargs, 

41 ): 

42 """ 

43 System Handler that is called ("as function") by `solve_ivp`. 

44 

45 Parameters 

46 ---------- 

47 system_matrix : spmatrix | sparray | SparseSymbolicEvaluator 

48 The system matrix. 

49 If constant, the type must be a sparse scipy matrix. 

50 If time dependent symbols are contained, the system matrix must be given in a `SparseSymbolicEvaluator` 

51 object. 

52 rc_solution 

53 settings 

54 print_progress 

55 print_points 

56 batch_end 

57 time_dependent_function : Callable, optional 

58 A function that calculates all time dependent variables within the time step and returns them in the correct 

59 order as iterable (e.g. list). 

60 It gets parameters like this:\n 

61 ``time_dependent_function(time, temperature_vector)`` \n 

62 or\n 

63 ``time_dependent_function(time, temperature_vector, input_vector)``\.\n 

64 To not run into Errors just use ``*args``\, ``**kwargs`` at the end in case more values are passed then 

65 needed. 

66 _initialize_temperature_function : bool, optional 

67 If ``False``, the ``dtemperature_dt_function`` attribute is not set using the method `_init_temperature_equation()` 

68 Used in subclasses to prevent an early call of attributes that are created in the subclasses later on. 

69 kwargs : dict, optional 

70 Not used: Just to not raise an error if too much information is passed. 

71 """ 

72 self.system_matrix: spmatrix | sparray | SparseSymbolicEvaluator = system_matrix 

73 self.rc_solution = rc_solution 

74 self.settings = settings 

75 

76 self.time_dependent_function: Callable = time_dependent_function # must return the value(s) as iterable 

77 self.time_dependent_active: bool = True 

78 if time_dependent_function is None: 

79 self.time_dependent_active = False 

80 self.time_dependent_function = lambda *args, **keyword_args: [] 

81 

82 if print_progress and batch_end is None: 

83 import warnings 

84 

85 warnings.warn("Print progress might not work as expected because batch_end value is not given.") 

86 batch_end = np.inf 

87 self.batch_end = batch_end 

88 

89 # print progress 

90 if print_points is None: 

91 print_points = [0] 

92 if isinstance(print_points, np.ndarray): 

93 print_points = print_points.tolist() 

94 self.print_points: deque = deque(print_points) 

95 self.next_printed_time_step = self.print_points.popleft() 

96 self.print_progress = print_progress 

97 

98 # initialize the temperature equation using a lambda 

99 self.dtemperature_dt_function = lambda *args, **keyword_args: None 

100 if _initialize_temperature_function: 

101 # this can be switched off because it always have to run at the end of all subclasses inits 

102 self.dtemperature_dt_function = self._init_temperature_equation() 

103 

104 def _init_temperature_equation(self) -> Callable: 

105 if isinstance(self.system_matrix, SparseSymbolicEvaluator): 

106 assert self.time_dependent_active 

107 system_matrix = lambda v: self.system_matrix.evaluate(v) 

108 else: 

109 if self.time_dependent_active: 

110 warnings.warn( 

111 "Function to calculate time dependent values passed, but no time dependent system matrix detected.\n" 

112 "The passed function has no effect." 

113 ) 

114 return lambda temperature: self.system_matrix @ temperature 

115 return lambda temperature, v: system_matrix(v) @ temperature 

116 

117 def __call__(self, t, temperature): 

118 """ 

119 The function the solve_ivp is going to solve. 

120 

121 Parameters 

122 ---------- 

123 t 

124 temperature 

125 

126 Notes 

127 ----- 

128 The input vector that is saved during the iteration of the solver at the t_eval 

129 

130 Returns 

131 ------- 

132 np.ndarray : 

133 The resulting vector of the temperature derivative. 

134 """ 

135 if self.print_progress: 

136 self._update_progress_print(t) 

137 temperature = temperature.reshape(-1, 1) 

138 

139 if self.time_dependent_active: 

140 time_dependent_values = self.time_dependent_function(t, temperature) 

141 dtemperature_dt = self.dtemperature_dt_function(temperature, time_dependent_values) 

142 else: 

143 dtemperature_dt = self.dtemperature_dt_function(temperature) 

144 

145 return dtemperature_dt.flatten() 

146 

147 def set_new_t_eval(self, new_t_eval): 

148 """ 

149 Change the current t_eval. 

150 

151 Parameters 

152 ---------- 

153 new_t_eval : array_like 

154 The new t_eval. 

155 """ 

156 # Currently only used in the child class 

157 pass 

158 

159 def _update_progress_print(self, t): 

160 if t >= self.next_printed_time_step: 

161 # prevent initialization print out (is not fail save, but okay) 

162 if t != self.batch_end or (self.print_points and t == self.print_points[0]): 

163 if len(self.print_points) > 0: 

164 self.next_printed_time_step = self.print_points.popleft() 

165 self.print_out_progress(t) 

166 else: 

167 self.next_printed_time_step = np.inf 

168 # deactivate printing for better performance 

169 self.print_progress = False 

170 

171 def print_out_progress(self, t): 

172 """ 

173 Prints a formatted time stamp. 

174 

175 Parameters 

176 ---------- 

177 t : float | int 

178 The time to format/print. 

179 """ 

180 print(f"Progress: t = {self.format_t(t)}") 

181 

182 

183 @staticmethod 

184 def format_t(t): 

185 days = int(t) // 86400 

186 hours = (int(t) % 86400) // 3600 

187 minutes = (int(t) % 3600) // 60 

188 seconds = int(t) % 60 

189 return ( 

190 f"{days:>4} days, {hours:02}:{minutes:02}:{seconds:02} - current time: " 

191 f"{datetime.now().strftime('%d %H:%M:%S')}" 

192 ) 

193 

194 

195class InhomogeneousSystemHandler(HomogeneousSystemHandler): 

196 def __init__( 

197 self, 

198 system_matrix, 

199 input_matrix, 

200 input_vector, 

201 rc_solution, 

202 functions_list, 

203 kwargs_functions: dict, 

204 t_eval=None, 

205 time_dependent_function=None, 

206 use_parallelization=False, 

207 core_count=cpu_count(), 

208 print_progress=True, 

209 print_points=3600, 

210 settings: Settings = initial_settings, 

211 first_time: float | int = 0, 

212 **kwargs, 

213 ): 

214 super().__init__( 

215 system_matrix, 

216 rc_solution, 

217 settings, 

218 time_dependent_function=time_dependent_function, 

219 print_progress=print_progress, 

220 print_points=print_points, 

221 batch_end=t_eval[-1], 

222 _initialize_temperature_function=False, 

223 ) 

224 self.input_matrix = input_matrix 

225 self.input_vector = input_vector.reshape(-1, 1) 

226 self.functions_list = functions_list 

227 

228 self.after_iteration = np.inf 

229 self.next_eval_time = 0 

230 self.t_eval_iter = iter(()) 

231 self.set_new_t_eval(t_eval) 

232 

233 self.last_t = first_time 

234 self.last_eval_time = 0 

235 self.current_eval_time = 0 

236 self.use_current_eval_time = False 

237 

238 self.input_update_function = self._update_input_vector 

239 

240 # if use_parallelization: # not used because it is not working yet 

241 # # decide what is the best using the length of the input matrix. 

242 # # If len(input) <= 500 only use threads, otherwise multiprocessing 

243 # number_inputs = self.input_vector.shape[0] 

244 # if number_inputs <= np.inf: # switched off because pickle doesn't work 

245 # self.input_update_function = self._update_input_vector 

246 # elif number_inputs <= 50: 

247 # self.input_update_function = self._update_input_vector_threads 

248 # else: 

249 # self.input_update_function = self._update_input_vector_processes 

250 self.core_count = max(1, min(core_count, cpu_count())) 

251 self.kwargs_functions: dict = kwargs_functions 

252 

253 # initialize the temperature equation using a lambda 

254 self.dtemperature_dt_function = self._init_temperature_equation() 

255 

256 def _init_temperature_equation(self) -> Callable: 

257 if isinstance(self.system_matrix, SparseSymbolicEvaluator): 

258 system_matrix_fun = lambda v: self.system_matrix.evaluate(v) 

259 else: 

260 system_matrix_fun = lambda *args, **kwargs: self.system_matrix 

261 if isinstance(self.input_matrix, SparseSymbolicEvaluator): 

262 input_matrix_fun = lambda v: self.input_matrix.evaluate(v) 

263 else: 

264 input_matrix_fun = lambda *args, **kwargs: self.input_matrix 

265 return lambda temperature, input_v, v: system_matrix_fun(v) @ temperature + input_matrix_fun(v) @ input_v 

266 

267 def set_new_t_eval(self, new_t_eval): 

268 """ 

269 Change the current t_eval. 

270 

271 Parameters 

272 ---------- 

273 new_t_eval : array_like 

274 The new t_eval. 

275 """ 

276 if new_t_eval is None: 

277 self.t_eval_iter = iter([-np.inf]) 

278 self.after_iteration = -np.inf 

279 else: 

280 self.t_eval_iter = iter(new_t_eval) 

281 self.after_iteration = np.inf 

282 self.next_eval_time = next(self.t_eval_iter, self.after_iteration) 

283 self.use_current_eval_time = False 

284 

285 def calculate_kwargs(self, tau, temp_vector, _input_vector, **kwargs): 

286 return {name: fun(tau, temp_vector, _input_vector, **kwargs) for name, fun in self.kwargs_functions.items()} 

287 

288 def _update_input_vector_threads(self, t, temperature, **kwargs): 

289 with ThreadPoolExecutor() as executor: 

290 results = list(executor.map(lambda f: f(t, temperature, self.input_vector, **kwargs), self.functions_list)) 

291 self.input_vector = np.array(results).reshape(-1, 1) 

292 

293 @staticmethod 

294 def _worker_call(args): 

295 f, t, temperature, input_vector, kwargs = args 

296 return f(t, temperature, input_vector, **kwargs) 

297 

298 def _update_input_vector_processes(self, t, temperature, **kwargs): 

299 args_iter = [(f, t, temperature, self.input_vector.copy(), kwargs) for f in self.functions_list] 

300 with ProcessPoolExecutor(max_workers=self.core_count) as executor: 

301 results = list(executor.map(self._worker_call, args_iter)) 

302 self.input_vector = np.array(results).reshape(-1, 1) 

303 

304 def _update_input_vector(self, t, temperature, **kwargs): 

305 # TODO: Speedup by creating one big lambda function returning a vector instead of looping over all 

306 self.input_vector = np.array( 

307 [f(t, temperature, self.input_vector, **kwargs) for f in self.functions_list] 

308 ).reshape(-1, 1) 

309 

310 def __call__(self, t, temperature: np.ndarray): 

311 """ 

312 The function the solve_ivp is going to solve. 

313 

314 Parameters 

315 ---------- 

316 t 

317 temperature 

318 

319 Notes 

320 ----- 

321 The input vector that is saved during the iteration of the solver at the t_eval 

322 

323 Returns 

324 ------- 

325 np.ndarray : 

326 The resulting vector of the temperature derivative. 

327 """ 

328 if t < self.last_t: 

329 if t == self.batch_end or (self.last_t >= self.last_eval_time > t): 

330 # Delete the last result because the solver is in initialization or one time step was saved to early. 

331 # Revert the next_eval_time to the last value. 

332 self.next_eval_time = self.last_eval_time 

333 self.use_current_eval_time = True 

334 self.rc_solution.delete_last_input() 

335 if self.print_progress: 

336 self._update_progress_print(t) 

337 temperature = temperature.reshape(-1, 1) 

338 

339 kwargs = self.calculate_kwargs(t, temperature, self.input_vector) 

340 

341 self.input_update_function(t, temperature, **kwargs) 

342 

343 time_dependent_values = self.time_dependent_function(t, temperature, self.input_vector) 

344 

345 # Check if the input vector has to be saved. 

346 if t >= self.next_eval_time: 

347 # NOTE: This is not perfect, because the input vector of the next time step is saved for the result of 

348 # the last time step. However, the solver will perform such small iteration steps that this will not 

349 # become a problem. 

350 self.rc_solution.append_to_input(self.input_vector) 

351 self.last_eval_time = self.next_eval_time 

352 if self.use_current_eval_time: 

353 self.next_eval_time = self.current_eval_time 

354 self.use_current_eval_time = False 

355 else: 

356 self.next_eval_time = next(self.t_eval_iter, self.after_iteration) 

357 

358 dtemperature_dt = self.dtemperature_dt_function(temperature, self.input_vector, time_dependent_values) 

359 

360 self.last_t = t 

361 return dtemperature_dt.flatten() 

362 

363 

364class SolveIVPHandler: 

365 def __init__( 

366 self, 

367 system_handler: HomogeneousSystemHandler | InhomogeneousSystemHandler, 

368 max_saved_steps=None, 

369 method=None, 

370 max_step=None, 

371 rtol=None, 

372 atol=None, 

373 save_interval=None, 

374 save_path=None, 

375 save_prefix="", 

376 minimize_ram_usage=None, 

377 solve_settings=None, 

378 **kwargs, 

379 ): 

380 """ 

381 Handler of the solving process with solve_ivp. 

382 

383 Parameters 

384 ---------- 

385 system_handler : HomogeneousSystemHandler | InhomogeneousSystemHandler 

386 The system handler that holds the function to solve in __call__(). 

387 max_saved_steps : int, optional 

388 The maximum number of used seconds during one solve_ivp call. 

389 It defines the batch size in seconds. 

390 Using this prevents long solving time because the matrices become very big. 

391 method : str, optional 

392 The method to use to solve the system (see scipy.solve_ivp). 

393 max_step 

394 rtol 

395 atol 

396 save_interval 

397 save_path 

398 save_prefix : str, optional 

399 This is the beginning of the name of the pickle file that is saved during the solving. 

400 minimize_ram_usage : bool, optional 

401 If True, the solution is deleted if saved to minimize the RAM usage during runtime. 

402 kwargs 

403 """ 

404 if solve_settings is None: 

405 solve_settings = SolveSettings() 

406 self.solve_settings = solve_settings 

407 max_saved_steps, method, max_step, rtol, atol, save_interval, minimize_ram_usage = self.set_initial_values( 

408 max_saved_steps=max_saved_steps, 

409 method=method, 

410 max_step=max_step, 

411 rtol=rtol, 

412 atol=atol, 

413 save_interval=save_interval, 

414 minimize_ram_usage=minimize_ram_usage, 

415 ) 

416 self.system_handler: HomogeneousSystemHandler | InhomogeneousSystemHandler = system_handler 

417 self.max_saved_steps = int(max_saved_steps) 

418 self.method = method 

419 self.max_step = max_step 

420 self.rtol = rtol 

421 self.atol = atol 

422 

423 self.save_interval = save_interval 

424 self.save_counter = 0 

425 if save_path is None: 

426 save_path = self.system_handler.settings.save_folder_path 

427 if save_path is None: 

428 self.save_path = None 

429 else: 

430 self.save_path = os.path.normpath(save_path) 

431 self.save_prefix = save_prefix 

432 

433 if self.save_path is None and minimize_ram_usage: 

434 print("Minimize RAM usage deactivated because no save_path is declared in Settings.") 

435 print("This also deactivates every save during solving.") 

436 minimize_ram_usage = False 

437 self.minimize_ram_usage = minimize_ram_usage 

438 

439 self.kwargs = kwargs 

440 

441 def set_initial_values(self, **kwargs): 

442 result = [] 

443 settings_dict: dict = self.solve_settings.dict 

444 for key, arg in zip(kwargs.keys(), kwargs.values()): 

445 if arg is None: 

446 result.append(settings_dict[key]) 

447 else: 

448 result.append(arg) 

449 print(f"Solve setting {key} is overwritten by manual/coded value: {arg}") 

450 return result 

451 

452 def get_batches(self, t_span): 

453 start, end = t_span 

454 splits = np.arange(start, end, self.max_saved_steps) 

455 if splits[-1] != end: 

456 splits = np.append(splits, end) 

457 return splits 

458 

459 def solve( 

460 self, 

461 t_span, 

462 y0, 

463 t_eval=None, 

464 continued_simulation: bool = False, 

465 expected_solution_size_mb=5000, 

466 ): 

467 """ 

468 Runs the solve_ivp in batches. 

469 

470 Parameters 

471 ---------- 

472 t_span : tuple[int | float, int| float] : 

473 The start and end of the simulation in seconds. 

474 Usually it starts at 0: (0, end_seconds) 

475 See also: scipy.integrate.solve_ivp() 

476 y0 : np.ndarray | Iterable | float | int 

477 The initial state of the system. 

478 See also: scipy.integrate.solve_ivp() 

479 t_eval : np.ndarray | Iterable | float | int, optional 

480 Times at which to store the computed solution, must be sorted and lie within `t_span`\. 

481 If None (default), use points selected by the solver. 

482 See also: scipy.integrate.solve_ivp() 

483 continued_simulation : bool, optional 

484 If True, the first value of the first batch is not kept, because the simulation is continued. 

485 expected_solution_size_mb : int | float, optional 

486 The expected solution size in Megabytes. Is used for RAM-Management: if not enough free memory is available, 

487 the method waits for 10 seconds and tries again for a total of 360 times before raising an error. 

488 The solution size depends on the size of the RC network and number of saved time steps. 

489 """ 

490 batches = self.get_batches(t_span) 

491 

492 list_t, list_y = [], [] 

493 current_y0 = y0 

494 

495 if self.save_path is not None: 

496 intermediate_save_prefix = os.path.join(self.save_path, self.save_prefix + f"_{int(t_span[-1])}_") 

497 else: 

498 intermediate_save_prefix = "dummy_path" 

499 

500 for i in range(len(batches) - 1): 

501 batch_start, batch_end = float(batches[i]), float(batches[i + 1]) 

502 if t_eval is not None: 

503 if i == 0 and not continued_simulation: 

504 mask = (t_eval >= batch_start) & (t_eval <= batch_end) 

505 else: 

506 # exclude the batch_start for every batch except the first one to prevent duplicates 

507 mask = (t_eval > batch_start) & (t_eval <= batch_end) 

508 t_eval_batch = t_eval[mask] 

509 else: 

510 t_eval_batch = None 

511 

512 self.system_handler.set_new_t_eval(t_eval_batch) 

513 

514 last_step_not_in_solution = False 

515 if t_eval_batch.size == 0 or t_eval_batch[-1] != batch_end: 

516 # The last step must be returned by solve_ivp anyway to pass it as new y0 in the next batch 

517 t_eval_batch = np.append(t_eval_batch, batch_end) 

518 last_step_not_in_solution = True 

519 

520 # update end of batch (used for correct print out and saves) 

521 self.system_handler.batch_end = batch_end 

522 

523 sol: Any = solve_ivp( 

524 fun=self.system_handler, 

525 t_span=(batch_start, batch_end), 

526 y0=current_y0, 

527 method=self.method, 

528 t_eval=t_eval_batch, 

529 max_step=self.max_step, 

530 rtol=self.rtol, 

531 atol=self.atol, 

532 **self.kwargs, 

533 ) 

534 current_y0 = sol.y[:, -1] 

535 

536 if last_step_not_in_solution: 

537 # delete last solution if not requested in settings (but it was needed for y0) 

538 self.system_handler.rc_solution.delete_last_input() 

539 new_time_steps = sol.t[:-1] 

540 new_y_values = sol.y[:, :-1] 

541 else: 

542 new_time_steps = sol.t 

543 new_y_values = sol.y 

544 

545 if new_time_steps.size != 0: 

546 list_t.append(new_time_steps) 

547 list_y.append(new_y_values) 

548 # increase counter only if new solution was added. Otherwise it saves the same solution / an empty one 

549 self.save_counter += 1 

550 

551 if self.save_counter // self.save_interval > 0 or ( 

552 batch_end == t_span[-1] 

553 and (self.system_handler.rc_solution.last_saved_timestep_index > 0 or self.minimize_ram_usage) 

554 ): 

555 save_path = f"{intermediate_save_prefix}{float(batch_end):09.0f}_s.pickle" 

556 if self.minimize_ram_usage: 

557 # save the solution and delete it afterward 

558 assert self.save_path is not None 

559 self.system_handler.rc_solution.save_to_file_only( 

560 t=np.concatenate(list_t), 

561 y=np.concatenate(list_y, axis=1).T, 

562 path_with_name_and_ending=save_path, 

563 ) 

564 else: 

565 self.add_to_solution(list_t, list_y) 

566 if self.save_path is not None: 

567 self.system_handler.rc_solution.save_new_solution(save_path) 

568 

569 list_y = [] 

570 list_t = [] 

571 self.save_counter = 0 

572 

573 # Make print out for the last time step that else will be missed 

574 if self.system_handler.print_progress: 

575 if self.system_handler.next_printed_time_step <= t_span[-1]: 

576 self.system_handler.print_out_progress(t_span[-1]) 

577 

578 save_path = f"{intermediate_save_prefix}result.pickle" 

579 if self.minimize_ram_usage: 

580 # load solution in and save it as one big solution 

581 # it will load all intermediate saves because it will not find the requested path with "_result.pickle" 

582 # first check, if enough RAM is available 

583 assert self.save_path is not None 

584 loop_counter = 0 

585 max_loops = 360 

586 while get_free_ram_gb() < expected_solution_size_mb / 1000 and loop_counter < max_loops: 

587 loop_counter += 1 

588 time.sleep(10) # waiting for more free RAM 

589 if not loop_counter >= max_loops: 

590 self.system_handler.rc_solution.load_solution(save_path) 

591 # Now save the accumulated result in one pickle with the "_result" ending 

592 self.system_handler.rc_solution.save_solution(save_path) 

593 else: 

594 print("Not enough RAM for over one hour. Incremental solutions are not combined.") 

595 else: 

596 self.add_to_solution(list_t, list_y) 

597 if self.save_path is not None: 

598 self.system_handler.rc_solution.save_solution(save_path) 

599 

600 def add_to_solution(self, new_t: list, new_y: list): 

601 if self.system_handler.rc_solution.t is None: 

602 self.system_handler.rc_solution.t = np.concatenate(new_t) 

603 else: 

604 if new_t is not None: 

605 self.system_handler.rc_solution.t = np.concatenate([self.system_handler.rc_solution.t, *new_t]) 

606 if self.system_handler.rc_solution.y is None: 

607 self.system_handler.rc_solution.y = np.concatenate(new_y, axis=1).T 

608 else: 

609 if new_y is not None: 

610 self.system_handler.rc_solution.y = np.concatenate([self.system_handler.rc_solution.y.T, *new_y], 

611 axis=1).T