import math
from scipy import stats
import numpy as np

# Full Q1 data (Jan 1 - Mar 31) by year
data = {
    2005: {'total': 69, 'e25': 1, 'e50': 0, 'e100': 0},
    2006: {'total': 100, 'e25': 0, 'e50': 0, 'e100': 0},
    2007: {'total': 110, 'e25': 0, 'e50': 0, 'e100': 0},
    2008: {'total': 138, 'e25': 0, 'e50': 0, 'e100': 0},
    2009: {'total': 199, 'e25': 2, 'e50': 1, 'e100': 0},
    2010: {'total': 141, 'e25': 0, 'e50': 0, 'e100': 0},
    2011: {'total': 371, 'e25': 6, 'e50': 4, 'e100': 1},
    2012: {'total': 480, 'e25': 11, 'e50': 6, 'e100': 2},
    2013: {'total': 812, 'e25': 20, 'e50': 11, 'e100': 2},
    2014: {'total': 852, 'e25': 12, 'e50': 7, 'e100': 5},
    2015: {'total': 840, 'e25': 18, 'e50': 8, 'e100': 5},
    2016: {'total': 1253, 'e25': 29, 'e50': 16, 'e100': 10},
    2017: {'total': 1230, 'e25': 26, 'e50': 13, 'e100': 6},
    2018: {'total': 1305, 'e25': 41, 'e50': 18, 'e100': 8},
    2019: {'total': 1557, 'e25': 42, 'e50': 21, 'e100': 8},
    2020: {'total': 1575, 'e25': 36, 'e50': 17, 'e100': 7},
    2021: {'total': 2065, 'e25': 53, 'e50': 30, 'e100': 15},
    2022: {'total': 2168, 'e25': 48, 'e50': 15, 'e100': 6},
    2023: {'total': 1865, 'e25': 47, 'e50': 20, 'e100': 9},
    2024: {'total': 1712, 'e25': 39, 'e50': 21, 'e100': 9},
    2025: {'total': 1905, 'e25': 41, 'e50': 17, 'e100': 5},
    2026: {'total': 2322, 'e25': 70, 'e50': 40, 'e100': 16},
}

def poisson_analysis(baseline_years, test_year, metric, label):
    baseline_vals = [data[y][metric] for y in baseline_years]
    observed = data[test_year][metric]
    
    lam = np.mean(baseline_vals)
    std_baseline = np.std(baseline_vals, ddof=1)
    
    # Poisson exact: P(X >= observed)
    p_value = stats.poisson.sf(observed - 1, lam)
    sigma = stats.norm.isf(p_value) if p_value > 0 else float('inf')
    
    # Combined uncertainty (commenter's method)
    err_observed = math.sqrt(observed)
    err_mean = math.sqrt(sum(baseline_vals)) / len(baseline_vals)
    combined_err = math.sqrt(err_observed**2 + err_mean**2)
    sigma_combined = (observed - lam) / combined_err if combined_err > 0 else 0
    p_combined = stats.norm.sf(sigma_combined)
    
    # Overdispersion
    variance = np.var(baseline_vals, ddof=1)
    dispersion = variance / lam if lam > 0 else 0
    
    # 1-in-N
    one_in_n_poisson = f"1 in {1/p_value:,.0f}" if p_value > 1e-10 else ">1 billion"
    one_in_n_combined = f"1 in {1/p_combined:,.0f}" if p_combined > 1e-10 else ">1 billion"
    
    print(f"\n{'='*70}")
    print(f"  {label}")
    print(f"  Baseline: {baseline_years[0]}-{baseline_years[-1]} | Test: {test_year}")
    print(f"{'='*70}")
    print(f"  Baseline values: {baseline_vals}")
    print(f"  Baseline mean (λ):     {lam:.2f}")
    print(f"  Baseline std dev:      {std_baseline:.2f}")
    print(f"  Observed {test_year}:          {observed}")
    print(f"  Excess over mean:      {observed - lam:.1f}")
    print(f"")
    print(f"  --- Poisson exact test ---")
    print(f"  P(X ≥ {observed} | λ={lam:.2f}):  {p_value:.8f}")
    print(f"  Equivalent sigma:      {sigma:.2f}σ")
    print(f"  Odds:                  {one_in_n_poisson}")
    print(f"")
    print(f"  --- Conservative combined uncertainty ---")
    print(f"  Error on observed:     ±{err_observed:.2f}")
    print(f"  Error on baseline mean:±{err_mean:.2f}")
    print(f"  Combined uncertainty:  ±{combined_err:.2f}")
    print(f"  Sigma (combined):      {sigma_combined:.2f}σ")
    print(f"  P-value:               {p_combined:.8f}")
    print(f"  Odds:                  {one_in_n_combined}")
    print(f"")
    print(f"  --- Overdispersion check ---")
    print(f"  Variance/mean ratio:   {dispersion:.2f}")
    if dispersion > 1.5:
        print(f"  ⚠ Overdispersed — trying negative binomial")
        n_param = lam**2 / (variance - lam) if variance > lam else 100
        p_param = lam / variance if variance > lam else 0.99
        nb_pvalue = stats.nbinom.sf(observed - 1, n_param, p_param)
        nb_sigma = stats.norm.isf(nb_pvalue) if nb_pvalue > 0 else float('inf')
        nb_odds = f"1 in {1/nb_pvalue:,.0f}" if nb_pvalue > 1e-10 else ">1 billion"
        print(f"  NB P(X ≥ {observed}):     {nb_pvalue:.8f}")
        print(f"  NB sigma:             {nb_sigma:.2f}σ")
        print(f"  NB odds:              {nb_odds}")
    else:
        print(f"  ✓ Poisson is appropriate")
    
    return {
        'metric': label,
        'lambda': lam,
        'observed': observed,
        'excess': observed - lam,
        'p_poisson': p_value,
        'sigma_poisson': sigma,
        'sigma_combined': sigma_combined,
        'p_combined': p_combined,
        'one_in_n_poisson': one_in_n_poisson,
        'one_in_n_combined': one_in_n_combined,
    }


print("╔══════════════════════════════════════════════════════════════════════╗")
print("║  POISSON ANALYSIS: Q1 2026 FIREBALL EVENTS (END OF QUARTER)       ║")
print("║  American Meteor Society Database — Full Q1 (Jan 1 – Mar 31)      ║")
print("╚══════════════════════════════════════════════════════════════════════╝")

# PRIMARY ANALYSIS: 2018-2025 baseline (stable platform era)
print("\n\n" + "▓"*70)
print("  PRIMARY BASELINE: 2018-2025 (stable platform era)")
print("▓"*70)

stable_years = list(range(2018, 2026))
results = []
for metric, label in [('e25', '25+ witness events'), ('e50', '50+ witness events'), ('e100', '100+ witness events'), ('total', 'Total events')]:
    r = poisson_analysis(stable_years, 2026, metric, label)
    results.append(r)

# Was 2021 also anomalous?
print("\n\n" + "▓"*70)
print("  COMPARISON: Was 2021 anomalous? (2018-2020 + 2022-2025 baseline)")
print("▓"*70)
no2021 = [y for y in range(2018, 2026) if y != 2021]
for metric, label in [('e50', '50+ witness events'), ('e100', '100+ witness events')]:
    poisson_analysis(no2021, 2021, metric, label)

# SUMMARY TABLE
print("\n\n" + "="*90)
print("  SUMMARY TABLE — Q1 2026 vs 2018-2025 baseline (full quarter, Jan 1 – Mar 31)")
print("="*90)
print(f"  {'Metric':<22} {'λ':>6} {'2026':>5} {'Excess':>7} {'Poisson':>9} {'Conserv.':>9} {'Odds (Poisson)':>16} {'Odds (Conserv.)':>16}")
print(f"  {'-'*22} {'-'*6} {'-'*5} {'-'*7} {'-'*9} {'-'*9} {'-'*16} {'-'*16}")
for r in results:
    print(f"  {r['metric']:<22} {r['lambda']:>6.1f} {r['observed']:>5} {r['excess']:>+7.1f} {r['sigma_poisson']:>8.2f}σ {r['sigma_combined']:>8.2f}σ {r['one_in_n_poisson']:>16} {r['one_in_n_combined']:>16}")

print(f"\n  Poisson σ: exact test assuming baseline mean is known")
print(f"  Conservative σ: accounts for uncertainty in both observed value and baseline mean")
print(f"  The true significance lies between these two values")
print(f"  Total events uses negative binomial (overdispersed) — see detailed output above")
