/* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ #include #include #include "secutil.h" #include "basicutil.h" #if defined(XP_UNIX) #include #endif #include #include #include #include #include "plgetopt.h" #include "nspr.h" #include "prio.h" #include "prnetdb.h" #include "prerror.h" #include "pk11func.h" #include "secitem.h" #include "sslproto.h" #include "nss.h" #include "ssl.h" #ifndef PORT_Sprintf #define PORT_Sprintf sprintf #endif #ifndef PORT_Strstr #define PORT_Strstr strstr #endif #ifndef PORT_Malloc #define PORT_Malloc PR_Malloc #endif #define RD_BUF_SIZE (60 * 1024) /* Include these cipher suite arrays to re-use tstclnt's * cipher selection code. */ int ssl2CipherSuites[] = { SSL_EN_RC4_128_WITH_MD5, /* A */ SSL_EN_RC4_128_EXPORT40_WITH_MD5, /* B */ SSL_EN_RC2_128_CBC_WITH_MD5, /* C */ SSL_EN_RC2_128_CBC_EXPORT40_WITH_MD5, /* D */ SSL_EN_DES_64_CBC_WITH_MD5, /* E */ SSL_EN_DES_192_EDE3_CBC_WITH_MD5, /* F */ 0 }; int ssl3CipherSuites[] = { -1, /* SSL_FORTEZZA_DMS_WITH_FORTEZZA_CBC_SHA* a */ -1, /* SSL_FORTEZZA_DMS_WITH_RC4_128_SHA * b */ TLS_RSA_WITH_RC4_128_MD5, /* c */ TLS_RSA_WITH_3DES_EDE_CBC_SHA, /* d */ TLS_RSA_WITH_DES_CBC_SHA, /* e */ TLS_RSA_EXPORT_WITH_RC4_40_MD5, /* f */ TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5, /* g */ -1, /* SSL_FORTEZZA_DMS_WITH_NULL_SHA * h */ TLS_RSA_WITH_NULL_MD5, /* i */ SSL_RSA_FIPS_WITH_3DES_EDE_CBC_SHA, /* j */ SSL_RSA_FIPS_WITH_DES_CBC_SHA, /* k */ TLS_RSA_EXPORT1024_WITH_DES_CBC_SHA, /* l */ TLS_RSA_EXPORT1024_WITH_RC4_56_SHA, /* m */ TLS_RSA_WITH_RC4_128_SHA, /* n */ TLS_DHE_DSS_WITH_RC4_128_SHA, /* o */ TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA, /* p */ TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA, /* q */ TLS_DHE_RSA_WITH_DES_CBC_SHA, /* r */ TLS_DHE_DSS_WITH_DES_CBC_SHA, /* s */ TLS_DHE_DSS_WITH_AES_128_CBC_SHA, /* t */ TLS_DHE_RSA_WITH_AES_128_CBC_SHA, /* u */ TLS_RSA_WITH_AES_128_CBC_SHA, /* v */ TLS_DHE_DSS_WITH_AES_256_CBC_SHA, /* w */ TLS_DHE_RSA_WITH_AES_256_CBC_SHA, /* x */ TLS_RSA_WITH_AES_256_CBC_SHA, /* y */ TLS_RSA_WITH_NULL_SHA, /* z */ 0 }; #define NO_FULLHS_PERCENTAGE -1 /* This global string is so that client main can see * which ciphers to use. */ static const char *cipherString; static PRInt32 certsTested; static int MakeCertOK; static int NoReuse; static int fullhs = NO_FULLHS_PERCENTAGE; /* percentage of full handshakes to ** perform */ static PRInt32 globalconid = 0; /* atomically set */ static int total_connections; /* total number of connections to perform */ static int total_connections_rounded_down_to_hundreds; static int total_connections_modulo_100; static PRBool NoDelay; static PRBool QuitOnTimeout = PR_FALSE; static PRBool ThrottleUp = PR_FALSE; static PRLock * threadLock; /* protects the global variables below */ static PRTime lastConnectFailure; static PRTime lastConnectSuccess; static PRTime lastThrottleUp; static PRInt32 remaining_connections; /* number of connections left */ static int active_threads = 8; /* number of threads currently trying to ** connect */ static PRInt32 numUsed; /* end of variables protected by threadLock */ static SSL3Statistics * ssl3stats; static int failed_already = 0; static SSLVersionRange enabledVersions; static PRBool enableSSL2 = PR_TRUE; static PRBool bypassPKCS11 = PR_FALSE; static PRBool disableLocking = PR_FALSE; static PRBool ignoreErrors = PR_FALSE; static PRBool enableSessionTickets = PR_FALSE; static PRBool enableCompression = PR_FALSE; static PRBool enableFalseStart = PR_FALSE; static PRBool enableCertStatus = PR_FALSE; PRIntervalTime maxInterval = PR_INTERVAL_NO_TIMEOUT; char * progName; secuPWData pwdata = { PW_NONE, 0 }; int stopping; int verbose; SECItem bigBuf; #define PRINTF if (verbose) printf #define FPRINTF if (verbose) fprintf static void Usage(const char *progName) { fprintf(stderr, "Usage: %s [-n nickname] [-p port] [-d dbdir] [-c connections]\n" " [-BDNovqs] [-f filename] [-N | -P percentage]\n" " [-w dbpasswd] [-C cipher(s)] [-t threads] [-W pwfile]\n" " [-V [min-version]:[max-version]] [-a sniHostName] hostname\n" " where -v means verbose\n" " -o flag is interpreted as follows:\n" " 1 -o means override the result of server certificate validation.\n" " 2 -o's mean skip server certificate validation altogether.\n" " -D means no TCP delays\n" " -q means quit when server gone (timeout rather than retry forever)\n" " -s means disable SSL socket locking\n" " -N means no session reuse\n" " -P means do a specified percentage of full handshakes (0-100)\n" " -V [min]:[max] restricts the set of enabled SSL/TLS protocols versions.\n" " All versions are enabled by default.\n" " Possible values for min/max: ssl2 ssl3 tls1.0 tls1.1 tls1.2\n" " Example: \"-V ssl3:\" enables SSL 3 and newer.\n" " -U means enable throttling up threads\n" " -B bypasses the PKCS11 layer for SSL encryption and MACing\n" " -T enable the cert_status extension (OCSP stapling)\n" " -u enable TLS Session Ticket extension\n" " -z enable compression\n" " -g enable false start\n", progName); exit(1); } static void errWarn(char * funcString) { PRErrorCode perr = PR_GetError(); PRInt32 oserr = PR_GetOSError(); const char * errString = SECU_Strerror(perr); fprintf(stderr, "strsclnt: %s returned error %d, OS error %d: %s\n", funcString, perr, oserr, errString); } static void errExit(char * funcString) { errWarn(funcString); exit(1); } /************************************************************************** ** ** Routines for disabling SSL ciphers. ** **************************************************************************/ void disableAllSSLCiphers(void) { const PRUint16 *cipherSuites = SSL_GetImplementedCiphers(); int i = SSL_GetNumImplementedCiphers(); SECStatus rv; /* disable all the SSL3 cipher suites */ while (--i >= 0) { PRUint16 suite = cipherSuites[i]; rv = SSL_CipherPrefSetDefault(suite, PR_FALSE); if (rv != SECSuccess) { printf("SSL_CipherPrefSetDefault didn't like value 0x%04x (i = %d)\n", suite, i); errWarn("SSL_CipherPrefSetDefault"); exit(2); } } } /* This invokes the "default" AuthCert handler in libssl. ** The only reason to use this one is that it prints out info as it goes. */ static SECStatus mySSLAuthCertificate(void *arg, PRFileDesc *fd, PRBool checkSig, PRBool isServer) { SECStatus rv; CERTCertificate * peerCert; const SECItemArray *csa; if (MakeCertOK>=2) { return SECSuccess; } peerCert = SSL_PeerCertificate(fd); PRINTF("strsclnt: Subject: %s\nstrsclnt: Issuer : %s\n", peerCert->subjectName, peerCert->issuerName); csa = SSL_PeerStapledOCSPResponses(fd); if (csa) { PRINTF("Received %d Cert Status items (OCSP stapled data)\n", csa->len); } /* invoke the "default" AuthCert handler. */ rv = SSL_AuthCertificate(arg, fd, checkSig, isServer); PR_ATOMIC_INCREMENT(&certsTested); if (rv == SECSuccess) { fputs("strsclnt: -- SSL: Server Certificate Validated.\n", stderr); } CERT_DestroyCertificate(peerCert); /* error, if any, will be displayed by the Bad Cert Handler. */ return rv; } static SECStatus myBadCertHandler( void *arg, PRFileDesc *fd) { PRErrorCode err = PR_GetError(); if (!MakeCertOK) fprintf(stderr, "strsclnt: -- SSL: Server Certificate Invalid, err %d.\n%s\n", err, SECU_Strerror(err)); return (MakeCertOK ? SECSuccess : SECFailure); } void printSecurityInfo(PRFileDesc *fd) { CERTCertificate * cert = NULL; SSL3Statistics * ssl3stats = SSL_GetStatistics(); SECStatus result; SSLChannelInfo channel; SSLCipherSuiteInfo suite; static int only_once; if (only_once && verbose < 2) return; only_once = 1; result = SSL_GetChannelInfo(fd, &channel, sizeof channel); if (result == SECSuccess && channel.length == sizeof channel && channel.cipherSuite) { result = SSL_GetCipherSuiteInfo(channel.cipherSuite, &suite, sizeof suite); if (result == SECSuccess) { FPRINTF(stderr, "strsclnt: SSL version %d.%d using %d-bit %s with %d-bit %s MAC\n", channel.protocolVersion >> 8, channel.protocolVersion & 0xff, suite.effectiveKeyBits, suite.symCipherName, suite.macBits, suite.macAlgorithmName); FPRINTF(stderr, "strsclnt: Server Auth: %d-bit %s, Key Exchange: %d-bit %s\n" " Compression: %s\n", channel.authKeyBits, suite.authAlgorithmName, channel.keaKeyBits, suite.keaTypeName, channel.compressionMethodName); } } cert = SSL_LocalCertificate(fd); if (!cert) cert = SSL_PeerCertificate(fd); if (verbose && cert) { char * ip = CERT_NameToAscii(&cert->issuer); char * sp = CERT_NameToAscii(&cert->subject); if (sp) { fprintf(stderr, "strsclnt: subject DN: %s\n", sp); PORT_Free(sp); } if (ip) { fprintf(stderr, "strsclnt: issuer DN: %s\n", ip); PORT_Free(ip); } } if (cert) { CERT_DestroyCertificate(cert); cert = NULL; } fprintf(stderr, "strsclnt: %ld cache hits; %ld cache misses, %ld cache not reusable\n" " %ld stateless resumes\n", ssl3stats->hsh_sid_cache_hits, ssl3stats->hsh_sid_cache_misses, ssl3stats->hsh_sid_cache_not_ok, ssl3stats->hsh_sid_stateless_resumes); } /************************************************************************** ** Begin thread management routines and data. **************************************************************************/ #define MAX_THREADS 128 typedef int startFn(void *a, void *b, int c); static PRInt32 numConnected; static int max_threads; /* peak threads allowed */ typedef struct perThreadStr { void * a; void * b; int tid; int rv; startFn * startFunc; PRThread * prThread; PRBool inUse; } perThread; perThread threads[MAX_THREADS]; void thread_wrapper(void * arg) { perThread * slot = (perThread *)arg; PRBool done = PR_FALSE; do { PRBool doop = PR_FALSE; PRBool dosleep = PR_FALSE; PRTime now = PR_Now(); PR_Lock(threadLock); if (! (slot->tid < active_threads)) { /* this thread isn't supposed to be running */ if (!ThrottleUp) { /* we'll never need this thread again, so abort it */ done = PR_TRUE; } else if (remaining_connections > 0) { /* we may still need this thread, so just sleep for 1s */ dosleep = PR_TRUE; /* the conditions to trigger a throttle up are : ** 1. last PR_Connect failure must have happened more than ** 10s ago ** 2. last throttling up must have happened more than 0.5s ago ** 3. there must be a more recent PR_Connect success than ** failure */ if ( (now - lastConnectFailure > 10 * PR_USEC_PER_SEC) && ( (!lastThrottleUp) || ( (now - lastThrottleUp) >= (PR_USEC_PER_SEC/2)) ) && (lastConnectSuccess > lastConnectFailure) ) { /* try throttling up by one thread */ active_threads = PR_MIN(max_threads, active_threads+1); fprintf(stderr,"active_threads set up to %d\n", active_threads); lastThrottleUp = PR_MAX(now, lastThrottleUp); } } else { /* no more connections left, we are done */ done = PR_TRUE; } } else { /* this thread should run */ if (--remaining_connections >= 0) { /* protected by threadLock */ doop = PR_TRUE; } else { done = PR_TRUE; } } PR_Unlock(threadLock); if (doop) { slot->rv = (* slot->startFunc)(slot->a, slot->b, slot->tid); PRINTF("strsclnt: Thread in slot %d returned %d\n", slot->tid, slot->rv); } if (dosleep) { PR_Sleep(PR_SecondsToInterval(1)); } } while (!done && (!failed_already || ignoreErrors)); } SECStatus launch_thread( startFn * startFunc, void * a, void * b, int tid) { PRUint32 i; perThread * slot; PR_Lock(threadLock); PORT_Assert(numUsed < MAX_THREADS); if (! (numUsed < MAX_THREADS)) { PR_Unlock(threadLock); return SECFailure; } i = numUsed++; slot = &threads[i]; slot->a = a; slot->b = b; slot->tid = tid; slot->startFunc = startFunc; slot->prThread = PR_CreateThread(PR_USER_THREAD, thread_wrapper, slot, PR_PRIORITY_NORMAL, PR_GLOBAL_THREAD, PR_JOINABLE_THREAD, 0); if (slot->prThread == NULL) { PR_Unlock(threadLock); printf("strsclnt: Failed to launch thread!\n"); return SECFailure; } slot->inUse = 1; PR_Unlock(threadLock); PRINTF("strsclnt: Launched thread in slot %d \n", i); return SECSuccess; } /* join all the threads */ int reap_threads(void) { int i; for (i = 0; i < MAX_THREADS; ++i) { if (threads[i].prThread) { PR_JoinThread(threads[i].prThread); threads[i].prThread = NULL; } } return 0; } void destroy_thread_data(void) { PORT_Memset(threads, 0, sizeof threads); if (threadLock) { PR_DestroyLock(threadLock); threadLock = NULL; } } void init_thread_data(void) { threadLock = PR_NewLock(); } /************************************************************************** ** End thread management routines. **************************************************************************/ PRBool useModelSocket = PR_TRUE; static const char stopCmd[] = { "GET /stop " }; static const char outHeader[] = { "HTTP/1.0 200 OK\r\n" "Server: Netscape-Enterprise/2.0a\r\n" "Date: Tue, 26 Aug 1997 22:10:05 GMT\r\n" "Content-type: text/plain\r\n" "\r\n" }; struct lockedVarsStr { PRLock * lock; int count; int waiters; PRCondVar * condVar; }; typedef struct lockedVarsStr lockedVars; void lockedVars_Init( lockedVars * lv) { lv->count = 0; lv->waiters = 0; lv->lock = PR_NewLock(); lv->condVar = PR_NewCondVar(lv->lock); } void lockedVars_Destroy( lockedVars * lv) { PR_DestroyCondVar(lv->condVar); lv->condVar = NULL; PR_DestroyLock(lv->lock); lv->lock = NULL; } void lockedVars_WaitForDone(lockedVars * lv) { PR_Lock(lv->lock); while (lv->count > 0) { PR_WaitCondVar(lv->condVar, PR_INTERVAL_NO_TIMEOUT); } PR_Unlock(lv->lock); } int /* returns count */ lockedVars_AddToCount(lockedVars * lv, int addend) { int rv; PR_Lock(lv->lock); rv = lv->count += addend; if (rv <= 0) { PR_NotifyCondVar(lv->condVar); } PR_Unlock(lv->lock); return rv; } int do_writes( void * a, void * b, int c) { PRFileDesc * ssl_sock = (PRFileDesc *)a; lockedVars * lv = (lockedVars *)b; int sent = 0; int count = 0; while (sent < bigBuf.len) { count = PR_Send(ssl_sock, bigBuf.data + sent, bigBuf.len - sent, 0, maxInterval); if (count < 0) { errWarn("PR_Send bigBuf"); break; } FPRINTF(stderr, "strsclnt: PR_Send wrote %d bytes from bigBuf\n", count ); sent += count; } if (count >= 0) { /* last write didn't fail. */ PR_Shutdown(ssl_sock, PR_SHUTDOWN_SEND); } /* notify the reader that we're done. */ lockedVars_AddToCount(lv, -1); return (sent < bigBuf.len) ? SECFailure : SECSuccess; } int handle_fdx_connection( PRFileDesc * ssl_sock, int connection) { SECStatus result; int firstTime = 1; int countRead = 0; lockedVars lv; char *buf; lockedVars_Init(&lv); lockedVars_AddToCount(&lv, 1); /* Attempt to launch the writer thread. */ result = launch_thread(do_writes, ssl_sock, &lv, connection); if (result != SECSuccess) goto cleanup; buf = PR_Malloc(RD_BUF_SIZE); if (buf) { do { /* do reads here. */ PRInt32 count; count = PR_Recv(ssl_sock, buf, RD_BUF_SIZE, 0, maxInterval); if (count < 0) { errWarn("PR_Recv"); break; } countRead += count; FPRINTF(stderr, "strsclnt: connection %d read %d bytes (%d total).\n", connection, count, countRead ); if (firstTime) { firstTime = 0; printSecurityInfo(ssl_sock); } } while (lockedVars_AddToCount(&lv, 0) > 0); PR_Free(buf); buf = 0; } /* Wait for writer to finish */ lockedVars_WaitForDone(&lv); lockedVars_Destroy(&lv); FPRINTF(stderr, "strsclnt: connection %d read %d bytes total. -----------------------\n", connection, countRead); cleanup: /* Caller closes the socket. */ return SECSuccess; } const char request[] = {"GET /abc HTTP/1.0\r\n\r\n" }; SECStatus handle_connection( PRFileDesc *ssl_sock, int tid) { int countRead = 0; PRInt32 rv; char *buf; buf = PR_Malloc(RD_BUF_SIZE); if (!buf) return SECFailure; /* compose the http request here. */ rv = PR_Send(ssl_sock, request, strlen(request), 0, maxInterval); if (rv <= 0) { errWarn("PR_Send"); PR_Free(buf); buf = 0; failed_already = 1; return SECFailure; } printSecurityInfo(ssl_sock); /* read until EOF */ while (1) { rv = PR_Recv(ssl_sock, buf, RD_BUF_SIZE, 0, maxInterval); if (rv == 0) { break; /* EOF */ } if (rv < 0) { errWarn("PR_Recv"); failed_already = 1; break; } countRead += rv; FPRINTF(stderr, "strsclnt: connection on thread %d read %d bytes (%d total).\n", tid, rv, countRead ); } PR_Free(buf); buf = 0; /* Caller closes the socket. */ FPRINTF(stderr, "strsclnt: connection on thread %d read %d bytes total. ---------\n", tid, countRead); return SECSuccess; /* success */ } #define USE_SOCK_PEER_ID 1 #ifdef USE_SOCK_PEER_ID PRInt32 lastFullHandshakePeerID; void myHandshakeCallback(PRFileDesc *socket, void *arg) { PR_ATOMIC_SET(&lastFullHandshakePeerID, (PRInt32) arg); } #endif /* one copy of this function is launched in a separate thread for each ** connection to be made. */ int do_connects( void * a, void * b, int tid) { PRNetAddr * addr = (PRNetAddr *) a; PRFileDesc * model_sock = (PRFileDesc *) b; PRFileDesc * ssl_sock = 0; PRFileDesc * tcp_sock = 0; PRStatus prStatus; PRUint32 sleepInterval = 50; /* milliseconds */ SECStatus result; int rv = SECSuccess; PRSocketOptionData opt; retry: tcp_sock = PR_OpenTCPSocket(addr->raw.family); if (tcp_sock == NULL) { errExit("PR_OpenTCPSocket"); } opt.option = PR_SockOpt_Nonblocking; opt.value.non_blocking = PR_FALSE; prStatus = PR_SetSocketOption(tcp_sock, &opt); if (prStatus != PR_SUCCESS) { errWarn("PR_SetSocketOption(PR_SockOpt_Nonblocking, PR_FALSE)"); PR_Close(tcp_sock); return SECSuccess; } if (NoDelay) { opt.option = PR_SockOpt_NoDelay; opt.value.no_delay = PR_TRUE; prStatus = PR_SetSocketOption(tcp_sock, &opt); if (prStatus != PR_SUCCESS) { errWarn("PR_SetSocketOption(PR_SockOpt_NoDelay, PR_TRUE)"); PR_Close(tcp_sock); return SECSuccess; } } prStatus = PR_Connect(tcp_sock, addr, PR_INTERVAL_NO_TIMEOUT); if (prStatus != PR_SUCCESS) { PRErrorCode err = PR_GetError(); /* save error code */ PRInt32 oserr = PR_GetOSError(); if (ThrottleUp) { PRTime now = PR_Now(); PR_Lock(threadLock); lastConnectFailure = PR_MAX(now, lastConnectFailure); PR_Unlock(threadLock); PR_SetError(err, oserr); /* restore error code */ } if ((err == PR_CONNECT_REFUSED_ERROR) || (err == PR_CONNECT_RESET_ERROR) ) { int connections = numConnected; PR_Close(tcp_sock); PR_Lock(threadLock); if (connections > 2 && active_threads >= connections) { active_threads = connections - 1; fprintf(stderr,"active_threads set down to %d\n", active_threads); } PR_Unlock(threadLock); if (QuitOnTimeout && sleepInterval > 40000) { fprintf(stderr, "strsclnt: Client timed out waiting for connection to server.\n"); exit(1); } PR_Sleep(PR_MillisecondsToInterval(sleepInterval)); sleepInterval <<= 1; goto retry; } errWarn("PR_Connect"); rv = SECFailure; goto done; } else { if (ThrottleUp) { PRTime now = PR_Now(); PR_Lock(threadLock); lastConnectSuccess = PR_MAX(now, lastConnectSuccess); PR_Unlock(threadLock); } } ssl_sock = SSL_ImportFD(model_sock, tcp_sock); /* XXX if this import fails, close tcp_sock and return. */ if (!ssl_sock) { PR_Close(tcp_sock); return SECSuccess; } if (fullhs != NO_FULLHS_PERCENTAGE) { #ifdef USE_SOCK_PEER_ID char sockPeerIDString[512]; static PRInt32 sockPeerID = 0; /* atomically incremented */ PRInt32 thisPeerID; #endif PRInt32 savid = PR_ATOMIC_INCREMENT(&globalconid); PRInt32 conid = 1 + (savid - 1) % 100; /* don't change peer ID on the very first handshake, which is always a full, so the session gets stored into the client cache */ if ( (savid != 1) && ( ( (savid <= total_connections_rounded_down_to_hundreds) && (conid <= fullhs) ) || (conid*100 <= total_connections_modulo_100*fullhs ) ) ) #ifdef USE_SOCK_PEER_ID { /* force a full handshake by changing the socket peer ID */ thisPeerID = PR_ATOMIC_INCREMENT(&sockPeerID); } else { /* reuse previous sockPeerID for restart handhsake */ thisPeerID = lastFullHandshakePeerID; } PR_snprintf(sockPeerIDString, sizeof(sockPeerIDString), "ID%d", thisPeerID); SSL_SetSockPeerID(ssl_sock, sockPeerIDString); SSL_HandshakeCallback(ssl_sock, myHandshakeCallback, (void*)thisPeerID); #else /* force a full handshake by setting the no cache option */ SSL_OptionSet(ssl_sock, SSL_NO_CACHE, 1); #endif } rv = SSL_ResetHandshake(ssl_sock, /* asServer */ 0); if (rv != SECSuccess) { errWarn("SSL_ResetHandshake"); goto done; } PR_ATOMIC_INCREMENT(&numConnected); if (bigBuf.data != NULL) { result = handle_fdx_connection( ssl_sock, tid); } else { result = handle_connection( ssl_sock, tid); } PR_ATOMIC_DECREMENT(&numConnected); done: if (ssl_sock) { PR_Close(ssl_sock); } else if (tcp_sock) { PR_Close(tcp_sock); } return SECSuccess; } typedef struct { PRLock* lock; char* nickname; CERTCertificate* cert; SECKEYPrivateKey* key; void* wincx; } cert_and_key; PRBool FindCertAndKey(cert_and_key* Cert_And_Key) { if ( (NULL == Cert_And_Key->nickname) || (0 == strcmp(Cert_And_Key->nickname,"none"))) { return PR_TRUE; } Cert_And_Key->cert = CERT_FindUserCertByUsage(CERT_GetDefaultCertDB(), Cert_And_Key->nickname, certUsageSSLClient, PR_FALSE, Cert_And_Key->wincx); if (Cert_And_Key->cert) { Cert_And_Key->key = PK11_FindKeyByAnyCert(Cert_And_Key->cert, Cert_And_Key->wincx); } if (Cert_And_Key->cert && Cert_And_Key->key) { return PR_TRUE; } else { return PR_FALSE; } } PRBool LoggedIn(CERTCertificate* cert, SECKEYPrivateKey* key) { if ( (cert->slot) && (key->pkcs11Slot) && (PR_TRUE == PK11_IsLoggedIn(cert->slot, NULL)) && (PR_TRUE == PK11_IsLoggedIn(key->pkcs11Slot, NULL)) ) { return PR_TRUE; } return PR_FALSE; } SECStatus StressClient_GetClientAuthData(void * arg, PRFileDesc * socket, struct CERTDistNamesStr * caNames, struct CERTCertificateStr ** pRetCert, struct SECKEYPrivateKeyStr **pRetKey) { cert_and_key* Cert_And_Key = (cert_and_key*) arg; if (!pRetCert || !pRetKey) { /* bad pointers, can't return a cert or key */ return SECFailure; } *pRetCert = NULL; *pRetKey = NULL; if (Cert_And_Key && Cert_And_Key->nickname) { while (PR_TRUE) { if (Cert_And_Key && Cert_And_Key->lock) { int timeout = 0; PR_Lock(Cert_And_Key->lock); if (Cert_And_Key->cert) { *pRetCert = CERT_DupCertificate(Cert_And_Key->cert); } if (Cert_And_Key->key) { *pRetKey = SECKEY_CopyPrivateKey(Cert_And_Key->key); } PR_Unlock(Cert_And_Key->lock); if (!*pRetCert || !*pRetKey) { /* one or both of them failed to copy. Either the source was NULL, or there was ** an out of memory condition. Free any allocated copy and fail */ if (*pRetCert) { CERT_DestroyCertificate(*pRetCert); *pRetCert = NULL; } if (*pRetKey) { SECKEY_DestroyPrivateKey(*pRetKey); *pRetKey = NULL; } break; } /* now check if those objects are valid */ if ( PR_FALSE == LoggedIn(*pRetCert, *pRetKey) ) { /* token is no longer logged in, it was removed */ /* first, delete and clear our invalid local objects */ CERT_DestroyCertificate(*pRetCert); SECKEY_DestroyPrivateKey(*pRetKey); *pRetCert = NULL; *pRetKey = NULL; PR_Lock(Cert_And_Key->lock); /* check if another thread already logged back in */ if (PR_TRUE == LoggedIn(Cert_And_Key->cert, Cert_And_Key->key)) { /* yes : try again */ PR_Unlock(Cert_And_Key->lock); continue; } /* this is the thread to retry */ CERT_DestroyCertificate(Cert_And_Key->cert); SECKEY_DestroyPrivateKey(Cert_And_Key->key); Cert_And_Key->cert = NULL; Cert_And_Key->key = NULL; /* now look up the cert and key again */ while (PR_FALSE == FindCertAndKey(Cert_And_Key) ) { PR_Sleep(PR_SecondsToInterval(1)); timeout++; if (timeout>=60) { printf("\nToken pulled and not reinserted early enough : aborting.\n"); exit(1); } } PR_Unlock(Cert_And_Key->lock); continue; /* try again to reduce code size */ } return SECSuccess; } } *pRetCert = NULL; *pRetKey = NULL; return SECFailure; } else { /* no cert configured, automatically find the right cert. */ CERTCertificate * cert = NULL; SECKEYPrivateKey * privkey = NULL; CERTCertNicknames * names; int i; void * proto_win = NULL; SECStatus rv = SECFailure; if (Cert_And_Key) { proto_win = Cert_And_Key->wincx; } names = CERT_GetCertNicknames(CERT_GetDefaultCertDB(), SEC_CERT_NICKNAMES_USER, proto_win); if (names != NULL) { for (i = 0; i < names->numnicknames; i++) { cert = CERT_FindUserCertByUsage(CERT_GetDefaultCertDB(), names->nicknames[i], certUsageSSLClient, PR_FALSE, proto_win); if ( !cert ) continue; /* Only check unexpired certs */ if (CERT_CheckCertValidTimes(cert, PR_Now(), PR_TRUE) != secCertTimeValid ) { CERT_DestroyCertificate(cert); continue; } rv = NSS_CmpCertChainWCANames(cert, caNames); if ( rv == SECSuccess ) { privkey = PK11_FindKeyByAnyCert(cert, proto_win); if ( privkey ) break; } rv = SECFailure; CERT_DestroyCertificate(cert); } CERT_FreeNicknames(names); } if (rv == SECSuccess) { *pRetCert = cert; *pRetKey = privkey; } return rv; } } int hexchar_to_int(int c) { if (((c) >= '0') && ((c) <= '9')) return (c) - '0'; if (((c) >= 'a') && ((c) <= 'f')) return (c) - 'a' + 10; if (((c) >= 'A') && ((c) <= 'F')) return (c) - 'A' + 10; failed_already = 1; return -1; } void client_main( unsigned short port, int connections, cert_and_key* Cert_And_Key, const char * hostName, const char * sniHostName) { PRFileDesc *model_sock = NULL; int i; int rv; PRStatus status; PRNetAddr addr; status = PR_StringToNetAddr(hostName, &addr); if (status == PR_SUCCESS) { addr.inet.port = PR_htons(port); } else { /* Lookup host */ PRAddrInfo *addrInfo; void *enumPtr = NULL; addrInfo = PR_GetAddrInfoByName(hostName, PR_AF_UNSPEC, PR_AI_ADDRCONFIG | PR_AI_NOCANONNAME); if (!addrInfo) { SECU_PrintError(progName, "error looking up host"); return; } do { enumPtr = PR_EnumerateAddrInfo(enumPtr, addrInfo, port, &addr); } while (enumPtr != NULL && addr.raw.family != PR_AF_INET && addr.raw.family != PR_AF_INET6); PR_FreeAddrInfo(addrInfo); if (enumPtr == NULL) { SECU_PrintError(progName, "error looking up host address"); return; } } /* all suites except RSA_NULL_MD5 are enabled by Domestic Policy */ NSS_SetDomesticPolicy(); /* all the SSL2 and SSL3 cipher suites are enabled by default. */ if (cipherString) { int ndx; /* disable all the ciphers, then enable the ones we want. */ disableAllSSLCiphers(); while (0 != (ndx = *cipherString)) { const char * startCipher = cipherString++; int cipher = 0; SECStatus rv; if (ndx == ':') { cipher = hexchar_to_int(*cipherString++); cipher <<= 4; cipher |= hexchar_to_int(*cipherString++); cipher <<= 4; cipher |= hexchar_to_int(*cipherString++); cipher <<= 4; cipher |= hexchar_to_int(*cipherString++); if (cipher <= 0) { fprintf(stderr, "strsclnt: Invalid cipher value: %-5.5s\n", startCipher); failed_already = 1; return; } } else { if (isalpha(ndx)) { const int *cptr; cptr = islower(ndx) ? ssl3CipherSuites : ssl2CipherSuites; for (ndx &= 0x1f; (cipher = *cptr++) != 0 && --ndx > 0; ) /* do nothing */; } if (cipher <= 0) { fprintf(stderr, "strsclnt: Invalid cipher letter: %c\n", *startCipher); failed_already = 1; return; } } rv = SSL_CipherPrefSetDefault(cipher, PR_TRUE); if (rv != SECSuccess) { fprintf(stderr, "strsclnt: SSL_CipherPrefSetDefault(0x%04x) failed\n", cipher); failed_already = 1; return; } } } /* configure model SSL socket. */ model_sock = PR_OpenTCPSocket(addr.raw.family); if (model_sock == NULL) { errExit("PR_OpenTCPSocket for model socket"); } model_sock = SSL_ImportFD(NULL, model_sock); if (model_sock == NULL) { errExit("SSL_ImportFD"); } /* do SSL configuration. */ rv = SSL_OptionSet(model_sock, SSL_SECURITY, enableSSL2 || enabledVersions.min != 0); if (rv < 0) { errExit("SSL_OptionSet SSL_SECURITY"); } rv = SSL_VersionRangeSet(model_sock, &enabledVersions); if (rv != SECSuccess) { errExit("error setting SSL/TLS version range "); } rv = SSL_OptionSet(model_sock, SSL_ENABLE_SSL2, enableSSL2); if (rv != SECSuccess) { errExit("error enabling SSLv2 "); } rv = SSL_OptionSet(model_sock, SSL_V2_COMPATIBLE_HELLO, enableSSL2); if (rv != SECSuccess) { errExit("error enabling SSLv2 compatible hellos "); } if (bigBuf.data) { /* doing FDX */ rv = SSL_OptionSet(model_sock, SSL_ENABLE_FDX, 1); if (rv < 0) { errExit("SSL_OptionSet SSL_ENABLE_FDX"); } } if (NoReuse) { rv = SSL_OptionSet(model_sock, SSL_NO_CACHE, 1); if (rv < 0) { errExit("SSL_OptionSet SSL_NO_CACHE"); } } if (bypassPKCS11) { rv = SSL_OptionSet(model_sock, SSL_BYPASS_PKCS11, 1); if (rv < 0) { errExit("SSL_OptionSet SSL_BYPASS_PKCS11"); } } if (disableLocking) { rv = SSL_OptionSet(model_sock, SSL_NO_LOCKS, 1); if (rv < 0) { errExit("SSL_OptionSet SSL_NO_LOCKS"); } } if (enableSessionTickets) { rv = SSL_OptionSet(model_sock, SSL_ENABLE_SESSION_TICKETS, PR_TRUE); if (rv != SECSuccess) errExit("SSL_OptionSet SSL_ENABLE_SESSION_TICKETS"); } if (enableCompression) { rv = SSL_OptionSet(model_sock, SSL_ENABLE_DEFLATE, PR_TRUE); if (rv != SECSuccess) errExit("SSL_OptionSet SSL_ENABLE_DEFLATE"); } if (enableFalseStart) { rv = SSL_OptionSet(model_sock, SSL_ENABLE_FALSE_START, PR_TRUE); if (rv != SECSuccess) errExit("SSL_OptionSet SSL_ENABLE_FALSE_START"); } if (enableCertStatus) { rv = SSL_OptionSet(model_sock, SSL_ENABLE_OCSP_STAPLING, PR_TRUE); if (rv != SECSuccess) errExit("SSL_OptionSet SSL_ENABLE_OCSP_STAPLING"); } SSL_SetPKCS11PinArg(model_sock, &pwdata); SSL_SetURL(model_sock, hostName); SSL_AuthCertificateHook(model_sock, mySSLAuthCertificate, (void *)CERT_GetDefaultCertDB()); SSL_BadCertHook(model_sock, myBadCertHandler, NULL); SSL_GetClientAuthDataHook(model_sock, StressClient_GetClientAuthData, (void*)Cert_And_Key); if (sniHostName) { SSL_SetURL(model_sock, sniHostName); } /* I'm not going to set the HandshakeCallback function. */ /* end of ssl configuration. */ init_thread_data(); remaining_connections = total_connections = connections; total_connections_modulo_100 = total_connections % 100; total_connections_rounded_down_to_hundreds = total_connections - total_connections_modulo_100; if (!NoReuse) { remaining_connections = 1; rv = launch_thread(do_connects, &addr, model_sock, 0); /* wait for the first connection to terminate, then launch the rest. */ reap_threads(); remaining_connections = total_connections - 1 ; } if (remaining_connections > 0) { active_threads = PR_MIN(active_threads, remaining_connections); /* Start up the threads */ for (i=0;i 0 && NULL != (local_file_fd = PR_Open(fileName, PR_RDONLY, 0))) { hdrLen = PORT_Strlen(outHeader); bigBuf.len = hdrLen + info.size; bigBuf.data = PORT_Malloc(bigBuf.len + 4095); if (!bigBuf.data) { errWarn("PORT_Malloc"); goto done; } PORT_Memcpy(bigBuf.data, outHeader, hdrLen); count = PR_Read(local_file_fd, bigBuf.data + hdrLen, info.size); if (count != info.size) { errWarn("PR_Read local file"); goto done; } rv = SECSuccess; done: PR_Close(local_file_fd); } return rv; } int main(int argc, char **argv) { const char * dir = "."; const char * fileName = NULL; char * hostName = NULL; char * nickName = NULL; char * tmp = NULL; int connections = 1; int exitVal; int tmpInt; unsigned short port = 443; SECStatus rv; PLOptState * optstate; PLOptStatus status; cert_and_key Cert_And_Key; char * sniHostName = NULL; /* Call the NSPR initialization routines */ PR_Init( PR_SYSTEM_THREAD, PR_PRIORITY_NORMAL, 1); SSL_VersionRangeGetSupported(ssl_variant_stream, &enabledVersions); tmp = strrchr(argv[0], '/'); tmp = tmp ? tmp + 1 : argv[0]; progName = strrchr(tmp, '\\'); progName = progName ? progName + 1 : tmp; optstate = PL_CreateOptState(argc, argv, "BC:DNP:TUV:W:a:c:d:f:gin:op:qst:uvw:z"); while ((status = PL_GetNextOpt(optstate)) == PL_OPT_OK) { switch(optstate->option) { case 'B': bypassPKCS11 = PR_TRUE; break; case 'C': cipherString = optstate->value; break; case 'D': NoDelay = PR_TRUE; break; case 'I': /* reserved for OCSP multi-stapling */ break; case 'N': NoReuse = 1; break; case 'P': fullhs = PORT_Atoi(optstate->value); break; case 'T': enableCertStatus = PR_TRUE; break; case 'U': ThrottleUp = PR_TRUE; break; case 'V': if (SECU_ParseSSLVersionRangeString(optstate->value, enabledVersions, enableSSL2, &enabledVersions, &enableSSL2) != SECSuccess) { Usage(progName); } break; case 'a': sniHostName = PL_strdup(optstate->value); break; case 'c': connections = PORT_Atoi(optstate->value); break; case 'd': dir = optstate->value; break; case 'f': fileName = optstate->value; break; case 'g': enableFalseStart = PR_TRUE; break; case 'i': ignoreErrors = PR_TRUE; break; case 'n': nickName = PL_strdup(optstate->value); break; case 'o': MakeCertOK++; break; case 'p': port = PORT_Atoi(optstate->value); break; case 'q': QuitOnTimeout = PR_TRUE; break; case 's': disableLocking = PR_TRUE; break; case 't': tmpInt = PORT_Atoi(optstate->value); if (tmpInt > 0 && tmpInt < MAX_THREADS) max_threads = active_threads = tmpInt; break; case 'u': enableSessionTickets = PR_TRUE; break; case 'v': verbose++; break; case 'w': pwdata.source = PW_PLAINTEXT; pwdata.data = PL_strdup(optstate->value); break; case 'W': pwdata.source = PW_FROMFILE; pwdata.data = PL_strdup(optstate->value); break; case 'z': enableCompression = PR_TRUE; break; case 0: /* positional parameter */ if (hostName) { Usage(progName); } hostName = PL_strdup(optstate->value); break; default: case '?': Usage(progName); break; } } PL_DestroyOptState(optstate); if (!hostName || status == PL_OPT_BAD) Usage(progName); if (fullhs!= NO_FULLHS_PERCENTAGE && (fullhs < 0 || fullhs>100 || NoReuse) ) Usage(progName); if (port == 0) Usage(progName); if (fileName) readBigFile(fileName); PK11_SetPasswordFunc(SECU_GetModulePassword); tmp = PR_GetEnv("NSS_DEBUG_TIMEOUT"); if (tmp && tmp[0]) { int sec = PORT_Atoi(tmp); if (sec > 0) { maxInterval = PR_SecondsToInterval(sec); } } /* Call the NSS initialization routines */ rv = NSS_Initialize(dir, "", "", SECMOD_DB, NSS_INIT_READONLY); if (rv != SECSuccess) { fputs("NSS_Init failed.\n", stderr); exit(1); } ssl3stats = SSL_GetStatistics(); Cert_And_Key.lock = PR_NewLock(); Cert_And_Key.nickname = nickName; Cert_And_Key.wincx = &pwdata; Cert_And_Key.cert = NULL; Cert_And_Key.key = NULL; if (PR_FALSE == FindCertAndKey(&Cert_And_Key)) { if (Cert_And_Key.cert == NULL) { fprintf(stderr, "strsclnt: Can't find certificate %s\n", Cert_And_Key.nickname); exit(1); } if (Cert_And_Key.key == NULL) { fprintf(stderr, "strsclnt: Can't find Private Key for cert %s\n", Cert_And_Key.nickname); exit(1); } } client_main(port, connections, &Cert_And_Key, hostName, sniHostName); /* clean up */ if (Cert_And_Key.cert) { CERT_DestroyCertificate(Cert_And_Key.cert); } if (Cert_And_Key.key) { SECKEY_DestroyPrivateKey(Cert_And_Key.key); } PR_DestroyLock(Cert_And_Key.lock); if (pwdata.data) { PL_strfree(pwdata.data); } if (Cert_And_Key.nickname) { PL_strfree(Cert_And_Key.nickname); } if (sniHostName) { PL_strfree(sniHostName); } PL_strfree(hostName); /* some final stats. */ if (ssl3stats->hsh_sid_cache_hits + ssl3stats->hsh_sid_cache_misses + ssl3stats->hsh_sid_cache_not_ok + ssl3stats->hsh_sid_stateless_resumes == 0) { /* presumably we were testing SSL2. */ printf("strsclnt: SSL2 - %d server certificates tested.\n", certsTested); } else { printf( "strsclnt: %ld cache hits; %ld cache misses, %ld cache not reusable\n" " %ld stateless resumes\n", ssl3stats->hsh_sid_cache_hits, ssl3stats->hsh_sid_cache_misses, ssl3stats->hsh_sid_cache_not_ok, ssl3stats->hsh_sid_stateless_resumes); } if (!NoReuse) { if (enableSessionTickets) exitVal = (ssl3stats->hsh_sid_stateless_resumes == 0); else exitVal = (ssl3stats->hsh_sid_cache_misses > 1) || (ssl3stats->hsh_sid_stateless_resumes != 0); if (!exitVal) exitVal = (ssl3stats->hsh_sid_cache_not_ok != 0) || (certsTested > 1); } else { printf("strsclnt: NoReuse - %d server certificates tested.\n", certsTested); if (ssl3stats->hsh_sid_cache_hits + ssl3stats->hsh_sid_cache_misses + ssl3stats->hsh_sid_cache_not_ok + ssl3stats->hsh_sid_stateless_resumes > 0) { exitVal = (ssl3stats->hsh_sid_cache_misses != connections) || (ssl3stats->hsh_sid_stateless_resumes != 0) || (certsTested != connections); } else { /* ssl2 connections */ exitVal = (certsTested != connections); } } exitVal = ( exitVal || failed_already ); SSL_ClearSessionCache(); if (NSS_Shutdown() != SECSuccess) { printf("strsclnt: NSS_Shutdown() failed.\n"); exit(1); } PR_Cleanup(); return exitVal; }