#!/usr/bin/env python3
"""
CS 451 Data-Intensive Distributed Computing:
Assignment 4 public check script

Sample usage:
$ ./check_assignment4_public.py 
"""

import sys
import os
import socket
from subprocess import run
import argparse
import re
import time

def check_a3(env,questions,partitions=None):
    spark_cmd = ["spark-submit"]
    collection_path = "data/Shakespeare.txt"
    index_path = "cs451-a3-index-shakespeare"
    queries = ['outrageous fortune AND', 'white red OR rose AND pluck AND']
    cluster_options = ""
    if env != 'linux':
        spark_cmd = ["spark-submit", "--num-executors", "2", "--executor-cores", "4", "--executor-memory", "24G"]
        collection_path = '/data/cs451/enwiki-20180901-sentences-0.1sample.txt'
        index_path = "cs451-a3-index-wiki"
        queries = ['waterloo stanford OR cheriton AND', 'big data AND hadoop spark OR AND']
        partitions = partitions or 8
    partitions = partitions or 4
    
    run(["mvn","clean","package"])
    result = None
    
    if 1 in questions:
        print("# Running Question 1 - Building Inverted Index)")
        start = time.time()
        with open("q1.out", "w") as of, open("q1.err", "w") as ef:
            result = run(spark_cmd + [
                "--class", "ca.uwaterloo.cs451.a3.BuildInvertedIndexCompressed",
                "target/assignments-1.0.jar", "--input", collection_path,
                "--output", index_path,
                "--partitions", str(partitions)],stdout=of,stderr=ef)
        end = time.time()    
        if result.returncode == 0:
            sizes = list(filter(lambda x: x is not None, map(lambda line: int(line.split()[0]) if line else None, run(['hdfs', 'dfs', '-du', index_path + '/part*'], capture_output=True).stdout.decode().split("\n"))))
            index_size = sum(sizes)
            print(f"Time Taken: {end-start:.3f} seconds")
            print(f"Index Size: {index_size} B, {index_size / 1024.0 / 1024.0:.1f} MiB")
            print(f"Expected partitions: {partitions}  Actual Partitions: {len(sizes)}")
        else:
            print("Q1 failed - see q1.out and q1.err")
            return


    if 2 in questions:
        print("# Running Question 2 - Boolean Retrieval")
        with open("q2.out", "w") as of, open("q2.err", "w") as ef:
            result = run(spark_cmd + [
                "--class", "ca.uwaterloo.cs451.a3.BooleanRetrievalCompressed",
                "target/assignments-1.0.jar", "--index", index_path,
                "--collection", collection_path],
                stdout=of, stderr=ef, input='\n'.join(queries), text=True)
        if result.returncode == 0:
            filtered_output = run(['egrep', '^Hits:|^Query time|^[0-9]+', "q2.out"], capture_output=True).stdout.decode()
            print(filtered_output)
            print("---\nNote -- the above should ONLY contain the query times, the number of hits, and the hits themselves - if you see more, get rid of your extra prints!")
        else:
            print("Q2 failed - see q2.out and q2.err")
            return


    


if __name__ == "__main__":
    env = 'datasci' if socket.gethostname() == 'datasci-login' else 'linux'
    parser = argparse.ArgumentParser(description="CS 451 2026 Winter Assignment 3 Public Test Script")
    parser.add_argument('-q', '--questions', help='Question(s) to evaluate', type=int, nargs='+', default=[1,2])
    parser.add_argument('-p', '--partitions', help='Number of partitions (defaults to 4 on student.cs, 8 on datasci)', type=int, default=None)
    args=parser.parse_args()
    print(f"Running on {env}")
    try:
        check_a3(env, args.questions, args.partitions)
    except Exception as e:
        print(e)
