[Twisted-Python] Updated TLS patch
Skinny Puppy
skin_pup-twisted at damnable.happypoo.com
Fri May 2 21:02:26 MDT 2003
Glyph Lefkowitz [glyph at twistedmatrix.com] wrote:
> -----BEGIN PGP SIGNED MESSAGE-----
> Hash: SHA1
>
> Ugly as it is, this looks like the right answer to me...
>
> On Friday, May 2, 2003, at 07:30 AM, Skinny Puppy wrote:
>
> >The branch/function call can be avoided by replacing the
> >doRead/doWrite/etc
> >methods in startTLS. While this is still not very perty ;)
> -----BEGIN PGP SIGNATURE-----
> Version: GnuPG v1.2.1 (Darwin)
>
> iD8DBQE+svG0vVGR4uSOE2wRAqo+AJ40/0hBnDnEh1267vYe7hAJV0TEUwCeNklv
> Qya3OyfpjxoexyNSb3iLPqc=
> =24qI
> -----END PGP SIGNATURE-----
Ok - Done - I still don't like it. I have not run any real world tests
yet, but I have used echoserv_tls.py/echoclient_tls.py and watched the
traffic with tcpdump to verify the encryption. And of course the Unit
Tests.
Jeremy
-------------- next part --------------
? doc/examples/echoclient_tls.py
? doc/examples/echoserv_tls.py
Index: twisted/internet/ssl.py
===================================================================
RCS file: /cvs/Twisted/twisted/internet/ssl.py,v
retrieving revision 1.40
diff -u -r1.40 ssl.py
--- twisted/internet/ssl.py 2 Apr 2003 04:11:32 -0000 1.40
+++ twisted/internet/ssl.py 3 May 2003 06:28:42 -0000
@@ -95,116 +95,13 @@
return SSL.Context(SSL.SSLv3_METHOD)
-class Connection(tcp.Connection):
- """I am an SSL connection.
- """
-
- __implements__ = tcp.Connection.__implements__, interfaces.ISSLTransport
-
- writeBlockedOnRead = 0
- readBlockedOnWrite= 0
- sslShutdown = 0
-
- def getPeerCertificate(self):
- """Return the certificate for the peer."""
- return self.socket.get_peer_certificate()
-
- def _postLoseConnection(self):
- """Gets called after loseConnection(), after buffered data is sent.
-
- We close the SSL transport layer, and if the other side hasn't
- closed it yet we start reading, waiting for a ZeroReturnError
- which will indicate the SSL shutdown has completed.
- """
- try:
- done = self.socket.shutdown()
- self.sslShutdown = 1
- except SSL.Error:
- return main.CONNECTION_LOST
- if done:
- return main.CONNECTION_DONE
- else:
- # we wait for other side to close SSL connection -
- # this will be signaled by SSL.ZeroReturnError when reading
- # from the socket
- self.stopWriting()
- self.startReading()
- return None # don't close socket just yet
-
- def doRead(self):
- """See tcp.Connection.doRead for details.
- """
- if self.writeBlockedOnRead:
- self.writeBlockedOnRead = 0
- return self.doWrite()
- try:
- return tcp.Connection.doRead(self)
- except SSL.ZeroReturnError:
- # close SSL layer, since other side has done so, if we haven't
- if not self.sslShutdown:
- try:
- self.socket.shutdown()
- self.sslShutdown = 1
- except SSL.Error:
- pass
- return main.CONNECTION_DONE
- except SSL.WantReadError:
- return
- except SSL.WantWriteError:
- self.readBlockedOnWrite = 1
- self.startWriting()
- return
- except SSL.Error:
- return main.CONNECTION_LOST
-
- def doWrite(self):
- if self.readBlockedOnWrite:
- self.readBlockedOnWrite = 0
- if not self.dataBuffer: self.stopWriting()
- return self.doRead()
- return tcp.Connection.doWrite(self)
-
- def writeSomeData(self, data):
- """See tcp.Connection.writeSomeData for details.
- """
- if not data:
- return 0
-
- try:
- return tcp.Connection.writeSomeData(self, data)
- except SSL.WantWriteError:
- return 0
- except SSL.WantReadError:
- self.writeBlockedOnRead = 1
- return 0
- except SSL.Error:
- return main.CONNECTION_LOST
-
- def _closeSocket(self):
- """Called to close our socket."""
- try:
- self.socket.sock_shutdown(2)
- except socket.error:
- try:
- self.socket.close()
- except socket.error:
- log.deferr()
-
-
-
-class Client(Connection, tcp.Client):
+class Client(tcp.Client):
"""I am an SSL client."""
def __init__(self, host, port, bindAddress, ctxFactory, connector, reactor=None):
# tcp.Client.__init__ depends on self.ctxFactory being set
self.ctxFactory = ctxFactory
tcp.Client.__init__(self, host, port, bindAddress, connector, reactor)
- def createInternetSocket(self):
- """(internal) create an SSL socket
- """
- sock = tcp.Client.createInternetSocket(self)
- return SSL.Connection(self.ctxFactory.getContext(), sock)
-
def getHost(self):
"""Returns a tuple of ('SSL', hostname, port).
@@ -219,16 +116,14 @@
"""
return ('SSL',)+self.addr
+ def _finishInit(self, whenDone, skt, error, reactor):
+ tcp.Client._finishInit(self, whenDone, skt, error, reactor)
+ self.startTLS(self.ctxFactory)
-class Server(Connection, tcp.Server):
+class Server(tcp.Server):
"""I am an SSL server.
"""
-
- def __init__(*args, **kw):
- # We don't want Connection's __init__
- tcp.Server.__init__(*args, **kw)
-
def getHost(self):
"""Returns a tuple of ('SSL', hostname, port).
@@ -257,33 +152,12 @@
"""
sock = tcp.Port.createInternetSocket(self)
return SSL.Connection(self.ctxFactory.getContext(), sock)
-
- def doRead(self):
- """Called when my socket is ready for reading.
- This accepts a connection and calls self.protocol() to handle the
- wire-level protocol.
- """
- try:
- try:
- skt, addr = self.socket.accept()
- except socket.error, e:
- if e.args[0] == tcp.EWOULDBLOCK:
- return
- raise
- except SSL.Error:
- log.deferr()
- return
- protocol = self.factory.buildProtocol(addr)
- if protocol is None:
- skt.close()
- return
- s = self.sessionno
- self.sessionno = s+1
- transport = self.transport(skt, protocol, addr, self, s)
- protocol.makeConnection(transport)
- except:
- log.deferr()
+ def _preMakeConnection(self, transport):
+ # *Don't* call startTLS here
+ # The transport already has the SSL.Connection object from above
+ transport._startTLS()
+ return tcp.Port._preMakeConnection(self, transport)
class Connector(base.BaseConnector):
Index: twisted/internet/tcp.py
===================================================================
RCS file: /cvs/Twisted/twisted/internet/tcp.py,v
retrieving revision 1.118
diff -u -r1.118 tcp.py
--- twisted/internet/tcp.py 2 May 2003 04:31:14 -0000 1.118
+++ twisted/internet/tcp.py 3 May 2003 06:28:50 -0000
@@ -39,6 +39,11 @@
except ImportError:
fcntl = None
+try:
+ from OpenSSL import SSL
+except ImportError:
+ SSL = None
+
if os.name == 'nt':
# we hardcode these since windows actually wants e.g.
# WSAEALREADY rather than EALREADY. Possibly we should
@@ -88,14 +93,36 @@
__implements__ = abstract.FileDescriptor.__implements__, interfaces.ITCPTransport
+ if SSL:
+ writeBlockedOnRead = 0
+ readBlockedOnWrite= 0
+ sslShutdown = 0
+ TLS = 0
+
def __init__(self, skt, protocol, reactor=None):
abstract.FileDescriptor.__init__(self, reactor=reactor)
self.socket = skt
self.socket.setblocking(0)
self.fileno = skt.fileno
self.protocol = protocol
+
+ def startTLS(self, ctx):
+ if not SSL:
+ raise RuntimeException, "No SSL support available"
+ assert not self.TLS
- def doRead(self):
+ self._startTLS()
+ self.socket = SSL.Connection(ctx.getContext(), self.socket)
+ self.fileno = self.socket.fileno
+
+ def _startTLS(self):
+ self.TLS = 1
+ self.doRead = self._TLS_doRead
+ self.writeSomeData = self._TLS_writeSomeData
+ self.doWrite = self._TLS_doWrite
+ self._closeSocket = self._TLS_closeSocket
+
+ def _NOTLS_doRead(self):
"""Calls self.protocol.dataReceived with all available data.
This reads up to self.bufferSize bytes of data from its socket, then
@@ -114,7 +141,42 @@
return main.CONNECTION_LOST
return self.protocol.dataReceived(data)
- def writeSomeData(self, data):
+ doRead = _NOTLS_doRead
+
+ def _TLS_doRead(self):
+ if self.writeBlockedOnRead:
+ self.writeBlockedOnRead = 0
+ return self.doWrite()
+ try:
+ return self._NOTLS_doRead()
+ except SSL.ZeroReturnError:
+ # close SSL layer, since other side has done so, if we haven't
+ if not self.sslShutdown:
+ try:
+ self.socket.shutdown()
+ self.sslShutdown = 1
+ except SSL.Error:
+ pass
+ return main.CONNECTION_DONE
+ except SSL.WantReadError:
+ return
+ except SSL.WantWriteError:
+ self.readBlockedOnWrite = 1
+ self.startWriting()
+ return
+ except SSL.Error:
+ return main.CONNECTION_LOST
+
+ def _TLS_doWrite(self):
+ if self.readBlockedOnWrite:
+ self.readBlockedOnWrite = 0
+ # XXX - This is touching internal guts bad bad bad
+ if not self.dataBuffer:
+ self.stopWriting()
+ return self.doRead()
+ return abstract.FileDescriptor.doWrite(self)
+
+ def _NOTLS_writeSomeData(self, data):
"""Connection.writeSomeData(data) -> #of bytes written | CONNECTION_LOST
This writes as much data as possible to the socket and returns either
the number of bytes read (which is positive) or a connection error code
@@ -128,7 +190,21 @@
else:
return main.CONNECTION_LOST
- def _closeSocket(self):
+ writeSomeData = _NOTLS_writeSomeData
+
+ def _TLS_writeSomeData(self, data):
+ if not data:
+ return 0
+ try:
+ return self._NOTLS_writeSomeData(data)
+ except SSL.WantWriteError:
+ return 0
+ except SSL.WantReadError:
+ self.writeBlockedOnRead = 1
+ except SSL.Error:
+ return main.CONNECTION_LOST
+
+ def _NOTLS_closeSocket(self):
"""Called to close our socket."""
# This used to close() the socket, but that doesn't *really* close if
# there's another reference to it in the TCP/IP stack, e.g. if it was
@@ -139,6 +215,17 @@
except socket.error:
pass
+ _closeSocket = _NOTLS_closeSocket
+
+ def _TLS_closeSocket(self):
+ try:
+ self.socket.sock_shutdown(2)
+ except:
+ try:
+ self.socket.close()
+ except:
+ pass
+
def connectionLost(self, reason):
"""See abstract.FileDescriptor.connectionLost().
"""
@@ -173,6 +260,33 @@
def setTcpNoDelay(self, enabled):
self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, enabled)
+
+ def _postLoseConnection(self):
+ """Gets called after loseConnection(), after buffered data is sent.
+
+ We close the SSL transport layer, and if the other side hasn't
+ closed it yet we start reading, waiting for a ZeroReturnError
+ which will indicate the SSL shutdown has completed.
+ """
+ if not self.TLS:
+ return abstract.FileDescriptor._postLoseConnection(self)
+
+ try:
+ done = self.socket.shutdown()
+ self.sslShutdown = 1
+ except SSL.Error:
+ return main.CONNECTION_LOST
+ if done:
+ return main.CONNECTION_DONE
+ else:
+ # we wait for other side to close SSL connection -
+ # this will be signaled by SSL.ZeroReturnError when reading
+ # from the socket
+ self.stopWriting()
+ self.startReading()
+
+ # don't close socket just yet
+ return None
class BaseClient(Connection):
@@ -191,6 +305,11 @@
else:
reactor.callLater(0, self.failIfNotConnected, error)
+ def startTLS(self, ctx):
+ holder = Connection.startTLS(self, ctx)
+ self.socket.set_connect_state()
+ return holder
+
def stopConnecting(self):
"""Stop attempt to connect."""
self.failIfNotConnected(error.UserError())
@@ -360,6 +479,11 @@
"""
return self.repstr
+ def startTLS(self, ctx):
+ holder = Connection.startTLS(self, ctx)
+ self.socket.set_accept_state()
+ return holder
+
def getHost(self):
"""Returns a tuple of ('INET', hostname, port).
@@ -458,6 +582,7 @@
elif e.args[0] == EPERM:
continue
raise
+
protocol = self.factory.buildProtocol(addr)
if protocol is None:
skt.close()
@@ -465,11 +590,22 @@
s = self.sessionno
self.sessionno = s+1
transport = self.transport(skt, protocol, addr, self, s)
+ transport = self._preMakeConnection(transport)
protocol.makeConnection(transport)
else:
self.numberAccepts = self.numberAccepts+20
except:
+ # Note that in TLS mode, this will possibly catch SSL.Errors
+ # raised by self.socket.accept()
+ #
+ # There is no "except SSL.Error:" above because SSL may be
+ # None if there is no SSL support. In any case, all the
+ # "except SSL.Error:" suite would probably do is log.deferr()
+ # and return, so handling it here works just as well.
log.deferr()
+
+ def _preMakeConnection(self, transport):
+ return transport
def loseConnection(self, connDone=failure.Failure(main.CONNECTION_DONE)):
"""Stop accepting connections on this port.
Index: twisted/test/test_ssl.py
===================================================================
RCS file: /cvs/Twisted/twisted/test/test_ssl.py,v
retrieving revision 1.9
diff -u -r1.9 test_ssl.py
--- twisted/test/test_ssl.py 3 May 2003 02:03:54 -0000 1.9
+++ twisted/test/test_ssl.py 3 May 2003 06:28:55 -0000
@@ -17,19 +17,23 @@
from __future__ import nested_scopes
from twisted.trial import unittest
from twisted.internet import protocol, reactor
+from twisted.protocols import basic
+
try:
- import OpenSSL
+ from OpenSSL import SSL
from twisted.internet import ssl
except ImportError:
- OpenSSL = None
+ SSL = None
+
import os
import test_tcp
+certPath = os.path.join(os.path.split(test_tcp.__file__)[0], "server.pem")
+
class StolenTCPTestCase(test_tcp.ProperlyCloseFilesTestCase, test_tcp.WriteDataTestCase):
def setUp(self):
- certPath = os.path.join(os.path.split(test_tcp.__file__)[0], "server.pem")
f = protocol.ServerFactory()
f.protocol = protocol.Protocol
self.listener = reactor.listenSSL(
@@ -49,5 +53,117 @@
self.totalConnections = 0
-if not OpenSSL:
- del StolenTCPTestCase
+class ClientTLSContext(ssl.ClientContextFactory):
+ isClient = 1
+ def getContext(self):
+ return SSL.Context(ssl.SSL.TLSv1_METHOD)
+
+class UnintelligentProtocol(basic.LineReceiver):
+ pretext = [
+ "first line",
+ "last thing before tls starts",
+ "STARTTLS",
+ ]
+
+ posttext = [
+ "first thing after tls started",
+ "last thing ever",
+ ]
+
+ def connectionMade(self):
+ for l in self.pretext:
+ self.sendLine(l)
+
+ def lineReceived(self, line):
+ if line == "READY":
+ self.transport.startTLS(ClientTLSContext())
+ for l in self.posttext:
+ self.sendLine(l)
+ self.transport.loseConnection()
+
+class ServerTLSContext(ssl.DefaultOpenSSLContextFactory):
+ isClient = 0
+ def __init__(self, *args, **kw):
+ kw['sslmethod'] = SSL.TLSv1_METHOD
+ ssl.DefaultOpenSSLContextFactory.__init__(self, *args, **kw)
+
+class LineCollector(basic.LineReceiver):
+ def __init__(self, doTLS):
+ self.doTLS = doTLS
+
+ def connectionMade(self):
+ self.factory.rawdata = ''
+ self.factory.lines = []
+
+ def lineReceived(self, line):
+ self.factory.lines.append(line)
+ if line == 'STARTTLS':
+ self.sendLine('READY')
+ if self.doTLS:
+ ctx = ServerTLSContext(
+ privateKeyFileName=certPath,
+ certificateFileName=certPath,
+ )
+ self.transport.startTLS(ctx)
+ else:
+ self.setRawMode()
+
+ def rawDataReceived(self, data):
+ self.factory.rawdata += data
+ self.factory.done = 1
+
+ def connectionLost(self, reason):
+ self.factory.done = 1
+
+class TLSTestCase(unittest.TestCase):
+ def testTLS(self):
+ cf = protocol.ClientFactory()
+ cf.protocol = UnintelligentProtocol
+
+ sf = protocol.ServerFactory()
+ sf.protocol = lambda: LineCollector(1)
+ sf.done = 0
+
+ port = reactor.listenTCP(0, sf)
+ portNo = port.getHost()[2]
+
+ reactor.connectTCP('0.0.0.0', portNo, cf)
+
+ i = 0
+ while i < 5000 and not sf.done:
+ reactor.iterate(0.01)
+ i += 1
+
+ self.failUnless(sf.done, "Never finished reading all lines")
+ self.assertEquals(
+ sf.lines,
+ UnintelligentProtocol.pretext + UnintelligentProtocol.posttext
+ )
+
+ def testUnTLS(self):
+ cf = protocol.ClientFactory()
+ cf.protocol = UnintelligentProtocol
+
+ sf = protocol.ServerFactory()
+ sf.protocol = lambda: LineCollector(0)
+ sf.done = 0
+
+ port = reactor.listenTCP(0, sf)
+ portNo = port.getHost()[2]
+
+ reactor.connectTCP('0.0.0.0', portNo, cf)
+
+ i = 0
+ while i < 5000 and not sf.done:
+ reactor.iterate(0.01)
+ i += 1
+
+ self.failUnless(sf.done, "Never finished reading all lines")
+ self.assertEquals(
+ sf.lines,
+ UnintelligentProtocol.pretext
+ )
+ self.failUnless(sf.rawdata, "No encrypted bytes received")
+
+if not SSL:
+ globals().clear()
More information about the Twisted-Python
mailing list