Coverage for pyrc \ core \ solver \ handler.py: 70%

253 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-29 14:14 +0200

1# ------------------------------------------------------------------------------- 

2# Copyright (C) 2026 Joel Kimmich, Tim Jourdan 

3# ------------------------------------------------------------------------------ 

4# License 

5# This file is part of PyRC, distributed under GPL-3.0-or-later. 

6# ------------------------------------------------------------------------------ 

7 

8from __future__ import annotations 

9 

10import os.path 

11import time 

12import warnings 

13from collections import deque 

14from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor 

15from datetime import datetime 

16from multiprocessing import cpu_count 

17from typing import Any, Callable 

18 

19import numpy as np 

20from scipy.integrate import solve_ivp 

21from scipy.sparse import spmatrix, sparray 

22 

23from pyrc.core.settings import initial_settings, Settings, SolveSettings 

24from pyrc.core.solver.symbolic import SparseSymbolicEvaluator 

25from pyrc.core.components.templates import RCSolution 

26from pyrc.tools.science import get_free_ram_gb 

27 

28 

29class HomogeneousSystemHandler: 

30 def __init__( 

31 self, 

32 system_matrix: SparseSymbolicEvaluator | spmatrix | sparray, 

33 rc_solution: RCSolution, 

34 settings: Settings, 

35 print_progress=True, 

36 print_points: list | np.ndarray = None, 

37 batch_end=None, 

38 time_dependent_function=None, 

39 _initialize_temperature_function=True, 

40 **kwargs, 

41 ): 

42 """ 

43 System Handler that is called ("as function") by ``solve_ivp``. 

44 

45 Parameters 

46 ---------- 

47 system_matrix : spmatrix | sparray | SparseSymbolicEvaluator 

48 The system matrix. 

49 If constant, the type must be a sparse scipy matrix. 

50 If time dependent symbols are contained, the system matrix must be given in a `SparseSymbolicEvaluator` 

51 object. 

52 rc_solution 

53 settings 

54 print_progress 

55 print_points 

56 batch_end 

57 time_dependent_function : Callable, optional 

58 A function that calculates all time dependent variables within the time step and returns them in the same order as 

59 self.get_time_dependent_symbols(). 

60 It gets parameters like this:\n 

61 ``time_dependent_function(time, temperature_vector)`` \n 

62 or\n 

63 ``time_dependent_function(time, temperature_vector, input_vector)``\\.\n 

64 This function is required if time dependent symbols exist. 

65 It must return an iterable (e.g. list). 

66 To not run into Errors just use ``*args``\\, ``**kwargs`` at the end in case more values are passed then 

67 needed. 

68 _initialize_temperature_function : bool, optional 

69 If ``False``, the ``dtemperature_dt_function`` attribute is not set using the method `_init_temperature_equation()` 

70 Used in subclasses to prevent an early call of attributes that are created in the subclasses later on. 

71 kwargs : dict, optional 

72 Not used: Just to not raise an error if too much information is passed. 

73 """ 

74 self.system_matrix: spmatrix | sparray | SparseSymbolicEvaluator = system_matrix 

75 self.rc_solution = rc_solution 

76 self.settings = settings 

77 

78 self.time_dependent_function: Callable = time_dependent_function # must return the value(s) as iterable 

79 self.time_dependent_active: bool = True 

80 if time_dependent_function is None: 

81 self.time_dependent_active = False 

82 self.time_dependent_function = lambda *args, **keyword_args: [] 

83 

84 if print_progress and batch_end is None: 

85 import warnings 

86 

87 warnings.warn("Print progress might not work as expected because batch_end value is not given.") 

88 batch_end = np.inf 

89 self.batch_end = batch_end 

90 

91 # print progress 

92 if print_points is None: 

93 print_points = [0] 

94 if isinstance(print_points, np.ndarray): 

95 print_points = print_points.tolist() 

96 self.print_points: deque = deque(print_points) 

97 self.next_printed_time_step = self.print_points.popleft() 

98 self.print_progress = print_progress 

99 

100 # initialize the temperature equation using a lambda 

101 self.dtemperature_dt_function = lambda *args, **keyword_args: None 

102 if _initialize_temperature_function: 

103 # this can be switched off because it always have to run at the end of all subclasses inits 

104 self.dtemperature_dt_function = self._init_temperature_equation() 

105 

106 def _init_temperature_equation(self) -> Callable: 

107 if isinstance(self.system_matrix, SparseSymbolicEvaluator): 

108 assert self.time_dependent_active 

109 system_matrix = lambda v: self.system_matrix.evaluate(v) 

110 else: 

111 if self.time_dependent_active: 

112 warnings.warn( 

113 "Function to calculate time dependent values passed, but no time dependent system matrix detected.\n" 

114 "The passed function has no effect." 

115 ) 

116 return lambda temperature: self.system_matrix @ temperature 

117 return lambda temperature, v: system_matrix(v) @ temperature 

118 

119 def __call__(self, t, temperature): 

120 """ 

121 The function the solve_ivp is going to solve. 

122 

123 Parameters 

124 ---------- 

125 t 

126 temperature 

127 

128 Notes 

129 ----- 

130 The input vector that is saved during the iteration of the solver at the t_eval 

131 

132 Returns 

133 ------- 

134 np.ndarray : 

135 The resulting vector of the temperature derivative. 

136 """ 

137 if self.print_progress: 

138 self._update_progress_print(t) 

139 temperature = temperature.reshape(-1, 1) 

140 

141 if self.time_dependent_active: 

142 time_dependent_values = self.time_dependent_function(t, temperature) 

143 dtemperature_dt = self.dtemperature_dt_function(temperature, time_dependent_values) 

144 else: 

145 dtemperature_dt = self.dtemperature_dt_function(temperature) 

146 

147 return dtemperature_dt.flatten() 

148 

149 def set_new_t_eval(self, new_t_eval): 

150 """ 

151 Change the current t_eval. 

152 

153 Parameters 

154 ---------- 

155 new_t_eval : array_like 

156 The new t_eval. 

157 """ 

158 # Currently only used in the child class 

159 pass 

160 

161 def _update_progress_print(self, t): 

162 if t >= self.next_printed_time_step: 

163 # prevent initialization print out (is not fail save, but okay) 

164 if t != self.batch_end or (self.print_points and t == self.print_points[0]): 

165 if len(self.print_points) > 0: 

166 self.next_printed_time_step = self.print_points.popleft() 

167 self.print_out_progress(t) 

168 else: 

169 self.next_printed_time_step = np.inf 

170 # deactivate printing for better performance 

171 self.print_progress = False 

172 

173 def print_out_progress(self, t): 

174 """ 

175 Prints a formatted time stamp. 

176 

177 Parameters 

178 ---------- 

179 t : float | int 

180 The time to format/print. 

181 """ 

182 print(f"Progress: t = {self.format_t(t)}") 

183 

184 @staticmethod 

185 def format_t(t): 

186 days = int(t) // 86400 

187 hours = (int(t) % 86400) // 3600 

188 minutes = (int(t) % 3600) // 60 

189 seconds = int(t) % 60 

190 return ( 

191 f"{days:>4} days, {hours:02}:{minutes:02}:{seconds:02} - current time: " 

192 f"{datetime.now().strftime('%d %H:%M:%S')}" 

193 ) 

194 

195 

196class InhomogeneousSystemHandler(HomogeneousSystemHandler): 

197 def __init__( 

198 self, 

199 system_matrix, 

200 input_matrix, 

201 input_vector, 

202 rc_solution, 

203 functions_list, 

204 kwargs_functions: dict, 

205 t_eval=None, 

206 time_dependent_function=None, 

207 use_parallelization=False, 

208 core_count=cpu_count(), 

209 print_progress=True, 

210 print_points=3600, 

211 settings: Settings = initial_settings, 

212 first_time: float | int = 0, 

213 **kwargs, 

214 ): 

215 super().__init__( 

216 system_matrix, 

217 rc_solution, 

218 settings, 

219 time_dependent_function=time_dependent_function, 

220 print_progress=print_progress, 

221 print_points=print_points, 

222 batch_end=t_eval[-1], 

223 _initialize_temperature_function=False, 

224 ) 

225 self.input_matrix = input_matrix 

226 self.input_vector = input_vector.reshape(-1, 1) 

227 self.functions_list = functions_list 

228 

229 self.after_iteration = np.inf 

230 self.next_eval_time = 0 

231 self.t_eval_iter = iter(()) 

232 self.set_new_t_eval(t_eval) 

233 

234 self.last_t = first_time 

235 self.last_eval_time = 0 

236 self.current_eval_time = 0 

237 self.use_current_eval_time = False 

238 

239 self.input_update_function = self._update_input_vector 

240 

241 # if use_parallelization: # not used because it is not working yet 

242 # # decide what is the best using the length of the input matrix. 

243 # # If len(input) <= 500 only use threads, otherwise multiprocessing 

244 # number_inputs = self.input_vector.shape[0] 

245 # if number_inputs <= np.inf: # switched off because pickle doesn't work 

246 # self.input_update_function = self._update_input_vector 

247 # elif number_inputs <= 50: 

248 # self.input_update_function = self._update_input_vector_threads 

249 # else: 

250 # self.input_update_function = self._update_input_vector_processes 

251 self.core_count = max(1, min(core_count, cpu_count())) 

252 self.kwargs_functions: dict = kwargs_functions 

253 

254 # initialize the temperature equation using a lambda 

255 self.dtemperature_dt_function = self._init_temperature_equation() 

256 

257 def _init_temperature_equation(self) -> Callable: 

258 if isinstance(self.system_matrix, SparseSymbolicEvaluator): 

259 system_matrix_fun = lambda v: self.system_matrix.evaluate(v) 

260 else: 

261 system_matrix_fun = lambda *args, **kwargs: self.system_matrix 

262 if isinstance(self.input_matrix, SparseSymbolicEvaluator): 

263 input_matrix_fun = lambda v: self.input_matrix.evaluate(v) 

264 else: 

265 input_matrix_fun = lambda *args, **kwargs: self.input_matrix 

266 return lambda temperature, input_v, v: system_matrix_fun(v) @ temperature + input_matrix_fun(v) @ input_v 

267 

268 def set_new_t_eval(self, new_t_eval): 

269 """ 

270 Change the current t_eval. 

271 

272 Parameters 

273 ---------- 

274 new_t_eval : array_like 

275 The new t_eval. 

276 """ 

277 if new_t_eval is None: 

278 self.t_eval_iter = iter([-np.inf]) 

279 self.after_iteration = -np.inf 

280 else: 

281 self.t_eval_iter = iter(new_t_eval) 

282 self.after_iteration = np.inf 

283 self.next_eval_time = next(self.t_eval_iter, self.after_iteration) 

284 self.use_current_eval_time = False 

285 

286 def calculate_kwargs(self, tau, temp_vector, _input_vector, **kwargs): 

287 return {name: fun(tau, temp_vector, _input_vector, **kwargs) for name, fun in self.kwargs_functions.items()} 

288 

289 def _update_input_vector_threads(self, t, temperature, **kwargs): 

290 with ThreadPoolExecutor() as executor: 

291 results = list(executor.map(lambda f: f(t, temperature, self.input_vector, **kwargs), self.functions_list)) 

292 self.input_vector = np.array(results).reshape(-1, 1) 

293 

294 @staticmethod 

295 def _worker_call(args): 

296 f, t, temperature, input_vector, kwargs = args 

297 return f(t, temperature, input_vector, **kwargs) 

298 

299 def _update_input_vector_processes(self, t, temperature, **kwargs): 

300 args_iter = [(f, t, temperature, self.input_vector.copy(), kwargs) for f in self.functions_list] 

301 with ProcessPoolExecutor(max_workers=self.core_count) as executor: 

302 results = list(executor.map(self._worker_call, args_iter)) 

303 self.input_vector = np.array(results).reshape(-1, 1) 

304 

305 def _update_input_vector(self, t, temperature, **kwargs): 

306 # TODO: Speedup by creating one big lambda function returning a vector instead of looping over all 

307 self.input_vector = np.array( 

308 [f(t, temperature, self.input_vector, **kwargs) for f in self.functions_list] 

309 ).reshape(-1, 1) 

310 

311 def __call__(self, t, temperature: np.ndarray): 

312 """ 

313 The function the solve_ivp is going to solve. 

314 

315 Parameters 

316 ---------- 

317 t 

318 temperature 

319 

320 Notes 

321 ----- 

322 The input vector that is saved during the iteration of the solver at the t_eval 

323 

324 Returns 

325 ------- 

326 np.ndarray : 

327 The resulting vector of the temperature derivative. 

328 """ 

329 if t < self.last_t: 

330 if t == self.batch_end or (self.last_t >= self.last_eval_time > t): 

331 # Delete the last result because the solver is in initialization or one time step was saved to early. 

332 # Revert the next_eval_time to the last value. 

333 self.next_eval_time = self.last_eval_time 

334 self.use_current_eval_time = True 

335 self.rc_solution.delete_last_input() 

336 if self.print_progress: 

337 self._update_progress_print(t) 

338 temperature = temperature.reshape(-1, 1) 

339 

340 kwargs = self.calculate_kwargs(t, temperature, self.input_vector) 

341 

342 self.input_update_function(t, temperature, **kwargs) 

343 

344 time_dependent_values = self.time_dependent_function(t, temperature, self.input_vector) 

345 

346 # Check if the input vector has to be saved. 

347 if t >= self.next_eval_time: 

348 # NOTE: This is not perfect, because the input vector of the next time step is saved for the result of 

349 # the last time step. However, the solver will perform such small iteration steps that this will not 

350 # become a problem. 

351 self.rc_solution.append_to_input(self.input_vector) 

352 self.last_eval_time = self.next_eval_time 

353 if self.use_current_eval_time: 

354 self.next_eval_time = self.current_eval_time 

355 self.use_current_eval_time = False 

356 else: 

357 self.next_eval_time = next(self.t_eval_iter, self.after_iteration) 

358 

359 dtemperature_dt = self.dtemperature_dt_function(temperature, self.input_vector, time_dependent_values) 

360 

361 self.last_t = t 

362 return dtemperature_dt.flatten() 

363 

364 

365class SolveIVPHandler: 

366 def __init__( 

367 self, 

368 system_handler: HomogeneousSystemHandler | InhomogeneousSystemHandler, 

369 max_saved_steps=None, 

370 method=None, 

371 max_step=None, 

372 rtol=None, 

373 atol=None, 

374 save_interval=None, 

375 save_path=None, 

376 save_prefix="", 

377 minimize_ram_usage=None, 

378 solve_settings=None, 

379 **kwargs, 

380 ): 

381 """ 

382 Handler of the solving process with solve_ivp. 

383 

384 Parameters 

385 ---------- 

386 system_handler : HomogeneousSystemHandler | InhomogeneousSystemHandler 

387 The system handler that holds the function to solve in __call__(). 

388 max_saved_steps : int, optional 

389 The maximum number of used seconds during one solve_ivp call. 

390 It defines the batch size in seconds. 

391 Using this prevents long solving time because the matrices become very big. 

392 method : str, optional 

393 The method to use to solve the system (see scipy.solve_ivp). 

394 max_step 

395 rtol 

396 atol 

397 save_interval 

398 save_path 

399 save_prefix : str, optional 

400 This is the beginning of the name of the pickle file that is saved during the solving. 

401 minimize_ram_usage : bool, optional 

402 If True, the solution is deleted if saved to minimize the RAM usage during runtime. 

403 kwargs 

404 """ 

405 if solve_settings is None: 

406 solve_settings = SolveSettings() 

407 self.solve_settings = solve_settings 

408 max_saved_steps, method, max_step, rtol, atol, save_interval, minimize_ram_usage = self.set_initial_values( 

409 max_saved_steps=max_saved_steps, 

410 method=method, 

411 max_step=max_step, 

412 rtol=rtol, 

413 atol=atol, 

414 save_interval=save_interval, 

415 minimize_ram_usage=minimize_ram_usage, 

416 ) 

417 self.system_handler: HomogeneousSystemHandler | InhomogeneousSystemHandler = system_handler 

418 self.max_saved_steps = int(max_saved_steps) 

419 self.method = method 

420 self.max_step = max_step 

421 self.rtol = rtol 

422 self.atol = atol 

423 

424 self.save_interval = save_interval 

425 self.save_counter = 0 

426 if save_path is None: 

427 save_path = self.system_handler.settings.save_folder_path 

428 if save_path is None: 

429 self.save_path = None 

430 else: 

431 self.save_path = os.path.normpath(save_path) 

432 self.save_prefix = save_prefix 

433 

434 if self.save_path is None and minimize_ram_usage: 

435 print("Minimize RAM usage deactivated because no save_path is declared in Settings.") 

436 print("This also deactivates every save during solving.") 

437 minimize_ram_usage = False 

438 self.minimize_ram_usage = minimize_ram_usage 

439 

440 self.kwargs = kwargs 

441 

442 def set_initial_values(self, **kwargs): 

443 result = [] 

444 settings_dict: dict = self.solve_settings.dict 

445 for key, arg in zip(kwargs.keys(), kwargs.values()): 

446 if arg is None: 

447 result.append(settings_dict[key]) 

448 else: 

449 result.append(arg) 

450 print(f"Solve setting {key} is overwritten by manual/coded value: {arg}") 

451 return result 

452 

453 def get_batches(self, t_span): 

454 start, end = t_span 

455 splits = np.arange(start, end, self.max_saved_steps) 

456 if splits[-1] != end: 

457 splits = np.append(splits, end) 

458 return splits 

459 

460 def solve( 

461 self, 

462 t_span, 

463 y0, 

464 t_eval=None, 

465 continued_simulation: bool = False, 

466 expected_solution_size_mb=5000, 

467 ): 

468 """ 

469 Runs the solve_ivp in batches. 

470 

471 Parameters 

472 ---------- 

473 t_span : tuple[int | float, int| float] : 

474 The start and end of the simulation in seconds. 

475 Usually it starts at 0: (0, end_seconds) 

476 See also: scipy.integrate.solve_ivp() 

477 y0 : np.ndarray | Iterable | float | int 

478 The initial state of the system. 

479 See also: scipy.integrate.solve_ivp() 

480 t_eval : np.ndarray | Iterable | float | int, optional 

481 Times at which to store the computed solution, must be sorted and lie within `t_span`\\. 

482 If None (default), use points selected by the solver. 

483 See also: scipy.integrate.solve_ivp() 

484 continued_simulation : bool, optional 

485 If True, the first value of the first batch is not kept, because the simulation is continued. 

486 expected_solution_size_mb : int | float, optional 

487 The expected solution size in Megabytes. Is used for RAM-Management: if not enough free memory is available, 

488 the method waits for 10 seconds and tries again for a total of 360 times before raising an error. 

489 The solution size depends on the size of the RC network and number of saved time steps. 

490 """ 

491 batches = self.get_batches(t_span) 

492 

493 list_t, list_y = [], [] 

494 current_y0 = y0 

495 

496 if self.save_path is not None: 

497 intermediate_save_prefix = os.path.join(self.save_path, self.save_prefix + f"_{int(t_span[-1])}_") 

498 else: 

499 intermediate_save_prefix = "dummy_path" 

500 

501 for i in range(len(batches) - 1): 

502 batch_start, batch_end = float(batches[i]), float(batches[i + 1]) 

503 if t_eval is not None: 

504 if i == 0 and not continued_simulation: 

505 mask = (t_eval >= batch_start) & (t_eval <= batch_end) 

506 else: 

507 # exclude the batch_start for every batch except the first one to prevent duplicates 

508 mask = (t_eval > batch_start) & (t_eval <= batch_end) 

509 t_eval_batch = t_eval[mask] 

510 else: 

511 t_eval_batch = None 

512 

513 self.system_handler.set_new_t_eval(t_eval_batch) 

514 

515 last_step_not_in_solution = False 

516 if t_eval_batch.size == 0 or t_eval_batch[-1] != batch_end: 

517 # The last step must be returned by solve_ivp anyway to pass it as new y0 in the next batch 

518 t_eval_batch = np.append(t_eval_batch, batch_end) 

519 last_step_not_in_solution = True 

520 

521 # update end of batch (used for correct print out and saves) 

522 self.system_handler.batch_end = batch_end 

523 

524 sol: Any = solve_ivp( 

525 fun=self.system_handler, 

526 t_span=(batch_start, batch_end), 

527 y0=current_y0, 

528 method=self.method, 

529 t_eval=t_eval_batch, 

530 max_step=self.max_step, 

531 rtol=self.rtol, 

532 atol=self.atol, 

533 **self.kwargs, 

534 ) 

535 current_y0 = sol.y[:, -1] 

536 

537 if last_step_not_in_solution: 

538 # delete last solution if not requested in settings (but it was needed for y0) 

539 self.system_handler.rc_solution.delete_last_input() 

540 new_time_steps = sol.t[:-1] 

541 new_y_values = sol.y[:, :-1] 

542 else: 

543 new_time_steps = sol.t 

544 new_y_values = sol.y 

545 

546 if new_time_steps.size != 0: 

547 list_t.append(new_time_steps) 

548 list_y.append(new_y_values) 

549 # increase counter only if new solution was added. Otherwise it saves the same solution / an empty one 

550 self.save_counter += 1 

551 

552 if self.save_counter // self.save_interval > 0 or ( 

553 batch_end == t_span[-1] 

554 and (self.system_handler.rc_solution.last_saved_timestep_index > 0 or self.minimize_ram_usage) 

555 ): 

556 save_path = f"{intermediate_save_prefix}{float(batch_end):09.0f}_s.pickle" 

557 if self.minimize_ram_usage: 

558 # save the solution and delete it afterward 

559 assert self.save_path is not None 

560 self.system_handler.rc_solution.save_to_file_only( 

561 t=np.concatenate(list_t), 

562 y=np.concatenate(list_y, axis=1).T, 

563 path_with_name_and_ending=save_path, 

564 ) 

565 else: 

566 self.system_handler.rc_solution.add_to_solution(list_t, list_y) 

567 if self.save_path is not None: 

568 self.system_handler.rc_solution.save_new_solution(save_path) 

569 

570 list_y = [] 

571 list_t = [] 

572 self.save_counter = 0 

573 

574 # Make print out for the last time step that else will be missed 

575 if self.system_handler.print_progress: 

576 if self.system_handler.next_printed_time_step <= t_span[-1]: 

577 self.system_handler.print_out_progress(t_span[-1]) 

578 

579 save_path = f"{intermediate_save_prefix}result.pickle" 

580 if self.minimize_ram_usage: 

581 # load solution in and save it as one big solution 

582 # it will load all intermediate saves because it will not find the requested path with "_result.pickle" 

583 # first check, if enough RAM is available 

584 assert self.save_path is not None 

585 loop_counter = 0 

586 max_loops = 360 

587 while get_free_ram_gb() < expected_solution_size_mb / 1000 and loop_counter < max_loops: 

588 loop_counter += 1 

589 time.sleep(10) # waiting for more free RAM 

590 if not loop_counter >= max_loops: 

591 self.system_handler.rc_solution.load_solution(save_path) 

592 # Now save the accumulated result in one pickle with the "_result" ending 

593 self.system_handler.rc_solution.save_solution(save_path) 

594 else: 

595 print("Not enough RAM for over one hour. Incremental solutions are not combined.") 

596 else: 

597 self.system_handler.rc_solution.add_to_solution(list_t, list_y) 

598 if self.save_path is not None: 

599 self.system_handler.rc_solution.save_solution(save_path)