#include <string.h>
#include <spu_mfcio.h>
#include "common.h"

// Use inlined SIMD math functions
#include <simdmath/sqrtf4.h>
#define sqrtf4 _sqrtf4

// Masses of all bodies
float mass[6][SPU_BODIES] CACHE_ALIGNED;

// Pointer into mass array to section of bodies this SPU is responsible for
float *own_mass;
// Position vectors for bodies this SPU is responsible for
VEC3D own_pos[SPU_BODIES] CACHE_ALIGNED;
// Velocity vectors for bodies this SPU is responsible for
VEC3D own_vel[SPU_BODIES] CACHE_ALIGNED;
// Acceleration vectors for bodies this SPU is responsible for
VEC3D own_acc[SPU_BODIES];

// 2 sections of position vectors for double-buffering
VEC3D pos_buf[2][SPU_BODIES] CACHE_ALIGNED;

CONTROL_BLOCK cb;
struct {
  int id;
  uintptr32_t pos_addr;
  float *mass;
} sections[5];

// Which section of bodies this SPU is responsible for (0-5)
int id;

// Calculates square root of a float
static inline float
sqrtf(float a)
{
  return spu_extract(sqrtf4(spu_splats(a)), 0);
}

// Updates acceleration vectors taking into account only interactions between
// the bodies this SPU is responsible for
void
process_own()
{
  for (int i = 0; i < SPU_BODIES - 1; i++) {
    for (int j = i + 1; j < SPU_BODIES; j++) {
      VEC3D d;
      float t;

      // Calculate displacement from i to j
      d.x = own_pos[j].x - own_pos[i].x;
      d.y = own_pos[j].y - own_pos[i].y;
      d.z = own_pos[j].z - own_pos[i].z;

      // Calculate 1 / dist^2
      t = 1 / (fsqrf(d.x) + fsqrf(d.y) + fsqrf(d.z));

      // Calculate components of t along the displacement vector
      t *= sqrtf(t);
      d.x *= t;
      d.y *= t;
      d.z *= t;

      // Update acceleration of i 
      own_acc[i].x += d.x * own_mass[j];
      own_acc[i].y += d.y * own_mass[j];
      own_acc[i].z += d.z * own_mass[j];

      // Update acceleration of j
      own_acc[j].x -= d.x * own_mass[i];
      own_acc[j].y -= d.y * own_mass[i];
      own_acc[j].z -= d.z * own_mass[i];
    }
  }
}

// Updates acceleration vectors taking into account interactions with the given
// section of bodies
void
process_other(VEC3D *other_pos, float *other_mass)
{
  for (int i = 0; i < SPU_BODIES; i++) {
    for (int j = 0; j < SPU_BODIES; j++) {
      VEC3D d;
      float t;

      // Calculate displacement from i to j
      d.x = other_pos[j].x - own_pos[i].x;
      d.y = other_pos[j].y - own_pos[i].y;
      d.z = other_pos[j].z - own_pos[i].z;

      // Calculate 1 / dist^2
      t = 1 / (fsqrf(d.x) + fsqrf(d.y) + fsqrf(d.z));

      // Calculate components of force along the displacement vector
      t *= sqrtf(t);
      t *= other_mass[j];
      d.x *= t;
      d.y *= t;
      d.z *= t;

      // Update acceleration of i 
      own_acc[i].x += d.x;
      own_acc[i].y += d.y;
      own_acc[i].z += d.z;
    }
  }
}

int
main(uint64_t speid, uint64_t argp, uint64_t envp)
{
  id = envp;
  own_mass = mass[id];

  // DMA in control block and wait for completion
  mfc_get(&cb,
          argp,
          sizeof(cb),
          0,
          0, 0);
  mfc_write_tag_mask(1 << 0);
  mfc_read_tag_status_all();

  // Start DMA for masses of all bodies and initial positions/velocities of
  // bodies this SPU is responsible for
  for (int i = 0; i < 6; i++) {
    mfc_get(mass[i],
            cb.mass_addr + i * sizeof(mass[0]),
            sizeof(mass[0]),
            0,
            0, 0);
  }
  mfc_get(own_pos,
          cb.pos_addr + id * sizeof(own_pos),
          sizeof(own_pos),
          0,
          0, 0);
  mfc_get(own_vel,
          cb.vel_addr + id * sizeof(own_vel),
          sizeof(own_vel),
          0,
          0, 0);

  // Calculate address of positions for sections owned by other SPUs
  for (int i = 0, j = 0; i < 6; i++) {
    if (i != id) {
      sections[j].id = i;
      sections[j].pos_addr = cb.pos_addr + i * sizeof(pos_buf[0]);
      sections[j].mass = mass[i];
      j++;
    }
  }

  // Wait for initial DMA to complete
  mfc_read_tag_status_all();

  while (TRUE) {
    uint32_t tag;

    // Wait for message from PPU indicating SPU can compute next step
    // (0 indicates SPU should stop)
    if (spu_read_in_mbox() != 0) {
      break;
    }

    // Start DMA transfer for first section of positions
    tag = 0;
    mfc_get(pos_buf[tag],
            sections[0].pos_addr,
            sizeof(pos_buf[0]),
            tag,
            0, 0);
    tag ^= 1;

    // Clear acceleration vectors
    memset(own_acc, 0, sizeof(own_acc));

    // Process interactions between the bodies this SPU is responsible for
    process_own();

    for (int i = 0; i < 4; i++) {
      // Start transfer for next section of positions
      mfc_get(pos_buf[tag],
              sections[i + 1].pos_addr,
              sizeof(pos_buf[0]),
              tag,
              0,
              0);

      // Wait for current section of positions to finish transferring
      tag ^= 1;
      mfc_write_tag_mask(1 << tag);
      mfc_read_tag_status_all();

      // Process interactions with this section
      process_other(pos_buf[tag], sections[i].mass);
    }

    // Wait for last section of positions to finish transferring
    tag ^= 1;
    mfc_write_tag_mask(1 << tag);
    mfc_read_tag_status_all();

    // Notify PPU that positions have been read
    spu_write_out_mbox(0);

    // Process interactions with last section
    process_other(pos_buf[tag], sections[4].mass);

    // Update positions
    for (int i = 0; i < SPU_BODIES; i++) {
      own_pos[i].x += own_vel[i].x * TIMESTEP;
      own_pos[i].y += own_vel[i].y * TIMESTEP;
      own_pos[i].z += own_vel[i].z * TIMESTEP;
    }

    // Wait for message from PPU indicating it is safe to write back new
    // positions
    spu_read_in_mbox();

    // Write back new positions for the bodies this SPU is responsible for and
    // wait for DMA completion
    mfc_put(own_pos,
            cb.pos_addr + id * sizeof(own_pos),
            sizeof(own_pos),
            0,
            0, 0);

    // Update velocities
    for (int i = 0; i < SPU_BODIES; i++) {
      own_vel[i].x += own_acc[i].x * TIMESTEP;
      own_vel[i].y += own_acc[i].y * TIMESTEP;
      own_vel[i].z += own_acc[i].z * TIMESTEP;
    }

    // Wait for positions to be written back
    mfc_write_tag_mask(1 << 0);
    mfc_read_tag_status_all();

    // Notify PPU that new positions have been written
    spu_write_out_mbox(0);
  }

  return 0;
}
