#include "RANSktau.hpp"

#ifdef __okl__

void udfDirichlet(bcData *bc)
{
  if(isField("scalar k") || isField("scalar tau")) {
    bc->sScalar = 0;
  }
}

#endif

void extractLine(nrs_t *nrs, double time)
{
  const auto np = (platform->comm.mpiRank() == 0) ? 200 : 0;
  const auto offset = np;

  static pointInterpolation_t *interpolator = nullptr;
  static std::vector<dfloat> xp, yp, zp;
  static deviceMemory<dfloat> o_Up;
  static deviceMemory<dfloat> o_kp;

  if (!interpolator) {
    auto mesh = nrs->meshV;
    const auto yMin = platform->linAlg->min(mesh->Nlocal, mesh->o_y, platform->comm.mpiComm());
    const auto yMax = platform->linAlg->max(mesh->Nlocal, mesh->o_y, platform->comm.mpiComm());

    if (np) {
      const auto x0 = 7.0;
      const auto z0 = 0.5;

      xp.push_back(x0);
      yp.push_back(yMin);
      zp.push_back(z0);

      const auto betaY = 2.8;
      const auto dy = (yMax - yMin) / (np - 1);
      for (int i = 1; i < np - 1; i++) {
        xp.push_back(x0);
        yp.push_back(tanh(betaY * (i * dy - 1)) / tanh(betaY));
        zp.push_back(z0);
      }

      xp.push_back(x0);
      yp.push_back(yMax);
      zp.push_back(z0);
      o_Up.resize(offset);
      o_kp.resize(offset);
    }

    interpolator = new pointInterpolation_t(mesh, platform->comm.mpiComm());
    interpolator->setPoints(xp, yp, zp);
    interpolator->find();
  }

  interpolator->eval(1, nrs->fluid->fieldOffset, nrs->fluid->o_U, offset, o_Up);
  interpolator->eval(1, nrs->fluid->fieldOffset, nrs->scalar->o_solution("k"), offset, o_kp);

  if (platform->comm.mpiRank() == 0) {
    std::vector<dfloat> Up(np);
    std::vector<dfloat> kp(np);
    o_Up.copyTo(Up);
    o_kp.copyTo(kp);

    std::ofstream f("profile.dat");
    for (int i = 0; i < np; i++) {
      f << std::scientific << time << " " << yp[i] << " " << Up[i] << " " << kp[i] << std::endl;
    }
    f.close();
  }
}

void uservp(double time)
{
  RANSktau::updateProperties();
}

void userq(double time)
{
  RANSktau::updateSourceTerms();
}

void UDF_LoadKernels(deviceKernelProperties& kernelInfo)
{
#if 0
  {
    auto props = kernelInfo;
    props.define("p_sigma_k") = 0.6;
    RANSktau::buildKernel(props);
  }
#endif
}

void UDF_Setup()
{
  nrs->userProperties = &uservp;
  nrs->userSource = &userq;

  auto mesh = nrs->meshV;

  // Box mesh manipulation
  const auto beta = 2.8;
  auto [x, y, z] = mesh->xyzHost();

  for (int n = 0; n < mesh->Nlocal; n++) {
    y[n] = tanh(beta * y[n]) / tanh(beta);
  }
  mesh->o_y.copyFrom(y.data());

  //Initial Conditions
  std::vector<dfloat> U(mesh->dim * nrs->fluid->fieldOffset, 0.0);
  std::vector<dfloat> k(mesh->Nlocal, 0.0);
  std::vector<dfloat> tau(mesh->Nlocal, 0.0);

  if (platform->options.getArgs("RESTART FILE NAME").empty()) {
    auto& scalar = nrs->scalar;
    for(int n = 0; n < mesh->Nlocal; n++) {
      U[n + 0 * nrs->fluid->fieldOffset] = 1;
      U[n + 1 * nrs->fluid->fieldOffset] = 0;
      U[n + 2 * nrs->fluid->fieldOffset] = 0;
      k[n] = 0.01; 
      tau[n] = 0.1;
    }
    nrs->fluid->o_U.copyFrom(U.data(), U.size());
    nrs->scalar->o_solution("k").copyFrom(k.data(), k.size()); 
    nrs->scalar->o_solution("tau").copyFrom(tau.data(), tau.size()); 
  }

  std::string model = "ktau";
  platform->par->extract("casedata","model",model);

  RANSktau::setup(nrs->scalar->nameToIndex.find("k")->second, model);
}

void computeUtau(nrs_t *nrs, double time, int tstep)
{
  auto mesh = nrs->meshV;

  std::vector<int> wbID;
  for (auto &[key, bcID] : platform->app->bc->bIdToTypeId()) {
    const auto field = key.first;
    if (field == "fluid velocity") {
      if (bcID == bdryBase::bcType_zeroDirichlet) {
        wbID.push_back(key.second + 1);
      }
    }
  }
  auto o_wbID = platform->device.malloc<int>(wbID.size(), wbID.data());

  auto o_Sij = nrs->strainRate();

  auto forces = nrs->aeroForces(o_wbID, o_Sij);
  o_Sij.free();

  auto o_tmp = platform->deviceMemoryPool.reserve<dfloat>(mesh->Nlocal);
  platform->linAlg->fill(mesh->Nlocal, 1.0, o_tmp);
  const auto areaWall = mesh->surfaceAreaMultiplyIntegrate(o_wbID, o_tmp);

  // https://turbulence.oden.utexas.edu/channel2015/data/LM_Channel_2000_mean_prof.dat
  const auto utauRef = 4.58794e-02;
  const auto utau = sqrt(std::abs(forces->tangential()[0]) / areaWall);
  const auto utauErr = std::abs((utau - utauRef) / utauRef);
  if (platform->comm.mpiRank() == 0) 
    printf("utau: %.4e; utauErr: %.4e \n", utau, utauErr);
}

void UDF_ExecuteStep(double time, int tstep)
{
  //if (nrs->lastStep) {
  //  computeUtau(nrs, time, tstep);
  //  extractLine(nrs, time);
  //}
}
