import sys
import numpy as np

"""
Example usage:
python main.py hw5-nn-train-10000.dat hw5-nn-test.dat # train on the first file, create predictions on the second file
python eval.py hw5-nn-test-pred.dat hw5-nn-test.dat
"""


def load_labels(path):
    data = np.loadtxt(path)
    return data[:, 2].astype(np.int64)


def main():
    if len(sys.argv) != 3:
        print(
            "Usage: python compute_accuracy.py <predictions.dat> <ground_truth.dat>",
            file=sys.stderr,
        )
        sys.exit(1)
    pred_path = sys.argv[1]
    truth_path = sys.argv[2]
    preds = load_labels(pred_path)
    actual = load_labels(truth_path)
    if len(preds) != len(actual):
        print("Error: file row counts differ", file=sys.stderr)
        sys.exit(1)
    accuracy = np.mean(preds == actual)
    print(f"{accuracy * 100}%")


if __name__ == "__main__":
    main()
