"""Real spherical harmonics in Cartesian form for PyTorch. This is an autogenerated file. See https://github.com/cheind/torch-spherical-harmonics for more information. """ import torch def rsh_cart_0(xyz: torch.Tensor): """Computes all real spherical harmonics up to degree 0. This is an autogenerated method. See https://github.com/cheind/torch-spherical-harmonics for more information. Params: xyz: (N,...,3) tensor of points on the unit sphere Returns: rsh: (N,...,1) real spherical harmonics projections of input. Ynm is found at index `n*(n+1) + m`, with `0 <= n <= degree` and `-n <= m <= n`. """ return torch.stack( [ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), ], -1, ) def rsh_cart_1(xyz: torch.Tensor): """Computes all real spherical harmonics up to degree 1. This is an autogenerated method. See https://github.com/cheind/torch-spherical-harmonics for more information. Params: xyz: (N,...,3) tensor of points on the unit sphere Returns: rsh: (N,...,4) real spherical harmonics projections of input. Ynm is found at index `n*(n+1) + m`, with `0 <= n <= degree` and `-n <= m <= n`. """ x = xyz[..., 0] y = xyz[..., 1] z = xyz[..., 2] return torch.stack( [ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), -0.48860251190292 * y, 0.48860251190292 * z, -0.48860251190292 * x, ], -1, ) def rsh_cart_2(xyz: torch.Tensor): """Computes all real spherical harmonics up to degree 2. This is an autogenerated method. See https://github.com/cheind/torch-spherical-harmonics for more information. Params: xyz: (N,...,3) tensor of points on the unit sphere Returns: rsh: (N,...,9) real spherical harmonics projections of input. Ynm is found at index `n*(n+1) + m`, with `0 <= n <= degree` and `-n <= m <= n`. """ x = xyz[..., 0] y = xyz[..., 1] z = xyz[..., 2] x2 = x**2 y2 = y**2 z2 = z**2 xy = x * y xz = x * z yz = y * z return torch.stack( [ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), -0.48860251190292 * y, 0.48860251190292 * z, -0.48860251190292 * x, 1.09254843059208 * xy, -1.09254843059208 * yz, 0.94617469575756 * z2 - 0.31539156525252, -1.09254843059208 * xz, 0.54627421529604 * x2 - 0.54627421529604 * y2, ], -1, ) def rsh_cart_3(xyz: torch.Tensor): """Computes all real spherical harmonics up to degree 3. This is an autogenerated method. See https://github.com/cheind/torch-spherical-harmonics for more information. Params: xyz: (N,...,3) tensor of points on the unit sphere Returns: rsh: (N,...,16) real spherical harmonics projections of input. Ynm is found at index `n*(n+1) + m`, with `0 <= n <= degree` and `-n <= m <= n`. """ x = xyz[..., 0] y = xyz[..., 1] z = xyz[..., 2] x2 = x**2 y2 = y**2 z2 = z**2 xy = x * y xz = x * z yz = y * z return torch.stack( [ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), -0.48860251190292 * y, 0.48860251190292 * z, -0.48860251190292 * x, 1.09254843059208 * xy, -1.09254843059208 * yz, 0.94617469575756 * z2 - 0.31539156525252, -1.09254843059208 * xz, 0.54627421529604 * x2 - 0.54627421529604 * y2, -0.590043589926644 * y * (3.0 * x2 - y2), 2.89061144264055 * xy * z, 0.304697199642977 * y * (1.5 - 7.5 * z2), 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, 0.304697199642977 * x * (1.5 - 7.5 * z2), 1.44530572132028 * z * (x2 - y2), -0.590043589926644 * x * (x2 - 3.0 * y2), ], -1, ) def rsh_cart_4(xyz: torch.Tensor): """Computes all real spherical harmonics up to degree 4. This is an autogenerated method. See https://github.com/cheind/torch-spherical-harmonics for more information. Params: xyz: (N,...,3) tensor of points on the unit sphere Returns: rsh: (N,...,25) real spherical harmonics projections of input. Ynm is found at index `n*(n+1) + m`, with `0 <= n <= degree` and `-n <= m <= n`. """ x = xyz[..., 0] y = xyz[..., 1] z = xyz[..., 2] x2 = x**2 y2 = y**2 z2 = z**2 xy = x * y xz = x * z yz = y * z x4 = x2**2 y4 = y2**2 z4 = z2**2 return torch.stack( [ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), -0.48860251190292 * y, 0.48860251190292 * z, -0.48860251190292 * x, 1.09254843059208 * xy, -1.09254843059208 * yz, 0.94617469575756 * z2 - 0.31539156525252, -1.09254843059208 * xz, 0.54627421529604 * x2 - 0.54627421529604 * y2, -0.590043589926644 * y * (3.0 * x2 - y2), 2.89061144264055 * xy * z, 0.304697199642977 * y * (1.5 - 7.5 * z2), 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, 0.304697199642977 * x * (1.5 - 7.5 * z2), 1.44530572132028 * z * (x2 - y2), -0.590043589926644 * x * (x2 - 3.0 * y2), 2.5033429417967 * xy * (x2 - y2), -1.77013076977993 * yz * (3.0 * x2 - y2), 0.126156626101008 * xy * (52.5 * z2 - 7.5), 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), 1.48099765681286 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) - 0.952069922236839 * z2 + 0.317356640745613, 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5), -1.77013076977993 * xz * (x2 - 3.0 * y2), -3.75501441269506 * x2 * y2 + 0.625835735449176 * x4 + 0.625835735449176 * y4, ], -1, ) def rsh_cart_5(xyz: torch.Tensor): """Computes all real spherical harmonics up to degree 5. This is an autogenerated method. See https://github.com/cheind/torch-spherical-harmonics for more information. Params: xyz: (N,...,3) tensor of points on the unit sphere Returns: rsh: (N,...,36) real spherical harmonics projections of input. Ynm is found at index `n*(n+1) + m`, with `0 <= n <= degree` and `-n <= m <= n`. """ x = xyz[..., 0] y = xyz[..., 1] z = xyz[..., 2] x2 = x**2 y2 = y**2 z2 = z**2 xy = x * y xz = x * z yz = y * z x4 = x2**2 y4 = y2**2 z4 = z2**2 return torch.stack( [ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), -0.48860251190292 * y, 0.48860251190292 * z, -0.48860251190292 * x, 1.09254843059208 * xy, -1.09254843059208 * yz, 0.94617469575756 * z2 - 0.31539156525252, -1.09254843059208 * xz, 0.54627421529604 * x2 - 0.54627421529604 * y2, -0.590043589926644 * y * (3.0 * x2 - y2), 2.89061144264055 * xy * z, 0.304697199642977 * y * (1.5 - 7.5 * z2), 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, 0.304697199642977 * x * (1.5 - 7.5 * z2), 1.44530572132028 * z * (x2 - y2), -0.590043589926644 * x * (x2 - 3.0 * y2), 2.5033429417967 * xy * (x2 - y2), -1.77013076977993 * yz * (3.0 * x2 - y2), 0.126156626101008 * xy * (52.5 * z2 - 7.5), 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), 1.48099765681286 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) - 0.952069922236839 * z2 + 0.317356640745613, 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5), -1.77013076977993 * xz * (x2 - 3.0 * y2), -3.75501441269506 * x2 * y2 + 0.625835735449176 * x4 + 0.625835735449176 * y4, -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4), 8.30264925952416 * xy * z * (x2 - y2), 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2), 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), 0.241571547304372 * y * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ), -1.24747010616985 * z * (1.5 * z2 - 0.5) + 1.6840846433293 * z * ( 1.75 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) - 1.125 * z2 + 0.375 ) + 0.498988042467941 * z, 0.241571547304372 * x * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ), 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2), 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4), -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4), ], -1, ) def rsh_cart_6(xyz: torch.Tensor): """Computes all real spherical harmonics up to degree 6. This is an autogenerated method. See https://github.com/cheind/torch-spherical-harmonics for more information. Params: xyz: (N,...,3) tensor of points on the unit sphere Returns: rsh: (N,...,49) real spherical harmonics projections of input. Ynm is found at index `n*(n+1) + m`, with `0 <= n <= degree` and `-n <= m <= n`. """ x = xyz[..., 0] y = xyz[..., 1] z = xyz[..., 2] x2 = x**2 y2 = y**2 z2 = z**2 xy = x * y xz = x * z yz = y * z x4 = x2**2 y4 = y2**2 z4 = z2**2 return torch.stack( [ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), -0.48860251190292 * y, 0.48860251190292 * z, -0.48860251190292 * x, 1.09254843059208 * xy, -1.09254843059208 * yz, 0.94617469575756 * z2 - 0.31539156525252, -1.09254843059208 * xz, 0.54627421529604 * x2 - 0.54627421529604 * y2, -0.590043589926644 * y * (3.0 * x2 - y2), 2.89061144264055 * xy * z, 0.304697199642977 * y * (1.5 - 7.5 * z2), 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, 0.304697199642977 * x * (1.5 - 7.5 * z2), 1.44530572132028 * z * (x2 - y2), -0.590043589926644 * x * (x2 - 3.0 * y2), 2.5033429417967 * xy * (x2 - y2), -1.77013076977993 * yz * (3.0 * x2 - y2), 0.126156626101008 * xy * (52.5 * z2 - 7.5), 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), 1.48099765681286 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) - 0.952069922236839 * z2 + 0.317356640745613, 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5), -1.77013076977993 * xz * (x2 - 3.0 * y2), -3.75501441269506 * x2 * y2 + 0.625835735449176 * x4 + 0.625835735449176 * y4, -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4), 8.30264925952416 * xy * z * (x2 - y2), 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2), 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), 0.241571547304372 * y * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ), -1.24747010616985 * z * (1.5 * z2 - 0.5) + 1.6840846433293 * z * ( 1.75 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) - 1.125 * z2 + 0.375 ) + 0.498988042467941 * z, 0.241571547304372 * x * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ), 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2), 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4), -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4), 4.09910463115149 * x**4 * xy - 13.6636821038383 * xy**3 + 4.09910463115149 * xy * y**4, -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4), 0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5), 0.00584892228263444 * y * (3.0 * x2 - y2) * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), 0.0701870673916132 * xy * ( 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) - 91.875 * z2 + 13.125 ), 0.221950995245231 * y * ( -2.8 * z * (1.5 - 7.5 * z2) + 2.2 * z * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ) - 4.8 * z ), -1.48328138624466 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + 1.86469659985043 * z * ( -1.33333333333333 * z * (1.5 * z2 - 0.5) + 1.8 * z * ( 1.75 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) - 1.125 * z2 + 0.375 ) + 0.533333333333333 * z ) + 0.953538034014426 * z2 - 0.317846011338142, 0.221950995245231 * x * ( -2.8 * z * (1.5 - 7.5 * z2) + 2.2 * z * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ) - 4.8 * z ), 0.0350935336958066 * (x2 - y2) * ( 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) - 91.875 * z2 + 13.125 ), 0.00584892228263444 * x * (x2 - 3.0 * y2) * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), 0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4), -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4), 0.683184105191914 * x2**3 + 10.2477615778787 * x2 * y4 - 10.2477615778787 * x4 * y2 - 0.683184105191914 * y2**3, ], -1, ) def rsh_cart_7(xyz: torch.Tensor): """Computes all real spherical harmonics up to degree 7. This is an autogenerated method. See https://github.com/cheind/torch-spherical-harmonics for more information. Params: xyz: (N,...,3) tensor of points on the unit sphere Returns: rsh: (N,...,64) real spherical harmonics projections of input. Ynm is found at index `n*(n+1) + m`, with `0 <= n <= degree` and `-n <= m <= n`. """ x = xyz[..., 0] y = xyz[..., 1] z = xyz[..., 2] x2 = x**2 y2 = y**2 z2 = z**2 xy = x * y xz = x * z yz = y * z x4 = x2**2 y4 = y2**2 z4 = z2**2 return torch.stack( [ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), -0.48860251190292 * y, 0.48860251190292 * z, -0.48860251190292 * x, 1.09254843059208 * xy, -1.09254843059208 * yz, 0.94617469575756 * z2 - 0.31539156525252, -1.09254843059208 * xz, 0.54627421529604 * x2 - 0.54627421529604 * y2, -0.590043589926644 * y * (3.0 * x2 - y2), 2.89061144264055 * xy * z, 0.304697199642977 * y * (1.5 - 7.5 * z2), 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, 0.304697199642977 * x * (1.5 - 7.5 * z2), 1.44530572132028 * z * (x2 - y2), -0.590043589926644 * x * (x2 - 3.0 * y2), 2.5033429417967 * xy * (x2 - y2), -1.77013076977993 * yz * (3.0 * x2 - y2), 0.126156626101008 * xy * (52.5 * z2 - 7.5), 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), 1.48099765681286 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) - 0.952069922236839 * z2 + 0.317356640745613, 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5), -1.77013076977993 * xz * (x2 - 3.0 * y2), -3.75501441269506 * x2 * y2 + 0.625835735449176 * x4 + 0.625835735449176 * y4, -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4), 8.30264925952416 * xy * z * (x2 - y2), 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2), 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), 0.241571547304372 * y * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ), -1.24747010616985 * z * (1.5 * z2 - 0.5) + 1.6840846433293 * z * ( 1.75 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) - 1.125 * z2 + 0.375 ) + 0.498988042467941 * z, 0.241571547304372 * x * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ), 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2), 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4), -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4), 4.09910463115149 * x**4 * xy - 13.6636821038383 * xy**3 + 4.09910463115149 * xy * y**4, -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4), 0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5), 0.00584892228263444 * y * (3.0 * x2 - y2) * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), 0.0701870673916132 * xy * ( 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) - 91.875 * z2 + 13.125 ), 0.221950995245231 * y * ( -2.8 * z * (1.5 - 7.5 * z2) + 2.2 * z * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ) - 4.8 * z ), -1.48328138624466 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + 1.86469659985043 * z * ( -1.33333333333333 * z * (1.5 * z2 - 0.5) + 1.8 * z * ( 1.75 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) - 1.125 * z2 + 0.375 ) + 0.533333333333333 * z ) + 0.953538034014426 * z2 - 0.317846011338142, 0.221950995245231 * x * ( -2.8 * z * (1.5 - 7.5 * z2) + 2.2 * z * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ) - 4.8 * z ), 0.0350935336958066 * (x2 - y2) * ( 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) - 91.875 * z2 + 13.125 ), 0.00584892228263444 * x * (x2 - 3.0 * y2) * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), 0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4), -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4), 0.683184105191914 * x2**3 + 10.2477615778787 * x2 * y4 - 10.2477615778787 * x4 * y2 - 0.683184105191914 * y2**3, -0.707162732524596 * y * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3), 2.6459606618019 * z * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4), 9.98394571852353e-5 * y * (5197.5 - 67567.5 * z2) * (-10.0 * x2 * y2 + 5.0 * x4 + y4), 0.00239614697244565 * xy * (x2 - y2) * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z), 0.00397356022507413 * y * (3.0 * x2 - y2) * ( 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) + 1063.125 * z2 - 118.125 ), 0.0561946276120613 * xy * ( -4.8 * z * (52.5 * z2 - 7.5) + 2.6 * z * ( 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) - 91.875 * z2 + 13.125 ) + 48.0 * z ), 0.206472245902897 * y * ( -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 2.16666666666667 * z * ( -2.8 * z * (1.5 - 7.5 * z2) + 2.2 * z * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ) - 4.8 * z ) - 10.9375 * z2 + 2.1875 ), 1.24862677781952 * z * (1.5 * z2 - 0.5) - 1.68564615005635 * z * ( 1.75 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) - 1.125 * z2 + 0.375 ) + 2.02901851395672 * z * ( -1.45833333333333 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + 1.83333333333333 * z * ( -1.33333333333333 * z * (1.5 * z2 - 0.5) + 1.8 * z * ( 1.75 * z * ( 1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z ) - 1.125 * z2 + 0.375 ) + 0.533333333333333 * z ) + 0.9375 * z2 - 0.3125 ) - 0.499450711127808 * z, 0.206472245902897 * x * ( -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 2.16666666666667 * z * ( -2.8 * z * (1.5 - 7.5 * z2) + 2.2 * z * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ) - 4.8 * z ) - 10.9375 * z2 + 2.1875 ), 0.0280973138060306 * (x2 - y2) * ( -4.8 * z * (52.5 * z2 - 7.5) + 2.6 * z * ( 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) - 91.875 * z2 + 13.125 ) + 48.0 * z ), 0.00397356022507413 * x * (x2 - 3.0 * y2) * ( 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) + 1063.125 * z2 - 118.125 ), 0.000599036743111412 * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z) * (-6.0 * x2 * y2 + x4 + y4), 9.98394571852353e-5 * x * (5197.5 - 67567.5 * z2) * (-10.0 * x2 * y2 + x4 + 5.0 * y4), 2.6459606618019 * z * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3), -0.707162732524596 * x * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3), ], -1, ) # @torch.jit.script def rsh_cart_8(xyz: torch.Tensor): """Computes all real spherical harmonics up to degree 8. This is an autogenerated method. See https://github.com/cheind/torch-spherical-harmonics for more information. Params: xyz: (N,...,3) tensor of points on the unit sphere Returns: rsh: (N,...,81) real spherical harmonics projections of input. Ynm is found at index `n*(n+1) + m`, with `0 <= n <= degree` and `-n <= m <= n`. """ x = xyz[..., 0] y = xyz[..., 1] z = xyz[..., 2] x2 = x**2 y2 = y**2 z2 = z**2 xy = x * y xz = x * z yz = y * z x4 = x2**2 y4 = y2**2 # z4 = z2**2 return torch.stack( [ 0.282094791773878 * torch.ones(1, device=xyz.device).expand(xyz.shape[:-1]), -0.48860251190292 * y, 0.48860251190292 * z, -0.48860251190292 * x, 1.09254843059208 * xy, -1.09254843059208 * yz, 0.94617469575756 * z2 - 0.31539156525252, -1.09254843059208 * xz, 0.54627421529604 * x2 - 0.54627421529604 * y2, -0.590043589926644 * y * (3.0 * x2 - y2), 2.89061144264055 * xy * z, 0.304697199642977 * y * (1.5 - 7.5 * z2), 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, 0.304697199642977 * x * (1.5 - 7.5 * z2), 1.44530572132028 * z * (x2 - y2), -0.590043589926644 * x * (x2 - 3.0 * y2), 2.5033429417967 * xy * (x2 - y2), -1.77013076977993 * yz * (3.0 * x2 - y2), 0.126156626101008 * xy * (52.5 * z2 - 7.5), 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), 1.48099765681286 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) - 0.952069922236839 * z2 + 0.317356640745613, 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5), -1.77013076977993 * xz * (x2 - 3.0 * y2), -3.75501441269506 * x2 * y2 + 0.625835735449176 * x4 + 0.625835735449176 * y4, -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4), 8.30264925952416 * xy * z * (x2 - y2), 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2), 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), 0.241571547304372 * y * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ), -1.24747010616985 * z * (1.5 * z2 - 0.5) + 1.6840846433293 * z * ( 1.75 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) - 1.125 * z2 + 0.375 ) + 0.498988042467941 * z, 0.241571547304372 * x * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ), 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2), 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4), -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4), 4.09910463115149 * x**4 * xy - 13.6636821038383 * xy**3 + 4.09910463115149 * xy * y**4, -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4), 0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5), 0.00584892228263444 * y * (3.0 * x2 - y2) * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), 0.0701870673916132 * xy * ( 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) - 91.875 * z2 + 13.125 ), 0.221950995245231 * y * ( -2.8 * z * (1.5 - 7.5 * z2) + 2.2 * z * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ) - 4.8 * z ), -1.48328138624466 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + 1.86469659985043 * z * ( -1.33333333333333 * z * (1.5 * z2 - 0.5) + 1.8 * z * ( 1.75 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) - 1.125 * z2 + 0.375 ) + 0.533333333333333 * z ) + 0.953538034014426 * z2 - 0.317846011338142, 0.221950995245231 * x * ( -2.8 * z * (1.5 - 7.5 * z2) + 2.2 * z * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ) - 4.8 * z ), 0.0350935336958066 * (x2 - y2) * ( 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) - 91.875 * z2 + 13.125 ), 0.00584892228263444 * x * (x2 - 3.0 * y2) * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), 0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4), -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4), 0.683184105191914 * x2**3 + 10.2477615778787 * x2 * y4 - 10.2477615778787 * x4 * y2 - 0.683184105191914 * y2**3, -0.707162732524596 * y * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3), 2.6459606618019 * z * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4), 9.98394571852353e-5 * y * (5197.5 - 67567.5 * z2) * (-10.0 * x2 * y2 + 5.0 * x4 + y4), 0.00239614697244565 * xy * (x2 - y2) * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z), 0.00397356022507413 * y * (3.0 * x2 - y2) * ( 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) + 1063.125 * z2 - 118.125 ), 0.0561946276120613 * xy * ( -4.8 * z * (52.5 * z2 - 7.5) + 2.6 * z * ( 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) - 91.875 * z2 + 13.125 ) + 48.0 * z ), 0.206472245902897 * y * ( -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 2.16666666666667 * z * ( -2.8 * z * (1.5 - 7.5 * z2) + 2.2 * z * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ) - 4.8 * z ) - 10.9375 * z2 + 2.1875 ), 1.24862677781952 * z * (1.5 * z2 - 0.5) - 1.68564615005635 * z * ( 1.75 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) - 1.125 * z2 + 0.375 ) + 2.02901851395672 * z * ( -1.45833333333333 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + 1.83333333333333 * z * ( -1.33333333333333 * z * (1.5 * z2 - 0.5) + 1.8 * z * ( 1.75 * z * ( 1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z ) - 1.125 * z2 + 0.375 ) + 0.533333333333333 * z ) + 0.9375 * z2 - 0.3125 ) - 0.499450711127808 * z, 0.206472245902897 * x * ( -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 2.16666666666667 * z * ( -2.8 * z * (1.5 - 7.5 * z2) + 2.2 * z * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ) - 4.8 * z ) - 10.9375 * z2 + 2.1875 ), 0.0280973138060306 * (x2 - y2) * ( -4.8 * z * (52.5 * z2 - 7.5) + 2.6 * z * ( 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) - 91.875 * z2 + 13.125 ) + 48.0 * z ), 0.00397356022507413 * x * (x2 - 3.0 * y2) * ( 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) + 1063.125 * z2 - 118.125 ), 0.000599036743111412 * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z) * (-6.0 * x2 * y2 + x4 + y4), 9.98394571852353e-5 * x * (5197.5 - 67567.5 * z2) * (-10.0 * x2 * y2 + x4 + 5.0 * y4), 2.6459606618019 * z * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3), -0.707162732524596 * x * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3), 5.83141328139864 * xy * (x2**3 + 7.0 * x2 * y4 - 7.0 * x4 * y2 - y2**3), -2.91570664069932 * yz * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3), 7.87853281621404e-6 * (1013512.5 * z2 - 67567.5) * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4), 5.10587282657803e-5 * y * (5.0 * z * (5197.5 - 67567.5 * z2) + 41580.0 * z) * (-10.0 * x2 * y2 + 5.0 * x4 + y4), 0.00147275890257803 * xy * (x2 - y2) * ( 3.75 * z * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z) - 14293.125 * z2 + 1299.375 ), 0.0028519853513317 * y * (3.0 * x2 - y2) * ( -7.33333333333333 * z * (52.5 - 472.5 * z2) + 3.0 * z * ( 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) + 1063.125 * z2 - 118.125 ) - 560.0 * z ), 0.0463392770473559 * xy * ( -4.125 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + 2.5 * z * ( -4.8 * z * (52.5 * z2 - 7.5) + 2.6 * z * ( 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) - 91.875 * z2 + 13.125 ) + 48.0 * z ) + 137.8125 * z2 - 19.6875 ), 0.193851103820053 * y * ( 3.2 * z * (1.5 - 7.5 * z2) - 2.51428571428571 * z * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ) + 2.14285714285714 * z * ( -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 2.16666666666667 * z * ( -2.8 * z * (1.5 - 7.5 * z2) + 2.2 * z * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ) - 4.8 * z ) - 10.9375 * z2 + 2.1875 ) + 5.48571428571429 * z ), 1.48417251362228 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) - 1.86581687426801 * z * ( -1.33333333333333 * z * (1.5 * z2 - 0.5) + 1.8 * z * ( 1.75 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) - 1.125 * z2 + 0.375 ) + 0.533333333333333 * z ) + 2.1808249179756 * z * ( 1.14285714285714 * z * (1.5 * z2 - 0.5) - 1.54285714285714 * z * ( 1.75 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) - 1.125 * z2 + 0.375 ) + 1.85714285714286 * z * ( -1.45833333333333 * z * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + 1.83333333333333 * z * ( -1.33333333333333 * z * (1.5 * z2 - 0.5) + 1.8 * z * ( 1.75 * z * ( 1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z ) - 1.125 * z2 + 0.375 ) + 0.533333333333333 * z ) + 0.9375 * z2 - 0.3125 ) - 0.457142857142857 * z ) - 0.954110901614325 * z2 + 0.318036967204775, 0.193851103820053 * x * ( 3.2 * z * (1.5 - 7.5 * z2) - 2.51428571428571 * z * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ) + 2.14285714285714 * z * ( -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 2.16666666666667 * z * ( -2.8 * z * (1.5 - 7.5 * z2) + 2.2 * z * ( 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + 9.375 * z2 - 1.875 ) - 4.8 * z ) - 10.9375 * z2 + 2.1875 ) + 5.48571428571429 * z ), 0.0231696385236779 * (x2 - y2) * ( -4.125 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + 2.5 * z * ( -4.8 * z * (52.5 * z2 - 7.5) + 2.6 * z * ( 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) - 91.875 * z2 + 13.125 ) + 48.0 * z ) + 137.8125 * z2 - 19.6875 ), 0.0028519853513317 * x * (x2 - 3.0 * y2) * ( -7.33333333333333 * z * (52.5 - 472.5 * z2) + 3.0 * z * ( 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) + 1063.125 * z2 - 118.125 ) - 560.0 * z ), 0.000368189725644507 * (-6.0 * x2 * y2 + x4 + y4) * ( 3.75 * z * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z) - 14293.125 * z2 + 1299.375 ), 5.10587282657803e-5 * x * (5.0 * z * (5197.5 - 67567.5 * z2) + 41580.0 * z) * (-10.0 * x2 * y2 + x4 + 5.0 * y4), 7.87853281621404e-6 * (1013512.5 * z2 - 67567.5) * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3), -2.91570664069932 * xz * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3), -20.4099464848952 * x2**3 * y2 - 20.4099464848952 * x2 * y2**3 + 0.72892666017483 * x4**2 + 51.0248662122381 * x4 * y4 + 0.72892666017483 * y4**2, ], -1, ) __all__ = [ "rsh_cart_0", "rsh_cart_1", "rsh_cart_2", "rsh_cart_3", "rsh_cart_4", "rsh_cart_5", "rsh_cart_6", "rsh_cart_7", "rsh_cart_8", ] from typing import Optional import torch class SphHarm(torch.nn.Module): def __init__(self, m, n, dtype=torch.float32) -> None: super().__init__() self.dtype = dtype m = torch.tensor(list(range(-m + 1, m))) n = torch.tensor(list(range(n))) self.is_normalized = False vals = torch.cartesian_prod(m, n).T vals = vals[:, vals[0] <= vals[1]] m, n = vals.unbind(0) self.register_buffer("m", tensor=m) self.register_buffer("n", tensor=n) self.register_buffer("l_max", tensor=torch.max(self.n)) f_a, f_b, initial_value, d0_mask_3d, d1_mask_3d = self._init_legendre() self.register_buffer("f_a", tensor=f_a) self.register_buffer("f_b", tensor=f_b) self.register_buffer("d0_mask_3d", tensor=d0_mask_3d) self.register_buffer("d1_mask_3d", tensor=d1_mask_3d) self.register_buffer("initial_value", tensor=initial_value) @property def device(self): return next(self.buffers()).device def forward(self, points: torch.Tensor) -> torch.Tensor: """Computes the spherical harmonics.""" # Y_l^m = (-1) ^ m c_l^m P_l^m(cos(theta)) exp(i m phi) B, N, D = points.shape dtype = points.dtype theta, phi = points.view(-1, D).to(self.dtype).unbind(-1) cos_colatitude = torch.cos(phi) legendre = self._gen_associated_legendre(cos_colatitude) vals = torch.stack([self.m.abs(), self.n], dim=0) vals = torch.cat( [ vals.repeat(1, theta.shape[0]), torch.arange(theta.shape[0], device=theta.device) .unsqueeze(0) .repeat_interleave(vals.shape[1], dim=1), ], dim=0, ) legendre_vals = legendre[vals[0], vals[1], vals[2]] legendre_vals = legendre_vals.reshape(-1, theta.shape[0]) angle = torch.outer(self.m.abs(), theta) vandermonde = torch.complex(torch.cos(angle), torch.sin(angle)) harmonics = torch.complex( legendre_vals * torch.real(vandermonde), legendre_vals * torch.imag(vandermonde), ) # Negative order. m = self.m.unsqueeze(-1) harmonics = torch.where( m < 0, (-1.0) ** m.abs() * torch.conj(harmonics), harmonics ) harmonics = harmonics.permute(1, 0).reshape(B, N, -1).to(dtype) return harmonics def _gen_recurrence_mask(self) -> tuple[torch.Tensor, torch.Tensor]: """Generates mask for recurrence relation on the remaining entries. The remaining entries are with respect to the diagonal and offdiagonal entries. Args: l_max: see `gen_normalized_legendre`. Returns: torch.Tensors representing the mask used by the recurrence relations. """ # Computes all coefficients. m_mat, l_mat = torch.meshgrid( torch.arange(0, self.l_max + 1, device=self.device, dtype=self.dtype), torch.arange(0, self.l_max + 1, device=self.device, dtype=self.dtype), indexing="ij", ) if self.is_normalized: c0 = l_mat * l_mat c1 = m_mat * m_mat c2 = 2.0 * l_mat c3 = (l_mat - 1.0) * (l_mat - 1.0) d0 = torch.sqrt((4.0 * c0 - 1.0) / (c0 - c1)) d1 = torch.sqrt(((c2 + 1.0) * (c3 - c1)) / ((c2 - 3.0) * (c0 - c1))) else: d0 = (2.0 * l_mat - 1.0) / (l_mat - m_mat) d1 = (l_mat + m_mat - 1.0) / (l_mat - m_mat) d0_mask_indices = torch.triu_indices(self.l_max + 1, 1) d1_mask_indices = torch.triu_indices(self.l_max + 1, 2) d_zeros = torch.zeros( (self.l_max + 1, self.l_max + 1), dtype=self.dtype, device=self.device ) d_zeros[d0_mask_indices] = d0[d0_mask_indices] d0_mask = d_zeros d_zeros = torch.zeros( (self.l_max + 1, self.l_max + 1), dtype=self.dtype, device=self.device ) d_zeros[d1_mask_indices] = d1[d1_mask_indices] d1_mask = d_zeros # Creates a 3D mask that contains 1s on the diagonal plane and 0s elsewhere. i = torch.arange(self.l_max + 1, device=self.device)[:, None, None] j = torch.arange(self.l_max + 1, device=self.device)[None, :, None] k = torch.arange(self.l_max + 1, device=self.device)[None, None, :] mask = (i + j - k == 0).to(self.dtype) d0_mask_3d = torch.einsum("jk,ijk->ijk", d0_mask, mask) d1_mask_3d = torch.einsum("jk,ijk->ijk", d1_mask, mask) return (d0_mask_3d, d1_mask_3d) def _recursive(self, i: int, p_val: torch.Tensor, x: torch.Tensor) -> torch.Tensor: coeff_0 = self.d0_mask_3d[i] coeff_1 = self.d1_mask_3d[i] h = torch.einsum( "ij,ijk->ijk", coeff_0, torch.einsum("ijk,k->ijk", torch.roll(p_val, shifts=1, dims=1), x), ) - torch.einsum("ij,ijk->ijk", coeff_1, torch.roll(p_val, shifts=2, dims=1)) p_val = p_val + h return p_val def _init_legendre(self): a_idx = torch.arange(1, self.l_max + 1, dtype=self.dtype, device=self.device) b_idx = torch.arange(self.l_max, dtype=self.dtype, device=self.device) if self.is_normalized: # The initial value p(0,0). initial_value: torch.Tensor = torch.tensor( 0.5 / (torch.pi**0.5), device=self.device ) f_a = torch.cumprod(-1 * torch.sqrt(1.0 + 0.5 / a_idx), dim=0) f_b = torch.sqrt(2.0 * b_idx + 3.0) else: # The initial value p(0,0). initial_value = torch.tensor(1.0, device=self.device) f_a = torch.cumprod(1.0 - 2.0 * a_idx, dim=0) f_b = 2.0 * b_idx + 1.0 d0_mask_3d, d1_mask_3d = self._gen_recurrence_mask() return f_a, f_b, initial_value, d0_mask_3d, d1_mask_3d def _gen_associated_legendre(self, x: torch.Tensor) -> torch.Tensor: r"""Computes associated Legendre functions (ALFs) of the first kind. The ALFs of the first kind are used in spherical harmonics. The spherical harmonic of degree `l` and order `m` can be written as `Y_l^m(θ, φ) = N_l^m * P_l^m(cos(θ)) * exp(i m φ)`, where `N_l^m` is the normalization factor and θ and φ are the colatitude and longitude, repectively. `N_l^m` is chosen in the way that the spherical harmonics form a set of orthonormal basis function of L^2(S^2). For the computational efficiency of spherical harmonics transform, the normalization factor is used in the computation of the ALFs. In addition, normalizing `P_l^m` avoids overflow/underflow and achieves better numerical stability. Three recurrence relations are used in the computation. Args: l_max: The maximum degree of the associated Legendre function. Both the degrees and orders are `[0, 1, 2, ..., l_max]`. x: A vector of type `float32`, `float64` containing the sampled points in spherical coordinates, at which the ALFs are computed; `x` is essentially `cos(θ)`. For the numerical integration used by the spherical harmonics transforms, `x` contains the quadrature points in the interval of `[-1, 1]`. There are several approaches to provide the quadrature points: Gauss-Legendre method (`scipy.special.roots_legendre`), Gauss-Chebyshev method (`scipy.special.roots_chebyu`), and Driscoll & Healy method (Driscoll, James R., and Dennis M. Healy. "Computing Fourier transforms and convolutions on the 2-sphere." Advances in applied mathematics 15, no. 2 (1994): 202-250.). The Gauss-Legendre quadrature points are nearly equal-spaced along θ and provide exact discrete orthogonality, (P^m)^T W P_m = I, where `T` represents the transpose operation, `W` is a diagonal matrix containing the quadrature weights, and `I` is the identity matrix. The Gauss-Chebyshev points are equally spaced, which only provide approximate discrete orthogonality. The Driscoll & Healy qudarture points are equally spaced and provide the exact discrete orthogonality. The number of sampling points is required to be twice as the number of frequency points (modes) in the Driscoll & Healy approach, which enables FFT and achieves a fast spherical harmonics transform. is_normalized: True if the associated Legendre functions are normalized. With normalization, `N_l^m` is applied such that the spherical harmonics form a set of orthonormal basis functions of L^2(S^2). Returns: The 3D array of shape `(l_max + 1, l_max + 1, len(x))` containing the values of the ALFs at `x`; the dimensions in the sequence of order, degree, and evalution points. """ p = torch.zeros( (self.l_max + 1, self.l_max + 1, x.shape[0]), dtype=x.dtype, device=x.device ) p[0, 0] = self.initial_value # Compute the diagonal entries p(l,l) with recurrence. y = torch.cumprod( torch.broadcast_to(torch.sqrt(1.0 - x * x), (self.l_max, x.shape[0])), dim=0 ) p_diag = self.initial_value * torch.einsum("i,ij->ij", self.f_a, y) # torch.diag_indices(l_max + 1) diag_indices = torch.stack( [torch.arange(0, self.l_max + 1, device=x.device)] * 2, dim=0 ) p[(diag_indices[0][1:], diag_indices[1][1:])] = p_diag diag_indices = torch.stack( [torch.arange(0, self.l_max, device=x.device)] * 2, dim=0 ) # Compute the off-diagonal entries with recurrence. p_offdiag = torch.einsum( "ij,ij->ij", torch.einsum("i,j->ij", self.f_b, x), p[(diag_indices[0], diag_indices[1])], ) # p[torch.diag_indices(l_max)]) p[(diag_indices[0][: self.l_max], diag_indices[1][: self.l_max] + 1)] = ( p_offdiag ) # Compute the remaining entries with recurrence. if self.l_max > 1: for i in range(2, self.l_max + 1): p = self._recursive(i, p, x) return p