import matplotlib.pyplot as plt
import numpy as np
from pyfeyn2.render.render import Render
[docs]def dotted(p1, p2, points=200):
n = np.linspace(0, points, points)
x, y = (
p1[0] + (p2[0] - p1[0]) * (n / points),
p1[1] + (p2[1] - p1[1]) * (n / points),
)
plt.plot(x, y, "k:")
[docs]def dashed(p1, p2, points=200):
n = np.linspace(0, points, points)
x, y = (
p1[0] + (p2[0] - p1[0]) * (n / points),
p1[1] + (p2[1] - p1[1]) * (n / points),
)
plt.plot(x, y, "k--")
[docs]def line(p1, p2, points=200):
n = np.linspace(0, points, points)
x, y = (
p1[0] + (p2[0] - p1[0]) * (n / points),
p1[1] + (p2[1] - p1[1]) * (n / points),
)
plt.plot(x, y, "k-")
[docs]def spring(xp1, xp2, points=200, rot=3, amp=0.15, line_frac=0.2):
p1 = [
xp1[0] + (xp2[0] - xp1[0]) * line_frac,
xp1[1] + (xp2[1] - xp1[1]) * line_frac,
]
p2 = [
xp2[0] - (xp2[0] - xp1[0]) * line_frac,
xp2[1] - (xp2[1] - xp1[1]) * line_frac,
]
n = np.linspace(0, points, points)
alpha = np.arctan((p2[1] - p1[1]) / np.array([(p2[0] - p1[0])]))
if p2[0] < p1[0]:
alpha += np.pi
w = rot / points * (2 * np.pi) + np.pi / points
ret = (
p1[0]
+ (p2[0] - p1[0]) * (n / points)
+ amp * (-np.cos(w * n - alpha) + np.cos(-alpha)),
p1[1]
+ (p2[1] - p1[1]) * (n / points)
+ amp * (np.sin(w * n - alpha) - np.sin(-alpha)),
)
x, y = (
np.append(np.insert(ret[0], 0, xp1[0]), xp2[0]),
np.append(np.insert(ret[1], 0, xp1[1]), xp2[1]),
)
plt.plot(x, y, "k-")
[docs]def wave(xp1, xp2, points=200, rot=3, amp=0.15, line_frac=0.2):
# first skip via p1 and p2 to
p1 = [
xp1[0] + (xp2[0] - xp1[0]) * line_frac,
xp1[1] + (xp2[1] - xp1[1]) * line_frac,
]
p2 = [
xp2[0] - (xp2[0] - xp1[0]) * line_frac,
xp2[1] - (xp2[1] - xp1[1]) * line_frac,
]
n = np.linspace(0, points, points)
alpha = np.arctan((p2[1] - p1[1]) / np.array([(p2[0] - p1[0])]))
w = rot / points * (2 * np.pi) # + np.pi / points
ret = (
p1[0]
+ (p2[0] - p1[0]) * (n / points)
+ amp * np.sin(w * n) * (0 * np.cos((alpha)) - np.sin((alpha))),
p1[1]
+ (p2[1] - p1[1]) * (n / points)
+ amp * np.sin(w * n) * (0 * np.sin((alpha)) + np.cos((alpha))),
)
x, y = (
np.append(np.insert(ret[0], 0, xp1[0]), xp2[0]),
np.append(np.insert(ret[1], 0, xp1[1]), xp2[1]),
)
plt.plot(x, y, "k-")
[docs]def combine_lines(lines):
return lambda *a, **k: [l(*a, **k) for l in lines]
[docs]namedlines = {
"straight": line,
"gluon": spring,
"photon": wave,
"boson": wave,
"vector": wave,
"ghost": dotted,
"fermion": line,
"higgs": dashed,
"gluino": combine_lines([spring, line]),
"gaugino": combine_lines([wave, line]),
"phantom": lambda *a, **k: None,
}
[docs]class MPLRender(Render):
def __init__(self, fd=None, *args, **kwargs):
super().__init__(fd, *args, **kwargs)
[docs] def render(
self,
file=None,
show=True,
width=None,
height=None,
resolution=100,
clean_up=False,
):
idtopos = {}
for v in self.fd.vertices:
idtopos[v.id] = (v.x, v.y)
for l in self.fd.legs:
idtopos[l.id] = (l.x, l.y)
for p in self.fd.propagators:
namedlines[p.type](idtopos[p.source], idtopos[p.target])
for l in self.fd.legs:
if l.sense[:2] == "in":
namedlines[l.type](idtopos[l.id], idtopos[l.target])
elif l.sense[:3] == "out":
namedlines[l.type](idtopos[l.target], idtopos[l.id])
else:
raise Exception("Unknown sense")
plt.axis("off")
if show:
plt.show()
if file is not None:
plt.savefig(file)
if clean_up:
plt.close()
@staticmethod
[docs] def valid_attribute(attr: str) -> bool:
return super(MPLRender, MPLRender).valid_attribute(attr) or attr in ["x", "y"]
@staticmethod
[docs] def valid_type(typ: str) -> bool:
if typ.lower() in namedlines:
return True
return False