[Twisted-Python] Postgres client
Sune Kirkeby
sune at mel.interspace.dk
Sun Jul 27 05:28:28 MDT 2003
Hello twisted people.
People have expressed interest in the postgres client I wrote for
twisted waaay back, so here it is.
It's got the packet parsing/formatting, and a weird sort of
interface, so afaik all it really needs is integration with the rest
of twisted (enterprise?), and postgres -> python type mappings
(/usr/include/postgresql/server/catalog/pg_type.h has the postgres
type identifiers listed).
The code is license-less as it is in the patch, but you may consider
it licensed under the LGPL, and I give Glyph non-exclusive
copyrights to the code. There, that should make it twisted-friendly.
I hope.
Enjoy,
/s
--
Sune Kirkeby | If humans were supposed to fly they'd
| be born with stewardesses.
-------------- next part --------------
diff --exclude-from=diffignore -Naur Twisted-0.12.3/doc/examples/postgresql.py Twisted/doc/examples/postgresql.py
--- Twisted-0.12.3/doc/examples/postgresql.py Thu Jan 1 01:00:00 1970
+++ Twisted/doc/examples/postgresql.py Tue Jan 1 22:27:04 2002
@@ -0,0 +1,32 @@
+from twisted.enterprise import spock
+from twisted.protocols import postgresql
+from twisted.internet import main
+from twisted.internet import tcp
+
+class Consumer:
+ def __init__(self, prefix):
+ self.done = 0
+ self.count = 0
+ self.prefix = prefix
+
+ def receivedRow(self, types, data):
+ assert len(types) == len(data)
+ assert not self.done
+
+ self.count = self.count + 1
+ if self.count % 100 == 0:
+ print self.prefix, 'Got 100 rows'
+
+ def queryDone(self):
+ print self.prefix, 'Done (%d rows)' % self.count
+ self.done = 1
+
+pf = postgresql.PostgreSQLClientFactory('sune')
+pool = spock.ConnectionPool('localhost', 5432, pf)
+
+for i in range(8):
+ sql = 'SELECT * FROM foo WHERE id=%d' % i
+ consumer = Consumer('%d' % i)
+ pool.query(sql, consumer)
+
+main.run()
diff --exclude-from=diffignore -Naur Twisted-0.12.3/twisted/protocols/postgresql.py Twisted/twisted/protocols/postgresql.py
--- Twisted-0.12.3/twisted/protocols/postgresql.py Thu Jan 1 01:00:00 1970
+++ Twisted/twisted/protocols/postgresql.py Wed Jan 2 13:54:53 2002
@@ -0,0 +1,336 @@
+from socket import htons, htonl, ntohs, ntohl
+import struct
+import string
+import array
+
+from twisted.protocols import protocol
+
+class IncompletePacket:
+ pass
+
+def parseBytes(l, s):
+ if len(s) < l:
+ raise IncompletePacket()
+ return s[0:l], s[l:]
+
+def formatInt16(i):
+ return struct.pack('h', htons(i))
+
+def parseInt16(s):
+ l = struct.calcsize('h')
+ if len(s) < l:
+ raise IncompletePacket()
+ return ntohs(struct.unpack('h', s[0:l])[0]), s[l:]
+
+def formatInt32(i):
+ return struct.pack('i', htonl(i))
+
+def parseInt32(s):
+ l = struct.calcsize('i')
+ if len(s) < l:
+ raise IncompletePacket()
+ return ntohl(struct.unpack('i', s[0:l])[0]), s[l:]
+
+def formatLimString(l, s):
+ if len(s) < l:
+ return s + '\0' * (l - len(s))
+ else:
+ return s[0:l]
+
+def formatString(s):
+ return s + '\0'
+
+def parseString(s):
+ l = string.find(s, '\0')
+ if l < 0:
+ raise IncompletePacket
+ return s[0:l], s[l + 1:]
+
+
+class StartupPacket:
+ def __init__(self, user, database='', args='', tty=''):
+ self.user = user
+ self.database = database
+ self.args = args
+ self.tty = tty
+ def send(self, transport):
+ s = formatInt32(296) # packet length
+ s += formatInt16(2) # protocol version, major
+ s += formatInt16(0) # protocol version, minor
+ s += formatLimString(64, self.database)
+ s += formatLimString(32, self.user)
+ s += formatLimString(64, self.args)
+ s += formatLimString(64, '') # unused
+ s += formatLimString(64, self.tty)
+ transport.write(s)
+
+class TerminatePacket:
+ def send(self, transport):
+ transport.write('X')
+
+class QueryPacket:
+ def __init__(self, query):
+ self.query = query
+
+ def send(self, transport):
+ transport.write('Q' + formatString(self.query))
+
+class CursorResponsePacket:
+ def __init__(self, name):
+ self.name = name
+
+class EmptyQueryResponsePacket:
+ pass
+
+class CompletedResponsePacket:
+ def __init__(self, cmd):
+ self.command = cmd
+
+class AuthenticationPacket:
+ def __init__(self, auth):
+ self.authentication = auth
+
+class BackendKeyDataPacket:
+ def __init__(self, pid, key):
+ self.process_id = pid
+ self.key = key
+
+class ReadyForQueryPacket:
+ pass
+
+class RowDescriptionPacket:
+ def __init__(self, columns):
+ self.columns = columns
+
+class AsciiRowPacket:
+ def __init__(self, columns):
+ self.columns = columns
+
+class ErrorPacket:
+ def __init__(self, message):
+ self.message = message
+
+class UnknownPacket(Exception):
+ pass
+
+
+def parsePacket(client, data):
+ orig_data = data
+ tag, data = parseBytes(1, data)
+
+ try:
+ if tag == 'E':
+ error, data = parseString(data)
+ return ErrorPacket(error), data
+
+ if tag == 'R':
+ auth, data = parseInt32(data)
+ return AuthenticationPacket(auth), data
+
+ if tag == 'K':
+ pid, data = parseInt32(data)
+ key, data = parseInt32(data)
+ return BackendKeyDataPacket(pid, key), data
+
+ if tag == 'Z':
+ return ReadyForQueryPacket(), data
+
+ if tag == 'P':
+ name, data = parseString(data)
+ return CursorResponsePacket(name), data
+
+ if tag == 'I':
+ unused, data = parseString(data)
+ return EmptyQueryResponsePacket(), data
+
+ if tag == 'T':
+ count, data = parseInt16(data)
+ columns = []
+ for i in range(count):
+ name, data = parseString(data)
+ type_oid, data = parseInt32(data)
+ type_size, data = parseInt16(data)
+ type_modifier, data = parseInt32(data)
+ columns.append((name, (type_oid, type_size, type_modifier)))
+ return RowDescriptionPacket(tuple(columns)), data
+
+ if tag == 'D':
+ field_count = len(client.row_description)
+ if field_count % 8 == 0:
+ bytes = field_count / 8
+ else:
+ bytes = field_count / 8 + 1
+ bitmap, data = parseBytes(bytes, data)
+ bitmap = array.array('B', bitmap)
+
+ fields = []
+ mask = 1 << 8
+ for i in range(field_count):
+ mask = mask >> 1
+ if mask == 0:
+ mask = 1 << 7
+ del bitmap[0]
+
+ if not bitmap[0] & mask:
+ fields.append(None)
+
+ else:
+ size, data = parseInt32(data)
+ value, data = parseBytes(size - 4, data)
+ fields.append(value)
+
+ return AsciiRowPacket(fields), data
+
+ if tag == 'C':
+ cmd, data = parseString(data)
+ return CompletedResponsePacket(cmd), data
+
+ except IncompletePacket:
+ return None, orig_data
+
+ raise UnknownPacket(orig_data)
+
+
+class SilentObserver:
+ def connectionMade(self, client):
+ pass
+
+ def backendError(self, client, message):
+ pass
+
+ def protocolError(self, client, message):
+ pass
+
+ def readyForQuery(self, client):
+ pass
+
+ def connectionLost(self, client):
+ pass
+
+class SilentConsumer:
+ def __init__(self):
+ self.done = 0
+ self.types = None
+
+ def receivedHead(self, types):
+ assert self.types is None
+ self.types = types
+
+ def receivedRow(self, data):
+ assert not self.types is None
+ assert len(self.types) == len(data)
+ assert not self.done
+
+ def queryDone(self):
+ self.done = 1
+
+class PostgreSQLClientFactory(protocol.ClientFactory):
+ def __init__(self, user, password='', args='', tty=''):
+ self.args = (user, password, args, tty)
+
+ def __getstate__(self):
+ return self.args
+ def __setstate__(self, state):
+ self.args = state
+
+ def buildProtocol(self, conn):
+ p = apply(PostgreSQLClient, self.args)
+ p.factory = self
+ return p
+
+class PostgreSQLClient(protocol.Protocol):
+ def __init__(self, user, password='', args='', tty=''):
+ self.user = user
+ self.password = password
+ self.backend_args = args
+ self.backend_tty = tty
+
+ self.ready = 0
+ self.buffer = ''
+ self.row_description = None
+
+ self.observer = SilentObserver()
+ self.consumer = SilentConsumer()
+
+ def setObserver(self, ob):
+ if ob is None:
+ ob = SilentObserver()
+ self.observer = ob
+
+ def connectionMade(self):
+ self.ready = 1
+ self.sendPacket(StartupPacket(self.user))
+
+ def query(self, query, consumer=SilentConsumer()):
+ self.consumer = consumer
+ self.sendPacket(QueryPacket(query))
+
+ def terminate(self):
+ self.sendPacket(TerminatePacket())
+ self.transport.loseConnection()
+ self.observer.terminated(self)
+
+ def sendPacket(self, p):
+ if not self.ready:
+ raise 'Not ready for query.'
+
+ self.ready = 0
+ p.send(self.transport)
+
+ def dataReceived(self, data):
+ self.buffer = self.buffer + data
+ while len(self.buffer) > 0:
+ try:
+ packet, self.buffer = parsePacket(self, self.buffer)
+
+ if packet is None:
+ break
+
+ type = packet.__class__
+ if type is AuthenticationPacket:
+ if packet.authentication == 0:
+ self.observer.connectionMade(self)
+ else:
+ s = 'Got request for unsupported ' + \
+ 'authentication: %d' % packet.authentication
+ self.observer.protocolError(self, s)
+ self.terminate()
+
+ elif type is BackendKeyDataPacket:
+ self.backend_key = packet
+
+ elif type is ReadyForQueryPacket:
+ self.ready = 1
+ self.consumer.queryDone()
+ if self.ready:
+ self.observer.readyForQuery(self)
+
+ elif type is RowDescriptionPacket:
+ self.row_description = packet.columns
+ self.consumer.receivedHead(self.row_description)
+
+ elif type is ErrorPacket:
+ self.observer.backendError(self, packet.message)
+ self.consumer = SilentConsumer()
+
+ elif type is CursorResponsePacket:
+ pass
+
+ elif type is AsciiRowPacket:
+ self.consumer.receivedRow(packet.columns)
+
+ elif type is CompletedResponsePacket:
+ pass
+
+ else:
+ s = 'Got a "%s" I do not know what to do with!' % \
+ packet.__class__.__name__
+ self.observer.protocolError(self, s)
+
+ except UnknownPacket, e:
+ s = 'Unknown packet: %s' % `e.args[0]`
+ self.observer.protocolError(self, s)
+ self.terminate()
+ self.buffer = ''
+
+__all__ = ['PostgreSQLClient']
More information about the Twisted-Python
mailing list