import numpy as np
from scipy import fftpack, stats
from PIL import Image
import os

def load_and_prepare_image(file_path):
    img = Image.open(file_path).convert('L')
    return np.array(img)

def compute_fft(image):
    f = np.fft.fft2(image)
    fshift = np.fft.fftshift(f)
    magnitude_spectrum = np.abs(fshift)
    return magnitude_spectrum

def radial_profile(data, center):
    y, x = np.indices(data.shape)
    r = np.sqrt((x - center[0])**2 + (y - center[1])**2)
    r = r.astype(int)
    tbin = np.bincount(r.ravel(), data.ravel())
    nr = np.bincount(r.ravel())
    radialprofile = np.divide(tbin, nr, where=nr!=0)
    return radialprofile

def compute_spectral_slope(fft_image):
    center = tuple(map(lambda x: x // 2, fft_image.shape))
    radial_prof = radial_profile(fft_image, center)
    
    frequencies = np.fft.fftfreq(len(radial_prof))[1:len(radial_prof)//2]
    power = radial_prof[1:len(radial_prof)//2]
    
    mask = np.isfinite(power) & (power > 0)
    frequencies, power = frequencies[mask], power[mask]
    
    log_freq = np.log(frequencies)
    log_power = np.log(power)
    
    slope, _ = np.polyfit(log_freq, log_power, 1)
    
    return slope

def compute_entropy(image):
    histogram = np.histogram(image, bins=256)[0]
    histogram = histogram / np.sum(histogram)
    return -np.sum(histogram * np.log2(histogram + 1e-10))

def compute_kurtosis(image):
    return stats.kurtosis(image.flatten())

def compute_moments(fft_image):
    # Compute the first 4 moments of the power spectrum
    power_spectrum = np.abs(fft_image)**2
    moments = [np.mean(power_spectrum)]
    for i in range(2, 5):
        moments.append(np.mean((power_spectrum - moments[0])**i))
    return moments

def process_image(file_path):
    image = load_and_prepare_image(file_path)
    fft_image = compute_fft(image)
    
    slope = compute_spectral_slope(fft_image)
    entropy = compute_entropy(image)
    kurtosis = compute_kurtosis(image)
    moments = compute_moments(fft_image)
    
    # Save FFT image
    fft_filename = os.path.splitext(file_path)[0] + "_fft.png"
    Image.fromarray((np.log(fft_image + 1) * 255 / np.log(1 + np.max(fft_image))).astype(np.uint8)).save(fft_filename)
    
    return fft_filename, slope, entropy, kurtosis, moments

def generate_html(results):
    html_content = """
    <html>
    <head>
        <style>
            table, th, td { border: 1px solid black; border-collapse: collapse; }
            th, td { padding: 5px; text-align: center; }
            img { max-width: 300px; max-height: 300px; }
        </style>
    </head>
    <body>
        <table>
            <tr>
                <th>Original Image</th>
                <th>2D FFT</th>
                <th>Spectral Slope</th>
                <th>Entropy</th>
                <th>Kurtosis</th>
                <th>Moments</th>
            </tr>
    """
    
    for original_path, fft_path, slope, entropy, kurtosis, moments in results:
        html_content += f"""
            <tr>
                <td><img src="{original_path}"></td>
                <td><img src="{fft_path}"></td>
                <td>{slope:.4f}</td>
                <td>{entropy:.4f}</td>
                <td>{kurtosis:.4f}</td>
                <td>{', '.join([f'{m:.2e}' for m in moments])}</td>
            </tr>
        """
    
    html_content += """
        </table>
    </body>
    </html>
    """
    
    with open("powerSpectra.html", "w") as f:
        f.write(html_content)

# List of image file paths
image_files = [
    '1722307680.png',
    '1722307715.png',
    '1722307729.png',
    '1722307742.png',
    'forest.jpg',
    'streetviewKingStreet.jpeg',
    'streetviewNorfolkStreet.jpeg'
]

# Process each image
results = []
for file_path in image_files:
    fft_path, slope, entropy, kurtosis, moments = process_image(file_path)
    results.append((file_path, fft_path, slope, entropy, kurtosis, moments))
    print(f"Processed {file_path}")

# Generate HTML
generate_html(results)

print("powerSpectra.html has been generated.")