Coverage for pyrc \ core \ solver \ handler.py: 70%
253 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-29 14:14 +0200
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-29 14:14 +0200
1# -------------------------------------------------------------------------------
2# Copyright (C) 2026 Joel Kimmich, Tim Jourdan
3# ------------------------------------------------------------------------------
4# License
5# This file is part of PyRC, distributed under GPL-3.0-or-later.
6# ------------------------------------------------------------------------------
8from __future__ import annotations
10import os.path
11import time
12import warnings
13from collections import deque
14from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
15from datetime import datetime
16from multiprocessing import cpu_count
17from typing import Any, Callable
19import numpy as np
20from scipy.integrate import solve_ivp
21from scipy.sparse import spmatrix, sparray
23from pyrc.core.settings import initial_settings, Settings, SolveSettings
24from pyrc.core.solver.symbolic import SparseSymbolicEvaluator
25from pyrc.core.components.templates import RCSolution
26from pyrc.tools.science import get_free_ram_gb
29class HomogeneousSystemHandler:
30 def __init__(
31 self,
32 system_matrix: SparseSymbolicEvaluator | spmatrix | sparray,
33 rc_solution: RCSolution,
34 settings: Settings,
35 print_progress=True,
36 print_points: list | np.ndarray = None,
37 batch_end=None,
38 time_dependent_function=None,
39 _initialize_temperature_function=True,
40 **kwargs,
41 ):
42 """
43 System Handler that is called ("as function") by ``solve_ivp``.
45 Parameters
46 ----------
47 system_matrix : spmatrix | sparray | SparseSymbolicEvaluator
48 The system matrix.
49 If constant, the type must be a sparse scipy matrix.
50 If time dependent symbols are contained, the system matrix must be given in a `SparseSymbolicEvaluator`
51 object.
52 rc_solution
53 settings
54 print_progress
55 print_points
56 batch_end
57 time_dependent_function : Callable, optional
58 A function that calculates all time dependent variables within the time step and returns them in the same order as
59 self.get_time_dependent_symbols().
60 It gets parameters like this:\n
61 ``time_dependent_function(time, temperature_vector)`` \n
62 or\n
63 ``time_dependent_function(time, temperature_vector, input_vector)``\\.\n
64 This function is required if time dependent symbols exist.
65 It must return an iterable (e.g. list).
66 To not run into Errors just use ``*args``\\, ``**kwargs`` at the end in case more values are passed then
67 needed.
68 _initialize_temperature_function : bool, optional
69 If ``False``, the ``dtemperature_dt_function`` attribute is not set using the method `_init_temperature_equation()`
70 Used in subclasses to prevent an early call of attributes that are created in the subclasses later on.
71 kwargs : dict, optional
72 Not used: Just to not raise an error if too much information is passed.
73 """
74 self.system_matrix: spmatrix | sparray | SparseSymbolicEvaluator = system_matrix
75 self.rc_solution = rc_solution
76 self.settings = settings
78 self.time_dependent_function: Callable = time_dependent_function # must return the value(s) as iterable
79 self.time_dependent_active: bool = True
80 if time_dependent_function is None:
81 self.time_dependent_active = False
82 self.time_dependent_function = lambda *args, **keyword_args: []
84 if print_progress and batch_end is None:
85 import warnings
87 warnings.warn("Print progress might not work as expected because batch_end value is not given.")
88 batch_end = np.inf
89 self.batch_end = batch_end
91 # print progress
92 if print_points is None:
93 print_points = [0]
94 if isinstance(print_points, np.ndarray):
95 print_points = print_points.tolist()
96 self.print_points: deque = deque(print_points)
97 self.next_printed_time_step = self.print_points.popleft()
98 self.print_progress = print_progress
100 # initialize the temperature equation using a lambda
101 self.dtemperature_dt_function = lambda *args, **keyword_args: None
102 if _initialize_temperature_function:
103 # this can be switched off because it always have to run at the end of all subclasses inits
104 self.dtemperature_dt_function = self._init_temperature_equation()
106 def _init_temperature_equation(self) -> Callable:
107 if isinstance(self.system_matrix, SparseSymbolicEvaluator):
108 assert self.time_dependent_active
109 system_matrix = lambda v: self.system_matrix.evaluate(v)
110 else:
111 if self.time_dependent_active:
112 warnings.warn(
113 "Function to calculate time dependent values passed, but no time dependent system matrix detected.\n"
114 "The passed function has no effect."
115 )
116 return lambda temperature: self.system_matrix @ temperature
117 return lambda temperature, v: system_matrix(v) @ temperature
119 def __call__(self, t, temperature):
120 """
121 The function the solve_ivp is going to solve.
123 Parameters
124 ----------
125 t
126 temperature
128 Notes
129 -----
130 The input vector that is saved during the iteration of the solver at the t_eval
132 Returns
133 -------
134 np.ndarray :
135 The resulting vector of the temperature derivative.
136 """
137 if self.print_progress:
138 self._update_progress_print(t)
139 temperature = temperature.reshape(-1, 1)
141 if self.time_dependent_active:
142 time_dependent_values = self.time_dependent_function(t, temperature)
143 dtemperature_dt = self.dtemperature_dt_function(temperature, time_dependent_values)
144 else:
145 dtemperature_dt = self.dtemperature_dt_function(temperature)
147 return dtemperature_dt.flatten()
149 def set_new_t_eval(self, new_t_eval):
150 """
151 Change the current t_eval.
153 Parameters
154 ----------
155 new_t_eval : array_like
156 The new t_eval.
157 """
158 # Currently only used in the child class
159 pass
161 def _update_progress_print(self, t):
162 if t >= self.next_printed_time_step:
163 # prevent initialization print out (is not fail save, but okay)
164 if t != self.batch_end or (self.print_points and t == self.print_points[0]):
165 if len(self.print_points) > 0:
166 self.next_printed_time_step = self.print_points.popleft()
167 self.print_out_progress(t)
168 else:
169 self.next_printed_time_step = np.inf
170 # deactivate printing for better performance
171 self.print_progress = False
173 def print_out_progress(self, t):
174 """
175 Prints a formatted time stamp.
177 Parameters
178 ----------
179 t : float | int
180 The time to format/print.
181 """
182 print(f"Progress: t = {self.format_t(t)}")
184 @staticmethod
185 def format_t(t):
186 days = int(t) // 86400
187 hours = (int(t) % 86400) // 3600
188 minutes = (int(t) % 3600) // 60
189 seconds = int(t) % 60
190 return (
191 f"{days:>4} days, {hours:02}:{minutes:02}:{seconds:02} - current time: "
192 f"{datetime.now().strftime('%d %H:%M:%S')}"
193 )
196class InhomogeneousSystemHandler(HomogeneousSystemHandler):
197 def __init__(
198 self,
199 system_matrix,
200 input_matrix,
201 input_vector,
202 rc_solution,
203 functions_list,
204 kwargs_functions: dict,
205 t_eval=None,
206 time_dependent_function=None,
207 use_parallelization=False,
208 core_count=cpu_count(),
209 print_progress=True,
210 print_points=3600,
211 settings: Settings = initial_settings,
212 first_time: float | int = 0,
213 **kwargs,
214 ):
215 super().__init__(
216 system_matrix,
217 rc_solution,
218 settings,
219 time_dependent_function=time_dependent_function,
220 print_progress=print_progress,
221 print_points=print_points,
222 batch_end=t_eval[-1],
223 _initialize_temperature_function=False,
224 )
225 self.input_matrix = input_matrix
226 self.input_vector = input_vector.reshape(-1, 1)
227 self.functions_list = functions_list
229 self.after_iteration = np.inf
230 self.next_eval_time = 0
231 self.t_eval_iter = iter(())
232 self.set_new_t_eval(t_eval)
234 self.last_t = first_time
235 self.last_eval_time = 0
236 self.current_eval_time = 0
237 self.use_current_eval_time = False
239 self.input_update_function = self._update_input_vector
241 # if use_parallelization: # not used because it is not working yet
242 # # decide what is the best using the length of the input matrix.
243 # # If len(input) <= 500 only use threads, otherwise multiprocessing
244 # number_inputs = self.input_vector.shape[0]
245 # if number_inputs <= np.inf: # switched off because pickle doesn't work
246 # self.input_update_function = self._update_input_vector
247 # elif number_inputs <= 50:
248 # self.input_update_function = self._update_input_vector_threads
249 # else:
250 # self.input_update_function = self._update_input_vector_processes
251 self.core_count = max(1, min(core_count, cpu_count()))
252 self.kwargs_functions: dict = kwargs_functions
254 # initialize the temperature equation using a lambda
255 self.dtemperature_dt_function = self._init_temperature_equation()
257 def _init_temperature_equation(self) -> Callable:
258 if isinstance(self.system_matrix, SparseSymbolicEvaluator):
259 system_matrix_fun = lambda v: self.system_matrix.evaluate(v)
260 else:
261 system_matrix_fun = lambda *args, **kwargs: self.system_matrix
262 if isinstance(self.input_matrix, SparseSymbolicEvaluator):
263 input_matrix_fun = lambda v: self.input_matrix.evaluate(v)
264 else:
265 input_matrix_fun = lambda *args, **kwargs: self.input_matrix
266 return lambda temperature, input_v, v: system_matrix_fun(v) @ temperature + input_matrix_fun(v) @ input_v
268 def set_new_t_eval(self, new_t_eval):
269 """
270 Change the current t_eval.
272 Parameters
273 ----------
274 new_t_eval : array_like
275 The new t_eval.
276 """
277 if new_t_eval is None:
278 self.t_eval_iter = iter([-np.inf])
279 self.after_iteration = -np.inf
280 else:
281 self.t_eval_iter = iter(new_t_eval)
282 self.after_iteration = np.inf
283 self.next_eval_time = next(self.t_eval_iter, self.after_iteration)
284 self.use_current_eval_time = False
286 def calculate_kwargs(self, tau, temp_vector, _input_vector, **kwargs):
287 return {name: fun(tau, temp_vector, _input_vector, **kwargs) for name, fun in self.kwargs_functions.items()}
289 def _update_input_vector_threads(self, t, temperature, **kwargs):
290 with ThreadPoolExecutor() as executor:
291 results = list(executor.map(lambda f: f(t, temperature, self.input_vector, **kwargs), self.functions_list))
292 self.input_vector = np.array(results).reshape(-1, 1)
294 @staticmethod
295 def _worker_call(args):
296 f, t, temperature, input_vector, kwargs = args
297 return f(t, temperature, input_vector, **kwargs)
299 def _update_input_vector_processes(self, t, temperature, **kwargs):
300 args_iter = [(f, t, temperature, self.input_vector.copy(), kwargs) for f in self.functions_list]
301 with ProcessPoolExecutor(max_workers=self.core_count) as executor:
302 results = list(executor.map(self._worker_call, args_iter))
303 self.input_vector = np.array(results).reshape(-1, 1)
305 def _update_input_vector(self, t, temperature, **kwargs):
306 # TODO: Speedup by creating one big lambda function returning a vector instead of looping over all
307 self.input_vector = np.array(
308 [f(t, temperature, self.input_vector, **kwargs) for f in self.functions_list]
309 ).reshape(-1, 1)
311 def __call__(self, t, temperature: np.ndarray):
312 """
313 The function the solve_ivp is going to solve.
315 Parameters
316 ----------
317 t
318 temperature
320 Notes
321 -----
322 The input vector that is saved during the iteration of the solver at the t_eval
324 Returns
325 -------
326 np.ndarray :
327 The resulting vector of the temperature derivative.
328 """
329 if t < self.last_t:
330 if t == self.batch_end or (self.last_t >= self.last_eval_time > t):
331 # Delete the last result because the solver is in initialization or one time step was saved to early.
332 # Revert the next_eval_time to the last value.
333 self.next_eval_time = self.last_eval_time
334 self.use_current_eval_time = True
335 self.rc_solution.delete_last_input()
336 if self.print_progress:
337 self._update_progress_print(t)
338 temperature = temperature.reshape(-1, 1)
340 kwargs = self.calculate_kwargs(t, temperature, self.input_vector)
342 self.input_update_function(t, temperature, **kwargs)
344 time_dependent_values = self.time_dependent_function(t, temperature, self.input_vector)
346 # Check if the input vector has to be saved.
347 if t >= self.next_eval_time:
348 # NOTE: This is not perfect, because the input vector of the next time step is saved for the result of
349 # the last time step. However, the solver will perform such small iteration steps that this will not
350 # become a problem.
351 self.rc_solution.append_to_input(self.input_vector)
352 self.last_eval_time = self.next_eval_time
353 if self.use_current_eval_time:
354 self.next_eval_time = self.current_eval_time
355 self.use_current_eval_time = False
356 else:
357 self.next_eval_time = next(self.t_eval_iter, self.after_iteration)
359 dtemperature_dt = self.dtemperature_dt_function(temperature, self.input_vector, time_dependent_values)
361 self.last_t = t
362 return dtemperature_dt.flatten()
365class SolveIVPHandler:
366 def __init__(
367 self,
368 system_handler: HomogeneousSystemHandler | InhomogeneousSystemHandler,
369 max_saved_steps=None,
370 method=None,
371 max_step=None,
372 rtol=None,
373 atol=None,
374 save_interval=None,
375 save_path=None,
376 save_prefix="",
377 minimize_ram_usage=None,
378 solve_settings=None,
379 **kwargs,
380 ):
381 """
382 Handler of the solving process with solve_ivp.
384 Parameters
385 ----------
386 system_handler : HomogeneousSystemHandler | InhomogeneousSystemHandler
387 The system handler that holds the function to solve in __call__().
388 max_saved_steps : int, optional
389 The maximum number of used seconds during one solve_ivp call.
390 It defines the batch size in seconds.
391 Using this prevents long solving time because the matrices become very big.
392 method : str, optional
393 The method to use to solve the system (see scipy.solve_ivp).
394 max_step
395 rtol
396 atol
397 save_interval
398 save_path
399 save_prefix : str, optional
400 This is the beginning of the name of the pickle file that is saved during the solving.
401 minimize_ram_usage : bool, optional
402 If True, the solution is deleted if saved to minimize the RAM usage during runtime.
403 kwargs
404 """
405 if solve_settings is None:
406 solve_settings = SolveSettings()
407 self.solve_settings = solve_settings
408 max_saved_steps, method, max_step, rtol, atol, save_interval, minimize_ram_usage = self.set_initial_values(
409 max_saved_steps=max_saved_steps,
410 method=method,
411 max_step=max_step,
412 rtol=rtol,
413 atol=atol,
414 save_interval=save_interval,
415 minimize_ram_usage=minimize_ram_usage,
416 )
417 self.system_handler: HomogeneousSystemHandler | InhomogeneousSystemHandler = system_handler
418 self.max_saved_steps = int(max_saved_steps)
419 self.method = method
420 self.max_step = max_step
421 self.rtol = rtol
422 self.atol = atol
424 self.save_interval = save_interval
425 self.save_counter = 0
426 if save_path is None:
427 save_path = self.system_handler.settings.save_folder_path
428 if save_path is None:
429 self.save_path = None
430 else:
431 self.save_path = os.path.normpath(save_path)
432 self.save_prefix = save_prefix
434 if self.save_path is None and minimize_ram_usage:
435 print("Minimize RAM usage deactivated because no save_path is declared in Settings.")
436 print("This also deactivates every save during solving.")
437 minimize_ram_usage = False
438 self.minimize_ram_usage = minimize_ram_usage
440 self.kwargs = kwargs
442 def set_initial_values(self, **kwargs):
443 result = []
444 settings_dict: dict = self.solve_settings.dict
445 for key, arg in zip(kwargs.keys(), kwargs.values()):
446 if arg is None:
447 result.append(settings_dict[key])
448 else:
449 result.append(arg)
450 print(f"Solve setting {key} is overwritten by manual/coded value: {arg}")
451 return result
453 def get_batches(self, t_span):
454 start, end = t_span
455 splits = np.arange(start, end, self.max_saved_steps)
456 if splits[-1] != end:
457 splits = np.append(splits, end)
458 return splits
460 def solve(
461 self,
462 t_span,
463 y0,
464 t_eval=None,
465 continued_simulation: bool = False,
466 expected_solution_size_mb=5000,
467 ):
468 """
469 Runs the solve_ivp in batches.
471 Parameters
472 ----------
473 t_span : tuple[int | float, int| float] :
474 The start and end of the simulation in seconds.
475 Usually it starts at 0: (0, end_seconds)
476 See also: scipy.integrate.solve_ivp()
477 y0 : np.ndarray | Iterable | float | int
478 The initial state of the system.
479 See also: scipy.integrate.solve_ivp()
480 t_eval : np.ndarray | Iterable | float | int, optional
481 Times at which to store the computed solution, must be sorted and lie within `t_span`\\.
482 If None (default), use points selected by the solver.
483 See also: scipy.integrate.solve_ivp()
484 continued_simulation : bool, optional
485 If True, the first value of the first batch is not kept, because the simulation is continued.
486 expected_solution_size_mb : int | float, optional
487 The expected solution size in Megabytes. Is used for RAM-Management: if not enough free memory is available,
488 the method waits for 10 seconds and tries again for a total of 360 times before raising an error.
489 The solution size depends on the size of the RC network and number of saved time steps.
490 """
491 batches = self.get_batches(t_span)
493 list_t, list_y = [], []
494 current_y0 = y0
496 if self.save_path is not None:
497 intermediate_save_prefix = os.path.join(self.save_path, self.save_prefix + f"_{int(t_span[-1])}_")
498 else:
499 intermediate_save_prefix = "dummy_path"
501 for i in range(len(batches) - 1):
502 batch_start, batch_end = float(batches[i]), float(batches[i + 1])
503 if t_eval is not None:
504 if i == 0 and not continued_simulation:
505 mask = (t_eval >= batch_start) & (t_eval <= batch_end)
506 else:
507 # exclude the batch_start for every batch except the first one to prevent duplicates
508 mask = (t_eval > batch_start) & (t_eval <= batch_end)
509 t_eval_batch = t_eval[mask]
510 else:
511 t_eval_batch = None
513 self.system_handler.set_new_t_eval(t_eval_batch)
515 last_step_not_in_solution = False
516 if t_eval_batch.size == 0 or t_eval_batch[-1] != batch_end:
517 # The last step must be returned by solve_ivp anyway to pass it as new y0 in the next batch
518 t_eval_batch = np.append(t_eval_batch, batch_end)
519 last_step_not_in_solution = True
521 # update end of batch (used for correct print out and saves)
522 self.system_handler.batch_end = batch_end
524 sol: Any = solve_ivp(
525 fun=self.system_handler,
526 t_span=(batch_start, batch_end),
527 y0=current_y0,
528 method=self.method,
529 t_eval=t_eval_batch,
530 max_step=self.max_step,
531 rtol=self.rtol,
532 atol=self.atol,
533 **self.kwargs,
534 )
535 current_y0 = sol.y[:, -1]
537 if last_step_not_in_solution:
538 # delete last solution if not requested in settings (but it was needed for y0)
539 self.system_handler.rc_solution.delete_last_input()
540 new_time_steps = sol.t[:-1]
541 new_y_values = sol.y[:, :-1]
542 else:
543 new_time_steps = sol.t
544 new_y_values = sol.y
546 if new_time_steps.size != 0:
547 list_t.append(new_time_steps)
548 list_y.append(new_y_values)
549 # increase counter only if new solution was added. Otherwise it saves the same solution / an empty one
550 self.save_counter += 1
552 if self.save_counter // self.save_interval > 0 or (
553 batch_end == t_span[-1]
554 and (self.system_handler.rc_solution.last_saved_timestep_index > 0 or self.minimize_ram_usage)
555 ):
556 save_path = f"{intermediate_save_prefix}{float(batch_end):09.0f}_s.pickle"
557 if self.minimize_ram_usage:
558 # save the solution and delete it afterward
559 assert self.save_path is not None
560 self.system_handler.rc_solution.save_to_file_only(
561 t=np.concatenate(list_t),
562 y=np.concatenate(list_y, axis=1).T,
563 path_with_name_and_ending=save_path,
564 )
565 else:
566 self.system_handler.rc_solution.add_to_solution(list_t, list_y)
567 if self.save_path is not None:
568 self.system_handler.rc_solution.save_new_solution(save_path)
570 list_y = []
571 list_t = []
572 self.save_counter = 0
574 # Make print out for the last time step that else will be missed
575 if self.system_handler.print_progress:
576 if self.system_handler.next_printed_time_step <= t_span[-1]:
577 self.system_handler.print_out_progress(t_span[-1])
579 save_path = f"{intermediate_save_prefix}result.pickle"
580 if self.minimize_ram_usage:
581 # load solution in and save it as one big solution
582 # it will load all intermediate saves because it will not find the requested path with "_result.pickle"
583 # first check, if enough RAM is available
584 assert self.save_path is not None
585 loop_counter = 0
586 max_loops = 360
587 while get_free_ram_gb() < expected_solution_size_mb / 1000 and loop_counter < max_loops:
588 loop_counter += 1
589 time.sleep(10) # waiting for more free RAM
590 if not loop_counter >= max_loops:
591 self.system_handler.rc_solution.load_solution(save_path)
592 # Now save the accumulated result in one pickle with the "_result" ending
593 self.system_handler.rc_solution.save_solution(save_path)
594 else:
595 print("Not enough RAM for over one hour. Incremental solutions are not combined.")
596 else:
597 self.system_handler.rc_solution.add_to_solution(list_t, list_y)
598 if self.save_path is not None:
599 self.system_handler.rc_solution.save_solution(save_path)