import numpy as np
from scipy import fftpack
import matplotlib.pyplot as plt
from PIL import Image

def load_and_prepare_image(file_path):
    img = Image.open(file_path).convert('L')
    return np.array(img)

def compute_fft(image):
    f = np.fft.fft2(image)
    fshift = np.fft.fftshift(f)
    magnitude_spectrum = 20*np.log(np.abs(fshift) + 1e-10)  # Add small epsilon to avoid log(0)
    return magnitude_spectrum

def radial_profile(data, center):
    y, x = np.indices(data.shape)
    r = np.sqrt((x - center[0])**2 + (y - center[1])**2)
    r = r.astype(int)
    tbin = np.bincount(r.ravel(), data.ravel())
    nr = np.bincount(r.ravel())
    radialprofile = np.divide(tbin, nr, where=nr!=0)  # Avoid division by zero
    return radialprofile

def compute_spectral_slope(fft_image):
    center = tuple(map(lambda x: x // 2, fft_image.shape))
    radial_prof = radial_profile(fft_image, center)
    
    frequencies = np.fft.fftfreq(len(radial_prof))[1:len(radial_prof)//2]
    power = radial_prof[1:len(radial_prof)//2]
    
    # Remove any remaining infinite or NaN values
    mask = np.isfinite(power) & (power > 0)
    frequencies, power = frequencies[mask], power[mask]
    
    log_freq = np.log(frequencies)
    log_power = np.log(power)
    
    slope, intercept = np.polyfit(log_freq, log_power, 1)
    
    return slope

def process_image(file_path):
    image = load_and_prepare_image(file_path)
    fft_image = compute_fft(image)
    slope = compute_spectral_slope(fft_image)
    return fft_image, slope

# List of image file paths
image_files = [
    '1722307680.png',
    '1722307715.png',
    '1722307729.png',
    '1722307742.png',
    'forest.jpg',
    'streetviewKingStreet.jpeg',
    'streetviewNorfolkStreet.jpeg'
]

# Process each image
results = []
for file_path in image_files:
    fft_image, slope = process_image(file_path)
    results.append((fft_image, slope))

# Print results and plot FFTs
num_images = len(results)
rows = (num_images + 2) // 3  # Calculate number of rows needed
fig, axs = plt.subplots(rows, 3, figsize=(15, 5*rows))
axs = axs.ravel()  # Flatten the 2D array of axes

for i, (fft_image, slope) in enumerate(results):
    print(f"Image {i+1} spectral slope: {slope:.4f}")
    
    axs[i].imshow(fft_image, cmap='gray')
    axs[i].set_title(f"FFT Image {i+1}")
    axs[i].axis('off')

# Hide any unused subplots
for j in range(i+1, len(axs)):
    axs[j].axis('off')

plt.tight_layout()
plt.show()

# Print all slopes for easy comparison
print("\nAll spectral slopes:")
print([round(result[1], 4) for result in results])