import numpy as np
import matplotlib.pyplot as plt
from scipy.fft import fft, fftfreq
import re

# === Parse U_probe.dat for 6 probes === #
def parse_U_probes(filename, num_probes=6):
    times = []
    probe_Ux = [[] for _ in range(num_probes)]
    probe_Uy = [[] for _ in range(num_probes)]

    with open(filename, 'r') as f:
        for line in f:
            if line.startswith('#') or not line.strip():
                continue

            time_part = line.split()[0]
            vectors = re.findall(r'\(([^)]+)\)', line)
            if len(vectors) < num_probes:
                continue

            times.append(float(time_part))
            for i in range(num_probes):
                components = [float(val) for val in vectors[i].split()]
                probe_Ux[i].append(components[0])
                probe_Uy[i].append(components[1])

    return np.array(times), [np.array(u) for u in probe_Ux], [np.array(v) for v in probe_Uy]


# === Load the data === #
times, Ux_probes, Uy_probes = parse_U_probes("U_probe.dat", num_probes=6)

# === Plot Ux === #
plt.figure(figsize=(12, 5))
for i, ux in enumerate(Ux_probes):
    plt.plot(times, ux, label=f"Probe {i} Ux")
plt.xlabel("Time [s]")
plt.ylabel("Ux [m/s]")
plt.title("Ux Velocity at All Probes")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# === Plot Uy === #
plt.figure(figsize=(12, 5))
for i, uy in enumerate(Uy_probes):
    plt.plot(times, uy, label=f"Probe {i} Uy")
plt.xlabel("Time [s]")
plt.ylabel("Uy [m/s]")
plt.title("Uy Velocity at All Probes")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# === FFT on Uy at selected probe === #
probe_index = 0  # Change this to any probe from 0 to 5
uy_signal = Uy_probes[probe_index] - np.mean(Uy_probes[probe_index])
dt = times[1] - times[0]
n = len(uy_signal)

frequencies = fftfreq(n, dt)
fft_values = np.abs(fft(uy_signal))[:n//2]
frequencies = frequencies[:n//2]

# Filter out low frequency noise
freq_mask = frequencies > 1.0
dominant_freq = frequencies[freq_mask][np.argmax(fft_values[freq_mask])]

# === Plot FFT === #
plt.figure(figsize=(10, 4))
plt.plot(frequencies, fft_values)
plt.xlabel("Frequency [Hz]")
plt.ylabel("Amplitude")
plt.title(f"FFT of Uy at Probe {probe_index}")
plt.grid(True)
plt.tight_layout()
plt.show()

# === Compute Strouhal Number === #
D = 0.132      # Characteristic length (e.g. cylinder diameter)
U_inf = 11.06  # Freestream velocity
St = (dominant_freq * D) / U_inf

print(f"Dominant frequency (Probe {probe_index}): {dominant_freq:.4f} Hz")
print(f"Strouhal number (Probe {probe_index}): {St:.4f}")

# === Plot Uy time signal === #
plt.plot(times, Uy_probes[probe_index])
plt.xlabel("Time [s]")
plt.ylabel(f"Uy [m/s] (Probe {probe_index})")
plt.title(f"Transverse Velocity vs Time (Probe {probe_index})")
plt.grid(True)
plt.tight_layout()
plt.show()

# === Parse coefficient.dat === # -------------------------------------
def parse_force_coefficients(filename):
    times = []
    Cl = []
    Cd = []

    with open(filename, 'r') as f:
        for line in f:
            if line.startswith('#') or not line.strip():
                continue
            parts = line.split()
            if len(parts) < 6:
                continue
            times.append(float(parts[0]))
            Cd.append(float(parts[2]))  # Column 3 = Cd (assuming standard OpenFOAM format)
            Cl.append(float(parts[3]))  # Column 4 = Cl

    return np.array(times), np.array(Cd), np.array(Cl)

# === Load coefficient.dat === #
coeff_time, Cd_vals, Cl_vals = parse_force_coefficients("coefficient.dat")

# === Plot Cl and Cd vs time (starting after 0.1 s) === #
start_time = 0.1
mask = coeff_time > start_time

plt.figure(figsize=(12, 5))
plt.plot(coeff_time[mask], Cl_vals[mask], label="Cl")
plt.plot(coeff_time[mask], Cd_vals[mask], label="Cd")
plt.xlabel("Time [s]")
plt.ylabel("Coefficient")
plt.title(f"Lift and Drag Coefficients Over Time (t > {start_time}s)")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# === Compute Mean and Fluctuation === #
Cl_mean = np.mean(Cl_vals)
Cd_mean = np.mean(Cd_vals)
Cl_amp = np.max(Cl_vals) - np.min(Cl_vals)
Cd_amp = np.max(Cd_vals) - np.min(Cd_vals)

print(f"Mean Cl: {Cl_mean:.4f}")
print(f"Mean Cd: {Cd_mean:.4f}")
print(f"Peak-to-Peak Cl fluctuation: {Cl_amp:.4f}")
print(f"Peak-to-Peak Cd fluctuation: {Cd_amp:.4f}")
