Coverage for pyrc \ core \ solver \ symbolic.py: 100%
37 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-13 16:59 +0200
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-13 16:59 +0200
1# -------------------------------------------------------------------------------
2# Copyright (C) 2026 Joel Kimmich, Tim Jourdan
3# ------------------------------------------------------------------------------
4# License
5# This file is part of PyRC, distributed under GPL-3.0-or-later.
6# ------------------------------------------------------------------------------
8from sympy import SparseMatrix, lambdify, Symbol
9import numpy as np
10from scipy.sparse import csr_matrix, coo_matrix
13class SparseSymbolicEvaluator:
14 def __init__(self, semi_symbolic_matrix: SparseMatrix, time_symbols: list[Symbol]):
15 """
16 Parameters
17 ----------
18 semi_symbolic_matrix : SparseMatrix
19 Sparse matrix with time-dependent symbols or numpy array
20 time_symbols : list[Symbol]
21 Ordered list of time-dependent symbols
22 """
23 self.time_symbols: list[Symbol] = time_symbols
24 self.shape = semi_symbolic_matrix.shape
26 symbolic_expressions = []
27 symbolic_rows, symbolic_cols = [], []
28 constant_data, constant_rows, constant_cols = [], [], []
30 for (i, j), expr in semi_symbolic_matrix.todok().items():
31 expr = semi_symbolic_matrix[i, j]
33 if hasattr(expr, "free_symbols") and expr.free_symbols:
34 symbolic_expressions.append(expr)
35 symbolic_rows.append(i)
36 symbolic_cols.append(j)
37 else:
38 constant_data.append(float(expr))
39 constant_rows.append(i)
40 constant_cols.append(j)
42 self._n_symbolic = len(symbolic_expressions)
43 n_constant = len(constant_data)
45 if self._n_symbolic > 0:
46 self._symbolic_func = lambdify(time_symbols, symbolic_expressions, "numpy")
48 all_rows = np.array(constant_rows + symbolic_rows, dtype=np.int32)
49 all_cols = np.array(constant_cols + symbolic_cols, dtype=np.int32)
50 self._all_data = np.empty(n_constant + self._n_symbolic, dtype=np.float64)
51 self._all_data[:n_constant] = constant_data
52 self._symbolic_slice = slice(n_constant, n_constant + self._n_symbolic)
54 # Build COO and convert to CSR to recover the permutation order
55 coo = coo_matrix((self._all_data, (all_rows, all_cols)), shape=self.shape)
57 self._perm = np.lexsort((all_cols, all_rows))
58 self._matrix = csr_matrix(coo)
59 self._csr_data = self._matrix.data
61 def evaluate(self, time_values=None) -> csr_matrix:
62 """
63 Parameters
64 ----------
65 time_values : array-like
66 Values for time-dependent symbols in same order as time_symbols
68 Returns
69 -------
70 csr_matrix
71 Evaluated sparse matrix
72 """
73 if self._n_symbolic > 0:
74 self._all_data[self._symbolic_slice] = self._symbolic_func(*time_values)
75 self._csr_data[:] = self._all_data[self._perm]
76 return self._matrix