#!/usr/bin/env python3

# Creating the figure requires Python, TeX Live, and the following steps:
#
#   python3 -m venv venv
#   source venv/bin/activate
#   python3 -m pip install storylines==0.12 numpy==1.25.1
#   python3 figure.py
#   deactivate
#
# Written by Jan Berges on 17 August 2023.
# Figure copyright (C) 2023 American Physical Society.

import numpy as np
import storylines

target = 'Physics'

if target == 'Physics':
    aspect = 2 / 1
    width = 1.2
    pixels = 800

elif target == 'PRX':
    aspect = 1 / 1
    width = 1.0
    pixels = 800

elif target == 'U Bremen':
    aspect = 4 / 3
    width = 1.0
    pixels = 800

elif target == 'MAPEX':
    aspect = 15.4 / 11.1
    width = 1.1
    pixels = 3080

else:
    raise ValueError('Invalid target')

thick = 0.03
thin = 0.025
cutoff = 10.0
max_Si_O = 1.7
max_O_O = 3.0
radius = 0.3

# Structure of alpha quartz from the Materials Project:
# https://legacy.materialsproject.org/materials/mp-6930
# https://doi.org/10.17188/1272701

a = np.array([
    [ 5.027782,  0.000000,  0.000000],
    [-2.513891,  4.354187,  0.000000],
    [ 0.000000,  0.000000,  5.518918]])

Si = np.array([
    [ 0.522710,  0.522710,  0.000000],
    [ 0.477290,  0.000000,  0.666667],
    [ 0.000000,  0.477290,  0.333333]])

O2 = np.array([
    [ 0.584963,  0.839260,  0.870667],
    [ 0.160740,  0.745703,  0.537334],
    [ 0.254297,  0.415037,  0.204000],
    [ 0.839260,  0.584963,  0.129333],
    [ 0.745703,  0.160740,  0.462666],
    [ 0.415037,  0.254297,  0.796000]])

Si = np.dot(Si, a)
O2 = np.dot(O2, a)

R = [i * a[0] + j * a[1] + k * a[2]
    for i in range(-2, 3)
    for j in range(-2, 3)
    for k in range(-2, 3)]

Si = [r0 + r for r0 in R for r in Si]
O2 = [r0 + r for r0 in R for r in O2]

Si = [r for r in Si if np.linalg.norm(r) < cutoff]
O2 = [r for r in O2 if np.linalg.norm(r) < cutoff]

Si = [r for r in Si if np.sum(np.linalg.norm(r - O2, axis=1) < max_Si_O) == 4]
O2 = [r for r in O2 if np.any(np.linalg.norm(r - Si, axis=1) < max_Si_O)]

def draw_arc(r, r0, phi1, phi2, nphi=500,
        wiggly=False, wiggles=7, amplitude=0.07):

    r = r - r0

    x, y, z = r
    phi = np.arctan2(y, x)
    the = np.arctan2(np.sqrt(x ** 2 + y ** 2), z)

    R = np.array([
        [np.sin(the) * np.cos(phi), np.cos(the) * np.cos(phi), -np.sin(phi)],
        [np.sin(the) * np.sin(phi), np.cos(the) * np.sin(phi),  np.cos(phi)],
        [              np.cos(the),              -np.sin(the),            0]])

    r = R.T.dot(r)

    if wiggly:
        radius = np.linalg.norm(r)
        wiggles *= 2 * np.pi * radius
        scale = (2 * np.pi) / (phi2 - phi1)
        wiggles = (round(wiggles / scale - 0.5) + 0.5) * scale

    arc = []

    for phi in np.linspace(phi1, phi2, nphi):
        Rz = np.array([
            [np.cos(phi), -np.sin(phi), 0],
            [np.sin(phi),  np.cos(phi), 0],
            [          0,            0, 1]])

        rr = Rz.dot(r)

        if wiggly:
            rr -= amplitude * np.sin(wiggles * phi) * rr / radius

        arc.append(rr)

    return r0 + np.einsum('xy,ny->nx', R, arc)

objects = []

atom = dict(mark='ball', mark_size=radius, only_marks=True)

for r in Si:
    objects.append(([r], dict(ball_color='Si', **atom)))

#atom['mark'] = 'text'

for n, r in enumerate(O2):
    objects.append(([r], dict(ball_color='O', #text_mark=str(n), color='white',
        **atom)))
    if n == 14:
        objects[-1][1]['label'] = 'above'

for bond in storylines.bonds(Si, O2, d1=radius, dmax=max_Si_O):
    objects.append((storylines.spring(*np.array(bond)),
        dict(color='Si!10!gray', line_width=thin)))

for bond in storylines.bonds(O2, d1=radius, dmax=max_O_O):
    objects.append((storylines.spring(*np.array(bond)),
        dict(color='O!10!gray', line_width=thin)))

A, B = np.array(storylines.bonds([O2[14], O2[3]], d1=radius, dmax=1e3)[0])
C = (1.93 * A + 1.07 * B) / 3
D = (1.07 * A + 1.93 * B) / 3
c = (2.07 * A + 0.93 * B) / 3
d = (0.93 * A + 2.07 * B) / 3
O = (A + B) / 2

c1, c2 = draw_arc(c, (c + C) / 2, 0.5 * np.pi, 1.5 * np.pi, nphi=2)
d1, d2 = draw_arc(d, (d + D) / 2, 0.5 * np.pi, 1.5 * np.pi, nphi=2)
ca = (3 * C + c) / 4
cb = (5 * C - c) / 4

electron0 = draw_arc(C, O, 0.00 * np.pi, 0.50 * np.pi)
electron1 = draw_arc(C, O, 0.50 * np.pi, 0.75 * np.pi)
electron2 = draw_arc(C, O, 0.75 * np.pi, 1.00 * np.pi)
electron3 = draw_arc(C, O, 1.00 * np.pi, 1.25 * np.pi)
electron4 = draw_arc(C, O, 1.25 * np.pi, 1.50 * np.pi)
electron5 = draw_arc(C, O, 1.50 * np.pi, 1.75 * np.pi)
electron6 = draw_arc(C, O, 1.75 * np.pi, 2.00 * np.pi)

E = electron5[0]
F = electron1[0]
G = (5 * F - 2 * E) / 3
f = F + c1 - c2
g = G - c1 + c2
o = G + (F - E) / 4

f1, f2 = draw_arc(f, (f + F) / 2, 0.5 * np.pi, 1.5 * np.pi, nphi=2)
g1, g2 = draw_arc(g, (g + G) / 2, 0.5 * np.pi, 1.5 * np.pi, nphi=2)
ga = (3 * G + g) / 4
gb = (5 * G - g) / 4

diagram = dict(color='white', line_width=thick)

objects.append((electron0, diagram))
objects.append((electron1, diagram))
objects.append((electron2, diagram))
objects.append((electron3, diagram))
objects.append((electron4, diagram))
objects.append((electron5, diagram))
objects.append((electron6, diagram))

objects.append((draw_arc(G, o, 0, 2 * np.pi), diagram))

dphi = 0.025 * np.pi

for n in [2, 5, 7, 9, 11, 13, 15]:
    phi = 0.125 * n * np.pi

    t1, = draw_arc(ca, O, phi - dphi, phi - dphi, nphi=1)
    t2, = draw_arc(cb, O, phi - dphi, phi - dphi, nphi=1)
    t0, = draw_arc(C, O, phi + dphi, phi + dphi, nphi=1)

    objects.append(([t0, t1, t2, t0], dict(fill=True, **diagram)))

dphi = 0.05 * np.pi

for n in [8]:
    phi = 0.125 * n * np.pi

    t1, = draw_arc(ga, o, phi - dphi, phi - dphi, nphi=1)
    t2, = draw_arc(gb, o, phi - dphi, phi - dphi, nphi=1)
    t0, = draw_arc(G, o, phi + dphi, phi + dphi, nphi=1)

    objects.append(([t0, t1, t2, t0], dict(fill=True, **diagram)))

objects.append((storylines.spring(A, c), dict(label='below', **diagram)))
objects.append((storylines.spring(d, B), diagram))
objects.append((storylines.spring(f, g), diagram))

phi = 1.625 * np.pi

objects.append((draw_arc(E, draw_arc(C, O, phi, phi, nphi=1)[0],
    0, 0.875 * np.pi, wiggly=True), diagram))

objects.append((draw_arc(electron4[0], electron2[0] + electron4[0] - O,
    0, 0.5 * np.pi, wiggly=True), diagram))

objects.append(([c, c1, c2, c], diagram))
objects.append(([C, c1, c2, C], dict(fill=True, **diagram)))

objects.append(([d, d1, d2, d], diagram))
objects.append(([D, d1, d2, D], dict(fill=True, **diagram)))

objects.append(([f, f1, f2, f], diagram))
objects.append(([F, f1, f2, F], dict(fill=True, **diagram)))

objects.append(([g, g1, g2, g], diagram))
objects.append(([G, g1, g2, G], dict(fill=True, **diagram)))

objects = storylines.project(objects, R=[-cutoff, -cutoff, -cutoff])

labels = [style.pop('label', None) for R, style in objects]

a = labels.index('above')
b = labels.index('below')

if a < b:
    objects[a], objects[b] = objects[b], objects[a]

plot = storylines.Plot(
    preamble=r'''
\definecolor{Si}{RGB}{240, 200, 160}
\definecolor{O}{RGB}{255, 13, 13}
''',

    xyaxes=False,

    height=0.0,
    margin=0.0,

    xmin=-width / 2,
    xmax=+width / 2,
    ymin=-width / 2 / aspect,
    ymax=+width / 2 / aspect,

    canvas='black')

opacities = np.array([np.average(list(zip(*R))[2])
    if style.get('color') != 'white' else np.nan
    for R, style in objects])

opacities[np.isnan(opacities)] = np.nanmax(opacities)

opacities -= min(opacities)
opacities /= max(opacities)

for (R, style), opacity in zip(objects, opacities):
    if 'color' in style:
        if style['color'] != 'white':
            style['color'] += '!%.1f!black' % (100 * opacity)

    if 'ball_color' in style:
        style['opacity'] = opacity

        plot.line(*list(zip(*R))[:2], mark_size=style['mark_size'] - 0.001,
            mark='*', color='black', draw='none', only_marks=True)

    plot.line(*list(zip(*R))[:2], **style)

plot.save('figure.png', width=pixels)