[Twisted-Python] SMTP patch take 2
Anders Hammarquist
iko at strakt.com
Mon Oct 7 07:18:27 MDT 2002
Hello,
Here is the promised revised SMTP patch. This adds a unit test for the
change in LineReceiver (the test can probably be improved upon, since
lineLengthExceeded is more of an exception), and a test to make sure
the '.' to finish mail data is on it's own line in SMTPClient. Plus
misc. cleanups. Comments are welcome as always.
/Anders
-------------- next part --------------
Index: twisted/protocols/basic.py
===================================================================
RCS file: /cvs/Twisted/twisted/protocols/basic.py,v
retrieving revision 1.24
diff -u -u -r1.24 basic.py
--- twisted/protocols/basic.py 1 Oct 2002 15:09:25 -0000 1.24
+++ twisted/protocols/basic.py 7 Oct 2002 13:04:01 -0000
@@ -153,12 +153,15 @@
line, self.__buffer = self.__buffer.split(self.delimiter, 1)
except ValueError:
if len(self.__buffer) > self.MAX_LENGTH:
- self.transport.loseConnection()
+ line, self.__buffer = self.__buffer, ''
+ self.lineLengthExceeded(line)
return
break
else:
- if len(line) > self.MAX_LENGTH:
- self.transport.loseConnection()
+ linelength = len(line)
+ if linelength > self.MAX_LENGTH:
+ line, self.__buffer = self.__buffer, ''
+ self.lineLengthExceeded(line)
return
self.lineReceived(line)
if self.transport.disconnecting:
@@ -200,6 +203,12 @@
"""Sends a line to the other end of the connection.
"""
self.transport.write(line + self.delimiter)
+
+ def lineLengthExceeded(self, line):
+ """Called when the maximum line length has been reached.
+ Override if it needs to be dealt with in some special way.
+ """
+ self.transport.loseConnection()
class Int32StringReceiver(protocol.Protocol):
Index: twisted/protocols/smtp.py
===================================================================
RCS file: /cvs/Twisted/twisted/protocols/smtp.py,v
retrieving revision 1.24
diff -u -u -r1.24 smtp.py
--- twisted/protocols/smtp.py 2 Oct 2002 12:38:24 -0000 1.24
+++ twisted/protocols/smtp.py 7 Oct 2002 13:04:01 -0000
@@ -18,10 +18,10 @@
"""
from twisted.protocols import basic
-from twisted.internet import protocol, defer
+from twisted.internet import protocol, defer, reactor
from twisted.python import log
-import os, time, string, operator
+import os, time, string, operator, re
class SMTPError(Exception):
pass
@@ -49,20 +49,104 @@
self.deferred.errback(arg)
self.done = 1
+class AddressError(SMTPError):
+ "Parse error in address"
+ pass
+
+# Character classes for parsing addresses
+atom = r"-A-Za-z0-9!#$%&'*+/=?^_`{|}~"
+
+class Address:
+ """Parse and hold an RFC 2821 address.
+
+ Source routes are stipped and ignored, UUCP-style bang-paths
+ and %-style routing are not parsed.
+ """
+
+ tstring = re.compile(r'''( # A string of
+ (?:"[^"]*" # quoted string
+ |\\.i # backslash-escaped characted
+ |[''' + string.replace(atom,'#',r'\#')
+ + r'''] # atom character
+ )+|.) # or any single character''',re.X)
+
+ def __init__(self, addr):
+ self.local = ''
+ self.domain = ''
+ self.addrstr = addr
+
+ # Tokenize
+ atl = filter(None,self.tstring.split(addr))
+
+ print atl
+
+ local = []
+ domain = []
+
+ while atl:
+ if atl[0] == '<':
+ if atl[-1] != '>':
+ raise AddressError, "Unbalanced <>"
+ atl = atl[1:-1]
+ elif atl[0] == '@':
+ atl = atl[1:]
+ if not local:
+ # Source route
+ while atl and atl[0] != ':':
+ # remove it
+ atl = atl[1:]
+ if not atl:
+ raise AddressError, "Malformed source route"
+ atl = atl[1:] # remove :
+ elif domain:
+ raise AddressError, "Too many @"
+ else:
+ # Now in domain
+ domain = ['']
+ elif len(atl[0]) == 1 and atl[0] not in atom + '.':
+ raise AddressError, "Parse error at " + atl[0]
+ else:
+ if not domain:
+ local.append(atl[0])
+ else:
+ domain.append(atl[0])
+ atl = atl[1:]
+
+ self.local = ''.join(local)
+ self.domain = ''.join(domain)
+
+ dequotebs = re.compile(r'\\(.)')
+ def dequote(self, addr):
+ "Remove RFC-2821 quotes from address"
+ res = []
+
+ atl = filter(None,self.tstring.split(addr))
+
+ for t in atl:
+ if t[0] == '"' and t[-1] == '"':
+ res.append(t[1:-1])
+ elif '\\' in t:
+ res.append(self.dequotebs.sub(r'\1',t))
+ else:
+ res.append(t)
+
+ return ''.join(res)
+
+ def __str__(self):
+ return '%s%s' % (self.local, self.domain and ("@" + self.domain) or "")
+
+ def __repr__(self):
+ return "%s.%s(%s)" % (self.__module__, self.__class__.__name__,
+ repr(str(self)))
class User:
def __init__(self, destination, helo, protocol, orig):
- try:
- self.name, self.domain = string.split(destination, '@', 1)
- except ValueError:
- self.name = destination
- self.domain = ''
+ self.dest = Address(destination)
self.helo = helo
self.protocol = protocol
self.orig = orig
-
class IMessage:
def lineReceived(self, line):
@@ -83,23 +167,40 @@
class SMTP(basic.LineReceiver):
- def __init__(self):
+ def __init__(self, domain, timeout=600):
self.mode = COMMAND
self.__from = None
self.__helo = None
- self.__to = ()
+ self.__to = []
+ self.timeout = timeout
+ self.host = domain
+
+ def timedout(self):
+ self.sendCode(421, '%s Timeout. Try talking faster next time!' %
+ self.host)
+ self.transport.loseConnection()
def connectionMade(self):
- self.sendCode(220, 'Spammers beware, your ass is on fire')
+ self.sendCode(220, '%s Spammers beware, your ass is on fire' %
+ self.host)
+ if self.timeout:
+ self.timeoutID = reactor.callLater(self.timeout, self.timedout)
def sendCode(self, code, message=''):
"Send an SMTP code with a message."
self.transport.write('%d %s\r\n' % (code, message))
def lineReceived(self, line):
+ if self.timeout:
+ self.timeoutID.cancel()
+ self.timeoutID = reactor.callLater(self.timeout, self.timedout)
+
if self.mode is DATA:
return self.dataLineReceived(line)
- command = string.split(line, None, 1)[0]
+ if line:
+ command = string.split(line, None, 1)[0]
+ else:
+ command = ''
method = getattr(self, 'do_'+string.upper(command), None)
if method is None:
method = self.do_UNKNOWN
@@ -107,21 +208,65 @@
line = line[len(command):]
return method(string.strip(line))
+ def lineLengthExceeded(self, line):
+ if self.mode is DATA:
+ for message in self.__messages:
+ message.connectionLost()
+ self.mode = COMMAND
+ del self.__messages
+ self.sendCode(500, 'Line too long')
+
+ def rawDataReceived(self, data):
+ "Throw away rest of long line"
+ rest = string.split(data, '\r\n', 1)
+ if len(rest) == 2:
+ self.setLineMode(self.rest[1])
+
def do_UNKNOWN(self, rest):
- self.sendCode(502, 'Command not implemented')
+ self.sendCode(500, 'Command not implemented')
def do_HELO(self, rest):
- self.__helo = rest
- self.sendCode(250, 'Nice to meet you')
+ peer = self.transport.getPeer()[1]
+ self.__helo = (rest, peer)
+ self.sendCode(250, '%s Hello %s, nice to meet you' % (self.host, peer))
def do_QUIT(self, rest):
self.sendCode(221, 'See you later')
self.transport.loseConnection()
+ # A string of quoted strings, backslash-escaped character or
+ # atom characters + '@.,:'
+ qstring = r'("[^"]*"|\\.|[' + string.replace(atom,'#',r'\#') + r'@.,:])+'
+
+ mail_re = re.compile(r'''\s*FROM:\s*(?P<path><> # Empty <>
+ |<''' + qstring + r'''> # <addr>
+ |''' + qstring + r''' # addr
+ )\s*(\s(?P<opts>.*))? # Optional WS + ESMTP options
+ $''',re.I|re.X)
+ rcpt_re = re.compile(r'\s*TO:\s*(?P<path><' + qstring + r'''> # <addr>
+ |''' + qstring + r''' # addr
+ )\s*(\s(?P<opts>.*))? # Optional WS + ESMTP options
+ $''',re.I|re.X)
+
def do_MAIL(self, rest):
- from_ = rest[len("MAIL:<"):-len(">")]
- self.validateFrom(self.__helo, from_, self._fromValid,
- self._fromInvalid)
+ if self.__from:
+ self.sendCode(503,"Only one sender per message, please")
+ return
+ # Clear old recipient list
+ self.__to = []
+ m = self.mail_re.match(rest)
+ if not m:
+ self.sendCode(501, "Syntax error")
+ return
+
+ try:
+ addr = Address(m.group('path'))
+ except AddressError, e:
+ self.sendCode(553, str(e))
+ return
+
+ self.validateFrom(self.__helo, addr, self._fromValid,
+ self._fromInvalid)
def _fromValid(self, from_):
self.__from = from_
@@ -131,12 +276,24 @@
self.sendCode(550, 'No mail for you!')
def do_RCPT(self, rest):
- to = rest[len("TO:<"):-len(">")]
- user = User(to, self.__helo, self, self.__from)
+ if not self.__from:
+ self.sendCode(503, "Must have sender before recpient")
+ return
+ m = self.rcpt_re.match(rest)
+ if not m:
+ self.sendCode(501, "Syntax error")
+ return
+
+ try:
+ user = User(m.group('path'), self.__helo, self, self.__from)
+ except AddressError, e:
+ self.sendCode(553, str(e))
+ return
+
self.validateTo(user, self._toValid, self._toInvalid)
def _toValid(self, to):
- self.__to = self.__to + (to,)
+ self.__to.append(to)
self.sendCode(250, 'Address recognized')
def _toInvalid(self, to):
@@ -144,22 +301,30 @@
def do_DATA(self, rest):
if self.__from is None or not self.__to:
- self.sendCode(550, 'Must have valid receiver and originator')
+ self.sendCode(503, 'Must have valid receiver and originator')
return
self.mode = DATA
helo, origin, recipients = self.__helo, self.__from, self.__to
self.__from = None
- self.__to = ()
+ self.__to = []
self.__messages = self.startMessage(recipients)
+ for message in self.__messages:
+ message.lineReceived(self.receivedHeader(helo, origin, recipients))
self.sendCode(354, 'Continue')
def connectionLost(self, reason):
+ # self.sendCode(421, 'Dropping connection.') # This does nothing...
+ # Ideally, if we (rather than the other side) lose the connection,
+ # we should be able to tell the other side that we are going away.
+ # RFC-2821 requires that we try.
if self.mode is DATA:
for message in self.__messages:
message.connectionLost()
+ del self.__messages
def do_RSET(self, rest):
- self.__init__()
+ self.__from = None
+ self.__to = []
self.sendCode(250, 'I remember nothing.')
def dataLineReceived(self, line):
@@ -177,6 +342,7 @@
deferred = message.eomReceived()
deferred.addCallback(ndeferred.callback)
deferred.addErrback(ndeferred.errback)
+ del self.__messages
return
line = line[1:]
for message in self.__messages:
@@ -189,7 +355,14 @@
self.sendCode(550, 'Could not send e-mail')
# overridable methods:
+ def receivedHeader(self, helo, origin, recipents):
+ return "Received: From %s ([%s]) by %s; %s" % (
+ helo[0], helo[1], self.host,
+ time.strftime("%a, %d %b %Y %H:%M:%S +0000", time.gmtime()))
+
def validateFrom(self, helo, origin, success, failure):
+ if not self.__helo:
+ self.sendCode(503,"Who are you? Say HELO first");
success(origin)
def validateTo(self, user, success, failure):
@@ -265,6 +438,7 @@
def smtpCode_354_data(self, line):
self.mailFile = self.getMailData()
+ self.lastsent = ''
self.transport.registerProducer(self, 0)
def smtpCode_250_afterData(self, line):
@@ -277,11 +451,18 @@
chunk = self.mailFile.read(8192)
if not chunk:
self.transport.unregisterProducer()
- self.sendLine('.')
+ if self.lastsent != '\n':
+ line = '\r\n.'
+ else:
+ line = '.'
+ self.sendLine(line)
self.state = 'afterData'
+ return
chunk = string.replace(chunk, "\n", "\r\n")
+ chunk = string.replace(chunk, "\r\n.", "\r\n..")
self.transport.write(chunk)
+ self.lastsent = chunk[-1]
def pauseProducing(self):
pass
Index: twisted/test/test_protocols.py
===================================================================
RCS file: /cvs/Twisted/twisted/test/test_protocols.py,v
retrieving revision 1.17
diff -u -u -r1.17 test_protocols.py
--- twisted/test/test_protocols.py 23 Sep 2002 08:51:29 -0000 1.17
+++ twisted/test/test_protocols.py 7 Oct 2002 13:04:02 -0000
@@ -33,11 +33,13 @@
class LineTester(basic.LineReceiver):
delimiter = '\n'
+ MAX_LENGTH = 64
def connectionMade(self):
self.received = []
def lineReceived(self, line):
+ print self.MAX_LENGTH, len(line)
self.received.append(line)
if line == '':
self.setRawMode()
@@ -51,6 +53,10 @@
if self.length == 0:
self.setLineMode(rest)
+ def lineLengthExceeded(self, line):
+ if len(line) > self.MAX_LENGTH+1:
+ self.setLineMode(line[self.MAX_LENGTH+1:])
+
class WireTestCase(unittest.TestCase):
def testEcho(self):
@@ -103,13 +109,14 @@
012345678len 0
foo 5
+1234567890123456789012345678901234567890123456789012345678901234567890
len 1
a'''
output = ['len 10', '0123456789', 'len 5', '1234\n',
'len 20', 'foo 123', '0123456789\n012345678',
- 'len 0', 'foo 5', '', 'len 1', 'a']
+ 'len 0', 'foo 5', '', '67890', 'len 1', 'a']
def testBuffer(self):
for packet_size in range(1, 10):
@@ -175,3 +182,6 @@
r.dataReceived(s)
if not r.brokenPeer:
raise AssertionError("connection wasn't closed on illegal netstring %s" % repr(s))
+
+if __name__ == '__main__':
+ unittest.main()
-------------- next part --------------
--
-- Of course I'm crazy, but that doesn't mean I'm wrong.
Anders Hammarquist | iko at strakt.com
AB Strakt | Tel: +46 31 749 08 80
G|teborg, Sweden. RADIO: SM6XMM and N2JGL | Fax: +46 31 749 08 81
More information about the Twisted-Python
mailing list