/*
   Copyright 2024 Ondrej Lhotak. All rights reserved.

   Permission is granted for private study use by students registered in
   CS 241E in the Fall 2024 term.

   The contents of this file may not be published, in whole or in part,
   in print or electronic form.

   The contents of this file may be included in work submitted for CS
   241E assignments in Fall 2024. The contents of this file may not be
   submitted, in whole or in part, for credit in any other course.
*/
package cs241e.mips

import State.*
import cs241e.Utils.*
import cs241e.mips.implementation.Conversions.*

import java.io.ByteArrayOutputStream
import scala.annotation.tailrec

/** Specification of the MIPS processor used to execute CS 241E programs. */

object CPU {
  /** Reads a register, enforcing that the value of register 0 is always zero. */
  def readRegister(state: State, regNum: Long) = if(regNum == 0) Word.zero else state.reg(regNum)

  /** Returns true if `address` is the valid address of some memory location. */
  def validAddress(address: Word): Boolean = {
    val addressAsNumber = asUnsigned(address)
    (addressAsNumber % 4) == 0 && addressAsNumber < asUnsigned(CPU.maxAddr)
  }

  /** Issues an error if `address` is not the valid address of any memory location. */
  def checkValidAddress(address: Word): Unit =
    if(!validAddress(address)) sys.error(s"attempt to dereference invalid address $address")

  /** Issues an error if `address` is not a valid value of the PC. */
  def checkValidPC(address: Word): Unit =
    if(!validAddress(address) && address != terminationPC) sys.error(s"attempt to assign invalid address $address to PC")

  /** Increments `address` by the number of bytes specified by `bytes`. */
  def incrementAddress(address: Word, bytes: Long = 4L): Word =
    encodeUnsigned(mod2to32(asUnsigned(address) + bytes))

  /** The transition function for one step of execution of the CPU. */
  def step(state0: State): State = {

    /* Fetch the instruction at the PC. */
    val pc = readRegister(state0, PC)
    val instruction = state0.mem(pc)

    /* Helper function to issue an invalid instruction error. */
    def invalidInstruction = sys.error(s"Invalid instruction $instruction at PC = $pc")

    /* Helper function to increment the PC by the number of words specified by `words`. */
    def incrementPC(state0: State, words: Long = 1): State = {
      val newPC = incrementAddress(readRegister(state0, PC), bytes = words*4)
      checkValidPC(newPC)
      state0.setReg(PC, newPC)
    }

    /* Increment the program counter. */
    val state1 = incrementPC(state0)

    /* Decode and execute the instruction. */
    val List(op, sBits, tBits, iBits) = instruction.splitAt(List(6, 5, 5))
    val s: Long = asUnsigned(sBits)
    val t: Long = asUnsigned(tBits)
    op match {
      case Bits("000000") =>
        val List(dBits, zeros, function) = iBits.splitAt(List(5, 5))
        val d = asUnsigned(dBits)
        if (zeros != Bits("00000")) invalidInstruction
        else function match {

          case Bits("100000") => // add
            val result = asUnsigned(readRegister(state1, s)) + asUnsigned(readRegister(state1, t))
            state1.setReg(d, encodeUnsigned(mod2to32(result)))

          case Bits("100010") => // sub
            val result = asUnsigned(readRegister(state1, s)) - asUnsigned(readRegister(state1, t))
            state1.setReg(d, encodeUnsigned(mod2to32(result)))

          case Bits("011000") if d == 0 => // mult
            val result = asSigned(readRegister(state1, s)) * asSigned(readRegister(state1, t))
            val (bitshi, bitslo) = encodeSigned64(result).splitAt(32)
            state1.setReg(LO, Word(bitslo)).setReg(HI, Word(bitshi))

          case Bits("011001") if d == 0 => // multu
            val result = asUnsigned(readRegister(state1, s)) * asUnsigned(readRegister(state1, t))
            val (bitshi, bitslo) = encodeUnsigned64(result).splitAt(32)
            state1.setReg(LO, Word(bitslo)).setReg(HI, Word(bitshi))

          case Bits("011010") if d == 0 => // div
            val quotient = asSigned(readRegister(state1, s)) / asSigned(readRegister(state1, t))
            val remainder = asSigned(readRegister(state1, s)) % asSigned(readRegister(state1, t))
            val bitslo = encodeSigned(mod2to32signed(quotient))
            val bitshi = encodeSigned(mod2to32signed(remainder))
            state1.setReg(LO, bitslo).setReg(HI, bitshi)

          case Bits("011011") if d == 0 => // divu
            val quotient = asUnsigned(readRegister(state1, s)) / asUnsigned(readRegister(state1, t))
            val remainder = asUnsigned(readRegister(state1, s)) % asUnsigned(readRegister(state1, t))
            val bitslo = encodeUnsigned(mod2to32(quotient))
            val bitshi = encodeUnsigned(mod2to32(remainder))
            state1.setReg(LO, bitslo).setReg(HI, bitshi)

          case Bits("010000") if s == 0 && t == 0 => // mfhi
            state1.setReg(d, readRegister(state1, HI))

          case Bits("010010") if s == 0 && t == 0 => // mflo
            state1.setReg(d, readRegister(state1, LO))

          case Bits("010100") if s == 0 && t == 0 => // lis
            val state2 = state1.setReg(d, state1.mem(readRegister(state1, PC)))
            incrementPC(state2)

          case Bits("101010") => // slt
            val result = asSigned(readRegister(state1, s)) < asSigned(readRegister(state1, t))
            state1.setReg(d,  Word(Bits("0" * 31) :+ result))

          case Bits("101011") => // sltu
            val result = asUnsigned(readRegister(state1, s)) < asUnsigned(readRegister(state1, t))
            state1.setReg(d,  Word(Bits("0" * 31) :+ result))

          case Bits("001000") if t == 0 && d == 0 => // jr
            val newAddress = readRegister(state1, s)
            checkValidPC(newAddress)
            state1.setReg(PC, newAddress)

          case Bits("001001") if t == 0 && d == 0 => // jalr
            val newAddress = readRegister(state1, s)
            checkValidPC(newAddress)
            state1
              .setReg(31, readRegister(state1, PC))
              .setReg(PC, newAddress)

          case _ => invalidInstruction
        }

      case Bits("100011") => // lw
        val address = incrementAddress(readRegister(state1, s), asSigned(iBits))
        checkValidAddress(address)
        state1.setReg(t, state1.mem(address))

      case Bits("101011") => // sw
        val address = incrementAddress(readRegister(state1, s), asSigned(iBits))
        if(address == printAddr) {
          outputStream.print(asUnsigned(readRegister(state1, t).takeRight(8)).toChar)
          state1
        } else {
          checkValidAddress(address)
          state1.setMem(address, readRegister(state1, t))
        }

      case Bits("000100") => // beq
        if(readRegister(state1, s) == readRegister(state1, t))
          incrementPC(state1, words = asSigned(iBits))
        else state1

      case Bits("000101") => // bne
        if(readRegister(state1, s) != readRegister(state1, t))
          incrementPC(state1, words = asSigned(iBits))
        else state1

      case _ => invalidInstruction
    }
  }

  /** Run steps of the CPU starting in `state` until it reaches a state when the PC has the value `terminationPC`. */
  @tailrec def run(state: State): State = {
    if(state.reg(PC) == terminationPC) state
    else run(step(state))
  }

  /** The PC value at which the CPU should halt execution. */
  val terminationPC = Word("11111110111000011101111010101101")

  /** The address that is one word beyond the last valid memory address. */
  val maxAddr =       Word("00000001000000000000000000000000")

  /** The special address for producing output. When a word is stored to this address, the low-order 8 bits of
    * that word are written to standard output.
    */
  val printAddr =     Word("11111111111111110000000000001100")

  private[cs241e] var outputStream = System.out
  
  /** Run the code passed in as `body` while redirecting the output of the MIPS machine to a string.
    * A typical use is: `val output: String = CPU.captureOutput(CPU.run(state))` */
  def captureOutput(body: =>Unit): String = {
    val baos = new ByteArrayOutputStream
    val oldStream = outputStream
    outputStream = new java.io.PrintStream(baos)
    try { body } finally { outputStream = oldStream }
    baos.toString
  }
}