#!/usr/bin/env python3
# Usage: python3 a3q3vis.py <X_file.dat> <y_file.dat> <beta_history.dat> <grad_history.dat> <ll_history.dat>

# No other libraries are allowed
import numpy as np
import matplotlib.pyplot as plt
import sys
from hw3q3 import load_data

def load_histories(beta_hist_file, grad_hist_file, ll_hist_file):
    '''
    Load beta, gradient, and ll histories from .dat files.
    You do not need to change this function.
    '''
    beta_hist = np.loadtxt(beta_hist_file)
    grad_hist = np.loadtxt(grad_hist_file)
    ll_hist = np.loadtxt(ll_hist_file)
    return beta_hist, grad_hist, ll_hist

def plot_ll(ll_hist):
    # TODO: Plot the ll and change in ll over iterations.

    plt.tight_layout()
    plt.savefig('plot_ll.pdf')
    print("Saved plot_ll.pdf")
    plt.close()

def plot_betas(beta_hist):
    # TODO: Plot the beta coefficients over iterations.
    
    plt.tight_layout()
    plt.savefig('plot_betas.pdf')
    print("Saved plot_betas.pdf")
    plt.close()

def plot_gradients(grad_hist):
    # TODO: Plot the gradient values over iterations.
    
    plt.tight_layout()
    plt.savefig('plot_gradients.pdf')
    print("Saved plot_gradients.pdf")
    plt.close()

def plot_decision_boundary(X, y, final_beta):
    # TODO: Plot the decision boundary with data points.
    
    plt.tight_layout()
    plt.savefig('plot_decision_boundary.pdf')
    print("Saved plot_decision_boundary.pdf")
    plt.close()

# Do not change the code below this line
if __name__ == "__main__":
    if len(sys.argv) != 6:
        print("Usage: python3 plot_results.py <X_file.dat> <y_file.dat> <beta_history.dat> <grad_history.dat> <ll_history.dat>")
        sys.exit(1)

    X, y = load_data(sys.argv[1], sys.argv[2])
    beta_hist, grad_hist, ll_hist = load_histories(sys.argv[3], sys.argv[4], sys.argv[5])

    # Generate Plots
    plot_ll(ll_hist)
    plot_betas(beta_hist)
    plot_gradients(grad_hist)
    
    # For decision boundary, we use the LAST beta found
    final_beta = beta_hist[-1]
    plot_decision_boundary(X, y, final_beta)