#include "RecoPixelVZero/PixelVZeroFinding/interface/PixelVZeroFinder.h"

#include "MagneticField/Engine/interface/MagneticField.h"
#include "MagneticField/Records/interface/IdealMagneticFieldRecord.h"

#include "Geometry/Vector/interface/GlobalVector.h"

/*****************************************************************************/
PixelVZeroFinder::PixelVZeroFinder
  (const edm::EventSetup& es,
   const edm::ParameterSet& pset,
   reco::TrackCollection tracks)
{
  // Get track level cuts
  minImpactPositiveDaughter =
    pset.getParameter<double>("minImpactPositiveDaughter");
  minImpactNegativeDaughter =
    pset.getParameter<double>("minImpactNegativeDaughter");

  // Get track-pair level cuts
  maxDcaR = pset.getParameter<double>("maxDcaR");
  maxDcaZ = pset.getParameter<double>("maxDcaZ");

  // Get mother cuts
  minCreationRadius = pset.getParameter<double>("minCreationRadius");
  maxCreationRadius = pset.getParameter<double>("maxCreationRadius");
  maxImpactMother   = pset.getParameter<double>("maxImpactMother");

  // Get magnetic field
  edm::ESHandle<MagneticField> magfield;
  es.get<IdealMagneticFieldRecord>().get(magfield);
  fieldInInvGeV = 1./fabs(magfield->inInverseGeV(GlobalPoint(0,0,0)).z());

  // Prepare helices, if pass track level cuts
  helices.clear();
  for(reco::TrackCollection::const_iterator it  = tracks.begin();
                                            it != tracks.end(); it++)
  if( (it->charge() > 0 && fabs(it->d0()) > minImpactPositiveDaughter) ||
      (it->charge() < 0 && fabs(it->d0()) > minImpactNegativeDaughter) )
    helices.push_back(getHelix(*it));
}

/*****************************************************************************/
PixelVZeroFinder::~PixelVZeroFinder()
{
}

/*****************************************************************************/
Helix PixelVZeroFinder::getHelix(const reco::Track& t)
{
  Helix h;

  // charge
  h.q   = t.charge();

  // momentum
  h.cotTheta = 1/tan(t.theta());
  h.pT       = t.pt();

  // charged impact parameter
  h.b = h.q * t.d0();

  // azimuthal angle
  h.chi = t.phi0() + t.charge() * M_PI_2;

  // helix radius, center
  h.R = t.pt() * fieldInInvGeV;
  h.X = -(h.R + h.b) * cos(h.chi);
  h.Y = -(h.R + h.b) * sin(h.chi);
  h.Z = t.dz();

  return h;
}

/*****************************************************************************/
bool PixelVZeroFinder::getZ(Helix& h, float *z)
{
  float dpsi = h.psi - h.chi;

  while(dpsi < -M_PI) dpsi += 2*M_PI;
  while(dpsi >  M_PI) dpsi -= 2*M_PI;

  // Check if dpsi is in the proper direction
  if(h.q * dpsi < 0)
  {
    *z = h.Z - h.q * h.R * h.cotTheta * dpsi;
    return true;
  }
  else
  {
    return false;
  }
}

/*****************************************************************************/
void PixelVZeroFinder::checkIntersection(Helix& h1,Helix& h2,float dcaR)
{
  float z1,z2;

  // Check if close enough in r
  if(dcaR < maxDcaR)
  if(getZ(h1,&z1) && getZ(h2,&z2))
  {
    // Closest points
    GlobalVector r1(h1.X + h1.R * cos(h1.psi), h1.Y + h1.R * sin(h1.psi), z1);
    GlobalVector r2(h2.X + h2.R * cos(h2.psi), h2.Y + h2.R * sin(h2.psi), z2);

    // Production vertex
    GlobalVector r = 0.5*(r1 + r2);

    // Check if inside maxCreationRadius
    if(r.perp() > minCreationRadius &&
       r.perp() < maxCreationRadius)
    {
      // Projected dcaZ
      float aCotTheta = 0.5*(h1.cotTheta + h2.cotTheta);
      float aSinTheta = 1/sqrt(1 + aCotTheta*aCotTheta);
      float dcaZ = fabs(z1-z2) * aSinTheta;

      // Check if close enough in z 
      if(dcaZ < maxDcaZ)
      {
        // Production momentum
        GlobalVector p1( h1.pT * h1.q*sin(h1.psi),
                        -h1.pT * h1.q*cos(h1.psi), h1.pT * h1.cotTheta);
        GlobalVector p2( h2.pT * h2.q*sin(h2.psi),
                        -h2.pT * h2.q*cos(h2.psi), h2.pT * h2.cotTheta);
        GlobalVector p = p1 + p2;

        // Impact parameter in the plane
        GlobalVector r_(r.x(),r.y(),0);
        GlobalVector p_(p.x(),p.y(),0);
        GlobalVector b_ = r_  - (r_*p_)*p_ / (p_*p_);

        float b = sqrt(b_*b_);

        if(b < maxImpactMother)
        {
          // Armenteros
          float pt    = (p1.cross(p2)).mag() / sqrt(p*p);
          float alpha = h1.q * (p1*p1 - p2*p2)/(p*p); 

          // All cuts passed
          vZeros.push_back(reco::VZero(h1.b, h2.b, dcaR,dcaZ,
                                       r,p1,p2, b, pt,alpha));
        }
      }
    }
  }
}

/*****************************************************************************/
void PixelVZeroFinder::checkPair(Helix& h1,Helix& h2)
{
  float dcaR;

  // Distance and relative direction of the centers
  float dx = h2.X - h1.X;
  float dy = h2.Y - h1.Y;
 
  float R12  = sqrt(dx*dx + dy*dy);
  float psi0 = atan2(dy, dx);
 
  // Check triangle inequalities
  if(R12 < h1.R + h2.R && R12 > fabs(h1.R - h2.R))
  {
    // intersection
    float gamm = acos((h1.R*h1.R - h2.R*h2.R + R12*R12)/(2*h1.R*R12));
 
    for(int j=0; j<2; j++)
    {
      int sign = 2*j-1;
      h1.psi = psi0 + sign*gamm;
      h2.psi = atan2(h1.Y - h2.Y + h1.R*sin(h1.psi),
                     h1.X - h2.X + h1.R*cos(h1.psi));
   
      dcaR = 0.;
   
      checkIntersection(h1,h2,dcaR);
    }
  }
  else
  {
    if(R12 > h1.R + h2.R)
    {
      // outside, transverse distance R21-(R1+R2)
      h1.psi = psi0;
      h2.psi = psi0 + M_PI;
   
      dcaR = R12 - (h1.R + h2.R);
    }
    else
    {
      // inside, transverse distance |R1-R2|-R12
      if(h2.R < h1.R) h1.psi = psi0;
                 else h1.psi = psi0 + M_PI;
      h2.psi = h1.psi;
   
      dcaR = fabs(h1.R - h2.R) - R12;
    }
  
    checkIntersection(h1,h2,dcaR);
  }
}

/*****************************************************************************/
reco::VZeroCollection PixelVZeroFinder::doIt()
{
 vZeros.clear();

 if(helices.size() >= 2)
   for(std::vector<Helix>::iterator h1 = helices.begin();
                                    h1!= helices.end(); h1++)
   if((*h1).q > 0) // positive
   for(std::vector<Helix>::iterator h2 = helices.begin();
                                    h2!= helices.end(); h2++)
   if((*h2).q < 0) // negative
     checkPair(*h1,*h2);

 return vZeros;
}

