# An example-driven introduction to automatic differentiation
Author: Daniel Lim

## 0. Package installation
Make sure you have started a new environment and installed the necessary packages; for example: 
```{bash}
conda create --name autodiff python=3.9
y
conda activate autodiff
pip install --upgrade pip
pip install --upgrade numpy autograd notebook
pip install --upgrade matplotlib scipy
```

## 1. Quick introduction to Autograd
Autograd is a simple Python package that performs automatic differentiation. It is a wrapper for Numpy - that means, Autograd "replaces" Numpy in your calculations and gives the same answers, but very importantly, gives access to efficient gradients. 

This might sound unfamiliar. Let's take a simple case and use Autograd to calculate the gradient of $y=x^2$. We expect $dy/dx=2x$.

In [None]:
import autograd.numpy as np
from autograd import grad

# define the function
def f(x):
    return x**2

# get the gradient function
grad_f = grad(f)

In [None]:
# test this on various values
print(f"At x=0.0, f(x)={f(0.0)}, and f'(x)={grad_f(0.0)}")
print(f"At x=1.0, f(x)={f(1.0)}, and f'(x)={grad_f(1.0)}")
print(f"At x=-2.0, f(x)={f(-2.0)}, and f'(x)={grad_f(-2.0)}")

It works! Let's extend this to several dimensions now

In [None]:
# define the function
def f(x):
    return x[0]*x[1]*x[2]

# get the gradient function
grad_f = grad(f)

# test this on various values
print(f"At x=[0.0,1.0,2.0], f(x)={f([0.0,1.0,2.0])}, and f'(x)={grad_f([0.0,1.0,2.0])}")
# print(f"At x=[0.0,0.0,0.0], f(x)={f([0.0,0.0,0.0])}, and f'(x)={grad_f([0.0,0.0,0.0])}")
# print(f"At x=[10.0,20.0,30.0], f(x)={f([10.0,20.0,30.0])}, and f'(x)={grad_f([10.0,20.0,30.0])}")

Note that we need to supply a three-element array $x$ to both f and grad_f. grad_f returns a 3-element array as the 3D gradient - the gradient of f in each of the three dimensions.

## 2. Timing demonstration

In [None]:
import autograd.numpy as np
from autograd import grad
import time

# define the function
def f(x):
    return np.sum(x)

# get the gradient function
grad_f = grad(f)

N_dims = 1000000 # number of tunable parameters
N_iter = 100 # number of iterations to run for a timing run


Let's time how long it takes to run the forward pass (just computing $f$ for various x-positions) for N_iter times:

In [None]:
start_time = time.time()
for j in range(N_iter):
    x0 = np.random.rand(N_dims) # position to evaluate the function and gradient at
    f_val = f(x0)
end_time_single_pass = time.time()
duration_single_pass = end_time_single_pass - start_time
print(f"Time taken for a {N_iter} forward passes: {(duration_single_pass):.6f} seconds.")

We will do the same timing for the gradient calculation for N_iter times. Note that the output of each calculation is a N_dims-long vector.

In [None]:
start_time = time.time()
for j in range(N_iter):
    x0 = np.random.rand(N_dims) # position to evaluate the function and gradient at
    grad_f_val = grad_f(x0)
end_time_grad = time.time()
duration_grad = end_time_grad - start_time
print(f"Time taken for {N_iter} gradient calculations: {(duration_grad):.6f} seconds.")

Let's see how much longer it took to compute the gradient relative to the forward pass:

In [None]:
print(f"Ratio of gradient calculation time to forward pass time: {(duration_grad/duration_single_pass):.2f}")

## 3. Brachistochrone curve calculation

We will now numerically approximate the Brachistochrone curve. The Brachistochrone curve is the smooth curve connecting two points so that a particle will slide under gravity from one point to another in the shortest time possible. Contrary to expectations, it is not a straight line connecting these two points!

In [None]:
import autograd.numpy as np
from autograd import grad
import matplotlib.pyplot as plt
from scipy.optimize import fsolve

#---------------------------
# GEOMETRY SETUP
#---------------------------

g = 9.81 # [m/s^2] gravitational acceleration
x_start = 0.0 # [m]
x_end = 1.0 # [m]
y_start = 1.0 # [m]
y_end = 0.0 # [m]
N = 25 # total number of discretization points (excluding endpoints)

# the y-positions will be sampled uniformly from top to bottom
y = np.linspace(y_start, y_end, N+2)

#---------------------------
# FUNCTION AND GRADIENT SETUP
#---------------------------

def construct_x(x_interior):
    """
    Reconstruct the full x-array by including boundary x-values.
    """
    return np.concatenate(([x_start], x_interior, [x_end]))

def total_time(x_interior):
    """
    Computes the total descent time along the discretized curve.
    The time for each segment is approximated by:
        Δt = ds / sqrt(2*g*(y_start - y_avg))
    where ds is the segment length and y_avg is the average y-value of the segment.
    """
    x_full = construct_x(x_interior)
    T = 0.0
    for i in range(N + 1):
        dx = x_full[i+1] - x_full[i]
        dy = y[i+1] - y[i]
        ds = np.sqrt(dx**2 + dy**2)
        y_avg = (y[i] + y[i+1]) / 2.0
        T += ds / np.sqrt(2 * g * (y_start - y_avg))
    return T

# Get the gradient function using autograd.
grad_total_time = grad(total_time)

In [None]:
#---------------------------
# OPTIMIZATION SETUP
#---------------------------

# define a starting guess: a straight line connecting the two endpoints
x_interior = np.linspace(x_start, x_end, N+2)[1:-1] #  we remove the first and last endpoints since they are not being optimized

# specify the gradient descent parameters
learning_rate = 2e-2 # size of each gradient descent step
num_iters = 5000 # maximum number of gradient descent iterations
obj_history = [] # we will record the performance at each optimization step

In [None]:
#---------------------------
# RUN OPTIMIZATION
#---------------------------
# Optimize via gradient descent.
for i in range(num_iters):
    grad_val = grad_total_time(x_interior)
    x_interior = x_interior - learning_rate * grad_val
    obj_history.append(total_time(x_interior))
    if i % 500 == 0:
        print(f"Iteration {i}, total time for particle to slide down: {total_time(x_interior):.4f} seconds")

print(f"Optimized total time for particle to slide down: {total_time(x_interior)} seconds.")
x_opt = construct_x(x_interior)

In [None]:
#---------------------------
# PLOTS
#---------------------------
# 1) Plot the convergence of the total descent time.
plt.figure()
plt.plot(obj_history)
plt.xlabel("Iteration")
plt.ylabel("Total descent time [s]")
plt.title("Particle travel time against iteration")
plt.show()

# 2) Overlay the optimized curve
plt.figure()
plt.plot(x_opt, y, 'b.-')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Optimized Brachistochrone curve')
plt.axis('equal')
plt.xlim((0,1))
plt.ylim((0,1))
plt.show()

In [None]:
#---------------------------
# THEORETICAL CYCLOID CALCULATION
#---------------------------
#
# The theoretical brachistochrone (for a particle sliding from (0,1) to (1,0))
# is a cycloid, described parametrically by:
#
#   x(θ) = (R/2) * (θ - sin θ)
#   y(θ) = 1 - (R/2) * (1 - cos θ)
#
# The boundary conditions x(θ_f)=1 and y(θ_f)=0 imply:
#   1 - (R/2) * (1 - cos θ_f) = 0  =>  R = 2/(1-cos θ_f)
#   (R/2) * (θ_f - sin θ_f) = 1
#
# Substituting R, the equation to solve for θ_f becomes:
#   θ_f - sin θ_f = 1 - cos θ_f

def equation(theta):
    """
    Residual for the transcendental equation: theta - sin(theta) - (1 - cos(theta)) = 0.
    """
    return theta - np.sin(theta) - (1 - np.cos(theta))

# Use fsolve to compute θ_f; initial guess of 2.0.
theta_f_guess = 2.0
theta_f_solution = fsolve(equation, theta_f_guess)[0]
R_solution = 2 / (1 - np.cos(theta_f_solution))

print(f"Theoretical solution: theta_f = {theta_f_solution:.6f}, R = {R_solution:.6f}")

# Generate the theoretical cycloid curve.
theta_vals = np.linspace(0, theta_f_solution, 200)
x_cycloid = (R_solution / 2.0) * (theta_vals - np.sin(theta_vals))
y_cycloid = 1.0 - (R_solution / 2.0) * (1.0 - np.cos(theta_vals))

plt.figure()
plt.plot(x_opt, y, 'b.-', label='Optimized (Discretized) Curve')
plt.plot(x_cycloid, y_cycloid, 'r--', label='Theoretical Cycloid')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Optimized Brachistochrone vs. Theoretical Cycloid')
plt.legend()
plt.axis('equal')
plt.xlim((0,1))
plt.ylim((0,1))
plt.show()
