Source code for trinidi.resolution

""" Resolution Operator """

import numpy as np

from jax import device_put

from scico.linop import Convolve, LinearOperator

from trinidi.util import time2energy

[docs]class ResolutionOperator: """ResolutionOperator Class""" def __repr__(self): return f"""{type(self)} input_shape = {self.input_shape} = projection_shape + (N_F,) output_shape = {self.output_shape} = projection_shape + (N_A,) projection_shape = {self.projection_shape} N_F = {self.input_shape[-1]} N_A = {self.output_shape[-1]} """ def __init__(self, output_shape, t_A, kernels=None): """Initialize a ResolutionOperator object. Args: output_shape: Output shape of operator, i.e. measurement shape. kernels (list of nd-arrays): list of convolution kernels. ``None`` results in identity operator. Each kernel must sum to 1. """ self.output_shape = output_shape self.t_A = t_A if kernels == None: kernels = [np.array([1])] if len(kernels) >= 1: self.kernels = kernels else: raise ValueError("Number of kernels must be at least 1.") for k in self.kernels: if np.any(k < 0): raise ValueError("Kernels must me non-negative") if np.abs(np.sum(k) - 1) > 1e-3: raise ValueError("Kernels must sum to 1.") self.projection_shape = self.output_shape[:-1] self.N_A = self.output_shape[-1] kernel_sizes = [k.size for k in self.kernels] N_buffer_lo = int((max(kernel_sizes[:2]) - 1) / 2) N_buffer_hi = int((max(kernel_sizes[-2:]) - 1) / 2) self.N_F = N_buffer_lo + self.N_A + N_buffer_hi # Creating operators self.input_shape = self.projection_shape + (self.N_F,) self.W = self._get_weights(self.kernels, self.N_F) # Convolution operators with different stds self.Hks = [] for k, kernel in enumerate(self.kernels): Hk = self._get_Hk(kernel, self.input_shape) self.Hks.append(Hk) HWk = LinearOperator( input_shape=self.input_shape, eval_fn=lambda x, Hk=Hk, Wk=self.W[k]: Hk(x) * Wk, ) if k == 0: self.H = HWk else: self.H = self.H + HWk if N_buffer_hi > 0: self.G = lambda x: ((x.T)[N_buffer_lo:-N_buffer_hi]).T else: self.G = lambda x: ((x.T)[N_buffer_lo:]).T self.R = lambda x: self.G(self.H(x)) if self.projection_shape == (1,): self.single = self else: single_output_shape = ( 1, self.N_A, ) self.R_single = self.__class__(single_output_shape, None, kernels=self.kernels) if self.t_A is not None: self.t_F = self.compute_t_F(self.t_A)
[docs] def __call__(self, x): return self.R(x)
[docs] def single(self, x): r"""Call Resolution operator on an array of size `(1, N_F)`.""" return self.R_single(x)
[docs] def call_on_any_array(self, array): """Call ResolutionOperator on an array of size `(..., N_F)`. Note: New ResolutionOperator object is created every time this function is called. Thus, is is not recommended to use this function when fast performance is required. """ if array.shape[-1] != self.input_shape[-1]: raise ValueError( f"array shape not compatible. array.shape[-1] ({array.shape[-1]}) != input_shape[-1] ({self.input_shape[-1]})" ) output_shape = array.shape[:-1] + (self.output_shape[-1],) R_ = self.__class__(output_shape, None, kernels=self.kernels) return R_(array)
[docs] def compute_t_F(self, t_A): r"""Finds time-of-flight array so that :math:`t_F^\top R \approx t_A^\top`. Args: t_A (array): Time-of-arrival equi-spaced increasing array. Returns: t_F (array): Time-of-flight equi-spaced increasing array. """ x = np.arange(self.N_F) u = np.array(self.call_on_any_array(x)) slope = (u[-1] - u[0]) / (u.size - 1) offset = u[0] Δt = t_A[1] - t_A[0] # desired slope t0 = t_A[0] # desired offset t_F = ((x - offset) / slope) * Δt + t0 return t_F
def _get_weights(self, kernels, N_F): def triangle(size, center=0, radius=1): x = np.arange(size) y = 1 - np.abs(x - center) / radius y = np.maximum(y, 0) return y K = len(kernels) if K > 1: W = np.zeros([K, N_F]) for i in range(K): W[i] = triangle(N_F, center=(N_F - 1) / (K - 1) * i, radius=(N_F - 1) / (K - 1)) else: W = np.ones([K, N_F]) return W def _get_Hk(self, kernel, input_shape): h = kernel.copy() h = np.require(h, dtype=np.float32) kernel_shape = tuple(np.ones_like(self.projection_shape)) + h.shape h = h.reshape(kernel_shape) h = device_put(h) return Convolve(h, input_shape=input_shape, mode="same", jit=True)
[docs] def plot_kernel_weights(self, ax): """Plot kernel weights as a function of t_F.""" if self.t_A is not None: for i, w in enumerate(self.W): ax.plot( self.t_F[w > 0], w[w > 0], linestyle="--", label=f"w{i}", alpha=0.6, linewidth=1.3, ) ax.legend(prop={"size": 8}) ax.set_title("Kernel Weights") else: raise ValueError("Can only plot weights if t_A is not None at construction.")
from scipy.special import gamma
[docs]def lansce_fp5_kernel(t_A, Δt, flight_path_length): r"""Resolution function kernel based on :cite:`lynn2002neutron`. Args: t_A (scalar): time-of-arrival of the neutron in :math:`\mathrm{μs}`. Δt (scalar): time sampling bin width in :math:`\mathrm{μs}`. flight_path_length (scalar): flight path length in :math:`\mathrm{m}`. Returns: Kernel array. """ E = time2energy(t_A, flight_path_length) v1 = 6 T1 = 0.74 / np.sqrt(E) / Δt t1 = 0.49 / np.sqrt(E) / Δt v2 = 4.3 T2 = 5.1 / np.sqrt(E) / Δt t2 = 2.2 / np.sqrt(E) / Δt w1 = 0.65 mode = t1 + T1 * v1 / 2 - T1 # This is only true thresh = t2 + v2 * T2 # for these values of v, T, t x = np.arange(-np.ceil(thresh), np.ceil(thresh)) fi = lambda x, v, t, T: ((x - t) ** (v / 2 - 1) / (gamma(v / 2) * T ** (v / 2))) * np.exp( -(x - t) / T ) f1 = np.zeros_like(x) f1[x + mode > t1] = fi(x[x + mode > t1] + mode, v1, t1, T1) f2 = np.zeros_like(x) f2[x + mode > t2] = fi(x[x + mode > t2] + mode, v2, t2, T2) f = f1 * w1 + f2 * (1 - w1) f = f / np.sum(f) return f
[docs]def equispaced_kernels(t_A, num_kernels, kernel_generator): r"""Generate a list of resolution function kernels. Generate a list of resolution function kernels that correspond to equispaced time-of-flights and a given kernel_generator. Args: t_A (array): time-of-arrival array of the neutrons in :math:`\mathrm{μs}`. num_kernels (int): Number of kernels to be generated. If `1`, single kernel with average time-of-arrival is being generated. kernel_generator (function): Function with single argument that generates a kernel array based on the time-of-arrival. Returns: list of Kernel arrays. """ if num_kernels > 1: t_As = np.linspace(t_A[0], t_A[-1], num=num_kernels) else: t_As = [np.mean(t_A)] return [kernel_generator(t) for t in t_As], t_As