/* $Id: agentrex.C,v 1.23 2001/08/24 23:03:01 ericp Exp $ */

/*
 *
 * Copyright (C) 2000 Michael Kaminsky (kaminsky@lcs.mit.edu)
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation; either version 2, or (at
 * your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
 * USA
 *
 */

#include "aios.h"
#include "sfsmisc.h"
#include "sfsconnect.h"
#include "list.h"
#include "agent.h"
#include "rex_prot.h"
#include "rex.h"

class agentstartfd: public rexfd
{

  str schost;

  // schost where first rexconnect came from
  

  cbi succeedcb;
  bool waitnewfd;

  void
  agentstarted (ref<int> resp, str schost, cbi succeedcb, clnt_stat err)
  {
    if (*resp || err) {
      warn << "could not start agent on "<< schost << " : ";
      if (err)
	warn << err << "\n";
      else
	warn << strerror (*resp) << "\n";
      // should still succeed even if there's an agent already running
      succeedcb (1);
    }
    else {
      succeedcb (0);
      warn << "agent forwarding connection started\n";
    }
  }


  
public:

  agentstartfd (rexchannel *pch, int fd, cbi succeedcb, str schost):
    rexfd (pch, fd), schost (schost), succeedcb (succeedcb), waitnewfd (true)  {}
    
  virtual void
  newfd (svccb *sbp)
  {
    rexcb_newfd_arg *argp = sbp->template getarg<rexcb_newfd_arg> ();

    waitnewfd = false;
    
    int s[2];

    if (socketpair(AF_UNIX, SOCK_STREAM, 0, s)) {
      warn << "error creating socketpair for agent forwarding";
      sbp->replyref(false);
      return;
    }

    make_async (s[1]);
    make_async (s[0]);

    sfsagent *a = New sfsagent (s[1]);
    a->setname (schost);
    a->cs = NULL;

    ref <int> resp = New refcounted <int>;
    a->ac->call (AGENT_START, NULL, resp, wrap (this,
						&agentstartfd::agentstarted,
						resp, schost, succeedcb)); 

    vNew refcounted<unixfd> (pch, argp->newfd, s[0]);

    sbp->replyref (true);
  }

  virtual void data (svccb *sbp) {
    rex_payload *argp = sbp->template getarg<rex_payload> ();
    if (waitnewfd && !argp->data.size ()) {
      warn ("agent forward channel failure: EOF from suidconnect agent\n");
      succeedcb (1);
      rexfd::data (sbp);
      return;
    }
  }
};


class agentchannel: public rexchannel
{
  str schost;
  cbi succeedcb;
public:
  agentchannel (rexsession *sess, vec<str> command, str schost,
		cbi succeedcb)
    : rexchannel (sess, 1, command), schost (schost), succeedcb (succeedcb)
    {}

  void madechannel (int error) {
    if (error) {
      //should probably still succeed even if we can't run "suidconnect agent"
      succeedcb (1);
    }
    else {
      vNew refcounted<agentstartfd> (this, 0, succeedcb, schost);
    }
  }
  
};

class rexsess {
  ptr<sfscon> sessconn;
  ptr<aclnt> sessclnt;
  ptr<aclnt> sfsclnt;
  ptr<aclnt> rexclnt;
  sfs_sessinfo sessinfo;
  rex_sesskeydat kscdat;
  rex_sesskeydat kcsdat;
  u_int32_t myauthno;
  sfs_seqno seqno;

  bool inprogress;
  
  
  bool afpending;       //  agent forwarding channel is being made
  vec<cb_rex::ptr> afcbq;
  vec<cb_rex::ptr> cbq;
  rexsession *sess;

  void fail ();
  void sessinit_done (int error);
  void attached (rexd_attach_res *resp, clnt_stat err);
  void attach ();
  void spawned (rexd_spawn_res *resp, clnt_stat err);
  void spawn ();
  void loggedin (sfs_loginres *lresp, clnt_stat err);
  void dologin (ptr<sfsagent_auth_res> ares, clnt_stat err);
  ptr<sfsagent_auth_res> signauthreq (sfsagent_authinit_arg *aa);
  void connected (ptr<sfscon> sc, str err);
  void seq2sessinfo (u_int64_t seqno, sfs_hash *sidp, sfs_sessinfo *sip);
  void eof () {
    // todo:  we need to be able to distinguish these
    warn ("either agent channel closed or EOF from proxy on %s\n",
	  path.cstr ());
    delete this;
  }

  
public:
  bool forwardagent;    //  rex client requested agent forwarding
  bool agentforwarded;  //  agent forwarding has been established
  str path;
  str rexconnect_origin;
  ptr<sfsagent_rex_res> cbres;
  ihash_entry<rexsess> link;

  void reply (cb_rex::ptr cb);
  void agentf_init_done (int err);
  void succeed (cb_rex::ptr cb, int error = 0);
  void abort () { sess->abort (); }
  rexsess (str path, str frompath, bool forwardagent, cb_rex::ptr cb);
  ~rexsess ();
};

ihash<str, rexsess, &rexsess::path, &rexsess::link> sesstab;

void
rexsess::seq2sessinfo (u_int64_t seqno, sfs_hash *sidp, sfs_sessinfo *sip)
{
  kcsdat.seqno = seqno;
  kscdat.seqno = seqno;

  sfs_sessinfo si;
  si.type = SFS_SESSINFO;
  si.kcs.setsize (sha1::hashsize);
  sha1_hashxdr (si.kcs.base (), kcsdat, true);
  si.ksc.setsize (sha1::hashsize);
  sha1_hashxdr (si.ksc.base (), kscdat, true);

  if (sidp)
    sha1_hashxdr (sidp->base (), si, true);
  if (sip)
    *sip = si;

  bzero (si.kcs.base (), si.kcs.size ());
  bzero (si.ksc.base (), si.ksc.size ());
}

rexsess::rexsess (str path, str pathfrom, bool forwardagent,
		  cb_rex::ptr firstcb)
                 : inprogress (true), afpending (false),
		   forwardagent (forwardagent), agentforwarded (false),
		   path (path), rexconnect_origin (pathfrom)
{
  cbq.push_back (firstcb);
  myauthno = 0;
  seqno = 1;
  sfs_connect_path (path, SFS_REX, wrap (this, &rexsess::connected));
}

rexsess::~rexsess ()
{
  bzero (&kscdat, sizeof (kscdat));
  bzero (&kcsdat, sizeof (kcsdat));
  if (cbres) {
    bzero (&cbres->resok->kcs, sizeof (cbres->resok->kcs));
    bzero (&cbres->resok->ksc, sizeof (cbres->resok->ksc));
  }
  if (sesstab[path])
    sesstab.remove (this);
}

void
rexsess::fail ()
{
  sessinit_done (1);
}

void
rexsess::sessinit_done (int error)
{
  inprogress = false;

  int cbql = cbq.size ();
  for (int c = 0; c < cbql; c++)
    succeed (cbq[c], error);
  cbq.clear ();

  if (error)
    delete this;
  else
    sess->setendcb (wrap (this, &rexsess::eof));
}

void
rexsess::reply (cb_rex::ptr cb) {
  cbres->resok->seqno = ++seqno;
  cb (cbres);
}

void
rexsess::agentf_init_done (int err) {
  afpending = false;
  agentforwarded = (err == 0);

  int afcbql = afcbq.size ();
  for (int c = 0; c < afcbql; c++)
    reply (afcbq[c]);
  afcbq.clear ();
}
void
rexsess::succeed (cb_rex::ptr cb, int error)
{
  if (inprogress)
    cbq.push_back (cb);
  else {
    if (error)
      cb (New refcounted<sfsagent_rex_res> (false));
    else {
      if (forwardagent && !agentforwarded) {
	afcbq.push_back (cb);
	if (!afpending) {
	  afpending = true;
	  vec<str> suidcommand;
	  suidcommand.setsize (2);
	  suidcommand[0] = "suidconnect";
	  suidcommand[1] = "agent";
	  sess->makechannel (New refcounted <agentchannel>
			     (sess, suidcommand, path,
			      wrap (this, &rexsess::agentf_init_done)));
	}
      }
      else
	reply (cb);
    }
  }
}

void
rexsess::attached (rexd_attach_res *resp, clnt_stat err)
{
  if (err) {
    warn << "FAILED (" << err << ")\n";
  }
  else if (*resp != SFS_OK) {
    // XXX
    warn << "FAILED (proxy attach err " << int (*resp) << ")\n";
  }
  delete resp;
  warnx << "attached to proxy\n";

  sessconn->x = axprt_crypt::alloc (sessconn->x->reclaim ());
  sessconn->x->encrypt (sessinfo.kcs.base (), sessinfo.kcs.size (),
			sessinfo.ksc.base (), sessinfo.ksc.size ());

  sesstab.insert (this);
  
  sess = New rexsession (path, sessconn->x);

  sessinit_done (0);
}

void
rexsess::attach ()
{
  rexd_attach_arg arg;

  arg.seqno = seqno++;
  seq2sessinfo (0, &arg.sessid, NULL);
  seq2sessinfo (arg.seqno, &arg.newsessid, &sessinfo);

  rexd_attach_res *resp = New rexd_attach_res;
  sessclnt->call (REXD_ATTACH, &arg, resp, 
		  wrap (this, &rexsess::attached, resp));
}

void
rexsess::spawned (rexd_spawn_res *resp, clnt_stat err)
{
  if (err) {
    warn << "REXD_SPAWN proxy RPC FAILED (" << err << ")\n";
    fail ();
    return;
  }
  else if (resp->err != SFS_OK) {
    // XXX
    warn << "FAILED (spawn proxy err " << int (resp->err) << ")\n";
    fail ();
    return;
  }
  warnx << "spawned proxy\n";

  kcsdat.sshare = resp->resok->kmsg.kcs_share;
  kscdat.sshare = resp->resok->kmsg.ksc_share;
  delete resp;

  cbres = New refcounted<sfsagent_rex_res> (true);
  cbres->resok->kcs.kcs_share = kcsdat.cshare;
  cbres->resok->kcs.ksc_share = kcsdat.sshare;
  cbres->resok->ksc.kcs_share = kscdat.cshare;
  cbres->resok->ksc.ksc_share = kscdat.sshare;

  attach ();
}

void
rexsess::spawn ()
{
  rexd_spawn_arg arg;
  rnd.getbytes (arg.kmsg.kcs_share.base (),
		arg.kmsg.kcs_share.size ());
  rnd.getbytes (arg.kmsg.ksc_share.base (),
		arg.kmsg.ksc_share.size ());
  kcsdat.type = SFS_KCS;
  kcsdat.cshare = arg.kmsg.kcs_share;
  kscdat.type = SFS_KSC;
  kscdat.cshare = arg.kmsg.ksc_share;
  arg.command.setsize (1);
  arg.command[0] = "proxy";

  rexd_spawn_res *resp = New rexd_spawn_res;
  sessclnt->call (REXD_SPAWN, &arg, resp, wrap (this, &rexsess::spawned, resp),
		  authuint_create (myauthno));
}

void
rexsess::loggedin (sfs_loginres *lresp, clnt_stat err)
{
  if (err) {
    warn << "loggedin: error\n";
    fail ();
    return;
  }
  if (!lresp) {
    warn << "loggedin: lresp is NULL error\n";
    fail ();
    return;
  }

  switch (lresp->status) {
  case SFSLOGIN_OK:
    myauthno = *lresp->authno;
    break;
#if 0
  case SFSLOGIN_MORE:
    {
      sfscd_agentreq_arg arg;
      arg.aid = aid;
      arg.agentreq.set_type (AGENTCB_AUTHMORE);
      arg.agentreq.more->authinfo = sp->authinfo;
      arg.agentreq.more->seqno = seqno;
      arg.agentreq.more->challenge = *sres.resmore;
      cbase = cdc->call (SFSCDCBPROC_AGENTREQ, &arg, &ares,
			 wrap (this, &userauth::aresult));
      break;
    }
  case SFSLOGIN_BAD:
    ntries++;
    sendreq ();
    break;
  case SFSLOGIN_ALLBAD:
    finish (0);
    break;
#endif
  default:
    warn << "userauth: bad status in loginres!\n";
    fail ();
    return;
  }
 
  delete lresp;
  spawn ();
}

void
rexsess::dologin (ptr<sfsagent_auth_res> ares, clnt_stat err)
{
  if (err) {
    warn << "dologin: " << err << "\n";
    fail ();
    return;
  }
  else if (!ares->authenticate) {
    warn << "dologin: no certificate\n";
    fail ();
    return;
  }
  else {
    sfs_loginarg larg;
    larg.seqno = 1;
    larg.certificate = *ares->certificate;

    sfs_loginres *lresp = New sfs_loginres;
    sfsclnt->call (SFSPROC_LOGIN, &larg, lresp, 
		   wrap (this, &rexsess::loggedin, lresp));
  }
}

ptr<sfsagent_auth_res>
rexsess::signauthreq (sfsagent_authinit_arg *aa)
{
  key *k = keynum (aa->ntries);
  if (!k || aa->authinfo.type != SFS_AUTHINFO) {
    warn ("signauthreq: couldn't find key\n");
    return NULL;
  }

  ref<sfsagent_auth_res> res = New refcounted<sfsagent_auth_res> (true);
  sfs_autharg ar (SFS_AUTHREQ);
  sfs_signed_authreq sar;
  str rawsar;

  ar.req->usrkey = k->k->n;
  sar.type = SFS_SIGNED_AUTHREQ;
  sar.seqno = aa->seqno;
  bzero (sar.usrinfo.base (), sar.usrinfo.size ());

  if (!sha1_hashxdr (sar.authid.base (), aa->authinfo)
      || !(rawsar = xdr2str (sar))
      || !(ar.req->signed_req = k->k->sign_r (rawsar),
	   xdr2bytes (*res->certificate, ar))) {
    warn ("signauthreq: xdr failed\n");
    res->set_authenticate (false);
  }

  warn << aa->requestor << ": " << aa->authinfo.name << ":"
       << armor32 (str (aa->authinfo.hostid.base (),
			aa->authinfo.hostid.size ()))
       << " (" << implicit_cast<int> (aa->authinfo.service) << ")\n";

  return res;
}

void
rexsess::connected (ptr<sfscon> sc, str err)
{
  if (!sc) {
    warn << path << ": FAILED (" << err << ")\n";
    fail ();
    return;
  }

  sessconn = sc;
  sessclnt = aclnt::alloc (sc->x, rexd_prog_1);
  sfsclnt = aclnt::alloc (sc->x, sfs_program_1);

  //  sessclnt->seteofcb (wrap (this, &rexsess::eof));

  sfsagent_authinit_arg aarg;
  aarg.ntries = 0;
  aarg.requestor = "";
  aarg.seqno = 1;
  aarg.authinfo.type = SFS_AUTHINFO;
  aarg.authinfo.service = SFS_REX;
  aarg.authinfo.name = sc->servinfo.host.hostname;
  aarg.authinfo.hostid = sc->hostid;
  aarg.authinfo.sessid = sc->sessid;

  ptr<sfsagent_auth_res> ares = signauthreq (&aarg);
  if (!ares) {
    fail ();
    return;
  }

  dologin (ares, RPC_SUCCESS);
}

void
rex_connect (str path, str pathfrom, bool forwardagent, cb_rex::ptr cb)
{
  if (!pathfrom)
    pathfrom = "localhost";
  if (rexsess *sp = sesstab[path]) {
    warn << "rexsess: hash lookup for " << path
	 << " succeeded from " << pathfrom << "\n";
    if (forwardagent)
      sp->forwardagent = true;
    sp->succeed (cb);
  }
  else
    vNew rexsess (path, pathfrom, forwardagent, cb);
}

static void
print_rexsess (vec<rex_sessentry> *psv, rexsess *sp) {
  rex_sessentry se;
  se.to = sp->path;
  se.created_from = sp->rexconnect_origin;
  se.agentforwarded = sp->agentforwarded;
  psv->push_back (se);
}

void
list_rexsess (svccb *sbp)
{
  vec<rex_sessentry> sv;
  sesstab.traverse (wrap (print_rexsess, &sv));

  rex_sessvec rsv;
  rsv.setsize (sv.size ());
  for (size_t i = 0; i < sv.size (); i++)
    rsv[i] = sv[i];
  sbp->replyref (rsv);
}

bool
kill_rexsess (str path) {
  if (rexsess *sp = sesstab[path]) {
    warn << "removing rexsession connected to " << path << "\n";
    sp->abort ();
    delete sp;
    return true;
  }
  else {
    warn <<
      "received request to remove nonexistant rexsession to " << path << "\n";
    return false;
  }
}






