#include #include #include #include #include #include #include #include "dat.h" #include "fns.h" typedef struct TTLS { uchar tp; uchar flags; uchar tln[4]; //optional, present if L flag set } TTLS; enum { TtlsFlagL = 1<<7, // header contains tln field TtlsFlagM = 1<<6, // more fragment(s) will follow for current msg TtlsFlagS = 1<<5, // start of tls session TtlsVersion = (1<<2)|(1<<1)|(1<<0), TtlsShortHlen = 2, // without tln field TtlsLongHlen = TtlsShortHlen+4, // with tln field Idle = 0, Start, Waiting, Timeout, Sending, RecvAck, Receiving, SendAck, Received, }; char *snames[] = { [Idle] "Idle", [Start] "Start", [Waiting] "Waiting", [Timeout] "Timeout", [Sending] "Sending", [RecvAck] "RecvAck", [Receiving] "Receiving", [SendAck] "SendAck", [Received] "Received", }; typedef struct TTLSstate { TLSconn tlsconn; // our handle to the tls connection int tlspipe[2]; // double pipe over which we talk with our tls // the stuff we read from it has to be fragmented, encapsulated and sent // the fragments we receive have to be reassembled and then written to it Channel *tlsfdc; // used to send file desc we get from tlsClient Channel *readc; // contains index in rbuf containing last msg read from tlspipe Channel *eofc; // confirm eof on tlspipe Channel *startclientc; // start new clientclient session Channel *startreadc; // start new readproc session Channel *timerstart; Channel *timerc; int tlsfd; int ttlsWhile; int ttlsPeriod; int ttslTxLen; // length of frame we prepared for sending int ttlsDone; // done processing the frame (and, if needed, preparing the response)? int ttlsState; // ttls state we are in uint ttlsVersion; Buf rbuf[Nbuf]; // msg read from the tls pipe, to be sent (possibly in fragments) int ridx; // index of first free rbuf int sendT; // total length of msg to be sent uint sendL; // length remaining to be sent uchar*sendP; // pointer in rbuf[...] pointing to stuff remaining to be sent int sendS; // still have to send first frame (fragment) for current msg? Buf wbuf; // receive buffer in which we reassemble fragments, and then write to tlspipe uint recvT; // total length we want to receive (and reassemble) uint recvL; // length received (and reassembled) so far uchar*recvP; // first free position (reassembly insert point) in recv buffer Thumbprint *thumbTable; int inuse; int clientid; uchar*theSessionCert; int theSessionCertlen; uchar* theSessionID; int theSessionIDlen; } TTLSstate; static TTLSstate theTTLSstate; static char errbuf[256]; static void cleanup(TTLSstate* s) { int idx, readdone; Alt a[] = { /* c v op */ {s->readc, &idx, CHANRCV}, {s->eofc, nil, CHANRCV}, {nil, nil, CHANEND}, }; syslog(0, logname, "cleanup pre tlsfd=%d tlspipe[0]=%d tlspipe[1]=%d", s->tlsfd, s->tlspipe[0], s->tlspipe[1]); if (!s->inuse) return; readdone = 0; if (s->tlsfd < 0 && s->tlspipe[0] < 0 && s->tlspipe[1] < 0) readdone = 1; if (s->tlsfd >= 0) { syslog(0, logname, "\tcleanup: closing tlsfd: %d", s->tlsfd); close(s->tlsfd); // should make devtls close s->tlspipe[1], causing eof on s->tlspipe[0] in readproc s->tlsfd = -1; } if (s->tlspipe[0] >= 0) { syslog(0, logname, "\tcleanup: closing tlspipe[0]: %d", s->tlspipe[0]); close(s->tlspipe[0]); s->tlspipe[0] = -1; } if (s->clientid != 0) threadkill(s->clientid); s->clientid = 0; syslog(0, logname, "\tcleanup middle readdone=%d tlsfd=%d tlspipe[0]=%d tlspipe[1]=%d", readdone, s->tlsfd, s->tlspipe[0], s->tlspipe[1]); while(!readdone) { syslog(0, logname, "\tcleanup receiving..."); switch(alt(a)){ case 0: syslog(0, logname, "\t\toops... cleanup recv from readc: %d", idx); // is this the close assert . if so, should we write this to ether? break; case 1: syslog(0, logname, "\t\tcleanup: confirmed eof from readproc"); readdone = 1; break; } } if (s->tlspipe[1] >= 0) { syslog(0, logname, "\tcleanup: closing tlspipe[1]: %d", s->tlspipe[1]); close(s->tlspipe[1]); s->tlspipe[1] = -1; } syslog(0, logname, "cleanup post tlsfd=%d tlspipe[0]=%d tlspipe[1]=%d", s->tlsfd, s->tlspipe[0], s->tlspipe[1]); } // ====================== static void tick(TTLSstate *s) { if (s->ttlsWhile >= 0) s->ttlsWhile--; } static void clockproc(void *arg) { TTLSstate *s; s = arg; for(;;) { recv(s->timerstart, nil); for (;;) { if (s->ttlsWhile == -1) break; sleep(100); if (s->ttlsWhile == -1) break; tick(s); if (s->ttlsWhile == 0) { send(s->timerc, nil); break; } } } } static char* timerName(TTLSstate *s, int *p) { if (p == &s->ttlsWhile) return "ttlsWhile"; return "unknown"; } static void startTimer(TTLSstate *s, int *p, int val) { syslog(0, logname, "startTimer %s to %d", timerName(s, p), val); if (s->ttlsWhile >= 0) syslog(0, logname, "startTimer oops: %s runs, val=%d", timerName(s, &s->ttlsWhile), s->ttlsWhile); *p = val * 10; if (nbsend(s->timerstart, nil) == 0) syslog(0, logname, "startTimer oops: could not timerstart"); } static void resetTimer(TTLSstate *s, int *p) { syslog(0, logname, "resetTimer %s (val was %d)", timerName(s, p), *p); *p = -1; syslog(0, logname, "\tresetTimer %s (val is %d)", timerName(s, p), *p); } // ====================== static void readproc(void *arg) { TTLSstate *s; Buf *r; s = arg; syslog(0, logname, "readproc starts: %d", threadid()); while(recvul(s->startreadc)) { syslog(0, logname, "readproc monitoring pipe: %d", s->tlspipe[0]); for(;;) { if (s->tlspipe[0] < 0) { syslog(0, logname, "readproc pipe not active: %d", s->tlspipe[0]); break; } r = &s->rbuf[s->ridx]; r->n = read(s->tlspipe[0], r->b, Buflen); syslog(0, logname, "readproc read from %d:%d", s->tlspipe[0], r->n); if (r->n <= 0) { syslog(0, logname, "readproc eof on pipe: %d", s->tlspipe[0]); break; } if (s->tlspipe[0] < 0) { syslog(0, logname, "readproc pipe no longer active: %d", s->tlspipe[0]); break; } // syslog(0, logname, "readproc sending..."); sendul(s->readc, s->ridx); s->ridx = (s->ridx+1)%Nbuf; } syslog(0, logname, "readproc sending eofc: %d", s->tlspipe[0]); sendul(s->eofc, 0); syslog(0, logname, "readproc restarts: %d", s->tlspipe[0]); } syslog(0, logname, "readproc exits: %d", threadid()); threadexits(nil); } static void clientproc(void *arg) { TTLSstate *s; int fd; uchar hash[SHA1dlen]; s = arg; syslog(0, logname, "clientproc starts: %d", threadid()); s->clientid = threadid(); syslog(0, logname, "clientproc (re)starting: tlspipe[1]=%d", s->tlspipe[1]); if (s->tlspipe[1] <= 0) { snprint(errbuf, sizeof(errbuf), "clientproc: no fd for tlsClient:%d", s->tlspipe[1]); syslog(0, logname, "%s", errbuf); fprint(2, "%s\n", errbuf); threadexitsall(errbuf); } syslog(0, logname, "calling tlsClient"); fd = tlsClient(s->tlspipe[1], &s->tlsconn); syslog(0, logname, "tlsClient result: fd=%d", fd); if (debug) print("clientproc: fd %d\n", fd); if (fd < 0) { syslog(0, logname, "tlsClient failed: %r"); fprint(2, "tlsClient failed: %r\n"); } else { syslog(0, logname, "tlsClient ok fd=%d", fd); if (s->tlsconn.cert==nil || s->tlsconn.certlen<=0) { syslog(0, logname, "server did not provide TLS certificate"); fprint(2, "server did not provide TLS certificate\n"); } else { // X509dump(s->tlsconn.cert, s->tlsconn.certlen); if (s->thumbTable != nil) { sha1(s->tlsconn.cert, s->tlsconn.certlen, hash, nil); if(!okThumbprint(hash, s->thumbTable)) { syslog(0, logname, "server certificate %.*H not recognized", SHA1dlen, hash); fprint(2, "server certificate %.*H not recognized\n", SHA1dlen, hash); } } else { syslog(0, logname, "no thumbprint to check server certificate"); } } } // clean up before we (implicitly) yield if (s->tlsconn.sessionID != nil) free(s->tlsconn.sessionID); s->tlsconn.sessionID = nil; s->tlsconn.sessionIDlen = 0; if (s->tlsconn.cert) free(s->tlsconn.cert); s->tlsconn.cert = nil; s->tlsconn.certlen = 0; sendul(s->tlsfdc, fd); syslog(0, logname, "clientproc ... finished: fd=%d", fd); syslog(0, logname, "clientproc exits: %d", threadid()); s->clientid = 0; threadexits(nil); } static void setupTls(TTLSstate *s) { syslog(0, logname, "setupTls pre tlspipe[0]=%d tlspipe[1]=%d", s->tlspipe[0], s->tlspipe[1]); if (s->tlspipe[0] >= 0 || s->tlspipe[1] >= 0) { snprint(errbuf, sizeof(errbuf), "setupTls: pipe already open? %d %d", s->tlspipe[0], s->tlspipe[1]); fprint(2, "%s\n", errbuf); syslog(0, logname, "%s", errbuf); threadexitsall(errbuf); } if (pipe(s->tlspipe) < 0) { fprint(2, "pipe failed: %r\n"); syslog(0, logname, "pipe failed: %r"); threadexitsall("pipe failed"); } // call tlsClient and wait for result syslog(0, logname, "setupTls startclientc..."); s->clientid = proccreate(clientproc, s, STACK); // signal reader to restart syslog(0, logname, "setupTls startreadc..."); sendul(s->startreadc, 1); s->inuse = 1; syslog(0, logname, "setupTls post tlspipe[0]=%d tlspipe[1]=%d", s->tlspipe[0], s->tlspipe[1]); } static int buildFrameStart(TTLSstate *s, uchar*b, int mtu) { TTLS *t; if (mtu <= TtlsLongHlen) print("buildFrameStart error: mtu much too small: mtu=%d, longhdr=%d\n", mtu, TtlsLongHlen); if (s->sendL <= mtu-TtlsLongHlen) print("buildFrameStart error: small enough, no framing needed: sz=%d, space=%d\n", s->sendL, mtu-TtlsLongHlen); t = (TTLS*)b; memset(t, 0, TtlsLongHlen); t->tp = EapTpTtls; t->flags = TtlsFlagM | TtlsFlagL; hnputl(t->tln, s->sendL); memcpy(b+TtlsLongHlen, s->sendP, mtu-TtlsLongHlen); s->ttslTxLen = mtu; s->sendP += mtu-TtlsLongHlen; s->sendL -= mtu-TtlsLongHlen; return mtu; } static int buildFrameMiddle(TTLSstate *s, uchar*b, int mtu) { TTLS *t; if (mtu <= TtlsShortHlen) print("buildFrameMiddle error: mtu much too small: mtu=%d, longhdr=%d\n", mtu, TtlsShortHlen); if (s->sendL <= mtu-TtlsShortHlen) print("buildFrameMiddle error: small enough, no framing needed: sz=%d, space=%d\n", s->sendL, mtu-TtlsShortHlen); t = (TTLS*)b; memset(t, 0, TtlsShortHlen); t->tp = EapTpTtls; t->flags = TtlsFlagM; memcpy(b+TtlsShortHlen, s->sendP, mtu-TtlsShortHlen); s->ttslTxLen = mtu; s->sendP += mtu-TtlsShortHlen; s->sendL -= mtu-TtlsShortHlen; return mtu; } static int buildMsg(TTLSstate *s, uchar*b, int mtu) { TTLS *t; int res; if (mtu <= TtlsShortHlen) print("buildMsg error: mtu much too small: mtu=%d, longhdr=%d\n", mtu, TtlsShortHlen); if (s->sendL > mtu-TtlsShortHlen) print("buildMsg error: too big, framing needed: sz=%d, space=%d\n", s->sendL, mtu-TtlsShortHlen); t = (TTLS*)b; memset(t, 0, TtlsShortHlen); t->tp = EapTpTtls; memcpy(b+TtlsShortHlen, s->sendP, s->sendL); s->ttslTxLen = TtlsShortHlen + s->sendL; res = s->sendL; s->sendP = 0; s->sendL = 0; return res; } static void buildAck(TTLSstate *s, uchar*b, int mtu) { TTLS *t; USED(mtu); t = (TTLS*)b; memset(t, 0, TtlsShortHlen); t->tp = EapTpTtls; s->ttslTxLen = TtlsShortHlen; } static void trans(TTLSstate *s, int new) { syslog(0, logname, "ttls trans: %s -> %s", (s->ttlsState>=0) ? snames[s->ttlsState] : "-", snames[new]); switch(new){ case RecvAck: s->ttlsDone = 1; break; case Receiving: s->ttlsDone = 1; break; case Idle: s->ttlsDone = 1; break; } s->ttlsState = new; } static void ttls(TTLSstate *s, uchar*rcvp, uint rcvl, uchar*txp, uint mtu, int*ttlsSuccess, int*ttlsFail) { int fd; int i; Alt a[] = { /* c v op */ {s->tlsfdc, &fd, CHANRCV}, {s->readc, &i, CHANRCV}, {s->timerc, nil, CHANRCV}, {nil, nil, CHANEND}, }; TTLS *t; uchar *p; uint l; int n; int olen, flen; // print("ttls %s\n", snames[s->ttlsState]); if (debug) print("ttls %s; recvL=%d; recvT=%d\n", snames[s->ttlsState], s->recvL, s->recvT); switch(s->ttlsState){ case Idle: trans(s, Idle); break; case Start: setupTls(s); // new session trans(s, Waiting); break; case Waiting: while(s->ttlsState == Waiting) { startTimer(s, &s->ttlsWhile, s->ttlsPeriod); switch(alt(a)){ case 0: // the tlsClient call returned syslog(0, logname, "ttls Waiting tlsClient return %d", fd); s->tlsfd = fd; if (fd < 0) { *ttlsFail = 1; trans(s, Idle); } else { doTTLSphase2(fd); } break; case 1: // something read from tlspipe: encapsulate and send syslog(0, logname, "ttls Waiting read from tlspipe"); s->sendP = s->rbuf[i].b; s->sendL = s->rbuf[i].n; s->sendT = s->sendL; if (debug) print("ttls readc: i=%d sendP=%p sendL=%d\n", i, s->sendP, s->sendL); s->sendS = 1; trans(s, Sending); break; case 2: /* timer expiration event */ syslog(0, logname, "ttls Waiting tlsClient timer expired"); if (s->ttlsWhile == 0) trans(s, Timeout); else { fprint(2, "ttls Waiting 2: should not happen\n"); syslog(0, logname, "ttls Waiting 2: should not happen"); threadexitsall("ttls Waiting 2: should not happen"); } break; } resetTimer(s, &s->ttlsWhile); } break; case Timeout: trans(s, Receiving); // seems we need more stuff to satisfy tlsClient break; case Sending: if (s->sendS && s->sendL > mtu-TtlsShortHlen) { olen = s->sendL; flen = buildFrameStart(s, txp, mtu); if (debug) print("ttls sendS and framed %d of %d, total %d, remains %d\n", flen, olen, s->sendT, s->sendL); s->sendS = 0; trans(s, RecvAck); } else if (s->sendL > mtu-TtlsShortHlen) { olen = s->sendL; flen = buildFrameMiddle(s, txp, mtu); if (debug) print("ttls framed %d of %d, total %d, remains %d\n", flen, olen, s->sendT, s->sendL); trans(s, RecvAck); } else { olen = s->sendL; flen = buildMsg(s, txp, mtu); if (debug) print("ttls framed %d of %d, total %d, remains %d\n", flen, olen, s->sendT, s->sendL); s->recvP = s->wbuf.b; s->recvL = 0; s->recvT = 0; trans(s, Receiving); } break; case RecvAck: t = (TTLS*)rcvp; if (t->flags&TtlsFlagS) print("tls: unexpected TtlsFlagS in %s\n", snames[s->ttlsState]); if (t->flags&TtlsFlagM) print("tls: unexpected TtlsFlagM in %s\n", snames[s->ttlsState]); if (t->flags&TtlsFlagL) print("tls: unexpected TtlsFlagL in %s\n", snames[s->ttlsState]); trans(s, Sending); break; case Receiving: t = (TTLS*)rcvp; if (t->flags&TtlsFlagS) print("tls: unexpected TtlsFlagS in %s\n", snames[s->ttlsState]); if (t->flags&TtlsFlagL && s->recvT > 0) print("tls: TtlsFlagL when recvT=%d\n", s->recvT); if (t->flags&TtlsFlagL) { s->recvT = nhgetl(t->tln); if (debug) print("ttls: TtlsFlagL len=%d\n", s->recvT); p = rcvp+TtlsLongHlen; l = rcvl-TtlsLongHlen; if (s->recvP != s->wbuf.b) print("ttls %s: recvP != wbuf.b recvP=%p wbuf.b=%p \n", snames[s->ttlsState], s->recvP, s->wbuf.b); if (s->recvL != 0) print("ttls %s: recvL != 0 recvL=%d\n", snames[s->ttlsState], s->recvL); } else { p = rcvp+TtlsShortHlen; l = rcvl-TtlsShortHlen; if (s->recvP != s->wbuf.b + s->recvL) print("ttls %s: recvP != wbuf.b + s->recvL recvP=%p wbuf.b=%p recvL=%d\n", snames[s->ttlsState], s->recvP, s->wbuf.b, s->recvL); } memcpy(s->recvP, p, l); s->recvP += l; s->recvL += l; if (debug) print("ttls %s: received %d; recvL=%d; recvT=%d\n", snames[s->ttlsState], l, s->recvL, s->recvT); if (t->flags&TtlsFlagM) trans(s, SendAck); else { if (s->recvT > 0 && s->recvT != s->recvL) print("ttls : recvT=%d != recvL=%d\n", s->recvT, s->recvL); if (s->recvL > 0) trans(s, Received); else trans(s, Waiting); } break; case SendAck: buildAck(s, txp, mtu); trans(s, Receiving); break; case Received: if (debug) print("ttls %s: writing tlspipe[0]: %s\n", snames[s->ttlsState], hexprefix(s->wbuf.b, s->recvL, 5)); n = write(s->tlspipe[0], s->wbuf.b, s->recvL); if (n<0) print("ttls %s: error writing tlspipe[0]: %r\n", snames[s->ttlsState]); syslog(0, logname, "writeproc written %d", n); if (n != s->recvL) print("ttls %s: writing tlspipe[0]: n != recvL n=%d recvL=%d\n", snames[s->ttlsState], n, s->recvL); if (debug) print("ttls %s: written to tlspipe[0] : %d\n", snames[s->ttlsState], s->recvL); trans(s, Waiting); break; } if (debug) print("ttls %s; recvL=%d; recvT=%d\n", snames[s->ttlsState], s->recvL, s->recvT); // print("ttls .... %s\n", snames[s->ttlsState]); } void initTTLS(char *file, char *filex) { TTLSstate *s; syslog(0, logname, "initTTLS"); s = &theTTLSstate; memset(s, 0, sizeof(TTLSstate)); s->ttlsState = Idle; s->tlsfdc = chancreate(sizeof(int), 0); s->readc = chancreate(sizeof(int), 0); s->eofc = chancreate(sizeof(int), 0); s->timerc = chancreate(sizeof(int), 0); s->timerstart = chancreate(sizeof(int), 0); s->startclientc = chancreate(sizeof(int), 0); s->startreadc = chancreate(sizeof(int), 0); s->tlsfd = -1; s->tlspipe[0] = -1; s->tlspipe[1] = -1; s->ttlsWhile = -1; s->ttlsPeriod = 5; //seconds s->tlsconn.sessionType = "ttls"; s->tlsconn.sessionConst = "ttls keying material"; s->tlsconn.sessionKey = theSessionKey; s->tlsconn.sessionKeylen = sizeof(theSessionKey); if (debugTLS) s->tlsconn.trace = print; fmtinstall('H', encodefmt); if (file) { s->thumbTable = initThumbprints(file, filex); if (s->thumbTable == nil) { snprint(errbuf, sizeof(errbuf), "initThumbprints: %r"); syslog(0, logname, "%s", errbuf); fprint(2, "%s\n", errbuf); threadexitsall(errbuf); } } proccreate(readproc, s, STACK); proccreate(clockproc, s, STACK); } void abortTTLS(void) { TTLSstate *s; syslog(0, logname, "abortTTLS"); s = &theTTLSstate; if (s->tlspipe[0] >= 0) { close(s->tlspipe[0]); s->tlspipe[0] = -1; } } static void run(TTLSstate *s, uchar*rcvp, uint rcvl, uchar*txp, uint mtu, int*success, int*failed) { s->ttlsDone = 0; while (!s->ttlsDone) ttls(s, rcvp, rcvl, txp, mtu, success, failed); } int processTTLS(uchar*rcvp, uint rcvl, int expectStart, uchar*txp, uint mtu, int*success, int*failed) { TTLS *hr; uchar flags, version; TTLSstate *s; // if (debug) print("processTTLS br=%p txp=%p mtu=%d bl=%d\n", br, txp, mtu, bl); s = &theTTLSstate; hr = (TTLS*)rcvp; if (hr->tp != EapTpTtls) return 0; // flag error?? // first thing should be EAP-TTLS start packet flags = rcvp[1]; // check length version = flags & TtlsVersion; if (debug) print("processTTLS flags=%s%s%s ver=%d mtu=%d bl=%d\n", (flags&TtlsFlagS ? "S":""), (flags&TtlsFlagM ? "M":""), (flags&TtlsFlagL ? "L":""), version, mtu, rcvl); if (expectStart && !flags&TtlsFlagS) { fprint(2, "expected EAP-TTLS start packet\n"); syslog(0, logname, "expected EAP-TTLS start packet"); threadexitsall("expected EAP-TTLS start packet"); } if (flags & TtlsFlagS) { cleanup(s); // previous session // ack?? // look for piggy-backed stuff? s->ttlsVersion = version; s->ttlsState = Start; s->ttlsDone = 0; s->sendP = 0; s->sendL = 0; s->sendS = 0; s->sendT = 0; s->recvP = 0; s->recvL = 0; s->recvT = 0; // we don't have a client certificate s->tlsconn.cert = nil; s->tlsconn.certlen = 0; // avoid trying session resumption - tlsClient does not support it s->tlsconn.sessionID = nil; s->tlsconn.sessionIDlen = 0; // if (debug) print("processTTLS TtlsFlagS version=%d \n", version); } run(s, rcvp, rcvl, txp, mtu, success, failed); return s->ttslTxLen; }