Source code for cxroots.paths

import functools
from math import pi
from typing import Optional, TypeVar, Union, overload

import numpy as np
import numpy.typing as npt
import scipy.integrate

from .types import AnalyticFunc, Color, IntegrationMethod
from .util import integrate_quad_complex

ComplexPathType = TypeVar("ComplexPathType", bound="ComplexPath")


[docs] class ComplexPath(object): """A base class for paths in the complex plane.""" def __init__(self): self._integral_cache = {} self._trap_cache = {} # If the path is created during subvision then this will be set to the path # that is oppositely oriented to this path. This is done so that we can look # up the _integral_cache for the reverse path self._reverse_path: Optional[ComplexPath] = None @overload def __call__(self, t: float) -> complex: ... @overload def __call__(self, t: npt.NDArray[np.float_]) -> npt.NDArray[np.complex_]: ... def __call__( self, t: Union[float, npt.NDArray[np.float_]] ) -> Union[complex, npt.NDArray[np.complex_]]: r""" The parameterization of the path in the varaible :math:`t\in[0,1]`. Parameters ---------- t : float A real number :math:`0\leq t \leq 1`. Returns ------- complex A point on the path in the complex plane. """ raise NotImplementedError("__call__ must be implemented in a subclass") @overload def dzdt(self, t: float) -> complex: ... @overload def dzdt(self, t: npt.NDArray[np.float_]) -> npt.NDArray[np.complex_]: ... def dzdt( self, t: Union[float, npt.NDArray[np.float_]] ) -> Union[complex, npt.NDArray[np.complex_]]: """ The derivative of the parameterised curve in the complex plane, z, with respect to the parameterization parameter, t. """ raise NotImplementedError("dzdt must be implemented in a subclass") def distance(self, z: complex) -> float: """ Distance from the point z to the closest point on the line. """ raise NotImplementedError("distance must be implemented in a subclass")
[docs] def trap_values( self, f: AnalyticFunc, k: int, use_cache: bool = True, ) -> npt.NDArray[np.complex128]: """ Compute or retrieve (if cached) the values of the functions f at :math:`2^k+1` points along the contour which are evenly spaced with respect to the parameterisation of the contour. Parameters ---------- f : function A function of a single complex variable. k : int Defines the number of points along the curve that f is to be evaluated at as :math:`2^k+1`. use_cache : bool, optional If True then use, if available, the results of any previous calls to this function for the same f and save any new results so that they can be reused later. Returns ------- :class:`numpy.ndarray` The values of f at :math:`2^k+1` points along the contour which are evenly spaced with respect to the parameterisation of the contour. """ if f in self._trap_cache.keys() and use_cache: vals = self._trap_cache[f] vals_k = int(np.log2(len(vals) - 1)) if vals_k == k: return vals elif vals_k > k: return vals[:: 2 ** (vals_k - k)] else: t = np.linspace(0, 1, 2**k + 1) vals = np.empty(2**k + 1, dtype=np.complex128) vals.fill(np.nan) vals[:: 2 ** (k - vals_k)] = self._trap_cache[f] vals[np.isnan(vals)] = f(self(t[np.isnan(vals)])) # cache values self._trap_cache[f] = vals return vals else: t = np.linspace(0, 1, 2**k + 1) vals = f(self(t)) if not isinstance(vals, np.ndarray): # Handle case when f does not return a vector, perhaps # because the function is actually a constant vals = vals * np.ones(len(t), dtype=np.complex128) if use_cache: self._trap_cache[f] = vals return vals
def trap_product( self, k: int, f: AnalyticFunc, df: Optional[AnalyticFunc] = None, phi: Optional[AnalyticFunc] = None, psi: Optional[AnalyticFunc] = None, ) -> complex: r""" Use Romberg integration to estimate the symmetric bilinear form used in (1.12) of [KB]_ using 2**k+1 samples .. math:: <\phi,\psi> = \frac{1}{2\pi i} \oint_C \phi(z)\psi(z)\frac{f'(z)}{f(z)} dz """ # compute/retrieve function evaluations f_val = self.trap_values(f, k) t = np.linspace(0, 1, 2**k + 1) dt = t[1] - t[0] if df is None: # approximate df/dz with finite difference dfdt = np.gradient(f_val, dt) df_val = dfdt / self.dzdt(t) else: df_val = self.trap_values(df, k) segment_integrand = df_val / f_val * self.dzdt(t) if phi is not None: segment_integrand = self.trap_values(phi, k) * segment_integrand if psi is not None: segment_integrand = self.trap_values(psi, k) * segment_integrand return scipy.integrate.romb(segment_integrand, dx=dt, axis=-1) / (2j * pi)
[docs] def plot( self, num_points: int = 100, linecolor: Color = "C0", linestyle: str = "-", ) -> None: """ Uses matplotlib to plot, but not show, the path as a 2D plot in the Complex plane. Parameters ---------- num_points : int, optional The number of points to use when plotting the path. linecolor : optional The colour of the plotted path, passed to the :func:`matplotlib.pyplot.plot` function as the keyword argument of 'color'. See the matplotlib tutorial on specifying `colours <https://matplotlib.org/stable/tutorials/colors/colors#>`_. linestyle : str, optional The line style of the plotted path, passed to the :func:`matplotlib.pyplot.plot` function as the keyword argument of 'linestyle'. The default corresponds to a solid line. See :meth:`matplotlib.lines.Line2D.set_linestyle` for other acceptable arguments. """ import matplotlib.pyplot as plt t = np.linspace(0, 1, num_points) path = self(t) plt.plot(path.real, path.imag, color=linecolor, linestyle=linestyle) plt.xlabel("Re[$z$]", size=16) plt.ylabel("Im[$z$]", size=16) plt.gca().set_aspect(1) # add arrow to indicate direction of path arrow_start = self(0.5) arrow_end = self(0.51) plt.annotate( "", (arrow_end.real, arrow_end.imag), (arrow_start.real, arrow_start.imag), arrowprops=dict(arrowstyle="->", fc=linecolor, ec=linecolor), )
[docs] def integrate( self, f: AnalyticFunc, abs_tol: float = 1.49e-08, rel_tol: float = 1.49e-08, div_max: int = 15, int_method: IntegrationMethod = "quad", ) -> complex: r""" Integrate the function f along the path. .. math:: \oint_C f(z) dz The value of the integral is cached and will be reused if the method is called with same arguments. Parameters ---------- f : function A function of a single complex variable. abs_tol : float, optional The absolute tolerance for the integration. rel_tol : float, optional The realative tolerance for the integration. div_max : int, optional If the Romberg integration method is used then div_max is the maximum number of divisions before the Romberg integration routine of a path exits. int_method : {'quad', 'romb'}, optional If 'quad' then :func:`scipy.integrate.quad` is used to compute the integral. If 'romb' then Romberg integraion, using :func:`scipy.integrate.romberg`, is used instead. Returns ------- complex The integral of the function f along the path. """ args = (f, abs_tol, rel_tol, div_max, int_method) if args in self._integral_cache.keys(): return self._integral_cache[args] if ( self._reverse_path is not None and args in self._reverse_path._integral_cache ): # if we have already computed the reverse of this path return -self._reverse_path._integral_cache[args] @functools.lru_cache(maxsize=None) def integrand(t: float) -> complex: return f(self(t)) * self.dzdt(t) if int_method == "romb": integral = scipy.integrate.romberg( integrand, 0, 1, tol=abs_tol, rtol=rel_tol, divmax=div_max, ) elif int_method == "quad": integral = integrate_quad_complex( integrand, 0, 1, epsabs=abs_tol, epsrel=rel_tol ) else: raise ValueError("int_method must be either 'romb' or 'quad'") if np.isnan(integral): raise RuntimeError( f"The integral along the segment {self} is NaN. This is most " "likely due to a root being on or very close to the path of " "integration." ) self._integral_cache[args] = integral return integral
[docs] def show(self, save_file: Optional[str] = None, **plot_kwargs) -> None: """ Shows the path as a 2D plot in the complex plane. Requires Matplotlib. Parameters ---------- save_file : str (optional) If given then the plot will be saved to disk with name 'save_file'. If save_file=None the plot is shown on-screen. **plot_kwargs Other key word args are passed to :meth:`~cxroots.Paths.ComplexPath.plot` """ import matplotlib.pyplot as plt self.plot(**plot_kwargs) if save_file is not None: plt.savefig(save_file, bbox_inches="tight") plt.close() else: plt.show()
[docs] class ComplexLine(ComplexPath): r""" A straight line :math:`z` in the complex plane from points a to b parameterised by ..math:: z(t) = a + (b-a)t, \quad 0\leq t \leq 1 """ def __init__(self, a: complex, b: complex): self.a, self.b = a, b super(ComplexLine, self).__init__() def __str__(self): return "ComplexLine from %.3f+%.3fi to %.3f+%.3fi" % ( self.a.real, self.a.imag, self.b.real, self.b.imag, ) @overload def __call__(self, t: float) -> complex: ... @overload def __call__(self, t: npt.NDArray[np.float_]) -> npt.NDArray[np.complex_]: ...
[docs] def __call__( self, t: Union[float, npt.NDArray[np.float_]] ) -> Union[complex, npt.NDArray[np.complex_]]: r""" The function :math:`z(t) = a + (b-a)t`. Parameters ---------- t : float A real number :math:`0\leq t \leq 1`. Returns ------- complex A point on the line in the complex plane. """ return self.a + t * (self.b - self.a)
@overload def dzdt(self, t: float) -> complex: ... @overload def dzdt(self, t: npt.NDArray[np.float_]) -> npt.NDArray[np.complex_]: ... def dzdt( self, t: Union[float, npt.NDArray[np.float_]] ) -> Union[complex, npt.NDArray[np.complex_]]: """ The derivative of the parameterised curve in the complex plane, z, with respect to the parameterization parameter, t. """ return self.b - self.a
[docs] def distance(self, z: complex) -> float: """ Distance from the point z to the closest point on the line. Parameters ---------- z : complex Returns ------- float The distance from z to the point on the line which is closest to z. """ # the projection of the point z onto the line a -> b is where # the parameter t is d = self.b - self.a t = ((z - self.a) * d.conjugate()).real / (d * d.conjugate()) # but the line segment only has 0 <= t <= 1 t = np.clip(t.real, 0, 1) # so the point on the line segment closest to z is c = self(t) return abs(c - z)
[docs] class ComplexArc(ComplexPath): r""" A circular arc :math:`z` with center z0, radius R, initial angle t0 and change of angle dt. The arc is parameterised by ..math:: z(t) = R e^{i(t0 + t dt)} + z0, \quad 0\leq t \leq 1 Parameters ---------- z0 : complex R : float t0 : float dt : float """ def __init__(self, z0: complex, R: float, t0: float, dt: float): # noqa: N803 self.z0, self.R, self.t0, self.dt = z0, R, t0, dt super(ComplexArc, self).__init__() def __str__(self): return "ComplexArc: z0=%.3f, R=%.3f, t0=%.3f, dt=%.3f" % ( self.z0, self.R, self.t0, self.dt, ) @overload def __call__(self, t: float) -> complex: ... @overload def __call__(self, t: npt.NDArray[np.float_]) -> npt.NDArray[np.complex_]: ...
[docs] def __call__( self, t: Union[float, npt.NDArray[np.float_]] ) -> Union[complex, npt.NDArray[np.complex_]]: r""" The function :math:`z(t) = R e^{i(t_0 + t dt)} + z_0`. Parameters ---------- t : float A real number :math:`0\leq t \leq 1`. Returns ------- complex A point on the arc in the complex plane. """ return self.R * np.exp(1j * (self.t0 + t * self.dt)) + self.z0
@overload def dzdt(self, t: float) -> complex: ... @overload def dzdt(self, t: npt.NDArray[np.float_]) -> npt.NDArray[np.complex_]: ... def dzdt( self, t: Union[float, npt.NDArray[np.float_]] ) -> Union[complex, npt.NDArray[np.complex_]]: """ The derivative of the parameterised curve in the complex plane, z, with respect to the parameterization parameter, t. """ return 1j * self.dt * self.R * np.exp(1j * (self.t0 + t * self.dt))
[docs] def distance(self, z: complex) -> float: """ Distance from the point z to the closest point on the arc. Parameters ---------- z : complex Returns ------- float The distance from z to the point on the arc which is closest to z. """ theta = np.angle(z - self.z0) # np.angle maps to (-pi,pi] theta = (theta - self.t0) % (2 * pi) + self.t0 # put theta in [t0,t0+2pi) if (self.dt > 0 and self.t0 < theta < self.t0 + self.dt) or ( self.dt < 0 and self.t0 + self.dt < theta - 2 * pi < self.t0 ): # the closest point to z lies on the arc return abs(self.R * np.exp(1j * theta) + self.z0 - z) else: # the closest point to z is one of the endpoints return min(abs(self(0) - z), abs(self(1) - z))