smhh24's picture
Upload 90 files
560b597 verified
"""Real spherical harmonics in Cartesian form for PyTorch.
This is an autogenerated file. See
https://github.com/cheind/torch-spherical-harmonics
for more information.
"""
import torch
def rsh_cart_0(xyz: torch.Tensor):
"""Computes all real spherical harmonics up to degree 0.
This is an autogenerated method. See
https://github.com/cheind/torch-spherical-harmonics
for more information.
Params:
xyz: (N,...,3) tensor of points on the unit sphere
Returns:
rsh: (N,...,1) real spherical harmonics
projections of input. Ynm is found at index
`n*(n+1) + m`, with `0 <= n <= degree` and
`-n <= m <= n`.
"""
return torch.stack(
[
xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
],
-1,
)
def rsh_cart_1(xyz: torch.Tensor):
"""Computes all real spherical harmonics up to degree 1.
This is an autogenerated method. See
https://github.com/cheind/torch-spherical-harmonics
for more information.
Params:
xyz: (N,...,3) tensor of points on the unit sphere
Returns:
rsh: (N,...,4) real spherical harmonics
projections of input. Ynm is found at index
`n*(n+1) + m`, with `0 <= n <= degree` and
`-n <= m <= n`.
"""
x = xyz[..., 0]
y = xyz[..., 1]
z = xyz[..., 2]
return torch.stack(
[
xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
-0.48860251190292 * y,
0.48860251190292 * z,
-0.48860251190292 * x,
],
-1,
)
def rsh_cart_2(xyz: torch.Tensor):
"""Computes all real spherical harmonics up to degree 2.
This is an autogenerated method. See
https://github.com/cheind/torch-spherical-harmonics
for more information.
Params:
xyz: (N,...,3) tensor of points on the unit sphere
Returns:
rsh: (N,...,9) real spherical harmonics
projections of input. Ynm is found at index
`n*(n+1) + m`, with `0 <= n <= degree` and
`-n <= m <= n`.
"""
x = xyz[..., 0]
y = xyz[..., 1]
z = xyz[..., 2]
x2 = x**2
y2 = y**2
z2 = z**2
xy = x * y
xz = x * z
yz = y * z
return torch.stack(
[
xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
-0.48860251190292 * y,
0.48860251190292 * z,
-0.48860251190292 * x,
1.09254843059208 * xy,
-1.09254843059208 * yz,
0.94617469575756 * z2 - 0.31539156525252,
-1.09254843059208 * xz,
0.54627421529604 * x2 - 0.54627421529604 * y2,
],
-1,
)
def rsh_cart_3(xyz: torch.Tensor):
"""Computes all real spherical harmonics up to degree 3.
This is an autogenerated method. See
https://github.com/cheind/torch-spherical-harmonics
for more information.
Params:
xyz: (N,...,3) tensor of points on the unit sphere
Returns:
rsh: (N,...,16) real spherical harmonics
projections of input. Ynm is found at index
`n*(n+1) + m`, with `0 <= n <= degree` and
`-n <= m <= n`.
"""
x = xyz[..., 0]
y = xyz[..., 1]
z = xyz[..., 2]
x2 = x**2
y2 = y**2
z2 = z**2
xy = x * y
xz = x * z
yz = y * z
return torch.stack(
[
xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
-0.48860251190292 * y,
0.48860251190292 * z,
-0.48860251190292 * x,
1.09254843059208 * xy,
-1.09254843059208 * yz,
0.94617469575756 * z2 - 0.31539156525252,
-1.09254843059208 * xz,
0.54627421529604 * x2 - 0.54627421529604 * y2,
-0.590043589926644 * y * (3.0 * x2 - y2),
2.89061144264055 * xy * z,
0.304697199642977 * y * (1.5 - 7.5 * z2),
1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
0.304697199642977 * x * (1.5 - 7.5 * z2),
1.44530572132028 * z * (x2 - y2),
-0.590043589926644 * x * (x2 - 3.0 * y2),
],
-1,
)
def rsh_cart_4(xyz: torch.Tensor):
"""Computes all real spherical harmonics up to degree 4.
This is an autogenerated method. See
https://github.com/cheind/torch-spherical-harmonics
for more information.
Params:
xyz: (N,...,3) tensor of points on the unit sphere
Returns:
rsh: (N,...,25) real spherical harmonics
projections of input. Ynm is found at index
`n*(n+1) + m`, with `0 <= n <= degree` and
`-n <= m <= n`.
"""
x = xyz[..., 0]
y = xyz[..., 1]
z = xyz[..., 2]
x2 = x**2
y2 = y**2
z2 = z**2
xy = x * y
xz = x * z
yz = y * z
x4 = x2**2
y4 = y2**2
z4 = z2**2
return torch.stack(
[
xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
-0.48860251190292 * y,
0.48860251190292 * z,
-0.48860251190292 * x,
1.09254843059208 * xy,
-1.09254843059208 * yz,
0.94617469575756 * z2 - 0.31539156525252,
-1.09254843059208 * xz,
0.54627421529604 * x2 - 0.54627421529604 * y2,
-0.590043589926644 * y * (3.0 * x2 - y2),
2.89061144264055 * xy * z,
0.304697199642977 * y * (1.5 - 7.5 * z2),
1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
0.304697199642977 * x * (1.5 - 7.5 * z2),
1.44530572132028 * z * (x2 - y2),
-0.590043589926644 * x * (x2 - 3.0 * y2),
2.5033429417967 * xy * (x2 - y2),
-1.77013076977993 * yz * (3.0 * x2 - y2),
0.126156626101008 * xy * (52.5 * z2 - 7.5),
0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
1.48099765681286
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
- 0.952069922236839 * z2
+ 0.317356640745613,
0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
-1.77013076977993 * xz * (x2 - 3.0 * y2),
-3.75501441269506 * x2 * y2
+ 0.625835735449176 * x4
+ 0.625835735449176 * y4,
],
-1,
)
def rsh_cart_5(xyz: torch.Tensor):
"""Computes all real spherical harmonics up to degree 5.
This is an autogenerated method. See
https://github.com/cheind/torch-spherical-harmonics
for more information.
Params:
xyz: (N,...,3) tensor of points on the unit sphere
Returns:
rsh: (N,...,36) real spherical harmonics
projections of input. Ynm is found at index
`n*(n+1) + m`, with `0 <= n <= degree` and
`-n <= m <= n`.
"""
x = xyz[..., 0]
y = xyz[..., 1]
z = xyz[..., 2]
x2 = x**2
y2 = y**2
z2 = z**2
xy = x * y
xz = x * z
yz = y * z
x4 = x2**2
y4 = y2**2
z4 = z2**2
return torch.stack(
[
xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
-0.48860251190292 * y,
0.48860251190292 * z,
-0.48860251190292 * x,
1.09254843059208 * xy,
-1.09254843059208 * yz,
0.94617469575756 * z2 - 0.31539156525252,
-1.09254843059208 * xz,
0.54627421529604 * x2 - 0.54627421529604 * y2,
-0.590043589926644 * y * (3.0 * x2 - y2),
2.89061144264055 * xy * z,
0.304697199642977 * y * (1.5 - 7.5 * z2),
1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
0.304697199642977 * x * (1.5 - 7.5 * z2),
1.44530572132028 * z * (x2 - y2),
-0.590043589926644 * x * (x2 - 3.0 * y2),
2.5033429417967 * xy * (x2 - y2),
-1.77013076977993 * yz * (3.0 * x2 - y2),
0.126156626101008 * xy * (52.5 * z2 - 7.5),
0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
1.48099765681286
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
- 0.952069922236839 * z2
+ 0.317356640745613,
0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
-1.77013076977993 * xz * (x2 - 3.0 * y2),
-3.75501441269506 * x2 * y2
+ 0.625835735449176 * x4
+ 0.625835735449176 * y4,
-0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
8.30264925952416 * xy * z * (x2 - y2),
0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
0.241571547304372
* y
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
),
-1.24747010616985 * z * (1.5 * z2 - 0.5)
+ 1.6840846433293
* z
* (
1.75
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
- 1.125 * z2
+ 0.375
)
+ 0.498988042467941 * z,
0.241571547304372
* x
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
),
0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
-0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
],
-1,
)
def rsh_cart_6(xyz: torch.Tensor):
"""Computes all real spherical harmonics up to degree 6.
This is an autogenerated method. See
https://github.com/cheind/torch-spherical-harmonics
for more information.
Params:
xyz: (N,...,3) tensor of points on the unit sphere
Returns:
rsh: (N,...,49) real spherical harmonics
projections of input. Ynm is found at index
`n*(n+1) + m`, with `0 <= n <= degree` and
`-n <= m <= n`.
"""
x = xyz[..., 0]
y = xyz[..., 1]
z = xyz[..., 2]
x2 = x**2
y2 = y**2
z2 = z**2
xy = x * y
xz = x * z
yz = y * z
x4 = x2**2
y4 = y2**2
z4 = z2**2
return torch.stack(
[
xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
-0.48860251190292 * y,
0.48860251190292 * z,
-0.48860251190292 * x,
1.09254843059208 * xy,
-1.09254843059208 * yz,
0.94617469575756 * z2 - 0.31539156525252,
-1.09254843059208 * xz,
0.54627421529604 * x2 - 0.54627421529604 * y2,
-0.590043589926644 * y * (3.0 * x2 - y2),
2.89061144264055 * xy * z,
0.304697199642977 * y * (1.5 - 7.5 * z2),
1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
0.304697199642977 * x * (1.5 - 7.5 * z2),
1.44530572132028 * z * (x2 - y2),
-0.590043589926644 * x * (x2 - 3.0 * y2),
2.5033429417967 * xy * (x2 - y2),
-1.77013076977993 * yz * (3.0 * x2 - y2),
0.126156626101008 * xy * (52.5 * z2 - 7.5),
0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
1.48099765681286
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
- 0.952069922236839 * z2
+ 0.317356640745613,
0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
-1.77013076977993 * xz * (x2 - 3.0 * y2),
-3.75501441269506 * x2 * y2
+ 0.625835735449176 * x4
+ 0.625835735449176 * y4,
-0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
8.30264925952416 * xy * z * (x2 - y2),
0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
0.241571547304372
* y
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
),
-1.24747010616985 * z * (1.5 * z2 - 0.5)
+ 1.6840846433293
* z
* (
1.75
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
- 1.125 * z2
+ 0.375
)
+ 0.498988042467941 * z,
0.241571547304372
* x
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
),
0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
-0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
4.09910463115149 * x**4 * xy
- 13.6636821038383 * xy**3
+ 4.09910463115149 * xy * y**4,
-2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5),
0.00584892228263444
* y
* (3.0 * x2 - y2)
* (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
0.0701870673916132
* xy
* (
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
- 91.875 * z2
+ 13.125
),
0.221950995245231
* y
* (
-2.8 * z * (1.5 - 7.5 * z2)
+ 2.2
* z
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
)
- 4.8 * z
),
-1.48328138624466
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ 1.86469659985043
* z
* (
-1.33333333333333 * z * (1.5 * z2 - 0.5)
+ 1.8
* z
* (
1.75
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
- 1.125 * z2
+ 0.375
)
+ 0.533333333333333 * z
)
+ 0.953538034014426 * z2
- 0.317846011338142,
0.221950995245231
* x
* (
-2.8 * z * (1.5 - 7.5 * z2)
+ 2.2
* z
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
)
- 4.8 * z
),
0.0350935336958066
* (x2 - y2)
* (
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
- 91.875 * z2
+ 13.125
),
0.00584892228263444
* x
* (x2 - 3.0 * y2)
* (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4),
-2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
0.683184105191914 * x2**3
+ 10.2477615778787 * x2 * y4
- 10.2477615778787 * x4 * y2
- 0.683184105191914 * y2**3,
],
-1,
)
def rsh_cart_7(xyz: torch.Tensor):
"""Computes all real spherical harmonics up to degree 7.
This is an autogenerated method. See
https://github.com/cheind/torch-spherical-harmonics
for more information.
Params:
xyz: (N,...,3) tensor of points on the unit sphere
Returns:
rsh: (N,...,64) real spherical harmonics
projections of input. Ynm is found at index
`n*(n+1) + m`, with `0 <= n <= degree` and
`-n <= m <= n`.
"""
x = xyz[..., 0]
y = xyz[..., 1]
z = xyz[..., 2]
x2 = x**2
y2 = y**2
z2 = z**2
xy = x * y
xz = x * z
yz = y * z
x4 = x2**2
y4 = y2**2
z4 = z2**2
return torch.stack(
[
xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
-0.48860251190292 * y,
0.48860251190292 * z,
-0.48860251190292 * x,
1.09254843059208 * xy,
-1.09254843059208 * yz,
0.94617469575756 * z2 - 0.31539156525252,
-1.09254843059208 * xz,
0.54627421529604 * x2 - 0.54627421529604 * y2,
-0.590043589926644 * y * (3.0 * x2 - y2),
2.89061144264055 * xy * z,
0.304697199642977 * y * (1.5 - 7.5 * z2),
1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
0.304697199642977 * x * (1.5 - 7.5 * z2),
1.44530572132028 * z * (x2 - y2),
-0.590043589926644 * x * (x2 - 3.0 * y2),
2.5033429417967 * xy * (x2 - y2),
-1.77013076977993 * yz * (3.0 * x2 - y2),
0.126156626101008 * xy * (52.5 * z2 - 7.5),
0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
1.48099765681286
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
- 0.952069922236839 * z2
+ 0.317356640745613,
0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
-1.77013076977993 * xz * (x2 - 3.0 * y2),
-3.75501441269506 * x2 * y2
+ 0.625835735449176 * x4
+ 0.625835735449176 * y4,
-0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
8.30264925952416 * xy * z * (x2 - y2),
0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
0.241571547304372
* y
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
),
-1.24747010616985 * z * (1.5 * z2 - 0.5)
+ 1.6840846433293
* z
* (
1.75
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
- 1.125 * z2
+ 0.375
)
+ 0.498988042467941 * z,
0.241571547304372
* x
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
),
0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
-0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
4.09910463115149 * x**4 * xy
- 13.6636821038383 * xy**3
+ 4.09910463115149 * xy * y**4,
-2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5),
0.00584892228263444
* y
* (3.0 * x2 - y2)
* (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
0.0701870673916132
* xy
* (
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
- 91.875 * z2
+ 13.125
),
0.221950995245231
* y
* (
-2.8 * z * (1.5 - 7.5 * z2)
+ 2.2
* z
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
)
- 4.8 * z
),
-1.48328138624466
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ 1.86469659985043
* z
* (
-1.33333333333333 * z * (1.5 * z2 - 0.5)
+ 1.8
* z
* (
1.75
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
- 1.125 * z2
+ 0.375
)
+ 0.533333333333333 * z
)
+ 0.953538034014426 * z2
- 0.317846011338142,
0.221950995245231
* x
* (
-2.8 * z * (1.5 - 7.5 * z2)
+ 2.2
* z
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
)
- 4.8 * z
),
0.0350935336958066
* (x2 - y2)
* (
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
- 91.875 * z2
+ 13.125
),
0.00584892228263444
* x
* (x2 - 3.0 * y2)
* (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4),
-2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
0.683184105191914 * x2**3
+ 10.2477615778787 * x2 * y4
- 10.2477615778787 * x4 * y2
- 0.683184105191914 * y2**3,
-0.707162732524596
* y
* (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3),
2.6459606618019 * z * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4),
9.98394571852353e-5
* y
* (5197.5 - 67567.5 * z2)
* (-10.0 * x2 * y2 + 5.0 * x4 + y4),
0.00239614697244565
* xy
* (x2 - y2)
* (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z),
0.00397356022507413
* y
* (3.0 * x2 - y2)
* (
3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
+ 1063.125 * z2
- 118.125
),
0.0561946276120613
* xy
* (
-4.8 * z * (52.5 * z2 - 7.5)
+ 2.6
* z
* (
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
- 91.875 * z2
+ 13.125
)
+ 48.0 * z
),
0.206472245902897
* y
* (
-2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 2.16666666666667
* z
* (
-2.8 * z * (1.5 - 7.5 * z2)
+ 2.2
* z
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
)
- 4.8 * z
)
- 10.9375 * z2
+ 2.1875
),
1.24862677781952 * z * (1.5 * z2 - 0.5)
- 1.68564615005635
* z
* (
1.75
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
- 1.125 * z2
+ 0.375
)
+ 2.02901851395672
* z
* (
-1.45833333333333
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ 1.83333333333333
* z
* (
-1.33333333333333 * z * (1.5 * z2 - 0.5)
+ 1.8
* z
* (
1.75
* z
* (
1.66666666666667 * z * (1.5 * z2 - 0.5)
- 0.666666666666667 * z
)
- 1.125 * z2
+ 0.375
)
+ 0.533333333333333 * z
)
+ 0.9375 * z2
- 0.3125
)
- 0.499450711127808 * z,
0.206472245902897
* x
* (
-2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 2.16666666666667
* z
* (
-2.8 * z * (1.5 - 7.5 * z2)
+ 2.2
* z
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
)
- 4.8 * z
)
- 10.9375 * z2
+ 2.1875
),
0.0280973138060306
* (x2 - y2)
* (
-4.8 * z * (52.5 * z2 - 7.5)
+ 2.6
* z
* (
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
- 91.875 * z2
+ 13.125
)
+ 48.0 * z
),
0.00397356022507413
* x
* (x2 - 3.0 * y2)
* (
3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
+ 1063.125 * z2
- 118.125
),
0.000599036743111412
* (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
* (-6.0 * x2 * y2 + x4 + y4),
9.98394571852353e-5
* x
* (5197.5 - 67567.5 * z2)
* (-10.0 * x2 * y2 + x4 + 5.0 * y4),
2.6459606618019 * z * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3),
-0.707162732524596
* x
* (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3),
],
-1,
)
# @torch.jit.script
def rsh_cart_8(xyz: torch.Tensor):
"""Computes all real spherical harmonics up to degree 8.
This is an autogenerated method. See
https://github.com/cheind/torch-spherical-harmonics
for more information.
Params:
xyz: (N,...,3) tensor of points on the unit sphere
Returns:
rsh: (N,...,81) real spherical harmonics
projections of input. Ynm is found at index
`n*(n+1) + m`, with `0 <= n <= degree` and
`-n <= m <= n`.
"""
x = xyz[..., 0]
y = xyz[..., 1]
z = xyz[..., 2]
x2 = x**2
y2 = y**2
z2 = z**2
xy = x * y
xz = x * z
yz = y * z
x4 = x2**2
y4 = y2**2
# z4 = z2**2
return torch.stack(
[
0.282094791773878 * torch.ones(1, device=xyz.device).expand(xyz.shape[:-1]),
-0.48860251190292 * y,
0.48860251190292 * z,
-0.48860251190292 * x,
1.09254843059208 * xy,
-1.09254843059208 * yz,
0.94617469575756 * z2 - 0.31539156525252,
-1.09254843059208 * xz,
0.54627421529604 * x2 - 0.54627421529604 * y2,
-0.590043589926644 * y * (3.0 * x2 - y2),
2.89061144264055 * xy * z,
0.304697199642977 * y * (1.5 - 7.5 * z2),
1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
0.304697199642977 * x * (1.5 - 7.5 * z2),
1.44530572132028 * z * (x2 - y2),
-0.590043589926644 * x * (x2 - 3.0 * y2),
2.5033429417967 * xy * (x2 - y2),
-1.77013076977993 * yz * (3.0 * x2 - y2),
0.126156626101008 * xy * (52.5 * z2 - 7.5),
0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
1.48099765681286
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
- 0.952069922236839 * z2
+ 0.317356640745613,
0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
-1.77013076977993 * xz * (x2 - 3.0 * y2),
-3.75501441269506 * x2 * y2
+ 0.625835735449176 * x4
+ 0.625835735449176 * y4,
-0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
8.30264925952416 * xy * z * (x2 - y2),
0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
0.241571547304372
* y
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
),
-1.24747010616985 * z * (1.5 * z2 - 0.5)
+ 1.6840846433293
* z
* (
1.75
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
- 1.125 * z2
+ 0.375
)
+ 0.498988042467941 * z,
0.241571547304372
* x
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
),
0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
-0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
4.09910463115149 * x**4 * xy
- 13.6636821038383 * xy**3
+ 4.09910463115149 * xy * y**4,
-2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5),
0.00584892228263444
* y
* (3.0 * x2 - y2)
* (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
0.0701870673916132
* xy
* (
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
- 91.875 * z2
+ 13.125
),
0.221950995245231
* y
* (
-2.8 * z * (1.5 - 7.5 * z2)
+ 2.2
* z
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
)
- 4.8 * z
),
-1.48328138624466
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ 1.86469659985043
* z
* (
-1.33333333333333 * z * (1.5 * z2 - 0.5)
+ 1.8
* z
* (
1.75
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
- 1.125 * z2
+ 0.375
)
+ 0.533333333333333 * z
)
+ 0.953538034014426 * z2
- 0.317846011338142,
0.221950995245231
* x
* (
-2.8 * z * (1.5 - 7.5 * z2)
+ 2.2
* z
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
)
- 4.8 * z
),
0.0350935336958066
* (x2 - y2)
* (
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
- 91.875 * z2
+ 13.125
),
0.00584892228263444
* x
* (x2 - 3.0 * y2)
* (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4),
-2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
0.683184105191914 * x2**3
+ 10.2477615778787 * x2 * y4
- 10.2477615778787 * x4 * y2
- 0.683184105191914 * y2**3,
-0.707162732524596
* y
* (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3),
2.6459606618019 * z * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4),
9.98394571852353e-5
* y
* (5197.5 - 67567.5 * z2)
* (-10.0 * x2 * y2 + 5.0 * x4 + y4),
0.00239614697244565
* xy
* (x2 - y2)
* (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z),
0.00397356022507413
* y
* (3.0 * x2 - y2)
* (
3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
+ 1063.125 * z2
- 118.125
),
0.0561946276120613
* xy
* (
-4.8 * z * (52.5 * z2 - 7.5)
+ 2.6
* z
* (
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
- 91.875 * z2
+ 13.125
)
+ 48.0 * z
),
0.206472245902897
* y
* (
-2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 2.16666666666667
* z
* (
-2.8 * z * (1.5 - 7.5 * z2)
+ 2.2
* z
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
)
- 4.8 * z
)
- 10.9375 * z2
+ 2.1875
),
1.24862677781952 * z * (1.5 * z2 - 0.5)
- 1.68564615005635
* z
* (
1.75
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
- 1.125 * z2
+ 0.375
)
+ 2.02901851395672
* z
* (
-1.45833333333333
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ 1.83333333333333
* z
* (
-1.33333333333333 * z * (1.5 * z2 - 0.5)
+ 1.8
* z
* (
1.75
* z
* (
1.66666666666667 * z * (1.5 * z2 - 0.5)
- 0.666666666666667 * z
)
- 1.125 * z2
+ 0.375
)
+ 0.533333333333333 * z
)
+ 0.9375 * z2
- 0.3125
)
- 0.499450711127808 * z,
0.206472245902897
* x
* (
-2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 2.16666666666667
* z
* (
-2.8 * z * (1.5 - 7.5 * z2)
+ 2.2
* z
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
)
- 4.8 * z
)
- 10.9375 * z2
+ 2.1875
),
0.0280973138060306
* (x2 - y2)
* (
-4.8 * z * (52.5 * z2 - 7.5)
+ 2.6
* z
* (
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
- 91.875 * z2
+ 13.125
)
+ 48.0 * z
),
0.00397356022507413
* x
* (x2 - 3.0 * y2)
* (
3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
+ 1063.125 * z2
- 118.125
),
0.000599036743111412
* (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
* (-6.0 * x2 * y2 + x4 + y4),
9.98394571852353e-5
* x
* (5197.5 - 67567.5 * z2)
* (-10.0 * x2 * y2 + x4 + 5.0 * y4),
2.6459606618019 * z * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3),
-0.707162732524596
* x
* (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3),
5.83141328139864 * xy * (x2**3 + 7.0 * x2 * y4 - 7.0 * x4 * y2 - y2**3),
-2.91570664069932
* yz
* (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3),
7.87853281621404e-6
* (1013512.5 * z2 - 67567.5)
* (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4),
5.10587282657803e-5
* y
* (5.0 * z * (5197.5 - 67567.5 * z2) + 41580.0 * z)
* (-10.0 * x2 * y2 + 5.0 * x4 + y4),
0.00147275890257803
* xy
* (x2 - y2)
* (
3.75 * z * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
- 14293.125 * z2
+ 1299.375
),
0.0028519853513317
* y
* (3.0 * x2 - y2)
* (
-7.33333333333333 * z * (52.5 - 472.5 * z2)
+ 3.0
* z
* (
3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
+ 1063.125 * z2
- 118.125
)
- 560.0 * z
),
0.0463392770473559
* xy
* (
-4.125 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ 2.5
* z
* (
-4.8 * z * (52.5 * z2 - 7.5)
+ 2.6
* z
* (
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
- 91.875 * z2
+ 13.125
)
+ 48.0 * z
)
+ 137.8125 * z2
- 19.6875
),
0.193851103820053
* y
* (
3.2 * z * (1.5 - 7.5 * z2)
- 2.51428571428571
* z
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
)
+ 2.14285714285714
* z
* (
-2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 2.16666666666667
* z
* (
-2.8 * z * (1.5 - 7.5 * z2)
+ 2.2
* z
* (
2.25
* z
* (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
)
- 4.8 * z
)
- 10.9375 * z2
+ 2.1875
)
+ 5.48571428571429 * z
),
1.48417251362228
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
- 1.86581687426801
* z
* (
-1.33333333333333 * z * (1.5 * z2 - 0.5)
+ 1.8
* z
* (
1.75
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
- 1.125 * z2
+ 0.375
)
+ 0.533333333333333 * z
)
+ 2.1808249179756
* z
* (
1.14285714285714 * z * (1.5 * z2 - 0.5)
- 1.54285714285714
* z
* (
1.75
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
- 1.125 * z2
+ 0.375
)
+ 1.85714285714286
* z
* (
-1.45833333333333
* z
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ 1.83333333333333
* z
* (
-1.33333333333333 * z * (1.5 * z2 - 0.5)
+ 1.8
* z
* (
1.75
* z
* (
1.66666666666667 * z * (1.5 * z2 - 0.5)
- 0.666666666666667 * z
)
- 1.125 * z2
+ 0.375
)
+ 0.533333333333333 * z
)
+ 0.9375 * z2
- 0.3125
)
- 0.457142857142857 * z
)
- 0.954110901614325 * z2
+ 0.318036967204775,
0.193851103820053
* x
* (
3.2 * z * (1.5 - 7.5 * z2)
- 2.51428571428571
* z
* (
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
)
+ 2.14285714285714
* z
* (
-2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 2.16666666666667
* z
* (
-2.8 * z * (1.5 - 7.5 * z2)
+ 2.2
* z
* (
2.25
* z
* (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ 9.375 * z2
- 1.875
)
- 4.8 * z
)
- 10.9375 * z2
+ 2.1875
)
+ 5.48571428571429 * z
),
0.0231696385236779
* (x2 - y2)
* (
-4.125 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ 2.5
* z
* (
-4.8 * z * (52.5 * z2 - 7.5)
+ 2.6
* z
* (
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
- 91.875 * z2
+ 13.125
)
+ 48.0 * z
)
+ 137.8125 * z2
- 19.6875
),
0.0028519853513317
* x
* (x2 - 3.0 * y2)
* (
-7.33333333333333 * z * (52.5 - 472.5 * z2)
+ 3.0
* z
* (
3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
+ 1063.125 * z2
- 118.125
)
- 560.0 * z
),
0.000368189725644507
* (-6.0 * x2 * y2 + x4 + y4)
* (
3.75 * z * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
- 14293.125 * z2
+ 1299.375
),
5.10587282657803e-5
* x
* (5.0 * z * (5197.5 - 67567.5 * z2) + 41580.0 * z)
* (-10.0 * x2 * y2 + x4 + 5.0 * y4),
7.87853281621404e-6
* (1013512.5 * z2 - 67567.5)
* (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3),
-2.91570664069932
* xz
* (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3),
-20.4099464848952 * x2**3 * y2
- 20.4099464848952 * x2 * y2**3
+ 0.72892666017483 * x4**2
+ 51.0248662122381 * x4 * y4
+ 0.72892666017483 * y4**2,
],
-1,
)
__all__ = [
"rsh_cart_0",
"rsh_cart_1",
"rsh_cart_2",
"rsh_cart_3",
"rsh_cart_4",
"rsh_cart_5",
"rsh_cart_6",
"rsh_cart_7",
"rsh_cart_8",
]
from typing import Optional
import torch
class SphHarm(torch.nn.Module):
def __init__(self, m, n, dtype=torch.float32) -> None:
super().__init__()
self.dtype = dtype
m = torch.tensor(list(range(-m + 1, m)))
n = torch.tensor(list(range(n)))
self.is_normalized = False
vals = torch.cartesian_prod(m, n).T
vals = vals[:, vals[0] <= vals[1]]
m, n = vals.unbind(0)
self.register_buffer("m", tensor=m)
self.register_buffer("n", tensor=n)
self.register_buffer("l_max", tensor=torch.max(self.n))
f_a, f_b, initial_value, d0_mask_3d, d1_mask_3d = self._init_legendre()
self.register_buffer("f_a", tensor=f_a)
self.register_buffer("f_b", tensor=f_b)
self.register_buffer("d0_mask_3d", tensor=d0_mask_3d)
self.register_buffer("d1_mask_3d", tensor=d1_mask_3d)
self.register_buffer("initial_value", tensor=initial_value)
@property
def device(self):
return next(self.buffers()).device
def forward(self, points: torch.Tensor) -> torch.Tensor:
"""Computes the spherical harmonics."""
# Y_l^m = (-1) ^ m c_l^m P_l^m(cos(theta)) exp(i m phi)
B, N, D = points.shape
dtype = points.dtype
theta, phi = points.view(-1, D).to(self.dtype).unbind(-1)
cos_colatitude = torch.cos(phi)
legendre = self._gen_associated_legendre(cos_colatitude)
vals = torch.stack([self.m.abs(), self.n], dim=0)
vals = torch.cat(
[
vals.repeat(1, theta.shape[0]),
torch.arange(theta.shape[0], device=theta.device)
.unsqueeze(0)
.repeat_interleave(vals.shape[1], dim=1),
],
dim=0,
)
legendre_vals = legendre[vals[0], vals[1], vals[2]]
legendre_vals = legendre_vals.reshape(-1, theta.shape[0])
angle = torch.outer(self.m.abs(), theta)
vandermonde = torch.complex(torch.cos(angle), torch.sin(angle))
harmonics = torch.complex(
legendre_vals * torch.real(vandermonde),
legendre_vals * torch.imag(vandermonde),
)
# Negative order.
m = self.m.unsqueeze(-1)
harmonics = torch.where(
m < 0, (-1.0) ** m.abs() * torch.conj(harmonics), harmonics
)
harmonics = harmonics.permute(1, 0).reshape(B, N, -1).to(dtype)
return harmonics
def _gen_recurrence_mask(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Generates mask for recurrence relation on the remaining entries.
The remaining entries are with respect to the diagonal and offdiagonal
entries.
Args:
l_max: see `gen_normalized_legendre`.
Returns:
torch.Tensors representing the mask used by the recurrence relations.
"""
# Computes all coefficients.
m_mat, l_mat = torch.meshgrid(
torch.arange(0, self.l_max + 1, device=self.device, dtype=self.dtype),
torch.arange(0, self.l_max + 1, device=self.device, dtype=self.dtype),
indexing="ij",
)
if self.is_normalized:
c0 = l_mat * l_mat
c1 = m_mat * m_mat
c2 = 2.0 * l_mat
c3 = (l_mat - 1.0) * (l_mat - 1.0)
d0 = torch.sqrt((4.0 * c0 - 1.0) / (c0 - c1))
d1 = torch.sqrt(((c2 + 1.0) * (c3 - c1)) / ((c2 - 3.0) * (c0 - c1)))
else:
d0 = (2.0 * l_mat - 1.0) / (l_mat - m_mat)
d1 = (l_mat + m_mat - 1.0) / (l_mat - m_mat)
d0_mask_indices = torch.triu_indices(self.l_max + 1, 1)
d1_mask_indices = torch.triu_indices(self.l_max + 1, 2)
d_zeros = torch.zeros(
(self.l_max + 1, self.l_max + 1), dtype=self.dtype, device=self.device
)
d_zeros[d0_mask_indices] = d0[d0_mask_indices]
d0_mask = d_zeros
d_zeros = torch.zeros(
(self.l_max + 1, self.l_max + 1), dtype=self.dtype, device=self.device
)
d_zeros[d1_mask_indices] = d1[d1_mask_indices]
d1_mask = d_zeros
# Creates a 3D mask that contains 1s on the diagonal plane and 0s elsewhere.
i = torch.arange(self.l_max + 1, device=self.device)[:, None, None]
j = torch.arange(self.l_max + 1, device=self.device)[None, :, None]
k = torch.arange(self.l_max + 1, device=self.device)[None, None, :]
mask = (i + j - k == 0).to(self.dtype)
d0_mask_3d = torch.einsum("jk,ijk->ijk", d0_mask, mask)
d1_mask_3d = torch.einsum("jk,ijk->ijk", d1_mask, mask)
return (d0_mask_3d, d1_mask_3d)
def _recursive(self, i: int, p_val: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
coeff_0 = self.d0_mask_3d[i]
coeff_1 = self.d1_mask_3d[i]
h = torch.einsum(
"ij,ijk->ijk",
coeff_0,
torch.einsum("ijk,k->ijk", torch.roll(p_val, shifts=1, dims=1), x),
) - torch.einsum("ij,ijk->ijk", coeff_1, torch.roll(p_val, shifts=2, dims=1))
p_val = p_val + h
return p_val
def _init_legendre(self):
a_idx = torch.arange(1, self.l_max + 1, dtype=self.dtype, device=self.device)
b_idx = torch.arange(self.l_max, dtype=self.dtype, device=self.device)
if self.is_normalized:
# The initial value p(0,0).
initial_value: torch.Tensor = torch.tensor(
0.5 / (torch.pi**0.5), device=self.device
)
f_a = torch.cumprod(-1 * torch.sqrt(1.0 + 0.5 / a_idx), dim=0)
f_b = torch.sqrt(2.0 * b_idx + 3.0)
else:
# The initial value p(0,0).
initial_value = torch.tensor(1.0, device=self.device)
f_a = torch.cumprod(1.0 - 2.0 * a_idx, dim=0)
f_b = 2.0 * b_idx + 1.0
d0_mask_3d, d1_mask_3d = self._gen_recurrence_mask()
return f_a, f_b, initial_value, d0_mask_3d, d1_mask_3d
def _gen_associated_legendre(self, x: torch.Tensor) -> torch.Tensor:
r"""Computes associated Legendre functions (ALFs) of the first kind.
The ALFs of the first kind are used in spherical harmonics. The spherical
harmonic of degree `l` and order `m` can be written as
`Y_l^m(θ, φ) = N_l^m * P_l^m(cos(θ)) * exp(i m φ)`, where `N_l^m` is the
normalization factor and θ and φ are the colatitude and longitude,
repectively. `N_l^m` is chosen in the way that the spherical harmonics form
a set of orthonormal basis function of L^2(S^2). For the computational
efficiency of spherical harmonics transform, the normalization factor is
used in the computation of the ALFs. In addition, normalizing `P_l^m`
avoids overflow/underflow and achieves better numerical stability. Three
recurrence relations are used in the computation.
Args:
l_max: The maximum degree of the associated Legendre function. Both the
degrees and orders are `[0, 1, 2, ..., l_max]`.
x: A vector of type `float32`, `float64` containing the sampled points in
spherical coordinates, at which the ALFs are computed; `x` is essentially
`cos(θ)`. For the numerical integration used by the spherical harmonics
transforms, `x` contains the quadrature points in the interval of
`[-1, 1]`. There are several approaches to provide the quadrature points:
Gauss-Legendre method (`scipy.special.roots_legendre`), Gauss-Chebyshev
method (`scipy.special.roots_chebyu`), and Driscoll & Healy
method (Driscoll, James R., and Dennis M. Healy. "Computing Fourier
transforms and convolutions on the 2-sphere." Advances in applied
mathematics 15, no. 2 (1994): 202-250.). The Gauss-Legendre quadrature
points are nearly equal-spaced along θ and provide exact discrete
orthogonality, (P^m)^T W P_m = I, where `T` represents the transpose
operation, `W` is a diagonal matrix containing the quadrature weights,
and `I` is the identity matrix. The Gauss-Chebyshev points are equally
spaced, which only provide approximate discrete orthogonality. The
Driscoll & Healy qudarture points are equally spaced and provide the
exact discrete orthogonality. The number of sampling points is required to
be twice as the number of frequency points (modes) in the Driscoll & Healy
approach, which enables FFT and achieves a fast spherical harmonics
transform.
is_normalized: True if the associated Legendre functions are normalized.
With normalization, `N_l^m` is applied such that the spherical harmonics
form a set of orthonormal basis functions of L^2(S^2).
Returns:
The 3D array of shape `(l_max + 1, l_max + 1, len(x))` containing the values
of the ALFs at `x`; the dimensions in the sequence of order, degree, and
evalution points.
"""
p = torch.zeros(
(self.l_max + 1, self.l_max + 1, x.shape[0]), dtype=x.dtype, device=x.device
)
p[0, 0] = self.initial_value
# Compute the diagonal entries p(l,l) with recurrence.
y = torch.cumprod(
torch.broadcast_to(torch.sqrt(1.0 - x * x), (self.l_max, x.shape[0])), dim=0
)
p_diag = self.initial_value * torch.einsum("i,ij->ij", self.f_a, y)
# torch.diag_indices(l_max + 1)
diag_indices = torch.stack(
[torch.arange(0, self.l_max + 1, device=x.device)] * 2, dim=0
)
p[(diag_indices[0][1:], diag_indices[1][1:])] = p_diag
diag_indices = torch.stack(
[torch.arange(0, self.l_max, device=x.device)] * 2, dim=0
)
# Compute the off-diagonal entries with recurrence.
p_offdiag = torch.einsum(
"ij,ij->ij",
torch.einsum("i,j->ij", self.f_b, x),
p[(diag_indices[0], diag_indices[1])],
) # p[torch.diag_indices(l_max)])
p[(diag_indices[0][: self.l_max], diag_indices[1][: self.l_max] + 1)] = (
p_offdiag
)
# Compute the remaining entries with recurrence.
if self.l_max > 1:
for i in range(2, self.l_max + 1):
p = self._recursive(i, p, x)
return p