Source code for linear_operator.operators.psd_sum_linear_operator

#!/usr/bin/env python3
from jaxtyping import Float
from torch import Tensor

from linear_operator.operators._linear_operator import LinearOperator
from linear_operator.operators.sum_linear_operator import SumLinearOperator


[docs]class PsdSumLinearOperator(SumLinearOperator): """ A SumLinearOperator, but where every component of the sum is positive semi-definite """ def zero_mean_mvn_samples( self: Float[LinearOperator, "*batch N N"], num_samples: int ) -> Float[Tensor, "num_samples *batch N"]: return sum(linear_op.zero_mean_mvn_samples(num_samples) for linear_op in self.linear_ops)